diff --git a/Makefile b/Makefile index c252820b..284f284a 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: help all clean test build release lint fmt check-fmt markdownlint nixie +.PHONY: help all clean test build release lint fmt check-fmt markdownlint nixie typecheck CRATE ?= wireframe CARGO ?= cargo @@ -22,6 +22,9 @@ test-bdd: ## Run rstest-bdd tests only test: ## Run all tests (bdd + unit/integration) RUSTFLAGS="-D warnings" $(CARGO) test --all-targets --all-features $(BUILD_JOBS) +typecheck: ## Run a workspace typecheck + RUSTFLAGS="-D warnings" $(CARGO) check --all-targets --all-features $(BUILD_JOBS) + # will match target/debug/libmy_library.rlib and target/release/libmy_library.rlib target/%/lib$(CRATE).rlib: ## Build library in debug or release $(CARGO) build $(BUILD_JOBS) \ diff --git a/docs/v0-1-0-to-v0-2-0-migration-guide.md b/docs/v0-1-0-to-v0-2-0-migration-guide.md index 8ceb27f8..8d9a72dd 100644 --- a/docs/v0-1-0-to-v0-2-0-migration-guide.md +++ b/docs/v0-1-0-to-v0-2-0-migration-guide.md @@ -1,7 +1,7 @@ # v0.1.0 to v0.2.0 migration guide -This guide highlights user-facing changes required when upgrading from v0.1.0 -to v0.2.0. +This guide summarizes the breaking changes required when migrating from +wireframe v0.1.0 to v0.2.0. ## Configuration builder naming update @@ -16,3 +16,22 @@ Update any references accordingly, including documentation and code examples. ```rust let config = BackoffConfig::normalized(...); ``` + +## Payload accessors + +The consuming payload accessors were renamed to follow Rust idioms. + +- `PacketParts::payload(self)` has been removed. Use + `PacketParts::into_payload(self)` instead. +- `FragmentParts::payload(self)` has been removed. Use + `FragmentParts::into_payload(self)` instead. + +```rust +// Before +let packet_payload = packet_parts.payload(); +let fragment_payload = fragment_parts.payload(); + +// After +let packet_payload = packet_parts.into_payload(); +let fragment_payload = fragment_parts.into_payload(); +``` diff --git a/src/app/envelope.rs b/src/app/envelope.rs index 85b843c6..ed26804f 100644 --- a/src/app/envelope.rs +++ b/src/app/envelope.rs @@ -55,7 +55,7 @@ use crate::{ /// Self { /// id: parts.id(), /// correlation_id: parts.correlation_id(), -/// payload: parts.payload(), +/// payload: parts.into_payload(), /// timestamp: 0, /// } /// } @@ -192,10 +192,10 @@ impl PacketParts { /// use wireframe::app::PacketParts; /// /// let parts = PacketParts::new(1, None, vec![7, 8]); - /// assert_eq!(parts.payload(), vec![7, 8]); + /// assert_eq!(parts.into_payload(), vec![7, 8]); /// ``` #[must_use] - pub fn payload(self) -> Vec { self.payload } + pub fn into_payload(self) -> Vec { self.payload } /// Ensure a correlation identifier is present, inheriting from `source` if missing. /// @@ -243,7 +243,7 @@ impl From for Envelope { fn from(p: PacketParts) -> Self { let id = p.id(); let correlation_id = p.correlation_id(); - let payload = p.payload(); + let payload = p.into_payload(); Envelope::new(id, correlation_id, payload) } } @@ -252,14 +252,14 @@ impl From for Envelope { impl Fragmentable for T { fn into_fragment_parts(self) -> FragmentParts { let parts = self.into_parts(); - FragmentParts::new(parts.id(), parts.correlation_id(), parts.payload()) + FragmentParts::new(parts.id(), parts.correlation_id(), parts.into_payload()) } fn from_fragment_parts(parts: FragmentParts) -> Self { T::from_parts(PacketParts::new( parts.id(), parts.correlation_id(), - parts.payload(), + parts.into_payload(), )) } } diff --git a/src/app/fragmentation_state.rs b/src/app/fragmentation_state.rs index 80b7b2ba..fb15779d 100644 --- a/src/app/fragmentation_state.rs +++ b/src/app/fragmentation_state.rs @@ -54,7 +54,7 @@ impl FragmentationState { let parts = packet.into_parts(); let id = parts.id(); let correlation = parts.correlation_id(); - let payload = parts.payload(); + let payload = parts.into_payload(); match decode_fragment_payload(&payload) { Ok(Some((header, fragment_payload))) => { diff --git a/src/client/tests/messaging.rs b/src/client/tests/messaging.rs index 9eacdf12..bc1f90ec 100644 --- a/src/client/tests/messaging.rs +++ b/src/client/tests/messaging.rs @@ -153,7 +153,7 @@ async fn receive_envelope_returns_envelope_with_correlation_id() { "response should have the same correlation ID" ); assert_eq!(response.id(), 42); - assert_eq!(response.into_parts().payload(), &[1, 2, 3]); + assert_eq!(response.into_parts().into_payload(), &[1, 2, 3]); server.abort(); } @@ -295,7 +295,7 @@ async fn round_trip_with_various_payload_sizes(#[case] payload: Vec) { .await .expect("call should succeed"); - assert_eq!(response.into_parts().payload(), payload.as_slice()); + assert_eq!(response.into_parts().into_payload(), payload.as_slice()); server.abort(); } diff --git a/src/connection/test_support.rs b/src/connection/test_support.rs index 6898d271..b3f57bf9 100644 --- a/src/connection/test_support.rs +++ b/src/connection/test_support.rs @@ -27,7 +27,7 @@ impl Packet for u8 { fn into_parts(self) -> PacketParts { PacketParts::new(0, None, vec![self]) } fn from_parts(parts: PacketParts) -> Self { - parts.payload().first().copied().unwrap_or_default() + parts.into_payload().first().copied().unwrap_or_default() } } @@ -36,7 +36,7 @@ impl Packet for Vec { fn into_parts(self) -> PacketParts { PacketParts::new(0, None, self) } - fn from_parts(parts: PacketParts) -> Self { parts.payload() } + fn from_parts(parts: PacketParts) -> Self { parts.into_payload() } } /// Build a connection actor configured with the supplied protocol hooks. diff --git a/src/fragment/packet.rs b/src/fragment/packet.rs index e8c21068..2d31c65d 100644 --- a/src/fragment/packet.rs +++ b/src/fragment/packet.rs @@ -38,8 +38,17 @@ impl FragmentParts { pub const fn correlation_id(&self) -> Option { self.correlation_id } /// Consume the parts and return the raw payload bytes. + /// + /// # Examples + /// + /// ``` + /// use wireframe::fragment::FragmentParts; + /// + /// let parts = FragmentParts::new(1, None, vec![7, 8]); + /// assert_eq!(parts.into_payload(), vec![7, 8]); + /// ``` #[must_use] - pub fn payload(self) -> Vec { self.payload } + pub fn into_payload(self) -> Vec { self.payload } } /// A packet that can be decomposed into parts and reconstructed for fragmentation. @@ -74,7 +83,7 @@ pub fn fragment_packet( let parts = packet.into_fragment_parts(); let id = parts.id(); let correlation = parts.correlation_id(); - let payload = parts.payload(); + let payload = parts.into_payload(); let batch = fragmenter.fragment_bytes(&payload)?; if !batch.is_fragmented() { diff --git a/src/middleware.rs b/src/middleware.rs index f5eca6f7..b1fc6509 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -333,7 +333,7 @@ impl Service for RouteService { (self.handler.as_ref())(&env).await; let parts = env.into_parts(); let correlation_id = parts.correlation_id(); - let payload = parts.payload(); + let payload = parts.into_payload(); Ok(ServiceResponse::new(payload, correlation_id)) } } diff --git a/tests/common/fragment_helpers.rs b/tests/common/fragment_helpers.rs index 4d18887e..4bd46f34 100644 --- a/tests/common/fragment_helpers.rs +++ b/tests/common/fragment_helpers.rs @@ -127,7 +127,7 @@ pub fn fragment_envelope(env: &Envelope, fragmenter: &Fragmenter) -> TestResult< let parts = env.clone().into_parts(); let id = parts.id(); let correlation = parts.correlation_id(); - let payload = parts.payload(); + let payload = parts.into_payload(); if payload.len() <= fragmenter.max_fragment_size().get() { return Ok(vec![Envelope::new(id, correlation, payload)]); @@ -180,7 +180,7 @@ pub async fn read_reassembled_response( while let Some(frame) = client.next().await { let bytes = frame?; let (env, _) = serializer.deserialize::(&bytes)?; - let payload = env.into_parts().payload(); + let payload = env.into_parts().into_payload(); match decode_fragment_payload(&payload)? { Some((header, fragment)) => { if let Some(message) = reassembler.push(header, fragment)? { @@ -204,7 +204,7 @@ pub fn make_handler(sender: &mpsc::UnboundedSender>) -> Handler TestResult { let mut reassembler = Reassembler::new(cfg.max_message_size, cfg.reassembly_timeout); let mut assembled: Option> = None; for env in out { - let payload = env.into_parts().payload(); + let payload = env.into_parts().into_payload(); let Some((header, frag)) = decode_fragment_payload(&payload)? else { assembled = Some(payload); continue; @@ -109,7 +109,7 @@ async fn connection_actor_passes_through_small_outbound_frames_unfragmented() -> .into_iter() .next() .ok_or("expected single frame but none found")?; - let payload_out = only.into_parts().payload(); + let payload_out = only.into_parts().into_payload(); match decode_fragment_payload(&payload_out)? { None => {} Some(_) => return Err("expected unfragmented payload".into()), diff --git a/tests/example_codecs.rs b/tests/example_codecs.rs index 49618113..ddc9d1fc 100644 --- a/tests/example_codecs.rs +++ b/tests/example_codecs.rs @@ -156,6 +156,6 @@ async fn hotline_codec_round_trips_through_app() { .deserialize::(&response_frame.payload) .expect("deserialize response"); assert_eq!(response_env.correlation_id(), Some(7)); - let response_payload = response_env.into_parts().payload(); + let response_payload = response_env.into_parts().into_payload(); assert_eq!(response_payload, b"ping".to_vec()); } diff --git a/tests/fragment_transport/rejection.rs b/tests/fragment_transport/rejection.rs index 632bab31..2947aae3 100644 --- a/tests/fragment_transport/rejection.rs +++ b/tests/fragment_transport/rejection.rs @@ -123,7 +123,7 @@ fn mutate_malformed_header(mut fragments: Vec) -> TestResult(&response_frame.payload) .expect("deserialize response"); assert_eq!(response_env.correlation_id(), Some(7)); - let response_payload = response_env.into_parts().payload(); + let response_payload = response_env.into_parts().into_payload(); assert_eq!(response_payload, b"ping".to_vec()); } diff --git a/tests/middleware_order.rs b/tests/middleware_order.rs index d6b47021..4ec577b6 100644 --- a/tests/middleware_order.rs +++ b/tests/middleware_order.rs @@ -94,7 +94,7 @@ async fn middleware_applied_in_reverse_order() -> TestResult<()> { let (resp, _) = serializer.deserialize::(first)?; let parts = wireframe::app::Packet::into_parts(resp); let correlation_id = parts.correlation_id(); - let payload = parts.payload(); + let payload = parts.into_payload(); assert_eq!( payload, [b'X', b'A', b'B', b'B', b'A'], diff --git a/tests/multi_packet_streaming.rs b/tests/multi_packet_streaming.rs index bf493867..940ebcaa 100644 --- a/tests/multi_packet_streaming.rs +++ b/tests/multi_packet_streaming.rs @@ -137,7 +137,10 @@ async fn client_receives_multi_packet_stream_with_terminator() -> TestResult<()> let out = harness.run().await?; assert_eq!(out.len(), 3, "expected two frames plus terminator"); - let payloads: Vec> = out.iter().map(|frame| parts(frame).payload()).collect(); + let payloads: Vec> = out + .iter() + .map(|frame| parts(frame).into_payload()) + .collect(); assert_eq!(payloads.first(), Some(&vec![1]), "first payload mismatch"); assert_eq!( payloads.get(1), diff --git a/tests/packet_parts.rs b/tests/packet_parts.rs index c025c96d..343f08b8 100644 --- a/tests/packet_parts.rs +++ b/tests/packet_parts.rs @@ -11,7 +11,7 @@ fn envelope_from_parts_round_trip() { let parts = rebuilt.into_parts(); let id = parts.id(); let correlation_id = parts.correlation_id(); - let payload = parts.payload(); + let payload = parts.into_payload(); assert_eq!(id, 2); assert_eq!(correlation_id, Some(5)); assert_eq!(payload, vec![1, 2]); diff --git a/tests/preamble.rs b/tests/preamble.rs index fc5a9edf..1c66b419 100644 --- a/tests/preamble.rs +++ b/tests/preamble.rs @@ -1,568 +1,15 @@ #![cfg(not(loom))] //! Tests for connection preamble reading. -use std::{ - error::Error, - io, - sync::{Arc, Mutex}, -}; - -use bincode::error::DecodeError; -use futures::future::BoxFuture; mod common; -use common::{TestResult, factory, unused_listener}; -use rstest::rstest; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt, duplex}, - net::TcpStream, - sync::oneshot, - time::{Duration, timeout}, -}; -use wireframe::{app::WireframeApp, preamble::read_preamble, server::WireframeServer}; - -#[derive(Debug, Clone, PartialEq, Eq, bincode::Encode, bincode::Decode)] -struct HotlinePreamble { - /// Should always be `b"TRTPHOTL"`. - magic: [u8; 8], - /// Minimum server version this client supports. - min_version: u16, - /// Client version. - client_version: u16, -} - -impl HotlinePreamble { - const MAGIC: [u8; 8] = *b"TRTPHOTL"; - - fn validate(&self) -> Result<(), DecodeError> { - if self.magic != Self::MAGIC { - return Err(DecodeError::Other("invalid hotline preamble")); - } - Ok(()) - } -} - -/// Create a server configured with `HotlinePreamble` handlers. -fn server_with_handlers( - factory: F, - success: S, - failure: E, -) -> WireframeServer -where - F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, - S: for<'a> Fn(&'a HotlinePreamble, &'a mut TcpStream) -> BoxFuture<'a, io::Result<()>> - + Send - + Sync - + 'static, - E: for<'a> Fn(&'a DecodeError, &'a mut TcpStream) -> BoxFuture<'a, io::Result<()>> - + Send - + Sync - + 'static, -{ - WireframeServer::new(factory) - .workers(1) - .with_preamble::() - .on_preamble_decode_success(success) - .on_preamble_decode_failure(failure) -} - -/// Run the provided server while executing `block`. -async fn with_running_server(server: WireframeServer, block: B) -> TestResult -where - F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, - T: wireframe::preamble::Preamble, - Fut: std::future::Future, - B: FnOnce(std::net::SocketAddr) -> Fut, -{ - let listener = unused_listener(); - let server = server.bind_existing_listener(listener)?; - let addr = server - .local_addr() - .ok_or_else(|| Box::::from("server missing local addr"))?; - let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); - let handle = tokio::spawn(async move { - server - .run_with_shutdown(async { - let _ = shutdown_rx.await; - }) - .await - }); - - block(addr).await?; - let _ = shutdown_tx.send(()); - let run_result = handle.await?; - run_result?; - Ok(()) -} - -#[tokio::test] -#[expect( - clippy::panic_in_result_fn, - reason = "asserts provide clearer diagnostics in tests" -)] -async fn parse_valid_preamble() -> TestResult { - let (mut client, mut server) = duplex(64); - let bytes = b"TRTPHOTL\x00\x01\x00\x02"; - client.write_all(bytes).await?; - client.shutdown().await?; - let (p, _) = read_preamble::<_, HotlinePreamble>(&mut server).await?; - p.validate()?; - assert_eq!(p.magic, HotlinePreamble::MAGIC, "preamble magic mismatch"); - assert_eq!(p.min_version, 1, "preamble minimum version mismatch"); - assert_eq!(p.client_version, 2, "preamble client version mismatch"); - Ok(()) -} - -#[tokio::test] -#[expect( - clippy::panic_in_result_fn, - reason = "asserts provide clearer diagnostics in tests" -)] -async fn invalid_magic_is_error() -> TestResult { - let (mut client, mut server) = duplex(64); - let bytes = b"WRONGMAG\x00\x01\x00\x02"; - client.write_all(bytes).await?; - client.shutdown().await?; - let (preamble, _) = read_preamble::<_, HotlinePreamble>(&mut server).await?; - assert!( - preamble.validate().is_err(), - "invalid magic should fail validation" - ); - Ok(()) -} - -#[derive(Clone, Copy)] -enum ExpectedCallback { - Success, - Failure, -} - -#[rstest] -#[case(b"TRTPHOTL\x00\x01\x00\x02", ExpectedCallback::Success)] -#[case(b"TRTPHOT", ExpectedCallback::Failure)] -#[tokio::test] -async fn server_triggers_expected_callback( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - #[case] bytes: &'static [u8], - #[case] expected: ExpectedCallback, -) -> TestResult { - let (success_tx, success_rx) = tokio::sync::oneshot::channel::(); - let (failure_tx, failure_rx) = tokio::sync::oneshot::channel::<()>(); - let success_tx = std::sync::Arc::new(std::sync::Mutex::new(Some(success_tx))); - let failure_tx = std::sync::Arc::new(std::sync::Mutex::new(Some(failure_tx))); - let server = server_with_handlers( - factory, - { - let success_tx = success_tx.clone(); - move |p, _| { - let success_tx = success_tx.clone(); - let clone = p.clone(); - Box::pin(async move { - if let Some(tx) = take_sender_io(&success_tx)? { - let _ = tx.send(clone); - } - Ok::<(), io::Error>(()) - }) - } - }, - { - let failure_tx = failure_tx.clone(); - move |_, _| { - let failure_tx = failure_tx.clone(); - Box::pin(async move { - if let Some(tx) = take_sender_io(&failure_tx)? { - let _ = tx.send(()); - } - Ok::<(), io::Error>(()) - }) - } - }, - ); - - with_running_server(server, |addr| async move { - let mut stream = TcpStream::connect(addr).await?; - stream.write_all(bytes).await?; - stream.shutdown().await?; - Ok(()) - }) - .await?; - - match expected { - ExpectedCallback::Success => { - let preamble = timeout(Duration::from_secs(1), success_rx).await??; - assert_eq!(preamble.magic, HotlinePreamble::MAGIC); - assert!( - timeout(Duration::from_millis(500), failure_rx) - .await - .is_err() - ); - } - ExpectedCallback::Failure => { - timeout(Duration::from_secs(1), failure_rx).await??; - assert!( - timeout(Duration::from_millis(500), success_rx) - .await - .is_err() - ); - } - } - Ok(()) -} - -#[rstest] -#[tokio::test] -async fn success_callback_can_write_response( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) -> TestResult { - let server = server_with_handlers( - factory, - |_, stream| { - Box::pin(async move { - stream.write_all(b"ACK").await?; - stream.flush().await?; - Ok::<(), io::Error>(()) - }) - }, - |_, _| Box::pin(async { Ok::<(), io::Error>(()) }), - ); - - with_running_server(server, |addr| async move { - let mut stream = TcpStream::connect(addr).await?; - let bytes = b"TRTPHOTL\x00\x01\x00\x02"; - stream.write_all(bytes).await?; - let mut buf = [0u8; 3]; - stream.read_exact(&mut buf).await?; - assert_eq!(&buf, b"ACK"); - Ok(()) - }) - .await?; - Ok(()) -} - -#[rstest] -#[tokio::test] -async fn failure_callback_can_write_response( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) -> TestResult { - let (failure_holder, failure_rx) = channel_holder(); - let server = WireframeServer::new(factory) - .with_preamble::() - .on_preamble_decode_failure(move |_, stream| { - let failure_holder = failure_holder.clone(); - Box::pin(async move { - stream.write_all(b"ERR").await?; - stream.flush().await?; - if let Some(tx) = take_sender_io(&failure_holder)? { - let _ = tx.send(()); - } - Ok::<(), io::Error>(()) - }) - }); - - with_running_server(server, |addr| async move { - let mut stream = TcpStream::connect(addr).await?; - stream.write_all(b"BAD").await?; - stream.shutdown().await?; - let mut buf = [0u8; 3]; - let read = timeout(Duration::from_secs(1), stream.read_exact(&mut buf)).await; - let result = read?; - result?; - assert_eq!(&buf, b"ERR"); - recv_within(Duration::from_millis(200), failure_rx).await?; - Ok(()) - }) - .await?; - Ok(()) -} - -#[rstest] -#[tokio::test] -async fn preamble_timeout_invokes_failure_handler_and_closes_connection( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) -> TestResult { - let (failure_holder, failure_rx) = channel_holder(); - let server = WireframeServer::new(factory) - .with_preamble::() - .preamble_timeout(Duration::from_millis(50)) - .on_preamble_decode_failure(move |err, stream| { - let failure_holder = failure_holder.clone(); - Box::pin(async move { - assert!( - matches!( - err, - DecodeError::Io { inner, .. } - if inner.kind() == io::ErrorKind::TimedOut - ), - "expected timed out error, got {err:?}" - ); - stream.write_all(b"ERR").await?; - stream.flush().await?; - stream.shutdown().await?; - if let Some(tx) = take_sender_io(&failure_holder)? { - let _ = tx.send(()); - } - Ok::<(), io::Error>(()) - }) - }); - - with_running_server(server, |addr| async move { - let mut stream = TcpStream::connect(addr).await?; - recv_within(Duration::from_secs(1), failure_rx).await?; - let mut buf = [0u8; 3]; - timeout(Duration::from_millis(500), stream.read_exact(&mut buf)).await??; - assert_eq!(&buf, b"ERR"); - let mut eof = [0u8; 1]; - let read = timeout(Duration::from_millis(200), stream.read(&mut eof)).await; - match read? { - Ok(0) => {} - Ok(n) => panic!("expected connection to close, read {n} bytes"), - Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {} - Err(e) => panic!("unexpected read error: {e:?}"), - } - Ok(()) - }) - .await?; - Ok(()) -} - -#[rstest] -#[tokio::test] -async fn success_handler_runs_without_failure_handler( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) -> TestResult { - let (success_tx, success_rx) = tokio::sync::oneshot::channel::(); - let success_tx = Arc::new(Mutex::new(Some(success_tx))); - let server = WireframeServer::new(factory) - .with_preamble::() - .on_preamble_decode_success({ - let success_tx = success_tx.clone(); - move |p, _| { - let success_tx = success_tx.clone(); - let preamble = p.clone(); - Box::pin(async move { - if let Some(tx) = take_sender_io(&success_tx)? { - let _ = tx.send(preamble); - } - Ok::<(), io::Error>(()) - }) - } - }); - - with_running_server(server, |addr| async move { - let mut stream = TcpStream::connect(addr).await?; - let bytes = b"TRTPHOTL\x00\x01\x00\x02"; - stream.write_all(bytes).await?; - stream.shutdown().await?; - let preamble = recv_within(Duration::from_secs(1), success_rx).await?; - assert_eq!(preamble.magic, HotlinePreamble::MAGIC); - Ok(()) - }) - .await?; - Ok(()) -} - -#[rstest] -#[tokio::test] -async fn preamble_timeout_allows_timely_preamble( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) -> TestResult { - let (success_holder, success_rx) = channel_holder(); - let (failure_holder, failure_rx) = channel_holder(); - let server = WireframeServer::new(factory) - .with_preamble::() - .preamble_timeout(Duration::from_millis(150)) - .on_preamble_decode_success({ - let success_holder = success_holder.clone(); - move |p, stream| { - let success_holder = success_holder.clone(); - let clone = p.clone(); - Box::pin(async move { - if let Some(tx) = take_sender_io(&success_holder)? { - let _ = tx.send(()); - } - stream.write_all(b"OK").await?; - stream.flush().await?; - // keep connection open by not shutting down here - assert_eq!(clone.magic, HotlinePreamble::MAGIC); - Ok::<(), io::Error>(()) - }) - } - }) - .on_preamble_decode_failure({ - let failure_holder = failure_holder.clone(); - move |_, _| { - let failure_holder = failure_holder.clone(); - Box::pin(async move { - if let Some(tx) = take_sender_io(&failure_holder)? { - let _ = tx.send(()); - } - Ok::<(), io::Error>(()) - }) - } - }); - - with_running_server(server, |addr| async move { - let mut stream = TcpStream::connect(addr).await?; - let bytes = b"TRTPHOTL\x00\x01\x00\x02"; - stream.write_all(bytes).await?; - - recv_within(Duration::from_millis(200), success_rx).await?; - assert!( - timeout(Duration::from_millis(150), failure_rx) - .await - .is_err(), - "failure handler should not fire for timely preamble" - ); - - let mut buf = [0u8; 2]; - stream.read_exact(&mut buf).await?; - assert_eq!(&buf, b"OK"); - Ok(()) - }) - .await?; - Ok(()) -} - -#[rstest] -#[tokio::test] -async fn failure_handler_error_is_logged_and_connection_closes( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) -> TestResult { - let (failure_holder, failure_rx) = channel_holder(); - let server = WireframeServer::new(factory) - .with_preamble::() - .on_preamble_decode_failure(move |_, _| { - let failure_holder = failure_holder.clone(); - Box::pin(async move { - if let Some(tx) = take_sender_io(&failure_holder)? { - let _ = tx.send(()); - } - Err::<(), io::Error>(io::Error::other("boom")) - }) - }); - - with_running_server(server, |addr| async move { - let mut stream = TcpStream::connect(addr).await?; - stream.write_all(b"BAD").await?; - stream.shutdown().await?; - - recv_within(Duration::from_secs(1), failure_rx).await?; - - let mut buf = [0u8; 1]; - let read = timeout(Duration::from_millis(200), stream.read(&mut buf)).await; - match read? { - Ok(0) => {} - Ok(n) => panic!("expected connection close, read {n} bytes"), - Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {} - Err(e) => panic!("unexpected read error: {e:?}"), - } - Ok(()) - }) - .await?; - Ok(()) -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, bincode::Encode, bincode::Decode)] -struct OtherPreamble(u8); - -type Holder = Arc>>>; - -fn channel_holder() -> (Holder, oneshot::Receiver<()>) { - let (tx, rx) = oneshot::channel(); - (Arc::new(Mutex::new(Some(tx))), rx) -} - -fn take_sender_io(holder: &Mutex>) -> io::Result> { - holder - .lock() - .map_err(|e| io::Error::other(format!("lock poisoned: {e}"))) - .map(|mut guard| guard.take()) -} - -async fn recv_within(duration: Duration, rx: oneshot::Receiver) -> TestResult { - Ok(timeout(duration, rx).await??) -} - -fn success_cb

