diff --git a/crates/test-programs/tests/wasi-preview2-components-sync.rs b/crates/test-programs/tests/wasi-preview2-components-sync.rs index 85e8cea10ca4..37a03d2aa862 100644 --- a/crates/test-programs/tests/wasi-preview2-components-sync.rs +++ b/crates/test-programs/tests/wasi-preview2-components-sync.rs @@ -245,7 +245,6 @@ fn poll_oneoff_files() { run("poll_oneoff_files", false).unwrap() } -#[cfg_attr(windows, should_panic)] #[test_log::test] fn poll_oneoff_stdio() { run("poll_oneoff_stdio", true).unwrap() diff --git a/crates/test-programs/tests/wasi-preview2-components.rs b/crates/test-programs/tests/wasi-preview2-components.rs index 021438d55814..947a9eea04e2 100644 --- a/crates/test-programs/tests/wasi-preview2-components.rs +++ b/crates/test-programs/tests/wasi-preview2-components.rs @@ -251,7 +251,6 @@ async fn poll_oneoff_files() { run("poll_oneoff_files", false).await.unwrap() } -#[cfg_attr(windows, should_panic)] #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn poll_oneoff_stdio() { run("poll_oneoff_stdio", true).await.unwrap() diff --git a/crates/test-programs/wasi-tests/src/bin/poll_oneoff_stdio.rs b/crates/test-programs/wasi-tests/src/bin/poll_oneoff_stdio.rs index a66eb94ffe28..2ddc00475281 100644 --- a/crates/test-programs/wasi-tests/src/bin/poll_oneoff_stdio.rs +++ b/crates/test-programs/wasi-tests/src/bin/poll_oneoff_stdio.rs @@ -69,7 +69,6 @@ unsafe fn test_stdin_read() { } fn writable_subs(h: &HashMap) -> Vec { - println!("writable subs: {:?}", h); h.iter() .map(|(ud, fd)| wasi::Subscription { userdata: *ud, @@ -87,7 +86,7 @@ fn writable_subs(h: &HashMap) -> Vec { unsafe fn test_stdout_stderr_write() { let mut writable: HashMap = - vec![(1, STDOUT_FD), (2, STDERR_FD)].into_iter().collect(); + [(1, STDOUT_FD), (2, STDERR_FD)].into_iter().collect(); let clock = wasi::Subscription { userdata: CLOCK_ID, diff --git a/crates/wasi/Cargo.toml b/crates/wasi/Cargo.toml index 8eb123055740..ab6655a79a02 100644 --- a/crates/wasi/Cargo.toml +++ b/crates/wasi/Cargo.toml @@ -42,6 +42,9 @@ tokio = { workspace = true, features = ["time", "sync", "io-std", "io-util", "rt [target.'cfg(unix)'.dependencies] rustix = { workspace = true, features = ["fs"], optional = true } +[target.'cfg(unix)'.dev-dependencies] +libc = { workspace = true } + [target.'cfg(windows)'.dependencies] io-extras = { workspace = true } windows-sys = { workspace = true } diff --git a/crates/wasi/src/preview2/mod.rs b/crates/wasi/src/preview2/mod.rs index 28271d26070a..0d60e947372d 100644 --- a/crates/wasi/src/preview2/mod.rs +++ b/crates/wasi/src/preview2/mod.rs @@ -148,13 +148,14 @@ pub mod bindings { pub use self::_internal_rest::wasi::*; } -static RUNTIME: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { - tokio::runtime::Builder::new_multi_thread() - .enable_time() - .enable_io() - .build() - .unwrap() -}); +pub(crate) static RUNTIME: once_cell::sync::Lazy = + once_cell::sync::Lazy::new(|| { + tokio::runtime::Builder::new_current_thread() + .enable_time() + .enable_io() + .build() + .unwrap() + }); pub(crate) fn spawn(f: F) -> tokio::task::JoinHandle where diff --git a/crates/wasi/src/preview2/pipe.rs b/crates/wasi/src/preview2/pipe.rs index 34eb8a34f03a..f8d64546c635 100644 --- a/crates/wasi/src/preview2/pipe.rs +++ b/crates/wasi/src/preview2/pipe.rs @@ -102,6 +102,7 @@ pub struct AsyncReadStream { state: StreamState, buffer: Option>, receiver: tokio::sync::mpsc::Receiver>, + pub(crate) join_handle: tokio::task::JoinHandle<()>, } impl AsyncReadStream { @@ -109,7 +110,7 @@ impl AsyncReadStream { /// provided by this struct, the argument must impl [`tokio::io::AsyncRead`]. pub fn new(mut reader: T) -> Self { let (sender, receiver) = tokio::sync::mpsc::channel(1); - crate::preview2::spawn(async move { + let join_handle = crate::preview2::spawn(async move { loop { use tokio::io::AsyncReadExt; let mut buf = bytes::BytesMut::with_capacity(4096); @@ -130,10 +131,17 @@ impl AsyncReadStream { state: StreamState::Open, buffer: None, receiver, + join_handle, } } } +impl Drop for AsyncReadStream { + fn drop(&mut self) { + self.join_handle.abort() + } +} + #[async_trait::async_trait] impl HostInputStream for AsyncReadStream { fn read(&mut self, size: usize) -> Result<(Bytes, StreamState), Error> { @@ -213,6 +221,7 @@ pub struct AsyncWriteStream { state: Option, sender: tokio::sync::mpsc::Sender, result_receiver: tokio::sync::mpsc::Receiver>, + join_handle: tokio::task::JoinHandle<()>, } impl AsyncWriteStream { @@ -222,7 +231,7 @@ impl AsyncWriteStream { let (sender, mut receiver) = tokio::sync::mpsc::channel::(1); let (result_sender, result_receiver) = tokio::sync::mpsc::channel(1); - crate::preview2::spawn(async move { + let join_handle = crate::preview2::spawn(async move { 'outer: loop { use tokio::io::AsyncWriteExt; match receiver.recv().await { @@ -260,6 +269,7 @@ impl AsyncWriteStream { state: Some(WriteState::Ready), sender, result_receiver, + join_handle, } } @@ -282,6 +292,12 @@ impl AsyncWriteStream { } } +impl Drop for AsyncWriteStream { + fn drop(&mut self) { + self.join_handle.abort() + } +} + #[async_trait::async_trait] impl HostOutputStream for AsyncWriteStream { fn write(&mut self, bytes: Bytes) -> Result<(usize, StreamState), anyhow::Error> { diff --git a/crates/wasi/src/preview2/stdio.rs b/crates/wasi/src/preview2/stdio.rs index 8386f6611ae8..81083d2e0bfc 100644 --- a/crates/wasi/src/preview2/stdio.rs +++ b/crates/wasi/src/preview2/stdio.rs @@ -23,45 +23,196 @@ pub fn stderr() -> Stderr { #[cfg(all(unix, test))] mod test { + use crate::preview2::{HostInputStream, StreamState}; + use libc; + use std::fs::File; + use std::io::{BufRead, BufReader, Write}; + use std::os::fd::FromRawFd; + + fn test_child_stdin(child: T, parent: P) + where + T: FnOnce(File), + P: FnOnce(File, BufReader), + { + unsafe { + // Make pipe for emulating stdin. + let mut stdin_fds: [libc::c_int; 2] = [0; 2]; + assert_eq!( + libc::pipe(stdin_fds.as_mut_ptr()), + 0, + "Failed to create stdin pipe" + ); + let [stdin_read, stdin_write] = stdin_fds; + + // Make pipe for getting results. + let mut result_fds: [libc::c_int; 2] = [0; 2]; + assert_eq!( + libc::pipe(result_fds.as_mut_ptr()), + 0, + "Failed to create result pipe" + ); + let [result_read, result_write] = result_fds; + + let child_pid = libc::fork(); + if child_pid == 0 { + libc::close(stdin_write); + libc::close(result_read); + + libc::close(libc::STDIN_FILENO); + libc::dup2(stdin_read, libc::STDIN_FILENO); + + let result_write = File::from_raw_fd(result_write); + child(result_write); + } else { + libc::close(stdin_read); + libc::close(result_write); + + let stdin_write = File::from_raw_fd(stdin_write); + let result_read = BufReader::new(File::from_raw_fd(result_read)); + parent(stdin_write, result_read); + } + } + } + // This could even be parameterized somehow to use the worker thread stdin vs the asyncfd // stdin. + fn test_stdin_by_forking(mk_stdin: T) + where + S: HostInputStream, + T: Fn() -> S, + { + test_child_stdin( + |mut result_write| { + let mut child_running = true; + while child_running { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + .block_on(async { + 'task: loop { + println!("child: creating stdin"); + let mut stdin = mk_stdin(); + + println!("child: checking that stdin is not ready"); + assert!( + tokio::time::timeout( + std::time::Duration::from_millis(100), + stdin.ready() + ) + .await + .is_err(), + "stdin available too soon" + ); + + writeln!(&mut result_write, "start").unwrap(); + + println!("child: started"); + + let mut buffer = String::new(); + loop { + println!("child: waiting for stdin to be ready"); + stdin.ready().await.unwrap(); + + println!("child: reading input"); + let (bytes, status) = stdin.read(1024).unwrap(); + + println!("child: {:?}, {:?}", bytes, status); + + // We can't effectively test for the case where stdin was closed. + assert_eq!(status, StreamState::Open); + + buffer.push_str(std::str::from_utf8(bytes.as_ref()).unwrap()); + if let Some((line, rest)) = buffer.split_once('\n') { + if line == "all done" { + writeln!(&mut result_write, "done").unwrap(); + println!("child: exiting..."); + child_running = false; + break 'task; + } else if line == "restart_runtime" { + writeln!(&mut result_write, "restarting").unwrap(); + println!("child: restarting runtime..."); + break 'task; + } else if line == "restart_task" { + writeln!(&mut result_write, "restarting").unwrap(); + println!("child: restarting task..."); + continue 'task; + } else { + writeln!(&mut result_write, "{}", line).unwrap(); + } + + buffer = rest.to_owned(); + } + } + } + }); + println!("runtime exited"); + } + println!("child exited"); + }, + |mut stdin_write, mut result_read| { + let mut line = String::new(); + result_read.read_line(&mut line).unwrap(); + assert_eq!(line, "start\n"); + + for i in 0..5 { + let message = format!("some bytes {}\n", i); + stdin_write.write_all(message.as_bytes()).unwrap(); + line.clear(); + result_read.read_line(&mut line).unwrap(); + assert_eq!(line, message); + } + + writeln!(&mut stdin_write, "restart_task").unwrap(); + line.clear(); + result_read.read_line(&mut line).unwrap(); + assert_eq!(line, "restarting\n"); + line.clear(); + + result_read.read_line(&mut line).unwrap(); + assert_eq!(line, "start\n"); + + for i in 0..10 { + let message = format!("more bytes {}\n", i); + stdin_write.write_all(message.as_bytes()).unwrap(); + line.clear(); + result_read.read_line(&mut line).unwrap(); + assert_eq!(line, message); + } + + writeln!(&mut stdin_write, "restart_runtime").unwrap(); + line.clear(); + result_read.read_line(&mut line).unwrap(); + assert_eq!(line, "restarting\n"); + line.clear(); + + result_read.read_line(&mut line).unwrap(); + assert_eq!(line, "start\n"); + + for i in 0..17 { + let message = format!("even more bytes {}\n", i); + stdin_write.write_all(message.as_bytes()).unwrap(); + line.clear(); + result_read.read_line(&mut line).unwrap(); + assert_eq!(line, message); + } + + writeln!(&mut stdin_write, "all done").unwrap(); + + line.clear(); + result_read.read_line(&mut line).unwrap(); + assert_eq!(line, "done\n"); + }, + ) + } + + #[test] + fn test_async_fd_stdin() { + test_stdin_by_forking(super::stdin); + } + #[test] - fn test_stdin_by_forking() { - // Make pipe for emulating stdin. - // Make pipe for getting results. - // Fork. - // When child: - // close stdin fd. - // use dup2 to turn the pipe recv end into the stdin fd. - // in a tokio runtime: - // let stdin = super::stdin(); - // // Make sure the initial state is that stdin is not ready: - // if timeout(stdin.ready().await).is_timeout() { - // send "start\n" on result pipe. - // } - // loop { - // match timeout(stdin.ready().await) { - // Ok => { - // let bytes = stdin.read(); - // if bytes == ending sentinel: - // exit - // if bytes == some other sentinel: - // return and go back to the thing where we start the tokio runtime, - // testing that when creating a new super::stdin() it works correctly - // send "got: {bytes:?}\n" on result pipe. - // } - // Err => { - // send "timed out\n" on result pipe. - // } - // } - // } - // When parent: - // wait to recv "start\n" on result pipe (or the child process exits) - // send some bytes to child stdin. - // make sure we get back "got {bytes:?}" on result pipe (or the child process exits) - // sleep for a while. - // make sure we get back "timed out" on result pipe (or the child process exits) - // send some bytes again. and etc. - // + fn test_worker_thread_stdin() { + test_stdin_by_forking(super::worker_thread_stdin::stdin); } } diff --git a/crates/wasi/src/preview2/stdio/unix.rs b/crates/wasi/src/preview2/stdio/unix.rs index 3888b5cdf96d..9e24efc95f50 100644 --- a/crates/wasi/src/preview2/stdio/unix.rs +++ b/crates/wasi/src/preview2/stdio/unix.rs @@ -5,30 +5,60 @@ use futures::ready; use std::future::Future; use std::io::{self, Read}; use std::pin::Pin; +use std::sync::{Arc, Mutex, OnceLock}; use std::task::{Context, Poll}; use tokio::io::unix::AsyncFd; use tokio::io::{AsyncRead, ReadBuf}; -// wasmtime cant use std::sync::OnceLock yet because of a llvm regression in -// 1.70. when 1.71 is released, we can switch to using std here. -use once_cell::sync::OnceCell as OnceLock; - -use std::sync::Mutex; - // We need a single global instance of the AsyncFd because creating // this instance registers the process's stdin fd with epoll, which will // return an error if an fd is registered more than once. -struct GlobalStdin(Mutex); -static STDIN: OnceLock = OnceLock::new(); +static STDIN: OnceLock = OnceLock::new(); + +#[derive(Clone)] +pub struct Stdin(Arc>); -impl GlobalStdin { - fn new() -> anyhow::Result { - Ok(Self(Mutex::new(AsyncReadStream::new(InnerStdin::new()?)))) +pub fn stdin() -> Stdin { + fn init_stdin() -> AsyncReadStream { + use crate::preview2::RUNTIME; + match tokio::runtime::Handle::try_current() { + Ok(_) => AsyncReadStream::new(InnerStdin::new().unwrap()), + Err(_) => { + let _enter = RUNTIME.enter(); + RUNTIME.block_on(async { AsyncReadStream::new(InnerStdin::new().unwrap()) }) + } + } } - fn read(&self, size: usize) -> Result<(Bytes, StreamState), Error> { + + let handle = STDIN + .get_or_init(|| Stdin(Arc::new(Mutex::new(init_stdin())))) + .clone(); + + { + let mut guard = handle.0.lock().unwrap(); + + // The backing task exited. This can happen in two cases: + // + // 1. the task crashed + // 2. the runtime has exited and been restarted in the same process + // + // As we can't tell the difference between these two, we assume the latter and restart the + // task. + if guard.join_handle.is_finished() { + *guard = init_stdin(); + } + } + + handle +} + +#[async_trait::async_trait] +impl crate::preview2::HostInputStream for Stdin { + fn read(&mut self, size: usize) -> Result<(Bytes, StreamState), Error> { HostInputStream::read(&mut *self.0.lock().unwrap(), size) } - fn ready<'a>(&'a self) -> impl Future> + 'a { + + async fn ready(&mut self) -> Result<(), Error> { // Custom Future impl takes the std mutex in each invocation of poll. // Required so we don't have to use a tokio mutex, which we can't take from // inside a sync context in Self::read. @@ -37,50 +67,19 @@ impl GlobalStdin { // then releasing the lock is acceptable here because the ready() future // is only ever going to await on a single channel recv, plus some management // of a state machine (for buffering). - struct Ready<'a>(&'a GlobalStdin); + struct Ready<'a> { + handle: &'a Stdin, + } impl<'a> Future for Ready<'a> { type Output = Result<(), Error>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let mut locked = self.as_mut().0 .0.lock().unwrap(); + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let mut locked = self.handle.0.lock().unwrap(); let fut = locked.ready(); tokio::pin!(fut); fut.poll(cx) } } - Ready(self) - } -} - -pub struct Stdin; -impl Stdin { - fn get_global() -> &'static GlobalStdin { - // Creation must be running in a tokio context to succeed. - match tokio::runtime::Handle::try_current() { - Ok(_) => STDIN.get_or_init(|| { - GlobalStdin::new().expect("creating AsyncFd for stdin in existing tokio context") - }), - Err(_) => STDIN.get_or_init(|| { - crate::preview2::in_tokio(async { - GlobalStdin::new() - .expect("creating AsyncFd for stdin in internal tokio context") - }) - }), - } - } -} - -pub fn stdin() -> Stdin { - Stdin -} - -#[async_trait::async_trait] -impl crate::preview2::HostInputStream for Stdin { - fn read(&mut self, size: usize) -> Result<(Bytes, StreamState), Error> { - Self::get_global().read(size) - } - - async fn ready(&mut self) -> Result<(), Error> { - Self::get_global().ready().await + Ready { handle: self }.await } } @@ -98,11 +97,11 @@ impl InnerStdin { let borrowed_fd = unsafe { rustix::fd::BorrowedFd::borrow_raw(stdin.as_raw_fd()) }; let flags = rustix::fs::fcntl_getfl(borrowed_fd)?; if !flags.contains(OFlags::NONBLOCK) { - rustix::fs::fcntl_setfl(borrowed_fd, flags.difference(OFlags::NONBLOCK))?; + rustix::fs::fcntl_setfl(borrowed_fd, flags.union(OFlags::NONBLOCK))?; } Ok(Self { - inner: AsyncFd::new(std::io::stdin())?, + inner: AsyncFd::new(stdin)?, }) } } @@ -122,8 +121,12 @@ impl AsyncRead for InnerStdin { buf.advance(len); return Poll::Ready(Ok(())); } - Ok(Err(err)) => return Poll::Ready(Err(err)), - Err(_would_block) => continue, + Ok(Err(err)) => { + return Poll::Ready(Err(err)); + } + Err(_would_block) => { + continue; + } } } } diff --git a/crates/wasi/src/preview2/stdio/worker_thread_stdin.rs b/crates/wasi/src/preview2/stdio/worker_thread_stdin.rs index 353b5c090e62..b4fde02465e3 100644 --- a/crates/wasi/src/preview2/stdio/worker_thread_stdin.rs +++ b/crates/wasi/src/preview2/stdio/worker_thread_stdin.rs @@ -1,7 +1,9 @@ use crate::preview2::{HostInputStream, StreamState}; -use anyhow::{Context, Error}; -use bytes::Bytes; -use tokio::sync::{mpsc, oneshot}; +use anyhow::Error; +use bytes::{Bytes, BytesMut}; +use std::io::Read; +use std::sync::Arc; +use tokio::sync::watch; // wasmtime cant use std::sync::OnceLock yet because of a llvm regression in // 1.70. when 1.71 is released, we can switch to using std here. @@ -9,109 +11,123 @@ use once_cell::sync::OnceCell as OnceLock; use std::sync::Mutex; -// We need a single global instance of the AsyncFd because creating -// this instance registers the process's stdin fd with epoll, which will -// return an error if an fd is registered more than once. struct GlobalStdin { - tx: mpsc::Sender>>, - // FIXME use a Watch to check for readiness instead of sending a oneshot sender + // Worker thread uses this to notify of new events. Ready checks use this + // to create a new Receiver via .subscribe(). The newly created receiver + // will only wait for events created after the call to subscribe(). + tx: Arc>, + // Worker thread and receivers share this state to get bytes read off + // stdin, or the error/closed state. + state: Arc>, } -static STDIN: OnceLock> = OnceLock::new(); - -fn create() -> Mutex { - let (tx, mut rx) = mpsc::channel::>>(1); - std::thread::spawn(move || { - use std::io::BufRead; - // A client is interested in stdin's readiness. - // Don't care about the None case - the GlobalStdin sender on the other - // end of this pipe will live forever, because it lives inside the OnceLock. - while let Some(msg) = rx.blocking_recv() { - // Fill buf - can we skip this if its - // already filled? - // also, this could block forever and the - // client could give up. in that case, - // another client may want to start waiting - let r = std::io::stdin() - .lock() - .fill_buf() - .map(|_| ()) - .map_err(anyhow::Error::from); - // tell the client stdin is ready for reading. - // don't care if the client happens to have died. - let _ = msg.send(r); + +#[derive(Debug)] +struct StdinState { + // Bytes read off stdin. + buffer: BytesMut, + // Error read off stdin, if any. + error: Option, + // If an error has occured in the past, we consider the stream closed. + closed: bool, +} + +static STDIN: OnceLock = OnceLock::new(); + +fn create() -> GlobalStdin { + let (tx, _rx) = watch::channel(()); + let tx = Arc::new(tx); + + let state = Arc::new(Mutex::new(StdinState { + buffer: BytesMut::new(), + error: None, + closed: false, + })); + + let ret = GlobalStdin { + state: state.clone(), + tx: tx.clone(), + }; + + std::thread::spawn(move || loop { + let mut bytes = BytesMut::zeroed(1024); + match std::io::stdin().lock().read(&mut bytes) { + // Reading `0` indicates that stdin has reached EOF, so we break + // the loop to allow the thread to exit. + Ok(0) => break, + + Ok(nbytes) => { + // Append to the buffer: + bytes.truncate(nbytes); + let mut locked = state.lock().unwrap(); + locked.buffer.extend_from_slice(&bytes); + } + Err(e) => { + // Set the error, and mark the stream as closed: + let mut locked = state.lock().unwrap(); + if locked.error.is_none() { + locked.error = Some(e) + } + locked.closed = true; + } } + // Receivers may or may not exist - fine if they dont, new + // ones will be created with subscribe() + let _ = tx.send(()); }); - - Mutex::new(GlobalStdin { tx }) + ret } +/// Only public interface is the [`HostInputStream`] impl. pub struct Stdin; impl Stdin { - fn get_global() -> &'static Mutex { + // Private! Only required internally. + fn get_global() -> &'static GlobalStdin { STDIN.get_or_init(|| create()) } } pub fn stdin() -> Stdin { - // This implementation still needs to be fixed, and we need better test coverage. - // We are deferring that work to a future PR. - // https://github.com/bytecodealliance/wasmtime/pull/6556#issuecomment-1646232646 - panic!("worker-thread based stdin is not yet implemented"); - // Stdin + Stdin } #[async_trait::async_trait] impl HostInputStream for Stdin { fn read(&mut self, size: usize) -> Result<(Bytes, StreamState), Error> { - use std::io::Read; - let mut buf = vec![0; size]; - // FIXME: this is actually blocking. This whole implementation is likely bogus as a result - let nbytes = std::io::stdin().read(&mut buf)?; - buf.truncate(nbytes); - Ok(( - buf.into(), - if nbytes > 0 { - StreamState::Open - } else { - StreamState::Closed - }, - )) + let g = Stdin::get_global(); + let mut locked = g.state.lock().unwrap(); + + if let Some(e) = locked.error.take() { + return Err(e.into()); + } + let size = locked.buffer.len().min(size); + let bytes = locked.buffer.split_to(size); + let state = if locked.buffer.is_empty() && locked.closed { + StreamState::Closed + } else { + StreamState::Open + }; + Ok((bytes.freeze(), state)) } async fn ready(&mut self) -> Result<(), Error> { - use mpsc::error::TrySendError; - use std::future::Future; - use std::pin::Pin; - use std::task::{Context, Poll}; - - // Custom Future impl takes the std mutex in each invocation of poll. - // Required so we don't have to use a tokio mutex, which we can't take from - // inside a sync context in Self::read. - // - // Take the lock, attempt to - struct Send(Option>>); - impl Future for Send { - type Output = anyhow::Result<()>; - fn poll(mut self: Pin<&mut Self>, _: &mut Context) -> Poll { - let locked = Stdin::get_global().lock().unwrap(); - let to_send = self.as_mut().0.take().expect("to_send should be some"); - match locked.tx.try_send(to_send) { - Ok(()) => Poll::Ready(Ok(())), - Err(TrySendError::Full(to_send)) => { - self.as_mut().0.replace(to_send); - Poll::Pending - } - Err(TrySendError::Closed(_)) => { - Poll::Ready(Err(anyhow::anyhow!("channel to GlobalStdin closed"))) - } - } + let g = Stdin::get_global(); + + // Block makes sure we dont hold the mutex across the await: + let mut rx = { + let locked = g.state.lock().unwrap(); + // read() will only return (empty, open) when the buffer is empty, + // AND there is no error AND the stream is still open: + if !locked.buffer.is_empty() || locked.error.is_some() || locked.closed { + return Ok(()); } - } + // Sender will take the mutex before updating the state of + // subscribe, so this ensures we will only await for any stdin + // events that are recorded after we drop the mutex: + g.tx.subscribe() + }; + + rx.changed().await.expect("impossible for sender to drop"); - let (result_tx, rx) = oneshot::channel::>(); - Box::pin(Send(Some(result_tx))) - .await - .context("sending message to worker thread")?; - rx.await.expect("channel is always alive") + Ok(()) } }