diff --git a/crates/wasi-http/src/types_impl.rs b/crates/wasi-http/src/types_impl.rs index b09cd20ed38f..2b6b5c15b04e 100644 --- a/crates/wasi-http/src/types_impl.rs +++ b/crates/wasi-http/src/types_impl.rs @@ -705,7 +705,7 @@ where let body = self.table().get_mut(&id)?; if let Some(stream) = body.take_stream() { - let stream = InputStream::Host(Box::new(stream)); + let stream: InputStream = Box::new(stream); let stream = self.table().push_child(stream, &id)?; return Ok(Ok(stream)); } diff --git a/crates/wasi/src/bindings.rs b/crates/wasi/src/bindings.rs index edbdda385b93..3bc96b5c6546 100644 --- a/crates/wasi/src/bindings.rs +++ b/crates/wasi/src/bindings.rs @@ -367,17 +367,15 @@ mod async_io { "[method]descriptor.unlink-file-at", "[method]descriptor.unlock", "[method]descriptor.write", - "[method]input-stream.read", "[method]input-stream.blocking-read", "[method]input-stream.blocking-skip", - "[method]input-stream.skip", - "[method]output-stream.forward", - "[method]output-stream.splice", + "[drop]input-stream", "[method]output-stream.blocking-splice", "[method]output-stream.blocking-flush", "[method]output-stream.blocking-write", "[method]output-stream.blocking-write-and-flush", "[method]output-stream.blocking-write-zeroes-and-flush", + "[drop]output-stream", "[method]directory-entry-stream.read-directory-entry", "poll", "[method]pollable.block", diff --git a/crates/wasi/src/filesystem.rs b/crates/wasi/src/filesystem.rs index f95bcc035450..a924fa96a70c 100644 --- a/crates/wasi/src/filesystem.rs +++ b/crates/wasi/src/filesystem.rs @@ -1,6 +1,8 @@ use crate::bindings::filesystem::types; use crate::runtime::{spawn_blocking, AbortOnDropJoinHandle}; -use crate::{HostOutputStream, StreamError, Subscribe, TrappableError}; +use crate::{ + HostInputStream, HostOutputStream, StreamError, StreamResult, Subscribe, TrappableError, +}; use anyhow::anyhow; use bytes::{Bytes, BytesMut}; use std::io; @@ -112,19 +114,40 @@ impl File { } } - /// Spawn a task on tokio's blocking thread for performing blocking - /// syscalls on the underlying [`cap_std::fs::File`]. - pub(crate) async fn spawn_blocking(&self, body: F) -> R + /// Execute the blocking `body` function. + /// + /// Depending on how the WasiCtx was configured, the body may either be: + /// - Executed directly on the current thread. In this case the `async` + /// signature of this method is effectively a lie and the returned + /// Future will always be immediately Ready. Or: + /// - Spawned on a background thread using [`tokio::task::spawn_blocking`] + /// and immediately awaited. + /// + /// Intentionally blocking the executor thread might seem unorthodox, but is + /// not actually a problem for specific workloads. See: + /// - [`crate::WasiCtxBuilder::allow_blocking_current_thread`] + /// - [Poor performance of wasmtime file I/O maybe because tokio](https://github.com/bytecodealliance/wasmtime/issues/7973) + /// - [Implement opt-in for enabling WASI to block the current thread](https://github.com/bytecodealliance/wasmtime/pull/8190) + pub(crate) async fn run_blocking(&self, body: F) -> R where F: FnOnce(&cap_std::fs::File) -> R + Send + 'static, R: Send + 'static, { - match self._spawn_blocking(body) { - SpawnBlocking::Done(result) => result, - SpawnBlocking::Spawned(task) => task.await, + match self.as_blocking_file() { + Some(file) => body(file), + None => self.spawn_blocking(body).await, } } + pub(crate) fn spawn_blocking(&self, body: F) -> AbortOnDropJoinHandle + where + F: FnOnce(&cap_std::fs::File) -> R + Send + 'static, + R: Send + 'static, + { + let f = self.file.clone(); + spawn_blocking(move || body(&f)) + } + /// Returns `Some` when the current thread is allowed to block in filesystem /// operations, and otherwise returns `None` to indicate that /// `spawn_blocking` must be used. @@ -135,25 +158,6 @@ impl File { None } } - - fn _spawn_blocking(&self, body: F) -> SpawnBlocking - where - F: FnOnce(&cap_std::fs::File) -> R + Send + 'static, - R: Send + 'static, - { - match self.as_blocking_file() { - Some(file) => SpawnBlocking::Done(body(file)), - None => { - let f = self.file.clone(); - SpawnBlocking::Spawned(spawn_blocking(move || body(&f))) - } - } - } -} - -enum SpawnBlocking { - Done(T), - Spawned(AbortOnDropJoinHandle), } bitflags::bitflags! { @@ -217,9 +221,21 @@ impl Dir { } } - /// Spawn a task on tokio's blocking thread for performing blocking - /// syscalls on the underlying [`cap_std::fs::Dir`]. - pub(crate) async fn spawn_blocking(&self, body: F) -> R + /// Execute the blocking `body` function. + /// + /// Depending on how the WasiCtx was configured, the body may either be: + /// - Executed directly on the current thread. In this case the `async` + /// signature of this method is effectively a lie and the returned + /// Future will always be immediately Ready. Or: + /// - Spawned on a background thread using [`tokio::task::spawn_blocking`] + /// and immediately awaited. + /// + /// Intentionally blocking the executor thread might seem unorthodox, but is + /// not actually a problem for specific workloads. See: + /// - [`crate::WasiCtxBuilder::allow_blocking_current_thread`] + /// - [Poor performance of wasmtime file I/O maybe because tokio](https://github.com/bytecodealliance/wasmtime/issues/7973) + /// - [Implement opt-in for enabling WASI to block the current thread](https://github.com/bytecodealliance/wasmtime/pull/8190) + pub(crate) async fn run_blocking(&self, body: F) -> R where F: FnOnce(&cap_std::fs::Dir) -> R + Send + 'static, R: Send + 'static, @@ -236,45 +252,137 @@ impl Dir { pub struct FileInputStream { file: File, position: u64, + state: ReadState, +} +enum ReadState { + Idle, + Waiting(AbortOnDropJoinHandle), + DataAvailable(Bytes), + Error(io::Error), + Closed, } impl FileInputStream { pub fn new(file: &File, position: u64) -> Self { Self { file: file.clone(), position, + state: ReadState::Idle, } } - pub async fn read(&mut self, size: usize) -> Result { + fn blocking_read(file: &cap_std::fs::File, offset: u64, size: usize) -> ReadState { use system_interface::fs::FileIoExt; - let p = self.position; - let (r, mut buf) = self - .file - .spawn_blocking(move |f| { - let mut buf = BytesMut::zeroed(size); - let r = f.read_at(&mut buf, p); - (r, buf) - }) - .await; - let n = read_result(r, size)?; - buf.truncate(n); - self.position += n as u64; - Ok(buf.freeze()) - } - - pub async fn skip(&mut self, nelem: usize) -> Result { - let bs = self.read(nelem).await?; - Ok(bs.len()) + let mut buf = BytesMut::zeroed(size); + loop { + match file.read_at(&mut buf, offset) { + Ok(0) => return ReadState::Closed, + Ok(n) => { + buf.truncate(n); + return ReadState::DataAvailable(buf.freeze()); + } + Err(e) if e.kind() == std::io::ErrorKind::Interrupted => { + // Try again, continue looping + } + Err(e) => return ReadState::Error(e), + } + } + } + + /// Wait for existing background task to finish, without starting any new background reads. + async fn wait_ready(&mut self) { + match &mut self.state { + ReadState::Waiting(task) => { + self.state = task.await; + } + _ => {} + } } } +#[async_trait::async_trait] +impl HostInputStream for FileInputStream { + fn read(&mut self, size: usize) -> StreamResult { + match &mut self.state { + ReadState::Idle => { + if size == 0 { + return Ok(Bytes::new()); + } -fn read_result(r: io::Result, size: usize) -> Result { - match r { - Ok(0) if size > 0 => Err(StreamError::Closed), - Ok(n) => Ok(n), - Err(e) if e.kind() == std::io::ErrorKind::Interrupted => Ok(0), - Err(e) => Err(StreamError::LastOperationFailed(e.into())), + let p = self.position; + self.state = ReadState::Waiting( + self.file + .spawn_blocking(move |f| Self::blocking_read(f, p, size)), + ); + Ok(Bytes::new()) + } + ReadState::DataAvailable(b) => { + let min_len = b.len().min(size); + let chunk = b.split_to(min_len); + if b.len() == 0 { + self.state = ReadState::Idle; + } + self.position += min_len as u64; + Ok(chunk) + } + ReadState::Waiting(_) => Ok(Bytes::new()), + ReadState::Error(_) => match mem::replace(&mut self.state, ReadState::Closed) { + ReadState::Error(e) => Err(StreamError::LastOperationFailed(e.into())), + _ => unreachable!(), + }, + ReadState::Closed => Err(StreamError::Closed), + } + } + /// Specialized blocking_* variant to bypass tokio's task spawning & joining + /// overhead on synchronous file I/O. + async fn blocking_read(&mut self, size: usize) -> StreamResult { + self.wait_ready().await; + + // Before we defer to the regular `read`, make sure it has data ready to go: + if let ReadState::Idle = self.state { + let p = self.position; + self.state = self + .file + .run_blocking(move |f| Self::blocking_read(f, p, size)) + .await; + } + + self.read(size) + } + async fn cancel(&mut self) { + match mem::replace(&mut self.state, ReadState::Closed) { + ReadState::Waiting(task) => { + // The task was created using `spawn_blocking`, so unless we're + // lucky enough that the task hasn't started yet, the abort + // signal won't have any effect and we're forced to wait for it + // to run to completion. + // From the guest's point of view, `input-stream::drop` then + // appears to block. Certainly less than ideal, but arguably still + // better than letting the guest rack up an unbounded number of + // background tasks. Also, the guest is only blocked if + // the stream was dropped mid-read, which we don't expect to + // occur frequently. + task.abort_wait().await; + } + _ => {} + } + } +} +#[async_trait::async_trait] +impl Subscribe for FileInputStream { + async fn ready(&mut self) { + if let ReadState::Idle = self.state { + // The guest hasn't initiated any read, but is nonetheless waiting + // for data to be available. We'll start a read for them: + + const DEFAULT_READ_SIZE: usize = 4096; + let p = self.position; + self.state = ReadState::Waiting( + self.file + .spawn_blocking(move |f| Self::blocking_read(f, p, DEFAULT_READ_SIZE)), + ); + } + + self.wait_ready().await } } @@ -316,14 +424,51 @@ impl FileOutputStream { state: OutputState::Ready, } } + + fn blocking_write( + file: &cap_std::fs::File, + mut buf: Bytes, + mode: FileOutputMode, + ) -> io::Result { + use system_interface::fs::FileIoExt; + + match mode { + FileOutputMode::Position(mut p) => { + let mut total = 0; + loop { + let nwritten = file.write_at(buf.as_ref(), p)?; + // afterwards buf contains [nwritten, len): + let _ = buf.split_to(nwritten); + p += nwritten as u64; + total += nwritten; + if buf.is_empty() { + break; + } + } + Ok(total) + } + FileOutputMode::Append => { + let mut total = 0; + loop { + let nwritten = file.append(buf.as_ref())?; + let _ = buf.split_to(nwritten); + total += nwritten; + if buf.is_empty() { + break; + } + } + Ok(total) + } + } + } } // FIXME: configurable? determine from how much space left in file? const FILE_WRITE_CAPACITY: usize = 1024 * 1024; +#[async_trait::async_trait] impl HostOutputStream for FileOutputStream { fn write(&mut self, buf: Bytes) -> Result<(), StreamError> { - use system_interface::fs::FileIoExt; match self.state { OutputState::Ready => {} OutputState::Closed => return Err(StreamError::Closed), @@ -336,49 +481,45 @@ impl HostOutputStream for FileOutputStream { } let m = self.mode; - let result = self.file._spawn_blocking(move |f| { - match m { - FileOutputMode::Position(mut p) => { - let mut total = 0; - let mut buf = buf; - loop { - let nwritten = f.write_at(buf.as_ref(), p)?; - // afterwards buf contains [nwritten, len): - let _ = buf.split_to(nwritten); - p += nwritten as u64; - total += nwritten; - if buf.is_empty() { - break; - } - } - Ok(total) - } - FileOutputMode::Append => { - let mut total = 0; - let mut buf = buf; - loop { - let nwritten = f.append(buf.as_ref())?; - let _ = buf.split_to(nwritten); - total += nwritten; - if buf.is_empty() { - break; - } - } - Ok(total) - } - } - }); - self.state = match result { - SpawnBlocking::Done(Ok(nwritten)) => { + self.state = OutputState::Waiting( + self.file + .spawn_blocking(move |f| Self::blocking_write(f, buf, m)), + ); + Ok(()) + } + /// Specialized blocking_* variant to bypass tokio's task spawning & joining + /// overhead on synchronous file I/O. + async fn blocking_write_and_flush(&mut self, buf: Bytes) -> StreamResult<()> { + self.ready().await; + + match self.state { + OutputState::Ready => {} + OutputState::Closed => return Err(StreamError::Closed), + OutputState::Error(_) => match mem::replace(&mut self.state, OutputState::Closed) { + OutputState::Error(e) => return Err(StreamError::LastOperationFailed(e.into())), + _ => unreachable!(), + }, + OutputState::Waiting(_) => unreachable!("we've just waited for readiness"), + } + + let m = self.mode; + match self + .file + .run_blocking(move |f| Self::blocking_write(f, buf, m)) + .await + { + Ok(nwritten) => { if let FileOutputMode::Position(ref mut p) = &mut self.mode { *p += nwritten as u64; } - OutputState::Ready + self.state = OutputState::Ready; + Ok(()) } - SpawnBlocking::Done(Err(e)) => OutputState::Error(e), - SpawnBlocking::Spawned(task) => OutputState::Waiting(task), - }; - Ok(()) + Err(e) => { + self.state = OutputState::Closed; + Err(StreamError::LastOperationFailed(e.into())) + } + } } fn flush(&mut self) -> Result<(), StreamError> { match self.state { @@ -404,6 +545,24 @@ impl HostOutputStream for FileOutputStream { OutputState::Waiting(_) => Ok(0), } } + async fn cancel(&mut self) { + match mem::replace(&mut self.state, OutputState::Closed) { + OutputState::Waiting(task) => { + // The task was created using `spawn_blocking`, so unless we're + // lucky enough that the task hasn't started yet, the abort + // signal won't have any effect and we're forced to wait for it + // to run to completion. + // From the guest's point of view, `output-stream::drop` then + // appears to block. Certainly less than ideal, but arguably still + // better than letting the guest rack up an unbounded number of + // background tasks. Also, the guest is only blocked if + // the stream was dropped mid-write, which we don't expect to + // occur frequently. + task.abort_wait().await; + } + _ => {} + } + } } #[async_trait::async_trait] diff --git a/crates/wasi/src/host/filesystem.rs b/crates/wasi/src/host/filesystem.rs index d3ac08f7706e..04cc8670c4ad 100644 --- a/crates/wasi/src/host/filesystem.rs +++ b/crates/wasi/src/host/filesystem.rs @@ -82,7 +82,7 @@ where }; let f = self.table().get(&fd)?.file()?; - f.spawn_blocking(move |f| f.advise(offset, len, advice)) + f.run_blocking(move |f| f.advise(offset, len, advice)) .await?; Ok(()) } @@ -92,7 +92,7 @@ where match descriptor { Descriptor::File(f) => { - match f.spawn_blocking(|f| f.sync_data()).await { + match f.run_blocking(|f| f.sync_data()).await { Ok(()) => Ok(()), // On windows, `sync_data` uses `FileFlushBuffers` which fails with // `ERROR_ACCESS_DENIED` if the file is not upen for writing. Ignore @@ -108,7 +108,7 @@ where } } Descriptor::Dir(d) => { - d.spawn_blocking(|d| Ok(d.open(std::path::Component::CurDir)?.sync_data()?)) + d.run_blocking(|d| Ok(d.open(std::path::Component::CurDir)?.sync_data()?)) .await } } @@ -138,7 +138,7 @@ where let descriptor = self.table().get(&fd)?; match descriptor { Descriptor::File(f) => { - let flags = f.spawn_blocking(|f| f.get_fd_flags()).await?; + let flags = f.run_blocking(|f| f.get_fd_flags()).await?; let mut flags = get_from_fdflags(flags); if f.open_mode.contains(OpenMode::READ) { flags |= DescriptorFlags::READ; @@ -149,7 +149,7 @@ where Ok(flags) } Descriptor::Dir(d) => { - let flags = d.spawn_blocking(|d| d.get_fd_flags()).await?; + let flags = d.run_blocking(|d| d.get_fd_flags()).await?; let mut flags = get_from_fdflags(flags); if d.open_mode.contains(OpenMode::READ) { flags |= DescriptorFlags::READ; @@ -170,7 +170,7 @@ where match descriptor { Descriptor::File(f) => { - let meta = f.spawn_blocking(|f| f.metadata()).await?; + let meta = f.run_blocking(|f| f.metadata()).await?; Ok(descriptortype_from(meta.file_type())) } Descriptor::Dir(_) => Ok(types::DescriptorType::Directory), @@ -186,7 +186,7 @@ where if !f.perms.contains(FilePerms::WRITE) { Err(ErrorCode::NotPermitted)?; } - f.spawn_blocking(move |f| f.set_len(size)).await?; + f.run_blocking(move |f| f.set_len(size)).await?; Ok(()) } @@ -206,7 +206,7 @@ where } let atim = systemtimespec_from(atim)?; let mtim = systemtimespec_from(mtim)?; - f.spawn_blocking(|f| f.set_times(atim, mtim)).await?; + f.run_blocking(|f| f.set_times(atim, mtim)).await?; Ok(()) } Descriptor::Dir(d) => { @@ -215,7 +215,7 @@ where } let atim = systemtimespec_from(atim)?; let mtim = systemtimespec_from(mtim)?; - d.spawn_blocking(|d| d.set_times(atim, mtim)).await?; + d.run_blocking(|d| d.set_times(atim, mtim)).await?; Ok(()) } } @@ -238,7 +238,7 @@ where } let (mut buffer, r) = f - .spawn_blocking(move |f| { + .run_blocking(move |f| { let mut buffer = vec![0; len.try_into().unwrap_or(usize::MAX)]; let r = f.read_vectored_at(&mut [IoSliceMut::new(&mut buffer)], offset); (buffer, r) @@ -275,7 +275,7 @@ where } let bytes_written = f - .spawn_blocking(move |f| f.write_vectored_at(&[IoSlice::new(&buf)], offset)) + .run_blocking(move |f| f.write_vectored_at(&[IoSlice::new(&buf)], offset)) .await?; Ok(types::Filesize::try_from(bytes_written).expect("usize fits in Filesize")) @@ -302,7 +302,7 @@ where } let entries = d - .spawn_blocking(|d| { + .run_blocking(|d| { // Both `entries` and `metadata` perform syscalls, which is why they are done // within this `block` call, rather than delay calculating the metadata // for entries when they're demanded later in the iterator chain. @@ -351,7 +351,7 @@ where match descriptor { Descriptor::File(f) => { - match f.spawn_blocking(|f| f.sync_all()).await { + match f.run_blocking(|f| f.sync_all()).await { Ok(()) => Ok(()), // On windows, `sync_data` uses `FileFlushBuffers` which fails with // `ERROR_ACCESS_DENIED` if the file is not upen for writing. Ignore @@ -367,7 +367,7 @@ where } } Descriptor::Dir(d) => { - d.spawn_blocking(|d| Ok(d.open(std::path::Component::CurDir)?.sync_all()?)) + d.run_blocking(|d| Ok(d.open(std::path::Component::CurDir)?.sync_all()?)) .await } } @@ -383,7 +383,7 @@ where if !d.perms.contains(DirPerms::MUTATE) { return Err(ErrorCode::NotPermitted.into()); } - d.spawn_blocking(move |d| d.create_dir(&path)).await?; + d.run_blocking(move |d| d.create_dir(&path)).await?; Ok(()) } @@ -392,12 +392,12 @@ where match descriptor { Descriptor::File(f) => { // No permissions check on stat: if opened, allowed to stat it - let meta = f.spawn_blocking(|f| f.metadata()).await?; + let meta = f.run_blocking(|f| f.metadata()).await?; Ok(descriptorstat_from(meta)) } Descriptor::Dir(d) => { // No permissions check on stat: if opened, allowed to stat it - let meta = d.spawn_blocking(|d| d.dir_metadata()).await?; + let meta = d.run_blocking(|d| d.dir_metadata()).await?; Ok(descriptorstat_from(meta)) } } @@ -416,9 +416,9 @@ where } let meta = if symlink_follow(path_flags) { - d.spawn_blocking(move |d| d.metadata(&path)).await? + d.run_blocking(move |d| d.metadata(&path)).await? } else { - d.spawn_blocking(move |d| d.symlink_metadata(&path)).await? + d.run_blocking(move |d| d.symlink_metadata(&path)).await? }; Ok(descriptorstat_from(meta)) } @@ -441,7 +441,7 @@ where let atim = systemtimespec_from(atim)?; let mtim = systemtimespec_from(mtim)?; if symlink_follow(path_flags) { - d.spawn_blocking(move |d| { + d.run_blocking(move |d| { d.set_times( &path, atim.map(cap_fs_ext::SystemTimeSpec::from_std), @@ -450,7 +450,7 @@ where }) .await?; } else { - d.spawn_blocking(move |d| { + d.run_blocking(move |d| { d.set_symlink_times( &path, atim.map(cap_fs_ext::SystemTimeSpec::from_std), @@ -485,7 +485,7 @@ where } let new_dir_handle = std::sync::Arc::clone(&new_dir.dir); old_dir - .spawn_blocking(move |d| d.hard_link(&old_path, &new_dir_handle, &new_path)) + .run_blocking(move |d| d.hard_link(&old_path, &new_dir_handle, &new_path)) .await?; Ok(()) } @@ -595,7 +595,7 @@ where } let opened = d - .spawn_blocking::<_, std::io::Result>(move |d| { + .run_blocking::<_, std::io::Result>(move |d| { let mut opened = d.open_with(&path, &opts)?; if opened.metadata()?.is_dir() { Ok(OpenResult::Dir(cap_std::fs::Dir::from_std_file( @@ -656,7 +656,7 @@ where if !d.perms.contains(DirPerms::READ) { return Err(ErrorCode::NotPermitted.into()); } - let link = d.spawn_blocking(move |d| d.read_link(&path)).await?; + let link = d.run_blocking(move |d| d.read_link(&path)).await?; Ok(link .into_os_string() .into_string() @@ -673,7 +673,7 @@ where if !d.perms.contains(DirPerms::MUTATE) { return Err(ErrorCode::NotPermitted.into()); } - Ok(d.spawn_blocking(move |d| d.remove_dir(&path)).await?) + Ok(d.run_blocking(move |d| d.remove_dir(&path)).await?) } async fn rename_at( @@ -694,7 +694,7 @@ where } let new_dir_handle = std::sync::Arc::clone(&new_dir.dir); Ok(old_dir - .spawn_blocking(move |d| d.rename(&old_path, &new_dir_handle, &new_path)) + .run_blocking(move |d| d.rename(&old_path, &new_dir_handle, &new_path)) .await?) } @@ -713,7 +713,7 @@ where if !d.perms.contains(DirPerms::MUTATE) { return Err(ErrorCode::NotPermitted.into()); } - Ok(d.spawn_blocking(move |d| d.symlink(&src_path, &dest_path)) + Ok(d.run_blocking(move |d| d.symlink(&src_path, &dest_path)) .await?) } @@ -729,7 +729,7 @@ where if !d.perms.contains(DirPerms::MUTATE) { return Err(ErrorCode::NotPermitted.into()); } - Ok(d.spawn_blocking(move |d| d.remove_file_or_symlink(&path)) + Ok(d.run_blocking(move |d| d.remove_file_or_symlink(&path)) .await?) } @@ -746,10 +746,10 @@ where } // Create a stream view for it. - let reader = FileInputStream::new(f, offset); + let reader: InputStream = Box::new(FileInputStream::new(f, offset)); // Insert the stream view into the table. Trap if the table is full. - let index = self.table().push(InputStream::File(reader))?; + let index = self.table().push(reader)?; Ok(index) } @@ -842,7 +842,7 @@ where let d = table.get(&fd)?.dir()?; // No permissions check on metadata: if dir opened, allowed to stat it let meta = d - .spawn_blocking(move |d| { + .run_blocking(move |d| { if symlink_follow(path_flags) { d.metadata(path) } else { @@ -878,11 +878,11 @@ async fn get_descriptor_metadata(fd: &types::Descriptor) -> FsResult { // No permissions check on metadata: if opened, allowed to stat it - Ok(f.spawn_blocking(|f| f.metadata()).await?) + Ok(f.run_blocking(|f| f.metadata()).await?) } Descriptor::Dir(d) => { // No permissions check on metadata: if opened, allowed to stat it - Ok(d.spawn_blocking(|d| d.dir_metadata()).await?) + Ok(d.run_blocking(|d| d.dir_metadata()).await?) } } } diff --git a/crates/wasi/src/host/io.rs b/crates/wasi/src/host/io.rs index 9f257af5aa11..b5cecca8132c 100644 --- a/crates/wasi/src/host/io.rs +++ b/crates/wasi/src/host/io.rs @@ -42,8 +42,8 @@ impl streams::HostOutputStream for WasiImpl where T: WasiView, { - fn drop(&mut self, stream: Resource) -> anyhow::Result<()> { - self.table().delete(stream)?; + async fn drop(&mut self, stream: Resource) -> anyhow::Result<()> { + self.table().delete(stream)?.cancel().await; Ok(()) } @@ -66,29 +66,16 @@ where stream: Resource, bytes: Vec, ) -> StreamResult<()> { - let s = self.table().get_mut(&stream)?; - if bytes.len() > 4096 { return Err(StreamError::trap( "Buffer too large for blocking-write-and-flush (expected at most 4096)", )); } - let mut bytes = bytes::Bytes::from(bytes); - loop { - let permit = s.write_ready().await?; - let len = bytes.len().min(permit); - let chunk = bytes.split_to(len); - s.write(chunk)?; - if bytes.is_empty() { - break; - } - } - - s.flush()?; - s.write_ready().await?; - - Ok(()) + self.table() + .get_mut(&stream)? + .blocking_write_and_flush(bytes.into()) + .await } async fn blocking_write_zeroes_and_flush( @@ -96,26 +83,16 @@ where stream: Resource, len: u64, ) -> StreamResult<()> { - let s = self.table().get_mut(&stream)?; - if len > 4096 { return Err(StreamError::trap( "Buffer too large for blocking-write-zeroes-and-flush (expected at most 4096)", )); } - let mut len = len; - while len > 0 { - let permit = s.write_ready().await?; - let this_len = len.min(permit as u64); - s.write_zeroes(this_len as usize)?; - len -= this_len; - } - - s.flush()?; - s.write_ready().await?; - - Ok(()) + self.table() + .get_mut(&stream)? + .blocking_write_zeroes_and_flush(len as usize) + .await } fn write_zeroes(&mut self, stream: Resource, len: u64) -> StreamResult<()> { @@ -135,7 +112,7 @@ where Ok(()) } - async fn splice( + fn splice( &mut self, dest: Resource, src: Resource, @@ -152,10 +129,7 @@ where return Ok(0); } - let contents = match self.table().get_mut(&src)? { - InputStream::Host(h) => h.read(len)?, - InputStream::File(f) => f.read(len).await?, - }; + let contents = self.table().get_mut(&src)?.read(len)?; let len = contents.len(); if len == 0 { @@ -173,13 +147,27 @@ where src: Resource, len: u64, ) -> StreamResult { - use crate::Subscribe; + let len = len.try_into().unwrap_or(usize::MAX); + + let permit = { + let output = self.table().get_mut(&dest)?; + output.write_ready().await? + }; + let len = len.min(permit); + if len == 0 { + return Ok(0); + } - self.table().get_mut(&dest)?.ready().await; + let contents = self.table().get_mut(&src)?.blocking_read(len).await?; - self.table().get_mut(&src)?.ready().await; + let len = contents.len(); + if len == 0 { + return Ok(0); + } - self.splice(dest, src, len).await + let output = self.table().get_mut(&dest)?; + output.blocking_write_and_flush(contents).await?; + Ok(len.try_into().expect("usize can fit in u64")) } } @@ -188,17 +176,14 @@ impl streams::HostInputStream for WasiImpl where T: WasiView, { - fn drop(&mut self, stream: Resource) -> anyhow::Result<()> { - self.table().delete(stream)?; + async fn drop(&mut self, stream: Resource) -> anyhow::Result<()> { + self.table().delete(stream)?.cancel().await; Ok(()) } - async fn read(&mut self, stream: Resource, len: u64) -> StreamResult> { + fn read(&mut self, stream: Resource, len: u64) -> StreamResult> { let len = len.try_into().unwrap_or(usize::MAX); - let bytes = match self.table().get_mut(&stream)? { - InputStream::Host(s) => s.read(len)?, - InputStream::File(s) => s.read(len).await?, - }; + let bytes = self.table().get_mut(&stream)?.read(len)?; debug_assert!(bytes.len() <= len); Ok(bytes.into()) } @@ -208,18 +193,15 @@ where stream: Resource, len: u64, ) -> StreamResult> { - if let InputStream::Host(s) = self.table().get_mut(&stream)? { - s.ready().await; - } - self.read(stream, len).await + let len = len.try_into().unwrap_or(usize::MAX); + let bytes = self.table().get_mut(&stream)?.blocking_read(len).await?; + debug_assert!(bytes.len() <= len); + Ok(bytes.into()) } - async fn skip(&mut self, stream: Resource, len: u64) -> StreamResult { + fn skip(&mut self, stream: Resource, len: u64) -> StreamResult { let len = len.try_into().unwrap_or(usize::MAX); - let written = match self.table().get_mut(&stream)? { - InputStream::Host(s) => s.skip(len)?, - InputStream::File(s) => s.skip(len).await?, - }; + let written = self.table().get_mut(&stream)?.skip(len)?; Ok(written.try_into().expect("usize always fits in u64")) } @@ -228,10 +210,9 @@ where stream: Resource, len: u64, ) -> StreamResult { - if let InputStream::Host(s) = self.table().get_mut(&stream)? { - s.ready().await; - } - self.skip(stream, len).await + let len = len.try_into().unwrap_or(usize::MAX); + let written = self.table().get_mut(&stream)?.blocking_skip(len).await?; + Ok(written.try_into().expect("usize always fits in u64")) } fn subscribe(&mut self, stream: Resource) -> anyhow::Result> { @@ -278,7 +259,7 @@ pub mod sync { T: WasiView, { fn drop(&mut self, stream: Resource) -> anyhow::Result<()> { - AsyncHostOutputStream::drop(self, stream) + in_tokio(async { AsyncHostOutputStream::drop(self, stream).await }) } fn check_write(&mut self, stream: Resource) -> StreamResult { @@ -340,7 +321,7 @@ pub mod sync { src: Resource, len: u64, ) -> StreamResult { - in_tokio(async { AsyncHostOutputStream::splice(self, dst, src, len).await }) + AsyncHostOutputStream::splice(self, dst, src, len) } fn blocking_splice( @@ -358,11 +339,11 @@ pub mod sync { T: WasiView, { fn drop(&mut self, stream: Resource) -> anyhow::Result<()> { - AsyncHostInputStream::drop(self, stream) + in_tokio(async { AsyncHostInputStream::drop(self, stream).await }) } fn read(&mut self, stream: Resource, len: u64) -> StreamResult> { - in_tokio(async { AsyncHostInputStream::read(self, stream, len).await }) + AsyncHostInputStream::read(self, stream, len) } fn blocking_read( @@ -374,7 +355,7 @@ pub mod sync { } fn skip(&mut self, stream: Resource, len: u64) -> StreamResult { - in_tokio(async { AsyncHostInputStream::skip(self, stream, len).await }) + AsyncHostInputStream::skip(self, stream, len) } fn blocking_skip(&mut self, stream: Resource, len: u64) -> StreamResult { diff --git a/crates/wasi/src/pipe.rs b/crates/wasi/src/pipe.rs index abc71888beb6..4b723cc0c13a 100644 --- a/crates/wasi/src/pipe.rs +++ b/crates/wasi/src/pipe.rs @@ -112,7 +112,7 @@ pub struct AsyncReadStream { closed: bool, buffer: Option>, receiver: mpsc::Receiver>, - _join_handle: crate::runtime::AbortOnDropJoinHandle<()>, + join_handle: Option>, } impl AsyncReadStream { @@ -143,7 +143,7 @@ impl AsyncReadStream { closed: false, buffer: None, receiver, - _join_handle: join_handle, + join_handle: Some(join_handle), } } } @@ -190,6 +190,13 @@ impl HostInputStream for AsyncReadStream { ))), } } + + async fn cancel(&mut self) { + match self.join_handle.take() { + Some(task) => _ = task.abort_wait().await, + None => {} + } + } } #[async_trait::async_trait] impl Subscribe for AsyncReadStream { diff --git a/crates/wasi/src/preview1.rs b/crates/wasi/src/preview1.rs index ee4178c9ce73..62022ab109e5 100644 --- a/crates/wasi/src/preview1.rs +++ b/crates/wasi/src/preview1.rs @@ -183,6 +183,12 @@ struct File { blocking_mode: BlockingMode, } +/// NB: preview1 files always use blocking writes regardless of what +/// they're configured to use since OSes don't have nonblocking +/// reads/writes anyway. This behavior originated in the first +/// implementation of WASIp1 where flags were propagated to the +/// OS and the OS ignored the nonblocking flag for files +/// generally. #[derive(Clone, Copy, Debug)] enum BlockingMode { Blocking, @@ -203,23 +209,11 @@ impl BlockingMode { max_size: usize, ) -> Result, types::Error> { let max_size = max_size.try_into().unwrap_or(u64::MAX); - match self { - BlockingMode::Blocking => { - match streams::HostInputStream::blocking_read(host, input_stream, max_size).await { - Ok(r) if r.is_empty() => Err(types::Errno::Intr.into()), - Ok(r) => Ok(r), - Err(StreamError::Closed) => Ok(Vec::new()), - Err(e) => Err(e.into()), - } - } - - BlockingMode::NonBlocking => { - match streams::HostInputStream::read(host, input_stream, max_size).await { - Ok(r) => Ok(r), - Err(StreamError::Closed) => Ok(Vec::new()), - Err(e) => Err(e.into()), - } - } + match streams::HostInputStream::blocking_read(host, input_stream, max_size).await { + Ok(r) if r.is_empty() => Err(types::Errno::Intr.into()), + Ok(r) => Ok(r), + Err(StreamError::Closed) => Ok(Vec::new()), + Err(e) => Err(e.into()), } } async fn write( @@ -236,52 +230,18 @@ impl BlockingMode { .map_err(|e| StreamError::Trap(e.into()))?; let mut bytes = &bytes[..]; - match self { - BlockingMode::Blocking => { - let total = bytes.len(); - while !bytes.is_empty() { - // NOTE: blocking_write_and_flush takes at most one 4k buffer. - let len = bytes.len().min(4096); - let (chunk, rest) = bytes.split_at(len); - bytes = rest; - - Streams::blocking_write_and_flush( - host, - output_stream.borrowed(), - Vec::from(chunk), - ) - .await? - } + let total = bytes.len(); + while !bytes.is_empty() { + // NOTE: blocking_write_and_flush takes at most one 4k buffer. + let len = bytes.len().min(4096); + let (chunk, rest) = bytes.split_at(len); + bytes = rest; - Ok(total) - } - BlockingMode::NonBlocking => { - let n = match Streams::check_write(host, output_stream.borrowed()) { - Ok(n) => n, - Err(StreamError::Closed) => 0, - Err(e) => Err(e)?, - }; - - let len = bytes.len().min(n as usize); - if len == 0 { - return Ok(0); - } - - match Streams::write(host, output_stream.borrowed(), bytes[..len].to_vec()) { - Ok(()) => {} - Err(StreamError::Closed) => return Ok(0), - Err(e) => Err(e)?, - } - - match Streams::blocking_flush(host, output_stream.borrowed()).await { - Ok(()) => {} - Err(StreamError::Closed) => return Ok(0), - Err(e) => Err(e)?, - }; - - Ok(len) - } + Streams::blocking_write_and_flush(host, output_stream.borrowed(), Vec::from(chunk)) + .await? } + + Ok(total) } } @@ -655,7 +615,7 @@ impl WasiP1Ctx { // block. None => { let buf = memory.to_vec(buf)?; - f.spawn_blocking(move |f| do_write(f, &buf)).await + f.run_blocking(move |f| do_write(f, &buf)).await } }; @@ -1368,10 +1328,12 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiP1Ctx { match desc { Descriptor::Stdin { stream, .. } => { streams::HostInputStream::drop(&mut self.as_wasi_impl(), stream) + .await .context("failed to call `drop` on `input-stream`") } Descriptor::Stdout { stream, .. } | Descriptor::Stderr { stream, .. } => { streams::HostOutputStream::drop(&mut self.as_wasi_impl(), stream) + .await .context("failed to call `drop` on `output-stream`") } Descriptor::File(File { fd, .. }) | Descriptor::Directory { fd, .. } => { @@ -1728,7 +1690,7 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiP1Ctx { drop(buf); let mut buf = vec![0; iov.len() as usize]; let buf = file - .spawn_blocking(move |file| -> Result<_, types::Error> { + .run_blocking(move |file| -> Result<_, types::Error> { let bytes_read = file .read_at(&mut buf, pos) .map_err(|e| StreamError::LastOperationFailed(e.into()))?; @@ -1805,6 +1767,7 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiP1Ctx { ) .await; streams::HostInputStream::drop(&mut self.as_wasi_impl(), stream) + .await .map_err(|e| types::Error::trap(e))?; (buf, read?) } diff --git a/crates/wasi/src/runtime.rs b/crates/wasi/src/runtime.rs index da41faa38e9e..dc29fef97239 100644 --- a/crates/wasi/src/runtime.rs +++ b/crates/wasi/src/runtime.rs @@ -39,6 +39,18 @@ pub(crate) static RUNTIME: once_cell::sync::Lazy = /// by keeping this handle owned by the Resource. #[derive(Debug)] pub struct AbortOnDropJoinHandle(tokio::task::JoinHandle); +impl AbortOnDropJoinHandle { + /// Abort the task and wait for it to finish. Optionally returns the result + /// of the task if it ran to completion prior to being aborted. + pub(crate) async fn abort_wait(mut self) -> Option { + self.0.abort(); + match (&mut self.0).await { + Ok(value) => Some(value), + Err(err) if err.is_cancelled() => None, + Err(err) => std::panic::resume_unwind(err.into_panic()), + } + } +} impl Drop for AbortOnDropJoinHandle { fn drop(&mut self) { self.0.abort() diff --git a/crates/wasi/src/stdio.rs b/crates/wasi/src/stdio.rs index 72658ec27c97..4fb359f394fb 100644 --- a/crates/wasi/src/stdio.rs +++ b/crates/wasi/src/stdio.rs @@ -103,6 +103,7 @@ impl StdinStream for AsyncStdinStream { } } +#[async_trait::async_trait] impl HostInputStream for AsyncStdinStream { fn read(&mut self, size: usize) -> Result { match self.0.try_lock() { @@ -116,6 +117,15 @@ impl HostInputStream for AsyncStdinStream { Err(_) => Err(StreamError::trap("concurrent skips are not supported")), } } + async fn cancel(&mut self) { + // Cancel the inner stream if we're the last reference to it: + if let Some(mutex) = Arc::get_mut(&mut self.0) { + match mutex.try_lock() { + Ok(mut stream) => stream.cancel().await, + Err(_) => {} + } + } + } } #[async_trait::async_trait] @@ -355,6 +365,7 @@ impl StdoutStream for AsyncStdoutStream { // won't attempt to interleave async IO from these disparate uses of stdio. // If that expectation doesn't turn out to be true, and you find yourself at // this comment to correct it: sorry about that. +#[async_trait::async_trait] impl HostOutputStream for AsyncStdoutStream { fn check_write(&mut self) -> Result { match self.0.try_lock() { @@ -374,6 +385,15 @@ impl HostOutputStream for AsyncStdoutStream { Err(_) => Err(StreamError::trap("concurrent flushes not supported yet")), } } + async fn cancel(&mut self) { + // Cancel the inner stream if we're the last reference to it: + if let Some(mutex) = Arc::get_mut(&mut self.0) { + match mutex.try_lock() { + Ok(mut stream) => stream.cancel().await, + Err(_) => {} + } + } + } } #[async_trait::async_trait] @@ -395,7 +415,7 @@ where { fn get_stdin(&mut self) -> Result, anyhow::Error> { let stream = self.ctx().stdin.stream(); - Ok(self.table().push(streams::InputStream::Host(stream))?) + Ok(self.table().push(stream)?) } } diff --git a/crates/wasi/src/stream.rs b/crates/wasi/src/stream.rs index 4fcb4c46551e..251133cdac66 100644 --- a/crates/wasi/src/stream.rs +++ b/crates/wasi/src/stream.rs @@ -1,4 +1,3 @@ -use crate::filesystem::FileInputStream; use crate::poll::Subscribe; use anyhow::Result; use bytes::Bytes; @@ -21,6 +20,13 @@ pub trait HostInputStream: Subscribe { /// closed, when a read fails, or when a trap should be generated. fn read(&mut self, size: usize) -> StreamResult; + /// Similar to `read`, except that it blocks until at least one byte can be + /// read. + async fn blocking_read(&mut self, size: usize) -> StreamResult { + self.ready().await; + self.read(size) + } + /// Same as the `read` method except that bytes are skipped. /// /// Note that this method is non-blocking like `read` and returns the same @@ -29,6 +35,16 @@ pub trait HostInputStream: Subscribe { let bs = self.read(nelem)?; Ok(bs.len()) } + + /// Similar to `skip`, except that it blocks until at least one byte can be + /// skipped. + async fn blocking_skip(&mut self, nelem: usize) -> StreamResult { + let bs = self.blocking_read(nelem).await?; + Ok(bs.len()) + } + + /// Cancel any asynchronous work and wait for it to wrap up. + async fn cancel(&mut self) {} } /// Representation of the `error` resource type in the `wasi:io/error` @@ -135,6 +151,47 @@ pub trait HostOutputStream: Subscribe { /// - prior operation ([`write`](Self::write) or [`flush`](Self::flush)) failed fn check_write(&mut self) -> StreamResult; + /// Perform a write of up to 4096 bytes, and then flush the stream. Block + /// until all of these operations are complete, or an error occurs. + /// + /// This is a convenience wrapper around the use of `check-write`, + /// `subscribe`, `write`, and `flush`, and is implemented with the + /// following pseudo-code: + /// + /// ```text + /// let pollable = this.subscribe(); + /// while !contents.is_empty() { + /// // Wait for the stream to become writable + /// pollable.block(); + /// let Ok(n) = this.check-write(); // eliding error handling + /// let len = min(n, contents.len()); + /// let (chunk, rest) = contents.split_at(len); + /// this.write(chunk ); // eliding error handling + /// contents = rest; + /// } + /// this.flush(); + /// // Wait for completion of `flush` + /// pollable.block(); + /// // Check for any errors that arose during `flush` + /// let _ = this.check-write(); // eliding error handling + /// ``` + async fn blocking_write_and_flush(&mut self, mut bytes: Bytes) -> StreamResult<()> { + loop { + let permit = self.write_ready().await?; + let len = bytes.len().min(permit); + let chunk = bytes.split_to(len); + self.write(chunk)?; + if bytes.is_empty() { + break; + } + } + + self.flush()?; + self.write_ready().await?; + + Ok(()) + } + /// Repeatedly write a byte to a stream. /// Important: this write must be non-blocking! /// Returning an Err which downcasts to a [`StreamError`] will be @@ -147,12 +204,46 @@ pub trait HostOutputStream: Subscribe { Ok(()) } + /// Perform a write of up to 4096 zeroes, and then flush the stream. + /// Block until all of these operations are complete, or an error + /// occurs. + /// + /// This is a convenience wrapper around the use of `check-write`, + /// `subscribe`, `write-zeroes`, and `flush`, and is implemented with + /// the following pseudo-code: + /// + /// ```text + /// let pollable = this.subscribe(); + /// while num_zeroes != 0 { + /// // Wait for the stream to become writable + /// pollable.block(); + /// let Ok(n) = this.check-write(); // eliding error handling + /// let len = min(n, num_zeroes); + /// this.write-zeroes(len); // eliding error handling + /// num_zeroes -= len; + /// } + /// this.flush(); + /// // Wait for completion of `flush` + /// pollable.block(); + /// // Check for any errors that arose during `flush` + /// let _ = this.check-write(); // eliding error handling + /// ``` + async fn blocking_write_zeroes_and_flush(&mut self, nelem: usize) -> StreamResult<()> { + // TODO: We could optimize this to not allocate one big zeroed buffer, and instead write + // repeatedly from a 'static buffer of zeros. + let bs = Bytes::from_iter(core::iter::repeat(0).take(nelem)); + self.blocking_write_and_flush(bs).await + } + /// Simultaneously waits for this stream to be writable and then returns how /// much may be written or the last error that happened. async fn write_ready(&mut self) -> StreamResult { self.ready().await; self.check_write() } + + /// Cancel any asynchronous work and wait for it to wrap up. + async fn cancel(&mut self) {} } #[async_trait::async_trait] @@ -162,20 +253,13 @@ impl Subscribe for Box { } } -pub enum InputStream { - Host(Box), - File(FileInputStream), -} - #[async_trait::async_trait] -impl Subscribe for InputStream { +impl Subscribe for Box { async fn ready(&mut self) { - match self { - InputStream::Host(stream) => stream.ready().await, - // Files are always ready - InputStream::File(_) => {} - } + (**self).ready().await } } +pub type InputStream = Box; + pub type OutputStream = Box; diff --git a/crates/wasi/src/tcp.rs b/crates/wasi/src/tcp.rs index 4bc811f49f03..32feef95a683 100644 --- a/crates/wasi/src/tcp.rs +++ b/crates/wasi/src/tcp.rs @@ -278,8 +278,7 @@ impl TcpSocket { Ok(stream) => { let stream = Arc::new(stream); self.tcp_state = TcpState::Connected(stream.clone()); - let input: InputStream = - InputStream::Host(Box::new(TcpReadStream::new(stream.clone()))); + let input: InputStream = Box::new(TcpReadStream::new(stream.clone())); let output: OutputStream = Box::new(TcpWriteStream::new(stream)); Ok((input, output)) } @@ -428,7 +427,7 @@ impl TcpSocket { let client = Arc::new(client); - let input: InputStream = InputStream::Host(Box::new(TcpReadStream::new(client.clone()))); + let input: InputStream = Box::new(TcpReadStream::new(client.clone())); let output: OutputStream = Box::new(TcpWriteStream::new(client.clone())); let tcp_socket = TcpSocket::from_state(TcpState::Connected(client), self.family)?; @@ -787,6 +786,7 @@ impl TcpWriteStream { } } +#[async_trait::async_trait] impl HostOutputStream for TcpWriteStream { fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> { match self.last_write { @@ -853,6 +853,12 @@ impl HostOutputStream for TcpWriteStream { } Ok(SOCKET_READY_SIZE) } + async fn cancel(&mut self) { + match mem::replace(&mut self.last_write, LastWrite::Closed) { + LastWrite::Waiting(task) => _ = task.abort_wait().await, + _ => {} + } + } } #[async_trait::async_trait] diff --git a/crates/wasi/src/write_stream.rs b/crates/wasi/src/write_stream.rs index e08546b60d4c..33ae910481b7 100644 --- a/crates/wasi/src/write_stream.rs +++ b/crates/wasi/src/write_stream.rs @@ -139,7 +139,7 @@ impl Worker { /// Provides a [`HostOutputStream`] impl from a [`tokio::io::AsyncWrite`] impl pub struct AsyncWriteStream { worker: Arc, - _join_handle: crate::runtime::AbortOnDropJoinHandle<()>, + join_handle: Option>, } impl AsyncWriteStream { @@ -156,11 +156,12 @@ impl AsyncWriteStream { AsyncWriteStream { worker, - _join_handle: join_handle, + join_handle: Some(join_handle), } } } +#[async_trait::async_trait] impl HostOutputStream for AsyncWriteStream { fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> { let mut state = self.worker.state(); @@ -194,6 +195,13 @@ impl HostOutputStream for AsyncWriteStream { fn check_write(&mut self) -> Result { self.worker.check_write() } + + async fn cancel(&mut self) { + match self.join_handle.take() { + Some(task) => _ = task.abort_wait().await, + None => {} + } + } } #[async_trait::async_trait] impl Subscribe for AsyncWriteStream {