( - holder: Arc>>>, -) -> impl for<'a> Fn(&'a P, &'a mut TcpStream) -> BoxFuture<'a, io::Result<()>> + Send + Sync + 'static -{ - move |_, _| { - let holder = holder.clone(); - Box::pin(async move { - if let Some(tx) = take_sender_io(&holder)? { - let _ = tx.send(()); - } - Ok(()) - }) - } -} - -fn failure_cb( - holder: Arc>>>, -) -> impl for<'a> Fn(&'a DecodeError, &'a mut TcpStream) -> BoxFuture<'a, io::Result<()>> -+ Send -+ Sync -+ 'static { - move |_, _| { - let holder = holder.clone(); - Box::pin(async move { - if let Some(tx) = take_sender_io(&holder)? { - let _ = tx.send(()); - } - Ok(()) - }) - } -} - -#[rstest] -#[tokio::test] -async fn callbacks_dropped_when_overriding_preamble( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) -> TestResult { - let (hotline_success, hotline_success_rx) = channel_holder(); - let (hotline_failure, hotline_failure_rx) = channel_holder(); - let (other_success, other_success_rx) = channel_holder(); - let (other_failure, other_failure_rx) = channel_holder(); - - let server = WireframeServer::new(factory.clone()) - .with_preamble::() - .on_preamble_decode_success(success_cb::(hotline_success.clone())) - .on_preamble_decode_failure(failure_cb(hotline_failure.clone())) - .with_preamble::() - .on_preamble_decode_success(success_cb::(other_success.clone())) - .on_preamble_decode_failure(failure_cb(other_failure.clone())); - with_running_server(server, |addr| async move { - let mut stream = TcpStream::connect(addr).await?; - let config = bincode::config::standard() - .with_big_endian() - .with_fixed_int_encoding(); - let mut bytes = bincode::encode_to_vec(OtherPreamble(1), config)?; - bytes.resize(8, 0); - stream.write_all(&bytes).await?; - stream.shutdown().await?; - // Wait for the success callback before shutting down the server. - recv_within(Duration::from_secs(1), other_success_rx).await?; - Ok(()) - }) - .await?; - assert!( - timeout(Duration::from_millis(500), other_failure_rx) - .await - .is_err(), - "other failure callback invoked", - ); - assert!( - timeout(Duration::from_millis(500), hotline_success_rx) - .await - .is_err(), - "hotline success callback invoked", - ); - assert!( - timeout(Duration::from_millis(500), hotline_failure_rx) - .await - .is_err(), - "hotline failure callback invoked", - ); - Ok(()) -} +#[path = "preamble/basic.rs"] +mod basic; +#[path = "preamble/callbacks.rs"] +mod callbacks; +#[path = "preamble/responses.rs"] +mod responses; +#[path = "preamble/support.rs"] +mod support; +#[path = "preamble/timeouts.rs"] +mod timeouts; diff --git a/tests/preamble/basic.rs b/tests/preamble/basic.rs new file mode 100644 index 00000000..7edc5540 --- /dev/null +++ b/tests/preamble/basic.rs @@ -0,0 +1,42 @@ +//! Basic preamble parsing tests. + +use tokio::io::{AsyncWriteExt, duplex}; +use wireframe::preamble::read_preamble; + +use crate::{common::TestResult, support::HotlinePreamble}; + +#[tokio::test] +#[expect( + clippy::panic_in_result_fn, + reason = "asserts provide clearer diagnostics in tests" +)] +async fn parse_valid_preamble() -> TestResult { + let (mut client, mut server) = duplex(64); + let bytes = b"TRTPHOTL\x00\x01\x00\x02"; + client.write_all(bytes).await?; + client.shutdown().await?; + let (p, _) = read_preamble::<_, HotlinePreamble>(&mut server).await?; + p.validate()?; + assert_eq!(p.magic, HotlinePreamble::MAGIC, "preamble magic mismatch"); + assert_eq!(p.min_version, 1, "preamble minimum version mismatch"); + assert_eq!(p.client_version, 2, "preamble client version mismatch"); + Ok(()) +} + +#[tokio::test] +#[expect( + clippy::panic_in_result_fn, + reason = "asserts provide clearer diagnostics in tests" +)] +async fn invalid_magic_is_error() -> TestResult { + let (mut client, mut server) = duplex(64); + let bytes = b"WRONGMAG\x00\x01\x00\x02"; + client.write_all(bytes).await?; + client.shutdown().await?; + let (preamble, _) = read_preamble::<_, HotlinePreamble>(&mut server).await?; + assert!( + preamble.validate().is_err(), + "invalid magic should fail validation" + ); + Ok(()) +} diff --git a/tests/preamble/callbacks.rs b/tests/preamble/callbacks.rs new file mode 100644 index 00000000..05dbc9b5 --- /dev/null +++ b/tests/preamble/callbacks.rs @@ -0,0 +1,225 @@ +//! Callback behaviour tests for preamble handling. + +use std::{io, sync::Arc}; + +use rstest::rstest; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpStream, + time::{Duration, timeout}, +}; +use wireframe::{app::WireframeApp, server::WireframeServer}; + +use crate::{ + common::{TestResult, factory}, + support::{ + HotlinePreamble, + OtherPreamble, + channel_holder, + failure_cb, + recv_within, + server_with_handlers, + success_cb, + take_sender_io, + with_running_server, + }, +}; + +#[derive(Clone, Copy)] +enum ExpectedCallback { + Success, + Failure, +} + +#[rstest] +#[case(b"TRTPHOTL\x00\x01\x00\x02", ExpectedCallback::Success)] +#[case(b"TRTPHOT", ExpectedCallback::Failure)] +#[tokio::test] +async fn server_triggers_expected_callback( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + #[case] bytes: &'static [u8], + #[case] expected: ExpectedCallback, +) -> TestResult { + let (success_tx, success_rx) = tokio::sync::oneshot::channel::(); + let (failure_tx, failure_rx) = tokio::sync::oneshot::channel::<()>(); + let success_tx = Arc::new(std::sync::Mutex::new(Some(success_tx))); + let failure_tx = Arc::new(std::sync::Mutex::new(Some(failure_tx))); + let server = server_with_handlers( + factory, + { + let success_tx = success_tx.clone(); + move |p, _| { + let success_tx = success_tx.clone(); + let clone = p.clone(); + Box::pin(async move { + if let Some(tx) = take_sender_io(&success_tx)? { + let _ = tx.send(clone); + } + Ok::<(), io::Error>(()) + }) + } + }, + { + let failure_tx = failure_tx.clone(); + move |_, _| { + let failure_tx = failure_tx.clone(); + Box::pin(async move { + if let Some(tx) = take_sender_io(&failure_tx)? { + let _ = tx.send(()); + } + Ok::<(), io::Error>(()) + }) + } + }, + ); + + with_running_server(server, |addr| async move { + let mut stream = TcpStream::connect(addr).await?; + stream.write_all(bytes).await?; + stream.shutdown().await?; + + match expected { + ExpectedCallback::Success => { + let preamble = timeout(Duration::from_secs(2), success_rx).await??; + assert_eq!(preamble.magic, HotlinePreamble::MAGIC); + assert!(timeout(Duration::from_secs(1), failure_rx).await.is_err()); + } + ExpectedCallback::Failure => { + timeout(Duration::from_secs(2), failure_rx).await??; + assert!(timeout(Duration::from_secs(1), success_rx).await.is_err()); + } + } + + Ok(()) + }) + .await?; + Ok(()) +} + +#[rstest] +#[tokio::test] +async fn success_handler_runs_without_failure_handler( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, +) -> TestResult { + let (success_tx, success_rx) = tokio::sync::oneshot::channel::(); + let success_tx = Arc::new(std::sync::Mutex::new(Some(success_tx))); + let server = WireframeServer::new(factory) + .with_preamble::() + .on_preamble_decode_success({ + let success_tx = success_tx.clone(); + move |p, _| { + let success_tx = success_tx.clone(); + let preamble = p.clone(); + Box::pin(async move { + if let Some(tx) = take_sender_io(&success_tx)? { + let _ = tx.send(preamble); + } + Ok::<(), io::Error>(()) + }) + } + }); + + with_running_server(server, |addr| async move { + let mut stream = TcpStream::connect(addr).await?; + let bytes = b"TRTPHOTL\x00\x01\x00\x02"; + stream.write_all(bytes).await?; + stream.shutdown().await?; + let preamble = recv_within(Duration::from_secs(1), success_rx).await?; + assert_eq!(preamble.magic, HotlinePreamble::MAGIC); + Ok(()) + }) + .await?; + Ok(()) +} + +#[rstest] +#[tokio::test] +async fn failure_handler_error_is_logged_and_connection_closes( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, +) -> TestResult { + let (failure_holder, failure_rx) = channel_holder(); + let server = WireframeServer::new(factory) + .with_preamble::() + .on_preamble_decode_failure(move |_, _| { + let failure_holder = failure_holder.clone(); + Box::pin(async move { + if let Some(tx) = take_sender_io(&failure_holder)? { + let _ = tx.send(()); + } + Err::<(), io::Error>(io::Error::other("boom")) + }) + }); + + with_running_server(server, |addr| async move { + let mut stream = TcpStream::connect(addr).await?; + stream.write_all(b"BAD").await?; + stream.shutdown().await?; + + recv_within(Duration::from_secs(1), failure_rx).await?; + + let mut buf = [0u8; 1]; + let read = timeout(Duration::from_millis(200), stream.read(&mut buf)).await; + match read? { + Ok(0) => {} + Ok(n) => panic!("expected connection close, read {n} bytes"), + Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {} + Err(e) => panic!("unexpected read error: {e:?}"), + } + Ok(()) + }) + .await?; + Ok(()) +} + +#[rstest] +#[tokio::test] +async fn callbacks_dropped_when_overriding_preamble( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, +) -> TestResult { + let (hotline_success, hotline_success_rx) = channel_holder(); + let (hotline_failure, hotline_failure_rx) = channel_holder(); + let (other_success, other_success_rx) = channel_holder(); + let (other_failure, other_failure_rx) = channel_holder(); + + let server = WireframeServer::new(factory.clone()) + .with_preamble::() + .on_preamble_decode_success(success_cb::(hotline_success.clone())) + .on_preamble_decode_failure(failure_cb(hotline_failure.clone())) + .with_preamble::() + .on_preamble_decode_success(success_cb::(other_success.clone())) + .on_preamble_decode_failure(failure_cb(other_failure.clone())); + + with_running_server(server, |addr| async move { + let mut stream = TcpStream::connect(addr).await?; + let config = bincode::config::standard() + .with_big_endian() + .with_fixed_int_encoding(); + let mut bytes = bincode::encode_to_vec(OtherPreamble(1), config)?; + bytes.resize(8, 0); + stream.write_all(&bytes).await?; + stream.shutdown().await?; + // Wait for the success callback before shutting down the server. + recv_within(Duration::from_secs(1), other_success_rx).await?; + Ok(()) + }) + .await?; + assert!( + timeout(Duration::from_millis(500), other_failure_rx) + .await + .is_err(), + "other failure callback invoked", + ); + assert!( + timeout(Duration::from_millis(500), hotline_success_rx) + .await + .is_err(), + "hotline success callback invoked", + ); + assert!( + timeout(Duration::from_millis(500), hotline_failure_rx) + .await + .is_err(), + "hotline failure callback invoked", + ); + Ok(()) +} diff --git a/tests/preamble/responses.rs b/tests/preamble/responses.rs new file mode 100644 index 00000000..ca66dfb2 --- /dev/null +++ b/tests/preamble/responses.rs @@ -0,0 +1,97 @@ +//! Response-writing tests for preamble handlers. + +use std::io; + +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpStream, + sync::oneshot, + time::{Duration, timeout}, +}; +use wireframe::server::WireframeServer; + +use crate::{ + common::{TestResult, factory}, + support::{ + HotlinePreamble, + channel_holder, + notify_holder, + recv_within, + server_with_handlers, + with_running_server, + }, +}; + +#[tokio::test] +#[expect( + clippy::panic_in_result_fn, + reason = "asserts provide clearer diagnostics in tests" +)] +async fn success_callback_can_write_response() -> TestResult { + let factory = factory(); + let (response_tx, response_rx) = oneshot::channel(); + let server = server_with_handlers( + factory, + |_, stream| { + Box::pin(async move { + stream.write_all(b"ACK").await?; + stream.flush().await?; + Ok::<(), io::Error>(()) + }) + }, + |_, _| Box::pin(async { Ok::<(), io::Error>(()) }), + ); + + with_running_server(server, move |addr| async move { + let mut stream = TcpStream::connect(addr).await?; + let bytes = b"TRTPHOTL\x00\x01\x00\x02"; + stream.write_all(bytes).await?; + let mut buf = [0u8; 3]; + stream.read_exact(&mut buf).await?; + let _ = response_tx.send(buf); + Ok(()) + }) + .await?; + let buf = recv_within(Duration::from_secs(1), response_rx).await?; + assert_eq!(&buf, b"ACK"); + Ok(()) +} + +#[tokio::test] +#[expect( + clippy::panic_in_result_fn, + reason = "asserts provide clearer diagnostics in tests" +)] +async fn failure_callback_can_write_response() -> TestResult { + let factory = factory(); + let (failure_holder, failure_rx) = channel_holder(); + let (response_tx, response_rx) = oneshot::channel(); + let server = WireframeServer::new(factory) + .with_preamble::() + .on_preamble_decode_failure(move |_, stream| { + let failure_holder = failure_holder.clone(); + Box::pin(async move { + stream.write_all(b"ERR").await?; + stream.flush().await?; + notify_holder(&failure_holder)?; + Ok::<(), io::Error>(()) + }) + }); + + with_running_server(server, move |addr| async move { + let mut stream = TcpStream::connect(addr).await?; + stream.write_all(b"BAD").await?; + stream.shutdown().await?; + let mut buf = [0u8; 3]; + let read = timeout(Duration::from_secs(1), stream.read_exact(&mut buf)).await; + let result = read?; + result?; + let _ = response_tx.send(buf); + Ok(()) + }) + .await?; + let buf = recv_within(Duration::from_secs(1), response_rx).await?; + assert_eq!(&buf, b"ERR"); + recv_within(Duration::from_millis(200), failure_rx).await?; + Ok(()) +} diff --git a/tests/preamble/support.rs b/tests/preamble/support.rs new file mode 100644 index 00000000..707d3dc6 --- /dev/null +++ b/tests/preamble/support.rs @@ -0,0 +1,216 @@ +//! Shared helpers for the preamble integration tests. + +use std::{ + error::Error, + io, + sync::{Arc, Mutex}, +}; + +use bincode::error::DecodeError; +use futures::future::BoxFuture; +use tokio::{ + net::TcpStream, + sync::oneshot, + time::{Duration, timeout}, +}; +use wireframe::{app::WireframeApp, server::WireframeServer}; + +use crate::common::{TestResult, unused_listener}; + +#[derive(Debug, Clone, PartialEq, Eq, bincode::Encode, bincode::Decode)] +pub(crate) struct HotlinePreamble { + /// Should always be `b"TRTPHOTL"`. + pub(crate) magic: [u8; 8], + /// Minimum server version this client supports. + pub(crate) min_version: u16, + /// Client version. + pub(crate) client_version: u16, +} + +impl HotlinePreamble { + pub(crate) const MAGIC: [u8; 8] = *b"TRTPHOTL"; + + pub(crate) fn validate(&self) -> Result<(), DecodeError> { + if self.magic != Self::MAGIC { + return Err(DecodeError::Other("invalid hotline preamble")); + } + Ok(()) + } +} + +/// Create a server configured with `HotlinePreamble` handlers. +pub(crate) fn server_with_handlers( + factory: F, + success: S, + failure: E, +) -> WireframeServer +where + F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, + S: for<'a> Fn(&'a HotlinePreamble, &'a mut TcpStream) -> BoxFuture<'a, io::Result<()>> + + Send + + Sync + + 'static, + E: for<'a> Fn(&'a DecodeError, &'a mut TcpStream) -> BoxFuture<'a, io::Result<()>> + + Send + + Sync + + 'static, +{ + WireframeServer::new(factory) + .workers(1) + .with_preamble::() + .on_preamble_decode_success(success) + .on_preamble_decode_failure(failure) +} + +/// Run the provided server while executing `block`. +pub(crate) async fn with_running_server( + server: WireframeServer, + block: B, +) -> TestResult +where + F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, + T: wireframe::preamble::Preamble, + Fut: std::future::Future, + B: FnOnce(std::net::SocketAddr) -> Fut, +{ + let listener = unused_listener(); + let server = server.bind_existing_listener(listener)?; + let addr = server + .local_addr() + .ok_or_else(|| Box::::from("server missing local addr"))?; + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let handle = tokio::spawn(async move { + server + .run_with_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + }); + + block(addr).await?; + let _ = shutdown_tx.send(()); + let run_result = handle.await?; + run_result?; + Ok(()) +} + +/// Alternate preamble used to verify handler overrides. +/// +/// # Examples +/// ```rust,ignore +/// let preamble = OtherPreamble(1); +/// assert_eq!(preamble.0, 1); +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq, bincode::Encode, bincode::Decode)] +pub(crate) struct OtherPreamble(pub(crate) u8); + +/// Shared oneshot sender holder used by callbacks. +pub(crate) type Holder = Arc>>>; + +/// Create a callback sender holder with its paired receiver. +/// +/// # Examples +/// ```rust,ignore +/// let (holder, rx) = channel_holder(); +/// assert!(holder.lock().unwrap().is_some()); +/// drop(rx); +/// ``` +pub(crate) fn channel_holder() -> (Holder, oneshot::Receiver<()>) { + let (tx, rx) = oneshot::channel(); + (Arc::new(Mutex::new(Some(tx))), rx) +} + +/// Take the sender from a mutex, returning an IO error on poison. +/// +/// # Examples +/// ```rust,ignore +/// use std::sync::Mutex; +/// +/// let holder = Mutex::new(Some(1)); +/// let value = take_sender_io(&holder).unwrap(); +/// assert_eq!(value, Some(1)); +/// ``` +pub(crate) fn take_sender_io(holder: &Mutex>) -> io::Result> { + holder + .lock() + .map_err(|e| io::Error::other(format!("lock poisoned: {e}"))) + .map(|mut guard| guard.take()) +} + +/// Signal the holder if a sender is still available. +/// +/// # Examples +/// ```rust,ignore +/// let (holder, _rx) = channel_holder(); +/// notify_holder(&holder).unwrap(); +/// ``` +pub(crate) fn notify_holder(holder: &Holder) -> io::Result<()> { + if let Some(tx) = take_sender_io(holder)? { + let _ = tx.send(()); + } + Ok(()) +} + +/// Await a oneshot receiver within the provided duration. +/// +/// # Examples +/// ```rust,ignore +/// # use tokio::sync::oneshot; +/// # use tokio::time::Duration; +/// # async fn demo() -> Result<(), Box> { +/// let (tx, rx) = oneshot::channel(); +/// let _ = tx.send(42); +/// let value = recv_within(Duration::from_millis(50), rx).await?; +/// assert_eq!(value, 42); +/// # Ok(()) +/// # } +/// ``` +pub(crate) async fn recv_within(duration: Duration, rx: oneshot::Receiver) -> TestResult { + Ok(timeout(duration, rx).await??) +} + +/// Build a success callback that signals through a shared holder. +/// +/// # Examples +/// ```rust,ignore +/// let (holder, _rx) = channel_holder(); +/// let callback = success_cb::(holder); +/// ``` +pub(crate) fn success_cb

