diff --git a/src/app/builder.rs b/src/app/builder.rs deleted file mode 100644 index 42162c1b..00000000 --- a/src/app/builder.rs +++ /dev/null @@ -1,335 +0,0 @@ -//! Application builder configuring routes and middleware. -//! [`WireframeApp`] is an Actix-inspired builder for managing connection -//! state, routing, and middleware in a `WireframeServer`. It exposes -//! convenience methods to register handlers and lifecycle hooks, and -//! serializes messages using a configurable serializer. - -use std::{ - any::{Any, TypeId}, - collections::HashMap, - sync::Arc, -}; - -use tokio::sync::{OnceCell, mpsc}; - -use super::{ - builder_defaults::{ - DEFAULT_READ_TIMEOUT_MS, - MAX_READ_TIMEOUT_MS, - MIN_READ_TIMEOUT_MS, - default_fragmentation, - }, - envelope::{Envelope, Packet}, - error::{Result, WireframeError}, - lifecycle::{ConnectionSetup, ConnectionTeardown}, - middleware_types::{Handler, Middleware}, -}; -use crate::{ - codec::{FrameCodec, LengthDelimitedFrameCodec, clamp_frame_length}, - fragment::FragmentationConfig, - hooks::WireframeProtocol, - message_assembler::MessageAssembler, - middleware::HandlerService, - serializer::{BincodeSerializer, Serializer}, -}; - -/// Configures routing and middleware for a `WireframeServer`. -/// -/// The builder stores registered routes and middleware without enforcing an -/// ordering. Methods return [`Result`] so registrations can be chained -/// ergonomically. -pub struct WireframeApp< - S: Serializer + Send + Sync = BincodeSerializer, - C: Send + 'static = (), - E: Packet = Envelope, - F: FrameCodec = LengthDelimitedFrameCodec, -> { - pub(super) handlers: HashMap>, - pub(super) routes: OnceCell>>>, - pub(super) middleware: Vec>>, - pub(super) serializer: S, - pub(super) app_data: HashMap>, - pub(super) on_connect: Option>>, - pub(super) on_disconnect: Option>>, - pub(super) protocol: Option>>, - pub(super) push_dlq: Option>>, - pub(super) codec: F, - pub(super) read_timeout_ms: u64, - pub(super) fragmentation: Option, - pub(super) message_assembler: Option>, -} - -impl Default for WireframeApp -where - S: Serializer + Default + Send + Sync, - C: Send + 'static, - E: Packet, - F: FrameCodec + Default, -{ - /// Initializes empty routes, middleware, and application data with the - /// default serializer and no lifecycle hooks. - fn default() -> Self { - let codec = F::default(); - let max_frame_length = codec.max_frame_length(); - Self { - handlers: HashMap::new(), - routes: OnceCell::new(), - middleware: Vec::new(), - serializer: S::default(), - app_data: HashMap::new(), - on_connect: None, - on_disconnect: None, - protocol: None, - push_dlq: None, - codec, - read_timeout_ms: DEFAULT_READ_TIMEOUT_MS, - fragmentation: default_fragmentation(max_frame_length), - message_assembler: None, - } - } -} - -impl WireframeApp -where - S: Serializer + Default + Send + Sync, - C: Send + 'static, - E: Packet, - F: FrameCodec + Default, -{ - /// Construct a new empty application builder. - /// - /// # Errors - /// - /// This function currently never returns an error but uses [`Result`] for - /// forward compatibility. - /// - /// # Examples - /// - /// ``` - /// use wireframe::app::WireframeApp; - /// WireframeApp::<_, _, wireframe::app::Envelope>::new().expect("failed to initialize app"); - /// ``` - pub fn new() -> Result { Ok(Self::default()) } - - /// Construct a new application builder using the provided serializer. - /// - /// # Errors - /// - /// This function currently never returns an error but uses [`Result`] for - /// forward compatibility. - pub fn with_serializer(serializer: S) -> Result { - Ok(Self { - serializer, - ..Self::default() - }) - } -} - -impl WireframeApp -where - S: Serializer + Send + Sync, - C: Send + 'static, - E: Packet, - F: FrameCodec, -{ - /// Helper to rebuild the app when changing type parameters. - /// - /// This centralises the field-by-field reconstruction required when - /// transforming between different serializer or codec types. - #[expect( - clippy::too_many_arguments, - reason = "internal helper grouping fields for type-transitioning builders" - )] - fn rebuild_with_params( - self, - serializer: S2, - codec: F2, - protocol: Option>>, - fragmentation: Option, - message_assembler: Option>, - ) -> WireframeApp - where - S2: Serializer + Send + Sync, - F2: FrameCodec, - { - WireframeApp { - handlers: self.handlers, - routes: OnceCell::new(), - middleware: self.middleware, - serializer, - app_data: self.app_data, - on_connect: self.on_connect, - on_disconnect: self.on_disconnect, - protocol, - push_dlq: self.push_dlq, - codec, - read_timeout_ms: self.read_timeout_ms, - fragmentation, - message_assembler, - } - } - - /// Helper to rebuild the app when changing the connection state type. - pub(super) fn rebuild_with_connection_type( - self, - on_connect: Option>>, - on_disconnect: Option>>, - ) -> WireframeApp - where - C2: Send + 'static, - { - WireframeApp { - handlers: self.handlers, - routes: OnceCell::new(), - middleware: self.middleware, - serializer: self.serializer, - app_data: self.app_data, - on_connect, - on_disconnect, - protocol: self.protocol, - push_dlq: self.push_dlq, - codec: self.codec, - read_timeout_ms: self.read_timeout_ms, - fragmentation: self.fragmentation, - message_assembler: self.message_assembler, - } - } - - /// Replace the frame codec used for framing I/O. - /// - /// This resets any installed protocol hooks because the frame type may - /// change across codecs. Fragmentation configuration is reset to the - /// codec-derived default. - #[must_use] - pub fn with_codec(mut self, codec: F2) -> WireframeApp - where - S: Default, - { - let fragmentation = default_fragmentation(codec.max_frame_length()); - let serializer = std::mem::take(&mut self.serializer); - let message_assembler = self.message_assembler.take(); - self.rebuild_with_params(serializer, codec, None, fragmentation, message_assembler) - } - - /// Register a route that maps `id` to `handler`. - /// - /// # Errors - /// - /// Returns [`WireframeError::DuplicateRoute`] if a handler for `id` - /// has already been registered. - pub fn route(mut self, id: u32, handler: Handler) -> Result { - if self.handlers.contains_key(&id) { - return Err(WireframeError::DuplicateRoute(id)); - } - self.handlers.insert(id, handler); - self.routes = OnceCell::new(); - Ok(self) - } - - /// Store a shared state value accessible to request extractors. - /// - /// The value can later be retrieved using [`crate::extractor::SharedState`]. Registering - /// another value of the same type overwrites the previous one. - #[must_use] - pub fn app_data(mut self, state: T) -> Self - where - T: Send + Sync + 'static, - { - self.app_data.insert( - TypeId::of::(), - Arc::new(state) as Arc, - ); - self - } - - /// Add a middleware component to the processing pipeline. - /// - /// # Errors - /// - /// This function currently always succeeds. - pub fn wrap(mut self, mw: M) -> Result - where - M: Middleware + 'static, - { - self.middleware.push(Box::new(mw)); - self.routes = OnceCell::new(); - Ok(self) - } - - /// Configure a Dead Letter Queue for dropped push frames. - /// - /// ```rust,no_run - /// use tokio::sync::mpsc; - /// use wireframe::app::WireframeApp; - /// - /// # fn build() -> WireframeApp { - /// # WireframeApp::new().expect("builder creation should not fail") - /// # } - /// # fn main() { - /// let (tx, _rx) = mpsc::channel(16); - /// let app = build().with_push_dlq(tx); - /// # let _ = app; - /// # } - /// ``` - #[must_use] - pub fn with_push_dlq(self, dlq: mpsc::Sender>) -> Self { - WireframeApp { - push_dlq: Some(dlq), - ..self - } - } - - /// Replace the serializer used for messages. - #[must_use] - pub fn serializer(mut self, serializer: Ser) -> WireframeApp - where - Ser: Serializer + Send + Sync, - F: Default, - { - let codec = std::mem::take(&mut self.codec); - let protocol = self.protocol.take(); - let fragmentation = self.fragmentation.take(); - let message_assembler = self.message_assembler.take(); - self.rebuild_with_params( - serializer, - codec, - protocol, - fragmentation, - message_assembler, - ) - } - - /// Configure the read timeout in milliseconds. - /// Clamped between 1 and 86 400 000 milliseconds (24 h). - #[must_use] - pub fn read_timeout_ms(mut self, timeout_ms: u64) -> Self { - self.read_timeout_ms = timeout_ms.clamp(MIN_READ_TIMEOUT_MS, MAX_READ_TIMEOUT_MS); - self - } - - /// Override the fragmentation configuration. - /// - /// Provide `None` to disable fragmentation entirely. - #[must_use] - pub fn fragmentation(mut self, config: Option) -> Self { - self.fragmentation = config; - self - } -} - -impl WireframeApp -where - S: Serializer + Send + Sync, - C: Send + 'static, - E: Packet, -{ - /// Set the initial buffer capacity for framed reads. - /// Clamped between 64 bytes and 16 MiB. - #[must_use] - pub fn buffer_capacity(mut self, capacity: usize) -> Self { - let capacity = clamp_frame_length(capacity); - self.codec = LengthDelimitedFrameCodec::new(capacity); - self.fragmentation = default_fragmentation(capacity); - self - } -} diff --git a/src/app/builder/codec.rs b/src/app/builder/codec.rs new file mode 100644 index 00000000..934c65dd --- /dev/null +++ b/src/app/builder/codec.rs @@ -0,0 +1,69 @@ +//! Codec and serializer configuration for `WireframeApp`. + +use super::WireframeApp; +use crate::{ + app::{Packet, builder_defaults::default_fragmentation}, + codec::{FrameCodec, LengthDelimitedFrameCodec, clamp_frame_length}, + serializer::Serializer, +}; + +impl WireframeApp +where + S: Serializer + Send + Sync, + C: Send + 'static, + E: Packet, + F: FrameCodec, +{ + /// Replace the frame codec used for framing I/O. + /// + /// This resets any installed protocol hooks because the frame type may + /// change across codecs. Fragmentation configuration is reset to the + /// codec-derived default. + #[must_use] + pub fn with_codec(mut self, codec: F2) -> WireframeApp + where + S: Default, + { + let fragmentation = default_fragmentation(codec.max_frame_length()); + let serializer = std::mem::take(&mut self.serializer); + let message_assembler = self.message_assembler.take(); + self.rebuild_with_params(serializer, codec, None, fragmentation, message_assembler) + } + + /// Replace the serializer used for messages. + #[must_use] + pub fn serializer(mut self, serializer: Ser) -> WireframeApp + where + Ser: Serializer + Send + Sync, + F: Default, + { + let codec = std::mem::take(&mut self.codec); + let protocol = self.protocol.take(); + let fragmentation = self.fragmentation.take(); + let message_assembler = self.message_assembler.take(); + self.rebuild_with_params( + serializer, + codec, + protocol, + fragmentation, + message_assembler, + ) + } +} + +impl WireframeApp +where + S: Serializer + Send + Sync, + C: Send + 'static, + E: Packet, +{ + /// Set the initial buffer capacity for framed reads. + /// Clamped between 64 bytes and 16 MiB. + #[must_use] + pub fn buffer_capacity(mut self, capacity: usize) -> Self { + let capacity = clamp_frame_length(capacity); + self.codec = LengthDelimitedFrameCodec::new(capacity); + self.fragmentation = default_fragmentation(capacity); + self + } +} diff --git a/src/app/builder/config.rs b/src/app/builder/config.rs new file mode 100644 index 00000000..dac658d9 --- /dev/null +++ b/src/app/builder/config.rs @@ -0,0 +1,62 @@ +//! General configuration methods for `WireframeApp`. + +use tokio::sync::mpsc; + +use super::WireframeApp; +use crate::{ + app::{ + Packet, + builder_defaults::{MAX_READ_TIMEOUT_MS, MIN_READ_TIMEOUT_MS}, + }, + codec::FrameCodec, + fragment::FragmentationConfig, + serializer::Serializer, +}; + +impl WireframeApp +where + S: Serializer + Send + Sync, + C: Send + 'static, + E: Packet, + F: FrameCodec, +{ + /// Configure the read timeout in milliseconds. + /// Clamped between 1 and 86,400,000 milliseconds (24 h). + #[must_use] + pub fn read_timeout_ms(mut self, timeout_ms: u64) -> Self { + self.read_timeout_ms = timeout_ms.clamp(MIN_READ_TIMEOUT_MS, MAX_READ_TIMEOUT_MS); + self + } + + /// Override the fragmentation configuration. + /// + /// Provide `None` to disable fragmentation entirely. + #[must_use] + pub fn fragmentation(mut self, config: Option) -> Self { + self.fragmentation = config; + self + } + + /// Configure a Dead Letter Queue for dropped push frames. + /// + /// ```rust,no_run + /// use tokio::sync::mpsc; + /// use wireframe::app::WireframeApp; + /// + /// # fn build() -> WireframeApp { + /// # WireframeApp::new().expect("builder creation should not fail") + /// # } + /// # fn main() { + /// let (tx, _rx) = mpsc::channel(16); + /// let app = build().with_push_dlq(tx); + /// # let _ = app; + /// # } + /// ``` + #[must_use] + pub fn with_push_dlq(self, dlq: mpsc::Sender>) -> Self { + WireframeApp { + push_dlq: Some(dlq), + ..self + } + } +} diff --git a/src/app/builder/core.rs b/src/app/builder/core.rs new file mode 100644 index 00000000..8ebc6795 --- /dev/null +++ b/src/app/builder/core.rs @@ -0,0 +1,162 @@ +//! Core builder types for `WireframeApp`. + +use std::{ + any::{Any, TypeId}, + collections::HashMap, + sync::Arc, +}; + +use tokio::sync::{OnceCell, mpsc}; + +use crate::{ + app::{ + builder_defaults::{DEFAULT_READ_TIMEOUT_MS, default_fragmentation}, + envelope::{Envelope, Packet}, + error::Result, + lifecycle::{ConnectionSetup, ConnectionTeardown}, + middleware_types::{Handler, Middleware}, + }, + codec::{FrameCodec, LengthDelimitedFrameCodec}, + hooks::WireframeProtocol, + message_assembler::MessageAssembler, + middleware::HandlerService, + serializer::{BincodeSerializer, Serializer}, +}; + +/// Configures routing and middleware for a `WireframeServer`. +/// +/// The builder stores registered routes and middleware without enforcing an +/// ordering. Methods return [`Result`] so registrations can be chained +/// ergonomically. +pub struct WireframeApp< + S: Serializer + Send + Sync = BincodeSerializer, + C: Send + 'static = (), + E: Packet = Envelope, + F: FrameCodec = LengthDelimitedFrameCodec, +> { + pub(in crate::app) handlers: HashMap>, + pub(in crate::app) routes: OnceCell>>>, + pub(in crate::app) middleware: Vec>>, + pub(in crate::app) serializer: S, + pub(in crate::app) app_data: HashMap>, + pub(in crate::app) on_connect: Option>>, + pub(in crate::app) on_disconnect: Option>>, + pub(in crate::app) protocol: + Option>>, + pub(in crate::app) push_dlq: Option>>, + pub(in crate::app) codec: F, + pub(in crate::app) read_timeout_ms: u64, + pub(in crate::app) fragmentation: Option, + pub(in crate::app) message_assembler: Option>, +} + +impl Default for WireframeApp +where + S: Serializer + Default + Send + Sync, + C: Send + 'static, + E: Packet, + F: FrameCodec + Default, +{ + /// Initializes empty routes, middleware, and application data with the + /// default serializer and no lifecycle hooks. + fn default() -> Self { + let codec = F::default(); + let max_frame_length = codec.max_frame_length(); + Self { + handlers: HashMap::new(), + routes: OnceCell::new(), + middleware: Vec::new(), + serializer: S::default(), + app_data: HashMap::new(), + on_connect: None, + on_disconnect: None, + protocol: None, + push_dlq: None, + codec, + read_timeout_ms: DEFAULT_READ_TIMEOUT_MS, + fragmentation: default_fragmentation(max_frame_length), + message_assembler: None, + } + } +} + +impl WireframeApp +where + S: Serializer + Default + Send + Sync, + C: Send + 'static, + E: Packet, + F: FrameCodec + Default, +{ + /// Construct a new empty application builder. + /// + /// # Errors + /// + /// This function currently never returns an error but uses [`Result`] for + /// forward compatibility. + /// + /// # Examples + /// + /// ``` + /// use wireframe::app::WireframeApp; + /// WireframeApp::<_, _, wireframe::app::Envelope>::new().expect("failed to initialize app"); + /// ``` + pub fn new() -> Result { Ok(Self::default()) } + + /// Construct a new application builder using the provided serializer. + /// + /// # Errors + /// + /// This function currently never returns an error but uses [`Result`] for + /// forward compatibility. + pub fn with_serializer(serializer: S) -> Result { + Ok(Self { + serializer, + ..Self::default() + }) + } +} + +impl WireframeApp +where + S: Serializer + Send + Sync, + C: Send + 'static, + E: Packet, + F: FrameCodec, +{ + /// Helper to rebuild the app when changing type parameters. + /// + /// This centralises the field-by-field reconstruction required when + /// transforming between different serializer or codec types. + #[expect( + clippy::too_many_arguments, + reason = "internal helper grouping fields for type-transitioning builders" + )] + pub(super) fn rebuild_with_params( + self, + serializer: S2, + codec: F2, + protocol: Option>>, + fragmentation: Option, + message_assembler: Option>, + ) -> WireframeApp + where + S2: Serializer + Send + Sync, + F2: FrameCodec, + { + WireframeApp { + handlers: self.handlers, + routes: OnceCell::new(), + middleware: self.middleware, + serializer, + app_data: self.app_data, + on_connect: self.on_connect, + on_disconnect: self.on_disconnect, + protocol, + push_dlq: self.push_dlq, + codec, + read_timeout_ms: self.read_timeout_ms, + fragmentation, + message_assembler, + } + } +} diff --git a/src/app/builder_lifecycle.rs b/src/app/builder/lifecycle.rs similarity index 65% rename from src/app/builder_lifecycle.rs rename to src/app/builder/lifecycle.rs index 33c2ca58..e460732d 100644 --- a/src/app/builder_lifecycle.rs +++ b/src/app/builder/lifecycle.rs @@ -1,9 +1,13 @@ -//! Connection lifecycle builder methods for [`WireframeApp`]. +//! Connection lifecycle hook configuration for `WireframeApp`. use std::{future::Future, sync::Arc}; -use super::{builder::WireframeApp, envelope::Packet, error::Result}; -use crate::{codec::FrameCodec, serializer::Serializer}; +use super::WireframeApp; +use crate::{ + app::{Packet, error::Result}, + codec::FrameCodec, + serializer::Serializer, +}; impl WireframeApp where @@ -21,8 +25,8 @@ where /// # Type Parameters /// /// This method changes the connection state type parameter from `C` to `C2`. - /// This means that any subsequent builder methods will operate on the new connection state - /// type `C2`. Be aware of this type transition when chaining builder methods. + /// This means that any subsequent builder methods will operate on the new connection state type + /// `C2`. Be aware of this type transition when chaining builder methods. /// /// # Errors /// @@ -37,7 +41,21 @@ where Fut: Future + Send + 'static, C2: Send + 'static, { - Ok(self.rebuild_with_connection_type(Some(Arc::new(move || Box::pin(f()))), None)) + Ok(WireframeApp { + handlers: self.handlers, + routes: tokio::sync::OnceCell::new(), + middleware: self.middleware, + serializer: self.serializer, + app_data: self.app_data, + on_connect: Some(Arc::new(move || Box::pin(f()))), + on_disconnect: None, + protocol: self.protocol, + push_dlq: self.push_dlq, + codec: self.codec, + read_timeout_ms: self.read_timeout_ms, + fragmentation: self.fragmentation, + message_assembler: self.message_assembler, + }) } /// Register a callback invoked when a connection is closed. diff --git a/src/app/builder/mod.rs b/src/app/builder/mod.rs new file mode 100644 index 00000000..afa80911 --- /dev/null +++ b/src/app/builder/mod.rs @@ -0,0 +1,15 @@ +//! Application builder configuring routes and middleware. +//! [`WireframeApp`] is an Actix-inspired builder for managing connection +//! state, routing, and middleware in a `WireframeServer`. It exposes +//! convenience methods to register handlers and lifecycle hooks, and +//! serializes messages using a configurable serializer. + +mod codec; +mod config; +mod core; +mod lifecycle; +mod protocol; +mod routing; +mod state; + +pub use core::WireframeApp; diff --git a/src/app/builder_protocol.rs b/src/app/builder/protocol.rs similarity index 97% rename from src/app/builder_protocol.rs rename to src/app/builder/protocol.rs index 19132f20..c303569f 100644 --- a/src/app/builder_protocol.rs +++ b/src/app/builder/protocol.rs @@ -1,9 +1,10 @@ -//! Protocol configuration builder methods for [`WireframeApp`]. +//! Protocol and message assembly configuration for `WireframeApp`. use std::sync::Arc; -use super::{builder::WireframeApp, envelope::Packet}; +use super::WireframeApp; use crate::{ + app::Packet, codec::FrameCodec, hooks::{ProtocolHooks, WireframeProtocol}, message_assembler::MessageAssembler, @@ -75,6 +76,27 @@ where } } + /// Get a clone of the configured protocol, if any. + /// + /// Returns `None` if no protocol was installed via [`with_protocol`](Self::with_protocol). + #[must_use] + pub fn protocol( + &self, + ) -> Option>> { + self.protocol.clone() + } + + /// Return protocol hooks derived from the installed protocol. + /// + /// If no protocol is installed, returns default (no-op) hooks. + #[must_use] + pub fn protocol_hooks(&self) -> ProtocolHooks { + self.protocol + .as_ref() + .map(ProtocolHooks::from_protocol) + .unwrap_or_default() + } + /// Get the configured message assembler, if any. /// /// # Examples @@ -105,25 +127,4 @@ where pub fn message_assembler(&self) -> Option<&Arc> { self.message_assembler.as_ref() } - - /// Get a clone of the configured protocol, if any. - /// - /// Returns `None` if no protocol was installed via [`with_protocol`](Self::with_protocol). - #[must_use] - pub fn protocol( - &self, - ) -> Option>> { - self.protocol.clone() - } - - /// Return protocol hooks derived from the installed protocol. - /// - /// If no protocol is installed, returns default (no-op) hooks. - #[must_use] - pub fn protocol_hooks(&self) -> ProtocolHooks { - self.protocol - .as_ref() - .map(ProtocolHooks::from_protocol) - .unwrap_or_default() - } } diff --git a/src/app/builder/routing.rs b/src/app/builder/routing.rs new file mode 100644 index 00000000..945b58b1 --- /dev/null +++ b/src/app/builder/routing.rs @@ -0,0 +1,49 @@ +//! Routing and middleware configuration for `WireframeApp`. + +use super::WireframeApp; +use crate::{ + app::{ + Packet, + error::{Result, WireframeError}, + middleware_types::{Handler, Middleware}, + }, + codec::FrameCodec, + serializer::Serializer, +}; + +impl WireframeApp +where + S: Serializer + Send + Sync, + C: Send + 'static, + E: Packet, + F: FrameCodec, +{ + /// Register a route that maps `id` to `handler`. + /// + /// # Errors + /// + /// Returns [`WireframeError::DuplicateRoute`] if a handler for `id` + /// has already been registered. + pub fn route(mut self, id: u32, handler: Handler) -> Result { + if self.handlers.contains_key(&id) { + return Err(WireframeError::DuplicateRoute(id)); + } + self.handlers.insert(id, handler); + self.routes = tokio::sync::OnceCell::new(); + Ok(self) + } + + /// Add a middleware component to the processing pipeline. + /// + /// # Errors + /// + /// This function currently always succeeds. + pub fn wrap(mut self, mw: M) -> Result + where + M: Middleware + 'static, + { + self.middleware.push(Box::new(mw)); + self.routes = tokio::sync::OnceCell::new(); + Ok(self) + } +} diff --git a/src/app/builder/state.rs b/src/app/builder/state.rs new file mode 100644 index 00000000..904b31f8 --- /dev/null +++ b/src/app/builder/state.rs @@ -0,0 +1,30 @@ +//! Shared state configuration for `WireframeApp`. + +use std::{any::TypeId, sync::Arc}; + +use super::WireframeApp; +use crate::{app::Packet, codec::FrameCodec, serializer::Serializer}; + +impl WireframeApp +where + S: Serializer + Send + Sync, + C: Send + 'static, + E: Packet, + F: FrameCodec, +{ + /// Store a shared state value accessible to request extractors. + /// + /// The value can later be retrieved using [`crate::extractor::SharedState`]. Registering + /// another value of the same type overwrites the previous one. + #[must_use] + pub fn app_data(mut self, state: T) -> Self + where + T: Send + Sync + 'static, + { + self.app_data.insert( + TypeId::of::(), + Arc::new(state) as Arc, + ); + self + } +} diff --git a/src/app/builder_defaults.rs b/src/app/builder_defaults.rs index fa539a29..b1c7c280 100644 --- a/src/app/builder_defaults.rs +++ b/src/app/builder_defaults.rs @@ -6,6 +6,7 @@ use crate::{codec::clamp_frame_length, fragment::FragmentationConfig}; pub(super) const MIN_READ_TIMEOUT_MS: u64 = 1; pub(super) const MAX_READ_TIMEOUT_MS: u64 = 86_400_000; +/// Default preamble read timeout in milliseconds. pub(super) const DEFAULT_READ_TIMEOUT_MS: u64 = 100; const DEFAULT_FRAGMENT_TIMEOUT: Duration = Duration::from_secs(30); const DEFAULT_MESSAGE_SIZE_MULTIPLIER: usize = 16; diff --git a/src/app/frame_handling.rs b/src/app/frame_handling.rs deleted file mode 100644 index 72ae1248..00000000 --- a/src/app/frame_handling.rs +++ /dev/null @@ -1,350 +0,0 @@ -//! Shared helpers for frame decoding, reassembly, and response forwarding. -//! -//! Extracted from `connection.rs` to keep modules small and focused. - -use std::io; - -use bytes::Bytes; -use futures::SinkExt; -use log::warn; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::codec::Framed; - -use super::{ - Envelope, - Packet, - PacketParts, - combined_codec::ConnectionCodec, - fragmentation_state::{FragmentProcessError, FragmentationState}, -}; -use crate::{ - codec::FrameCodec, - middleware::{HandlerService, Service, ServiceRequest}, - serializer::Serializer, -}; - -/// Tracks consecutive deserialization failures and enforces a per-connection limit. -/// -/// The counter increments on each failure; reaching `limit` terminates processing. -struct DeserFailureTracker<'a> { - count: &'a mut u32, - limit: u32, -} - -impl<'a> DeserFailureTracker<'a> { - fn new(count: &'a mut u32, limit: u32) -> Self { Self { count, limit } } - - fn record( - &mut self, - correlation_id: Option, - context: &str, - err: impl std::fmt::Debug, - ) -> io::Result<()> { - *self.count = self.count.saturating_add(1); - warn!("{context}: correlation_id={correlation_id:?}, error={err:?}"); - crate::metrics::inc_deser_errors(); - if *self.count >= self.limit { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "too many deserialization failures", - )); - } - Ok(()) - } -} - -/// Context for writing handler responses to the framed stream. -/// -/// Carries the serializer, codec, and mutable framing state for a connection. -pub(crate) struct ResponseContext<'a, S, W, F> -where - S: Serializer + Send + Sync, - W: AsyncRead + AsyncWrite + Unpin, - F: FrameCodec, -{ - pub(crate) serializer: &'a S, - pub(crate) framed: &'a mut Framed>, - pub(crate) fragmentation: &'a mut Option, - pub(crate) codec: &'a F, -} - -/// Attempt to reassemble a potentially fragmented envelope. -pub(crate) fn reassemble_if_needed( - fragmentation: &mut Option, - deser_failures: &mut u32, - env: Envelope, - max_deser_failures: u32, -) -> io::Result> { - let mut failures = DeserFailureTracker::new(deser_failures, max_deser_failures); - - if let Some(state) = fragmentation.as_mut() { - let correlation_id = env.correlation_id; - match state.reassemble(env) { - Ok(Some(env)) => Ok(Some(env)), - Ok(None) => Ok(None), - Err(FragmentProcessError::Decode(err)) => { - failures.record(correlation_id, "failed to decode fragment header", err)?; - Ok(None) - } - Err(FragmentProcessError::Reassembly(err)) => { - failures.record(correlation_id, "fragment reassembly failed", err)?; - Ok(None) - } - } - } else { - Ok(Some(env)) - } -} - -/// Forward a handler response, fragmenting if required, and write to the framed stream. -pub(crate) async fn forward_response( - env: Envelope, - service: &HandlerService, - ctx: ResponseContext<'_, S, W, F>, -) -> io::Result<()> -where - S: Serializer + Send + Sync, - E: Packet, - W: AsyncRead + AsyncWrite + Unpin, - F: FrameCodec, -{ - let request = ServiceRequest::new(env.payload, env.correlation_id); - let resp = match service.call(request).await { - Ok(resp) => resp, - Err(e) => { - warn!( - "handler error: id={id}, correlation_id={correlation_id:?}, error={error:?}", - id = env.id, - correlation_id = env.correlation_id, - error = e - ); - crate::metrics::inc_handler_errors(); - return Ok(()); - } - }; - - let parts = PacketParts::new(env.id, resp.correlation_id(), resp.into_inner()) - .inherit_correlation(env.correlation_id); - let correlation_id = parts.correlation_id(); - let Ok(responses) = fragment_responses(ctx.fragmentation, parts, env.id, correlation_id) else { - return Ok(()); // already logged - }; - - for response in responses { - let Ok(bytes) = serialize_response(ctx.serializer, &response, env.id, correlation_id) - else { - break; // already logged - }; - - let send_result = - send_response_payload::(ctx.codec, ctx.framed, Bytes::from(bytes), &response) - .await; - match send_result { - Ok(()) => {} - Err(err) if should_drop_response_send_error(&err) => break, // already logged - Err(err) => return Err(err), - } - } - - Ok(()) -} - -fn should_drop_response_send_error(error: &io::Error) -> bool { - matches!( - error.kind(), - io::ErrorKind::InvalidInput | io::ErrorKind::InvalidData - ) -} - -fn fragment_responses( - fragmentation: &mut Option, - parts: PacketParts, - id: u32, - correlation_id: Option, -) -> io::Result> { - let envelope = Envelope::from_parts(parts); - match fragmentation.as_mut() { - Some(state) => match state.fragment(envelope) { - Ok(fragmented) => Ok(fragmented), - Err(err) => { - warn!( - concat!( - "failed to fragment response: id={id}, correlation_id={correlation_id:?}, ", - "error={err:?}" - ), - id = id, - correlation_id = correlation_id, - err = err - ); - crate::metrics::inc_handler_errors(); - Err(io::Error::other("fragmentation failed")) - } - }, - None => Ok(vec![envelope]), - } -} - -fn serialize_response( - serializer: &S, - response: &Envelope, - id: u32, - correlation_id: Option, -) -> io::Result> { - match serializer.serialize(response) { - Ok(bytes) => Ok(bytes), - Err(e) => { - warn!( - concat!( - "failed to serialize response: id={id}, correlation_id={correlation_id:?}, ", - "error={e:?}" - ), - id = id, - correlation_id = correlation_id, - e = e - ); - crate::metrics::inc_handler_errors(); - Err(io::Error::other("serialization failed")) - } - } -} - -/// Send a response payload over the framed stream using codec-aware wrapping. -/// -/// Wraps the raw payload bytes in the codec's native frame format via -/// [`FrameCodec::wrap_payload`] before writing to the underlying stream. -/// This ensures responses are encoded correctly for the configured protocol. -async fn send_response_payload( - codec: &F, - framed: &mut Framed>, - payload: Bytes, - response: &Envelope, -) -> io::Result<()> -where - W: AsyncRead + AsyncWrite + Unpin, - F: FrameCodec, -{ - let frame = codec.wrap_payload(payload); - if let Err(e) = framed.send(frame).await { - let id = response.id; - let correlation_id = response.correlation_id; - warn!("failed to send response: id={id}, correlation_id={correlation_id:?}, error={e:?}"); - crate::metrics::inc_handler_errors(); - return Err(e); - } - Ok(()) -} - -#[cfg(all(test, not(loom)))] -mod tests { - //! Tests for frame handling helpers and response sending. - - use bytes::Bytes; - use futures::StreamExt; - use rstest::{fixture, rstest}; - use tokio::io::DuplexStream; - - use super::*; - use crate::{ - app::combined_codec::CombinedCodec, - test_helpers::{TestAdapter, TestCodec}, - }; - - struct FramedHarness { - codec: TestCodec, - server_framed: Framed>, - client_framed: Framed>, - } - - fn build_harness(max_frame_length: usize) -> FramedHarness { - let codec = TestCodec::new(max_frame_length); - let client_codec = TestCodec::new(max_frame_length); - let (client, server) = tokio::io::duplex(256); - let server_codec = CombinedCodec::new(codec.decoder(), codec.encoder()); - let client_codec = CombinedCodec::new(client_codec.decoder(), client_codec.encoder()); - let server_framed = Framed::new(server, server_codec); - let client_framed = Framed::new(client, client_codec); - - FramedHarness { - codec, - server_framed, - client_framed, - } - } - - #[fixture] - fn harness() -> FramedHarness { - // Keep fixture setup explicit to avoid duplicated per-test harness creation. - build_harness(64) - } - - #[rstest] - #[case::ok(vec![1, 2, 3, 4], false)] - #[case::oversized(vec![0u8; 100], true)] - #[tokio::test] - async fn send_response_payload_behaviour( - #[case] payload: Vec, - #[case] expect_error: bool, - mut harness: FramedHarness, - ) { - let response = Envelope::new(1, Some(99), payload.clone()); - let result = send_response_payload::( - &harness.codec, - &mut harness.server_framed, - Bytes::from(payload.clone()), - &response, - ) - .await; - - if expect_error { - assert!( - result.is_err(), - "expected send to fail for oversized payload" - ); - assert_eq!( - result - .expect_err("oversized payload should return an error") - .kind(), - io::ErrorKind::InvalidInput - ); - return; - } - - result.expect("send should succeed"); - let frame = harness - .client_framed - .next() - .await - .expect("expected a frame") - .expect("decode should succeed"); - - assert_eq!(frame.tag, 0x42, "wrap_payload should set tag to 0x42"); - assert_eq!(frame.payload, payload, "payload should match"); - assert_eq!( - harness.codec.wraps(), - 1, - "wrap_payload should advance codec state" - ); - } - - /// Verify `ResponseContext` fields are accessible and usable. - #[rstest] - #[tokio::test] - async fn response_context_holds_references(mut harness: FramedHarness) { - use crate::serializer::BincodeSerializer; - - let serializer = BincodeSerializer; - let mut fragmentation: Option = None; - - let ctx: ResponseContext<'_, BincodeSerializer, _, TestCodec> = ResponseContext { - serializer: &serializer, - framed: &mut harness.server_framed, - fragmentation: &mut fragmentation, - codec: &harness.codec, - }; - - // Verify fields are accessible (compile-time check with runtime assertion) - assert!(ctx.fragmentation.is_none()); - } - - // Covered by `send_response_payload_behaviour` cases. -} diff --git a/src/app/frame_handling/core.rs b/src/app/frame_handling/core.rs new file mode 100644 index 00000000..2fff9b7e --- /dev/null +++ b/src/app/frame_handling/core.rs @@ -0,0 +1,54 @@ +//! Core frame handling context types. + +use std::io; + +use log::warn; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::Framed; + +use crate::{ + app::{combined_codec::ConnectionCodec, fragmentation_state::FragmentationState}, + codec::FrameCodec, + serializer::Serializer, +}; + +/// Tracks deserialization failures and enforces a maximum error threshold. +pub(super) struct DeserFailureTracker<'a> { + count: &'a mut u32, + limit: u32, +} + +impl<'a> DeserFailureTracker<'a> { + pub(super) fn new(count: &'a mut u32, limit: u32) -> Self { Self { count, limit } } + + pub(super) fn record( + &mut self, + correlation_id: Option, + context: &str, + err: impl std::fmt::Debug, + ) -> io::Result<()> { + *self.count = (*self.count).saturating_add(1); + warn!("{context}: correlation_id={correlation_id:?}, error={err:?}"); + crate::metrics::inc_deser_errors(); + if *self.count >= self.limit { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "too many deserialization failures", + )); + } + Ok(()) + } +} + +/// Bundles shared dependencies for response forwarding. +pub(crate) struct ResponseContext<'a, S, W, F> +where + S: Serializer + Send + Sync, + W: AsyncRead + AsyncWrite + Unpin, + F: FrameCodec, +{ + pub(crate) serializer: &'a S, + pub(crate) framed: &'a mut Framed>, + pub(crate) fragmentation: &'a mut Option, + pub(crate) codec: &'a F, +} diff --git a/src/app/frame_handling/mod.rs b/src/app/frame_handling/mod.rs new file mode 100644 index 00000000..9cce4b6f --- /dev/null +++ b/src/app/frame_handling/mod.rs @@ -0,0 +1,15 @@ +//! Shared helpers for frame decoding, reassembly, and response forwarding. +//! +//! Extracted from `connection.rs` to keep modules small and focused. + +mod core; +mod reassembly; +mod response; + +pub(crate) use core::ResponseContext; + +pub(crate) use reassembly::reassemble_if_needed; +pub(crate) use response::forward_response; + +#[cfg(all(test, not(loom)))] +mod tests; diff --git a/src/app/frame_handling/reassembly.rs b/src/app/frame_handling/reassembly.rs new file mode 100644 index 00000000..7d9abfb4 --- /dev/null +++ b/src/app/frame_handling/reassembly.rs @@ -0,0 +1,37 @@ +//! Helpers for fragment reassembly. + +use std::io; + +use super::core::DeserFailureTracker; +use crate::app::{ + Envelope, + fragmentation_state::{FragmentProcessError, FragmentationState}, +}; + +/// Attempt to reassemble a potentially fragmented envelope. +pub(crate) fn reassemble_if_needed( + fragmentation: &mut Option, + deser_failures: &mut u32, + env: Envelope, + max_deser_failures: u32, +) -> io::Result> { + let mut failures = DeserFailureTracker::new(deser_failures, max_deser_failures); + + if let Some(state) = fragmentation.as_mut() { + let correlation_id = env.correlation_id; + match state.reassemble(env) { + Ok(Some(env)) => Ok(Some(env)), + Ok(None) => Ok(None), + Err(FragmentProcessError::Decode(err)) => { + failures.record(correlation_id, "failed to decode fragment header", err)?; + Ok(None) + } + Err(FragmentProcessError::Reassembly(err)) => { + failures.record(correlation_id, "fragment reassembly failed", err)?; + Ok(None) + } + } + } else { + Ok(Some(env)) + } +} diff --git a/src/app/frame_handling/response.rs b/src/app/frame_handling/response.rs new file mode 100644 index 00000000..edcb3d33 --- /dev/null +++ b/src/app/frame_handling/response.rs @@ -0,0 +1,147 @@ +//! Response forwarding helpers for frame handling. + +use std::io; + +use bytes::Bytes; +use futures::SinkExt; +use log::warn; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::Framed; + +use crate::{ + app::{ + Envelope, + Packet, + PacketParts, + combined_codec::ConnectionCodec, + fragmentation_state::FragmentationState, + }, + codec::FrameCodec, + middleware::{HandlerService, Service, ServiceRequest}, + serializer::Serializer, +}; + +/// Forward a handler response, fragmenting if required, and write to the framed stream. +/// +/// `forward_response` accepts an [`Envelope`], builds a [`ServiceRequest`], and +/// invokes `service.call(request)`. If the handler returns `Err(e)`, this is +/// treated as an application-level failure: the error is logged, +/// [`crate::metrics::inc_handler_errors()`] is incremented, and the function +/// returns `Ok(())` (intentional log-and-continue behaviour). Transport-level +/// I/O failures (for example during fragmentation, serialization, or frame +/// send) still return `io::Error` and are propagated to the caller. +pub(crate) async fn forward_response( + env: Envelope, + service: &HandlerService, + ctx: super::ResponseContext<'_, S, W, F>, +) -> io::Result<()> +where + S: Serializer + Send + Sync, + E: Packet, + W: AsyncRead + AsyncWrite + Unpin, + F: FrameCodec, +{ + let request = ServiceRequest::new(env.payload, env.correlation_id); + let resp = match service.call(request).await { + Ok(resp) => resp, + Err(e) => { + warn!( + "handler error: id={id}, correlation_id={correlation_id:?}, error={e:?}", + id = env.id, + correlation_id = env.correlation_id + ); + crate::metrics::inc_handler_errors(); + return Ok(()); + } + }; + + let parts = PacketParts::new(env.id, resp.correlation_id(), resp.into_inner()) + .inherit_correlation(env.correlation_id); + let correlation_id = parts.correlation_id(); + let responses = fragment_responses(ctx.fragmentation, parts, env.id, correlation_id)?; + + for response in responses { + let bytes = serialize_response(ctx.serializer, &response, env.id, correlation_id)?; + send_response_payload::(ctx.codec, ctx.framed, Bytes::from(bytes), &response).await?; + } + + Ok(()) +} + +fn fragment_responses( + fragmentation: &mut Option, + parts: PacketParts, + id: u32, + correlation_id: Option, +) -> io::Result> { + let envelope = Envelope::from_parts(parts); + match fragmentation.as_mut() { + Some(state) => match state.fragment(envelope) { + Ok(fragmented) => Ok(fragmented), + Err(err) => { + warn!( + concat!( + "failed to fragment response: id={id}, correlation_id={correlation_id:?}, ", + "error={err:?}" + ), + id = id, + correlation_id = correlation_id, + err = err + ); + crate::metrics::inc_handler_errors(); + Err(io::Error::other(err)) + } + }, + None => Ok(vec![envelope]), + } +} + +fn serialize_response( + serializer: &S, + response: &Envelope, + id: u32, + correlation_id: Option, +) -> io::Result> { + match serializer.serialize(response) { + Ok(bytes) => Ok(bytes), + Err(e) => { + warn!( + concat!( + "failed to serialize response: id={id}, correlation_id={correlation_id:?}, ", + "error={e:?}" + ), + id = id, + correlation_id = correlation_id, + e = e + ); + crate::metrics::inc_handler_errors(); + Err(io::Error::other(e)) + } + } +} + +/// Send a response payload over the framed stream using codec-aware wrapping. +/// +/// Wraps the raw payload bytes in the codec's native frame format via +/// [`FrameCodec::wrap_payload`] before writing to the underlying stream. +/// This ensures responses are encoded correctly for the configured protocol. +pub(super) async fn send_response_payload( + codec: &F, + framed: &mut Framed>, + payload: Bytes, + response: &Envelope, +) -> io::Result<()> +where + W: AsyncRead + AsyncWrite + Unpin, + F: FrameCodec, +{ + let frame = codec.wrap_payload(payload); + if let Err(e) = framed.send(frame).await { + let id = response.id; + let correlation_id = response.correlation_id; + warn!("failed to send response: id={id}, correlation_id={correlation_id:?}, error={e:?}"); + crate::metrics::inc_handler_errors(); + return Err(io::Error::other(e)); + } + Ok(()) +} diff --git a/src/app/frame_handling/tests.rs b/src/app/frame_handling/tests.rs new file mode 100644 index 00000000..e693f68b --- /dev/null +++ b/src/app/frame_handling/tests.rs @@ -0,0 +1,249 @@ +//! Tests for frame handling helpers. + +use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, +}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures::StreamExt; +use rstest::{fixture, rstest}; +use tokio::io::DuplexStream; +use tokio_util::codec::{Decoder, Encoder}; + +use super::{ResponseContext, response::send_response_payload}; +use crate::{ + app::{Envelope, combined_codec::CombinedCodec, fragmentation_state::FragmentationState}, + codec::FrameCodec, +}; + +/// Test frame carrying a tag byte and payload. +#[derive(Clone, Debug)] +struct TestFrame { + tag: u8, + payload: Vec, +} + +/// Test codec that wraps payloads with a distinctive tag byte. +#[derive(Clone, Debug)] +struct TestCodec { + max_frame_length: usize, + counter: Arc, +} + +impl TestCodec { + fn new(max_frame_length: usize) -> Self { + Self { + max_frame_length, + counter: Arc::new(AtomicUsize::new(0)), + } + } + + fn wraps(&self) -> usize { self.counter.load(Ordering::SeqCst) } +} + +#[derive(Clone, Debug)] +struct TestAdapter { + max_frame_length: usize, +} + +impl TestAdapter { + fn new(max_frame_length: usize) -> Self { Self { max_frame_length } } +} + +impl Decoder for TestAdapter { + type Item = TestFrame; + type Error = std::io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + const HEADER_LEN: usize = 2; + if src.len() < HEADER_LEN { + return Ok(None); + } + + let mut header = src.as_ref(); + let tag = header.get_u8(); + let payload_len = header.get_u8() as usize; + if payload_len > self.max_frame_length { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "payload too large", + )); + } + if src.len() < HEADER_LEN + payload_len { + return Ok(None); + } + + let mut frame_bytes = src.split_to(HEADER_LEN + payload_len); + frame_bytes.advance(HEADER_LEN); + let payload = frame_bytes.to_vec(); + + Ok(Some(TestFrame { tag, payload })) + } +} + +impl Encoder for TestAdapter { + type Error = std::io::Error; + + fn encode(&mut self, item: TestFrame, dst: &mut BytesMut) -> Result<(), Self::Error> { + if item.payload.len() > self.max_frame_length { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "payload too large", + )); + } + + let payload_len = u8::try_from(item.payload.len()).map_err(|_| { + std::io::Error::new(std::io::ErrorKind::InvalidInput, "payload too long") + })?; + dst.reserve(2 + item.payload.len()); + dst.put_u8(item.tag); + dst.put_u8(payload_len); + dst.extend_from_slice(&item.payload); + Ok(()) + } +} + +impl FrameCodec for TestCodec { + type Frame = TestFrame; + type Decoder = TestAdapter; + type Encoder = TestAdapter; + + fn decoder(&self) -> Self::Decoder { TestAdapter::new(self.max_frame_length) } + + fn encoder(&self) -> Self::Encoder { TestAdapter::new(self.max_frame_length) } + + fn frame_payload(frame: &Self::Frame) -> &[u8] { frame.payload.as_slice() } + + /// Wraps payload with tag byte 0x42 to verify codec-aware wrapping. + fn wrap_payload(&self, payload: Bytes) -> Self::Frame { + self.counter.fetch_add(1, Ordering::SeqCst); + TestFrame { + tag: 0x42, + payload: payload.to_vec(), + } + } + + fn correlation_id(frame: &Self::Frame) -> Option { Some(u64::from(frame.tag)) } + + fn max_frame_length(&self) -> usize { self.max_frame_length } +} + +struct TestHarness { + codec: TestCodec, + framed: tokio_util::codec::Framed>, + client: DuplexStream, +} + +#[fixture] +fn harness() -> TestHarness { + let max_frame_length = 64; + build_harness(max_frame_length) +} + +#[fixture] +fn small_harness() -> TestHarness { + let max_frame_length = 4; + build_harness(max_frame_length) +} + +fn build_harness(max_frame_length: usize) -> TestHarness { + let codec = TestCodec::new(max_frame_length); + let (client, server) = tokio::io::duplex(256); + let combined = CombinedCodec::new(codec.decoder(), codec.encoder()); + let framed = tokio_util::codec::Framed::new(server, combined); + + TestHarness { + codec, + framed, + client, + } +} + +/// Verify `send_response_payload` uses `F::wrap_payload` to frame responses. +#[rstest] +#[tokio::test] +async fn send_response_payload_wraps_with_codec(harness: TestHarness) { + let TestHarness { + codec, + mut framed, + client, + } = harness; + + let payload = vec![1, 2, 3, 4]; + let response = Envelope::new(1, Some(99), payload.clone()); + send_response_payload::( + &codec, + &mut framed, + Bytes::from(payload.clone()), + &response, + ) + .await + .expect("send should succeed"); + + drop(framed); + + let combined_client = CombinedCodec::new(codec.decoder(), codec.encoder()); + let mut client_framed = tokio_util::codec::Framed::new(client, combined_client); + let frame = client_framed + .next() + .await + .expect("expected a frame") + .expect("decode should succeed"); + + assert_eq!(frame.tag, 0x42, "wrap_payload should set tag to 0x42"); + assert_eq!(frame.payload, payload, "payload should match"); + assert_eq!(codec.wraps(), 1, "wrap_payload should advance codec state"); +} + +/// Verify `ResponseContext` fields are accessible and usable. +#[rstest] +#[tokio::test] +async fn response_context_holds_references(harness: TestHarness) { + use crate::serializer::BincodeSerializer; + + let TestHarness { + codec, + mut framed, + client: _client, + } = harness; + let serializer = BincodeSerializer; + let mut fragmentation: Option = None; + + let ctx: ResponseContext<'_, BincodeSerializer, _, TestCodec> = ResponseContext { + serializer: &serializer, + framed: &mut framed, + fragmentation: &mut fragmentation, + codec: &codec, + }; + + // Verify fields are accessible (compile-time check with runtime assertion) + assert!(ctx.fragmentation.is_none()); +} + +/// Verify `send_response_payload` returns error on send failure. +#[rstest] +#[tokio::test] +async fn send_response_payload_returns_error_on_failure(small_harness: TestHarness) { + let TestHarness { + codec, + mut framed, + client: _client, + } = small_harness; + + // Payload exceeds max_frame_length, so encode will fail + let oversized_payload = vec![0u8; 100]; + let response = Envelope::new(1, Some(99), oversized_payload.clone()); + let result = send_response_payload::( + &codec, + &mut framed, + Bytes::from(oversized_payload), + &response, + ) + .await; + + assert!( + result.is_err(), + "expected send to fail for oversized payload" + ); +} diff --git a/src/app/mod.rs b/src/app/mod.rs index dd1f4544..74fdcff0 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -10,8 +10,6 @@ mod builder; mod builder_defaults; -mod builder_lifecycle; -mod builder_protocol; mod combined_codec; mod connection; mod envelope; diff --git a/src/client/builder.rs b/src/client/builder.rs deleted file mode 100644 index ae5f9ed1..00000000 --- a/src/client/builder.rs +++ /dev/null @@ -1,425 +0,0 @@ -//! Builder for configuring and connecting a wireframe client. - -use std::{ - future::Future, - net::SocketAddr, - sync::{Arc, atomic::AtomicU64}, -}; - -use bincode::Encode; -use tokio::net::TcpSocket; -use tokio_util::codec::Framed; - -use super::{ - ClientCodecConfig, - ClientError, - SocketOptions, - WireframeClient, - hooks::LifecycleHooks, - preamble_exchange::{PreambleConfig, perform_preamble_exchange}, -}; -use crate::{ - frame::LengthFormat, - rewind_stream::RewindStream, - serializer::{BincodeSerializer, Serializer}, -}; - -const INITIAL_READ_BUFFER_CAPACITY_LIMIT: usize = 64 * 1024; - -/// Reconstructs `WireframeClientBuilder` with one field updated to a new value. -/// -/// This macro reduces duplication in type-changing builder methods that need to -/// create a new builder instance with different generic parameters. When a type -/// parameter changes, struct update syntax (`..self`) cannot be used, so fields -/// must be copied explicitly. -/// -/// The `lifecycle_hooks` field requires special handling because `LifecycleHooks` -/// is parameterized by the connection state type. When changing `S` or `P`, the -/// hooks can be moved directly since `C` is unchanged. When changing `C` via -/// `on_connection_setup`, a new `LifecycleHooks` must be constructed. -macro_rules! builder_field_update { - // Serializer change: preserves P and C, moves lifecycle_hooks unchanged - ($self:expr,serializer = $value:expr) => { - WireframeClientBuilder { - serializer: $value, - codec_config: $self.codec_config, - socket_options: $self.socket_options, - preamble_config: $self.preamble_config, - lifecycle_hooks: $self.lifecycle_hooks, - } - }; - // Preamble change: preserves S and C, moves lifecycle_hooks unchanged - ($self:expr,preamble_config = $value:expr) => { - WireframeClientBuilder { - serializer: $self.serializer, - codec_config: $self.codec_config, - socket_options: $self.socket_options, - preamble_config: $value, - lifecycle_hooks: $self.lifecycle_hooks, - } - }; - // Lifecycle hooks change: preserves S and P, changes C - ($self:expr,lifecycle_hooks = $value:expr) => { - WireframeClientBuilder { - serializer: $self.serializer, - codec_config: $self.codec_config, - socket_options: $self.socket_options, - preamble_config: $self.preamble_config, - lifecycle_hooks: $value, - } - }; -} - -/// Builder for [`WireframeClient`]. -/// -/// The builder supports three generic type parameters: -/// - `S`: The serializer type (default: `BincodeSerializer`) -/// - `P`: The preamble type (default: `()`) -/// - `C`: The connection state type returned by the setup hook (default: `()`) -/// -/// # Examples -/// -/// ``` -/// use wireframe::client::WireframeClientBuilder; -/// -/// let builder = WireframeClientBuilder::new(); -/// let _ = builder; -/// ``` -pub struct WireframeClientBuilder { - pub(crate) serializer: S, - pub(crate) codec_config: ClientCodecConfig, - pub(crate) socket_options: SocketOptions, - pub(crate) preamble_config: Option>, - pub(crate) lifecycle_hooks: LifecycleHooks, -} - -impl WireframeClientBuilder { - /// Create a new builder with default settings. - /// - /// # Examples - /// - /// ``` - /// use wireframe::client::WireframeClientBuilder; - /// - /// let builder = WireframeClientBuilder::new(); - /// let _ = builder; - /// ``` - #[must_use] - pub fn new() -> Self { - Self { - serializer: BincodeSerializer, - codec_config: ClientCodecConfig::default(), - socket_options: SocketOptions::default(), - preamble_config: None, - lifecycle_hooks: LifecycleHooks::default(), - } - } -} - -impl Default for WireframeClientBuilder { - fn default() -> Self { Self::new() } -} - -impl WireframeClientBuilder -where - S: Serializer + Send + Sync, -{ - /// Replace the serializer used for encoding and decoding messages. - /// - /// # Examples - /// - /// ``` - /// use wireframe::{BincodeSerializer, client::WireframeClientBuilder}; - /// - /// let builder = WireframeClientBuilder::new().serializer(BincodeSerializer); - /// let _ = builder; - /// ``` - #[must_use] - pub fn serializer(self, serializer: Ser) -> WireframeClientBuilder - where - Ser: Serializer + Send + Sync, - { - builder_field_update!(self, serializer = serializer) - } - - /// Configure codec settings for the connection. - /// - /// # Examples - /// - /// ``` - /// use wireframe::client::{ClientCodecConfig, WireframeClientBuilder}; - /// - /// let codec = ClientCodecConfig::default().max_frame_length(2048); - /// let builder = WireframeClientBuilder::new().codec_config(codec); - /// let _ = builder; - /// ``` - #[must_use] - pub fn codec_config(mut self, codec_config: ClientCodecConfig) -> Self { - self.codec_config = codec_config; - self - } - - /// Configure the maximum frame length for the connection. - /// - /// # Examples - /// - /// ``` - /// use wireframe::client::WireframeClientBuilder; - /// - /// let builder = WireframeClientBuilder::new().max_frame_length(2048); - /// let _ = builder; - /// ``` - #[must_use] - pub fn max_frame_length(mut self, max_frame_length: usize) -> Self { - self.codec_config = self.codec_config.max_frame_length(max_frame_length); - self - } - - /// Configure the length prefix format for the connection. - /// - /// # Examples - /// - /// ``` - /// use wireframe::{client::WireframeClientBuilder, frame::LengthFormat}; - /// - /// let builder = WireframeClientBuilder::new().length_format(LengthFormat::u16_be()); - /// let _ = builder; - /// ``` - #[must_use] - pub fn length_format(mut self, length_format: LengthFormat) -> Self { - self.codec_config = self.codec_config.length_format(length_format); - self - } - - // Socket option convenience methods (nodelay, keepalive, linger, etc.) are - // in `socket_option_methods.rs` to keep this file under 400 lines. - - /// Configure a preamble to send before exchanging frames. - /// - /// The preamble is written to the server immediately after establishing - /// the TCP connection, before the framing layer begins. Use - /// [`on_preamble_success`](Self::on_preamble_success) to read the server's - /// response and [`preamble_timeout`](Self::preamble_timeout) to bound the - /// exchange. - /// - /// # Examples - /// - /// ``` - /// use wireframe::client::WireframeClientBuilder; - /// - /// #[derive(bincode::Encode)] - /// struct MyPreamble { - /// version: u16, - /// } - /// - /// let builder = WireframeClientBuilder::new().with_preamble(MyPreamble { version: 1 }); - /// let _ = builder; - /// ``` - #[must_use] - pub fn with_preamble(self, preamble: Q) -> WireframeClientBuilder - where - Q: Encode + Send + Sync + 'static, - { - builder_field_update!(self, preamble_config = Some(PreambleConfig::new(preamble))) - } - - /// Register a callback invoked when the connection is established. - /// - /// The callback can perform authentication or other setup tasks and returns - /// connection-specific state stored for the connection's lifetime. This - /// hook is invoked after the preamble exchange (if configured) succeeds. - /// - /// # Type Parameters - /// - /// This method changes the connection state type parameter from `C` to - /// `C2`. Subsequent builder methods will operate on the new connection - /// state type. Note that any previously configured `on_connection_teardown` - /// hook is cleared because its type signature depends on the old state - /// type. The `on_error` hook is preserved since it does not depend on - /// the connection state type. - /// - /// # Examples - /// - /// ``` - /// use wireframe::client::WireframeClientBuilder; - /// - /// struct Session { - /// id: u64, - /// } - /// - /// let builder = - /// WireframeClientBuilder::new().on_connection_setup(|| async { Session { id: 42 } }); - /// let _ = builder; - /// ``` - #[must_use] - pub fn on_connection_setup(self, f: F) -> WireframeClientBuilder - where - F: Fn() -> Fut + Send + Sync + 'static, - Fut: Future + Send + 'static, - C2: Send + 'static, - { - // Preserve on_error since it is not parameterized by C. - // on_disconnect must be cleared because its signature depends on C. - let on_error = self.lifecycle_hooks.on_error; - builder_field_update!( - self, - lifecycle_hooks = LifecycleHooks { - on_connect: Some(Arc::new(move || Box::pin(f()))), - on_disconnect: None, - on_error, - } - ) - } -} - -impl WireframeClientBuilder -where - S: Serializer + Send + Sync, - C: Send + 'static, -{ - /// Register a callback invoked when the connection is closed. - /// - /// The callback receives the connection state produced by - /// [`on_connection_setup`](Self::on_connection_setup). The teardown hook - /// is invoked when [`WireframeClient::close`](super::WireframeClient::close) - /// is called. - /// - /// # Examples - /// - /// ``` - /// use wireframe::client::WireframeClientBuilder; - /// - /// struct Session { - /// id: u64, - /// } - /// - /// let builder = WireframeClientBuilder::new() - /// .on_connection_setup(|| async { Session { id: 42 } }) - /// .on_connection_teardown(|session| async move { - /// println!("Session {} closed", session.id); - /// }); - /// let _ = builder; - /// ``` - #[must_use] - pub fn on_connection_teardown(mut self, f: F) -> Self - where - F: Fn(C) -> Fut + Send + Sync + 'static, - Fut: Future + Send + 'static, - { - self.lifecycle_hooks.on_disconnect = Some(Arc::new(move |c| Box::pin(f(c)))); - self - } - - /// Register a callback invoked when an error occurs. - /// - /// The callback receives a reference to the error and can perform logging - /// or recovery actions. The handler is invoked before the error is returned - /// to the caller. - /// - /// # Examples - /// - /// ``` - /// use wireframe::client::WireframeClientBuilder; - /// - /// let builder = WireframeClientBuilder::new().on_error(|err| { - /// let message = err.to_string(); - /// async move { - /// eprintln!("Client error: {message}"); - /// } - /// }); - /// let _ = builder; - /// ``` - #[must_use] - pub fn on_error(mut self, f: F) -> Self - where - F: for<'a> Fn(&'a ClientError) -> Fut + Send + Sync + 'static, - Fut: Future + Send + 'static, - { - self.lifecycle_hooks.on_error = Some(Arc::new(move |e| Box::pin(f(e)))); - self - } -} - -impl WireframeClientBuilder -where - S: Serializer + Send + Sync, - P: Encode + Send + Sync + 'static, - C: Send + 'static, -{ - /// Establish a connection and return a configured client. - /// - /// If a preamble is configured, it is written to the server before the - /// framing layer is established. The success callback (if registered) is - /// invoked after writing the preamble and may read the server's response. - /// If a connection setup hook is registered, it is invoked after the - /// preamble exchange succeeds. - /// - /// # Errors - /// - /// Returns [`ClientError`] if socket configuration, connection, or - /// preamble exchange fails. - /// - /// # Examples - /// - /// ```no_run - /// use std::net::SocketAddr; - /// - /// use wireframe::{ClientError, WireframeClient}; - /// - /// # #[tokio::main] - /// # async fn main() -> Result<(), ClientError> { - /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address"); - /// let _client = WireframeClient::builder().connect(addr).await?; - /// # Ok(()) - /// # } - /// ``` - pub async fn connect( - self, - addr: SocketAddr, - ) -> Result, C>, ClientError> { - let socket = if addr.is_ipv4() { - TcpSocket::new_v4()? - } else { - TcpSocket::new_v6()? - }; - self.socket_options.apply(&socket)?; - let mut stream = socket.connect(addr).await?; - - // Perform preamble exchange if configured. - let leftover = if let Some(config) = self.preamble_config { - perform_preamble_exchange(&mut stream, config).await? - } else { - Vec::new() - }; - - // Build framed codec, always wrapping in RewindStream for type consistency. - // When leftover is empty, RewindStream has negligible overhead. - let codec_config = self.codec_config; - let codec = codec_config.build_codec(); - let mut framed = Framed::new(RewindStream::new(leftover, stream), codec); - let initial_read_buffer_capacity = core::cmp::min( - INITIAL_READ_BUFFER_CAPACITY_LIMIT, - codec_config.max_frame_length_value(), - ); - framed - .read_buffer_mut() - .reserve(initial_read_buffer_capacity); - - // Invoke connection setup hook if configured. - let connection_state = if let Some(ref setup) = self.lifecycle_hooks.on_connect { - Some(setup().await) - } else { - None - }; - - Ok(WireframeClient { - framed, - serializer: self.serializer, - codec_config, - connection_state, - on_disconnect: self.lifecycle_hooks.on_disconnect, - on_error: self.lifecycle_hooks.on_error, - correlation_counter: AtomicU64::new(1), - }) - } -} diff --git a/src/client/builder/codec.rs b/src/client/builder/codec.rs new file mode 100644 index 00000000..d28523a8 --- /dev/null +++ b/src/client/builder/codec.rs @@ -0,0 +1,58 @@ +//! Codec configuration methods for `WireframeClientBuilder`. + +use super::WireframeClientBuilder; +use crate::{client::ClientCodecConfig, frame::LengthFormat, serializer::Serializer}; + +impl WireframeClientBuilder +where + S: Serializer + Send + Sync, +{ + /// Configure codec settings for the connection. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::{ClientCodecConfig, WireframeClientBuilder}; + /// + /// let codec = ClientCodecConfig::default().max_frame_length(2048); + /// let builder = WireframeClientBuilder::new().codec_config(codec); + /// let _ = builder; + /// ``` + #[must_use] + pub fn codec_config(mut self, codec_config: ClientCodecConfig) -> Self { + self.codec_config = codec_config; + self + } + + /// Configure the maximum frame length for the connection. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// let builder = WireframeClientBuilder::new().max_frame_length(2048); + /// let _ = builder; + /// ``` + #[must_use] + pub fn max_frame_length(mut self, max_frame_length: usize) -> Self { + self.codec_config = self.codec_config.max_frame_length(max_frame_length); + self + } + + /// Configure the length prefix format for the connection. + /// + /// # Examples + /// + /// ``` + /// use wireframe::{client::WireframeClientBuilder, frame::LengthFormat}; + /// + /// let builder = WireframeClientBuilder::new().length_format(LengthFormat::u16_be()); + /// let _ = builder; + /// ``` + #[must_use] + pub fn length_format(mut self, length_format: LengthFormat) -> Self { + self.codec_config = self.codec_config.length_format(length_format); + self + } +} diff --git a/src/client/builder/connect.rs b/src/client/builder/connect.rs new file mode 100644 index 00000000..f72ed4b8 --- /dev/null +++ b/src/client/builder/connect.rs @@ -0,0 +1,100 @@ +//! Connection establishment for `WireframeClientBuilder`. + +use std::{net::SocketAddr, sync::atomic::AtomicU64}; + +use bincode::Encode; +use tokio::net::TcpSocket; +use tokio_util::codec::Framed; + +use super::WireframeClientBuilder; +use crate::{ + client::{ClientError, WireframeClient, preamble_exchange::perform_preamble_exchange}, + rewind_stream::RewindStream, + serializer::Serializer, +}; + +const INITIAL_READ_BUFFER_CAPACITY: usize = 64 * 1024; + +impl WireframeClientBuilder +where + S: Serializer + Send + Sync, + P: Encode + Send + Sync + 'static, + C: Send + 'static, +{ + /// Establish a connection and return a configured client. + /// + /// If a preamble is configured, it is written to the server before the + /// framing layer is established. The success callback (if registered) is + /// invoked after writing the preamble and may read the server's response. + /// If a connection setup hook is registered, it is invoked after the + /// preamble exchange succeeds. + /// + /// # Errors + /// + /// Returns [`ClientError`] if socket configuration, connection, or + /// preamble exchange fails. + /// + /// # Examples + /// + /// ```no_run + /// use std::net::SocketAddr; + /// + /// use wireframe::{ClientError, WireframeClient}; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), ClientError> { + /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address"); + /// let _client = WireframeClient::builder().connect(addr).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn connect( + self, + addr: SocketAddr, + ) -> Result, C>, ClientError> { + let socket = if addr.is_ipv4() { + TcpSocket::new_v4()? + } else { + TcpSocket::new_v6()? + }; + self.socket_options.apply(&socket)?; + let mut stream = socket.connect(addr).await?; + + // Perform preamble exchange if configured. + let leftover = if let Some(config) = self.preamble_config { + perform_preamble_exchange(&mut stream, config).await? + } else { + Vec::new() + }; + + // Build framed codec, always wrapping in RewindStream for type consistency. + // When leftover is empty, RewindStream has negligible overhead. + let codec_config = self.codec_config; + let codec = codec_config.build_codec(); + let mut framed = Framed::new(RewindStream::new(leftover, stream), codec); + let initial_read_buffer_capacity = core::cmp::min( + INITIAL_READ_BUFFER_CAPACITY, + codec_config.max_frame_length_value(), + ); + framed + .read_buffer_mut() + .reserve(initial_read_buffer_capacity); + + // Invoke connection setup hook if configured. + let connection_state = if let Some(ref setup) = self.lifecycle_hooks.on_connect { + Some(setup().await) + } else { + None + }; + + Ok(WireframeClient { + framed, + serializer: self.serializer, + codec_config, + connection_state, + on_disconnect: self.lifecycle_hooks.on_disconnect, + on_error: self.lifecycle_hooks.on_error, + correlation_counter: AtomicU64::new(1), + }) + } +} diff --git a/src/client/builder/core.rs b/src/client/builder/core.rs new file mode 100644 index 00000000..13727aa8 --- /dev/null +++ b/src/client/builder/core.rs @@ -0,0 +1,61 @@ +//! Core wireframe client builder type. + +use crate::{ + client::{ + ClientCodecConfig, + SocketOptions, + hooks::LifecycleHooks, + preamble_exchange::PreambleConfig, + }, + serializer::BincodeSerializer, +}; + +/// Builder for [`WireframeClient`](crate::client::WireframeClient). +/// +/// The builder supports three generic type parameters: +/// - `S`: The serializer type (default: `BincodeSerializer`) +/// - `P`: The preamble type (default: `()`) +/// - `C`: The connection state type returned by the setup hook (default: `()`) +/// +/// # Examples +/// +/// ``` +/// use wireframe::client::WireframeClientBuilder; +/// +/// let builder = WireframeClientBuilder::new(); +/// let _ = builder; +/// ``` +pub struct WireframeClientBuilder { + pub(crate) serializer: S, + pub(crate) codec_config: ClientCodecConfig, + pub(crate) socket_options: SocketOptions, + pub(crate) preamble_config: Option>, + pub(crate) lifecycle_hooks: LifecycleHooks, +} + +impl WireframeClientBuilder { + /// Create a new builder with default settings. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// let builder = WireframeClientBuilder::new(); + /// let _ = builder; + /// ``` + #[must_use] + pub fn new() -> Self { + Self { + serializer: BincodeSerializer, + codec_config: ClientCodecConfig::default(), + socket_options: SocketOptions::default(), + preamble_config: None, + lifecycle_hooks: LifecycleHooks::default(), + } + } +} + +impl Default for WireframeClientBuilder { + fn default() -> Self { Self::new() } +} diff --git a/src/client/builder/lifecycle.rs b/src/client/builder/lifecycle.rs new file mode 100644 index 00000000..1ffcbe86 --- /dev/null +++ b/src/client/builder/lifecycle.rs @@ -0,0 +1,130 @@ +//! Lifecycle hook methods for `WireframeClientBuilder`. + +use std::{future::Future, sync::Arc}; + +use super::WireframeClientBuilder; +use crate::{ + client::{ClientError, hooks::LifecycleHooks}, + serializer::Serializer, +}; + +impl WireframeClientBuilder +where + S: Serializer + Send + Sync, +{ + /// Register a callback invoked when the connection is established. + /// + /// The callback can perform authentication or other setup tasks and returns + /// connection-specific state stored for the connection's lifetime. This + /// hook is invoked after the preamble exchange (if configured) succeeds. + /// + /// # Type Parameters + /// + /// This method changes the connection state type parameter from `C` to + /// `C2`. Subsequent builder methods will operate on the new connection + /// state type. Note that any previously configured `on_connection_teardown` + /// hook is cleared because its type signature depends on the old state + /// type. The `on_error` hook is preserved since it does not depend on + /// the connection state type. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// struct Session { + /// id: u64, + /// } + /// + /// let builder = + /// WireframeClientBuilder::new().on_connection_setup(|| async { Session { id: 42 } }); + /// let _ = builder; + /// ``` + #[must_use] + pub fn on_connection_setup(self, f: F) -> WireframeClientBuilder + where + F: Fn() -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + C2: Send + 'static, + { + // Preserve on_error since it is not parameterized by C. + // on_disconnect must be cleared because its signature depends on C. + let on_error = self.lifecycle_hooks.on_error; + builder_field_update!( + self, + lifecycle_hooks = LifecycleHooks { + on_connect: Some(Arc::new(move || Box::pin(f()))), + on_disconnect: None, + on_error, + } + ) + } +} + +impl WireframeClientBuilder +where + S: Serializer + Send + Sync, + C: Send + 'static, +{ + /// Register a callback invoked when the connection is closed. + /// + /// The callback receives the connection state produced by + /// [`on_connection_setup`](Self::on_connection_setup). The teardown hook + /// is invoked when [`WireframeClient::close`](crate::client::WireframeClient::close) + /// is called. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// struct Session { + /// id: u64, + /// } + /// + /// let builder = WireframeClientBuilder::new() + /// .on_connection_setup(|| async { Session { id: 42 } }) + /// .on_connection_teardown(|session| async move { + /// println!("Session {} closed", session.id); + /// }); + /// let _ = builder; + /// ``` + #[must_use] + pub fn on_connection_teardown(mut self, f: F) -> Self + where + F: Fn(C) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.lifecycle_hooks.on_disconnect = Some(Arc::new(move |c| Box::pin(f(c)))); + self + } + + /// Register a callback invoked when an error occurs. + /// + /// The callback receives a reference to the error and can perform logging + /// or recovery actions. The handler is invoked before the error is returned + /// to the caller. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// let builder = WireframeClientBuilder::new().on_error(|err| { + /// let message = err.to_string(); + /// async move { + /// eprintln!("Client error: {message}"); + /// } + /// }); + /// let _ = builder; + /// ``` + #[must_use] + pub fn on_error(mut self, f: F) -> Self + where + F: for<'a> Fn(&'a ClientError) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.lifecycle_hooks.on_error = Some(Arc::new(move |e| Box::pin(f(e)))); + self + } +} diff --git a/src/client/builder/mod.rs b/src/client/builder/mod.rs new file mode 100644 index 00000000..56ba8343 --- /dev/null +++ b/src/client/builder/mod.rs @@ -0,0 +1,54 @@ +//! Builder for configuring and connecting a wireframe client. + +/// Reconstructs `WireframeClientBuilder` with one field updated to a new value. +/// +/// This macro reduces duplication in type-changing builder methods that need to +/// create a new builder instance with different generic parameters. When a type +/// parameter changes, struct update syntax (`..self`) cannot be used, so fields +/// must be copied explicitly. +/// +/// The `lifecycle_hooks` field requires special handling because `LifecycleHooks` +/// is parameterized by the connection state type. When changing `S` or `P`, the +/// hooks can be moved directly since `C` is unchanged. When changing `C` via +/// `on_connection_setup`, a new `LifecycleHooks` must be constructed. +macro_rules! builder_field_update { + // Serializer change: preserves P and C, moves lifecycle_hooks unchanged + ($self:expr,serializer = $value:expr) => { + WireframeClientBuilder { + serializer: $value, + codec_config: $self.codec_config, + socket_options: $self.socket_options, + preamble_config: $self.preamble_config, + lifecycle_hooks: $self.lifecycle_hooks, + } + }; + // Preamble change: preserves S and C, moves lifecycle_hooks unchanged + ($self:expr,preamble_config = $value:expr) => { + WireframeClientBuilder { + serializer: $self.serializer, + codec_config: $self.codec_config, + socket_options: $self.socket_options, + preamble_config: $value, + lifecycle_hooks: $self.lifecycle_hooks, + } + }; + // Lifecycle hooks change: preserves S and P, changes C + ($self:expr,lifecycle_hooks = $value:expr) => { + WireframeClientBuilder { + serializer: $self.serializer, + codec_config: $self.codec_config, + socket_options: $self.socket_options, + preamble_config: $self.preamble_config, + lifecycle_hooks: $value, + } + }; +} + +mod codec; +mod connect; +mod core; +mod lifecycle; +mod preamble; +mod serializer; + +pub use core::WireframeClientBuilder; diff --git a/src/client/builder/preamble.rs b/src/client/builder/preamble.rs new file mode 100644 index 00000000..6fb57889 --- /dev/null +++ b/src/client/builder/preamble.rs @@ -0,0 +1,40 @@ +//! Preamble configuration methods for `WireframeClientBuilder`. + +use bincode::Encode; + +use super::WireframeClientBuilder; +use crate::{client::preamble_exchange::PreambleConfig, serializer::Serializer}; + +impl WireframeClientBuilder +where + S: Serializer + Send + Sync, +{ + /// Configure a preamble to send before exchanging frames. + /// + /// The preamble is written to the server immediately after establishing + /// the TCP connection, before the framing layer begins. Use + /// [`on_preamble_success`](Self::on_preamble_success) to read the server's + /// response and [`preamble_timeout`](Self::preamble_timeout) to bound the + /// exchange. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// #[derive(bincode::Encode)] + /// struct MyPreamble { + /// version: u16, + /// } + /// + /// let builder = WireframeClientBuilder::new().with_preamble(MyPreamble { version: 1 }); + /// let _ = builder; + /// ``` + #[must_use] + pub fn with_preamble(self, preamble: Q) -> WireframeClientBuilder + where + Q: Encode + Send + Sync + 'static, + { + builder_field_update!(self, preamble_config = Some(PreambleConfig::new(preamble))) + } +} diff --git a/src/client/builder/serializer.rs b/src/client/builder/serializer.rs new file mode 100644 index 00000000..be646f27 --- /dev/null +++ b/src/client/builder/serializer.rs @@ -0,0 +1,27 @@ +//! Serializer configuration methods for `WireframeClientBuilder`. + +use super::WireframeClientBuilder; +use crate::serializer::Serializer; + +impl WireframeClientBuilder +where + S: Serializer + Send + Sync, +{ + /// Replace the serializer used for encoding and decoding messages. + /// + /// # Examples + /// + /// ``` + /// use wireframe::{BincodeSerializer, client::WireframeClientBuilder}; + /// + /// let builder = WireframeClientBuilder::new().serializer(BincodeSerializer); + /// let _ = builder; + /// ``` + #[must_use] + pub fn serializer(self, serializer: Ser) -> WireframeClientBuilder + where + Ser: Serializer + Send + Sync, + { + builder_field_update!(self, serializer = serializer) + } +} diff --git a/src/codec/recovery.rs b/src/codec/recovery.rs deleted file mode 100644 index 040398ee..00000000 --- a/src/codec/recovery.rs +++ /dev/null @@ -1,331 +0,0 @@ -//! Recovery policy types and hooks for codec error handling. -//! -//! This module provides infrastructure for customising how codec errors are -//! handled, including recovery policies, error context for structured logging, -//! and hooks for application-specific error handling. -//! -//! # Recovery Policies -//! -//! When a codec error occurs, the framework applies a recovery policy: -//! -//! - [`RecoveryPolicy::Drop`]: Discard the malformed frame and continue processing. Suitable for -//! recoverable errors like oversized frames. -//! - [`RecoveryPolicy::Quarantine`]: Pause the connection temporarily before retrying. Useful for -//! rate-limiting misbehaving clients. -//! - [`RecoveryPolicy::Disconnect`]: Terminate the connection immediately. Required for -//! unrecoverable errors like I/O failures. -//! -//! # Custom Recovery Hooks -//! -//! Applications can customise error handling by implementing -//! [`RecoveryPolicyHook`]: -//! -//! ``` -//! use std::time::Duration; -//! -//! use wireframe::codec::{CodecError, CodecErrorContext, RecoveryPolicy, RecoveryPolicyHook}; -//! -//! struct StrictRecovery; -//! -//! impl RecoveryPolicyHook for StrictRecovery { -//! fn recovery_policy(&self, _error: &CodecError, _ctx: &CodecErrorContext) -> RecoveryPolicy { -//! // Disconnect on any error -//! RecoveryPolicy::Disconnect -//! } -//! } -//! ``` - -use std::{net::SocketAddr, time::Duration}; - -use super::error::CodecError; - -/// Recovery policies for codec errors. -/// -/// Each policy defines how the framework responds to a codec error. -/// -/// # Default Behaviour -/// -/// [`CodecError::default_recovery_policy`] returns the recommended policy for -/// each error type. Applications can override this via [`RecoveryPolicyHook`]. -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] -pub enum RecoveryPolicy { - /// Discard the malformed frame and continue processing. - /// - /// This is the default for recoverable errors like oversized frames or - /// protocol violations that affect only a single message. The connection - /// remains open and can process subsequent frames. - /// - /// # When to Use - /// - /// - Oversized frames that exceed `max_frame_length` - /// - Empty frames where non-empty is expected - /// - Protocol-level errors (unknown message type, sequence violation) - #[default] - Drop, - - /// Pause the connection temporarily before retrying. - /// - /// The connection enters a quarantine state for a configurable duration. - /// During quarantine, no frames are processed. After the timeout, normal - /// processing resumes. - /// - /// # When to Use - /// - /// - Rate-limiting misbehaving clients - /// - Temporary back-off after repeated errors - /// - Giving time for upstream issues to resolve - Quarantine, - - /// Terminate the connection immediately. - /// - /// The connection is closed without processing further frames. This is - /// required for unrecoverable errors where the framing state is corrupted - /// or the transport has failed. - /// - /// # When to Use - /// - /// - I/O errors (socket closed, write failed) - /// - Invalid frame length encoding (framing state corrupted) - /// - EOF conditions (connection ending) - Disconnect, -} - -impl RecoveryPolicy { - /// Returns the policy name as a static string for metrics and logging. - /// - /// # Examples - /// - /// ``` - /// use wireframe::codec::RecoveryPolicy; - /// - /// assert_eq!(RecoveryPolicy::Drop.as_str(), "drop"); - /// assert_eq!(RecoveryPolicy::Quarantine.as_str(), "quarantine"); - /// assert_eq!(RecoveryPolicy::Disconnect.as_str(), "disconnect"); - /// ``` - #[must_use] - pub const fn as_str(self) -> &'static str { - match self { - Self::Drop => "drop", - Self::Quarantine => "quarantine", - Self::Disconnect => "disconnect", - } - } -} - -/// Structured context for codec error logging and diagnostics. -/// -/// This struct captures connection-specific information to include in -/// structured logs and metrics when codec errors occur. -/// -/// # Examples -/// -/// ``` -/// use wireframe::codec::CodecErrorContext; -/// -/// let ctx = CodecErrorContext::new() -/// .with_connection_id(42) -/// .with_correlation_id(123); -/// -/// assert_eq!(ctx.connection_id, Some(42)); -/// assert_eq!(ctx.correlation_id, Some(123)); -/// ``` -#[derive(Clone, Debug, Default)] -pub struct CodecErrorContext { - /// Unique identifier for the connection. - pub connection_id: Option, - - /// Remote peer address. - pub peer_address: Option, - - /// Correlation identifier from the frame, if available. - /// - /// This helps attribute errors to specific requests when debugging. - pub correlation_id: Option, - - /// Codec instance state for debugging. - /// - /// May include sequence numbers, bytes processed, or other codec-specific - /// state information. - pub codec_state: Option, -} - -impl CodecErrorContext { - /// Create a new empty context. - #[must_use] - pub fn new() -> Self { Self::default() } - - /// Set the connection identifier. - #[must_use] - pub fn with_connection_id(mut self, id: u64) -> Self { - self.connection_id = Some(id); - self - } - - /// Set the peer address. - #[must_use] - pub fn with_peer_address(mut self, addr: SocketAddr) -> Self { - self.peer_address = Some(addr); - self - } - - /// Set the correlation identifier. - #[must_use] - pub fn with_correlation_id(mut self, id: u64) -> Self { - self.correlation_id = Some(id); - self - } - - /// Set codec-specific state information. - #[must_use] - pub fn with_codec_state(mut self, state: impl Into) -> Self { - self.codec_state = Some(state.into()); - self - } -} - -/// Hook trait for customising codec error recovery behaviour. -/// -/// Implementations can override default recovery policies based on -/// application-specific requirements or connection state. -/// -/// # Default Implementation -/// -/// The default implementation ([`DefaultRecoveryPolicy`]) delegates to -/// [`CodecError::default_recovery_policy`] for all errors. -/// -/// # Examples -/// -/// ``` -/// use std::time::Duration; -/// -/// use wireframe::codec::{ -/// CodecError, -/// CodecErrorContext, -/// EofError, -/// RecoveryPolicy, -/// RecoveryPolicyHook, -/// }; -/// -/// /// Quarantine connections that close unexpectedly. -/// struct QuarantineOnPrematureEof; -/// -/// impl RecoveryPolicyHook for QuarantineOnPrematureEof { -/// fn recovery_policy(&self, error: &CodecError, _ctx: &CodecErrorContext) -> RecoveryPolicy { -/// match error { -/// CodecError::Eof(EofError::MidFrame { .. }) => RecoveryPolicy::Quarantine, -/// _ => error.default_recovery_policy(), -/// } -/// } -/// -/// fn quarantine_duration(&self, _error: &CodecError, _ctx: &CodecErrorContext) -> Duration { -/// Duration::from_secs(60) -/// } -/// } -/// ``` -pub trait RecoveryPolicyHook: Send + Sync { - /// Determine the recovery policy for a codec error. - /// - /// The default implementation delegates to - /// [`CodecError::default_recovery_policy`]. - fn recovery_policy(&self, error: &CodecError, _ctx: &CodecErrorContext) -> RecoveryPolicy { - error.default_recovery_policy() - } - - /// Returns the quarantine duration when [`RecoveryPolicy::Quarantine`] is - /// selected. - /// - /// Default: 30 seconds. - fn quarantine_duration(&self, _error: &CodecError, _ctx: &CodecErrorContext) -> Duration { - Duration::from_secs(30) - } - - /// Called before applying the recovery policy. - /// - /// Use this hook for logging, metrics emission, or state updates. The - /// default implementation does nothing. - fn on_error(&self, _error: &CodecError, _ctx: &CodecErrorContext, _policy: RecoveryPolicy) {} -} - -/// Default recovery policy implementation. -/// -/// This implementation uses the built-in default policies from -/// [`CodecError::default_recovery_policy`] without any customisation. -#[derive(Clone, Copy, Debug, Default)] -pub struct DefaultRecoveryPolicy; - -impl RecoveryPolicyHook for DefaultRecoveryPolicy {} - -/// Configuration for recovery policy behaviour. -/// -/// Use this to configure global recovery settings on the application builder. -/// -/// # Examples -/// -/// ``` -/// use std::time::Duration; -/// -/// use wireframe::codec::RecoveryConfig; -/// -/// let config = RecoveryConfig::default() -/// .max_consecutive_drops(5) -/// .quarantine_duration(Duration::from_secs(60)) -/// .log_dropped_frames(true); -/// -/// assert_eq!(config.max_consecutive_drops, 5); -/// ``` -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct RecoveryConfig { - /// Maximum consecutive dropped frames before escalating to disconnect. - /// - /// When this threshold is exceeded, the recovery policy escalates from - /// [`RecoveryPolicy::Drop`] to [`RecoveryPolicy::Disconnect`]. - /// - /// Default: 10. - pub max_consecutive_drops: u32, - - /// Default quarantine duration. - /// - /// Default: 30 seconds. - pub quarantine_duration: Duration, - - /// Whether to log dropped frames at warn level. - /// - /// Default: true. - pub log_dropped_frames: bool, -} - -impl Default for RecoveryConfig { - fn default() -> Self { - Self { - max_consecutive_drops: 10, - quarantine_duration: Duration::from_secs(30), - log_dropped_frames: true, - } - } -} - -impl RecoveryConfig { - /// Set the maximum consecutive dropped frames before disconnect. - #[must_use] - pub fn max_consecutive_drops(mut self, count: u32) -> Self { - self.max_consecutive_drops = count; - self - } - - /// Set the default quarantine duration. - #[must_use] - pub fn quarantine_duration(mut self, duration: Duration) -> Self { - self.quarantine_duration = duration; - self - } - - /// Set whether to log dropped frames. - #[must_use] - pub fn log_dropped_frames(mut self, enabled: bool) -> Self { - self.log_dropped_frames = enabled; - self - } -} - -#[cfg(test)] -mod tests; diff --git a/src/codec/recovery/config.rs b/src/codec/recovery/config.rs new file mode 100644 index 00000000..69e75958 --- /dev/null +++ b/src/codec/recovery/config.rs @@ -0,0 +1,76 @@ +//! Configuration for codec recovery policies. + +use std::time::Duration; + +/// Configuration for recovery policy behaviour. +/// +/// Use this to configure global recovery settings on the application builder. +/// +/// # Examples +/// +/// ``` +/// use std::time::Duration; +/// +/// use wireframe::codec::RecoveryConfig; +/// +/// let config = RecoveryConfig::default() +/// .max_consecutive_drops(5) +/// .quarantine_duration(Duration::from_secs(60)) +/// .log_dropped_frames(true); +/// +/// assert_eq!(config.max_consecutive_drops, 5); +/// ``` +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct RecoveryConfig { + /// Maximum consecutive dropped frames before escalating to disconnect. + /// + /// When this threshold is exceeded, the recovery policy escalates from + /// [`RecoveryPolicy::Drop`](crate::codec::RecoveryPolicy::Drop) to + /// [`RecoveryPolicy::Disconnect`](crate::codec::RecoveryPolicy::Disconnect). + /// + /// Default: 10. + pub max_consecutive_drops: u32, + + /// Default quarantine duration. + /// + /// Default: 30 seconds. + pub quarantine_duration: Duration, + + /// Whether to log dropped frames at warn level. + /// + /// Default: true. + pub log_dropped_frames: bool, +} + +impl Default for RecoveryConfig { + fn default() -> Self { + Self { + max_consecutive_drops: 10, + quarantine_duration: Duration::from_secs(30), + log_dropped_frames: true, + } + } +} + +impl RecoveryConfig { + /// Set the maximum consecutive dropped frames before disconnect. + #[must_use] + pub fn max_consecutive_drops(mut self, count: u32) -> Self { + self.max_consecutive_drops = count; + self + } + + /// Set the default quarantine duration. + #[must_use] + pub fn quarantine_duration(mut self, duration: Duration) -> Self { + self.quarantine_duration = duration; + self + } + + /// Set whether to log dropped frames. + #[must_use] + pub fn log_dropped_frames(mut self, enabled: bool) -> Self { + self.log_dropped_frames = enabled; + self + } +} diff --git a/src/codec/recovery/context.rs b/src/codec/recovery/context.rs new file mode 100644 index 00000000..5ab103a2 --- /dev/null +++ b/src/codec/recovery/context.rs @@ -0,0 +1,74 @@ +//! Structured context for codec error recovery and logging. + +use std::net::SocketAddr; + +/// Structured context for codec error logging and diagnostics. +/// +/// This struct captures connection-specific information to include in +/// structured logs and metrics when codec errors occur. +/// +/// # Examples +/// +/// ``` +/// use wireframe::codec::CodecErrorContext; +/// +/// let ctx = CodecErrorContext::new() +/// .with_connection_id(42) +/// .with_correlation_id(123); +/// +/// assert_eq!(ctx.connection_id, Some(42)); +/// assert_eq!(ctx.correlation_id, Some(123)); +/// ``` +#[derive(Clone, Debug, Default)] +pub struct CodecErrorContext { + /// Unique identifier for the connection. + pub connection_id: Option, + + /// Remote peer address. + pub peer_address: Option, + + /// Correlation identifier from the frame, if available. + /// + /// This helps attribute errors to specific requests when debugging. + pub correlation_id: Option, + + /// Codec instance state for debugging. + /// + /// May include sequence numbers, bytes processed, or other codec-specific + /// state information. + pub codec_state: Option, +} + +impl CodecErrorContext { + /// Create a new empty context. + #[must_use] + pub fn new() -> Self { Self::default() } + + /// Set the connection identifier. + #[must_use] + pub fn with_connection_id(mut self, id: u64) -> Self { + self.connection_id = Some(id); + self + } + + /// Set the peer address. + #[must_use] + pub fn with_peer_address(mut self, addr: SocketAddr) -> Self { + self.peer_address = Some(addr); + self + } + + /// Set the correlation identifier. + #[must_use] + pub fn with_correlation_id(mut self, id: u64) -> Self { + self.correlation_id = Some(id); + self + } + + /// Set codec-specific state information. + #[must_use] + pub fn with_codec_state(mut self, state: impl Into) -> Self { + self.codec_state = Some(state.into()); + self + } +} diff --git a/src/codec/recovery/hook.rs b/src/codec/recovery/hook.rs new file mode 100644 index 00000000..aec5d1e6 --- /dev/null +++ b/src/codec/recovery/hook.rs @@ -0,0 +1,78 @@ +//! Hooks for customising codec error recovery behaviour. + +use std::time::Duration; + +use super::{CodecErrorContext, RecoveryPolicy}; +use crate::codec::error::CodecError; + +/// Hook trait for customising codec error recovery behaviour. +/// +/// Implementations can override default recovery policies based on +/// application-specific requirements or connection state. +/// +/// # Default Implementation +/// +/// The default implementation ([`DefaultRecoveryPolicy`]) delegates to +/// [`CodecError::default_recovery_policy`] for all errors. +/// +/// # Examples +/// +/// ``` +/// use std::time::Duration; +/// +/// use wireframe::codec::{ +/// CodecError, +/// CodecErrorContext, +/// EofError, +/// RecoveryPolicy, +/// RecoveryPolicyHook, +/// }; +/// +/// /// Quarantine connections that close unexpectedly. +/// struct QuarantineOnPrematureEof; +/// +/// impl RecoveryPolicyHook for QuarantineOnPrematureEof { +/// fn recovery_policy(&self, error: &CodecError, _ctx: &CodecErrorContext) -> RecoveryPolicy { +/// match error { +/// CodecError::Eof(EofError::MidFrame { .. }) => RecoveryPolicy::Quarantine, +/// _ => error.default_recovery_policy(), +/// } +/// } +/// +/// fn quarantine_duration(&self, _error: &CodecError, _ctx: &CodecErrorContext) -> Duration { +/// Duration::from_secs(60) +/// } +/// } +/// ``` +pub trait RecoveryPolicyHook: Send + Sync { + /// Determine the recovery policy for a codec error. + /// + /// The default implementation delegates to + /// [`CodecError::default_recovery_policy`]. + fn recovery_policy(&self, error: &CodecError, _ctx: &CodecErrorContext) -> RecoveryPolicy { + error.default_recovery_policy() + } + + /// Returns the quarantine duration when [`RecoveryPolicy::Quarantine`] is + /// selected. + /// + /// Default: 30 seconds. + fn quarantine_duration(&self, _error: &CodecError, _ctx: &CodecErrorContext) -> Duration { + Duration::from_secs(30) + } + + /// Called before applying the recovery policy. + /// + /// Use this hook for logging, metrics emission, or state updates. The + /// default implementation does nothing. + fn on_error(&self, _error: &CodecError, _ctx: &CodecErrorContext, _policy: RecoveryPolicy) {} +} + +/// Default recovery policy implementation. +/// +/// This implementation uses the built-in default policies from +/// [`CodecError::default_recovery_policy`] without any customisation. +#[derive(Clone, Copy, Debug, Default)] +pub struct DefaultRecoveryPolicy; + +impl RecoveryPolicyHook for DefaultRecoveryPolicy {} diff --git a/src/codec/recovery/mod.rs b/src/codec/recovery/mod.rs new file mode 100644 index 00000000..c5f39e6a --- /dev/null +++ b/src/codec/recovery/mod.rs @@ -0,0 +1,49 @@ +//! Recovery policy types and hooks for codec error handling. +//! +//! This module provides infrastructure for customising how codec errors are +//! handled, including recovery policies, error context for structured logging, +//! and hooks for application-specific error handling. +//! +//! # Recovery Policies +//! +//! When a codec error occurs, the framework applies a recovery policy: +//! +//! - [`RecoveryPolicy::Drop`]: Discard the malformed frame and continue processing. Suitable for +//! recoverable errors like oversized frames. +//! - [`RecoveryPolicy::Quarantine`]: Pause the connection temporarily before retrying. Useful for +//! rate-limiting misbehaving clients. +//! - [`RecoveryPolicy::Disconnect`]: Terminate the connection immediately. Required for +//! unrecoverable errors like I/O failures. +//! +//! # Custom Recovery Hooks +//! +//! Applications can customise error handling by implementing +//! [`RecoveryPolicyHook`]: +//! +//! ``` +//! use std::time::Duration; +//! +//! use wireframe::codec::{CodecError, CodecErrorContext, RecoveryPolicy, RecoveryPolicyHook}; +//! +//! struct StrictRecovery; +//! +//! impl RecoveryPolicyHook for StrictRecovery { +//! fn recovery_policy(&self, _error: &CodecError, _ctx: &CodecErrorContext) -> RecoveryPolicy { +//! // Disconnect on any error +//! RecoveryPolicy::Disconnect +//! } +//! } +//! ``` + +mod config; +mod context; +mod hook; +mod policy; + +pub use config::RecoveryConfig; +pub use context::CodecErrorContext; +pub use hook::{DefaultRecoveryPolicy, RecoveryPolicyHook}; +pub use policy::RecoveryPolicy; + +#[cfg(test)] +mod tests; diff --git a/src/codec/recovery/policy.rs b/src/codec/recovery/policy.rs new file mode 100644 index 00000000..865e12d9 --- /dev/null +++ b/src/codec/recovery/policy.rs @@ -0,0 +1,75 @@ +//! Recovery policies for codec errors. + +/// Recovery policies for codec errors. +/// +/// Each policy defines how the framework responds to a codec error. +/// +/// # Default Behaviour +/// +/// [`CodecError::default_recovery_policy`](crate::codec::CodecError::default_recovery_policy) +/// returns the recommended policy for each error type. Applications can +/// override this via [`RecoveryPolicyHook`](crate::codec::RecoveryPolicyHook). +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum RecoveryPolicy { + /// Discard the malformed frame and continue processing. + /// + /// This is the default for recoverable errors like oversized frames or + /// protocol violations that affect only a single message. The connection + /// remains open and can process subsequent frames. + /// + /// # When to Use + /// + /// - Oversized frames that exceed `max_frame_length` + /// - Empty frames where non-empty is expected + /// - Protocol-level errors (unknown message type, sequence violation) + #[default] + Drop, + + /// Pause the connection temporarily before retrying. + /// + /// The connection enters a quarantine state for a configurable duration. + /// During quarantine, no frames are processed. After the timeout, normal + /// processing resumes. + /// + /// # When to Use + /// + /// - Rate-limiting misbehaving clients + /// - Temporary back-off after repeated errors + /// - Giving time for upstream issues to resolve + Quarantine, + + /// Terminate the connection immediately. + /// + /// The connection is closed without processing further frames. This is + /// required for unrecoverable errors where the framing state is corrupted + /// or the transport has failed. + /// + /// # When to Use + /// + /// - I/O errors (socket closed, write failed) + /// - Invalid frame length encoding (framing state corrupted) + /// - EOF conditions (connection ending) + Disconnect, +} + +impl RecoveryPolicy { + /// Returns the policy name as a static string for metrics and logging. + /// + /// # Examples + /// + /// ``` + /// use wireframe::codec::RecoveryPolicy; + /// + /// assert_eq!(RecoveryPolicy::Drop.as_str(), "drop"); + /// assert_eq!(RecoveryPolicy::Quarantine.as_str(), "quarantine"); + /// assert_eq!(RecoveryPolicy::Disconnect.as_str(), "disconnect"); + /// ``` + #[must_use] + pub const fn as_str(self) -> &'static str { + match self { + Self::Drop => "drop", + Self::Quarantine => "quarantine", + Self::Disconnect => "disconnect", + } + } +} diff --git a/src/codec/recovery/tests.rs b/src/codec/recovery/tests.rs index 7534667f..57e0daf3 100644 --- a/src/codec/recovery/tests.rs +++ b/src/codec/recovery/tests.rs @@ -1,22 +1,12 @@ -//! Tests for codec recovery policies, hooks, and configuration. +//! Tests for codec recovery policies. -use std::{io, net::SocketAddr, time::Duration}; - -use rstest::{fixture, rstest}; +use std::{io, time::Duration}; use super::*; - -#[fixture] -fn default_hook() -> DefaultRecoveryPolicy { - // Use the framework default hook for baseline policy assertions. - DefaultRecoveryPolicy -} - -#[fixture] -fn context() -> CodecErrorContext { - // These tests exercise hook behaviour without connection metadata. - CodecErrorContext::new() -} +use crate::codec::{ + CodecError, + error::{EofError, FramingError}, +}; #[test] fn recovery_policy_default_is_drop() { @@ -37,47 +27,44 @@ fn context_builder_sets_fields() { #[test] fn context_with_peer_address() { - let addr: SocketAddr = "127.0.0.1:8080".parse().expect("valid test address"); + let addr: std::net::SocketAddr = "127.0.0.1:8080".parse().expect("valid test address"); let ctx = CodecErrorContext::new().with_peer_address(addr); assert_eq!(ctx.peer_address, Some(addr)); } -#[rstest] -fn default_recovery_policy_delegates_to_error( - default_hook: DefaultRecoveryPolicy, - context: CodecErrorContext, -) { - use crate::codec::error::{EofError, FramingError}; +#[test] +fn default_recovery_policy_delegates_to_error() { + let default_hook = DefaultRecoveryPolicy; + let default_ctx = CodecErrorContext::new(); // Check various error types let err = CodecError::Framing(FramingError::OversizedFrame { size: 100, max: 50 }); assert_eq!( - default_hook.recovery_policy(&err, &context), + default_hook.recovery_policy(&err, &default_ctx), RecoveryPolicy::Drop ); let err = CodecError::Io(io::Error::other("test")); assert_eq!( - default_hook.recovery_policy(&err, &context), + default_hook.recovery_policy(&err, &default_ctx), RecoveryPolicy::Disconnect ); let err = CodecError::Eof(EofError::CleanClose); assert_eq!( - default_hook.recovery_policy(&err, &context), + default_hook.recovery_policy(&err, &default_ctx), RecoveryPolicy::Disconnect ); } -#[rstest] -fn default_quarantine_duration_is_30_seconds( - default_hook: DefaultRecoveryPolicy, - context: CodecErrorContext, -) { - let io_error = CodecError::Io(io::Error::other("test")); +#[test] +fn default_quarantine_duration_is_30_seconds() { + let default_hook = DefaultRecoveryPolicy; + let default_ctx = CodecErrorContext::new(); + let err = CodecError::Io(io::Error::other("test")); assert_eq!( - default_hook.quarantine_duration(&io_error, &context), + default_hook.quarantine_duration(&err, &default_ctx), Duration::from_secs(30) ); } diff --git a/src/extractor/connection_info.rs b/src/extractor/connection_info.rs new file mode 100644 index 00000000..9ac30659 --- /dev/null +++ b/src/extractor/connection_info.rs @@ -0,0 +1,48 @@ +//! Extractor for connection metadata. + +use std::net::SocketAddr; + +use super::{FromMessageRequest, MessageRequest, Payload}; + +/// Extractor providing peer connection metadata. +#[derive(Debug, Clone, Copy)] +pub struct ConnectionInfo { + peer_addr: Option, +} + +impl ConnectionInfo { + /// Returns the peer's socket address for the current connection, if available. + /// + /// # Examples + /// + /// ```rust + /// use std::net::SocketAddr; + /// + /// use wireframe::extractor::{ConnectionInfo, FromMessageRequest, MessageRequest, Payload}; + /// + /// let addr: SocketAddr = "127.0.0.1:8080".parse().expect("valid socket address"); + /// let req = MessageRequest::new().with_peer_addr(Some(addr)); + /// let info = ConnectionInfo::from_message_request(&req, &mut Payload::default()) + /// .expect("connection info extraction should succeed"); + /// assert_eq!(info.peer_addr(), Some(addr)); + /// ``` + #[must_use] + pub fn peer_addr(&self) -> Option { self.peer_addr } +} + +impl FromMessageRequest for ConnectionInfo { + type Error = std::convert::Infallible; + + /// Extracts connection metadata from the message request. + /// + /// Returns a `ConnectionInfo` containing the peer's socket address, if available. This + /// extraction is infallible. + fn from_message_request( + req: &MessageRequest, + _payload: &mut Payload<'_>, + ) -> Result { + Ok(Self { + peer_addr: req.peer_addr, + }) + } +} diff --git a/src/extractor/error.rs b/src/extractor/error.rs new file mode 100644 index 00000000..6d04be73 --- /dev/null +++ b/src/extractor/error.rs @@ -0,0 +1,25 @@ +//! Error types for built-in extractors. + +use thiserror::Error; + +/// Errors that can occur when extracting built-in types. +/// +/// This enum is marked `#[non_exhaustive]` so more variants may be added in +/// the future without breaking changes. +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum ExtractError { + /// No shared state of the requested type was found. + #[error("no shared state registered for {0}")] + MissingState(&'static str), + /// Failed to decode the message payload. + #[error("failed to decode payload: {0}")] + InvalidPayload(#[source] bincode::error::DecodeError), + /// No streaming body was available for this request. + /// + /// This occurs when: + /// - The request was not configured for streaming consumption + /// - The stream was already consumed by another extractor + #[error("no streaming body available for this request")] + MissingBodyStream, +} diff --git a/src/extractor/extractors.rs b/src/extractor/extractors.rs deleted file mode 100644 index 266e2efb..00000000 --- a/src/extractor/extractors.rs +++ /dev/null @@ -1,182 +0,0 @@ -//! Built-in extractor implementations for message, streaming body, and connection metadata. - -use std::net::SocketAddr; - -use super::{ExtractError, FromMessageRequest, MessageRequest, Payload}; -use crate::message::Message as WireMessage; - -/// Extractor that deserializes the message payload into `T`. -#[derive(Debug, Clone)] -pub struct Message(T); - -impl Message { - /// Consumes the extractor and returns the inner deserialized message value. - #[must_use] - pub fn into_inner(self) -> T { self.0 } -} - -impl std::ops::Deref for Message { - type Target = T; - - /// Returns a reference to the inner value. - /// - /// This enables transparent access to the wrapped type via dereferencing. - fn deref(&self) -> &Self::Target { &self.0 } -} - -impl FromMessageRequest for Message -where - T: WireMessage, -{ - type Error = ExtractError; - - /// Attempts to extract and deserialize a message of type `T` from the payload. - /// - /// Advances the payload by the number of bytes consumed during deserialization. - /// Returns an error if the payload cannot be decoded into the target type. - /// - /// # Returns - /// - `Ok(Self)`: The successfully extracted and deserialized message. - /// - `Err(ExtractError::InvalidPayload)`: If deserialization fails. - fn from_message_request( - _req: &MessageRequest, - payload: &mut Payload<'_>, - ) -> Result { - let (msg, consumed) = - T::from_bytes(payload.as_ref()).map_err(ExtractError::InvalidPayload)?; - payload.advance(consumed); - Ok(Self(msg)) - } -} - -/// Extractor providing streaming access to the request body. -/// -/// Unlike [`Payload`] which borrows buffered bytes, this extractor -/// takes ownership of a streaming body channel. Handlers opting into -/// streaming receive chunks incrementally via a [`RequestBodyStream`]. -/// -/// This type is the inbound counterpart to [`crate::Response::Stream`]. -/// -/// # Examples -/// -/// ``` -/// use bytes::Bytes; -/// use tokio::io::AsyncReadExt; -/// use wireframe::{extractor::StreamingBody, request::body_channel}; -/// -/// # #[tokio::main] -/// # async fn main() { -/// let (tx, stream) = body_channel(4); -/// -/// tokio::spawn(async move { -/// let _ = tx.send(Ok(Bytes::from_static(b"payload"))).await; -/// }); -/// -/// let body = StreamingBody::new(stream); -/// let mut reader = body.into_reader(); -/// let mut buf = Vec::new(); -/// reader.read_to_end(&mut buf).await.expect("read body"); -/// assert_eq!(buf, b"payload"); -/// # } -/// ``` -/// -/// [`RequestBodyStream`]: crate::request::RequestBodyStream -pub struct StreamingBody { - stream: crate::request::RequestBodyStream, -} - -impl StreamingBody { - /// Create a streaming body from the given stream. - /// - /// Typically constructed by the framework when a handler opts into - /// streaming request consumption. - #[must_use] - pub fn new(stream: crate::request::RequestBodyStream) -> Self { Self { stream } } - - /// Consume the extractor and return the underlying stream. - /// - /// Use this when you need direct access to the stream for custom - /// processing with [`futures::StreamExt`] methods. - #[must_use] - pub fn into_stream(self) -> crate::request::RequestBodyStream { self.stream } - - /// Convert to an [`AsyncRead`] adapter. - /// - /// Protocol crates can use this to feed streaming bytes into parsers - /// that operate on readers rather than streams. - /// - /// [`AsyncRead`]: tokio::io::AsyncRead - #[must_use] - pub fn into_reader(self) -> crate::request::RequestBodyReader { - crate::request::RequestBodyReader::new(self.stream) - } -} - -impl std::fmt::Debug for StreamingBody { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("StreamingBody").finish_non_exhaustive() - } -} - -impl FromMessageRequest for StreamingBody { - type Error = ExtractError; - - /// Extract the streaming body from the request. - /// - /// # Errors - /// - /// Returns [`ExtractError::MissingBodyStream`] if: - /// - The request was not configured for streaming consumption - /// - The stream was already consumed by another extractor - fn from_message_request( - req: &MessageRequest, - _payload: &mut Payload<'_>, - ) -> Result { - req.take_body_stream() - .map(Self::new) - .ok_or(ExtractError::MissingBodyStream) - } -} - -/// Extractor providing peer connection metadata. -#[derive(Debug, Clone, Copy)] -pub struct ConnectionInfo { - peer_addr: Option, -} - -impl ConnectionInfo { - /// Returns the peer's socket address for the current connection, if available. - /// - /// # Examples - /// - /// ```rust - /// use std::net::SocketAddr; - /// - /// use wireframe::extractor::{ConnectionInfo, FromMessageRequest, MessageRequest, Payload}; - /// - /// let addr: SocketAddr = "127.0.0.1:8080".parse().expect("valid socket address"); - /// let req = MessageRequest::new().with_peer_addr(Some(addr)); - /// let info = ConnectionInfo::from_message_request(&req, &mut Payload::default()) - /// .expect("connection info extraction should succeed"); - /// assert_eq!(info.peer_addr(), Some(addr)); - /// ``` - #[must_use] - pub fn peer_addr(&self) -> Option { self.peer_addr } -} - -impl FromMessageRequest for ConnectionInfo { - type Error = std::convert::Infallible; - - /// Extracts connection metadata from the message request. - /// - /// Returns a `ConnectionInfo` containing the peer's socket address, if available. This - /// extraction is infallible. - fn from_message_request( - req: &MessageRequest, - _payload: &mut Payload<'_>, - ) -> Result { - Ok(Self { - peer_addr: req.peer_addr, - }) - } -} diff --git a/src/extractor/message.rs b/src/extractor/message.rs new file mode 100644 index 00000000..153a3204 --- /dev/null +++ b/src/extractor/message.rs @@ -0,0 +1,47 @@ +//! Message extractor for deserialized payloads. + +use super::{ExtractError, FromMessageRequest, MessageRequest, Payload}; +use crate::message::Message as WireMessage; + +/// Extractor that deserializes the message payload into `T`. +#[derive(Debug, Clone)] +pub struct Message(T); + +impl Message { + /// Consumes the extractor and returns the inner deserialized message value. + #[must_use] + pub fn into_inner(self) -> T { self.0 } +} + +impl std::ops::Deref for Message { + type Target = T; + + /// Returns a reference to the inner value. + /// + /// This enables transparent access to the wrapped type via dereferencing. + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl FromMessageRequest for Message +where + T: WireMessage, +{ + type Error = ExtractError; + + /// Attempts to extract and deserialize a message of type `T` from the payload. + /// + /// Advances the payload by the number of bytes consumed during deserialization. + /// Returns an error if the payload cannot be decoded into the target type. + /// + /// # Returns + /// - `Ok(Self)`: The successfully extracted and deserialized message. + /// - `Err(ExtractError::InvalidPayload)`: If deserialization fails. + fn from_message_request( + _req: &MessageRequest, + payload: &mut Payload<'_>, + ) -> Result { + let (msg, consumed) = T::from_bytes(payload.data).map_err(ExtractError::InvalidPayload)?; + payload.advance(consumed); + Ok(Self(msg)) + } +} diff --git a/src/extractor/mod.rs b/src/extractor/mod.rs index a1a158f5..2174bb8e 100644 --- a/src/extractor/mod.rs +++ b/src/extractor/mod.rs @@ -4,333 +4,18 @@ //! state. Implement [`FromMessageRequest`] for custom extractors to parse //! payload bytes or inspect connection info before your handler runs. -use std::{ - any::{Any, TypeId}, - collections::HashMap, - net::SocketAddr, - sync::{Arc, Mutex}, -}; - -use thiserror::Error; - -use crate::request::RequestBodyStream; - -mod extractors; - -pub use extractors::{ConnectionInfo, Message, StreamingBody}; - -/// Request context passed to extractors. -/// -/// This type contains metadata about the current connection and provides -/// access to application state registered with [`crate::app::WireframeApp`]. -#[derive(Default)] -pub struct MessageRequest { - /// Address of the peer that sent the current message. - pub peer_addr: Option, - /// Shared state values registered with the application. - /// - /// Values are keyed by their [`TypeId`]. Registering additional - /// state of the same type will replace the previous entry. - pub app_data: HashMap>, - /// Optional streaming body for handlers that opt into streaming consumption. - /// - /// When present, the [`StreamingBody`] extractor can take ownership of this - /// stream. Only one handler/extractor may consume the stream; subsequent - /// extractions will receive [`ExtractError::MissingBodyStream`]. - body_stream: Option>>, -} - -impl MessageRequest { - /// Create a new empty message request. - /// - /// Use [`Self::with_peer_addr`] to configure connection metadata. - #[must_use] - pub fn new() -> Self { Self::default() } - - /// Set the peer address for this request. - /// - /// # Examples - /// - /// ```rust - /// use std::net::SocketAddr; - /// - /// use wireframe::extractor::MessageRequest; - /// - /// let addr: SocketAddr = "127.0.0.1:8080".parse().expect("valid socket address"); - /// let req = MessageRequest::new().with_peer_addr(Some(addr)); - /// assert!(req.peer_addr.is_some()); - /// ``` - #[must_use] - pub fn with_peer_addr(mut self, addr: Option) -> Self { - self.peer_addr = addr; - self - } - - /// Retrieve shared state of type `T` if available. - /// - /// Returns `None` when no value of type `T` was registered. - /// - /// # Examples - /// - /// ```rust,no_run - /// use wireframe::{ - /// app::WireframeApp, - /// extractor::{MessageRequest, SharedState}, - /// }; - /// - /// let _app = WireframeApp::new() - /// .expect("failed to initialize app") - /// .app_data(5u32); - /// // The framework populates the request with application data. - /// # let mut req = MessageRequest::default(); - /// # req.insert_state(5u32); - /// let val: Option> = req.state(); - /// assert_eq!(*val.expect("state should be available"), 5); - /// ``` - #[must_use] - pub fn state(&self) -> Option> - where - T: Send + Sync + 'static, - { - self.app_data - .get(&TypeId::of::()) - .and_then(|data| data.clone().downcast::().ok()) - .map(SharedState) - } - - /// Insert shared state of type `T` into the request. - /// - /// This replaces any existing value of the same type. - /// - /// # Examples - /// - /// ```rust - /// use wireframe::extractor::{MessageRequest, SharedState}; - /// - /// let mut req = MessageRequest::default(); - /// req.insert_state(5u32); - /// let val: Option> = req.state(); - /// assert_eq!(*val.expect("state should be available"), 5); - /// ``` - pub fn insert_state(&mut self, state: T) - where - T: Send + Sync + 'static, - { - self.app_data.insert( - TypeId::of::(), - Arc::new(state) as Arc, - ); - } - - /// Set the streaming body for this request. - /// - /// The framework calls this when a handler opts into streaming consumption. - /// The stream can later be taken by the [`StreamingBody`] extractor. - /// - /// # Examples - /// - /// ```rust - /// use wireframe::{extractor::MessageRequest, request::body_channel}; - /// - /// let mut req = MessageRequest::default(); - /// let (_tx, stream) = body_channel(4); - /// req.set_body_stream(stream); - /// ``` - pub fn set_body_stream(&mut self, stream: RequestBodyStream) { - self.body_stream = Some(Mutex::new(Some(stream))); - } - - /// Take the streaming body from this request, if present. - /// - /// Returns `None` if no body stream was set or if it was already taken - /// by a previous extractor. This ensures only one consumer receives the - /// stream. - /// - /// # Mutex poisoning - /// - /// If the internal mutex is poisoned (for example, due to a panic in another - /// thread while holding the lock), this method returns `None` rather than - /// propagating the panic. This behaviour prioritizes availability over - /// strict correctness: a poisoned mutex typically indicates a serious bug - /// elsewhere, but crashing additional handlers would only compound the - /// problem. The missing stream will surface as an [`ExtractError::MissingBodyStream`] - /// in the handler, which can be logged and investigated. - #[must_use] - pub fn take_body_stream(&self) -> Option { - self.body_stream - .as_ref() - .and_then(|mutex| mutex.lock().ok()) - .and_then(|mut guard| guard.take()) - } -} - -/// Raw payload buffer handed to extractors. -/// -/// Create a `Payload` from a slice using [`Payload::new`] or `into`: -/// -/// ```rust -/// use wireframe::extractor::Payload; -/// -/// let p1 = Payload::new(b"abc"); -/// let p2: Payload<'_> = b"xyz".as_slice().into(); -/// assert_eq!(p1.as_ref(), b"abc" as &[u8]); -/// assert_eq!(p2.as_ref(), b"xyz" as &[u8]); -/// ``` -#[derive(Default)] -pub struct Payload<'a> { - /// Incoming bytes not yet processed. - data: &'a [u8], -} - -impl<'a> Payload<'a> { - /// Creates a new `Payload` from the provided byte slice. - /// - /// # Examples - /// - /// ```rust,no_run - /// use wireframe::extractor::Payload; - /// - /// let payload = Payload::new(b"data"); - /// assert_eq!(payload.as_ref(), b"data" as &[u8]); - /// ``` - #[must_use] - #[inline] - pub fn new(data: &'a [u8]) -> Self { Self { data } } -} - -impl<'a> From<&'a [u8]> for Payload<'a> { - #[inline] - fn from(data: &'a [u8]) -> Self { Self { data } } -} - -impl AsRef<[u8]> for Payload<'_> { - fn as_ref(&self) -> &[u8] { self.data } -} - -impl Payload<'_> { - /// Advances the payload by `count` bytes. - /// - /// Consumes up to `count` bytes from the front of the slice, ensuring we - /// never slice beyond the available buffer. - /// - /// # Examples - /// - /// ```rust,no_run - /// use wireframe::extractor::Payload; - /// - /// let mut payload = Payload::new(b"abcd"); - /// payload.advance(2); - /// assert_eq!(payload.as_ref(), b"cd" as &[u8]); - /// ``` - pub fn advance(&mut self, count: usize) { - let n = count.min(self.data.len()); - self.data = self.data.get(n..).unwrap_or_default(); - } - - /// Returns the number of bytes remaining. - /// - /// # Examples - /// - /// ```rust,no_run - /// use wireframe::extractor::Payload; - /// - /// let mut payload = Payload::new(b"bytes"); - /// assert_eq!(payload.remaining(), 5); - /// payload.advance(2); - /// assert_eq!(payload.remaining(), 3); - /// ``` - #[must_use] - pub fn remaining(&self) -> usize { self.data.len() } -} - -/// Trait for extracting data from a [`MessageRequest`]. -/// -/// Types implementing this trait can be used as parameters to handler -/// functions. When invoked, `wireframe` passes the current request metadata and -/// message payload, allowing the extractor to parse bytes or inspect -/// connection information. This makes it easy to share common parsing and -/// validation logic across handlers. -pub trait FromMessageRequest: Sized { - /// Error type returned when extraction fails. - type Error: std::error::Error + Send + Sync + 'static; - - /// Perform extraction from the request and payload. - /// - /// # Errors - /// - /// Returns an error if extraction fails. - fn from_message_request( - req: &MessageRequest, - payload: &mut Payload<'_>, - ) -> Result; -} - -/// Shared application state accessible to handlers. -#[derive(Clone)] -pub struct SharedState(Arc); - -impl From> for SharedState { - fn from(inner: Arc) -> Self { Self(inner) } -} - -impl From for SharedState { - fn from(inner: T) -> Self { Self(Arc::new(inner)) } -} - -/// Errors that can occur when extracting built-in types. -/// -/// This enum is marked `#[non_exhaustive]` so more variants may be added in -/// the future without breaking changes. -#[derive(Debug, Error)] -#[non_exhaustive] -pub enum ExtractError { - /// No shared state of the requested type was found. - #[error("no shared state registered for {0}")] - MissingState(&'static str), - /// Failed to decode the message payload. - #[error("failed to decode payload: {0}")] - InvalidPayload(#[source] bincode::error::DecodeError), - /// No streaming body was available for this request. - /// - /// This occurs when: - /// - The request was not configured for streaming consumption - /// - The stream was already consumed by another extractor - #[error("no streaming body available for this request")] - MissingBodyStream, -} - -impl FromMessageRequest for SharedState -where - T: Send + Sync + 'static, -{ - type Error = ExtractError; - - fn from_message_request( - req: &MessageRequest, - _payload: &mut Payload<'_>, - ) -> Result { - req.state::() - .ok_or(ExtractError::MissingState(std::any::type_name::())) - } -} - -impl std::ops::Deref for SharedState { - type Target = T; - - /// Returns a reference to the inner shared state value. - /// - /// This allows transparent access to the underlying state managed by `SharedState`. - /// - /// # Examples - /// - /// ```rust,no_run - /// use std::sync::Arc; - /// - /// use wireframe::extractor::SharedState; - /// - /// let state = Arc::new(42); - /// let shared: SharedState = state.clone().into(); - /// assert_eq!(*shared, 42); - /// ``` - fn deref(&self) -> &Self::Target { &self.0 } -} +mod connection_info; +mod error; +mod message; +mod request; +mod shared_state; +mod streaming; +mod trait_def; + +pub use connection_info::ConnectionInfo; +pub use error::ExtractError; +pub use message::Message; +pub use request::{MessageRequest, Payload}; +pub use shared_state::SharedState; +pub use streaming::StreamingBody; +pub use trait_def::FromMessageRequest; diff --git a/src/extractor/request.rs b/src/extractor/request.rs new file mode 100644 index 00000000..1b372595 --- /dev/null +++ b/src/extractor/request.rs @@ -0,0 +1,239 @@ +//! Request context and payload buffer types for extractors. + +use std::{ + any::{Any, TypeId}, + collections::HashMap, + net::SocketAddr, + sync::{Arc, Mutex}, +}; + +use super::SharedState; +use crate::request::RequestBodyStream; + +/// Request context passed to extractors. +/// +/// This type contains metadata about the current connection and provides +/// access to application state registered with [`crate::app::WireframeApp`]. +#[derive(Default)] +pub struct MessageRequest { + /// Address of the peer that sent the current message. + pub peer_addr: Option, + /// Shared state values registered with the application. + /// + /// Values are keyed by their [`TypeId`]. Registering additional + /// state of the same type will replace the previous entry. + pub(crate) app_data: HashMap>, + /// Optional streaming body for handlers that opt into streaming consumption. + /// + /// When present, the [`StreamingBody`](crate::extractor::StreamingBody) + /// extractor can take ownership of this stream. Only one handler/extractor + /// may consume the stream; subsequent extractions will receive + /// [`ExtractError::MissingBodyStream`]. + body_stream: Option>>, +} + +impl MessageRequest { + /// Create a new empty message request. + /// + /// Use [`Self::with_peer_addr`] to configure connection metadata. + #[must_use] + pub fn new() -> Self { Self::default() } + + /// Set the peer address for this request. + /// + /// # Examples + /// + /// ```rust + /// use std::net::SocketAddr; + /// + /// use wireframe::extractor::MessageRequest; + /// + /// let req = MessageRequest::new().with_peer_addr(Some( + /// "127.0.0.1:8080".parse().expect("valid socket address"), + /// )); + /// assert!(req.peer_addr.is_some()); + /// ``` + #[must_use] + pub fn with_peer_addr(mut self, addr: Option) -> Self { + self.peer_addr = addr; + self + } + + /// Retrieve shared state of type `T` if available. + /// + /// Returns `None` when no value of type `T` was registered. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::{ + /// app::WireframeApp, + /// extractor::{MessageRequest, SharedState}, + /// }; + /// + /// let _app = WireframeApp::new() + /// .expect("failed to initialize app") + /// .app_data(5u32); + /// // The framework populates the request with application data. + /// # let mut req = MessageRequest::default(); + /// # req.insert_state(5u32); + /// let val: Option> = req.state(); + /// assert_eq!(*val.expect("shared state missing"), 5); + /// ``` + #[must_use] + pub fn state(&self) -> Option> + where + T: Send + Sync + 'static, + { + self.app_data + .get(&TypeId::of::()) + .and_then(|data| data.clone().downcast::().ok()) + .map(SharedState::from) + } + + /// Insert shared state of type `T` into the request. + /// + /// This replaces any existing value of the same type. + /// + /// # Examples + /// + /// ```rust + /// use wireframe::extractor::{MessageRequest, SharedState}; + /// + /// let mut req = MessageRequest::default(); + /// req.insert_state(5u32); + /// let val: Option> = req.state(); + /// assert_eq!(*val.expect("shared state missing"), 5); + /// ``` + pub fn insert_state(&mut self, state: T) + where + T: Send + Sync + 'static, + { + self.app_data.insert( + TypeId::of::(), + Arc::new(state) as Arc, + ); + } + + /// Set the streaming body for this request. + /// + /// The framework calls this when a handler opts into streaming consumption. + /// The stream can later be taken by the + /// [`StreamingBody`](crate::extractor::StreamingBody) extractor. + /// + /// # Examples + /// + /// ```rust + /// use wireframe::{extractor::MessageRequest, request::body_channel}; + /// + /// let mut req = MessageRequest::default(); + /// let (_tx, stream) = body_channel(4); + /// req.set_body_stream(stream); + /// ``` + pub fn set_body_stream(&mut self, stream: RequestBodyStream) { + self.body_stream = Some(Mutex::new(Some(stream))); + } + + /// Take the streaming body from this request, if present. + /// + /// Returns `None` if no body stream was set or if it was already taken + /// by a previous extractor. This ensures only one consumer receives the + /// stream. + /// + /// # Mutex poisoning + /// + /// If the internal mutex is poisoned (for example, due to a panic in another + /// thread while holding the lock), this method returns `None` rather than + /// propagating the panic. This behaviour prioritizes availability over + /// strict correctness: a poisoned mutex typically indicates a serious bug + /// elsewhere, but crashing additional handlers would only compound the + /// problem. The missing stream will surface as an + /// [`ExtractError::MissingBodyStream`](crate::extractor::ExtractError::MissingBodyStream) + /// in the handler, which can be logged and investigated. + #[must_use] + pub fn take_body_stream(&self) -> Option { + self.body_stream + .as_ref() + .and_then(|mutex| mutex.lock().ok()) + .and_then(|mut guard| guard.take()) + } +} + +/// Raw payload buffer handed to extractors. +/// +/// Create a `Payload` from a slice using [`Payload::new`] or `into`: +/// +/// ```rust +/// use wireframe::extractor::Payload; +/// +/// let p1 = Payload::new(b"abc"); +/// let p2: Payload<'_> = b"xyz".as_slice().into(); +/// assert_eq!(p1.as_ref(), b"abc" as &[u8]); +/// assert_eq!(p2.as_ref(), b"xyz" as &[u8]); +/// ``` +#[derive(Default)] +pub struct Payload<'a> { + /// Incoming bytes not yet processed. + pub(super) data: &'a [u8], +} + +impl<'a> Payload<'a> { + /// Creates a new `Payload` from the provided byte slice. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::extractor::Payload; + /// + /// let payload = Payload::new(b"data"); + /// assert_eq!(payload.as_ref(), b"data" as &[u8]); + /// ``` + #[must_use] + #[inline] + pub fn new(data: &'a [u8]) -> Self { Self { data } } +} + +impl<'a> From<&'a [u8]> for Payload<'a> { + #[inline] + fn from(data: &'a [u8]) -> Self { Self { data } } +} + +impl AsRef<[u8]> for Payload<'_> { + fn as_ref(&self) -> &[u8] { self.data } +} + +impl Payload<'_> { + /// Advances the payload by `count` bytes. + /// + /// Consumes up to `count` bytes from the front of the slice, ensuring we + /// never slice beyond the available buffer. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::extractor::Payload; + /// + /// let mut payload = Payload::new(b"abcd"); + /// payload.advance(2); + /// assert_eq!(payload.as_ref(), b"cd" as &[u8]); + /// ``` + pub fn advance(&mut self, count: usize) { + let n = count.min(self.data.len()); + self.data = self.data.get(n..).unwrap_or_default(); + } + + /// Returns the number of bytes remaining. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::extractor::Payload; + /// + /// let mut payload = Payload::new(b"bytes"); + /// assert_eq!(payload.remaining(), 5); + /// payload.advance(2); + /// assert_eq!(payload.remaining(), 3); + /// ``` + #[must_use] + pub fn remaining(&self) -> usize { self.data.len() } +} diff --git a/src/extractor/shared_state.rs b/src/extractor/shared_state.rs new file mode 100644 index 00000000..1589c814 --- /dev/null +++ b/src/extractor/shared_state.rs @@ -0,0 +1,59 @@ +//! Shared application state extractor. + +use std::sync::Arc; + +use super::{ExtractError, FromMessageRequest, MessageRequest, Payload}; + +/// Shared application state accessible to handlers. +#[derive(Clone)] +pub struct SharedState(Arc); + +impl std::fmt::Debug for SharedState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SharedState").finish_non_exhaustive() + } +} + +impl From> for SharedState { + fn from(inner: Arc) -> Self { Self(inner) } +} + +impl From for SharedState { + fn from(inner: T) -> Self { Self(Arc::new(inner)) } +} + +impl FromMessageRequest for SharedState +where + T: Send + Sync + 'static, +{ + type Error = ExtractError; + + fn from_message_request( + req: &MessageRequest, + _payload: &mut Payload<'_>, + ) -> Result { + req.state::() + .ok_or(ExtractError::MissingState(std::any::type_name::())) + } +} + +impl std::ops::Deref for SharedState { + type Target = T; + + /// Returns a reference to the inner shared state value. + /// + /// This allows transparent access to the underlying state managed by `SharedState`. + /// + /// # Examples + /// + /// ```rust,no_run + /// use std::sync::Arc; + /// + /// use wireframe::extractor::SharedState; + /// + /// let state = Arc::new(42); + /// let shared: SharedState = state.clone().into(); + /// assert_eq!(*shared, 42); + /// ``` + fn deref(&self) -> &Self::Target { &self.0 } +} diff --git a/src/extractor/streaming.rs b/src/extractor/streaming.rs new file mode 100644 index 00000000..6bcd70db --- /dev/null +++ b/src/extractor/streaming.rs @@ -0,0 +1,92 @@ +//! Streaming body extractor. + +use super::{ExtractError, FromMessageRequest, MessageRequest, Payload}; + +/// Extractor providing streaming access to the request body. +/// +/// Unlike [`Payload`] which borrows buffered bytes, this extractor +/// takes ownership of a streaming body channel. Handlers opting into +/// streaming receive chunks incrementally via a [`RequestBodyStream`]. +/// +/// This type is the inbound counterpart to [`crate::Response::Stream`]. +/// +/// # Examples +/// +/// ``` +/// use bytes::Bytes; +/// use tokio::io::AsyncReadExt; +/// use wireframe::{extractor::StreamingBody, request::body_channel}; +/// +/// # #[tokio::main] +/// # async fn main() { +/// let (tx, stream) = body_channel(4); +/// +/// tokio::spawn(async move { +/// let _ = tx.send(Ok(Bytes::from_static(b"payload"))).await; +/// }); +/// +/// let body = StreamingBody::new(stream); +/// let mut reader = body.into_reader(); +/// let mut buf = Vec::new(); +/// reader.read_to_end(&mut buf).await.expect("read body"); +/// assert_eq!(buf, b"payload"); +/// # } +/// ``` +/// +/// [`RequestBodyStream`]: crate::request::RequestBodyStream +pub struct StreamingBody { + stream: crate::request::RequestBodyStream, +} + +impl StreamingBody { + /// Create a streaming body from the given stream. + /// + /// Typically constructed by the framework when a handler opts into + /// streaming request consumption. + #[must_use] + pub fn new(stream: crate::request::RequestBodyStream) -> Self { Self { stream } } + + /// Consume the extractor and return the underlying stream. + /// + /// Use this when you need direct access to the stream for custom + /// processing with [`futures::StreamExt`] methods. + #[must_use] + pub fn into_stream(self) -> crate::request::RequestBodyStream { self.stream } + + /// Convert to an [`AsyncRead`] adapter. + /// + /// Protocol crates can use this to feed streaming bytes into parsers + /// that operate on readers rather than streams. + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + #[must_use] + pub fn into_reader(self) -> crate::request::RequestBodyReader { + crate::request::RequestBodyReader::new(self.stream) + } +} + +impl std::fmt::Debug for StreamingBody { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StreamingBody").finish_non_exhaustive() + } +} + +impl FromMessageRequest for StreamingBody { + type Error = ExtractError; + + /// Extract the streaming body from the request. + /// + /// # Errors + /// + /// Returns [`ExtractError::MissingBodyStream`] if: + /// - The request was not configured for streaming consumption + /// - The stream was already consumed by another extractor + fn from_message_request( + req: &MessageRequest, + _payload: &mut Payload<'_>, + ) -> Result { + req.take_body_stream() + .map(Self::new) + .ok_or(ExtractError::MissingBodyStream) + } +} diff --git a/src/extractor/trait_def.rs b/src/extractor/trait_def.rs new file mode 100644 index 00000000..870faf93 --- /dev/null +++ b/src/extractor/trait_def.rs @@ -0,0 +1,25 @@ +//! Trait definition for extractor types. + +use super::{MessageRequest, Payload}; + +/// Trait for extracting data from a [`MessageRequest`]. +/// +/// Types implementing this trait can be used as parameters to handler +/// functions. When invoked, `wireframe` passes the current request metadata and +/// message payload, allowing the extractor to parse bytes or inspect +/// connection information. This makes it easy to share common parsing and +/// validation logic across handlers. +pub trait FromMessageRequest: Sized { + /// Error type returned when extraction fails. + type Error: std::error::Error + Send + Sync + 'static; + + /// Perform extraction from the request and payload. + /// + /// # Errors + /// + /// Returns an error if extraction fails. + fn from_message_request( + req: &MessageRequest, + payload: &mut Payload<'_>, + ) -> Result; +} diff --git a/src/server/config/mod.rs b/src/server/config/mod.rs index 31480b75..06d87a62 100644 --- a/src/server/config/mod.rs +++ b/src/server/config/mod.rs @@ -49,9 +49,9 @@ macro_rules! builder_callback { pub mod binding; pub mod preamble; -fn default_worker_count() -> usize { - std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get) -} +fn default_worker_count() -> usize { super::default_worker_count() } +#[cfg(test)] +mod tests; impl WireframeServer where diff --git a/src/server/config/tests.rs b/src/server/config/tests.rs deleted file mode 100644 index 5e32879e..00000000 --- a/src/server/config/tests.rs +++ /dev/null @@ -1,406 +0,0 @@ -//! Tests for server configuration utilities. -//! -//! This module exercises the `WireframeServer` builder, covering worker counts, -//! binding behaviour, preamble handling, handler registration, and method -//! chaining. Fixtures from `test_util` provide shared setup and parameterised -//! cases via `rstest`. - -use std::{ - io, - sync::{ - Arc, - atomic::{AtomicUsize, Ordering}, - }, - time::Duration, -}; - -use bincode::error::DecodeError; -use rstest::{fixture, rstest}; -use tokio::net::{TcpListener, TcpStream}; - -use super::*; -use crate::server::{ - test_util::{ - TestPreamble, - bind_server, - factory, - free_listener, - listener_addr, - server_with_preamble, - }, - BackoffConfig, -}; - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum PreambleHandlerKind { - Success, - Failure, -} - -fn assert_local_addr_matches_listener( - server: WireframeServer, - expected: std::net::SocketAddr, -) where - F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, - T: crate::preamble::Preamble, - S: crate::server::ServerState, - Ser: crate::serializer::Serializer + Send + Sync, - Ctx: Send + 'static, - E: crate::app::Packet, - Codec: crate::codec::FrameCodec, -{ - let local_addr = server.local_addr().expect("local address missing"); - assert_eq!(local_addr, expected); -} - -#[fixture] -async fn connected_streams() -> io::Result<(TcpStream, TcpStream)> { - let listener = TcpListener::bind("127.0.0.1:0").await?; - let addr = listener.local_addr()?; - let client = TcpStream::connect(addr).await?; - let (server, _) = listener.accept().await?; - Ok((client, server)) -} - -#[rstest] -fn test_new_server_creation(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - let server = WireframeServer::new(factory); - assert!(server.worker_count() >= 1); - assert!(server.local_addr().is_none()); -} - -#[rstest] -fn test_new_server_default_worker_count( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) { - let server = WireframeServer::new(factory); - assert_eq!(server.worker_count(), default_worker_count()); -} - -#[rstest] -fn test_workers_configuration(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - let server = WireframeServer::new(factory); - let server = server.workers(4); - assert_eq!(server.worker_count(), 4); - let server = server.workers(100); - assert_eq!(server.worker_count(), 100); - assert_eq!(server.workers(0).worker_count(), 1); -} - -#[rstest] -fn test_with_preamble_type_conversion( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) { - let server = WireframeServer::new(factory).with_preamble::(); - assert_eq!(server.worker_count(), default_worker_count()); -} - -#[rstest] -fn test_preamble_timeout_configuration( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) { - let timeout = Duration::from_millis(25); - let server = WireframeServer::new(factory).preamble_timeout(timeout); - assert_eq!(server.preamble_timeout, Some(timeout)); - - let clamped = WireframeServer::new(factory).preamble_timeout(Duration::ZERO); - assert_eq!(clamped.preamble_timeout, Some(Duration::from_millis(1))); -} - -#[rstest] -#[tokio::test] -async fn test_bind_success( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - free_listener: std::net::TcpListener, -) { - let expected = listener_addr(&free_listener); - let server = WireframeServer::new(factory) - .bind_existing_listener(free_listener) - .expect("Failed to bind"); - assert_local_addr_matches_listener(server, expected); -} - -#[rstest] -fn test_local_addr_before_bind(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - assert!(WireframeServer::new(factory).local_addr().is_none()); -} - -#[rstest] -#[tokio::test] -async fn test_local_addr_after_bind( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - free_listener: std::net::TcpListener, -) { - let expected = listener_addr(&free_listener); - let server = bind_server(factory, free_listener); - assert_local_addr_matches_listener(server, expected); -} - -#[rstest] -#[case::success(PreambleHandlerKind::Success)] -#[case::failure(PreambleHandlerKind::Failure)] -#[tokio::test] -async fn test_preamble_handler_registration( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - #[case] handler: PreambleHandlerKind, - connected_streams: io::Result<(TcpStream, TcpStream)>, -) -> io::Result<()> { - let counter = Arc::new(AtomicUsize::new(0)); - let c = counter.clone(); - - let server = server_with_preamble(factory); - let server = match handler { - PreambleHandlerKind::Success => server.on_preamble_decode_success(move |_p: &TestPreamble, _| { - let c = c.clone(); - Box::pin(async move { - c.fetch_add(1, Ordering::SeqCst); - Ok(()) - }) - }), - PreambleHandlerKind::Failure => server.on_preamble_decode_failure( - move |_err: &DecodeError, _stream| { - let c = c.clone(); - Box::pin(async move { - c.fetch_add(1, Ordering::SeqCst); - Ok::<(), io::Error>(()) - }) - }, - ), - }; - - assert_eq!(counter.load(Ordering::SeqCst), 0); - match handler { - PreambleHandlerKind::Success => { - let handler = server - .on_preamble_success - .as_ref() - .expect("success handler missing"); - let (_client, mut stream) = connected_streams?; - let preamble = TestPreamble { id: 0, message: String::new() }; - handler(&preamble, &mut stream).await?; - } - PreambleHandlerKind::Failure => { - let handler = server - .on_preamble_failure - .as_ref() - .expect("failure handler missing"); - let (_client, mut stream) = connected_streams?; - handler(&DecodeError::UnexpectedEnd, &mut stream).await?; - } - } - assert_eq!(counter.load(Ordering::SeqCst), 1); - Ok(()) -} - -#[rstest] -#[tokio::test] -async fn test_method_chaining( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - free_listener: std::net::TcpListener, -) { - let handler_invoked = Arc::new(AtomicUsize::new(0)); - let counter = handler_invoked.clone(); - let server = WireframeServer::new(factory) - .workers(2) - .with_preamble::() - .on_preamble_decode_success(move |_p: &TestPreamble, _| { - let c = counter.clone(); - Box::pin(async move { - c.fetch_add(1, Ordering::SeqCst); - Ok(()) - }) - }) - .on_preamble_decode_failure(|_: &DecodeError, _| Box::pin(async { Ok::<(), io::Error>(()) })) - .bind_existing_listener(free_listener) - .expect("Failed to bind"); - assert_eq!(server.worker_count(), 2); - assert!(server.local_addr().is_some()); - assert_eq!(handler_invoked.load(Ordering::SeqCst), 0); -} - -#[rstest] -#[tokio::test] -async fn test_server_configuration_persistence( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - free_listener: std::net::TcpListener, -) { - let server = WireframeServer::new(factory) - .workers(5) - .bind_existing_listener(free_listener) - .expect("Failed to bind"); - assert_eq!(server.worker_count(), 5); - assert!(server.local_addr().is_some()); -} - -#[rstest] -fn test_extreme_worker_counts(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - let server = WireframeServer::new(factory); - let server = server.workers(usize::MAX); - assert_eq!(server.worker_count(), usize::MAX); - assert_eq!(server.workers(0).worker_count(), 1); -} - -#[rstest] -#[tokio::test] -async fn test_bind_to_multiple_addresses( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - free_listener: std::net::TcpListener, -) { - let addr1 = listener_addr(&free_listener); - - let server = WireframeServer::new(factory); - let server = server - .bind_existing_listener(free_listener) - .expect("Failed to bind first address"); - let first = server.local_addr().expect("first bound address missing"); - assert_eq!(first, addr1); - - let server = server - .bind(std::net::SocketAddr::new(addr1.ip(), 0)) - .expect("Failed to bind second address"); - let second = server.local_addr().expect("second bound address missing"); - assert_eq!(second.ip(), addr1.ip()); - assert_ne!(first.port(), second.port()); -} - -#[derive(Debug)] -struct BackoffScenario { - description: &'static str, - config: BackoffConfig, - expected_initial: Duration, - expected_max: Duration, -} - -#[rstest] -#[case::accept_config(BackoffScenario { - description: "accepts explicit delays", - config: BackoffConfig { - initial_delay: Duration::from_millis(5), - max_delay: Duration::from_millis(500), - }, - expected_initial: Duration::from_millis(5), - expected_max: Duration::from_millis(500), -})] -#[case::accept_initial_delay(BackoffScenario { - description: "accepts initial delay with default max", - config: BackoffConfig { - initial_delay: Duration::from_millis(20), - ..BackoffConfig::default() - }, - expected_initial: Duration::from_millis(20), - expected_max: BackoffConfig::default().max_delay, -})] -#[case::accept_max_delay(BackoffScenario { - description: "accepts max delay with default initial", - config: BackoffConfig { - max_delay: Duration::from_millis(2000), - ..BackoffConfig::default() - }, - expected_initial: BackoffConfig::default().initial_delay, - expected_max: Duration::from_millis(2000), -})] -#[case::clamp_zero_initial(BackoffScenario { - description: "clamps zero initial delay", - config: BackoffConfig { - initial_delay: Duration::ZERO, - ..BackoffConfig::default() - }, - expected_initial: Duration::from_millis(1), - expected_max: BackoffConfig::default().max_delay, -})] -#[case::swap_initial_gt_max(BackoffScenario { - description: "swaps initial and max delays when inverted", - config: BackoffConfig { - initial_delay: Duration::from_millis(100), - max_delay: Duration::from_millis(50), - }, - expected_initial: Duration::from_millis(50), - expected_max: Duration::from_millis(100), -})] -#[case::swap_initial_over_default_max(BackoffScenario { - description: "swaps initial and max delays when initial exceeds default max", - config: BackoffConfig { - initial_delay: Duration::from_secs(2), - max_delay: Duration::from_secs(1), - }, - expected_initial: Duration::from_secs(1), - expected_max: Duration::from_secs(2), -})] -#[case::swap_small_values(BackoffScenario { - description: "swaps small inverted delays", - config: BackoffConfig { - initial_delay: Duration::from_millis(5), - max_delay: Duration::from_millis(1), - }, - expected_initial: Duration::from_millis(1), - expected_max: Duration::from_millis(5), -})] -#[case::clamp_zero_both(BackoffScenario { - description: "clamps zero initial and max delays", - config: BackoffConfig { - initial_delay: Duration::ZERO, - max_delay: Duration::ZERO, - }, - expected_initial: Duration::from_millis(1), - expected_max: Duration::from_millis(1), -})] -fn test_accept_backoff_scenarios( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - #[case] scenario: BackoffScenario, -) { - let server = WireframeServer::new(factory).accept_backoff(scenario.config); - assert_eq!( - server.backoff_config.initial_delay, - scenario.expected_initial, - "scenario: {}", - scenario.description - ); - assert_eq!( - server.backoff_config.max_delay, - scenario.expected_max, - "scenario: {}", - scenario.description - ); -} - -/// Behaviour test verifying exponential delay doubling and capping. -#[test] -fn test_accept_exponential_backoff_doubles_and_caps() { - let initial = Duration::from_millis(10); - let max = Duration::from_millis(80); - let attempts = 5; - - let sequence = backoff_sequence(initial, max, attempts); - - let expected_delays = [ - initial, - std::cmp::min(initial.saturating_mul(2), max), - std::cmp::min(initial.saturating_mul(4), max), - std::cmp::min(initial.saturating_mul(8), max), - max, - ]; - - assert_eq!(&sequence[..], &expected_delays); -} - -fn backoff_sequence(initial: Duration, max: Duration, attempts: usize) -> Vec { - let mut sequence = Vec::with_capacity(attempts); - let mut backoff = initial; - - for _ in 0..attempts { - sequence.push(backoff); - backoff = std::cmp::min(backoff * 2, max); - } - - sequence -} - -#[rstest] -fn test_backoff_default_values(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - let server = WireframeServer::new(factory); - assert_eq!( - server.backoff_config.initial_delay, - Duration::from_millis(10) - ); - assert_eq!(server.backoff_config.max_delay, Duration::from_secs(1)); -} diff --git a/src/server/config/tests/mod.rs b/src/server/config/tests/mod.rs new file mode 100644 index 00000000..f34ff9e6 --- /dev/null +++ b/src/server/config/tests/mod.rs @@ -0,0 +1,18 @@ +//! Tests for server configuration utilities. +//! +//! This module exercises the `WireframeServer` builder, covering worker counts, +//! binding behaviour, preamble handling, handler registration, and method +//! chaining. Fixtures from `test_util` provide shared setup and parameterised +//! cases via `rstest`. + +mod tests_backoff; +mod tests_basic; +mod tests_binding; +mod tests_integration; +mod tests_preamble; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum PreambleHandlerKind { + Success, + Failure, +} diff --git a/src/server/config/tests/tests_backoff.rs b/src/server/config/tests/tests_backoff.rs new file mode 100644 index 00000000..429e791f --- /dev/null +++ b/src/server/config/tests/tests_backoff.rs @@ -0,0 +1,164 @@ +//! Backoff configuration tests for `WireframeServer`. + +use std::time::Duration; + +use rstest::rstest; + +use crate::{ + app::WireframeApp, + server::{BackoffConfig, WireframeServer, test_util::factory}, +}; + +#[derive(Debug)] +struct BackoffScenario { + config: BackoffConfig, + expected_initial: Duration, + expected_max: Duration, + description: &'static str, +} + +#[rstest] +#[case::custom_backoff( + BackoffScenario { + config: BackoffConfig { + initial_delay: Duration::from_millis(5), + max_delay: Duration::from_millis(500), + }, + expected_initial: Duration::from_millis(5), + expected_max: Duration::from_millis(500), + description: "custom backoff", + } +)] +#[case::custom_initial_delay( + BackoffScenario { + config: BackoffConfig { + initial_delay: Duration::from_millis(20), + ..BackoffConfig::default() + }, + expected_initial: Duration::from_millis(20), + expected_max: Duration::from_secs(1), + description: "custom initial delay", + } +)] +#[case::custom_max_delay( + BackoffScenario { + config: BackoffConfig { + max_delay: Duration::from_millis(2000), + ..BackoffConfig::default() + }, + expected_initial: Duration::from_millis(10), + expected_max: Duration::from_millis(2000), + description: "custom max delay", + } +)] +#[case::clamp_initial_delay( + BackoffScenario { + config: BackoffConfig { + initial_delay: Duration::ZERO, + ..BackoffConfig::default() + }, + expected_initial: Duration::from_millis(1), + expected_max: Duration::from_secs(1), + description: "clamp initial delay", + } +)] +#[case::swap_shorter_max( + BackoffScenario { + config: BackoffConfig { + initial_delay: Duration::from_millis(100), + max_delay: Duration::from_millis(50), + }, + expected_initial: Duration::from_millis(50), + expected_max: Duration::from_millis(100), + description: "swap shorter max", + } +)] +#[case::swap_with_default_max( + BackoffScenario { + config: BackoffConfig { + initial_delay: Duration::from_secs(2), + max_delay: Duration::from_secs(1), + }, + expected_initial: Duration::from_secs(1), + expected_max: Duration::from_secs(2), + description: "swap with default max", + } +)] +#[case::swap_small_values( + BackoffScenario { + config: BackoffConfig { + initial_delay: Duration::from_millis(5), + max_delay: Duration::from_millis(1), + }, + expected_initial: Duration::from_millis(1), + expected_max: Duration::from_millis(5), + description: "swap small values", + } +)] +#[case::clamp_both_zero( + BackoffScenario { + config: BackoffConfig { + initial_delay: Duration::ZERO, + max_delay: Duration::ZERO, + }, + expected_initial: Duration::from_millis(1), + expected_max: Duration::from_millis(1), + description: "clamp both zero", + } +)] +fn test_backoff_configuration_scenarios( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + #[case] scenario: BackoffScenario, +) { + let server = WireframeServer::new(factory).accept_backoff(scenario.config); + assert_eq!( + server.backoff_config.initial_delay, scenario.expected_initial, + "scenario: {}", + scenario.description + ); + assert_eq!( + server.backoff_config.max_delay, scenario.expected_max, + "scenario: {}", + scenario.description + ); +} + +/// Behaviour test verifying exponential delay doubling and capping. +#[test] +fn test_accept_exponential_backoff_doubles_and_caps() { + let initial = Duration::from_millis(10); + let max = Duration::from_millis(80); + let attempts = 5; + let delays = backoff_sequence(initial, max, attempts); + let expected_delays = vec![ + initial, + std::cmp::min(initial * 2, max), + std::cmp::min(initial * 4, max), + std::cmp::min(initial * 8, max), + max, + ]; + + assert_eq!(delays, expected_delays); +} + +fn backoff_sequence(initial: Duration, max: Duration, attempts: usize) -> Vec { + let mut backoff = initial; + let mut delays = Vec::with_capacity(attempts); + + for _ in 0..attempts { + delays.push(backoff); + backoff = std::cmp::min(backoff * 2, max); + } + + delays +} + +#[rstest] +fn test_backoff_default_values(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { + let server = WireframeServer::new(factory); + assert_eq!( + server.backoff_config.initial_delay, + Duration::from_millis(10) + ); + assert_eq!(server.backoff_config.max_delay, Duration::from_secs(1)); +} diff --git a/src/server/config/tests/tests_basic.rs b/src/server/config/tests/tests_basic.rs new file mode 100644 index 00000000..b41ca92f --- /dev/null +++ b/src/server/config/tests/tests_basic.rs @@ -0,0 +1,55 @@ +//! Basic configuration tests for `WireframeServer`. + +use rstest::rstest; + +use crate::{ + app::WireframeApp, + server::{ + WireframeServer, + default_worker_count, + test_util::{bind_server, factory, free_listener, listener_addr}, + }, +}; + +#[rstest] +fn test_new_server_creation(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { + let server = WireframeServer::new(factory); + assert!(server.worker_count() >= 1); + assert!(server.local_addr().is_none()); +} + +#[rstest] +fn test_new_server_default_worker_count( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, +) { + let server = WireframeServer::new(factory); + assert_eq!(server.worker_count(), default_worker_count()); +} + +#[rstest] +fn test_workers_configuration(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { + let server = WireframeServer::new(factory); + let server = server.workers(4); + assert_eq!(server.worker_count(), 4); + let server = server.workers(100); + assert_eq!(server.worker_count(), 100); + assert_eq!(server.workers(0).worker_count(), 1); +} + +#[rstest] +fn test_local_addr_before_bind(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { + assert!(WireframeServer::new(factory).local_addr().is_none()); +} + +#[rstest] +#[tokio::test] +async fn test_local_addr_after_bind( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_listener: std::net::TcpListener, +) { + let expected = listener_addr(&free_listener); + let local_addr = bind_server(factory, free_listener) + .local_addr() + .expect("local address missing"); + assert_eq!(local_addr, expected); +} diff --git a/src/server/config/tests/tests_binding.rs b/src/server/config/tests/tests_binding.rs new file mode 100644 index 00000000..5b624b2c --- /dev/null +++ b/src/server/config/tests/tests_binding.rs @@ -0,0 +1,49 @@ +//! Binding behaviour tests for `WireframeServer`. + +use rstest::rstest; + +use crate::{ + app::WireframeApp, + server::{ + WireframeServer, + test_util::{factory, free_listener, listener_addr}, + }, +}; + +#[rstest] +#[tokio::test] +async fn test_bind_success( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_listener: std::net::TcpListener, +) { + let expected = listener_addr(&free_listener); + let local_addr = WireframeServer::new(factory) + .bind_existing_listener(free_listener) + .expect("Failed to bind") + .local_addr() + .expect("local address missing"); + assert_eq!(local_addr, expected); +} + +#[rstest] +#[tokio::test] +async fn test_bind_to_multiple_addresses( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_listener: std::net::TcpListener, +) { + let addr1 = listener_addr(&free_listener); + + let server = WireframeServer::new(factory); + let server = server + .bind_existing_listener(free_listener) + .expect("Failed to bind first address"); + let first = server.local_addr().expect("first bound address missing"); + assert_eq!(first, addr1); + + let server = server + .bind(std::net::SocketAddr::new(addr1.ip(), 0)) + .expect("Failed to bind second address"); + let second = server.local_addr().expect("second bound address missing"); + assert_eq!(second.ip(), addr1.ip()); + assert_ne!(first.port(), second.port()); +} diff --git a/src/server/config/tests/tests_integration.rs b/src/server/config/tests/tests_integration.rs new file mode 100644 index 00000000..f9ce7457 --- /dev/null +++ b/src/server/config/tests/tests_integration.rs @@ -0,0 +1,70 @@ +//! Integration-style tests for server builder composition. + +use std::{ + io, + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, +}; + +use bincode::error::DecodeError; +use rstest::rstest; + +use crate::{ + app::WireframeApp, + server::{ + WireframeServer, + test_util::{TestPreamble, factory, free_listener}, + }, +}; + +#[rstest] +#[tokio::test] +async fn test_method_chaining( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_listener: std::net::TcpListener, +) { + let handler_invoked = Arc::new(AtomicUsize::new(0)); + let counter = handler_invoked.clone(); + let server = WireframeServer::new(factory) + .workers(2) + .with_preamble::() + .on_preamble_decode_success(move |_p: &TestPreamble, _| { + let c = counter.clone(); + Box::pin(async move { + c.fetch_add(1, Ordering::SeqCst); + Ok(()) + }) + }) + .on_preamble_decode_failure(|_: &DecodeError, _| { + Box::pin(async { Ok::<(), io::Error>(()) }) + }) + .bind_existing_listener(free_listener) + .expect("Failed to bind"); + assert_eq!(server.worker_count(), 2); + assert!(server.local_addr().is_some()); + assert_eq!(handler_invoked.load(Ordering::SeqCst), 0); +} + +#[rstest] +#[tokio::test] +async fn test_server_configuration_persistence( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_listener: std::net::TcpListener, +) { + let server = WireframeServer::new(factory) + .workers(5) + .bind_existing_listener(free_listener) + .expect("Failed to bind"); + assert_eq!(server.worker_count(), 5); + assert!(server.local_addr().is_some()); +} + +#[rstest] +fn test_extreme_worker_counts(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { + let server = WireframeServer::new(factory); + let server = server.workers(usize::MAX); + assert_eq!(server.worker_count(), usize::MAX); + assert_eq!(server.workers(0).worker_count(), 1); +} diff --git a/src/server/config/tests/tests_preamble.rs b/src/server/config/tests/tests_preamble.rs new file mode 100644 index 00000000..5360dfd6 --- /dev/null +++ b/src/server/config/tests/tests_preamble.rs @@ -0,0 +1,120 @@ +//! Preamble configuration tests for `WireframeServer`. + +use std::{ + io, + net::TcpListener as StdTcpListener, + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, + time::Duration, +}; + +use bincode::error::DecodeError; +use rstest::{fixture, rstest}; +use tokio::net::{TcpListener, TcpStream}; + +use super::PreambleHandlerKind; +use crate::{ + app::WireframeApp, + server::{ + WireframeServer, + default_worker_count, + test_util::{TestPreamble, factory, server_with_preamble}, + }, +}; + +#[rstest] +fn test_with_preamble_type_conversion( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, +) { + let server = WireframeServer::new(factory).with_preamble::(); + assert_eq!(server.worker_count(), default_worker_count()); +} + +#[rstest] +fn test_preamble_timeout_configuration( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, +) { + let timeout = Duration::from_millis(25); + let server = WireframeServer::new(factory.clone()).preamble_timeout(timeout); + assert_eq!(server.preamble_timeout, Some(timeout)); + + let clamped = WireframeServer::new(factory).preamble_timeout(Duration::ZERO); + assert_eq!(clamped.preamble_timeout, Some(Duration::from_millis(1))); +} + +#[rstest] +#[case::success(PreambleHandlerKind::Success)] +#[case::failure(PreambleHandlerKind::Failure)] +#[tokio::test] +async fn test_preamble_handler_registration( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + #[case] handler: PreambleHandlerKind, + bound_listener: io::Result, +) -> io::Result<()> { + let counter = Arc::new(AtomicUsize::new(0)); + let c = counter.clone(); + + let server = server_with_preamble(factory); + let server = match handler { + PreambleHandlerKind::Success => { + server.on_preamble_decode_success(move |_p: &TestPreamble, _| { + let c = c.clone(); + Box::pin(async move { + c.fetch_add(1, Ordering::SeqCst); + Ok(()) + }) + }) + } + PreambleHandlerKind::Failure => { + server.on_preamble_decode_failure(move |_err: &DecodeError, _stream| { + let c = c.clone(); + Box::pin(async move { + c.fetch_add(1, Ordering::SeqCst); + Ok::<(), io::Error>(()) + }) + }) + } + }; + + assert_eq!(counter.load(Ordering::SeqCst), 0); + let mut stream = accept_stream(bound_listener?).await?; + match handler { + PreambleHandlerKind::Success => { + let handler = server + .on_preamble_success + .as_ref() + .expect("success handler missing"); + let preamble = TestPreamble { + id: 0, + message: String::new(), + }; + handler(&preamble, &mut stream).await?; + } + PreambleHandlerKind::Failure => { + let handler = server + .on_preamble_failure + .as_ref() + .expect("failure handler missing"); + handler(&DecodeError::UnexpectedEnd { additional: 0 }, &mut stream).await?; + } + } + assert_eq!(counter.load(Ordering::SeqCst), 1); + Ok(()) +} + +#[fixture] +fn bound_listener() -> io::Result { + let addr = "127.0.0.1:0"; + StdTcpListener::bind(addr) +} + +async fn accept_stream(listener: StdTcpListener) -> io::Result { + let addr = listener.local_addr()?; + listener.set_nonblocking(true)?; + let listener = TcpListener::from_std(listener)?; + let _client = TcpStream::connect(addr).await?; + let (stream, _) = listener.accept().await?; + Ok(stream) +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 28a0ac78..d4e8fc4b 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -17,6 +17,11 @@ use crate::{ serializer::{BincodeSerializer, Serializer}, }; +/// Compute the default worker count from available CPU parallelism. +pub(crate) fn default_worker_count() -> usize { + std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get) +} + /// Handler invoked when a connection preamble decodes successfully. /// /// Implementors may perform asynchronous I/O on the provided stream before the diff --git a/tests/app_data.rs b/tests/app_data.rs index 7dbd533a..0b85688b 100644 --- a/tests/app_data.rs +++ b/tests/app_data.rs @@ -41,7 +41,6 @@ fn missing_shared_state_returns_error( mut empty_payload: Payload<'static>, ) { let err = SharedState::::from_message_request(&request, &mut empty_payload) - .err() - .expect("missing state error expected"); + .expect_err("missing state error expected"); assert!(matches!(err, ExtractError::MissingState(_))); }