From c740f1b6256cf118e8f051d9d03638185a114b7f Mon Sep 17 00:00:00 2001 From: Pat Hickey Date: Mon, 25 Nov 2024 14:11:49 -0800 Subject: [PATCH] Reactor::wait_for takes `&Pollable` instead of `Pollable` We want to be able to reuse a pollable on many uses in the reactor. The fundamental underlying operation `wasi:io/poll.poll` takes a `list>`, or `&[&Pollable]` in Rust. The `Poller` structure needs to be changed to unsafe transmute the `&Pollable` to a common lifetime for passing to `poll`. This remains safe as long as each caller that inserts a `&Pollable` unconditionally removes it. --- src/http/client.rs | 6 +++--- src/http/response.rs | 2 +- src/net/tcp_listener.rs | 6 +++--- src/net/tcp_stream.rs | 8 ++++---- src/runtime/polling.rs | 22 +++++++++------------- src/runtime/reactor.rs | 8 ++++---- src/time/mod.rs | 2 +- 7 files changed, 25 insertions(+), 29 deletions(-) diff --git a/src/http/client.rs b/src/http/client.rs index af792d5..ea45ad3 100644 --- a/src/http/client.rs +++ b/src/http/client.rs @@ -36,7 +36,7 @@ impl Client { OutgoingBody::finish(wasi_body, trailers).unwrap(); // 4. Receive the response - Reactor::current().wait_for(res.subscribe()).await; + Reactor::current().wait_for(&res.subscribe()).await; // NOTE: the first `unwrap` is to ensure readiness, the second `unwrap` // is to trap if we try and get the response more than once. The final // `?` is to raise the actual error if there is one. @@ -90,13 +90,13 @@ impl AsyncWrite for OutputStream { let max = max.min(buf.len()); let buf = &buf[0..max]; self.stream.write(buf).unwrap(); - Reactor::current().wait_for(self.stream.subscribe()).await; + Reactor::current().wait_for(&self.stream.subscribe()).await; Ok(max) } async fn flush(&mut self) -> io::Result<()> { self.stream.flush().unwrap(); - Reactor::current().wait_for(self.stream.subscribe()).await; + Reactor::current().wait_for(&self.stream.subscribe()).await; Ok(()) } } diff --git a/src/http/response.rs b/src/http/response.rs index 0dcfe30..e48b7ff 100644 --- a/src/http/response.rs +++ b/src/http/response.rs @@ -113,7 +113,7 @@ impl AsyncRead for IncomingBody { None => { // Wait for an event to be ready let pollable = self.body_stream.subscribe(); - Reactor::current().wait_for(pollable).await; + Reactor::current().wait_for(&pollable).await; // Read the bytes from the body stream let buf = match self.body_stream.read(CHUNK_SIZE) { diff --git a/src/net/tcp_listener.rs b/src/net/tcp_listener.rs index 94352aa..9130f1d 100644 --- a/src/net/tcp_listener.rs +++ b/src/net/tcp_listener.rs @@ -45,11 +45,11 @@ impl TcpListener { socket .start_bind(&network, local_address) .map_err(to_io_err)?; - reactor.wait_for(socket.subscribe()).await; + reactor.wait_for(&socket.subscribe()).await; socket.finish_bind().map_err(to_io_err)?; socket.start_listen().map_err(to_io_err)?; - reactor.wait_for(socket.subscribe()).await; + reactor.wait_for(&socket.subscribe()).await; socket.finish_listen().map_err(to_io_err)?; Ok(Self { socket }) } @@ -78,7 +78,7 @@ impl<'a> AsyncIterator for Incoming<'a> { async fn next(&mut self) -> Option { Reactor::current() - .wait_for(self.listener.socket.subscribe()) + .wait_for(&self.listener.socket.subscribe()) .await; let (socket, input, output) = match self.listener.socket.accept().map_err(to_io_err) { Ok(accepted) => accepted, diff --git a/src/net/tcp_stream.rs b/src/net/tcp_stream.rs index 5c6ab8a..307271a 100644 --- a/src/net/tcp_stream.rs +++ b/src/net/tcp_stream.rs @@ -30,7 +30,7 @@ impl TcpStream { impl AsyncRead for TcpStream { async fn read(&mut self, buf: &mut [u8]) -> io::Result { - Reactor::current().wait_for(self.input.subscribe()).await; + Reactor::current().wait_for(&self.input.subscribe()).await; let slice = match self.input.read(buf.len() as u64) { Ok(slice) => slice, Err(StreamError::Closed) => return Ok(0), @@ -44,7 +44,7 @@ impl AsyncRead for TcpStream { impl AsyncRead for &TcpStream { async fn read(&mut self, buf: &mut [u8]) -> io::Result { - Reactor::current().wait_for(self.input.subscribe()).await; + Reactor::current().wait_for(&self.input.subscribe()).await; let slice = match self.input.read(buf.len() as u64) { Ok(slice) => slice, Err(StreamError::Closed) => return Ok(0), @@ -58,7 +58,7 @@ impl AsyncRead for &TcpStream { impl AsyncWrite for TcpStream { async fn write(&mut self, buf: &[u8]) -> io::Result { - Reactor::current().wait_for(self.output.subscribe()).await; + Reactor::current().wait_for(&self.output.subscribe()).await; self.output.write(buf).map_err(to_io_err)?; Ok(buf.len()) } @@ -70,7 +70,7 @@ impl AsyncWrite for TcpStream { impl AsyncWrite for &TcpStream { async fn write(&mut self, buf: &[u8]) -> io::Result { - Reactor::current().wait_for(self.output.subscribe()).await; + Reactor::current().wait_for(&self.output.subscribe()).await; self.output.write(buf).map_err(to_io_err)?; Ok(buf.len()) } diff --git a/src/runtime/polling.rs b/src/runtime/polling.rs index 8ddec0a..a0382c5 100644 --- a/src/runtime/polling.rs +++ b/src/runtime/polling.rs @@ -10,7 +10,7 @@ use wasi::io::poll::{poll, Pollable}; /// Waits for I/O events. #[derive(Debug)] pub(crate) struct Poller { - pub(crate) targets: Slab, + pub(crate) targets: Slab<&'static Pollable>, } impl Poller { @@ -27,21 +27,17 @@ impl Poller { } /// Insert a new `Pollable` target into `Poller` - pub(crate) fn insert(&mut self, target: Pollable) -> EventKey { - let key = self.targets.insert(target); + /// + /// Safety: Caller MUST remove the EventKey corresponding to this insert + /// during the lifetime of &Pollable. + pub(crate) unsafe fn insert(&mut self, target: &Pollable) -> EventKey { + let key = self.targets.insert(std::mem::transmute(target)); EventKey(key as u32) } - /// Get a `Pollable` if it exists. - pub(crate) fn get(&self, key: &EventKey) -> Option<&Pollable> { - self.targets.get(key.0 as usize) - } - /// Remove an instance of `Pollable` from `Poller`. - /// - /// Returns `None` if no entry was found for `key`. - pub(crate) fn remove(&mut self, key: EventKey) -> Option { - self.targets.try_remove(key.0 as usize) + pub(crate) fn remove(&mut self, key: EventKey) { + self.targets.try_remove(key.0 as usize); } /// Block the current thread until a new event has triggered. @@ -58,7 +54,7 @@ impl Poller { let mut targets = Vec::with_capacity(self.targets.len()); for (index, target) in self.targets.iter() { indexes.push(index); - targets.push(target); + targets.push(*target); } debug_assert_ne!( diff --git a/src/runtime/reactor.rs b/src/runtime/reactor.rs index 5525c8b..a1af257 100644 --- a/src/runtime/reactor.rs +++ b/src/runtime/reactor.rs @@ -72,8 +72,7 @@ impl Reactor { } /// Wait for the pollable to resolve. - pub async fn wait_for(&self, pollable: Pollable) { - let mut pollable = Some(pollable); + pub async fn wait_for(&self, pollable: &Pollable) { let mut key = None; // This function is the core loop of our function; it will be called // multiple times as the future is resolving. @@ -84,12 +83,13 @@ impl Reactor { // Schedule interest in the `pollable` on the first iteration. On // every iteration, register the waker with the reactor. - let key = key.get_or_insert_with(|| reactor.poller.insert(pollable.take().unwrap())); + // Safety: caller of insert operation must remove key during lifetime of &Pollable. + let key = key.get_or_insert_with(|| unsafe { reactor.poller.insert(pollable) }); reactor.wakers.insert(*key, cx.waker().clone()); // Check whether we're ready or need to keep waiting. If we're // ready, we clean up after ourselves. - if reactor.poller.get(key).unwrap().ready() { + if pollable.ready() { reactor.poller.remove(*key); reactor.wakers.remove(key); Poll::Ready(()) diff --git a/src/time/mod.rs b/src/time/mod.rs index 419b35e..e2b11cf 100644 --- a/src/time/mod.rs +++ b/src/time/mod.rs @@ -67,7 +67,7 @@ impl Timer { match self.0 { Some(deadline) => { Reactor::current() - .wait_for(subscribe_instant(*deadline)) + .wait_for(&subscribe_instant(*deadline)) .await } None => std::future::pending().await,