diff --git a/src/server.rs b/src/server.rs index 1fbe302d..bb556905 100644 --- a/src/server.rs +++ b/src/server.rs @@ -58,6 +58,18 @@ where workers: usize, on_preamble_success: Option>, on_preamble_failure: Option, + /// Channel used to notify when the server is ready. + /// + /// # Thread Safety + /// This sender is `Send` and may be moved between threads safely. + /// + /// # Single-use Semantics + /// A `oneshot::Sender` can transmit only one readiness notification. After + /// sending, the sender is consumed and cannot be reused. + /// + /// # Implications + /// Because only one notification may be sent, a new `ready_tx` must be + /// provided each time the server is started. ready_tx: Option>, _preamble: PhantomData, } @@ -195,6 +207,9 @@ where /// Configure a channel used to signal when the server is ready to accept /// connections. + /// + /// The provided `oneshot::Sender` is consumed after the first signal. Use a + /// fresh sender for each server run. #[must_use] pub fn ready_signal(mut self, tx: oneshot::Sender<()>) -> Self { self.ready_tx = Some(tx); @@ -325,7 +340,10 @@ where { let listener = self.listener.expect("`bind` must be called before `run`"); if let Some(tx) = self.ready_tx { - let _ = tx.send(()); + let result = tx.send(()); + if result.is_err() { + tracing::warn!("Failed to send readiness signal: receiver dropped"); + } } let shutdown_token = CancellationToken::new(); @@ -388,19 +406,17 @@ async fn worker_task( let peer_addr = stream.peer_addr().ok(); t.spawn(async move { use futures::FutureExt as _; - if let Err(panic) = std::panic::AssertUnwindSafe( + let fut = std::panic::AssertUnwindSafe( process_stream(stream, factory, success, failure), ) - .catch_unwind() - .await - { - let panic_msg = if let Some(s) = panic.downcast_ref::<&str>() { - (*s).to_string() - } else if let Some(s) = panic.downcast_ref::() { - s.clone() - } else { - format!("{panic:?}") - }; + .catch_unwind(); + + if let Err(panic) = fut.await { + let panic_msg = panic + .downcast_ref::<&str>() + .copied() + .or_else(|| panic.downcast_ref::().map(String::as_str)) + .unwrap_or(""); tracing::error!(panic = %panic_msg, ?peer_addr, "connection task panicked"); } }); diff --git a/tests/server.rs b/tests/server.rs index 031ec01d..3eda6da0 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -28,3 +28,39 @@ fn workers_accepts_large_values() { .workers(128); assert_eq!(server.worker_count(), 128); } + +/// Ensure dropping the readiness receiver logs a warning and does not +/// prevent the server from accepting connections. +#[tokio::test] +async fn readiness_receiver_dropped() { + use tokio::{ + net::TcpStream, + sync::oneshot, + time::{Duration, sleep}, + }; + + let factory = || WireframeApp::new().expect("WireframeApp::new failed"); + let server = WireframeServer::new(factory) + .workers(1) + .bind("127.0.0.1:0".parse().unwrap()) + .unwrap(); + + let addr = server.local_addr().unwrap(); + // Create channel and immediately drop receiver to force send failure + let (tx_ready, rx_ready) = oneshot::channel(); + drop(rx_ready); + + tokio::spawn(async move { + server + .ready_signal(tx_ready) + .run_with_shutdown(tokio::time::sleep(Duration::from_millis(200))) + .await + .unwrap(); + }); + + // Wait briefly to ensure server attempted to send readiness signal + sleep(Duration::from_millis(100)).await; + + // Server should still accept connections + let _stream = TcpStream::connect(addr).await.expect("connect failed"); +} diff --git a/tests/world.rs b/tests/world.rs index 10f14329..f1a22e98 100644 --- a/tests/world.rs +++ b/tests/world.rs @@ -6,26 +6,18 @@ use std::net::SocketAddr; use cucumber::World; -use tokio::{ - net::TcpStream, - sync::oneshot::{self, Sender}, -}; +use tokio::{net::TcpStream, sync::oneshot}; use wireframe::{app::WireframeApp, server::WireframeServer}; -#[derive(Debug, Default, World)] -pub struct PanicWorld { - pub addr: Option, - pub attempts: usize, - pub shutdown: Option>, - pub handle: Option>, +#[derive(Debug)] +struct PanicServer { + addr: SocketAddr, + shutdown: Option>, + handle: tokio::task::JoinHandle<()>, } -impl PanicWorld { - /// Start a server that panics during connection setup. - /// - /// # Panics - /// Panics if binding the server fails or the server task fails. - pub async fn start_panic_server(&mut self) { +impl PanicServer { + async fn spawn() -> Self { let factory = || { WireframeApp::new() .expect("Failed to create WireframeApp") @@ -37,32 +29,66 @@ impl PanicWorld { .bind("127.0.0.1:0".parse().expect("Failed to parse address")) .expect("bind"); - self.addr = Some(server.local_addr().expect("Failed to get server address")); - let (tx, rx) = oneshot::channel(); - let (ready_tx, ready_rx) = oneshot::channel(); - self.shutdown = Some(tx); + let addr = server.local_addr().expect("Failed to get server address"); + let (tx_shutdown, rx_shutdown) = oneshot::channel(); + let (tx_ready, rx_ready) = oneshot::channel(); - self.handle = Some(tokio::spawn(async move { + let handle = tokio::spawn(async move { server - .ready_signal(ready_tx) + .ready_signal(tx_ready) .run_with_shutdown(async { - let _ = rx.await; + let _ = rx_shutdown.await; }) .await .expect("Server task failed"); - })); + }); + rx_ready.await.expect("Server did not signal ready"); - ready_rx.await.expect("Server did not signal ready"); + Self { + addr, + shutdown: Some(tx_shutdown), + handle, + } } +} + +impl Drop for PanicServer { + fn drop(&mut self) { + use std::time::Duration; + + if let Some(tx) = self.shutdown.take() { + let _ = tx.send(()); + } + let timeout = Duration::from_secs(5); + let joined = futures::executor::block_on(tokio::time::timeout(timeout, &mut self.handle)); + match joined { + Ok(Ok(())) => {} + Ok(Err(e)) => eprintln!("PanicServer task panicked: {e:?}"), + Err(_) => eprintln!("PanicServer task did not shut down within timeout"), + } + } +} + +#[derive(Debug, Default, World)] +pub struct PanicWorld { + server: Option, + attempts: usize, +} + +impl PanicWorld { + /// Start a server that panics during connection setup. + /// + /// # Panics + /// Panics if binding the server fails or the server task fails. + pub async fn start_panic_server(&mut self) { self.server.replace(PanicServer::spawn().await); } /// Connect to the running server once. /// /// # Panics /// Panics if the server address is unknown or the connection fails. pub async fn connect_once(&mut self) { - TcpStream::connect(self.addr.expect("Server address not set")) - .await - .expect("Failed to connect"); + let addr = self.server.as_ref().expect("Server not started").addr; + TcpStream::connect(addr).await.expect("Failed to connect"); self.attempts += 1; } @@ -72,11 +98,8 @@ impl PanicWorld { /// Panics if the connection attempts do not match the expected count. pub async fn verify_and_shutdown(&mut self) { assert_eq!(self.attempts, 2); - if let Some(tx) = self.shutdown.take() { - let _ = tx.send(()); - } - if let Some(handle) = self.handle.take() { - handle.await.expect("Server task join failed"); - } + // dropping PanicServer will shut it down + self.server.take(); + tokio::task::yield_now().await; } }