diff --git a/docs/v0-1-0-to-v0-2-0-migration-guide.md b/docs/v0-1-0-to-v0-2-0-migration-guide.md index 8d9a72dd..757b9de4 100644 --- a/docs/v0-1-0-to-v0-2-0-migration-guide.md +++ b/docs/v0-1-0-to-v0-2-0-migration-guide.md @@ -1,7 +1,7 @@ # v0.1.0 to v0.2.0 migration guide -This guide summarizes the breaking changes required when migrating from -wireframe v0.1.0 to v0.2.0. +This guide summarizes the breaking changes that must be addressed when +migrating from wireframe v0.1.0 to v0.2.0. ## Configuration builder naming update diff --git a/src/app/builder.rs b/src/app/builder.rs index 8255d813..42162c1b 100644 --- a/src/app/builder.rs +++ b/src/app/builder.rs @@ -13,7 +13,12 @@ use std::{ use tokio::sync::{OnceCell, mpsc}; use super::{ - builder_defaults::{MAX_READ_TIMEOUT_MS, MIN_READ_TIMEOUT_MS, default_fragmentation}, + builder_defaults::{ + DEFAULT_READ_TIMEOUT_MS, + MAX_READ_TIMEOUT_MS, + MIN_READ_TIMEOUT_MS, + default_fragmentation, + }, envelope::{Envelope, Packet}, error::{Result, WireframeError}, lifecycle::{ConnectionSetup, ConnectionTeardown}, @@ -77,7 +82,7 @@ where protocol: None, push_dlq: None, codec, - read_timeout_ms: 100, + read_timeout_ms: DEFAULT_READ_TIMEOUT_MS, fragmentation: default_fragmentation(max_frame_length), message_assembler: None, } @@ -164,6 +169,32 @@ where } } + /// Helper to rebuild the app when changing the connection state type. + pub(super) fn rebuild_with_connection_type( + self, + on_connect: Option>>, + on_disconnect: Option>>, + ) -> WireframeApp + where + C2: Send + 'static, + { + WireframeApp { + handlers: self.handlers, + routes: OnceCell::new(), + middleware: self.middleware, + serializer: self.serializer, + app_data: self.app_data, + on_connect, + on_disconnect, + protocol: self.protocol, + push_dlq: self.push_dlq, + codec: self.codec, + read_timeout_ms: self.read_timeout_ms, + fragmentation: self.fragmentation, + message_assembler: self.message_assembler, + } + } + /// Replace the frame codec used for framing I/O. /// /// This resets any installed protocol hooks because the frame type may @@ -288,7 +319,7 @@ where impl WireframeApp where - S: Serializer + Default + Send + Sync, + S: Serializer + Send + Sync, C: Send + 'static, E: Packet, { diff --git a/src/app/builder_defaults.rs b/src/app/builder_defaults.rs index b66bf5c0..fa539a29 100644 --- a/src/app/builder_defaults.rs +++ b/src/app/builder_defaults.rs @@ -6,6 +6,7 @@ use crate::{codec::clamp_frame_length, fragment::FragmentationConfig}; pub(super) const MIN_READ_TIMEOUT_MS: u64 = 1; pub(super) const MAX_READ_TIMEOUT_MS: u64 = 86_400_000; +pub(super) const DEFAULT_READ_TIMEOUT_MS: u64 = 100; const DEFAULT_FRAGMENT_TIMEOUT: Duration = Duration::from_secs(30); const DEFAULT_MESSAGE_SIZE_MULTIPLIER: usize = 16; diff --git a/src/app/builder_lifecycle.rs b/src/app/builder_lifecycle.rs index 81552bd4..33c2ca58 100644 --- a/src/app/builder_lifecycle.rs +++ b/src/app/builder_lifecycle.rs @@ -2,8 +2,6 @@ use std::{future::Future, sync::Arc}; -use tokio::sync::OnceCell; - use super::{builder::WireframeApp, envelope::Packet, error::Result}; use crate::{codec::FrameCodec, serializer::Serializer}; @@ -39,21 +37,7 @@ where Fut: Future + Send + 'static, C2: Send + 'static, { - Ok(WireframeApp { - handlers: self.handlers, - routes: OnceCell::new(), - middleware: self.middleware, - serializer: self.serializer, - app_data: self.app_data, - on_connect: Some(Arc::new(move || Box::pin(f()))), - on_disconnect: None, - protocol: self.protocol, - push_dlq: self.push_dlq, - codec: self.codec, - read_timeout_ms: self.read_timeout_ms, - fragmentation: self.fragmentation, - message_assembler: self.message_assembler, - }) + Ok(self.rebuild_with_connection_type(Some(Arc::new(move || Box::pin(f()))), None)) } /// Register a callback invoked when a connection is closed. diff --git a/src/app/builder_protocol.rs b/src/app/builder_protocol.rs index 3e3296a0..19132f20 100644 --- a/src/app/builder_protocol.rs +++ b/src/app/builder_protocol.rs @@ -22,6 +22,11 @@ where /// The protocol defines hooks for connection setup, frame modification, and /// command completion. It is wrapped in an [`Arc`] and stored for later use /// by the connection actor. + /// + /// At present, the protocol must use `ProtocolError = ()`. This keeps the + /// protocol object safe for dynamic dispatch, maintains a uniform + /// interface across connections, and avoids leaking application-specific + /// error types into the builder API. #[must_use] pub fn with_protocol

