diff --git a/Cargo.toml b/Cargo.toml index fa68f45..c237caf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ json = ["dep:serde", "dep:serde_json"] [dependencies] async-task.workspace = true +futures-core.workspace = true http.workspace = true itoa.workspace = true pin-project-lite.workspace = true diff --git a/src/io/streams.rs b/src/io/streams.rs index d8b35ec..7bb3939 100644 --- a/src/io/streams.rs +++ b/src/io/streams.rs @@ -1,15 +1,19 @@ use super::{AsyncPollable, AsyncRead, AsyncWrite}; -use std::cell::OnceCell; -use std::io::Result; +use crate::runtime::WaitFor; +use std::future::{poll_fn, Future}; +use std::pin::Pin; +use std::sync::{Mutex, OnceLock}; +use std::task::{Context, Poll}; use wasip2::io::streams::{InputStream, OutputStream, StreamError}; /// A wrapper for WASI's `InputStream` resource that provides implementations of `AsyncRead` and /// `AsyncPollable`. #[derive(Debug)] pub struct AsyncInputStream { + wait_for: Mutex>>>, // Lazily initialized pollable, used for lifetime of stream to check readiness. // Field ordering matters: this child must be dropped before stream - subscription: OnceCell, + subscription: OnceLock, stream: InputStream, } @@ -17,22 +21,34 @@ impl AsyncInputStream { /// Construct an `AsyncInputStream` from a WASI `InputStream` resource. pub fn new(stream: InputStream) -> Self { Self { - subscription: OnceCell::new(), + wait_for: Mutex::new(None), + subscription: OnceLock::new(), stream, } } - /// Await for read readiness. - async fn ready(&self) { + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> { // Lazily initialize the AsyncPollable let subscription = self .subscription .get_or_init(|| AsyncPollable::new(self.stream.subscribe())); - // Wait on readiness - subscription.wait_for().await; + // Lazily initialize the WaitFor. Clear it after it becomes ready. + let mut wait_for_slot = self.wait_for.lock().unwrap(); + let wait_for = wait_for_slot.get_or_insert_with(|| Box::pin(subscription.wait_for())); + match wait_for.as_mut().poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(()) => { + let _ = wait_for_slot.take(); + Poll::Ready(()) + } + } + } + /// Await for read readiness. + async fn ready(&self) { + poll_fn(|cx| self.poll_ready(cx)).await } /// Asynchronously read from the input stream. /// This method is the same as [`AsyncRead::read`], but doesn't require a `&mut self`. - pub async fn read(&self, buf: &mut [u8]) -> Result { + pub async fn read(&self, buf: &mut [u8]) -> std::io::Result { let read = loop { self.ready().await; // Ideally, the ABI would be able to read directly into buf. @@ -56,10 +72,40 @@ impl AsyncInputStream { buf[0..len].copy_from_slice(&read); Ok(len) } + + /// Use this `AsyncInputStream` as a `futures_core::stream::Stream` with + /// items of `Result, std::io::Error>`. The returned byte vectors + /// will be at most 8k. If you want to control chunk size, use + /// `Self::into_stream_of`. + pub fn into_stream(self) -> AsyncInputChunkStream { + AsyncInputChunkStream { + stream: self, + chunk_size: 8 * 1024, + } + } + + /// Use this `AsyncInputStream` as a `futures_core::stream::Stream` with + /// items of `Result, std::io::Error>`. The returned byte vectors + /// will be at most the `chunk_size` argument specified. + pub fn into_stream_of(self, chunk_size: usize) -> AsyncInputChunkStream { + AsyncInputChunkStream { + stream: self, + chunk_size, + } + } + + /// Use this `AsyncInputStream` as a `futures_core::stream::Stream` with + /// items of `Result`. + pub fn into_bytestream(self) -> AsyncInputByteStream { + AsyncInputByteStream { + stream: self.into_stream(), + buffer: std::io::Read::bytes(std::io::Cursor::new(Vec::new())), + } + } } impl AsyncRead for AsyncInputStream { - async fn read(&mut self, buf: &mut [u8]) -> Result { + async fn read(&mut self, buf: &mut [u8]) -> std::io::Result { Self::read(self, buf).await } @@ -69,13 +115,94 @@ impl AsyncRead for AsyncInputStream { } } +/// Wrapper of `AsyncInputStream` that impls `futures_core::stream::Stream` +/// with an item of `Result, std::io::Error>` +pub struct AsyncInputChunkStream { + stream: AsyncInputStream, + chunk_size: usize, +} + +impl AsyncInputChunkStream { + /// Extract the `AsyncInputStream` which backs this stream. + pub fn into_inner(self) -> AsyncInputStream { + self.stream + } +} + +impl futures_core::stream::Stream for AsyncInputChunkStream { + type Item = Result, std::io::Error>; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.stream.poll_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(()) => match self.stream.stream.read(self.chunk_size as u64) { + Ok(r) if r.is_empty() => Poll::Pending, + Ok(r) => Poll::Ready(Some(Ok(r))), + Err(StreamError::LastOperationFailed(err)) => { + Poll::Ready(Some(Err(std::io::Error::other(err.to_debug_string())))) + } + Err(StreamError::Closed) => Poll::Ready(None), + }, + } + } +} + +pin_project_lite::pin_project! { + /// Wrapper of `AsyncInputStream` that impls + /// `futures_core::stream::Stream` with item `Result`. + pub struct AsyncInputByteStream { + #[pin] + stream: AsyncInputChunkStream, + buffer: std::io::Bytes>>, + } +} + +impl AsyncInputByteStream { + /// Extract the `AsyncInputStream` which backs this stream, and any bytes + /// read from the `AsyncInputStream` which have not yet been yielded by + /// the byte stream. + pub fn into_inner(self) -> (AsyncInputStream, Vec) { + ( + self.stream.into_inner(), + self.buffer + .collect::, std::io::Error>>() + .expect("read of Cursor> is infallible"), + ) + } +} + +impl futures_core::stream::Stream for AsyncInputByteStream { + type Item = Result; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + match this.buffer.next() { + Some(byte) => Poll::Ready(Some(Ok(byte.expect("cursor on Vec is infallible")))), + None => match futures_core::stream::Stream::poll_next(this.stream, cx) { + Poll::Ready(Some(Ok(bytes))) => { + let mut bytes = std::io::Read::bytes(std::io::Cursor::new(bytes)); + match bytes.next() { + Some(Ok(byte)) => { + *this.buffer = bytes; + Poll::Ready(Some(Ok(byte))) + } + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } + } + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + }, + } + } +} + /// A wrapper for WASI's `output-stream` resource that provides implementations of `AsyncWrite` and /// `AsyncPollable`. #[derive(Debug)] pub struct AsyncOutputStream { // Lazily initialized pollable, used for lifetime of stream to check readiness. // Field ordering matters: this child must be dropped before stream - subscription: OnceCell, + subscription: OnceLock, stream: OutputStream, } @@ -83,7 +210,7 @@ impl AsyncOutputStream { /// Construct an `AsyncOutputStream` from a WASI `OutputStream` resource. pub fn new(stream: OutputStream) -> Self { Self { - subscription: OnceCell::new(), + subscription: OnceLock::new(), stream, } } @@ -104,7 +231,7 @@ impl AsyncOutputStream { /// a `std::io::Error` indicating either an error returned by the stream write /// using the debug string provided by the WASI error, or else that the, /// indicated by `std::io::ErrorKind::ConnectionReset`. - pub async fn write(&self, buf: &[u8]) -> Result { + pub async fn write(&self, buf: &[u8]) -> std::io::Result { // Loops at most twice. loop { match self.stream.check_write() { @@ -145,7 +272,7 @@ impl AsyncOutputStream { /// the stream flush, using the debug string provided by the WASI error, /// or else that the stream is closed, indicated by /// `std::io::ErrorKind::ConnectionReset`. - pub async fn flush(&self) -> Result<()> { + pub async fn flush(&self) -> std::io::Result<()> { match self.stream.flush() { Ok(()) => { self.ready().await; @@ -162,10 +289,10 @@ impl AsyncOutputStream { } impl AsyncWrite for AsyncOutputStream { // Required methods - async fn write(&mut self, buf: &[u8]) -> Result { + async fn write(&mut self, buf: &[u8]) -> std::io::Result { Self::write(self, buf).await } - async fn flush(&mut self) -> Result<()> { + async fn flush(&mut self) -> std::io::Result<()> { Self::flush(self).await } @@ -180,11 +307,10 @@ pub(crate) async fn splice( reader: &AsyncInputStream, writer: &AsyncOutputStream, len: u64, -) -> core::result::Result { +) -> Result { // Wait for both streams to be ready. - let r = reader.ready(); + reader.ready().await; writer.ready().await; - r.await; writer.stream.splice(&reader.stream, len) }