( + holder: Arc>>>, +) -> impl for<'a> Fn(&'a P, &'a mut TcpStream) -> BoxFuture<'a, io::Result<()>> + Send + Sync + 'static +{ + move |_, _| { + let holder = holder.clone(); + Box::pin(async move { + if let Some(tx) = take_sender_io(&holder)? { + let _ = tx.send(()); + } + Ok(()) + }) + } +} + +/// Build a failure callback that signals through a shared holder. +/// +/// # Examples +/// ```rust,ignore +/// let (holder, _rx) = channel_holder(); +/// let callback = failure_cb(holder); +/// ``` +pub(crate) fn failure_cb( + holder: Arc>>>, +) -> impl for<'a> Fn(&'a DecodeError, &'a mut TcpStream) -> BoxFuture<'a, io::Result<()>> ++ Send ++ Sync ++ 'static { + move |_, _| { + let holder = holder.clone(); + Box::pin(async move { + if let Some(tx) = take_sender_io(&holder)? { + let _ = tx.send(()); + } + Ok(()) + }) + } +} diff --git a/tests/preamble/timeouts.rs b/tests/preamble/timeouts.rs new file mode 100644 index 00000000..2d331bbe --- /dev/null +++ b/tests/preamble/timeouts.rs @@ -0,0 +1,146 @@ +//! Timeout behaviour tests for preamble handling. + +use std::io; + +use bincode::error::DecodeError; +use futures::future::BoxFuture; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpStream, + sync::oneshot, + time::{Duration, timeout}, +}; +use wireframe::server::WireframeServer; + +use crate::{ + common::{TestResult, factory}, + support::{ + Holder, + HotlinePreamble, + channel_holder, + failure_cb, + notify_holder, + recv_within, + with_running_server, + }, +}; + +fn timeout_success_handler( + holder: Holder, +) -> impl for<'a> Fn(&'a HotlinePreamble, &'a mut TcpStream) -> BoxFuture<'a, io::Result<()>> ++ Send ++ Sync ++ 'static { + move |p, stream| { + let holder = holder.clone(); + let clone = p.clone(); + Box::pin(async move { + notify_holder(&holder)?; + stream.write_all(b"OK").await?; + stream.flush().await?; + // keep connection open by not shutting down here + assert_eq!(clone.magic, HotlinePreamble::MAGIC); + Ok::<(), io::Error>(()) + }) + } +} + +#[tokio::test] +#[expect( + clippy::panic_in_result_fn, + reason = "asserts provide clearer diagnostics in tests" +)] +async fn preamble_timeout_invokes_failure_handler_and_closes_connection() -> TestResult { + let factory = factory(); + let (failure_holder, failure_rx) = channel_holder(); + let (result_tx, result_rx) = oneshot::channel(); + let server = WireframeServer::new(factory) + .with_preamble::() + .preamble_timeout(Duration::from_millis(50)) + .on_preamble_decode_failure(move |err, stream| { + let failure_holder = failure_holder.clone(); + Box::pin(async move { + assert!( + matches!( + err, + DecodeError::Io { inner, .. } + if inner.kind() == io::ErrorKind::TimedOut + ), + "expected timed out error, got {err:?}" + ); + stream.write_all(b"ERR").await?; + stream.flush().await?; + stream.shutdown().await?; + notify_holder(&failure_holder)?; + Ok::<(), io::Error>(()) + }) + }); + + with_running_server(server, move |addr| async move { + let mut stream = TcpStream::connect(addr).await?; + recv_within(Duration::from_secs(1), failure_rx).await?; + let mut buf = [0u8; 3]; + timeout(Duration::from_millis(500), stream.read_exact(&mut buf)).await??; + let mut eof = [0u8; 1]; + let read = timeout(Duration::from_millis(200), stream.read(&mut eof)).await; + let closed = match read? { + Ok(0) => true, + Ok(n) => { + return Err(io::Error::other(format!( + "expected connection to close, read {n} bytes" + )) + .into()); + } + Err(e) if e.kind() == io::ErrorKind::ConnectionReset => true, + Err(e) => return Err(e.into()), + }; + let _ = result_tx.send((buf, closed)); + Ok(()) + }) + .await?; + let (buf, closed) = recv_within(Duration::from_secs(1), result_rx).await?; + assert_eq!(&buf, b"ERR"); + assert!(closed, "expected connection to close"); + Ok(()) +} + +#[tokio::test] +#[expect( + clippy::panic_in_result_fn, + reason = "asserts provide clearer diagnostics in tests" +)] +async fn preamble_timeout_allows_timely_preamble() -> TestResult { + let factory = factory(); + let (success_holder, success_rx) = channel_holder(); + let (failure_holder, failure_rx) = channel_holder(); + let (result_tx, result_rx) = oneshot::channel(); + let server = WireframeServer::new(factory) + .with_preamble::() + .preamble_timeout(Duration::from_millis(150)) + .on_preamble_decode_success(timeout_success_handler(success_holder.clone())) + .on_preamble_decode_failure(failure_cb(failure_holder.clone())); + + with_running_server(server, move |addr| async move { + let mut stream = TcpStream::connect(addr).await?; + let bytes = b"TRTPHOTL\x00\x01\x00\x02"; + stream.write_all(bytes).await?; + + recv_within(Duration::from_millis(200), success_rx).await?; + let failure_fired = timeout(Duration::from_millis(150), failure_rx) + .await + .is_ok(); + + let mut buf = [0u8; 2]; + stream.read_exact(&mut buf).await?; + let _ = result_tx.send((buf, failure_fired)); + Ok(()) + }) + .await?; + let (buf, failure_fired) = recv_within(Duration::from_secs(1), result_rx).await?; + assert_eq!(&buf, b"OK"); + assert!( + !failure_fired, + "failure handler should not fire for timely preamble" + ); + Ok(()) +} diff --git a/tests/response.rs b/tests/response.rs index 5b3546f0..e5cd2c56 100644 --- a/tests/response.rs +++ b/tests/response.rs @@ -267,7 +267,7 @@ async fn process_stream_honours_buffer_capacity() -> TestResult { let (resp_env, _) = BincodeSerializer .deserialize::(frame) .map_err(|e| format!("deserialize failed: {e}"))?; - let resp_len = resp_env.into_parts().payload().len(); + let resp_len = resp_env.into_parts().into_payload().len(); assert_eq!(resp_len, payload.len()); Ok(()) } diff --git a/wireframe_testing/src/echo_server.rs b/wireframe_testing/src/echo_server.rs index fdc3c0e6..97264b50 100644 --- a/wireframe_testing/src/echo_server.rs +++ b/wireframe_testing/src/echo_server.rs @@ -36,7 +36,7 @@ pub fn process_frame(mode: ServerMode, bytes: &[u8]) -> Option> { ServerMode::Mismatch => { let wrong_id = envelope.correlation_id().map(|id| id.wrapping_add(999)); let parts = envelope.into_parts(); - Envelope::new(parts.id(), wrong_id, parts.payload()) + Envelope::new(parts.id(), wrong_id, parts.into_payload()) } };