Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 116 additions & 31 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,12 +372,45 @@ where
}
}

/// Runs a worker task that accepts incoming TCP connections and processes them asynchronously.
///
/// Each accepted connection is handled in a separate task, with optional callbacks for preamble
/// decode success or failure. The worker listens for shutdown signals to terminate gracefully.
/// Accept errors are retried with exponential backoff.
async fn worker_task<F, T>(
/// Spawn a task to process a single TCP connection, logging and discarding any
/// panics from the task.
fn spawn_connection_task<F, T>(
stream: tokio::net::TcpStream,
factory: F,
on_success: Option<PreambleCallback<T>>,
on_failure: Option<PreambleErrorCallback>,
tracker: &TaskTracker,
) where
F: Fn() -> WireframeApp + Send + Sync + Clone + 'static,
T: Preamble,
{
let peer_addr = match stream.peer_addr() {
Ok(addr) => Some(addr),
Err(e) => {
tracing::warn!(error = %e, "Failed to retrieve peer address");
None
}
};
tracker.spawn(async move {
use futures::FutureExt as _;
let fut =
std::panic::AssertUnwindSafe(process_stream(stream, factory, on_success, on_failure))
.catch_unwind();

if let Err(panic) = fut.await {
let panic_msg = panic
.downcast_ref::<&str>()
.copied()
.or_else(|| panic.downcast_ref::<String>().map(String::as_str))
.unwrap_or("<non-string panic>");
tracing::error!(panic = %panic_msg, ?peer_addr, "connection task panicked");
}
});
}

/// Accept incoming connections until `shutdown` is triggered, retrying on
/// errors with exponential backoff.
async fn accept_loop<F, T>(
listener: Arc<TcpListener>,
factory: F,
on_success: Option<PreambleCallback<T>>,
Expand All @@ -386,7 +419,6 @@ async fn worker_task<F, T>(
tracker: TaskTracker,
) where
F: Fn() -> WireframeApp + Send + Sync + Clone + 'static,
// `Preamble` ensures `T` supports borrowed decoding.
T: Preamble,
{
let mut delay = Duration::from_millis(10);
Expand All @@ -398,28 +430,13 @@ async fn worker_task<F, T>(

res = listener.accept() => match res {
Ok((stream, _)) => {
let success = on_success.clone();
let failure = on_failure.clone();
let factory = factory.clone();
let t = tracker.clone();
// Capture peer address for better error context
let peer_addr = stream.peer_addr().ok();
t.spawn(async move {
use futures::FutureExt as _;
let fut = std::panic::AssertUnwindSafe(
process_stream(stream, factory, success, failure),
)
.catch_unwind();

if let Err(panic) = fut.await {
let panic_msg = panic
.downcast_ref::<&str>()
.copied()
.or_else(|| panic.downcast_ref::<String>().map(String::as_str))
.unwrap_or("<non-string panic>");
tracing::error!(panic = %panic_msg, ?peer_addr, "connection task panicked");
}
});
spawn_connection_task(
stream,
factory.clone(),
on_success.clone(),
on_failure.clone(),
&tracker,
);
delay = Duration::from_millis(10);
}
Err(e) => {
Expand All @@ -432,6 +449,25 @@ async fn worker_task<F, T>(
}
}

/// Worker task that delegates connection acceptance to `accept_loop`.
///
/// This function serves as an entry point for worker tasks, passing all parameters
/// to `accept_loop` which handles the actual connection acceptance and processing.
async fn worker_task<F, T>(
listener: Arc<TcpListener>,
factory: F,
on_success: Option<PreambleCallback<T>>,
on_failure: Option<PreambleErrorCallback>,
shutdown: CancellationToken,
tracker: TaskTracker,
) where
F: Fn() -> WireframeApp + Send + Sync + Clone + 'static,
// `Preamble` ensures `T` supports borrowed decoding.
T: Preamble,
{
accept_loop(listener, factory, on_success, on_failure, shutdown, tracker).await;
}

/// Processes an incoming TCP stream by decoding a preamble and dispatching the connection to a
/// `WireframeApp`.
///
Expand Down Expand Up @@ -857,14 +893,14 @@ mod tests {

#[rstest]
#[tokio::test]
async fn test_worker_task_shutdown_signal(
async fn test_accept_loop_shutdown_signal(
factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static,
) {
let token = CancellationToken::new();
let tracker = TaskTracker::new();
let listener = Arc::new(TcpListener::bind("127.0.0.1:0").await.unwrap());

tracker.spawn(worker_task::<_, ()>(
tracker.spawn(accept_loop::<_, ()>(
listener,
factory,
None,
Expand Down Expand Up @@ -922,6 +958,55 @@ mod tests {
assert!(cfg!(debug_assertions));
}

/// Panics in connection handlers are logged and do not tear down the worker.
#[rstest]
#[traced_test]
#[tokio::test]
async fn spawn_connection_task_logs_panic(
factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static,
) {
let app_factory = move || {
factory()
.on_connection_setup(|| async { panic!("boom") })
.unwrap()
};
let tracker = TaskTracker::new();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();

let handle = tokio::spawn({
let tracker = tracker.clone();
async move {
let (stream, _) = listener.accept().await.unwrap();
spawn_connection_task::<_, ()>(stream, app_factory, None, None, &tracker);
tracker.close();
tracker.wait().await;
}
});

let client = TcpStream::connect(addr).await.unwrap();
let peer_addr = client.local_addr().unwrap();
client.writable().await.unwrap();
client.try_write(&[0; 8]).unwrap();
drop(client);

handle.await.unwrap();

tokio::task::yield_now().await;

logs_assert(|lines: &[&str]| {
lines
.iter()
.find(|line| {
line.contains("connection task panicked")
&& line.contains("panic=boom")
&& line.contains(&format!("peer_addr=Some({peer_addr})"))
})
.map(|_| ())
.ok_or_else(|| "panic log not found".to_string())
});
}

/// Ensure the server survives panicking connection tasks.
///
/// The test spawns a server with a connection setup callback that
Expand Down