diff --git a/src/http/client.rs b/src/http/client.rs index af792d5..552d740 100644 --- a/src/http/client.rs +++ b/src/http/client.rs @@ -1,5 +1,5 @@ use super::{response::IncomingBody, Body, Error, Request, Response, Result}; -use crate::io::{self, AsyncWrite}; +use crate::io::{self, AsyncOutputStream}; use crate::runtime::Reactor; use crate::time::Duration; use wasi::http::types::{OutgoingBody, RequestOptions as WasiRequestOptions}; @@ -27,9 +27,7 @@ impl Client { let res = wasi::http::outgoing_handler::handle(wasi_req, self.wasi_options()?).unwrap(); // 2. Start sending the request body - io::copy(body, OutputStream::new(body_stream)) - .await - .expect("io::copy broke oh no"); + io::copy(body, AsyncOutputStream::new(body_stream)).await?; // 3. Finish sending the request body let trailers = None; @@ -74,33 +72,6 @@ impl Client { } } -struct OutputStream { - stream: wasi::http::types::OutputStream, -} - -impl OutputStream { - fn new(stream: wasi::http::types::OutputStream) -> Self { - Self { stream } - } -} - -impl AsyncWrite for OutputStream { - async fn write(&mut self, buf: &[u8]) -> io::Result { - let max = self.stream.check_write().unwrap() as usize; - let max = max.min(buf.len()); - let buf = &buf[0..max]; - self.stream.write(buf).unwrap(); - Reactor::current().wait_for(self.stream.subscribe()).await; - Ok(max) - } - - async fn flush(&mut self) -> io::Result<()> { - self.stream.flush().unwrap(); - Reactor::current().wait_for(self.stream.subscribe()).await; - Ok(()) - } -} - #[derive(Default, Debug)] struct RequestOptions { connect_timeout: Option, diff --git a/src/http/error.rs b/src/http/error.rs index fb785bd..a32cf1c 100644 --- a/src/http/error.rs +++ b/src/http/error.rs @@ -24,6 +24,7 @@ impl fmt::Debug for Error { ErrorVariant::HeaderName(e) => write!(f, "header name error: {e:?}"), ErrorVariant::HeaderValue(e) => write!(f, "header value error: {e:?}"), ErrorVariant::Method(e) => write!(f, "method error: {e:?}"), + ErrorVariant::BodyIo(e) => write!(f, "body error: {e:?}"), ErrorVariant::Other(e) => write!(f, "{e}"), } } @@ -37,6 +38,7 @@ impl fmt::Display for Error { ErrorVariant::HeaderName(e) => write!(f, "header name error: {e}"), ErrorVariant::HeaderValue(e) => write!(f, "header value error: {e}"), ErrorVariant::Method(e) => write!(f, "method error: {e}"), + ErrorVariant::BodyIo(e) => write!(f, "body error: {e}"), ErrorVariant::Other(e) => write!(f, "{e}"), } } @@ -100,6 +102,12 @@ impl From for Error { } } +impl From for Error { + fn from(e: std::io::Error) -> Error { + ErrorVariant::BodyIo(e).into() + } +} + #[derive(Debug)] pub enum ErrorVariant { WasiHttp(WasiHttpErrorCode), @@ -107,5 +115,6 @@ pub enum ErrorVariant { HeaderName(InvalidHeaderName), HeaderValue(InvalidHeaderValue), Method(InvalidMethod), + BodyIo(std::io::Error), Other(String), } diff --git a/src/http/response.rs b/src/http/response.rs index 0dcfe30..463cc03 100644 --- a/src/http/response.rs +++ b/src/http/response.rs @@ -1,12 +1,7 @@ use wasi::http::types::{IncomingBody as WasiIncomingBody, IncomingResponse}; -use wasi::io::streams::{InputStream, StreamError}; use super::{fields::header_map_from_wasi, Body, Error, HeaderMap, Result, StatusCode}; -use crate::io::AsyncRead; -use crate::runtime::Reactor; - -/// Stream 2kb chunks at a time -const CHUNK_SIZE: u64 = 2048; +use crate::io::{AsyncInputStream, AsyncRead}; /// An HTTP response #[derive(Debug)] @@ -57,9 +52,7 @@ impl Response { let body = IncomingBody { kind, - buf_offset: 0, - buf: None, - body_stream, + body_stream: AsyncInputStream::new(body_stream), _incoming_body: incoming_body, }; @@ -96,54 +89,15 @@ impl Response { #[derive(Debug)] pub struct IncomingBody { kind: BodyKind, - buf: Option>, - // How many bytes have we already read from the buf? - buf_offset: usize, - // IMPORTANT: the order of these fields here matters. `body_stream` must // be dropped before `_incoming_body`. - body_stream: InputStream, + body_stream: AsyncInputStream, _incoming_body: WasiIncomingBody, } impl AsyncRead for IncomingBody { async fn read(&mut self, out_buf: &mut [u8]) -> crate::io::Result { - let buf = match &mut self.buf { - Some(ref mut buf) => buf, - None => { - // Wait for an event to be ready - let pollable = self.body_stream.subscribe(); - Reactor::current().wait_for(pollable).await; - - // Read the bytes from the body stream - let buf = match self.body_stream.read(CHUNK_SIZE) { - Ok(buf) => buf, - Err(StreamError::Closed) => return Ok(0), - Err(StreamError::LastOperationFailed(err)) => { - return Err(std::io::Error::other(format!( - "last operation failed: {}", - err.to_debug_string() - ))) - } - }; - self.buf.insert(buf) - } - }; - - // copy bytes - let len = (buf.len() - self.buf_offset).min(out_buf.len()); - let max = self.buf_offset + len; - let slice = &buf[self.buf_offset..max]; - out_buf[0..len].copy_from_slice(slice); - self.buf_offset += len; - - // reset the local slice if necessary - if self.buf_offset == buf.len() { - self.buf = None; - self.buf_offset = 0; - } - - Ok(len) + self.body_stream.read(out_buf).await } } diff --git a/src/io/mod.rs b/src/io/mod.rs index cdb82f2..0f34b1b 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -5,13 +5,18 @@ mod cursor; mod empty; mod read; mod seek; +mod stdio; +mod streams; mod write; +pub use crate::runtime::AsyncPollable; pub use copy::*; pub use cursor::*; pub use empty::*; pub use read::*; pub use seek::*; +pub use stdio::*; +pub use streams::*; pub use write::*; /// The error type for I/O operations. diff --git a/src/io/stdio.rs b/src/io/stdio.rs new file mode 100644 index 0000000..b36b0c4 --- /dev/null +++ b/src/io/stdio.rs @@ -0,0 +1,139 @@ +use super::{AsyncInputStream, AsyncOutputStream}; +use std::cell::LazyCell; +use wasi::cli::terminal_input::TerminalInput; +use wasi::cli::terminal_output::TerminalOutput; + +/// Use the program's stdin as an `AsyncInputStream`. +#[derive(Debug)] +pub struct Stdin { + stream: AsyncInputStream, + terminput: LazyCell>, +} + +/// Get the program's stdin for use as an `AsyncInputStream`. +pub fn stdin() -> Stdin { + let stream = AsyncInputStream::new(wasi::cli::stdin::get_stdin()); + Stdin { + stream, + terminput: LazyCell::new(|| wasi::cli::terminal_stdin::get_terminal_stdin()), + } +} + +impl std::ops::Deref for Stdin { + type Target = AsyncInputStream; + fn deref(&self) -> &AsyncInputStream { + &self.stream + } +} +impl std::ops::DerefMut for Stdin { + fn deref_mut(&mut self) -> &mut AsyncInputStream { + &mut self.stream + } +} + +impl Stdin { + /// Check if stdin is a terminal. + pub fn is_terminal(&self) -> bool { + LazyCell::force(&self.terminput).is_some() + } +} + +/// Use the program's stdout as an `AsyncOutputStream`. +#[derive(Debug)] +pub struct Stdout { + stream: AsyncOutputStream, + termoutput: LazyCell>, +} + +/// Get the program's stdout for use as an `AsyncOutputStream`. +pub fn stdout() -> Stdout { + let stream = AsyncOutputStream::new(wasi::cli::stdout::get_stdout()); + Stdout { + stream, + termoutput: LazyCell::new(|| wasi::cli::terminal_stdout::get_terminal_stdout()), + } +} + +impl Stdout { + /// Check if stdout is a terminal. + pub fn is_terminal(&self) -> bool { + LazyCell::force(&self.termoutput).is_some() + } +} + +impl std::ops::Deref for Stdout { + type Target = AsyncOutputStream; + fn deref(&self) -> &AsyncOutputStream { + &self.stream + } +} +impl std::ops::DerefMut for Stdout { + fn deref_mut(&mut self) -> &mut AsyncOutputStream { + &mut self.stream + } +} + +/// Use the program's stdout as an `AsyncOutputStream`. +#[derive(Debug)] +pub struct Stderr { + stream: AsyncOutputStream, + termoutput: LazyCell>, +} + +/// Get the program's stdout for use as an `AsyncOutputStream`. +pub fn stderr() -> Stderr { + let stream = AsyncOutputStream::new(wasi::cli::stderr::get_stderr()); + Stderr { + stream, + termoutput: LazyCell::new(|| wasi::cli::terminal_stderr::get_terminal_stderr()), + } +} + +impl Stderr { + /// Check if stderr is a terminal. + pub fn is_terminal(&self) -> bool { + LazyCell::force(&self.termoutput).is_some() + } +} + +impl std::ops::Deref for Stderr { + type Target = AsyncOutputStream; + fn deref(&self) -> &AsyncOutputStream { + &self.stream + } +} +impl std::ops::DerefMut for Stderr { + fn deref_mut(&mut self) -> &mut AsyncOutputStream { + &mut self.stream + } +} + +#[cfg(test)] +mod test { + use crate::io::AsyncWrite; + use crate::runtime::block_on; + #[test] + // No internal predicate. Run test with --nocapture and inspect output manually. + fn stdout_println_hello_world() { + block_on(async { + let mut stdout = super::stdout(); + let term = if stdout.is_terminal() { "is" } else { "is not" }; + stdout + .write_all(format!("hello, world! stdout {term} a terminal\n",).as_bytes()) + .await + .unwrap(); + }) + } + #[test] + // No internal predicate. Run test with --nocapture and inspect output manually. + fn stderr_println_hello_world() { + block_on(async { + let mut stdout = super::stdout(); + let term = if stdout.is_terminal() { "is" } else { "is not" }; + stdout + .write_all(format!("hello, world! stderr {term} a terminal\n",).as_bytes()) + .await + .unwrap(); + }) + } +} diff --git a/src/io/streams.rs b/src/io/streams.rs new file mode 100644 index 0000000..de45d88 --- /dev/null +++ b/src/io/streams.rs @@ -0,0 +1,125 @@ +use super::{AsyncPollable, AsyncRead, AsyncWrite}; +use std::cell::RefCell; +use std::io::Result; +use wasi::io::streams::{InputStream, OutputStream, StreamError}; + +#[derive(Debug)] +pub struct AsyncInputStream { + // Lazily initialized pollable, used for lifetime of stream to check readiness. + // Field ordering matters: this child must be dropped before stream + subscription: RefCell>, + stream: InputStream, +} + +impl AsyncInputStream { + pub fn new(stream: InputStream) -> Self { + Self { + subscription: RefCell::new(None), + stream, + } + } + async fn ready(&self) { + // Lazily initialize the AsyncPollable + if self.subscription.borrow().is_none() { + self.subscription + .replace(Some(AsyncPollable::new(self.stream.subscribe()))); + } + // Wait on readiness + self.subscription + .borrow() + .as_ref() + .expect("populated refcell") + .wait_for() + .await; + } +} + +impl AsyncRead for AsyncInputStream { + async fn read(&mut self, buf: &mut [u8]) -> Result { + self.ready().await; + // Ideally, the ABI would be able to read directly into buf. However, with the default + // generated bindings, it returns a newly allocated vec, which we need to copy into buf. + let read = match self.stream.read(buf.len() as u64) { + Ok(r) => r, + Err(StreamError::Closed) => return Ok(0), + Err(StreamError::LastOperationFailed(err)) => { + return Err(std::io::Error::other(err.to_debug_string())) + } + }; + let len = read.len(); + buf[0..len].copy_from_slice(&read); + Ok(len) + } +} + +#[derive(Debug)] +pub struct AsyncOutputStream { + // Lazily initialized pollable, used for lifetime of stream to check readiness. + // Field ordering matters: this child must be dropped before stream + subscription: RefCell>, + stream: OutputStream, +} + +impl AsyncOutputStream { + pub fn new(stream: OutputStream) -> Self { + Self { + subscription: RefCell::new(None), + stream, + } + } + async fn ready(&self) { + // Lazily initialize the AsyncPollable + if self.subscription.borrow().is_none() { + self.subscription + .replace(Some(AsyncPollable::new(self.stream.subscribe()))); + } + // Wait on readiness + self.subscription + .borrow() + .as_ref() + .expect("populated refcell") + .wait_for() + .await; + } +} +impl AsyncWrite for AsyncOutputStream { + // Required methods + async fn write(&mut self, buf: &[u8]) -> Result { + // Loops at most twice. + loop { + match self.stream.check_write() { + Ok(0) => { + self.ready().await; + // Next loop guaranteed to have nonzero check_write, or error. + continue; + } + Ok(some) => { + let writable = some.try_into().unwrap_or(usize::MAX).min(buf.len()); + match self.stream.write(&buf[0..writable]) { + Ok(()) => return Ok(writable), + Err(StreamError::Closed) => return Ok(0), + Err(StreamError::LastOperationFailed(err)) => { + return Err(std::io::Error::other(err.to_debug_string())) + } + } + } + Err(StreamError::Closed) => return Ok(0), + Err(StreamError::LastOperationFailed(err)) => { + return Err(std::io::Error::other(err.to_debug_string())) + } + } + } + } + async fn flush(&mut self) -> Result<()> { + match self.stream.flush() { + Ok(()) => { + self.ready().await; + Ok(()) + } + Err(StreamError::Closed) => Ok(()), + Err(StreamError::LastOperationFailed(err)) => { + Err(std::io::Error::other(err.to_debug_string())) + } + } + } +} diff --git a/src/net/tcp_listener.rs b/src/net/tcp_listener.rs index 94352aa..7aedc71 100644 --- a/src/net/tcp_listener.rs +++ b/src/net/tcp_listener.rs @@ -3,15 +3,17 @@ use wasi::sockets::tcp::{ErrorCode, IpAddressFamily, IpSocketAddress, TcpSocket} use crate::io; use crate::iter::AsyncIterator; -use crate::runtime::Reactor; use std::io::ErrorKind; use std::net::SocketAddr; use super::TcpStream; +use crate::runtime::AsyncPollable; /// A TCP socket server, listening for connections. #[derive(Debug)] pub struct TcpListener { + // Field order matters: must drop this child before parent below + pollable: AsyncPollable, socket: TcpSocket, } @@ -40,18 +42,17 @@ impl TcpListener { } SocketAddr::V6(_) => todo!("IPv6 not yet supported in `wstd::net::TcpListener`"), }; - let reactor = Reactor::current(); - socket .start_bind(&network, local_address) .map_err(to_io_err)?; - reactor.wait_for(socket.subscribe()).await; + let pollable = AsyncPollable::new(socket.subscribe()); + pollable.wait_for().await; socket.finish_bind().map_err(to_io_err)?; socket.start_listen().map_err(to_io_err)?; - reactor.wait_for(socket.subscribe()).await; + pollable.wait_for().await; socket.finish_listen().map_err(to_io_err)?; - Ok(Self { socket }) + Ok(Self { pollable, socket }) } /// Returns the local socket address of this listener. @@ -77,18 +78,12 @@ impl<'a> AsyncIterator for Incoming<'a> { type Item = io::Result; async fn next(&mut self) -> Option { - Reactor::current() - .wait_for(self.listener.socket.subscribe()) - .await; + self.listener.pollable.wait_for().await; let (socket, input, output) = match self.listener.socket.accept().map_err(to_io_err) { Ok(accepted) => accepted, Err(err) => return Some(Err(err)), }; - Some(Ok(TcpStream { - socket, - input, - output, - })) + Some(Ok(TcpStream::new(input, output, socket))) } } diff --git a/src/net/tcp_stream.rs b/src/net/tcp_stream.rs index 0e66550..c1f663b 100644 --- a/src/net/tcp_stream.rs +++ b/src/net/tcp_stream.rs @@ -1,23 +1,27 @@ -use std::io::Error; +use std::cell::RefCell; use wasi::{ - io::streams::StreamError, - sockets::tcp::{InputStream, OutputStream, TcpSocket}, + io::streams::{InputStream, OutputStream}, + sockets::tcp::TcpSocket, }; -use crate::{ - io::{self, AsyncRead, AsyncWrite}, - runtime::Reactor, -}; +use crate::io::{self, AsyncInputStream, AsyncOutputStream, AsyncRead, AsyncWrite}; /// A TCP stream between a local and a remote socket. pub struct TcpStream { - pub(super) input: InputStream, - pub(super) output: OutputStream, - pub(super) socket: TcpSocket, + input: RefCell, + output: RefCell, + socket: TcpSocket, } impl TcpStream { + pub(crate) fn new(input: InputStream, output: OutputStream, socket: TcpSocket) -> Self { + TcpStream { + input: RefCell::new(AsyncInputStream::new(input)), + output: RefCell::new(AsyncOutputStream::new(output)), + socket, + } + } /// Returns the socket address of the remote peer of this TCP connection. pub fn peer_addr(&self) -> io::Result { let addr = self @@ -40,53 +44,33 @@ impl Drop for TcpStream { impl AsyncRead for TcpStream { async fn read(&mut self, buf: &mut [u8]) -> io::Result { - Reactor::current().wait_for(self.input.subscribe()).await; - let slice = match self.input.read(buf.len() as u64) { - Ok(slice) => slice, - Err(StreamError::Closed) => return Ok(0), - Err(e) => return Err(to_io_err(e)), - }; - let bytes_read = slice.len(); - buf[..bytes_read].clone_from_slice(&slice); - Ok(bytes_read) + self.input.borrow_mut().read(buf).await } } impl AsyncRead for &TcpStream { async fn read(&mut self, buf: &mut [u8]) -> io::Result { - Reactor::current().wait_for(self.input.subscribe()).await; - let slice = match self.input.read(buf.len() as u64) { - Ok(slice) => slice, - Err(StreamError::Closed) => return Ok(0), - Err(e) => return Err(to_io_err(e)), - }; - let bytes_read = slice.len(); - buf[..bytes_read].clone_from_slice(&slice); - Ok(bytes_read) + self.input.borrow_mut().read(buf).await } } impl AsyncWrite for TcpStream { async fn write(&mut self, buf: &[u8]) -> io::Result { - Reactor::current().wait_for(self.output.subscribe()).await; - self.output.write(buf).map_err(to_io_err)?; - Ok(buf.len()) + self.output.borrow_mut().write(buf).await } async fn flush(&mut self) -> io::Result<()> { - self.output.flush().map_err(to_io_err) + self.output.borrow_mut().flush().await } } impl AsyncWrite for &TcpStream { async fn write(&mut self, buf: &[u8]) -> io::Result { - Reactor::current().wait_for(self.output.subscribe()).await; - self.output.write(buf).map_err(to_io_err)?; - Ok(buf.len()) + self.output.borrow_mut().write(buf).await } async fn flush(&mut self) -> io::Result<()> { - self.output.flush().map_err(to_io_err) + self.output.borrow_mut().flush().await } } @@ -125,10 +109,3 @@ impl<'a> Drop for WriteHalf<'a> { .shutdown(wasi::sockets::tcp::ShutdownType::Send); } } - -fn to_io_err(err: StreamError) -> std::io::Error { - match err { - StreamError::LastOperationFailed(err) => Error::other(err.to_debug_string()), - StreamError::Closed => Error::other("Stream was closed"), - } -} diff --git a/src/runtime/reactor.rs b/src/runtime/reactor.rs index 831d598..afd001d 100644 --- a/src/runtime/reactor.rs +++ b/src/runtime/reactor.rs @@ -33,6 +33,11 @@ impl Drop for Registration { pub struct AsyncPollable(Rc); impl AsyncPollable { + /// Create an `AsyncPollable` from a Wasi `Pollable`. Schedules the `Pollable` with the current + /// `Reactor`. + pub fn new(pollable: Pollable) -> Self { + Reactor::current().schedule(pollable) + } /// Create a Future that waits for the Pollable's readiness. pub fn wait_for(&self) -> WaitFor { use std::sync::atomic::{AtomicUsize, Ordering};