Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ fn poll_oneoff_files() {
run("poll_oneoff_files", false).unwrap()
}

#[cfg_attr(windows, should_panic)]
#[test_log::test]
fn poll_oneoff_stdio() {
run("poll_oneoff_stdio", true).unwrap()
Expand Down
1 change: 0 additions & 1 deletion crates/test-programs/tests/wasi-preview2-components.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ async fn poll_oneoff_files() {
run("poll_oneoff_files", false).await.unwrap()
}

#[cfg_attr(windows, should_panic)]
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn poll_oneoff_stdio() {
run("poll_oneoff_stdio", true).await.unwrap()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ unsafe fn test_stdin_read() {
}

fn writable_subs(h: &HashMap<u64, wasi::Fd>) -> Vec<wasi::Subscription> {
println!("writable subs: {:?}", h);
h.iter()
.map(|(ud, fd)| wasi::Subscription {
userdata: *ud,
Expand All @@ -87,7 +86,7 @@ fn writable_subs(h: &HashMap<u64, wasi::Fd>) -> Vec<wasi::Subscription> {

unsafe fn test_stdout_stderr_write() {
let mut writable: HashMap<u64, wasi::Fd> =
vec![(1, STDOUT_FD), (2, STDERR_FD)].into_iter().collect();
[(1, STDOUT_FD), (2, STDERR_FD)].into_iter().collect();

let clock = wasi::Subscription {
userdata: CLOCK_ID,
Expand Down
3 changes: 3 additions & 0 deletions crates/wasi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ tokio = { workspace = true, features = ["time", "sync", "io-std", "io-util", "rt
[target.'cfg(unix)'.dependencies]
rustix = { workspace = true, features = ["fs"], optional = true }

[target.'cfg(unix)'.dev-dependencies]
libc = { workspace = true }

[target.'cfg(windows)'.dependencies]
io-extras = { workspace = true }
windows-sys = { workspace = true }
Expand Down
15 changes: 8 additions & 7 deletions crates/wasi/src/preview2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,14 @@ pub mod bindings {
pub use self::_internal_rest::wasi::*;
}

static RUNTIME: once_cell::sync::Lazy<tokio::runtime::Runtime> = once_cell::sync::Lazy::new(|| {
tokio::runtime::Builder::new_multi_thread()
.enable_time()
.enable_io()
.build()
.unwrap()
});
pub(crate) static RUNTIME: once_cell::sync::Lazy<tokio::runtime::Runtime> =
once_cell::sync::Lazy::new(|| {
tokio::runtime::Builder::new_current_thread()
.enable_time()
.enable_io()
.build()
.unwrap()
});

pub(crate) fn spawn<F, G>(f: F) -> tokio::task::JoinHandle<G>
where
Expand Down
20 changes: 18 additions & 2 deletions crates/wasi/src/preview2/pipe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,15 @@ pub struct AsyncReadStream {
state: StreamState,
buffer: Option<Result<Bytes, std::io::Error>>,
receiver: tokio::sync::mpsc::Receiver<Result<(Bytes, StreamState), std::io::Error>>,
pub(crate) join_handle: tokio::task::JoinHandle<()>,
}

impl AsyncReadStream {
/// Create a [`AsyncReadStream`]. In order to use the [`HostInputStream`] impl
/// provided by this struct, the argument must impl [`tokio::io::AsyncRead`].
pub fn new<T: tokio::io::AsyncRead + Send + Sync + Unpin + 'static>(mut reader: T) -> Self {
let (sender, receiver) = tokio::sync::mpsc::channel(1);
crate::preview2::spawn(async move {
let join_handle = crate::preview2::spawn(async move {
loop {
use tokio::io::AsyncReadExt;
let mut buf = bytes::BytesMut::with_capacity(4096);
Expand All @@ -130,10 +131,17 @@ impl AsyncReadStream {
state: StreamState::Open,
buffer: None,
receiver,
join_handle,
}
}
}

impl Drop for AsyncReadStream {
fn drop(&mut self) {
self.join_handle.abort()
}
}

#[async_trait::async_trait]
impl HostInputStream for AsyncReadStream {
fn read(&mut self, size: usize) -> Result<(Bytes, StreamState), Error> {
Expand Down Expand Up @@ -213,6 +221,7 @@ pub struct AsyncWriteStream {
state: Option<WriteState>,
sender: tokio::sync::mpsc::Sender<Bytes>,
result_receiver: tokio::sync::mpsc::Receiver<Result<StreamState, std::io::Error>>,
join_handle: tokio::task::JoinHandle<()>,
}

impl AsyncWriteStream {
Expand All @@ -222,7 +231,7 @@ impl AsyncWriteStream {
let (sender, mut receiver) = tokio::sync::mpsc::channel::<Bytes>(1);
let (result_sender, result_receiver) = tokio::sync::mpsc::channel(1);

crate::preview2::spawn(async move {
let join_handle = crate::preview2::spawn(async move {
'outer: loop {
use tokio::io::AsyncWriteExt;
match receiver.recv().await {
Expand Down Expand Up @@ -260,6 +269,7 @@ impl AsyncWriteStream {
state: Some(WriteState::Ready),
sender,
result_receiver,
join_handle,
}
}

Expand All @@ -282,6 +292,12 @@ impl AsyncWriteStream {
}
}

impl Drop for AsyncWriteStream {
fn drop(&mut self) {
self.join_handle.abort()
}
}

#[async_trait::async_trait]
impl HostOutputStream for AsyncWriteStream {
fn write(&mut self, bytes: Bytes) -> Result<(usize, StreamState), anyhow::Error> {
Expand Down
225 changes: 188 additions & 37 deletions crates/wasi/src/preview2/stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,196 @@ pub fn stderr() -> Stderr {

#[cfg(all(unix, test))]
mod test {
use crate::preview2::{HostInputStream, StreamState};
use libc;
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
use std::os::fd::FromRawFd;

fn test_child_stdin<T, P>(child: T, parent: P)
where
T: FnOnce(File),
P: FnOnce(File, BufReader<File>),
{
unsafe {
// Make pipe for emulating stdin.
let mut stdin_fds: [libc::c_int; 2] = [0; 2];
assert_eq!(
libc::pipe(stdin_fds.as_mut_ptr()),
0,
"Failed to create stdin pipe"
);
let [stdin_read, stdin_write] = stdin_fds;

// Make pipe for getting results.
let mut result_fds: [libc::c_int; 2] = [0; 2];
assert_eq!(
libc::pipe(result_fds.as_mut_ptr()),
0,
"Failed to create result pipe"
);
let [result_read, result_write] = result_fds;

let child_pid = libc::fork();
if child_pid == 0 {
libc::close(stdin_write);
libc::close(result_read);

libc::close(libc::STDIN_FILENO);
libc::dup2(stdin_read, libc::STDIN_FILENO);

let result_write = File::from_raw_fd(result_write);
child(result_write);
} else {
libc::close(stdin_read);
libc::close(result_write);

let stdin_write = File::from_raw_fd(stdin_write);
let result_read = BufReader::new(File::from_raw_fd(result_read));
parent(stdin_write, result_read);
}
}
}

// This could even be parameterized somehow to use the worker thread stdin vs the asyncfd
// stdin.
fn test_stdin_by_forking<S, T>(mk_stdin: T)
where
S: HostInputStream,
T: Fn() -> S,
{
test_child_stdin(
|mut result_write| {
let mut child_running = true;
while child_running {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
'task: loop {
println!("child: creating stdin");
let mut stdin = mk_stdin();

println!("child: checking that stdin is not ready");
assert!(
tokio::time::timeout(
std::time::Duration::from_millis(100),
stdin.ready()
)
.await
.is_err(),
"stdin available too soon"
);

writeln!(&mut result_write, "start").unwrap();

println!("child: started");

let mut buffer = String::new();
loop {
println!("child: waiting for stdin to be ready");
stdin.ready().await.unwrap();

println!("child: reading input");
let (bytes, status) = stdin.read(1024).unwrap();

println!("child: {:?}, {:?}", bytes, status);

// We can't effectively test for the case where stdin was closed.
assert_eq!(status, StreamState::Open);

buffer.push_str(std::str::from_utf8(bytes.as_ref()).unwrap());
if let Some((line, rest)) = buffer.split_once('\n') {
if line == "all done" {
writeln!(&mut result_write, "done").unwrap();
println!("child: exiting...");
child_running = false;
break 'task;
} else if line == "restart_runtime" {
writeln!(&mut result_write, "restarting").unwrap();
println!("child: restarting runtime...");
break 'task;
} else if line == "restart_task" {
writeln!(&mut result_write, "restarting").unwrap();
println!("child: restarting task...");
continue 'task;
} else {
writeln!(&mut result_write, "{}", line).unwrap();
}

buffer = rest.to_owned();
}
}
}
});
println!("runtime exited");
}
println!("child exited");
},
|mut stdin_write, mut result_read| {
let mut line = String::new();
result_read.read_line(&mut line).unwrap();
assert_eq!(line, "start\n");

for i in 0..5 {
let message = format!("some bytes {}\n", i);
stdin_write.write_all(message.as_bytes()).unwrap();
line.clear();
result_read.read_line(&mut line).unwrap();
assert_eq!(line, message);
}

writeln!(&mut stdin_write, "restart_task").unwrap();
line.clear();
result_read.read_line(&mut line).unwrap();
assert_eq!(line, "restarting\n");
line.clear();

result_read.read_line(&mut line).unwrap();
assert_eq!(line, "start\n");

for i in 0..10 {
let message = format!("more bytes {}\n", i);
stdin_write.write_all(message.as_bytes()).unwrap();
line.clear();
result_read.read_line(&mut line).unwrap();
assert_eq!(line, message);
}

writeln!(&mut stdin_write, "restart_runtime").unwrap();
line.clear();
result_read.read_line(&mut line).unwrap();
assert_eq!(line, "restarting\n");
line.clear();

result_read.read_line(&mut line).unwrap();
assert_eq!(line, "start\n");

for i in 0..17 {
let message = format!("even more bytes {}\n", i);
stdin_write.write_all(message.as_bytes()).unwrap();
line.clear();
result_read.read_line(&mut line).unwrap();
assert_eq!(line, message);
}

writeln!(&mut stdin_write, "all done").unwrap();

line.clear();
result_read.read_line(&mut line).unwrap();
assert_eq!(line, "done\n");
},
)
}

#[test]
fn test_async_fd_stdin() {
test_stdin_by_forking(super::stdin);
}

#[test]
fn test_stdin_by_forking() {
// Make pipe for emulating stdin.
// Make pipe for getting results.
// Fork.
// When child:
// close stdin fd.
// use dup2 to turn the pipe recv end into the stdin fd.
// in a tokio runtime:
// let stdin = super::stdin();
// // Make sure the initial state is that stdin is not ready:
// if timeout(stdin.ready().await).is_timeout() {
// send "start\n" on result pipe.
// }
// loop {
// match timeout(stdin.ready().await) {
// Ok => {
// let bytes = stdin.read();
// if bytes == ending sentinel:
// exit
// if bytes == some other sentinel:
// return and go back to the thing where we start the tokio runtime,
// testing that when creating a new super::stdin() it works correctly
// send "got: {bytes:?}\n" on result pipe.
// }
// Err => {
// send "timed out\n" on result pipe.
// }
// }
// }
// When parent:
// wait to recv "start\n" on result pipe (or the child process exits)
// send some bytes to child stdin.
// make sure we get back "got {bytes:?}" on result pipe (or the child process exits)
// sleep for a while.
// make sure we get back "timed out" on result pipe (or the child process exits)
// send some bytes again. and etc.
//
fn test_worker_thread_stdin() {
test_stdin_by_forking(super::worker_thread_stdin::stdin);
}
}
Loading