diff --git a/docs/wireframe-testing-crate.md b/docs/wireframe-testing-crate.md index dc6a791f..3c00bb49 100644 --- a/docs/wireframe-testing-crate.md +++ b/docs/wireframe-testing-crate.md @@ -65,10 +65,13 @@ where These functions mirror the behaviour of `run_app_with_frame` and `run_app_with_frames` found in the repository’s test utilities. They create a -`tokio::io::duplex` stream, spawn the application as a background task, and -write the provided frame(s) to the client side of the stream. After the -application finishes processing, the helpers collect the bytes written back and -return them for inspection. +`tokio::io::duplex` stream, run the application on the server half, and write +the provided frame(s) to the client side. All helpers delegate to a single +internal function that handles this I/O plumbing, ensuring consistent +behaviour. Should the application panic, the panic message is returned as an +`io::Error` beginning with `server task failed`, helping surface failures in +tests. After the application finishes processing the input frames, the bytes +written back are collected for inspection. Any I/O errors surfaced by the duplex stream or failures while decoding a length prefix propagate through the returned `IoResult`. Malformed or truncated diff --git a/wireframe_testing/Cargo.toml b/wireframe_testing/Cargo.toml index 57643b84..9b3af948 100644 --- a/wireframe_testing/Cargo.toml +++ b/wireframe_testing/Cargo.toml @@ -12,3 +12,4 @@ rstest = "0.18.2" logtest = "2" log = "0.4" metrics-util = "0.20" +futures = "0.3" diff --git a/wireframe_testing/src/helpers.rs b/wireframe_testing/src/helpers.rs index bf187469..cc3c81da 100644 --- a/wireframe_testing/src/helpers.rs +++ b/wireframe_testing/src/helpers.rs @@ -6,7 +6,7 @@ use bincode::config; use bytes::BytesMut; use rstest::fixture; -use tokio::io::{self, AsyncReadExt, AsyncWriteExt, duplex}; +use tokio::io::{self, AsyncReadExt, AsyncWriteExt, DuplexStream, duplex}; use wireframe::{ app::{Envelope, Packet, WireframeApp}, frame::{FrameMetadata, FrameProcessor, LengthPrefixedProcessor}, @@ -33,6 +33,53 @@ impl TestSerializer for T where const DEFAULT_CAPACITY: usize = 4096; +async fn drive_internal( + server_fn: F, + frames: Vec>, + capacity: usize, +) -> io::Result> +where + F: FnOnce(DuplexStream) -> Fut, + Fut: std::future::Future + Send, +{ + let (mut client, server) = duplex(capacity); + + let server_fut = async { + use futures::FutureExt as _; + let result = std::panic::AssertUnwindSafe(server_fn(server)) + .catch_unwind() + .await; + match result { + Ok(_) => Ok(()), + Err(panic) => { + let msg = panic + .downcast_ref::<&str>() + .copied() + .or_else(|| panic.downcast_ref::().map(String::as_str)) + .unwrap_or(""); + Err(io::Error::new( + io::ErrorKind::Other, + format!("server task failed: {msg}"), + )) + } + } + }; + + let client_fut = async { + for frame in &frames { + client.write_all(frame).await?; + } + client.shutdown().await?; + + let mut buf = Vec::new(); + client.read_to_end(&mut buf).await?; + io::Result::Ok(buf) + }; + + let ((), buf) = tokio::try_join!(server_fut, client_fut)?; + Ok(buf) +} + /// Drive `app` with a single length-prefixed `frame` and return the bytes /// produced by the server. /// @@ -134,26 +181,12 @@ where C: Send + 'static, E: Packet, { - let (mut client, server) = duplex(capacity); - let server_task = tokio::spawn(async move { - app.handle_connection(server).await; - }); - - for frame in &frames { - client.write_all(frame).await?; - } - client.shutdown().await?; - - let mut buf = Vec::new(); - client.read_to_end(&mut buf).await?; - - match server_task.await { - Ok(_) => Ok(buf), - Err(e) => Err(io::Error::new( - io::ErrorKind::Other, - format!("server task failed: {e}"), - )), - } + drive_internal( + |server| async move { app.handle_connection(server).await }, + frames, + capacity, + ) + .await } /// Feed a single frame into a mutable `app`, allowing the instance to be reused @@ -248,22 +281,12 @@ where C: Send + 'static, E: Packet, { - let (mut client, server) = duplex(capacity); - - let server_fut = app.handle_connection(server); - let client_fut = async { - for frame in &frames { - client.write_all(frame).await?; - } - client.shutdown().await?; - - let mut buf = Vec::new(); - client.read_to_end(&mut buf).await?; - io::Result::Ok(buf) - }; - - let ((), buf) = tokio::join!(server_fut, client_fut); - buf + drive_internal( + |server| async { app.handle_connection(server).await }, + frames, + capacity, + ) + .await } /// Encode `msg` using bincode, frame it and drive `app`.