(self, protocol: P) -> Self where diff --git a/src/app/frame_handling.rs b/src/app/frame_handling.rs index 94d54d31..e61984e9 100644 --- a/src/app/frame_handling.rs +++ b/src/app/frame_handling.rs @@ -23,6 +23,9 @@ use crate::{ serializer::Serializer, }; +/// Tracks consecutive deserialization failures and enforces a per-connection limit. +/// +/// The counter increments on each failure; reaching `limit` terminates processing. struct DeserFailureTracker<'a> { count: &'a mut u32, limit: u32, @@ -37,7 +40,7 @@ impl<'a> DeserFailureTracker<'a> { context: &str, err: impl std::fmt::Debug, ) -> io::Result<()> { - *self.count += 1; + *self.count = self.count.saturating_add(1); warn!("{context}: correlation_id={correlation_id:?}, error={err:?}"); crate::metrics::inc_deser_errors(); if *self.count >= self.limit { @@ -50,6 +53,9 @@ impl<'a> DeserFailureTracker<'a> { } } +/// Context for writing handler responses to the framed stream. +/// +/// Carries the serializer, codec, and mutable framing state for a connection. pub(crate) struct ResponseContext<'a, S, W, F> where S: Serializer + Send + Sync, @@ -107,8 +113,10 @@ where Ok(resp) => resp, Err(e) => { warn!( - "handler error: id={}, correlation_id={:?}, error={e:?}", - env.id, env.correlation_id + "handler error: id={id}, correlation_id={correlation_id:?}, error={error:?}", + id = env.id, + correlation_id = env.correlation_id, + error = e ); crate::metrics::inc_handler_errors(); return Ok(()); @@ -128,17 +136,26 @@ where break; // already logged }; - if send_response_payload::(ctx.codec, ctx.framed, Bytes::from(bytes), &response) - .await - .is_err() - { - break; + let send_result = + send_response_payload::(ctx.codec, ctx.framed, Bytes::from(bytes), &response) + .await; + match send_result { + Ok(()) => {} + Err(err) if should_drop_response_send_error(&err) => break, // already logged + Err(err) => return Err(err), } } Ok(()) } +fn should_drop_response_send_error(error: &io::Error) -> bool { + matches!( + error.kind(), + io::ErrorKind::InvalidInput | io::ErrorKind::InvalidData + ) +} + fn fragment_responses( fragmentation: &mut Option, parts: PacketParts, @@ -151,8 +168,13 @@ fn fragment_responses( Ok(fragmented) => Ok(fragmented), Err(err) => { warn!( - "failed to fragment response: id={id}, correlation_id={correlation_id:?}, \ - error={err:?}" + concat!( + "failed to fragment response: id={id}, correlation_id={correlation_id:?}, ", + "error={err:?}" + ), + id = id, + correlation_id = correlation_id, + err = err ); crate::metrics::inc_handler_errors(); Err(io::Error::other("fragmentation failed")) @@ -172,8 +194,13 @@ fn serialize_response( Ok(bytes) => Ok(bytes), Err(e) => { warn!( - "failed to serialize response: id={id}, correlation_id={correlation_id:?}, \ - error={e:?}" + concat!( + "failed to serialize response: id={id}, correlation_id={correlation_id:?}, ", + "error={e:?}" + ), + id = id, + correlation_id = correlation_id, + e = e ); crate::metrics::inc_handler_errors(); Err(io::Error::other("serialization failed")) @@ -202,43 +229,89 @@ where let correlation_id = response.correlation_id; warn!("failed to send response: id={id}, correlation_id={correlation_id:?}, error={e:?}"); crate::metrics::inc_handler_errors(); - return Err(io::Error::other("send failed")); + return Err(e); } Ok(()) } #[cfg(test)] mod tests { + //! Tests for frame handling helpers and response sending. + use bytes::Bytes; use futures::StreamExt; + use rstest::{fixture, rstest}; + use tokio::io::DuplexStream; use super::*; - use crate::{app::combined_codec::CombinedCodec, test_helpers::TestCodec}; + use crate::{ + app::combined_codec::CombinedCodec, + test_helpers::{TestAdapter, TestCodec}, + }; - /// Verify `send_response_payload` uses `F::wrap_payload` to frame responses. - #[tokio::test] - async fn send_response_payload_wraps_with_codec() { - let codec = TestCodec::new(64); + struct FramedHarness { + codec: TestCodec, + server_framed: Framed>, + client_framed: Framed>, + } + + fn build_harness(max_frame_length: usize) -> FramedHarness { + let codec = TestCodec::new(max_frame_length); + let client_codec = TestCodec::new(max_frame_length); let (client, server) = tokio::io::duplex(256); - let combined = CombinedCodec::new(codec.decoder(), codec.encoder()); - let mut framed = Framed::new(server, combined); + let server_codec = CombinedCodec::new(codec.decoder(), codec.encoder()); + let client_codec = CombinedCodec::new(client_codec.decoder(), client_codec.encoder()); + let server_framed = Framed::new(server, server_codec); + let client_framed = Framed::new(client, client_codec); + + FramedHarness { + codec, + server_framed, + client_framed, + } + } + + #[fixture] + fn harness() -> FramedHarness { + // Keep fixture setup explicit to avoid duplicated per-test harness creation. + build_harness(64) + } - let payload = vec![1, 2, 3, 4]; + #[rstest] + #[case::ok(vec![1, 2, 3, 4], false)] + #[case::oversized(vec![0u8; 100], true)] + #[tokio::test] + async fn send_response_payload_behaviour( + #[case] payload: Vec, + #[case] expect_error: bool, + mut harness: FramedHarness, + ) { let response = Envelope::new(1, Some(99), payload.clone()); - send_response_payload::( - &codec, - &mut framed, + let result = send_response_payload::( + &harness.codec, + &mut harness.server_framed, Bytes::from(payload.clone()), &response, ) - .await - .expect("send should succeed"); + .await; - drop(framed); + if expect_error { + assert!( + result.is_err(), + "expected send to fail for oversized payload" + ); + assert_eq!( + result + .expect_err("oversized payload should return an error") + .kind(), + io::ErrorKind::InvalidInput + ); + return; + } - let combined_client = CombinedCodec::new(codec.decoder(), codec.encoder()); - let mut client_framed = Framed::new(client, combined_client); - let frame = client_framed + result.expect("send should succeed"); + let frame = harness + .client_framed .next() .await .expect("expected a frame") @@ -246,54 +319,32 @@ mod tests { assert_eq!(frame.tag, 0x42, "wrap_payload should set tag to 0x42"); assert_eq!(frame.payload, payload, "payload should match"); - assert_eq!(codec.wraps(), 1, "wrap_payload should advance codec state"); + assert_eq!( + harness.codec.wraps(), + 1, + "wrap_payload should advance codec state" + ); } /// Verify `ResponseContext` fields are accessible and usable. + #[rstest] #[tokio::test] - async fn response_context_holds_references() { + async fn response_context_holds_references(mut harness: FramedHarness) { use crate::serializer::BincodeSerializer; - let codec = TestCodec::new(64); - let (_client, server) = tokio::io::duplex(256); - let combined = CombinedCodec::new(codec.decoder(), codec.encoder()); - let mut framed = Framed::new(server, combined); let serializer = BincodeSerializer; let mut fragmentation: Option = None; let ctx: ResponseContext<'_, BincodeSerializer, _, TestCodec> = ResponseContext { serializer: &serializer, - framed: &mut framed, + framed: &mut harness.server_framed, fragmentation: &mut fragmentation, - codec: &codec, + codec: &harness.codec, }; // Verify fields are accessible (compile-time check with runtime assertion) assert!(ctx.fragmentation.is_none()); } - /// Verify `send_response_payload` returns error on send failure. - #[tokio::test] - async fn send_response_payload_returns_error_on_failure() { - let codec = TestCodec::new(4); // Small limit to trigger failure - let (_client, server) = tokio::io::duplex(256); - let combined = CombinedCodec::new(codec.decoder(), codec.encoder()); - let mut framed = Framed::new(server, combined); - - // Payload exceeds max_frame_length, so encode will fail - let oversized_payload = vec![0u8; 100]; - let response = Envelope::new(1, Some(99), oversized_payload.clone()); - let result = send_response_payload::( - &codec, - &mut framed, - Bytes::from(oversized_payload), - &response, - ) - .await; - - assert!( - result.is_err(), - "expected send to fail for oversized payload" - ); - } + // Covered by `send_response_payload_behaviour` cases. } diff --git a/src/client/builder.rs b/src/client/builder.rs index d7ff1b9f..ae5f9ed1 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -24,6 +24,8 @@ use crate::{ serializer::{BincodeSerializer, Serializer}, }; +const INITIAL_READ_BUFFER_CAPACITY_LIMIT: usize = 64 * 1024; + /// Reconstructs `WireframeClientBuilder` with one field updated to a new value. /// /// This macro reduces duplication in type-changing builder methods that need to @@ -319,8 +321,11 @@ where /// ``` /// use wireframe::client::WireframeClientBuilder; /// - /// let builder = WireframeClientBuilder::new().on_error(|err| async move { - /// eprintln!("Client error: {err}"); + /// let builder = WireframeClientBuilder::new().on_error(|err| { + /// let message = err.to_string(); + /// async move { + /// eprintln!("Client error: {message}"); + /// } /// }); /// let _ = builder; /// ``` @@ -392,8 +397,10 @@ where let codec_config = self.codec_config; let codec = codec_config.build_codec(); let mut framed = Framed::new(RewindStream::new(leftover, stream), codec); - let initial_read_buffer_capacity = - core::cmp::min(64 * 1024, codec_config.max_frame_length_value()); + let initial_read_buffer_capacity = core::cmp::min( + INITIAL_READ_BUFFER_CAPACITY_LIMIT, + codec_config.max_frame_length_value(), + ); framed .read_buffer_mut() .reserve(initial_read_buffer_capacity); diff --git a/src/codec/recovery.rs b/src/codec/recovery.rs index db520ab8..040398ee 100644 --- a/src/codec/recovery.rs +++ b/src/codec/recovery.rs @@ -227,8 +227,7 @@ pub trait RecoveryPolicyHook: Send + Sync { /// /// The default implementation delegates to /// [`CodecError::default_recovery_policy`]. - fn recovery_policy(&self, error: &CodecError, ctx: &CodecErrorContext) -> RecoveryPolicy { - let _ = ctx; + fn recovery_policy(&self, error: &CodecError, _ctx: &CodecErrorContext) -> RecoveryPolicy { error.default_recovery_policy() } @@ -274,7 +273,7 @@ impl RecoveryPolicyHook for DefaultRecoveryPolicy {} /// /// assert_eq!(config.max_consecutive_drops, 5); /// ``` -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct RecoveryConfig { /// Maximum consecutive dropped frames before escalating to disconnect. /// @@ -329,83 +328,4 @@ impl RecoveryConfig { } #[cfg(test)] -mod tests { - use std::io; - - use super::*; - - #[test] - fn recovery_policy_default_is_drop() { - assert_eq!(RecoveryPolicy::default(), RecoveryPolicy::Drop); - } - - #[test] - fn context_builder_sets_fields() { - let ctx = CodecErrorContext::new() - .with_connection_id(42) - .with_correlation_id(123) - .with_codec_state("seq=5"); - - assert_eq!(ctx.connection_id, Some(42)); - assert_eq!(ctx.correlation_id, Some(123)); - assert_eq!(ctx.codec_state, Some("seq=5".to_string())); - } - - #[test] - fn context_with_peer_address() { - let addr: SocketAddr = "127.0.0.1:8080".parse().expect("valid test address"); - let ctx = CodecErrorContext::new().with_peer_address(addr); - assert_eq!(ctx.peer_address, Some(addr)); - } - - #[test] - fn default_recovery_policy_delegates_to_error() { - use super::super::error::{EofError, FramingError}; - - let hook = DefaultRecoveryPolicy; - let ctx = CodecErrorContext::new(); - - // Check various error types - let err = CodecError::Framing(FramingError::OversizedFrame { size: 100, max: 50 }); - assert_eq!(hook.recovery_policy(&err, &ctx), RecoveryPolicy::Drop); - - let err = CodecError::Io(io::Error::other("test")); - assert_eq!(hook.recovery_policy(&err, &ctx), RecoveryPolicy::Disconnect); - - let err = CodecError::Eof(EofError::CleanClose); - assert_eq!(hook.recovery_policy(&err, &ctx), RecoveryPolicy::Disconnect); - } - - #[test] - fn default_quarantine_duration_is_30_seconds() { - let hook = DefaultRecoveryPolicy; - let ctx = CodecErrorContext::new(); - let err = CodecError::Io(io::Error::other("test")); - - assert_eq!( - hook.quarantine_duration(&err, &ctx), - Duration::from_secs(30) - ); - } - - #[test] - fn recovery_config_builder() { - let config = RecoveryConfig::default() - .max_consecutive_drops(5) - .quarantine_duration(Duration::from_secs(60)) - .log_dropped_frames(false); - - assert_eq!(config.max_consecutive_drops, 5); - assert_eq!(config.quarantine_duration, Duration::from_secs(60)); - assert!(!config.log_dropped_frames); - } - - #[test] - fn recovery_config_defaults() { - let config = RecoveryConfig::default(); - - assert_eq!(config.max_consecutive_drops, 10); - assert_eq!(config.quarantine_duration, Duration::from_secs(30)); - assert!(config.log_dropped_frames); - } -} +mod tests; diff --git a/src/codec/recovery/tests.rs b/src/codec/recovery/tests.rs new file mode 100644 index 00000000..7534667f --- /dev/null +++ b/src/codec/recovery/tests.rs @@ -0,0 +1,104 @@ +//! Tests for codec recovery policies, hooks, and configuration. + +use std::{io, net::SocketAddr, time::Duration}; + +use rstest::{fixture, rstest}; + +use super::*; + +#[fixture] +fn default_hook() -> DefaultRecoveryPolicy { + // Use the framework default hook for baseline policy assertions. + DefaultRecoveryPolicy +} + +#[fixture] +fn context() -> CodecErrorContext { + // These tests exercise hook behaviour without connection metadata. + CodecErrorContext::new() +} + +#[test] +fn recovery_policy_default_is_drop() { + assert_eq!(RecoveryPolicy::default(), RecoveryPolicy::Drop); +} + +#[test] +fn context_builder_sets_fields() { + let ctx = CodecErrorContext::new() + .with_connection_id(42) + .with_correlation_id(123) + .with_codec_state("seq=5"); + + assert_eq!(ctx.connection_id, Some(42)); + assert_eq!(ctx.correlation_id, Some(123)); + assert_eq!(ctx.codec_state, Some("seq=5".to_string())); +} + +#[test] +fn context_with_peer_address() { + let addr: SocketAddr = "127.0.0.1:8080".parse().expect("valid test address"); + let ctx = CodecErrorContext::new().with_peer_address(addr); + assert_eq!(ctx.peer_address, Some(addr)); +} + +#[rstest] +fn default_recovery_policy_delegates_to_error( + default_hook: DefaultRecoveryPolicy, + context: CodecErrorContext, +) { + use crate::codec::error::{EofError, FramingError}; + + // Check various error types + let err = CodecError::Framing(FramingError::OversizedFrame { size: 100, max: 50 }); + assert_eq!( + default_hook.recovery_policy(&err, &context), + RecoveryPolicy::Drop + ); + + let err = CodecError::Io(io::Error::other("test")); + assert_eq!( + default_hook.recovery_policy(&err, &context), + RecoveryPolicy::Disconnect + ); + + let err = CodecError::Eof(EofError::CleanClose); + assert_eq!( + default_hook.recovery_policy(&err, &context), + RecoveryPolicy::Disconnect + ); +} + +#[rstest] +fn default_quarantine_duration_is_30_seconds( + default_hook: DefaultRecoveryPolicy, + context: CodecErrorContext, +) { + let io_error = CodecError::Io(io::Error::other("test")); + + assert_eq!( + default_hook.quarantine_duration(&io_error, &context), + Duration::from_secs(30) + ); +} + +#[test] +fn recovery_config_builder() { + let config = RecoveryConfig::default() + .max_consecutive_drops(5) + .quarantine_duration(Duration::from_secs(60)) + .log_dropped_frames(false); + + assert_eq!(config.max_consecutive_drops, 5); + assert_eq!(config.quarantine_duration, Duration::from_secs(60)); + assert!(!config.log_dropped_frames); +} + +#[test] +fn recovery_config_defaults() { + let config = RecoveryConfig::default(); + + assert_eq!(config.max_consecutive_drops, 10); + assert_eq!(config.quarantine_duration, Duration::from_secs(30)); + assert!(config.log_dropped_frames); +} diff --git a/src/extractor/extractors.rs b/src/extractor/extractors.rs new file mode 100644 index 00000000..266e2efb --- /dev/null +++ b/src/extractor/extractors.rs @@ -0,0 +1,182 @@ +//! Built-in extractor implementations for message, streaming body, and connection metadata. + +use std::net::SocketAddr; + +use super::{ExtractError, FromMessageRequest, MessageRequest, Payload}; +use crate::message::Message as WireMessage; + +/// Extractor that deserializes the message payload into `T`. +#[derive(Debug, Clone)] +pub struct Message(T); + +impl Message { + /// Consumes the extractor and returns the inner deserialized message value. + #[must_use] + pub fn into_inner(self) -> T { self.0 } +} + +impl std::ops::Deref for Message { + type Target = T; + + /// Returns a reference to the inner value. + /// + /// This enables transparent access to the wrapped type via dereferencing. + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl FromMessageRequest for Message +where + T: WireMessage, +{ + type Error = ExtractError; + + /// Attempts to extract and deserialize a message of type `T` from the payload. + /// + /// Advances the payload by the number of bytes consumed during deserialization. + /// Returns an error if the payload cannot be decoded into the target type. + /// + /// # Returns + /// - `Ok(Self)`: The successfully extracted and deserialized message. + /// - `Err(ExtractError::InvalidPayload)`: If deserialization fails. + fn from_message_request( + _req: &MessageRequest, + payload: &mut Payload<'_>, + ) -> Result { + let (msg, consumed) = + T::from_bytes(payload.as_ref()).map_err(ExtractError::InvalidPayload)?; + payload.advance(consumed); + Ok(Self(msg)) + } +} + +/// Extractor providing streaming access to the request body. +/// +/// Unlike [`Payload`] which borrows buffered bytes, this extractor +/// takes ownership of a streaming body channel. Handlers opting into +/// streaming receive chunks incrementally via a [`RequestBodyStream`]. +/// +/// This type is the inbound counterpart to [`crate::Response::Stream`]. +/// +/// # Examples +/// +/// ``` +/// use bytes::Bytes; +/// use tokio::io::AsyncReadExt; +/// use wireframe::{extractor::StreamingBody, request::body_channel}; +/// +/// # #[tokio::main] +/// # async fn main() { +/// let (tx, stream) = body_channel(4); +/// +/// tokio::spawn(async move { +/// let _ = tx.send(Ok(Bytes::from_static(b"payload"))).await; +/// }); +/// +/// let body = StreamingBody::new(stream); +/// let mut reader = body.into_reader(); +/// let mut buf = Vec::new(); +/// reader.read_to_end(&mut buf).await.expect("read body"); +/// assert_eq!(buf, b"payload"); +/// # } +/// ``` +/// +/// [`RequestBodyStream`]: crate::request::RequestBodyStream +pub struct StreamingBody { + stream: crate::request::RequestBodyStream, +} + +impl StreamingBody { + /// Create a streaming body from the given stream. + /// + /// Typically constructed by the framework when a handler opts into + /// streaming request consumption. + #[must_use] + pub fn new(stream: crate::request::RequestBodyStream) -> Self { Self { stream } } + + /// Consume the extractor and return the underlying stream. + /// + /// Use this when you need direct access to the stream for custom + /// processing with [`futures::StreamExt`] methods. + #[must_use] + pub fn into_stream(self) -> crate::request::RequestBodyStream { self.stream } + + /// Convert to an [`AsyncRead`] adapter. + /// + /// Protocol crates can use this to feed streaming bytes into parsers + /// that operate on readers rather than streams. + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + #[must_use] + pub fn into_reader(self) -> crate::request::RequestBodyReader { + crate::request::RequestBodyReader::new(self.stream) + } +} + +impl std::fmt::Debug for StreamingBody { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StreamingBody").finish_non_exhaustive() + } +} + +impl FromMessageRequest for StreamingBody { + type Error = ExtractError; + + /// Extract the streaming body from the request. + /// + /// # Errors + /// + /// Returns [`ExtractError::MissingBodyStream`] if: + /// - The request was not configured for streaming consumption + /// - The stream was already consumed by another extractor + fn from_message_request( + req: &MessageRequest, + _payload: &mut Payload<'_>, + ) -> Result { + req.take_body_stream() + .map(Self::new) + .ok_or(ExtractError::MissingBodyStream) + } +} + +/// Extractor providing peer connection metadata. +#[derive(Debug, Clone, Copy)] +pub struct ConnectionInfo { + peer_addr: Option, +} + +impl ConnectionInfo { + /// Returns the peer's socket address for the current connection, if available. + /// + /// # Examples + /// + /// ```rust + /// use std::net::SocketAddr; + /// + /// use wireframe::extractor::{ConnectionInfo, FromMessageRequest, MessageRequest, Payload}; + /// + /// let addr: SocketAddr = "127.0.0.1:8080".parse().expect("valid socket address"); + /// let req = MessageRequest::new().with_peer_addr(Some(addr)); + /// let info = ConnectionInfo::from_message_request(&req, &mut Payload::default()) + /// .expect("connection info extraction should succeed"); + /// assert_eq!(info.peer_addr(), Some(addr)); + /// ``` + #[must_use] + pub fn peer_addr(&self) -> Option { self.peer_addr } +} + +impl FromMessageRequest for ConnectionInfo { + type Error = std::convert::Infallible; + + /// Extracts connection metadata from the message request. + /// + /// Returns a `ConnectionInfo` containing the peer's socket address, if available. This + /// extraction is infallible. + fn from_message_request( + req: &MessageRequest, + _payload: &mut Payload<'_>, + ) -> Result { + Ok(Self { + peer_addr: req.peer_addr, + }) + } +} diff --git a/src/extractor.rs b/src/extractor/mod.rs similarity index 59% rename from src/extractor.rs rename to src/extractor/mod.rs index f507aefb..a1a158f5 100644 --- a/src/extractor.rs +++ b/src/extractor/mod.rs @@ -11,7 +11,13 @@ use std::{ sync::{Arc, Mutex}, }; -use crate::{message::Message as WireMessage, request::RequestBodyStream}; +use thiserror::Error; + +use crate::request::RequestBodyStream; + +mod extractors; + +pub use extractors::{ConnectionInfo, Message, StreamingBody}; /// Request context passed to extractors. /// @@ -50,7 +56,8 @@ impl MessageRequest { /// /// use wireframe::extractor::MessageRequest; /// - /// let req = MessageRequest::new().with_peer_addr(Some("127.0.0.1:8080".parse().unwrap())); + /// let addr: SocketAddr = "127.0.0.1:8080".parse().expect("valid socket address"); + /// let req = MessageRequest::new().with_peer_addr(Some(addr)); /// assert!(req.peer_addr.is_some()); /// ``` #[must_use] @@ -71,12 +78,14 @@ impl MessageRequest { /// extractor::{MessageRequest, SharedState}, /// }; /// - /// let _app = WireframeApp::new().unwrap().app_data(5u32); + /// let _app = WireframeApp::new() + /// .expect("failed to initialize app") + /// .app_data(5u32); /// // The framework populates the request with application data. /// # let mut req = MessageRequest::default(); /// # req.insert_state(5u32); /// let val: Option> = req.state(); - /// assert_eq!(*val.unwrap(), 5); + /// assert_eq!(*val.expect("state should be available"), 5); /// ``` #[must_use] pub fn state(&self) -> Option> @@ -101,7 +110,7 @@ impl MessageRequest { /// let mut req = MessageRequest::default(); /// req.insert_state(5u32); /// let val: Option> = req.state(); - /// assert_eq!(*val.unwrap(), 5); + /// assert_eq!(*val.expect("state should be available"), 5); /// ``` pub fn insert_state(&mut self, state: T) where @@ -272,49 +281,24 @@ impl From for SharedState { /// /// This enum is marked `#[non_exhaustive]` so more variants may be added in /// the future without breaking changes. -#[derive(Debug)] +#[derive(Debug, Error)] #[non_exhaustive] pub enum ExtractError { /// No shared state of the requested type was found. + #[error("no shared state registered for {0}")] MissingState(&'static str), /// Failed to decode the message payload. - InvalidPayload(bincode::error::DecodeError), + #[error("failed to decode payload: {0}")] + InvalidPayload(#[source] bincode::error::DecodeError), /// No streaming body was available for this request. /// /// This occurs when: /// - The request was not configured for streaming consumption /// - The stream was already consumed by another extractor + #[error("no streaming body available for this request")] MissingBodyStream, } -impl std::fmt::Display for ExtractError { - /// Formats the `ExtractError` for display purposes. - /// - /// Displays a descriptive message for missing shared state or payload decoding errors. - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::MissingState(ty) => write!(f, "no shared state registered for {ty}"), - Self::InvalidPayload(e) => write!(f, "failed to decode payload: {e}"), - Self::MissingBodyStream => { - write!(f, "no streaming body available for this request") - } - } - } -} - -impl std::error::Error for ExtractError { - /// Returns the underlying error if this is an `InvalidPayload` variant. - /// - /// # Returns - /// An optional reference to the underlying decode error, or `None` if not applicable. - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - Self::InvalidPayload(e) => Some(e), - _ => None, - } - } -} - impl FromMessageRequest for SharedState where T: Send + Sync + 'static, @@ -350,177 +334,3 @@ impl std::ops::Deref for SharedState { /// ``` fn deref(&self) -> &Self::Target { &self.0 } } - -/// Extractor that deserializes the message payload into `T`. -#[derive(Debug, Clone)] -pub struct Message(T); - -impl Message { - /// Consumes the extractor and returns the inner deserialised message value. - #[must_use] - pub fn into_inner(self) -> T { self.0 } -} - -impl std::ops::Deref for Message { - type Target = T; - - /// Returns a reference to the inner value. - /// - /// This enables transparent access to the wrapped type via dereferencing. - fn deref(&self) -> &Self::Target { &self.0 } -} - -impl FromMessageRequest for Message -where - T: WireMessage, -{ - type Error = ExtractError; - - /// Attempts to extract and deserialize a message of type `T` from the payload. - /// - /// Advances the payload by the number of bytes consumed during deserialization. - /// Returns an error if the payload cannot be decoded into the target type. - /// - /// # Returns - /// - `Ok(Self)`: The successfully extracted and deserialized message. - /// - `Err(ExtractError::InvalidPayload)`: If deserialization fails. - fn from_message_request( - _req: &MessageRequest, - payload: &mut Payload<'_>, - ) -> Result { - let (msg, consumed) = T::from_bytes(payload.data).map_err(ExtractError::InvalidPayload)?; - payload.advance(consumed); - Ok(Self(msg)) - } -} - -/// Extractor providing streaming access to the request body. -/// -/// Unlike [`Payload`] which borrows buffered bytes, this extractor -/// takes ownership of a streaming body channel. Handlers opting into -/// streaming receive chunks incrementally via a [`RequestBodyStream`]. -/// -/// This type is the inbound counterpart to [`crate::Response::Stream`]. -/// -/// # Examples -/// -/// ``` -/// use bytes::Bytes; -/// use tokio::io::AsyncReadExt; -/// use wireframe::{extractor::StreamingBody, request::body_channel}; -/// -/// # #[tokio::main] -/// # async fn main() { -/// let (tx, stream) = body_channel(4); -/// -/// tokio::spawn(async move { -/// let _ = tx.send(Ok(Bytes::from_static(b"payload"))).await; -/// }); -/// -/// let body = StreamingBody::new(stream); -/// let mut reader = body.into_reader(); -/// let mut buf = Vec::new(); -/// reader.read_to_end(&mut buf).await.expect("read body"); -/// assert_eq!(buf, b"payload"); -/// # } -/// ``` -/// -/// [`RequestBodyStream`]: crate::request::RequestBodyStream -pub struct StreamingBody { - stream: crate::request::RequestBodyStream, -} - -impl StreamingBody { - /// Create a streaming body from the given stream. - /// - /// Typically constructed by the framework when a handler opts into - /// streaming request consumption. - #[must_use] - pub fn new(stream: crate::request::RequestBodyStream) -> Self { Self { stream } } - - /// Consume the extractor and return the underlying stream. - /// - /// Use this when you need direct access to the stream for custom - /// processing with [`futures::StreamExt`] methods. - #[must_use] - pub fn into_stream(self) -> crate::request::RequestBodyStream { self.stream } - - /// Convert to an [`AsyncRead`] adapter. - /// - /// Protocol crates can use this to feed streaming bytes into parsers - /// that operate on readers rather than streams. - /// - /// [`AsyncRead`]: tokio::io::AsyncRead - #[must_use] - pub fn into_reader(self) -> crate::request::RequestBodyReader { - crate::request::RequestBodyReader::new(self.stream) - } -} - -impl std::fmt::Debug for StreamingBody { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("StreamingBody").finish_non_exhaustive() - } -} - -impl FromMessageRequest for StreamingBody { - type Error = ExtractError; - - /// Extract the streaming body from the request. - /// - /// # Errors - /// - /// Returns [`ExtractError::MissingBodyStream`] if: - /// - The request was not configured for streaming consumption - /// - The stream was already consumed by another extractor - fn from_message_request( - req: &MessageRequest, - _payload: &mut Payload<'_>, - ) -> Result { - req.take_body_stream() - .map(Self::new) - .ok_or(ExtractError::MissingBodyStream) - } -} - -/// Extractor providing peer connection metadata. -#[derive(Debug, Clone, Copy)] -pub struct ConnectionInfo { - peer_addr: Option, -} - -impl ConnectionInfo { - /// Returns the peer's socket address for the current connection, if available. - /// - /// # Examples - /// - /// ```rust - /// use std::net::SocketAddr; - /// - /// use wireframe::extractor::{ConnectionInfo, FromMessageRequest, MessageRequest, Payload}; - /// - /// let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap(); - /// let req = MessageRequest::new().with_peer_addr(Some(addr)); - /// let info = ConnectionInfo::from_message_request(&req, &mut Payload::default()).unwrap(); - /// assert_eq!(info.peer_addr(), Some(addr)); - /// ``` - #[must_use] - pub fn peer_addr(&self) -> Option { self.peer_addr } -} - -impl FromMessageRequest for ConnectionInfo { - type Error = std::convert::Infallible; - - /// Extracts connection metadata from the message request. - /// - /// Returns a `ConnectionInfo` containing the peer's socket address, if available. This - /// extraction is infallible. - fn from_message_request( - req: &MessageRequest, - _payload: &mut Payload<'_>, - ) -> Result { - Ok(Self { - peer_addr: req.peer_addr, - }) - } -} diff --git a/src/fragment/tests.rs b/src/fragment/tests.rs index c0283c47..1338aff1 100644 --- a/src/fragment/tests.rs +++ b/src/fragment/tests.rs @@ -1,3 +1,9 @@ +//! Unit tests for the fragmentation and reassembly subsystem. +//! +//! Covers `FragmentHeader` field access, `FragmentSeries` ordering and +//! validation, `Fragmenter` splitting and message ID management, and +//! `Reassembler` assembly with size limits and expiry handling. + use std::{ num::NonZeroUsize, time::{Duration, Instant}, diff --git a/src/hooks.rs b/src/hooks.rs index 86581544..4b7f1381 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -19,10 +19,19 @@ use crate::{ pub struct ConnectionContext; /// Trait encapsulating protocol-specific logic and callbacks. +/// +/// `WireframeProtocol` allows a custom `ProtocolError` type, but +/// [`crate::app::WireframeApp::with_protocol`] currently requires +/// `ProtocolError = ()` so the protocol can be stored behind dynamic dispatch +/// with a uniform interface. This constraint may be relaxed in a future +/// release. pub trait WireframeProtocol: Send + Sync + 'static { /// Frame type written to the socket. type Frame: FrameLike; /// Custom error type for protocol operations. + /// + /// When installed via [`crate::app::WireframeApp::with_protocol`], this + /// must currently be `()`. type ProtocolError; /// Called once when a new connection is established. The provided @@ -45,10 +54,10 @@ pub trait WireframeProtocol: Send + Sync + 'static { /// /// impl WireframeProtocol for MyProtocol { /// type Frame = Vec; - /// type ProtocolError = String; + /// type ProtocolError = (); /// - /// fn handle_error(&self, error: Self::ProtocolError, _ctx: &mut ConnectionContext) { - /// tracing::error!(error = %error, "protocol error"); + /// fn handle_error(&self, _error: Self::ProtocolError, _ctx: &mut ConnectionContext) { + /// tracing::error!("protocol error"); /// // Custom handling here /// } /// } @@ -81,7 +90,7 @@ pub trait WireframeProtocol: Send + Sync + 'static { /// /// impl WireframeProtocol for MyProtocol { /// type Frame = Vec; - /// type ProtocolError = String; + /// type ProtocolError = (); /// /// fn on_eof(&self, error: &EofError, partial_data: &[u8], _ctx: &mut ConnectionContext) { /// match error { diff --git a/src/message.rs b/src/message.rs index 0e055b8e..7f0b8931 100644 --- a/src/message.rs +++ b/src/message.rs @@ -31,10 +31,10 @@ pub trait Message: Encode + for<'de> BorrowDecode<'de, ()> { /// /// # Errors /// - /// Deserialises a message instance from a byte slice using the standard configuration. + /// Deserializes a message instance from a byte slice using the standard configuration. /// - /// Returns the deserialised message and the number of bytes consumed, or a [`DecodeError`] if - /// deserialisation fails. + /// Returns the deserialized message and the number of bytes consumed, or a [`DecodeError`] if + /// deserialization fails. /// /// # Examples /// diff --git a/src/server/config/mod.rs b/src/server/config/mod.rs index ed2ef86b..31480b75 100644 --- a/src/server/config/mod.rs +++ b/src/server/config/mod.rs @@ -49,6 +49,10 @@ macro_rules! builder_callback { pub mod binding; pub mod preamble; +fn default_worker_count() -> usize { + std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get) +} + impl WireframeServer where F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, @@ -75,7 +79,7 @@ where /// ``` #[must_use] pub fn new(factory: F) -> Self { - let workers = std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get); + let workers = default_worker_count(); Self { factory, workers, diff --git a/src/server/config/tests.rs b/src/server/config/tests.rs index 5dfbbdee..5e32879e 100644 --- a/src/server/config/tests.rs +++ b/src/server/config/tests.rs @@ -15,7 +15,7 @@ use std::{ }; use bincode::error::DecodeError; -use rstest::rstest; +use rstest::{fixture, rstest}; use tokio::net::{TcpListener, TcpStream}; use super::*; @@ -37,15 +37,36 @@ enum PreambleHandlerKind { Failure, } -fn expected_default_worker_count() -> usize { - // Mirror the default worker logic to keep tests aligned with `WireframeServer::new`. - std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get) +fn assert_local_addr_matches_listener( + server: WireframeServer, + expected: std::net::SocketAddr, +) where + F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, + T: crate::preamble::Preamble, + S: crate::server::ServerState, + Ser: crate::serializer::Serializer + Send + Sync, + Ctx: Send + 'static, + E: crate::app::Packet, + Codec: crate::codec::FrameCodec, +{ + let local_addr = server.local_addr().expect("local address missing"); + assert_eq!(local_addr, expected); +} + +#[fixture] +async fn connected_streams() -> io::Result<(TcpStream, TcpStream)> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let client = TcpStream::connect(addr).await?; + let (server, _) = listener.accept().await?; + Ok((client, server)) } #[rstest] fn test_new_server_creation(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { let server = WireframeServer::new(factory); - assert!(server.worker_count() >= 1 && server.local_addr().is_none()); + assert!(server.worker_count() >= 1); + assert!(server.local_addr().is_none()); } #[rstest] @@ -53,15 +74,15 @@ fn test_new_server_default_worker_count( factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, ) { let server = WireframeServer::new(factory); - assert_eq!(server.worker_count(), expected_default_worker_count()); + assert_eq!(server.worker_count(), default_worker_count()); } #[rstest] fn test_workers_configuration(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - let mut server = WireframeServer::new(factory); - server = server.workers(4); + let server = WireframeServer::new(factory); + let server = server.workers(4); assert_eq!(server.worker_count(), 4); - server = server.workers(100); + let server = server.workers(100); assert_eq!(server.worker_count(), 100); assert_eq!(server.workers(0).worker_count(), 1); } @@ -71,7 +92,7 @@ fn test_with_preamble_type_conversion( factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, ) { let server = WireframeServer::new(factory).with_preamble::(); - assert_eq!(server.worker_count(), expected_default_worker_count()); + assert_eq!(server.worker_count(), default_worker_count()); } #[rstest] @@ -93,12 +114,10 @@ async fn test_bind_success( free_listener: std::net::TcpListener, ) { let expected = listener_addr(&free_listener); - let local_addr = WireframeServer::new(factory) + let server = WireframeServer::new(factory) .bind_existing_listener(free_listener) - .expect("Failed to bind") - .local_addr() - .expect("local address missing"); - assert_eq!(local_addr, expected); + .expect("Failed to bind"); + assert_local_addr_matches_listener(server, expected); } #[rstest] @@ -113,10 +132,8 @@ async fn test_local_addr_after_bind( free_listener: std::net::TcpListener, ) { let expected = listener_addr(&free_listener); - let local_addr = bind_server(factory, free_listener) - .local_addr() - .expect("local address missing"); - assert_eq!(local_addr, expected); + let server = bind_server(factory, free_listener); + assert_local_addr_matches_listener(server, expected); } #[rstest] @@ -126,7 +143,8 @@ async fn test_local_addr_after_bind( async fn test_preamble_handler_registration( factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, #[case] handler: PreambleHandlerKind, -) { + connected_streams: io::Result<(TcpStream, TcpStream)>, +) -> io::Result<()> { let counter = Arc::new(AtomicUsize::new(0)); let c = counter.clone(); @@ -153,44 +171,25 @@ async fn test_preamble_handler_registration( assert_eq!(counter.load(Ordering::SeqCst), 0); match handler { PreambleHandlerKind::Success => { - assert!(server.on_preamble_success.is_some()); let handler = server .on_preamble_success .as_ref() .expect("success handler missing"); - let listener = TcpListener::bind("127.0.0.1:0") - .await - .expect("bind listener"); - let addr = listener.local_addr().expect("listener addr"); - let _client = TcpStream::connect(addr) - .await - .expect("client connect failed"); - let (mut stream, _) = listener.accept().await.expect("accept stream"); + let (_client, mut stream) = connected_streams?; let preamble = TestPreamble { id: 0, message: String::new() }; - handler(&preamble, &mut stream) - .await - .expect("handler failed"); + handler(&preamble, &mut stream).await?; } PreambleHandlerKind::Failure => { - assert!(server.on_preamble_failure.is_some()); let handler = server .on_preamble_failure .as_ref() .expect("failure handler missing"); - let listener = TcpListener::bind("127.0.0.1:0") - .await - .expect("bind listener"); - let addr = listener.local_addr().expect("listener addr"); - let _client = TcpStream::connect(addr) - .await - .expect("client connect failed"); - let (mut stream, _) = listener.accept().await.expect("accept stream"); - handler(&DecodeError::UnexpectedEnd, &mut stream) - .await - .expect("handler failed"); + let (_client, mut stream) = connected_streams?; + handler(&DecodeError::UnexpectedEnd, &mut stream).await?; } } assert_eq!(counter.load(Ordering::SeqCst), 1); + Ok(()) } #[rstest] @@ -235,8 +234,8 @@ async fn test_server_configuration_persistence( #[rstest] fn test_extreme_worker_counts(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - let mut server = WireframeServer::new(factory); - server = server.workers(usize::MAX); + let server = WireframeServer::new(factory); + let server = server.workers(usize::MAX); assert_eq!(server.worker_count(), usize::MAX); assert_eq!(server.workers(0).worker_count(), 1); } @@ -264,16 +263,104 @@ async fn test_bind_to_multiple_addresses( assert_ne!(first.port(), second.port()); } +#[derive(Debug)] +struct BackoffScenario { + description: &'static str, + config: BackoffConfig, + expected_initial: Duration, + expected_max: Duration, +} + #[rstest] -fn test_accept_backoff_configuration( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) { - let cfg = BackoffConfig { +#[case::accept_config(BackoffScenario { + description: "accepts explicit delays", + config: BackoffConfig { initial_delay: Duration::from_millis(5), max_delay: Duration::from_millis(500), - }; - let server = WireframeServer::new(factory).accept_backoff(cfg); - assert_eq!(server.backoff_config, cfg); + }, + expected_initial: Duration::from_millis(5), + expected_max: Duration::from_millis(500), +})] +#[case::accept_initial_delay(BackoffScenario { + description: "accepts initial delay with default max", + config: BackoffConfig { + initial_delay: Duration::from_millis(20), + ..BackoffConfig::default() + }, + expected_initial: Duration::from_millis(20), + expected_max: BackoffConfig::default().max_delay, +})] +#[case::accept_max_delay(BackoffScenario { + description: "accepts max delay with default initial", + config: BackoffConfig { + max_delay: Duration::from_millis(2000), + ..BackoffConfig::default() + }, + expected_initial: BackoffConfig::default().initial_delay, + expected_max: Duration::from_millis(2000), +})] +#[case::clamp_zero_initial(BackoffScenario { + description: "clamps zero initial delay", + config: BackoffConfig { + initial_delay: Duration::ZERO, + ..BackoffConfig::default() + }, + expected_initial: Duration::from_millis(1), + expected_max: BackoffConfig::default().max_delay, +})] +#[case::swap_initial_gt_max(BackoffScenario { + description: "swaps initial and max delays when inverted", + config: BackoffConfig { + initial_delay: Duration::from_millis(100), + max_delay: Duration::from_millis(50), + }, + expected_initial: Duration::from_millis(50), + expected_max: Duration::from_millis(100), +})] +#[case::swap_initial_over_default_max(BackoffScenario { + description: "swaps initial and max delays when initial exceeds default max", + config: BackoffConfig { + initial_delay: Duration::from_secs(2), + max_delay: Duration::from_secs(1), + }, + expected_initial: Duration::from_secs(1), + expected_max: Duration::from_secs(2), +})] +#[case::swap_small_values(BackoffScenario { + description: "swaps small inverted delays", + config: BackoffConfig { + initial_delay: Duration::from_millis(5), + max_delay: Duration::from_millis(1), + }, + expected_initial: Duration::from_millis(1), + expected_max: Duration::from_millis(5), +})] +#[case::clamp_zero_both(BackoffScenario { + description: "clamps zero initial and max delays", + config: BackoffConfig { + initial_delay: Duration::ZERO, + max_delay: Duration::ZERO, + }, + expected_initial: Duration::from_millis(1), + expected_max: Duration::from_millis(1), +})] +fn test_accept_backoff_scenarios( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + #[case] scenario: BackoffScenario, +) { + let server = WireframeServer::new(factory).accept_backoff(scenario.config); + assert_eq!( + server.backoff_config.initial_delay, + scenario.expected_initial, + "scenario: {}", + scenario.description + ); + assert_eq!( + server.backoff_config.max_delay, + scenario.expected_max, + "scenario: {}", + scenario.description + ); } /// Behaviour test verifying exponential delay doubling and capping. @@ -308,47 +395,6 @@ fn backoff_sequence(initial: Duration, max: Duration, attempts: usize) -> Vec WireframeApp + Send + Sync + Clone + 'static, -) { - let delay = Duration::from_millis(20); - let cfg = BackoffConfig { initial_delay: delay, ..BackoffConfig::default() }; - let server = WireframeServer::new(factory).accept_backoff(cfg); - assert_eq!(server.backoff_config.initial_delay, delay); -} - -#[rstest] -fn test_accept_max_delay_configuration( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) { - let delay = Duration::from_millis(2000); - let cfg = BackoffConfig { max_delay: delay, ..BackoffConfig::default() }; - let server = WireframeServer::new(factory).accept_backoff(cfg); - assert_eq!(server.backoff_config.max_delay, delay); -} - -#[rstest] -fn test_backoff_validation(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - let server = WireframeServer::new(factory.clone()) - .accept_backoff(BackoffConfig { initial_delay: Duration::ZERO, ..BackoffConfig::default() }); - assert_eq!( - server.backoff_config.initial_delay, - Duration::from_millis(1) - ); - - let server = WireframeServer::new(factory) - .accept_backoff(BackoffConfig { - initial_delay: Duration::from_millis(100), - max_delay: Duration::from_millis(50), - }); - assert_eq!( - server.backoff_config.initial_delay, - Duration::from_millis(50) - ); - assert_eq!(server.backoff_config.max_delay, Duration::from_millis(100)); -} - #[rstest] fn test_backoff_default_values(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { let server = WireframeServer::new(factory); @@ -358,41 +404,3 @@ fn test_backoff_default_values(factory: impl Fn() -> WireframeApp + Send + Sync ); assert_eq!(server.backoff_config.max_delay, Duration::from_secs(1)); } - -#[rstest] -fn test_initial_delay_exceeds_default_max( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) { - let cfg = BackoffConfig { - initial_delay: Duration::from_secs(2), - max_delay: Duration::from_secs(1), - }; - let server = WireframeServer::new(factory).accept_backoff(cfg); - assert_eq!(server.backoff_config.initial_delay, Duration::from_secs(1)); - assert_eq!(server.backoff_config.max_delay, Duration::from_secs(2)); -} - -#[rstest] -fn test_accept_backoff_parameter_swapping( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) { - let server = WireframeServer::new(factory.clone()).accept_backoff(BackoffConfig { - initial_delay: Duration::from_millis(5), - max_delay: Duration::from_millis(1), - }); - assert_eq!( - server.backoff_config.initial_delay, - Duration::from_millis(1) - ); - assert_eq!(server.backoff_config.max_delay, Duration::from_millis(5)); - - let server = WireframeServer::new(factory).accept_backoff(BackoffConfig { - initial_delay: Duration::ZERO, - max_delay: Duration::ZERO, - }); - assert_eq!( - server.backoff_config.initial_delay, - Duration::from_millis(1) - ); - assert_eq!(server.backoff_config.max_delay, Duration::from_millis(1)); -} diff --git a/src/test_helpers/frame_codec.rs b/src/test_helpers/frame_codec.rs index f598423f..b20d7e7d 100644 --- a/src/test_helpers/frame_codec.rs +++ b/src/test_helpers/frame_codec.rs @@ -13,7 +13,7 @@ use tokio_util::codec::{Decoder, Encoder}; use crate::codec::FrameCodec; -/// Test frame that wraps payloads with a distinctive tag byte. +/// Test frame used by `TestCodec`, wrapping payloads with a distinctive tag byte. #[derive(Clone, Debug)] pub struct TestFrame { /// Tag byte stored in the frame header.