Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ where
workers: usize,
on_preamble_success: Option<PreambleCallback<T>>,
on_preamble_failure: Option<PreambleErrorCallback>,
/// Channel used to notify when the server is ready.
///
/// # Thread Safety
/// This sender is `Send` and may be moved between threads safely.
///
/// # Single-use Semantics
/// A `oneshot::Sender` can transmit only one readiness notification. After
/// sending, the sender is consumed and cannot be reused.
///
/// # Implications
/// Because only one notification may be sent, a new `ready_tx` must be
/// provided each time the server is started.
ready_tx: Option<oneshot::Sender<()>>,
_preamble: PhantomData<T>,
}
Expand Down Expand Up @@ -195,6 +207,9 @@ where

/// Configure a channel used to signal when the server is ready to accept
/// connections.
///
/// The provided `oneshot::Sender` is consumed after the first signal. Use a
/// fresh sender for each server run.
#[must_use]
pub fn ready_signal(mut self, tx: oneshot::Sender<()>) -> Self {
self.ready_tx = Some(tx);
Expand Down Expand Up @@ -325,7 +340,10 @@ where
{
let listener = self.listener.expect("`bind` must be called before `run`");
if let Some(tx) = self.ready_tx {
let _ = tx.send(());
let result = tx.send(());
if result.is_err() {
tracing::warn!("Failed to send readiness signal: receiver dropped");
}
}

let shutdown_token = CancellationToken::new();
Expand Down Expand Up @@ -388,19 +406,17 @@ async fn worker_task<F, T>(
let peer_addr = stream.peer_addr().ok();
t.spawn(async move {
use futures::FutureExt as _;
if let Err(panic) = std::panic::AssertUnwindSafe(
let fut = std::panic::AssertUnwindSafe(
process_stream(stream, factory, success, failure),
)
.catch_unwind()
.await
{
let panic_msg = if let Some(s) = panic.downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = panic.downcast_ref::<String>() {
s.clone()
} else {
format!("{panic:?}")
};
.catch_unwind();

if let Err(panic) = fut.await {
let panic_msg = panic
.downcast_ref::<&str>()
.copied()
.or_else(|| panic.downcast_ref::<String>().map(String::as_str))
.unwrap_or("<non-string panic>");
tracing::error!(panic = %panic_msg, ?peer_addr, "connection task panicked");
}
});
Expand Down
36 changes: 36 additions & 0 deletions tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,39 @@ fn workers_accepts_large_values() {
.workers(128);
assert_eq!(server.worker_count(), 128);
}

/// Ensure dropping the readiness receiver logs a warning and does not
/// prevent the server from accepting connections.
#[tokio::test]
async fn readiness_receiver_dropped() {
use tokio::{
net::TcpStream,
sync::oneshot,
time::{Duration, sleep},
};

let factory = || WireframeApp::new().expect("WireframeApp::new failed");
let server = WireframeServer::new(factory)
.workers(1)
.bind("127.0.0.1:0".parse().unwrap())
.unwrap();

let addr = server.local_addr().unwrap();
// Create channel and immediately drop receiver to force send failure
let (tx_ready, rx_ready) = oneshot::channel();
drop(rx_ready);

tokio::spawn(async move {
server
.ready_signal(tx_ready)
.run_with_shutdown(tokio::time::sleep(Duration::from_millis(200)))
.await
.unwrap();
});

// Wait briefly to ensure server attempted to send readiness signal
sleep(Duration::from_millis(100)).await;

// Server should still accept connections
let _stream = TcpStream::connect(addr).await.expect("connect failed");
}
91 changes: 57 additions & 34 deletions tests/world.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,18 @@
use std::net::SocketAddr;

use cucumber::World;
use tokio::{
net::TcpStream,
sync::oneshot::{self, Sender},
};
use tokio::{net::TcpStream, sync::oneshot};
use wireframe::{app::WireframeApp, server::WireframeServer};

#[derive(Debug, Default, World)]
pub struct PanicWorld {
pub addr: Option<SocketAddr>,
pub attempts: usize,
pub shutdown: Option<Sender<()>>,
pub handle: Option<tokio::task::JoinHandle<()>>,
#[derive(Debug)]
struct PanicServer {
addr: SocketAddr,
shutdown: Option<oneshot::Sender<()>>,
handle: tokio::task::JoinHandle<()>,
}

impl PanicWorld {
/// Start a server that panics during connection setup.
///
/// # Panics
/// Panics if binding the server fails or the server task fails.
pub async fn start_panic_server(&mut self) {
impl PanicServer {
async fn spawn() -> Self {
let factory = || {
WireframeApp::new()
.expect("Failed to create WireframeApp")
Expand All @@ -37,32 +29,66 @@ impl PanicWorld {
.bind("127.0.0.1:0".parse().expect("Failed to parse address"))
.expect("bind");

self.addr = Some(server.local_addr().expect("Failed to get server address"));
let (tx, rx) = oneshot::channel();
let (ready_tx, ready_rx) = oneshot::channel();
self.shutdown = Some(tx);
let addr = server.local_addr().expect("Failed to get server address");
let (tx_shutdown, rx_shutdown) = oneshot::channel();
let (tx_ready, rx_ready) = oneshot::channel();

self.handle = Some(tokio::spawn(async move {
let handle = tokio::spawn(async move {
server
.ready_signal(ready_tx)
.ready_signal(tx_ready)
.run_with_shutdown(async {
let _ = rx.await;
let _ = rx_shutdown.await;
})
.await
.expect("Server task failed");
}));
});
rx_ready.await.expect("Server did not signal ready");

ready_rx.await.expect("Server did not signal ready");
Self {
addr,
shutdown: Some(tx_shutdown),
handle,
}
}
}

impl Drop for PanicServer {
fn drop(&mut self) {
use std::time::Duration;

if let Some(tx) = self.shutdown.take() {
let _ = tx.send(());
}
let timeout = Duration::from_secs(5);
let joined = futures::executor::block_on(tokio::time::timeout(timeout, &mut self.handle));
match joined {
Ok(Ok(())) => {}
Ok(Err(e)) => eprintln!("PanicServer task panicked: {e:?}"),
Err(_) => eprintln!("PanicServer task did not shut down within timeout"),
}
}
}

#[derive(Debug, Default, World)]
pub struct PanicWorld {
server: Option<PanicServer>,
attempts: usize,
}
Comment thread
leynos marked this conversation as resolved.

impl PanicWorld {
/// Start a server that panics during connection setup.
///
/// # Panics
/// Panics if binding the server fails or the server task fails.
pub async fn start_panic_server(&mut self) { self.server.replace(PanicServer::spawn().await); }

/// Connect to the running server once.
///
/// # Panics
/// Panics if the server address is unknown or the connection fails.
pub async fn connect_once(&mut self) {
TcpStream::connect(self.addr.expect("Server address not set"))
.await
.expect("Failed to connect");
let addr = self.server.as_ref().expect("Server not started").addr;
TcpStream::connect(addr).await.expect("Failed to connect");
self.attempts += 1;
}

Expand All @@ -72,11 +98,8 @@ impl PanicWorld {
/// Panics if the connection attempts do not match the expected count.
pub async fn verify_and_shutdown(&mut self) {
assert_eq!(self.attempts, 2);
if let Some(tx) = self.shutdown.take() {
let _ = tx.send(());
}
if let Some(handle) = self.handle.take() {
handle.await.expect("Server task join failed");
}
// dropping PanicServer will shut it down
self.server.take();
tokio::task::yield_now().await;
}
}