diff --git a/Cargo.lock b/Cargo.lock index 12b03ee3..3cd970da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -573,6 +573,7 @@ version = "0.1.0" dependencies = [ "bincode", "bytes", + "rstest", "tokio", "wireframe", ] diff --git a/tests/lifecycle.rs b/tests/lifecycle.rs index b02a552d..b68d3566 100644 --- a/tests/lifecycle.rs +++ b/tests/lifecycle.rs @@ -1,40 +1,66 @@ -use std::sync::{ - Arc, - atomic::{AtomicUsize, Ordering}, +use std::{ + future::Future, + pin::Pin, + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, }; -use tokio::io::duplex; -use wireframe::app::WireframeApp; +use bytes::BytesMut; +use wireframe::{ + app::{Envelope, Packet, WireframeApp}, + frame::{FrameProcessor, LengthPrefixedProcessor}, + serializer::{BincodeSerializer, Serializer}, +}; +use wireframe_testing::{processor, run_app_with_frame, run_with_duplex_server}; + +fn call_counting_callback( + counter: &Arc, + result: R, +) -> impl Fn(A) -> Pin + Send>> + Clone + 'static +where + A: Send + 'static, + R: Clone + Send + 'static, +{ + let counter = counter.clone(); + move |_| { + let counter = counter.clone(); + let result = result.clone(); + Box::pin(async move { + counter.fetch_add(1, Ordering::SeqCst); + result + }) + } +} + +fn wireframe_app_with_lifecycle_callbacks( + setup: &Arc, + teardown: &Arc, + state: u32, +) -> WireframeApp +where + E: Packet, +{ + let setup_cb = call_counting_callback(setup, state); + let teardown_cb = call_counting_callback(teardown, ()); + + WireframeApp::<_, _, E>::new_with_envelope() + .unwrap() + .on_connection_setup(move || setup_cb(())) + .unwrap() + .on_connection_teardown(teardown_cb) + .unwrap() +} #[tokio::test] async fn setup_and_teardown_callbacks_run() { let setup_count = Arc::new(AtomicUsize::new(0)); let teardown_count = Arc::new(AtomicUsize::new(0)); - let setup_clone = setup_count.clone(); - let teardown_clone = teardown_count.clone(); - - let app = WireframeApp::new() - .unwrap() - .on_connection_setup(move || { - let setup_clone = setup_clone.clone(); - async move { - setup_clone.fetch_add(1, Ordering::SeqCst); - 42u32 - } - }) - .unwrap() - .on_connection_teardown(move |state| { - let teardown_clone = teardown_clone.clone(); - async move { - assert_eq!(state, 42u32); - teardown_clone.fetch_add(1, Ordering::SeqCst); - } - }) - .unwrap(); + let app = wireframe_app_with_lifecycle_callbacks::(&setup_count, &teardown_count, 42); - let (_client, server) = duplex(64); - app.handle_connection(server).await; + run_with_duplex_server(app).await; assert_eq!(setup_count.load(Ordering::SeqCst), 1); assert_eq!(teardown_count.load(Ordering::SeqCst), 1); @@ -42,20 +68,14 @@ async fn setup_and_teardown_callbacks_run() { #[tokio::test] async fn setup_without_teardown_runs() { let setup_count = Arc::new(AtomicUsize::new(0)); - let setup_clone = setup_count.clone(); + let cb = call_counting_callback(&setup_count, ()); let app = WireframeApp::new() .unwrap() - .on_connection_setup(move || { - let setup_clone = setup_clone.clone(); - async move { - setup_clone.fetch_add(1, Ordering::SeqCst); - } - }) + .on_connection_setup(move || cb(())) .unwrap(); - let (_client, server) = duplex(64); - app.handle_connection(server).await; + run_with_duplex_server(app).await; assert_eq!(setup_count.load(Ordering::SeqCst), 1); } @@ -63,20 +83,54 @@ async fn setup_without_teardown_runs() { #[tokio::test] async fn teardown_without_setup_does_not_run() { let teardown_count = Arc::new(AtomicUsize::new(0)); - let teardown_clone = teardown_count.clone(); + let cb = call_counting_callback(&teardown_count, ()); let app = WireframeApp::new() .unwrap() - .on_connection_teardown(move |()| { - let teardown_clone = teardown_clone.clone(); - async move { - teardown_clone.fetch_add(1, Ordering::SeqCst); - } - }) + .on_connection_teardown(cb) .unwrap(); - let (_client, server) = duplex(64); - app.handle_connection(server).await; + run_with_duplex_server(app).await; assert_eq!(teardown_count.load(Ordering::SeqCst), 0); } + +#[derive(bincode::Encode, bincode::BorrowDecode, PartialEq, Debug)] +struct StateEnvelope { + id: u32, + msg: Vec, +} + +impl wireframe::app::Packet for StateEnvelope { + fn id(&self) -> u32 { self.id } + + fn into_parts(self) -> (u32, Vec) { (self.id, self.msg) } + + fn from_parts(id: u32, msg: Vec) -> Self { Self { id, msg } } +} + +#[tokio::test] +async fn helpers_propagate_connection_state() { + let setup = Arc::new(AtomicUsize::new(0)); + let teardown = Arc::new(AtomicUsize::new(0)); + + let app = wireframe_app_with_lifecycle_callbacks::(&setup, &teardown, 7) + .frame_processor(processor()) + .route(1, Arc::new(|_: &StateEnvelope| Box::pin(async {}))) + .unwrap(); + + let env = StateEnvelope { + id: 1, + msg: vec![1], + }; + let bytes = BincodeSerializer.serialize(&env).unwrap(); + let mut frame = BytesMut::new(); + LengthPrefixedProcessor::default() + .encode(&bytes, &mut frame) + .unwrap(); + + let out = run_app_with_frame(app, frame.to_vec()).await.unwrap(); + assert!(!out.is_empty()); + assert_eq!(setup.load(Ordering::SeqCst), 1); + assert_eq!(teardown.load(Ordering::SeqCst), 1); +} diff --git a/wireframe_testing/Cargo.toml b/wireframe_testing/Cargo.toml index de9f5b66..062b4a05 100644 --- a/wireframe_testing/Cargo.toml +++ b/wireframe_testing/Cargo.toml @@ -8,6 +8,4 @@ tokio = { version = "1", features = ["macros", "rt", "io-util"] } wireframe = { path = ".." } bincode = "^2.0" bytes = "^1.0" - -[dev-dependencies] rstest = "0.18.2" diff --git a/wireframe_testing/src/helpers.rs b/wireframe_testing/src/helpers.rs index 8a754068..a4957233 100644 --- a/wireframe_testing/src/helpers.rs +++ b/wireframe_testing/src/helpers.rs @@ -1,5 +1,6 @@ use bincode::config; use bytes::BytesMut; +use rstest::fixture; use tokio::io::{self, AsyncReadExt, AsyncWriteExt, duplex}; use wireframe::{ app::{Envelope, Packet, WireframeApp}, @@ -7,10 +8,19 @@ use wireframe::{ serializer::Serializer, }; +/// Create a default length-prefixed frame processor for tests. +#[fixture] +#[allow( + unused_braces, + reason = "Clippy is wrong here; this is not a redundant block" +)] +pub fn processor() -> LengthPrefixedProcessor { LengthPrefixedProcessor::default() } + pub trait TestSerializer: Serializer + FrameMetadata + Send + Sync + 'static { } + impl TestSerializer for T where T: Serializer + FrameMetadata + Send + Sync + 'static { @@ -172,3 +182,116 @@ where LengthPrefixedProcessor::default().encode(&bytes, &mut framed)?; drive_with_frame(app, framed.to_vec()).await } + +/// Run `app` with a single input `frame` using the default buffer capacity. +/// +/// # Errors +/// +/// Returns any I/O errors encountered while interacting with the in-memory +/// duplex stream. +pub async fn run_app_with_frame( + app: WireframeApp, + frame: Vec, +) -> io::Result> +where + S: TestSerializer, + C: Send + 'static, + E: Packet, +{ + run_app_with_frame_with_capacity(app, frame, DEFAULT_CAPACITY).await +} + +/// Drive `app` with a single frame using a duplex buffer of `capacity` bytes. +/// +/// # Errors +/// +/// Propagates any I/O errors from the in-memory connection. +/// +/// # Panics +/// +/// Panics if the spawned task running the application panics. +pub async fn run_app_with_frame_with_capacity( + app: WireframeApp, + frame: Vec, + capacity: usize, +) -> io::Result> +where + S: TestSerializer, + C: Send + 'static, + E: Packet, +{ + run_app_with_frames_with_capacity(app, vec![frame], capacity).await +} + +/// Run `app` with multiple input `frames` using the default buffer capacity. +/// +/// # Errors +/// +/// Returns any I/O errors encountered while interacting with the in-memory +/// duplex stream. +#[allow(dead_code)] +pub async fn run_app_with_frames( + app: WireframeApp, + frames: Vec>, +) -> io::Result> +where + S: TestSerializer, + C: Send + 'static, + E: Packet, +{ + run_app_with_frames_with_capacity(app, frames, DEFAULT_CAPACITY).await +} + +/// Drive `app` with multiple frames using a duplex buffer of `capacity` bytes. +/// +/// # Errors +/// +/// Propagates any I/O errors from the in-memory connection. +/// +/// # Panics +/// +/// Panics if the spawned task running the application panics. +pub async fn run_app_with_frames_with_capacity( + app: WireframeApp, + frames: Vec>, + capacity: usize, +) -> io::Result> +where + S: TestSerializer, + C: Send + 'static, + E: Packet, +{ + let (mut client, server) = duplex(capacity); + let server_task = tokio::spawn(async move { + app.handle_connection(server).await; + }); + + for frame in &frames { + client.write_all(frame).await?; + } + client.shutdown().await?; + + let mut buf = Vec::new(); + client.read_to_end(&mut buf).await?; + + server_task.await.unwrap(); + Ok(buf) +} + +/// Run `app` against an empty duplex stream. +/// +/// This helper drives the connection lifecycle without sending any frames, +/// ensuring setup and teardown callbacks execute. +/// +/// # Panics +/// +/// Panics if `handle_connection` fails. +pub async fn run_with_duplex_server(app: WireframeApp) +where + S: TestSerializer, + C: Send + 'static, + E: Packet, +{ + let (_client, server) = duplex(64); + app.handle_connection(server).await; +} diff --git a/wireframe_testing/src/lib.rs b/wireframe_testing/src/lib.rs index bec3748b..183bbf12 100644 --- a/wireframe_testing/src/lib.rs +++ b/wireframe_testing/src/lib.rs @@ -24,4 +24,10 @@ pub use helpers::{ drive_with_frames, drive_with_frames_mut, drive_with_frames_with_capacity, + processor, + run_app_with_frame, + run_app_with_frame_with_capacity, + run_app_with_frames, + run_app_with_frames_with_capacity, + run_with_duplex_server, };