From c5eee908e5cf5f7dd46bf745b2e64a0bd9ccd369 Mon Sep 17 00:00:00 2001 From: Leynos Date: Sun, 27 Jul 2025 01:21:55 +0100 Subject: [PATCH 1/2] Improve readiness signalling and panic handling --- src/server.rs | 60 ++++++++++++++++++++++++------------ tests/world.rs | 83 +++++++++++++++++++++++++++++--------------------- 2 files changed, 89 insertions(+), 54 deletions(-) diff --git a/src/server.rs b/src/server.rs index 1fbe302d..e35b66d9 100644 --- a/src/server.rs +++ b/src/server.rs @@ -58,6 +58,18 @@ where workers: usize, on_preamble_success: Option>, on_preamble_failure: Option, + /// Channel used to notify when the server is ready. + /// + /// # Thread Safety + /// This sender is `Send` and may be moved between threads safely. + /// + /// # Single-use Semantics + /// A `oneshot::Sender` can transmit only one readiness notification. After + /// sending, the sender is consumed and cannot be reused. + /// + /// # Implications + /// Because only one notification may be sent, a new `ready_tx` must be + /// provided each time the server is started. ready_tx: Option>, _preamble: PhantomData, } @@ -195,6 +207,9 @@ where /// Configure a channel used to signal when the server is ready to accept /// connections. + /// + /// The provided `oneshot::Sender` is consumed after the first signal. Use a + /// fresh sender for each server run. #[must_use] pub fn ready_signal(mut self, tx: oneshot::Sender<()>) -> Self { self.ready_tx = Some(tx); @@ -324,8 +339,10 @@ where S: futures::Future + Send, { let listener = self.listener.expect("`bind` must be called before `run`"); - if let Some(tx) = self.ready_tx { - let _ = tx.send(()); + if let Some(tx) = self.ready_tx + && tx.send(()).is_err() + { + tracing::warn!("Failed to send readiness signal: receiver dropped"); } let shutdown_token = CancellationToken::new(); @@ -354,6 +371,23 @@ where } } +async fn catch_and_log_unwind(fut: Fut, peer_addr: Option) +where + Fut: std::future::Future + Send + 'static, +{ + use futures::FutureExt as _; + if let Err(panic) = std::panic::AssertUnwindSafe(fut).catch_unwind().await { + let panic_msg = if let Some(s) = panic.downcast_ref::<&str>() { + (*s).to_string() + } else if let Some(s) = panic.downcast_ref::() { + s.clone() + } else { + format!("{panic:?}") + }; + tracing::error!(panic = %panic_msg, ?peer_addr, "connection task panicked"); + } +} + /// Runs a worker task that accepts incoming TCP connections and processes them asynchronously. /// /// Each accepted connection is handled in a separate task, with optional callbacks for preamble @@ -386,24 +420,10 @@ async fn worker_task( let t = tracker.clone(); // Capture peer address for better error context let peer_addr = stream.peer_addr().ok(); - t.spawn(async move { - use futures::FutureExt as _; - if let Err(panic) = std::panic::AssertUnwindSafe( - process_stream(stream, factory, success, failure), - ) - .catch_unwind() - .await - { - let panic_msg = if let Some(s) = panic.downcast_ref::<&str>() { - (*s).to_string() - } else if let Some(s) = panic.downcast_ref::() { - s.clone() - } else { - format!("{panic:?}") - }; - tracing::error!(panic = %panic_msg, ?peer_addr, "connection task panicked"); - } - }); + t.spawn(catch_and_log_unwind( + process_stream(stream, factory, success, failure), + peer_addr, + )); delay = Duration::from_millis(10); } Err(e) => { diff --git a/tests/world.rs b/tests/world.rs index 10f14329..228b7ddb 100644 --- a/tests/world.rs +++ b/tests/world.rs @@ -6,26 +6,18 @@ use std::net::SocketAddr; use cucumber::World; -use tokio::{ - net::TcpStream, - sync::oneshot::{self, Sender}, -}; +use tokio::{net::TcpStream, sync::oneshot}; use wireframe::{app::WireframeApp, server::WireframeServer}; -#[derive(Debug, Default, World)] -pub struct PanicWorld { - pub addr: Option, - pub attempts: usize, - pub shutdown: Option>, - pub handle: Option>, +#[derive(Debug)] +struct PanicServer { + addr: SocketAddr, + shutdown: Option>, + handle: tokio::task::JoinHandle<()>, } -impl PanicWorld { - /// Start a server that panics during connection setup. - /// - /// # Panics - /// Panics if binding the server fails or the server task fails. - pub async fn start_panic_server(&mut self) { +impl PanicServer { + async fn spawn() -> Self { let factory = || { WireframeApp::new() .expect("Failed to create WireframeApp") @@ -37,32 +29,58 @@ impl PanicWorld { .bind("127.0.0.1:0".parse().expect("Failed to parse address")) .expect("bind"); - self.addr = Some(server.local_addr().expect("Failed to get server address")); - let (tx, rx) = oneshot::channel(); - let (ready_tx, ready_rx) = oneshot::channel(); - self.shutdown = Some(tx); + let addr = server.local_addr().expect("Failed to get server address"); + let (tx_shutdown, rx_shutdown) = oneshot::channel(); + let (tx_ready, rx_ready) = oneshot::channel(); - self.handle = Some(tokio::spawn(async move { + let handle = tokio::spawn(async move { server - .ready_signal(ready_tx) + .ready_signal(tx_ready) .run_with_shutdown(async { - let _ = rx.await; + let _ = rx_shutdown.await; }) .await .expect("Server task failed"); - })); + }); + rx_ready.await.expect("Server did not signal ready"); - ready_rx.await.expect("Server did not signal ready"); + Self { + addr, + shutdown: Some(tx_shutdown), + handle, + } } +} + +impl Drop for PanicServer { + fn drop(&mut self) { + if let Some(tx) = self.shutdown.take() { + let _ = tx.send(()); + } + let _ = futures::executor::block_on(&mut self.handle); + } +} + +#[derive(Debug, Default, World)] +pub struct PanicWorld { + server: Option, + attempts: usize, +} + +impl PanicWorld { + /// Start a server that panics during connection setup. + /// + /// # Panics + /// Panics if binding the server fails or the server task fails. + pub async fn start_panic_server(&mut self) { self.server.replace(PanicServer::spawn().await); } /// Connect to the running server once. /// /// # Panics /// Panics if the server address is unknown or the connection fails. pub async fn connect_once(&mut self) { - TcpStream::connect(self.addr.expect("Server address not set")) - .await - .expect("Failed to connect"); + let addr = self.server.as_ref().expect("Server not started").addr; + TcpStream::connect(addr).await.expect("Failed to connect"); self.attempts += 1; } @@ -72,11 +90,8 @@ impl PanicWorld { /// Panics if the connection attempts do not match the expected count. pub async fn verify_and_shutdown(&mut self) { assert_eq!(self.attempts, 2); - if let Some(tx) = self.shutdown.take() { - let _ = tx.send(()); - } - if let Some(handle) = self.handle.take() { - handle.await.expect("Server task join failed"); - } + // dropping PanicServer will shut it down + self.server.take(); + tokio::task::yield_now().await; } } From f8707e3c1bfe63cb137a0090a431dd9452627383 Mon Sep 17 00:00:00 2001 From: Leynos Date: Sun, 27 Jul 2025 01:43:25 +0100 Subject: [PATCH 2/2] Improve server readiness handling --- src/server.rs | 46 +++++++++++++++++++++------------------------- tests/server.rs | 36 ++++++++++++++++++++++++++++++++++++ tests/world.rs | 10 +++++++++- 3 files changed, 66 insertions(+), 26 deletions(-) diff --git a/src/server.rs b/src/server.rs index e35b66d9..bb556905 100644 --- a/src/server.rs +++ b/src/server.rs @@ -339,10 +339,11 @@ where S: futures::Future + Send, { let listener = self.listener.expect("`bind` must be called before `run`"); - if let Some(tx) = self.ready_tx - && tx.send(()).is_err() - { - tracing::warn!("Failed to send readiness signal: receiver dropped"); + if let Some(tx) = self.ready_tx { + let result = tx.send(()); + if result.is_err() { + tracing::warn!("Failed to send readiness signal: receiver dropped"); + } } let shutdown_token = CancellationToken::new(); @@ -371,23 +372,6 @@ where } } -async fn catch_and_log_unwind(fut: Fut, peer_addr: Option) -where - Fut: std::future::Future + Send + 'static, -{ - use futures::FutureExt as _; - if let Err(panic) = std::panic::AssertUnwindSafe(fut).catch_unwind().await { - let panic_msg = if let Some(s) = panic.downcast_ref::<&str>() { - (*s).to_string() - } else if let Some(s) = panic.downcast_ref::() { - s.clone() - } else { - format!("{panic:?}") - }; - tracing::error!(panic = %panic_msg, ?peer_addr, "connection task panicked"); - } -} - /// Runs a worker task that accepts incoming TCP connections and processes them asynchronously. /// /// Each accepted connection is handled in a separate task, with optional callbacks for preamble @@ -420,10 +404,22 @@ async fn worker_task( let t = tracker.clone(); // Capture peer address for better error context let peer_addr = stream.peer_addr().ok(); - t.spawn(catch_and_log_unwind( - process_stream(stream, factory, success, failure), - peer_addr, - )); + t.spawn(async move { + use futures::FutureExt as _; + let fut = std::panic::AssertUnwindSafe( + process_stream(stream, factory, success, failure), + ) + .catch_unwind(); + + if let Err(panic) = fut.await { + let panic_msg = panic + .downcast_ref::<&str>() + .copied() + .or_else(|| panic.downcast_ref::().map(String::as_str)) + .unwrap_or(""); + tracing::error!(panic = %panic_msg, ?peer_addr, "connection task panicked"); + } + }); delay = Duration::from_millis(10); } Err(e) => { diff --git a/tests/server.rs b/tests/server.rs index 031ec01d..3eda6da0 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -28,3 +28,39 @@ fn workers_accepts_large_values() { .workers(128); assert_eq!(server.worker_count(), 128); } + +/// Ensure dropping the readiness receiver logs a warning and does not +/// prevent the server from accepting connections. +#[tokio::test] +async fn readiness_receiver_dropped() { + use tokio::{ + net::TcpStream, + sync::oneshot, + time::{Duration, sleep}, + }; + + let factory = || WireframeApp::new().expect("WireframeApp::new failed"); + let server = WireframeServer::new(factory) + .workers(1) + .bind("127.0.0.1:0".parse().unwrap()) + .unwrap(); + + let addr = server.local_addr().unwrap(); + // Create channel and immediately drop receiver to force send failure + let (tx_ready, rx_ready) = oneshot::channel(); + drop(rx_ready); + + tokio::spawn(async move { + server + .ready_signal(tx_ready) + .run_with_shutdown(tokio::time::sleep(Duration::from_millis(200))) + .await + .unwrap(); + }); + + // Wait briefly to ensure server attempted to send readiness signal + sleep(Duration::from_millis(100)).await; + + // Server should still accept connections + let _stream = TcpStream::connect(addr).await.expect("connect failed"); +} diff --git a/tests/world.rs b/tests/world.rs index 228b7ddb..f1a22e98 100644 --- a/tests/world.rs +++ b/tests/world.rs @@ -54,10 +54,18 @@ impl PanicServer { impl Drop for PanicServer { fn drop(&mut self) { + use std::time::Duration; + if let Some(tx) = self.shutdown.take() { let _ = tx.send(()); } - let _ = futures::executor::block_on(&mut self.handle); + let timeout = Duration::from_secs(5); + let joined = futures::executor::block_on(tokio::time::timeout(timeout, &mut self.handle)); + match joined { + Ok(Ok(())) => {} + Ok(Err(e)) => eprintln!("PanicServer task panicked: {e:?}"), + Err(_) => eprintln!("PanicServer task did not shut down within timeout"), + } } }