From 6ecc914e993f9d5b753b753fdc8709de7bea1e37 Mon Sep 17 00:00:00 2001 From: Leynos Date: Thu, 5 Feb 2026 11:19:15 +0000 Subject: [PATCH 1/9] refactor(app builder): modularize WireframeApp and remove deprecated client builder Removed the monolithic src/app/builder.rs file and replaced it with a modular builder implementation split across multiple files (codec.rs, config.rs, core.rs, lifecycle.rs, mod.rs, protocol.rs, routing.rs, state.rs). This improves maintainability and readability by organizing related functionality into dedicated modules. Deleted the deprecated src/client/builder.rs and related client builder modules, cleaning up the codebase. Additionally, extracted frame handling helpers into a frame_handling module with submodules for core, reassembly, and response forwarding logic, keeping connection.rs smaller and more focused. Overall, this commit significantly improves project structure by organizing app builder and frame handling code into well-defined modules, and removing outdated client builder code. Co-authored-by: devboxerhub[bot] --- src/app/builder.rs | 335 -------------- src/app/builder/codec.rs | 69 +++ src/app/builder/config.rs | 62 +++ src/app/builder/core.rs | 162 +++++++ .../lifecycle.rs} | 30 +- src/app/builder/mod.rs | 15 + .../protocol.rs} | 47 +- src/app/builder/routing.rs | 49 ++ src/app/builder/state.rs | 30 ++ src/app/frame_handling.rs | 350 --------------- src/app/frame_handling/core.rs | 52 +++ src/app/frame_handling/mod.rs | 15 + src/app/frame_handling/reassembly.rs | 37 ++ src/app/frame_handling/response.rs | 139 ++++++ src/app/frame_handling/tests.rs | 209 +++++++++ src/app/mod.rs | 2 - src/client/builder.rs | 425 ------------------ src/client/builder/codec.rs | 58 +++ src/client/builder/connect.rs | 96 ++++ src/client/builder/core.rs | 61 +++ src/client/builder/lifecycle.rs | 127 ++++++ src/client/builder/mod.rs | 54 +++ src/client/builder/preamble.rs | 40 ++ src/client/builder/serializer.rs | 27 ++ src/codec/recovery.rs | 331 -------------- src/codec/recovery/config.rs | 76 ++++ src/codec/recovery/context.rs | 74 +++ src/codec/recovery/hook.rs | 79 ++++ src/codec/recovery/mod.rs | 49 ++ src/codec/recovery/policy.rs | 75 ++++ src/codec/recovery/tests.rs | 59 +-- src/extractor/connection_info.rs | 47 ++ src/extractor/error.rs | 48 ++ src/extractor/extractors.rs | 182 -------- src/extractor/message.rs | 47 ++ src/extractor/mod.rs | 345 +------------- src/extractor/request.rs | 235 ++++++++++ src/extractor/shared_state.rs | 53 +++ src/extractor/streaming.rs | 92 ++++ src/extractor/trait_def.rs | 25 ++ src/server/config/mod.rs | 2 + src/server/config/tests.rs | 406 ----------------- src/server/config/tests/mod.rs | 23 + src/server/config/tests/tests_backoff.rs | 166 +++++++ src/server/config/tests/tests_basic.rs | 54 +++ src/server/config/tests/tests_binding.rs | 49 ++ src/server/config/tests/tests_integration.rs | 70 +++ src/server/config/tests/tests_preamble.rs | 125 ++++++ 48 files changed, 2772 insertions(+), 2431 deletions(-) delete mode 100644 src/app/builder.rs create mode 100644 src/app/builder/codec.rs create mode 100644 src/app/builder/config.rs create mode 100644 src/app/builder/core.rs rename src/app/{builder_lifecycle.rs => builder/lifecycle.rs} (65%) create mode 100644 src/app/builder/mod.rs rename src/app/{builder_protocol.rs => builder/protocol.rs} (97%) create mode 100644 src/app/builder/routing.rs create mode 100644 src/app/builder/state.rs delete mode 100644 src/app/frame_handling.rs create mode 100644 src/app/frame_handling/core.rs create mode 100644 src/app/frame_handling/mod.rs create mode 100644 src/app/frame_handling/reassembly.rs create mode 100644 src/app/frame_handling/response.rs create mode 100644 src/app/frame_handling/tests.rs delete mode 100644 src/client/builder.rs create mode 100644 src/client/builder/codec.rs create mode 100644 src/client/builder/connect.rs create mode 100644 src/client/builder/core.rs create mode 100644 src/client/builder/lifecycle.rs create mode 100644 src/client/builder/mod.rs create mode 100644 src/client/builder/preamble.rs create mode 100644 src/client/builder/serializer.rs delete mode 100644 src/codec/recovery.rs create mode 100644 src/codec/recovery/config.rs create mode 100644 src/codec/recovery/context.rs create mode 100644 src/codec/recovery/hook.rs create mode 100644 src/codec/recovery/mod.rs create mode 100644 src/codec/recovery/policy.rs create mode 100644 src/extractor/connection_info.rs create mode 100644 src/extractor/error.rs delete mode 100644 src/extractor/extractors.rs create mode 100644 src/extractor/message.rs create mode 100644 src/extractor/request.rs create mode 100644 src/extractor/shared_state.rs create mode 100644 src/extractor/streaming.rs create mode 100644 src/extractor/trait_def.rs delete mode 100644 src/server/config/tests.rs create mode 100644 src/server/config/tests/mod.rs create mode 100644 src/server/config/tests/tests_backoff.rs create mode 100644 src/server/config/tests/tests_basic.rs create mode 100644 src/server/config/tests/tests_binding.rs create mode 100644 src/server/config/tests/tests_integration.rs create mode 100644 src/server/config/tests/tests_preamble.rs diff --git a/src/app/builder.rs b/src/app/builder.rs deleted file mode 100644 index 42162c1b..00000000 --- a/src/app/builder.rs +++ /dev/null @@ -1,335 +0,0 @@ -//! Application builder configuring routes and middleware. -//! [`WireframeApp`] is an Actix-inspired builder for managing connection -//! state, routing, and middleware in a `WireframeServer`. It exposes -//! convenience methods to register handlers and lifecycle hooks, and -//! serializes messages using a configurable serializer. - -use std::{ - any::{Any, TypeId}, - collections::HashMap, - sync::Arc, -}; - -use tokio::sync::{OnceCell, mpsc}; - -use super::{ - builder_defaults::{ - DEFAULT_READ_TIMEOUT_MS, - MAX_READ_TIMEOUT_MS, - MIN_READ_TIMEOUT_MS, - default_fragmentation, - }, - envelope::{Envelope, Packet}, - error::{Result, WireframeError}, - lifecycle::{ConnectionSetup, ConnectionTeardown}, - middleware_types::{Handler, Middleware}, -}; -use crate::{ - codec::{FrameCodec, LengthDelimitedFrameCodec, clamp_frame_length}, - fragment::FragmentationConfig, - hooks::WireframeProtocol, - message_assembler::MessageAssembler, - middleware::HandlerService, - serializer::{BincodeSerializer, Serializer}, -}; - -/// Configures routing and middleware for a `WireframeServer`. -/// -/// The builder stores registered routes and middleware without enforcing an -/// ordering. Methods return [`Result`] so registrations can be chained -/// ergonomically. -pub struct WireframeApp< - S: Serializer + Send + Sync = BincodeSerializer, - C: Send + 'static = (), - E: Packet = Envelope, - F: FrameCodec = LengthDelimitedFrameCodec, -> { - pub(super) handlers: HashMap>, - pub(super) routes: OnceCell>>>, - pub(super) middleware: Vec>>, - pub(super) serializer: S, - pub(super) app_data: HashMap>, - pub(super) on_connect: Option>>, - pub(super) on_disconnect: Option>>, - pub(super) protocol: Option>>, - pub(super) push_dlq: Option>>, - pub(super) codec: F, - pub(super) read_timeout_ms: u64, - pub(super) fragmentation: Option, - pub(super) message_assembler: Option>, -} - -impl Default for WireframeApp -where - S: Serializer + Default + Send + Sync, - C: Send + 'static, - E: Packet, - F: FrameCodec + Default, -{ - /// Initializes empty routes, middleware, and application data with the - /// default serializer and no lifecycle hooks. - fn default() -> Self { - let codec = F::default(); - let max_frame_length = codec.max_frame_length(); - Self { - handlers: HashMap::new(), - routes: OnceCell::new(), - middleware: Vec::new(), - serializer: S::default(), - app_data: HashMap::new(), - on_connect: None, - on_disconnect: None, - protocol: None, - push_dlq: None, - codec, - read_timeout_ms: DEFAULT_READ_TIMEOUT_MS, - fragmentation: default_fragmentation(max_frame_length), - message_assembler: None, - } - } -} - -impl WireframeApp -where - S: Serializer + Default + Send + Sync, - C: Send + 'static, - E: Packet, - F: FrameCodec + Default, -{ - /// Construct a new empty application builder. - /// - /// # Errors - /// - /// This function currently never returns an error but uses [`Result`] for - /// forward compatibility. - /// - /// # Examples - /// - /// ``` - /// use wireframe::app::WireframeApp; - /// WireframeApp::<_, _, wireframe::app::Envelope>::new().expect("failed to initialize app"); - /// ``` - pub fn new() -> Result { Ok(Self::default()) } - - /// Construct a new application builder using the provided serializer. - /// - /// # Errors - /// - /// This function currently never returns an error but uses [`Result`] for - /// forward compatibility. - pub fn with_serializer(serializer: S) -> Result { - Ok(Self { - serializer, - ..Self::default() - }) - } -} - -impl WireframeApp -where - S: Serializer + Send + Sync, - C: Send + 'static, - E: Packet, - F: FrameCodec, -{ - /// Helper to rebuild the app when changing type parameters. - /// - /// This centralises the field-by-field reconstruction required when - /// transforming between different serializer or codec types. - #[expect( - clippy::too_many_arguments, - reason = "internal helper grouping fields for type-transitioning builders" - )] - fn rebuild_with_params( - self, - serializer: S2, - codec: F2, - protocol: Option>>, - fragmentation: Option, - message_assembler: Option>, - ) -> WireframeApp - where - S2: Serializer + Send + Sync, - F2: FrameCodec, - { - WireframeApp { - handlers: self.handlers, - routes: OnceCell::new(), - middleware: self.middleware, - serializer, - app_data: self.app_data, - on_connect: self.on_connect, - on_disconnect: self.on_disconnect, - protocol, - push_dlq: self.push_dlq, - codec, - read_timeout_ms: self.read_timeout_ms, - fragmentation, - message_assembler, - } - } - - /// Helper to rebuild the app when changing the connection state type. - pub(super) fn rebuild_with_connection_type( - self, - on_connect: Option>>, - on_disconnect: Option>>, - ) -> WireframeApp - where - C2: Send + 'static, - { - WireframeApp { - handlers: self.handlers, - routes: OnceCell::new(), - middleware: self.middleware, - serializer: self.serializer, - app_data: self.app_data, - on_connect, - on_disconnect, - protocol: self.protocol, - push_dlq: self.push_dlq, - codec: self.codec, - read_timeout_ms: self.read_timeout_ms, - fragmentation: self.fragmentation, - message_assembler: self.message_assembler, - } - } - - /// Replace the frame codec used for framing I/O. - /// - /// This resets any installed protocol hooks because the frame type may - /// change across codecs. Fragmentation configuration is reset to the - /// codec-derived default. - #[must_use] - pub fn with_codec(mut self, codec: F2) -> WireframeApp - where - S: Default, - { - let fragmentation = default_fragmentation(codec.max_frame_length()); - let serializer = std::mem::take(&mut self.serializer); - let message_assembler = self.message_assembler.take(); - self.rebuild_with_params(serializer, codec, None, fragmentation, message_assembler) - } - - /// Register a route that maps `id` to `handler`. - /// - /// # Errors - /// - /// Returns [`WireframeError::DuplicateRoute`] if a handler for `id` - /// has already been registered. - pub fn route(mut self, id: u32, handler: Handler) -> Result { - if self.handlers.contains_key(&id) { - return Err(WireframeError::DuplicateRoute(id)); - } - self.handlers.insert(id, handler); - self.routes = OnceCell::new(); - Ok(self) - } - - /// Store a shared state value accessible to request extractors. - /// - /// The value can later be retrieved using [`crate::extractor::SharedState`]. Registering - /// another value of the same type overwrites the previous one. - #[must_use] - pub fn app_data(mut self, state: T) -> Self - where - T: Send + Sync + 'static, - { - self.app_data.insert( - TypeId::of::(), - Arc::new(state) as Arc, - ); - self - } - - /// Add a middleware component to the processing pipeline. - /// - /// # Errors - /// - /// This function currently always succeeds. - pub fn wrap(mut self, mw: M) -> Result - where - M: Middleware + 'static, - { - self.middleware.push(Box::new(mw)); - self.routes = OnceCell::new(); - Ok(self) - } - - /// Configure a Dead Letter Queue for dropped push frames. - /// - /// ```rust,no_run - /// use tokio::sync::mpsc; - /// use wireframe::app::WireframeApp; - /// - /// # fn build() -> WireframeApp { - /// # WireframeApp::new().expect("builder creation should not fail") - /// # } - /// # fn main() { - /// let (tx, _rx) = mpsc::channel(16); - /// let app = build().with_push_dlq(tx); - /// # let _ = app; - /// # } - /// ``` - #[must_use] - pub fn with_push_dlq(self, dlq: mpsc::Sender>) -> Self { - WireframeApp { - push_dlq: Some(dlq), - ..self - } - } - - /// Replace the serializer used for messages. - #[must_use] - pub fn serializer(mut self, serializer: Ser) -> WireframeApp - where - Ser: Serializer + Send + Sync, - F: Default, - { - let codec = std::mem::take(&mut self.codec); - let protocol = self.protocol.take(); - let fragmentation = self.fragmentation.take(); - let message_assembler = self.message_assembler.take(); - self.rebuild_with_params( - serializer, - codec, - protocol, - fragmentation, - message_assembler, - ) - } - - /// Configure the read timeout in milliseconds. - /// Clamped between 1 and 86 400 000 milliseconds (24 h). - #[must_use] - pub fn read_timeout_ms(mut self, timeout_ms: u64) -> Self { - self.read_timeout_ms = timeout_ms.clamp(MIN_READ_TIMEOUT_MS, MAX_READ_TIMEOUT_MS); - self - } - - /// Override the fragmentation configuration. - /// - /// Provide `None` to disable fragmentation entirely. - #[must_use] - pub fn fragmentation(mut self, config: Option) -> Self { - self.fragmentation = config; - self - } -} - -impl WireframeApp -where - S: Serializer + Send + Sync, - C: Send + 'static, - E: Packet, -{ - /// Set the initial buffer capacity for framed reads. - /// Clamped between 64 bytes and 16 MiB. - #[must_use] - pub fn buffer_capacity(mut self, capacity: usize) -> Self { - let capacity = clamp_frame_length(capacity); - self.codec = LengthDelimitedFrameCodec::new(capacity); - self.fragmentation = default_fragmentation(capacity); - self - } -} diff --git a/src/app/builder/codec.rs b/src/app/builder/codec.rs new file mode 100644 index 00000000..f91ebc09 --- /dev/null +++ b/src/app/builder/codec.rs @@ -0,0 +1,69 @@ +//! Codec and serializer configuration for `WireframeApp`. + +use super::WireframeApp; +use crate::{ + app::{Packet, builder_defaults::default_fragmentation}, + codec::{FrameCodec, LengthDelimitedFrameCodec, clamp_frame_length}, + serializer::Serializer, +}; + +impl WireframeApp +where + S: Serializer + Send + Sync, + C: Send + 'static, + E: Packet, + F: FrameCodec, +{ + /// Replace the frame codec used for framing I/O. + /// + /// This resets any installed protocol hooks because the frame type may + /// change across codecs. Fragmentation configuration is reset to the + /// codec-derived default. + #[must_use] + pub fn with_codec(mut self, codec: F2) -> WireframeApp + where + S: Default, + { + let fragmentation = default_fragmentation(codec.max_frame_length()); + let serializer = std::mem::take(&mut self.serializer); + let message_assembler = self.message_assembler.take(); + self.rebuild_with_params(serializer, codec, None, fragmentation, message_assembler) + } + + /// Replace the serializer used for messages. + #[must_use] + pub fn serializer(mut self, serializer: Ser) -> WireframeApp + where + Ser: Serializer + Send + Sync, + F: Default, + { + let codec = std::mem::take(&mut self.codec); + let protocol = self.protocol.take(); + let fragmentation = self.fragmentation.take(); + let message_assembler = self.message_assembler.take(); + self.rebuild_with_params( + serializer, + codec, + protocol, + fragmentation, + message_assembler, + ) + } +} + +impl WireframeApp +where + S: Serializer + Default + Send + Sync, + C: Send + 'static, + E: Packet, +{ + /// Set the initial buffer capacity for framed reads. + /// Clamped between 64 bytes and 16 MiB. + #[must_use] + pub fn buffer_capacity(mut self, capacity: usize) -> Self { + let capacity = clamp_frame_length(capacity); + self.codec = LengthDelimitedFrameCodec::new(capacity); + self.fragmentation = default_fragmentation(capacity); + self + } +} diff --git a/src/app/builder/config.rs b/src/app/builder/config.rs new file mode 100644 index 00000000..dac658d9 --- /dev/null +++ b/src/app/builder/config.rs @@ -0,0 +1,62 @@ +//! General configuration methods for `WireframeApp`. + +use tokio::sync::mpsc; + +use super::WireframeApp; +use crate::{ + app::{ + Packet, + builder_defaults::{MAX_READ_TIMEOUT_MS, MIN_READ_TIMEOUT_MS}, + }, + codec::FrameCodec, + fragment::FragmentationConfig, + serializer::Serializer, +}; + +impl WireframeApp +where + S: Serializer + Send + Sync, + C: Send + 'static, + E: Packet, + F: FrameCodec, +{ + /// Configure the read timeout in milliseconds. + /// Clamped between 1 and 86,400,000 milliseconds (24 h). + #[must_use] + pub fn read_timeout_ms(mut self, timeout_ms: u64) -> Self { + self.read_timeout_ms = timeout_ms.clamp(MIN_READ_TIMEOUT_MS, MAX_READ_TIMEOUT_MS); + self + } + + /// Override the fragmentation configuration. + /// + /// Provide `None` to disable fragmentation entirely. + #[must_use] + pub fn fragmentation(mut self, config: Option) -> Self { + self.fragmentation = config; + self + } + + /// Configure a Dead Letter Queue for dropped push frames. + /// + /// ```rust,no_run + /// use tokio::sync::mpsc; + /// use wireframe::app::WireframeApp; + /// + /// # fn build() -> WireframeApp { + /// # WireframeApp::new().expect("builder creation should not fail") + /// # } + /// # fn main() { + /// let (tx, _rx) = mpsc::channel(16); + /// let app = build().with_push_dlq(tx); + /// # let _ = app; + /// # } + /// ``` + #[must_use] + pub fn with_push_dlq(self, dlq: mpsc::Sender>) -> Self { + WireframeApp { + push_dlq: Some(dlq), + ..self + } + } +} diff --git a/src/app/builder/core.rs b/src/app/builder/core.rs new file mode 100644 index 00000000..cb4a03ef --- /dev/null +++ b/src/app/builder/core.rs @@ -0,0 +1,162 @@ +//! Core builder types for `WireframeApp`. + +use std::{ + any::{Any, TypeId}, + collections::HashMap, + sync::Arc, +}; + +use tokio::sync::{OnceCell, mpsc}; + +use crate::{ + app::{ + builder_defaults::default_fragmentation, + envelope::{Envelope, Packet}, + error::Result, + lifecycle::{ConnectionSetup, ConnectionTeardown}, + middleware_types::{Handler, Middleware}, + }, + codec::{FrameCodec, LengthDelimitedFrameCodec}, + hooks::WireframeProtocol, + message_assembler::MessageAssembler, + middleware::HandlerService, + serializer::{BincodeSerializer, Serializer}, +}; + +/// Configures routing and middleware for a `WireframeServer`. +/// +/// The builder stores registered routes and middleware without enforcing an +/// ordering. Methods return [`Result`] so registrations can be chained +/// ergonomically. +pub struct WireframeApp< + S: Serializer + Send + Sync = BincodeSerializer, + C: Send + 'static = (), + E: Packet = Envelope, + F: FrameCodec = LengthDelimitedFrameCodec, +> { + pub(in crate::app) handlers: HashMap>, + pub(in crate::app) routes: OnceCell>>>, + pub(in crate::app) middleware: Vec>>, + pub(in crate::app) serializer: S, + pub(in crate::app) app_data: HashMap>, + pub(in crate::app) on_connect: Option>>, + pub(in crate::app) on_disconnect: Option>>, + pub(in crate::app) protocol: + Option>>, + pub(in crate::app) push_dlq: Option>>, + pub(in crate::app) codec: F, + pub(in crate::app) read_timeout_ms: u64, + pub(in crate::app) fragmentation: Option, + pub(in crate::app) message_assembler: Option>, +} + +impl Default for WireframeApp +where + S: Serializer + Default + Send + Sync, + C: Send + 'static, + E: Packet, + F: FrameCodec + Default, +{ + /// Initializes empty routes, middleware, and application data with the + /// default serializer and no lifecycle hooks. + fn default() -> Self { + let codec = F::default(); + let max_frame_length = codec.max_frame_length(); + Self { + handlers: HashMap::new(), + routes: OnceCell::new(), + middleware: Vec::new(), + serializer: S::default(), + app_data: HashMap::new(), + on_connect: None, + on_disconnect: None, + protocol: None, + push_dlq: None, + codec, + read_timeout_ms: 100, + fragmentation: default_fragmentation(max_frame_length), + message_assembler: None, + } + } +} + +impl WireframeApp +where + S: Serializer + Default + Send + Sync, + C: Send + 'static, + E: Packet, + F: FrameCodec + Default, +{ + /// Construct a new empty application builder. + /// + /// # Errors + /// + /// This function currently never returns an error but uses [`Result`] for + /// forward compatibility. + /// + /// # Examples + /// + /// ``` + /// use wireframe::app::WireframeApp; + /// WireframeApp::<_, _, wireframe::app::Envelope>::new().expect("failed to initialize app"); + /// ``` + pub fn new() -> Result { Ok(Self::default()) } + + /// Construct a new application builder using the provided serializer. + /// + /// # Errors + /// + /// This function currently never returns an error but uses [`Result`] for + /// forward compatibility. + pub fn with_serializer(serializer: S) -> Result { + Ok(Self { + serializer, + ..Self::default() + }) + } +} + +impl WireframeApp +where + S: Serializer + Send + Sync, + C: Send + 'static, + E: Packet, + F: FrameCodec, +{ + /// Helper to rebuild the app when changing type parameters. + /// + /// This centralises the field-by-field reconstruction required when + /// transforming between different serializer or codec types. + #[expect( + clippy::too_many_arguments, + reason = "internal helper grouping fields for type-transitioning builders" + )] + pub(super) fn rebuild_with_params( + self, + serializer: S2, + codec: F2, + protocol: Option>>, + fragmentation: Option, + message_assembler: Option>, + ) -> WireframeApp + where + S2: Serializer + Send + Sync, + F2: FrameCodec, + { + WireframeApp { + handlers: self.handlers, + routes: OnceCell::new(), + middleware: self.middleware, + serializer, + app_data: self.app_data, + on_connect: self.on_connect, + on_disconnect: self.on_disconnect, + protocol, + push_dlq: self.push_dlq, + codec, + read_timeout_ms: self.read_timeout_ms, + fragmentation, + message_assembler, + } + } +} diff --git a/src/app/builder_lifecycle.rs b/src/app/builder/lifecycle.rs similarity index 65% rename from src/app/builder_lifecycle.rs rename to src/app/builder/lifecycle.rs index 33c2ca58..e460732d 100644 --- a/src/app/builder_lifecycle.rs +++ b/src/app/builder/lifecycle.rs @@ -1,9 +1,13 @@ -//! Connection lifecycle builder methods for [`WireframeApp`]. +//! Connection lifecycle hook configuration for `WireframeApp`. use std::{future::Future, sync::Arc}; -use super::{builder::WireframeApp, envelope::Packet, error::Result}; -use crate::{codec::FrameCodec, serializer::Serializer}; +use super::WireframeApp; +use crate::{ + app::{Packet, error::Result}, + codec::FrameCodec, + serializer::Serializer, +}; impl WireframeApp where @@ -21,8 +25,8 @@ where /// # Type Parameters /// /// This method changes the connection state type parameter from `C` to `C2`. - /// This means that any subsequent builder methods will operate on the new connection state - /// type `C2`. Be aware of this type transition when chaining builder methods. + /// This means that any subsequent builder methods will operate on the new connection state type + /// `C2`. Be aware of this type transition when chaining builder methods. /// /// # Errors /// @@ -37,7 +41,21 @@ where Fut: Future + Send + 'static, C2: Send + 'static, { - Ok(self.rebuild_with_connection_type(Some(Arc::new(move || Box::pin(f()))), None)) + Ok(WireframeApp { + handlers: self.handlers, + routes: tokio::sync::OnceCell::new(), + middleware: self.middleware, + serializer: self.serializer, + app_data: self.app_data, + on_connect: Some(Arc::new(move || Box::pin(f()))), + on_disconnect: None, + protocol: self.protocol, + push_dlq: self.push_dlq, + codec: self.codec, + read_timeout_ms: self.read_timeout_ms, + fragmentation: self.fragmentation, + message_assembler: self.message_assembler, + }) } /// Register a callback invoked when a connection is closed. diff --git a/src/app/builder/mod.rs b/src/app/builder/mod.rs new file mode 100644 index 00000000..afa80911 --- /dev/null +++ b/src/app/builder/mod.rs @@ -0,0 +1,15 @@ +//! Application builder configuring routes and middleware. +//! [`WireframeApp`] is an Actix-inspired builder for managing connection +//! state, routing, and middleware in a `WireframeServer`. It exposes +//! convenience methods to register handlers and lifecycle hooks, and +//! serializes messages using a configurable serializer. + +mod codec; +mod config; +mod core; +mod lifecycle; +mod protocol; +mod routing; +mod state; + +pub use core::WireframeApp; diff --git a/src/app/builder_protocol.rs b/src/app/builder/protocol.rs similarity index 97% rename from src/app/builder_protocol.rs rename to src/app/builder/protocol.rs index 19132f20..c303569f 100644 --- a/src/app/builder_protocol.rs +++ b/src/app/builder/protocol.rs @@ -1,9 +1,10 @@ -//! Protocol configuration builder methods for [`WireframeApp`]. +//! Protocol and message assembly configuration for `WireframeApp`. use std::sync::Arc; -use super::{builder::WireframeApp, envelope::Packet}; +use super::WireframeApp; use crate::{ + app::Packet, codec::FrameCodec, hooks::{ProtocolHooks, WireframeProtocol}, message_assembler::MessageAssembler, @@ -75,6 +76,27 @@ where } } + /// Get a clone of the configured protocol, if any. + /// + /// Returns `None` if no protocol was installed via [`with_protocol`](Self::with_protocol). + #[must_use] + pub fn protocol( + &self, + ) -> Option>> { + self.protocol.clone() + } + + /// Return protocol hooks derived from the installed protocol. + /// + /// If no protocol is installed, returns default (no-op) hooks. + #[must_use] + pub fn protocol_hooks(&self) -> ProtocolHooks { + self.protocol + .as_ref() + .map(ProtocolHooks::from_protocol) + .unwrap_or_default() + } + /// Get the configured message assembler, if any. /// /// # Examples @@ -105,25 +127,4 @@ where pub fn message_assembler(&self) -> Option<&Arc> { self.message_assembler.as_ref() } - - /// Get a clone of the configured protocol, if any. - /// - /// Returns `None` if no protocol was installed via [`with_protocol`](Self::with_protocol). - #[must_use] - pub fn protocol( - &self, - ) -> Option>> { - self.protocol.clone() - } - - /// Return protocol hooks derived from the installed protocol. - /// - /// If no protocol is installed, returns default (no-op) hooks. - #[must_use] - pub fn protocol_hooks(&self) -> ProtocolHooks { - self.protocol - .as_ref() - .map(ProtocolHooks::from_protocol) - .unwrap_or_default() - } } diff --git a/src/app/builder/routing.rs b/src/app/builder/routing.rs new file mode 100644 index 00000000..945b58b1 --- /dev/null +++ b/src/app/builder/routing.rs @@ -0,0 +1,49 @@ +//! Routing and middleware configuration for `WireframeApp`. + +use super::WireframeApp; +use crate::{ + app::{ + Packet, + error::{Result, WireframeError}, + middleware_types::{Handler, Middleware}, + }, + codec::FrameCodec, + serializer::Serializer, +}; + +impl WireframeApp +where + S: Serializer + Send + Sync, + C: Send + 'static, + E: Packet, + F: FrameCodec, +{ + /// Register a route that maps `id` to `handler`. + /// + /// # Errors + /// + /// Returns [`WireframeError::DuplicateRoute`] if a handler for `id` + /// has already been registered. + pub fn route(mut self, id: u32, handler: Handler) -> Result { + if self.handlers.contains_key(&id) { + return Err(WireframeError::DuplicateRoute(id)); + } + self.handlers.insert(id, handler); + self.routes = tokio::sync::OnceCell::new(); + Ok(self) + } + + /// Add a middleware component to the processing pipeline. + /// + /// # Errors + /// + /// This function currently always succeeds. + pub fn wrap(mut self, mw: M) -> Result + where + M: Middleware + 'static, + { + self.middleware.push(Box::new(mw)); + self.routes = tokio::sync::OnceCell::new(); + Ok(self) + } +} diff --git a/src/app/builder/state.rs b/src/app/builder/state.rs new file mode 100644 index 00000000..904b31f8 --- /dev/null +++ b/src/app/builder/state.rs @@ -0,0 +1,30 @@ +//! Shared state configuration for `WireframeApp`. + +use std::{any::TypeId, sync::Arc}; + +use super::WireframeApp; +use crate::{app::Packet, codec::FrameCodec, serializer::Serializer}; + +impl WireframeApp +where + S: Serializer + Send + Sync, + C: Send + 'static, + E: Packet, + F: FrameCodec, +{ + /// Store a shared state value accessible to request extractors. + /// + /// The value can later be retrieved using [`crate::extractor::SharedState`]. Registering + /// another value of the same type overwrites the previous one. + #[must_use] + pub fn app_data(mut self, state: T) -> Self + where + T: Send + Sync + 'static, + { + self.app_data.insert( + TypeId::of::(), + Arc::new(state) as Arc, + ); + self + } +} diff --git a/src/app/frame_handling.rs b/src/app/frame_handling.rs deleted file mode 100644 index 72ae1248..00000000 --- a/src/app/frame_handling.rs +++ /dev/null @@ -1,350 +0,0 @@ -//! Shared helpers for frame decoding, reassembly, and response forwarding. -//! -//! Extracted from `connection.rs` to keep modules small and focused. - -use std::io; - -use bytes::Bytes; -use futures::SinkExt; -use log::warn; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::codec::Framed; - -use super::{ - Envelope, - Packet, - PacketParts, - combined_codec::ConnectionCodec, - fragmentation_state::{FragmentProcessError, FragmentationState}, -}; -use crate::{ - codec::FrameCodec, - middleware::{HandlerService, Service, ServiceRequest}, - serializer::Serializer, -}; - -/// Tracks consecutive deserialization failures and enforces a per-connection limit. -/// -/// The counter increments on each failure; reaching `limit` terminates processing. -struct DeserFailureTracker<'a> { - count: &'a mut u32, - limit: u32, -} - -impl<'a> DeserFailureTracker<'a> { - fn new(count: &'a mut u32, limit: u32) -> Self { Self { count, limit } } - - fn record( - &mut self, - correlation_id: Option, - context: &str, - err: impl std::fmt::Debug, - ) -> io::Result<()> { - *self.count = self.count.saturating_add(1); - warn!("{context}: correlation_id={correlation_id:?}, error={err:?}"); - crate::metrics::inc_deser_errors(); - if *self.count >= self.limit { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "too many deserialization failures", - )); - } - Ok(()) - } -} - -/// Context for writing handler responses to the framed stream. -/// -/// Carries the serializer, codec, and mutable framing state for a connection. -pub(crate) struct ResponseContext<'a, S, W, F> -where - S: Serializer + Send + Sync, - W: AsyncRead + AsyncWrite + Unpin, - F: FrameCodec, -{ - pub(crate) serializer: &'a S, - pub(crate) framed: &'a mut Framed>, - pub(crate) fragmentation: &'a mut Option, - pub(crate) codec: &'a F, -} - -/// Attempt to reassemble a potentially fragmented envelope. -pub(crate) fn reassemble_if_needed( - fragmentation: &mut Option, - deser_failures: &mut u32, - env: Envelope, - max_deser_failures: u32, -) -> io::Result> { - let mut failures = DeserFailureTracker::new(deser_failures, max_deser_failures); - - if let Some(state) = fragmentation.as_mut() { - let correlation_id = env.correlation_id; - match state.reassemble(env) { - Ok(Some(env)) => Ok(Some(env)), - Ok(None) => Ok(None), - Err(FragmentProcessError::Decode(err)) => { - failures.record(correlation_id, "failed to decode fragment header", err)?; - Ok(None) - } - Err(FragmentProcessError::Reassembly(err)) => { - failures.record(correlation_id, "fragment reassembly failed", err)?; - Ok(None) - } - } - } else { - Ok(Some(env)) - } -} - -/// Forward a handler response, fragmenting if required, and write to the framed stream. -pub(crate) async fn forward_response( - env: Envelope, - service: &HandlerService, - ctx: ResponseContext<'_, S, W, F>, -) -> io::Result<()> -where - S: Serializer + Send + Sync, - E: Packet, - W: AsyncRead + AsyncWrite + Unpin, - F: FrameCodec, -{ - let request = ServiceRequest::new(env.payload, env.correlation_id); - let resp = match service.call(request).await { - Ok(resp) => resp, - Err(e) => { - warn!( - "handler error: id={id}, correlation_id={correlation_id:?}, error={error:?}", - id = env.id, - correlation_id = env.correlation_id, - error = e - ); - crate::metrics::inc_handler_errors(); - return Ok(()); - } - }; - - let parts = PacketParts::new(env.id, resp.correlation_id(), resp.into_inner()) - .inherit_correlation(env.correlation_id); - let correlation_id = parts.correlation_id(); - let Ok(responses) = fragment_responses(ctx.fragmentation, parts, env.id, correlation_id) else { - return Ok(()); // already logged - }; - - for response in responses { - let Ok(bytes) = serialize_response(ctx.serializer, &response, env.id, correlation_id) - else { - break; // already logged - }; - - let send_result = - send_response_payload::(ctx.codec, ctx.framed, Bytes::from(bytes), &response) - .await; - match send_result { - Ok(()) => {} - Err(err) if should_drop_response_send_error(&err) => break, // already logged - Err(err) => return Err(err), - } - } - - Ok(()) -} - -fn should_drop_response_send_error(error: &io::Error) -> bool { - matches!( - error.kind(), - io::ErrorKind::InvalidInput | io::ErrorKind::InvalidData - ) -} - -fn fragment_responses( - fragmentation: &mut Option, - parts: PacketParts, - id: u32, - correlation_id: Option, -) -> io::Result> { - let envelope = Envelope::from_parts(parts); - match fragmentation.as_mut() { - Some(state) => match state.fragment(envelope) { - Ok(fragmented) => Ok(fragmented), - Err(err) => { - warn!( - concat!( - "failed to fragment response: id={id}, correlation_id={correlation_id:?}, ", - "error={err:?}" - ), - id = id, - correlation_id = correlation_id, - err = err - ); - crate::metrics::inc_handler_errors(); - Err(io::Error::other("fragmentation failed")) - } - }, - None => Ok(vec![envelope]), - } -} - -fn serialize_response( - serializer: &S, - response: &Envelope, - id: u32, - correlation_id: Option, -) -> io::Result> { - match serializer.serialize(response) { - Ok(bytes) => Ok(bytes), - Err(e) => { - warn!( - concat!( - "failed to serialize response: id={id}, correlation_id={correlation_id:?}, ", - "error={e:?}" - ), - id = id, - correlation_id = correlation_id, - e = e - ); - crate::metrics::inc_handler_errors(); - Err(io::Error::other("serialization failed")) - } - } -} - -/// Send a response payload over the framed stream using codec-aware wrapping. -/// -/// Wraps the raw payload bytes in the codec's native frame format via -/// [`FrameCodec::wrap_payload`] before writing to the underlying stream. -/// This ensures responses are encoded correctly for the configured protocol. -async fn send_response_payload( - codec: &F, - framed: &mut Framed>, - payload: Bytes, - response: &Envelope, -) -> io::Result<()> -where - W: AsyncRead + AsyncWrite + Unpin, - F: FrameCodec, -{ - let frame = codec.wrap_payload(payload); - if let Err(e) = framed.send(frame).await { - let id = response.id; - let correlation_id = response.correlation_id; - warn!("failed to send response: id={id}, correlation_id={correlation_id:?}, error={e:?}"); - crate::metrics::inc_handler_errors(); - return Err(e); - } - Ok(()) -} - -#[cfg(all(test, not(loom)))] -mod tests { - //! Tests for frame handling helpers and response sending. - - use bytes::Bytes; - use futures::StreamExt; - use rstest::{fixture, rstest}; - use tokio::io::DuplexStream; - - use super::*; - use crate::{ - app::combined_codec::CombinedCodec, - test_helpers::{TestAdapter, TestCodec}, - }; - - struct FramedHarness { - codec: TestCodec, - server_framed: Framed>, - client_framed: Framed>, - } - - fn build_harness(max_frame_length: usize) -> FramedHarness { - let codec = TestCodec::new(max_frame_length); - let client_codec = TestCodec::new(max_frame_length); - let (client, server) = tokio::io::duplex(256); - let server_codec = CombinedCodec::new(codec.decoder(), codec.encoder()); - let client_codec = CombinedCodec::new(client_codec.decoder(), client_codec.encoder()); - let server_framed = Framed::new(server, server_codec); - let client_framed = Framed::new(client, client_codec); - - FramedHarness { - codec, - server_framed, - client_framed, - } - } - - #[fixture] - fn harness() -> FramedHarness { - // Keep fixture setup explicit to avoid duplicated per-test harness creation. - build_harness(64) - } - - #[rstest] - #[case::ok(vec![1, 2, 3, 4], false)] - #[case::oversized(vec![0u8; 100], true)] - #[tokio::test] - async fn send_response_payload_behaviour( - #[case] payload: Vec, - #[case] expect_error: bool, - mut harness: FramedHarness, - ) { - let response = Envelope::new(1, Some(99), payload.clone()); - let result = send_response_payload::( - &harness.codec, - &mut harness.server_framed, - Bytes::from(payload.clone()), - &response, - ) - .await; - - if expect_error { - assert!( - result.is_err(), - "expected send to fail for oversized payload" - ); - assert_eq!( - result - .expect_err("oversized payload should return an error") - .kind(), - io::ErrorKind::InvalidInput - ); - return; - } - - result.expect("send should succeed"); - let frame = harness - .client_framed - .next() - .await - .expect("expected a frame") - .expect("decode should succeed"); - - assert_eq!(frame.tag, 0x42, "wrap_payload should set tag to 0x42"); - assert_eq!(frame.payload, payload, "payload should match"); - assert_eq!( - harness.codec.wraps(), - 1, - "wrap_payload should advance codec state" - ); - } - - /// Verify `ResponseContext` fields are accessible and usable. - #[rstest] - #[tokio::test] - async fn response_context_holds_references(mut harness: FramedHarness) { - use crate::serializer::BincodeSerializer; - - let serializer = BincodeSerializer; - let mut fragmentation: Option = None; - - let ctx: ResponseContext<'_, BincodeSerializer, _, TestCodec> = ResponseContext { - serializer: &serializer, - framed: &mut harness.server_framed, - fragmentation: &mut fragmentation, - codec: &harness.codec, - }; - - // Verify fields are accessible (compile-time check with runtime assertion) - assert!(ctx.fragmentation.is_none()); - } - - // Covered by `send_response_payload_behaviour` cases. -} diff --git a/src/app/frame_handling/core.rs b/src/app/frame_handling/core.rs new file mode 100644 index 00000000..4b0be8bf --- /dev/null +++ b/src/app/frame_handling/core.rs @@ -0,0 +1,52 @@ +//! Core frame handling context types. + +use std::io; + +use log::warn; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::Framed; + +use crate::{ + app::{combined_codec::ConnectionCodec, fragmentation_state::FragmentationState}, + codec::FrameCodec, + serializer::Serializer, +}; + +pub(super) struct DeserFailureTracker<'a> { + count: &'a mut u32, + limit: u32, +} + +impl<'a> DeserFailureTracker<'a> { + pub(super) fn new(count: &'a mut u32, limit: u32) -> Self { Self { count, limit } } + + pub(super) fn record( + &mut self, + correlation_id: Option, + context: &str, + err: impl std::fmt::Debug, + ) -> io::Result<()> { + *self.count += 1; + warn!("{context}: correlation_id={correlation_id:?}, error={err:?}"); + crate::metrics::inc_deser_errors(); + if *self.count >= self.limit { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "too many deserialization failures", + )); + } + Ok(()) + } +} + +pub(crate) struct ResponseContext<'a, S, W, F> +where + S: Serializer + Send + Sync, + W: AsyncRead + AsyncWrite + Unpin, + F: FrameCodec, +{ + pub(crate) serializer: &'a S, + pub(crate) framed: &'a mut Framed>, + pub(crate) fragmentation: &'a mut Option, + pub(crate) codec: &'a F, +} diff --git a/src/app/frame_handling/mod.rs b/src/app/frame_handling/mod.rs new file mode 100644 index 00000000..9cce4b6f --- /dev/null +++ b/src/app/frame_handling/mod.rs @@ -0,0 +1,15 @@ +//! Shared helpers for frame decoding, reassembly, and response forwarding. +//! +//! Extracted from `connection.rs` to keep modules small and focused. + +mod core; +mod reassembly; +mod response; + +pub(crate) use core::ResponseContext; + +pub(crate) use reassembly::reassemble_if_needed; +pub(crate) use response::forward_response; + +#[cfg(all(test, not(loom)))] +mod tests; diff --git a/src/app/frame_handling/reassembly.rs b/src/app/frame_handling/reassembly.rs new file mode 100644 index 00000000..7d9abfb4 --- /dev/null +++ b/src/app/frame_handling/reassembly.rs @@ -0,0 +1,37 @@ +//! Helpers for fragment reassembly. + +use std::io; + +use super::core::DeserFailureTracker; +use crate::app::{ + Envelope, + fragmentation_state::{FragmentProcessError, FragmentationState}, +}; + +/// Attempt to reassemble a potentially fragmented envelope. +pub(crate) fn reassemble_if_needed( + fragmentation: &mut Option, + deser_failures: &mut u32, + env: Envelope, + max_deser_failures: u32, +) -> io::Result> { + let mut failures = DeserFailureTracker::new(deser_failures, max_deser_failures); + + if let Some(state) = fragmentation.as_mut() { + let correlation_id = env.correlation_id; + match state.reassemble(env) { + Ok(Some(env)) => Ok(Some(env)), + Ok(None) => Ok(None), + Err(FragmentProcessError::Decode(err)) => { + failures.record(correlation_id, "failed to decode fragment header", err)?; + Ok(None) + } + Err(FragmentProcessError::Reassembly(err)) => { + failures.record(correlation_id, "fragment reassembly failed", err)?; + Ok(None) + } + } + } else { + Ok(Some(env)) + } +} diff --git a/src/app/frame_handling/response.rs b/src/app/frame_handling/response.rs new file mode 100644 index 00000000..a8ef2c02 --- /dev/null +++ b/src/app/frame_handling/response.rs @@ -0,0 +1,139 @@ +//! Response forwarding helpers for frame handling. + +use std::io; + +use bytes::Bytes; +use futures::SinkExt; +use log::warn; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::Framed; + +use crate::{ + app::{ + Envelope, + Packet, + PacketParts, + combined_codec::ConnectionCodec, + fragmentation_state::FragmentationState, + }, + codec::FrameCodec, + middleware::{HandlerService, Service, ServiceRequest}, + serializer::Serializer, +}; + +/// Forward a handler response, fragmenting if required, and write to the framed stream. +pub(crate) async fn forward_response( + env: Envelope, + service: &HandlerService, + ctx: super::ResponseContext<'_, S, W, F>, +) -> io::Result<()> +where + S: Serializer + Send + Sync, + E: Packet, + W: AsyncRead + AsyncWrite + Unpin, + F: FrameCodec, +{ + let request = ServiceRequest::new(env.payload, env.correlation_id); + let resp = match service.call(request).await { + Ok(resp) => resp, + Err(e) => { + warn!( + "handler error: id={}, correlation_id={:?}, error={e:?}", + env.id, env.correlation_id + ); + crate::metrics::inc_handler_errors(); + return Ok(()); + } + }; + + let parts = PacketParts::new(env.id, resp.correlation_id(), resp.into_inner()) + .inherit_correlation(env.correlation_id); + let correlation_id = parts.correlation_id(); + let Ok(responses) = fragment_responses(ctx.fragmentation, parts, env.id, correlation_id) else { + return Ok(()); // already logged + }; + + for response in responses { + let Ok(bytes) = serialize_response(ctx.serializer, &response, env.id, correlation_id) + else { + break; // already logged + }; + + if send_response_payload::(ctx.codec, ctx.framed, Bytes::from(bytes), &response) + .await + .is_err() + { + break; + } + } + + Ok(()) +} + +fn fragment_responses( + fragmentation: &mut Option, + parts: PacketParts, + id: u32, + correlation_id: Option, +) -> io::Result> { + let envelope = Envelope::from_parts(parts); + match fragmentation.as_mut() { + Some(state) => match state.fragment(envelope) { + Ok(fragmented) => Ok(fragmented), + Err(err) => { + warn!( + "failed to fragment response: id={id}, correlation_id={correlation_id:?}, \ + error={err:?}" + ); + crate::metrics::inc_handler_errors(); + Err(io::Error::other("fragmentation failed")) + } + }, + None => Ok(vec![envelope]), + } +} + +fn serialize_response( + serializer: &S, + response: &Envelope, + id: u32, + correlation_id: Option, +) -> io::Result> { + match serializer.serialize(response) { + Ok(bytes) => Ok(bytes), + Err(e) => { + warn!( + "failed to serialize response: id={id}, correlation_id={correlation_id:?}, \ + error={e:?}" + ); + crate::metrics::inc_handler_errors(); + Err(io::Error::other("serialization failed")) + } + } +} + +/// Send a response payload over the framed stream using codec-aware wrapping. +/// +/// Wraps the raw payload bytes in the codec's native frame format via +/// [`FrameCodec::wrap_payload`] before writing to the underlying stream. +/// This ensures responses are encoded correctly for the configured protocol. +pub(super) async fn send_response_payload( + codec: &F, + framed: &mut Framed>, + payload: Bytes, + response: &Envelope, +) -> io::Result<()> +where + W: AsyncRead + AsyncWrite + Unpin, + F: FrameCodec, +{ + let frame = codec.wrap_payload(payload); + if let Err(e) = framed.send(frame).await { + let id = response.id; + let correlation_id = response.correlation_id; + warn!("failed to send response: id={id}, correlation_id={correlation_id:?}, error={e:?}"); + crate::metrics::inc_handler_errors(); + return Err(io::Error::other("send failed")); + } + Ok(()) +} diff --git a/src/app/frame_handling/tests.rs b/src/app/frame_handling/tests.rs new file mode 100644 index 00000000..abaebe1c --- /dev/null +++ b/src/app/frame_handling/tests.rs @@ -0,0 +1,209 @@ +//! Tests for frame handling helpers. + +use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, +}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures::StreamExt; +use tokio_util::codec::{Decoder, Encoder}; + +use super::{ResponseContext, response::send_response_payload}; +use crate::{ + app::{Envelope, combined_codec::CombinedCodec, fragmentation_state::FragmentationState}, + codec::FrameCodec, +}; + +/// Test codec that wraps payloads with a distinctive tag byte. +#[derive(Clone, Debug)] +struct TestFrame { + tag: u8, + payload: Vec, +} + +#[derive(Clone, Debug)] +struct TestCodec { + max_frame_length: usize, + counter: Arc, +} + +impl TestCodec { + fn new(max_frame_length: usize) -> Self { + Self { + max_frame_length, + counter: Arc::new(AtomicUsize::new(0)), + } + } + + fn wraps(&self) -> usize { self.counter.load(Ordering::SeqCst) } +} + +#[derive(Clone, Debug)] +struct TestAdapter { + max_frame_length: usize, +} + +impl TestAdapter { + fn new(max_frame_length: usize) -> Self { Self { max_frame_length } } +} + +impl Decoder for TestAdapter { + type Item = TestFrame; + type Error = std::io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + const HEADER_LEN: usize = 2; + if src.len() < HEADER_LEN { + return Ok(None); + } + + let mut header = src.as_ref(); + let tag = header.get_u8(); + let payload_len = header.get_u8() as usize; + if payload_len > self.max_frame_length { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "payload too large", + )); + } + if src.len() < HEADER_LEN + payload_len { + return Ok(None); + } + + let mut frame_bytes = src.split_to(HEADER_LEN + payload_len); + frame_bytes.advance(HEADER_LEN); + let payload = frame_bytes.to_vec(); + + Ok(Some(TestFrame { tag, payload })) + } +} + +impl Encoder for TestAdapter { + type Error = std::io::Error; + + fn encode(&mut self, item: TestFrame, dst: &mut BytesMut) -> Result<(), Self::Error> { + if item.payload.len() > self.max_frame_length { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "payload too large", + )); + } + + let payload_len = u8::try_from(item.payload.len()).map_err(|_| { + std::io::Error::new(std::io::ErrorKind::InvalidInput, "payload too long") + })?; + dst.reserve(2 + item.payload.len()); + dst.put_u8(item.tag); + dst.put_u8(payload_len); + dst.extend_from_slice(&item.payload); + Ok(()) + } +} + +impl FrameCodec for TestCodec { + type Frame = TestFrame; + type Decoder = TestAdapter; + type Encoder = TestAdapter; + + fn decoder(&self) -> Self::Decoder { TestAdapter::new(self.max_frame_length) } + + fn encoder(&self) -> Self::Encoder { TestAdapter::new(self.max_frame_length) } + + fn frame_payload(frame: &Self::Frame) -> &[u8] { frame.payload.as_slice() } + + /// Wraps payload with tag byte 0x42 to verify codec-aware wrapping. + fn wrap_payload(&self, payload: Bytes) -> Self::Frame { + self.counter.fetch_add(1, Ordering::SeqCst); + TestFrame { + tag: 0x42, + payload: payload.to_vec(), + } + } + + fn correlation_id(frame: &Self::Frame) -> Option { Some(u64::from(frame.tag)) } + + fn max_frame_length(&self) -> usize { self.max_frame_length } +} + +/// Verify `send_response_payload` uses `F::wrap_payload` to frame responses. +#[tokio::test] +async fn send_response_payload_wraps_with_codec() { + let codec = TestCodec::new(64); + let (client, server) = tokio::io::duplex(256); + let combined = CombinedCodec::new(codec.decoder(), codec.encoder()); + let mut framed = tokio_util::codec::Framed::new(server, combined); + + let payload = vec![1, 2, 3, 4]; + let response = Envelope::new(1, Some(99), payload.clone()); + send_response_payload::( + &codec, + &mut framed, + Bytes::from(payload.clone()), + &response, + ) + .await + .expect("send should succeed"); + + drop(framed); + + let combined_client = CombinedCodec::new(codec.decoder(), codec.encoder()); + let mut client_framed = tokio_util::codec::Framed::new(client, combined_client); + let frame = client_framed + .next() + .await + .expect("expected a frame") + .expect("decode should succeed"); + + assert_eq!(frame.tag, 0x42, "wrap_payload should set tag to 0x42"); + assert_eq!(frame.payload, payload, "payload should match"); + assert_eq!(codec.wraps(), 1, "wrap_payload should advance codec state"); +} + +/// Verify `ResponseContext` fields are accessible and usable. +#[tokio::test] +async fn response_context_holds_references() { + use crate::serializer::BincodeSerializer; + + let codec = TestCodec::new(64); + let (_client, server) = tokio::io::duplex(256); + let combined = CombinedCodec::new(codec.decoder(), codec.encoder()); + let mut framed = tokio_util::codec::Framed::new(server, combined); + let serializer = BincodeSerializer; + let mut fragmentation: Option = None; + + let ctx: ResponseContext<'_, BincodeSerializer, _, TestCodec> = ResponseContext { + serializer: &serializer, + framed: &mut framed, + fragmentation: &mut fragmentation, + codec: &codec, + }; + + // Verify fields are accessible (compile-time check with runtime assertion) + assert!(ctx.fragmentation.is_none()); +} + +/// Verify `send_response_payload` returns error on send failure. +#[tokio::test] +async fn send_response_payload_returns_error_on_failure() { + let codec = TestCodec::new(4); // Small limit to trigger failure + let (_client, server) = tokio::io::duplex(256); + let combined = CombinedCodec::new(codec.decoder(), codec.encoder()); + let mut framed = tokio_util::codec::Framed::new(server, combined); + + // Payload exceeds max_frame_length, so encode will fail + let oversized_payload = vec![0u8; 100]; + let response = Envelope::new(1, Some(99), oversized_payload.clone()); + let result = send_response_payload::( + &codec, + &mut framed, + Bytes::from(oversized_payload), + &response, + ) + .await; + + assert!( + result.is_err(), + "expected send to fail for oversized payload" + ); +} diff --git a/src/app/mod.rs b/src/app/mod.rs index dd1f4544..74fdcff0 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -10,8 +10,6 @@ mod builder; mod builder_defaults; -mod builder_lifecycle; -mod builder_protocol; mod combined_codec; mod connection; mod envelope; diff --git a/src/client/builder.rs b/src/client/builder.rs deleted file mode 100644 index ae5f9ed1..00000000 --- a/src/client/builder.rs +++ /dev/null @@ -1,425 +0,0 @@ -//! Builder for configuring and connecting a wireframe client. - -use std::{ - future::Future, - net::SocketAddr, - sync::{Arc, atomic::AtomicU64}, -}; - -use bincode::Encode; -use tokio::net::TcpSocket; -use tokio_util::codec::Framed; - -use super::{ - ClientCodecConfig, - ClientError, - SocketOptions, - WireframeClient, - hooks::LifecycleHooks, - preamble_exchange::{PreambleConfig, perform_preamble_exchange}, -}; -use crate::{ - frame::LengthFormat, - rewind_stream::RewindStream, - serializer::{BincodeSerializer, Serializer}, -}; - -const INITIAL_READ_BUFFER_CAPACITY_LIMIT: usize = 64 * 1024; - -/// Reconstructs `WireframeClientBuilder` with one field updated to a new value. -/// -/// This macro reduces duplication in type-changing builder methods that need to -/// create a new builder instance with different generic parameters. When a type -/// parameter changes, struct update syntax (`..self`) cannot be used, so fields -/// must be copied explicitly. -/// -/// The `lifecycle_hooks` field requires special handling because `LifecycleHooks` -/// is parameterized by the connection state type. When changing `S` or `P`, the -/// hooks can be moved directly since `C` is unchanged. When changing `C` via -/// `on_connection_setup`, a new `LifecycleHooks` must be constructed. -macro_rules! builder_field_update { - // Serializer change: preserves P and C, moves lifecycle_hooks unchanged - ($self:expr,serializer = $value:expr) => { - WireframeClientBuilder { - serializer: $value, - codec_config: $self.codec_config, - socket_options: $self.socket_options, - preamble_config: $self.preamble_config, - lifecycle_hooks: $self.lifecycle_hooks, - } - }; - // Preamble change: preserves S and C, moves lifecycle_hooks unchanged - ($self:expr,preamble_config = $value:expr) => { - WireframeClientBuilder { - serializer: $self.serializer, - codec_config: $self.codec_config, - socket_options: $self.socket_options, - preamble_config: $value, - lifecycle_hooks: $self.lifecycle_hooks, - } - }; - // Lifecycle hooks change: preserves S and P, changes C - ($self:expr,lifecycle_hooks = $value:expr) => { - WireframeClientBuilder { - serializer: $self.serializer, - codec_config: $self.codec_config, - socket_options: $self.socket_options, - preamble_config: $self.preamble_config, - lifecycle_hooks: $value, - } - }; -} - -/// Builder for [`WireframeClient`]. -/// -/// The builder supports three generic type parameters: -/// - `S`: The serializer type (default: `BincodeSerializer`) -/// - `P`: The preamble type (default: `()`) -/// - `C`: The connection state type returned by the setup hook (default: `()`) -/// -/// # Examples -/// -/// ``` -/// use wireframe::client::WireframeClientBuilder; -/// -/// let builder = WireframeClientBuilder::new(); -/// let _ = builder; -/// ``` -pub struct WireframeClientBuilder { - pub(crate) serializer: S, - pub(crate) codec_config: ClientCodecConfig, - pub(crate) socket_options: SocketOptions, - pub(crate) preamble_config: Option>, - pub(crate) lifecycle_hooks: LifecycleHooks, -} - -impl WireframeClientBuilder { - /// Create a new builder with default settings. - /// - /// # Examples - /// - /// ``` - /// use wireframe::client::WireframeClientBuilder; - /// - /// let builder = WireframeClientBuilder::new(); - /// let _ = builder; - /// ``` - #[must_use] - pub fn new() -> Self { - Self { - serializer: BincodeSerializer, - codec_config: ClientCodecConfig::default(), - socket_options: SocketOptions::default(), - preamble_config: None, - lifecycle_hooks: LifecycleHooks::default(), - } - } -} - -impl Default for WireframeClientBuilder { - fn default() -> Self { Self::new() } -} - -impl WireframeClientBuilder -where - S: Serializer + Send + Sync, -{ - /// Replace the serializer used for encoding and decoding messages. - /// - /// # Examples - /// - /// ``` - /// use wireframe::{BincodeSerializer, client::WireframeClientBuilder}; - /// - /// let builder = WireframeClientBuilder::new().serializer(BincodeSerializer); - /// let _ = builder; - /// ``` - #[must_use] - pub fn serializer(self, serializer: Ser) -> WireframeClientBuilder - where - Ser: Serializer + Send + Sync, - { - builder_field_update!(self, serializer = serializer) - } - - /// Configure codec settings for the connection. - /// - /// # Examples - /// - /// ``` - /// use wireframe::client::{ClientCodecConfig, WireframeClientBuilder}; - /// - /// let codec = ClientCodecConfig::default().max_frame_length(2048); - /// let builder = WireframeClientBuilder::new().codec_config(codec); - /// let _ = builder; - /// ``` - #[must_use] - pub fn codec_config(mut self, codec_config: ClientCodecConfig) -> Self { - self.codec_config = codec_config; - self - } - - /// Configure the maximum frame length for the connection. - /// - /// # Examples - /// - /// ``` - /// use wireframe::client::WireframeClientBuilder; - /// - /// let builder = WireframeClientBuilder::new().max_frame_length(2048); - /// let _ = builder; - /// ``` - #[must_use] - pub fn max_frame_length(mut self, max_frame_length: usize) -> Self { - self.codec_config = self.codec_config.max_frame_length(max_frame_length); - self - } - - /// Configure the length prefix format for the connection. - /// - /// # Examples - /// - /// ``` - /// use wireframe::{client::WireframeClientBuilder, frame::LengthFormat}; - /// - /// let builder = WireframeClientBuilder::new().length_format(LengthFormat::u16_be()); - /// let _ = builder; - /// ``` - #[must_use] - pub fn length_format(mut self, length_format: LengthFormat) -> Self { - self.codec_config = self.codec_config.length_format(length_format); - self - } - - // Socket option convenience methods (nodelay, keepalive, linger, etc.) are - // in `socket_option_methods.rs` to keep this file under 400 lines. - - /// Configure a preamble to send before exchanging frames. - /// - /// The preamble is written to the server immediately after establishing - /// the TCP connection, before the framing layer begins. Use - /// [`on_preamble_success`](Self::on_preamble_success) to read the server's - /// response and [`preamble_timeout`](Self::preamble_timeout) to bound the - /// exchange. - /// - /// # Examples - /// - /// ``` - /// use wireframe::client::WireframeClientBuilder; - /// - /// #[derive(bincode::Encode)] - /// struct MyPreamble { - /// version: u16, - /// } - /// - /// let builder = WireframeClientBuilder::new().with_preamble(MyPreamble { version: 1 }); - /// let _ = builder; - /// ``` - #[must_use] - pub fn with_preamble(self, preamble: Q) -> WireframeClientBuilder - where - Q: Encode + Send + Sync + 'static, - { - builder_field_update!(self, preamble_config = Some(PreambleConfig::new(preamble))) - } - - /// Register a callback invoked when the connection is established. - /// - /// The callback can perform authentication or other setup tasks and returns - /// connection-specific state stored for the connection's lifetime. This - /// hook is invoked after the preamble exchange (if configured) succeeds. - /// - /// # Type Parameters - /// - /// This method changes the connection state type parameter from `C` to - /// `C2`. Subsequent builder methods will operate on the new connection - /// state type. Note that any previously configured `on_connection_teardown` - /// hook is cleared because its type signature depends on the old state - /// type. The `on_error` hook is preserved since it does not depend on - /// the connection state type. - /// - /// # Examples - /// - /// ``` - /// use wireframe::client::WireframeClientBuilder; - /// - /// struct Session { - /// id: u64, - /// } - /// - /// let builder = - /// WireframeClientBuilder::new().on_connection_setup(|| async { Session { id: 42 } }); - /// let _ = builder; - /// ``` - #[must_use] - pub fn on_connection_setup(self, f: F) -> WireframeClientBuilder - where - F: Fn() -> Fut + Send + Sync + 'static, - Fut: Future + Send + 'static, - C2: Send + 'static, - { - // Preserve on_error since it is not parameterized by C. - // on_disconnect must be cleared because its signature depends on C. - let on_error = self.lifecycle_hooks.on_error; - builder_field_update!( - self, - lifecycle_hooks = LifecycleHooks { - on_connect: Some(Arc::new(move || Box::pin(f()))), - on_disconnect: None, - on_error, - } - ) - } -} - -impl WireframeClientBuilder -where - S: Serializer + Send + Sync, - C: Send + 'static, -{ - /// Register a callback invoked when the connection is closed. - /// - /// The callback receives the connection state produced by - /// [`on_connection_setup`](Self::on_connection_setup). The teardown hook - /// is invoked when [`WireframeClient::close`](super::WireframeClient::close) - /// is called. - /// - /// # Examples - /// - /// ``` - /// use wireframe::client::WireframeClientBuilder; - /// - /// struct Session { - /// id: u64, - /// } - /// - /// let builder = WireframeClientBuilder::new() - /// .on_connection_setup(|| async { Session { id: 42 } }) - /// .on_connection_teardown(|session| async move { - /// println!("Session {} closed", session.id); - /// }); - /// let _ = builder; - /// ``` - #[must_use] - pub fn on_connection_teardown(mut self, f: F) -> Self - where - F: Fn(C) -> Fut + Send + Sync + 'static, - Fut: Future + Send + 'static, - { - self.lifecycle_hooks.on_disconnect = Some(Arc::new(move |c| Box::pin(f(c)))); - self - } - - /// Register a callback invoked when an error occurs. - /// - /// The callback receives a reference to the error and can perform logging - /// or recovery actions. The handler is invoked before the error is returned - /// to the caller. - /// - /// # Examples - /// - /// ``` - /// use wireframe::client::WireframeClientBuilder; - /// - /// let builder = WireframeClientBuilder::new().on_error(|err| { - /// let message = err.to_string(); - /// async move { - /// eprintln!("Client error: {message}"); - /// } - /// }); - /// let _ = builder; - /// ``` - #[must_use] - pub fn on_error(mut self, f: F) -> Self - where - F: for<'a> Fn(&'a ClientError) -> Fut + Send + Sync + 'static, - Fut: Future + Send + 'static, - { - self.lifecycle_hooks.on_error = Some(Arc::new(move |e| Box::pin(f(e)))); - self - } -} - -impl WireframeClientBuilder -where - S: Serializer + Send + Sync, - P: Encode + Send + Sync + 'static, - C: Send + 'static, -{ - /// Establish a connection and return a configured client. - /// - /// If a preamble is configured, it is written to the server before the - /// framing layer is established. The success callback (if registered) is - /// invoked after writing the preamble and may read the server's response. - /// If a connection setup hook is registered, it is invoked after the - /// preamble exchange succeeds. - /// - /// # Errors - /// - /// Returns [`ClientError`] if socket configuration, connection, or - /// preamble exchange fails. - /// - /// # Examples - /// - /// ```no_run - /// use std::net::SocketAddr; - /// - /// use wireframe::{ClientError, WireframeClient}; - /// - /// # #[tokio::main] - /// # async fn main() -> Result<(), ClientError> { - /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address"); - /// let _client = WireframeClient::builder().connect(addr).await?; - /// # Ok(()) - /// # } - /// ``` - pub async fn connect( - self, - addr: SocketAddr, - ) -> Result, C>, ClientError> { - let socket = if addr.is_ipv4() { - TcpSocket::new_v4()? - } else { - TcpSocket::new_v6()? - }; - self.socket_options.apply(&socket)?; - let mut stream = socket.connect(addr).await?; - - // Perform preamble exchange if configured. - let leftover = if let Some(config) = self.preamble_config { - perform_preamble_exchange(&mut stream, config).await? - } else { - Vec::new() - }; - - // Build framed codec, always wrapping in RewindStream for type consistency. - // When leftover is empty, RewindStream has negligible overhead. - let codec_config = self.codec_config; - let codec = codec_config.build_codec(); - let mut framed = Framed::new(RewindStream::new(leftover, stream), codec); - let initial_read_buffer_capacity = core::cmp::min( - INITIAL_READ_BUFFER_CAPACITY_LIMIT, - codec_config.max_frame_length_value(), - ); - framed - .read_buffer_mut() - .reserve(initial_read_buffer_capacity); - - // Invoke connection setup hook if configured. - let connection_state = if let Some(ref setup) = self.lifecycle_hooks.on_connect { - Some(setup().await) - } else { - None - }; - - Ok(WireframeClient { - framed, - serializer: self.serializer, - codec_config, - connection_state, - on_disconnect: self.lifecycle_hooks.on_disconnect, - on_error: self.lifecycle_hooks.on_error, - correlation_counter: AtomicU64::new(1), - }) - } -} diff --git a/src/client/builder/codec.rs b/src/client/builder/codec.rs new file mode 100644 index 00000000..d28523a8 --- /dev/null +++ b/src/client/builder/codec.rs @@ -0,0 +1,58 @@ +//! Codec configuration methods for `WireframeClientBuilder`. + +use super::WireframeClientBuilder; +use crate::{client::ClientCodecConfig, frame::LengthFormat, serializer::Serializer}; + +impl WireframeClientBuilder +where + S: Serializer + Send + Sync, +{ + /// Configure codec settings for the connection. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::{ClientCodecConfig, WireframeClientBuilder}; + /// + /// let codec = ClientCodecConfig::default().max_frame_length(2048); + /// let builder = WireframeClientBuilder::new().codec_config(codec); + /// let _ = builder; + /// ``` + #[must_use] + pub fn codec_config(mut self, codec_config: ClientCodecConfig) -> Self { + self.codec_config = codec_config; + self + } + + /// Configure the maximum frame length for the connection. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// let builder = WireframeClientBuilder::new().max_frame_length(2048); + /// let _ = builder; + /// ``` + #[must_use] + pub fn max_frame_length(mut self, max_frame_length: usize) -> Self { + self.codec_config = self.codec_config.max_frame_length(max_frame_length); + self + } + + /// Configure the length prefix format for the connection. + /// + /// # Examples + /// + /// ``` + /// use wireframe::{client::WireframeClientBuilder, frame::LengthFormat}; + /// + /// let builder = WireframeClientBuilder::new().length_format(LengthFormat::u16_be()); + /// let _ = builder; + /// ``` + #[must_use] + pub fn length_format(mut self, length_format: LengthFormat) -> Self { + self.codec_config = self.codec_config.length_format(length_format); + self + } +} diff --git a/src/client/builder/connect.rs b/src/client/builder/connect.rs new file mode 100644 index 00000000..ffb12c35 --- /dev/null +++ b/src/client/builder/connect.rs @@ -0,0 +1,96 @@ +//! Connection establishment for `WireframeClientBuilder`. + +use std::{net::SocketAddr, sync::atomic::AtomicU64}; + +use bincode::Encode; +use tokio::net::TcpSocket; +use tokio_util::codec::Framed; + +use super::WireframeClientBuilder; +use crate::{ + client::{ClientError, WireframeClient, preamble_exchange::perform_preamble_exchange}, + rewind_stream::RewindStream, + serializer::Serializer, +}; + +impl WireframeClientBuilder +where + S: Serializer + Send + Sync, + P: Encode + Send + Sync + 'static, + C: Send + 'static, +{ + /// Establish a connection and return a configured client. + /// + /// If a preamble is configured, it is written to the server before the + /// framing layer is established. The success callback (if registered) is + /// invoked after writing the preamble and may read the server's response. + /// If a connection setup hook is registered, it is invoked after the + /// preamble exchange succeeds. + /// + /// # Errors + /// + /// Returns [`ClientError`] if socket configuration, connection, or + /// preamble exchange fails. + /// + /// # Examples + /// + /// ```no_run + /// use std::net::SocketAddr; + /// + /// use wireframe::{ClientError, WireframeClient}; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<(), ClientError> { + /// let addr: SocketAddr = "127.0.0.1:9000".parse().expect("valid socket address"); + /// let _client = WireframeClient::builder().connect(addr).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn connect( + self, + addr: SocketAddr, + ) -> Result, C>, ClientError> { + let socket = if addr.is_ipv4() { + TcpSocket::new_v4()? + } else { + TcpSocket::new_v6()? + }; + self.socket_options.apply(&socket)?; + let mut stream = socket.connect(addr).await?; + + // Perform preamble exchange if configured. + let leftover = if let Some(config) = self.preamble_config { + perform_preamble_exchange(&mut stream, config).await? + } else { + Vec::new() + }; + + // Build framed codec, always wrapping in RewindStream for type consistency. + // When leftover is empty, RewindStream has negligible overhead. + let codec_config = self.codec_config; + let codec = codec_config.build_codec(); + let mut framed = Framed::new(RewindStream::new(leftover, stream), codec); + let initial_read_buffer_capacity = + core::cmp::min(64 * 1024, codec_config.max_frame_length_value()); + framed + .read_buffer_mut() + .reserve(initial_read_buffer_capacity); + + // Invoke connection setup hook if configured. + let connection_state = if let Some(ref setup) = self.lifecycle_hooks.on_connect { + Some(setup().await) + } else { + None + }; + + Ok(WireframeClient { + framed, + serializer: self.serializer, + codec_config, + connection_state, + on_disconnect: self.lifecycle_hooks.on_disconnect, + on_error: self.lifecycle_hooks.on_error, + correlation_counter: AtomicU64::new(1), + }) + } +} diff --git a/src/client/builder/core.rs b/src/client/builder/core.rs new file mode 100644 index 00000000..13727aa8 --- /dev/null +++ b/src/client/builder/core.rs @@ -0,0 +1,61 @@ +//! Core wireframe client builder type. + +use crate::{ + client::{ + ClientCodecConfig, + SocketOptions, + hooks::LifecycleHooks, + preamble_exchange::PreambleConfig, + }, + serializer::BincodeSerializer, +}; + +/// Builder for [`WireframeClient`](crate::client::WireframeClient). +/// +/// The builder supports three generic type parameters: +/// - `S`: The serializer type (default: `BincodeSerializer`) +/// - `P`: The preamble type (default: `()`) +/// - `C`: The connection state type returned by the setup hook (default: `()`) +/// +/// # Examples +/// +/// ``` +/// use wireframe::client::WireframeClientBuilder; +/// +/// let builder = WireframeClientBuilder::new(); +/// let _ = builder; +/// ``` +pub struct WireframeClientBuilder { + pub(crate) serializer: S, + pub(crate) codec_config: ClientCodecConfig, + pub(crate) socket_options: SocketOptions, + pub(crate) preamble_config: Option>, + pub(crate) lifecycle_hooks: LifecycleHooks, +} + +impl WireframeClientBuilder { + /// Create a new builder with default settings. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// let builder = WireframeClientBuilder::new(); + /// let _ = builder; + /// ``` + #[must_use] + pub fn new() -> Self { + Self { + serializer: BincodeSerializer, + codec_config: ClientCodecConfig::default(), + socket_options: SocketOptions::default(), + preamble_config: None, + lifecycle_hooks: LifecycleHooks::default(), + } + } +} + +impl Default for WireframeClientBuilder { + fn default() -> Self { Self::new() } +} diff --git a/src/client/builder/lifecycle.rs b/src/client/builder/lifecycle.rs new file mode 100644 index 00000000..a9269f2f --- /dev/null +++ b/src/client/builder/lifecycle.rs @@ -0,0 +1,127 @@ +//! Lifecycle hook methods for `WireframeClientBuilder`. + +use std::{future::Future, sync::Arc}; + +use super::WireframeClientBuilder; +use crate::{ + client::{ClientError, hooks::LifecycleHooks}, + serializer::Serializer, +}; + +impl WireframeClientBuilder +where + S: Serializer + Send + Sync, +{ + /// Register a callback invoked when the connection is established. + /// + /// The callback can perform authentication or other setup tasks and returns + /// connection-specific state stored for the connection's lifetime. This + /// hook is invoked after the preamble exchange (if configured) succeeds. + /// + /// # Type Parameters + /// + /// This method changes the connection state type parameter from `C` to + /// `C2`. Subsequent builder methods will operate on the new connection + /// state type. Note that any previously configured `on_connection_teardown` + /// hook is cleared because its type signature depends on the old state + /// type. The `on_error` hook is preserved since it does not depend on + /// the connection state type. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// struct Session { + /// id: u64, + /// } + /// + /// let builder = + /// WireframeClientBuilder::new().on_connection_setup(|| async { Session { id: 42 } }); + /// let _ = builder; + /// ``` + #[must_use] + pub fn on_connection_setup(self, f: F) -> WireframeClientBuilder + where + F: Fn() -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + C2: Send + 'static, + { + // Preserve on_error since it is not parameterized by C. + // on_disconnect must be cleared because its signature depends on C. + let on_error = self.lifecycle_hooks.on_error; + builder_field_update!( + self, + lifecycle_hooks = LifecycleHooks { + on_connect: Some(Arc::new(move || Box::pin(f()))), + on_disconnect: None, + on_error, + } + ) + } +} + +impl WireframeClientBuilder +where + S: Serializer + Send + Sync, + C: Send + 'static, +{ + /// Register a callback invoked when the connection is closed. + /// + /// The callback receives the connection state produced by + /// [`on_connection_setup`](Self::on_connection_setup). The teardown hook + /// is invoked when [`WireframeClient::close`](crate::client::WireframeClient::close) + /// is called. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// struct Session { + /// id: u64, + /// } + /// + /// let builder = WireframeClientBuilder::new() + /// .on_connection_setup(|| async { Session { id: 42 } }) + /// .on_connection_teardown(|session| async move { + /// println!("Session {} closed", session.id); + /// }); + /// let _ = builder; + /// ``` + #[must_use] + pub fn on_connection_teardown(mut self, f: F) -> Self + where + F: Fn(C) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.lifecycle_hooks.on_disconnect = Some(Arc::new(move |c| Box::pin(f(c)))); + self + } + + /// Register a callback invoked when an error occurs. + /// + /// The callback receives a reference to the error and can perform logging + /// or recovery actions. The handler is invoked before the error is returned + /// to the caller. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// let builder = WireframeClientBuilder::new().on_error(|err| async move { + /// eprintln!("Client error: {err}"); + /// }); + /// let _ = builder; + /// ``` + #[must_use] + pub fn on_error(mut self, f: F) -> Self + where + F: for<'a> Fn(&'a ClientError) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + { + self.lifecycle_hooks.on_error = Some(Arc::new(move |e| Box::pin(f(e)))); + self + } +} diff --git a/src/client/builder/mod.rs b/src/client/builder/mod.rs new file mode 100644 index 00000000..56ba8343 --- /dev/null +++ b/src/client/builder/mod.rs @@ -0,0 +1,54 @@ +//! Builder for configuring and connecting a wireframe client. + +/// Reconstructs `WireframeClientBuilder` with one field updated to a new value. +/// +/// This macro reduces duplication in type-changing builder methods that need to +/// create a new builder instance with different generic parameters. When a type +/// parameter changes, struct update syntax (`..self`) cannot be used, so fields +/// must be copied explicitly. +/// +/// The `lifecycle_hooks` field requires special handling because `LifecycleHooks` +/// is parameterized by the connection state type. When changing `S` or `P`, the +/// hooks can be moved directly since `C` is unchanged. When changing `C` via +/// `on_connection_setup`, a new `LifecycleHooks` must be constructed. +macro_rules! builder_field_update { + // Serializer change: preserves P and C, moves lifecycle_hooks unchanged + ($self:expr,serializer = $value:expr) => { + WireframeClientBuilder { + serializer: $value, + codec_config: $self.codec_config, + socket_options: $self.socket_options, + preamble_config: $self.preamble_config, + lifecycle_hooks: $self.lifecycle_hooks, + } + }; + // Preamble change: preserves S and C, moves lifecycle_hooks unchanged + ($self:expr,preamble_config = $value:expr) => { + WireframeClientBuilder { + serializer: $self.serializer, + codec_config: $self.codec_config, + socket_options: $self.socket_options, + preamble_config: $value, + lifecycle_hooks: $self.lifecycle_hooks, + } + }; + // Lifecycle hooks change: preserves S and P, changes C + ($self:expr,lifecycle_hooks = $value:expr) => { + WireframeClientBuilder { + serializer: $self.serializer, + codec_config: $self.codec_config, + socket_options: $self.socket_options, + preamble_config: $self.preamble_config, + lifecycle_hooks: $value, + } + }; +} + +mod codec; +mod connect; +mod core; +mod lifecycle; +mod preamble; +mod serializer; + +pub use core::WireframeClientBuilder; diff --git a/src/client/builder/preamble.rs b/src/client/builder/preamble.rs new file mode 100644 index 00000000..6fb57889 --- /dev/null +++ b/src/client/builder/preamble.rs @@ -0,0 +1,40 @@ +//! Preamble configuration methods for `WireframeClientBuilder`. + +use bincode::Encode; + +use super::WireframeClientBuilder; +use crate::{client::preamble_exchange::PreambleConfig, serializer::Serializer}; + +impl WireframeClientBuilder +where + S: Serializer + Send + Sync, +{ + /// Configure a preamble to send before exchanging frames. + /// + /// The preamble is written to the server immediately after establishing + /// the TCP connection, before the framing layer begins. Use + /// [`on_preamble_success`](Self::on_preamble_success) to read the server's + /// response and [`preamble_timeout`](Self::preamble_timeout) to bound the + /// exchange. + /// + /// # Examples + /// + /// ``` + /// use wireframe::client::WireframeClientBuilder; + /// + /// #[derive(bincode::Encode)] + /// struct MyPreamble { + /// version: u16, + /// } + /// + /// let builder = WireframeClientBuilder::new().with_preamble(MyPreamble { version: 1 }); + /// let _ = builder; + /// ``` + #[must_use] + pub fn with_preamble(self, preamble: Q) -> WireframeClientBuilder + where + Q: Encode + Send + Sync + 'static, + { + builder_field_update!(self, preamble_config = Some(PreambleConfig::new(preamble))) + } +} diff --git a/src/client/builder/serializer.rs b/src/client/builder/serializer.rs new file mode 100644 index 00000000..be646f27 --- /dev/null +++ b/src/client/builder/serializer.rs @@ -0,0 +1,27 @@ +//! Serializer configuration methods for `WireframeClientBuilder`. + +use super::WireframeClientBuilder; +use crate::serializer::Serializer; + +impl WireframeClientBuilder +where + S: Serializer + Send + Sync, +{ + /// Replace the serializer used for encoding and decoding messages. + /// + /// # Examples + /// + /// ``` + /// use wireframe::{BincodeSerializer, client::WireframeClientBuilder}; + /// + /// let builder = WireframeClientBuilder::new().serializer(BincodeSerializer); + /// let _ = builder; + /// ``` + #[must_use] + pub fn serializer(self, serializer: Ser) -> WireframeClientBuilder + where + Ser: Serializer + Send + Sync, + { + builder_field_update!(self, serializer = serializer) + } +} diff --git a/src/codec/recovery.rs b/src/codec/recovery.rs deleted file mode 100644 index 040398ee..00000000 --- a/src/codec/recovery.rs +++ /dev/null @@ -1,331 +0,0 @@ -//! Recovery policy types and hooks for codec error handling. -//! -//! This module provides infrastructure for customising how codec errors are -//! handled, including recovery policies, error context for structured logging, -//! and hooks for application-specific error handling. -//! -//! # Recovery Policies -//! -//! When a codec error occurs, the framework applies a recovery policy: -//! -//! - [`RecoveryPolicy::Drop`]: Discard the malformed frame and continue processing. Suitable for -//! recoverable errors like oversized frames. -//! - [`RecoveryPolicy::Quarantine`]: Pause the connection temporarily before retrying. Useful for -//! rate-limiting misbehaving clients. -//! - [`RecoveryPolicy::Disconnect`]: Terminate the connection immediately. Required for -//! unrecoverable errors like I/O failures. -//! -//! # Custom Recovery Hooks -//! -//! Applications can customise error handling by implementing -//! [`RecoveryPolicyHook`]: -//! -//! ``` -//! use std::time::Duration; -//! -//! use wireframe::codec::{CodecError, CodecErrorContext, RecoveryPolicy, RecoveryPolicyHook}; -//! -//! struct StrictRecovery; -//! -//! impl RecoveryPolicyHook for StrictRecovery { -//! fn recovery_policy(&self, _error: &CodecError, _ctx: &CodecErrorContext) -> RecoveryPolicy { -//! // Disconnect on any error -//! RecoveryPolicy::Disconnect -//! } -//! } -//! ``` - -use std::{net::SocketAddr, time::Duration}; - -use super::error::CodecError; - -/// Recovery policies for codec errors. -/// -/// Each policy defines how the framework responds to a codec error. -/// -/// # Default Behaviour -/// -/// [`CodecError::default_recovery_policy`] returns the recommended policy for -/// each error type. Applications can override this via [`RecoveryPolicyHook`]. -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] -pub enum RecoveryPolicy { - /// Discard the malformed frame and continue processing. - /// - /// This is the default for recoverable errors like oversized frames or - /// protocol violations that affect only a single message. The connection - /// remains open and can process subsequent frames. - /// - /// # When to Use - /// - /// - Oversized frames that exceed `max_frame_length` - /// - Empty frames where non-empty is expected - /// - Protocol-level errors (unknown message type, sequence violation) - #[default] - Drop, - - /// Pause the connection temporarily before retrying. - /// - /// The connection enters a quarantine state for a configurable duration. - /// During quarantine, no frames are processed. After the timeout, normal - /// processing resumes. - /// - /// # When to Use - /// - /// - Rate-limiting misbehaving clients - /// - Temporary back-off after repeated errors - /// - Giving time for upstream issues to resolve - Quarantine, - - /// Terminate the connection immediately. - /// - /// The connection is closed without processing further frames. This is - /// required for unrecoverable errors where the framing state is corrupted - /// or the transport has failed. - /// - /// # When to Use - /// - /// - I/O errors (socket closed, write failed) - /// - Invalid frame length encoding (framing state corrupted) - /// - EOF conditions (connection ending) - Disconnect, -} - -impl RecoveryPolicy { - /// Returns the policy name as a static string for metrics and logging. - /// - /// # Examples - /// - /// ``` - /// use wireframe::codec::RecoveryPolicy; - /// - /// assert_eq!(RecoveryPolicy::Drop.as_str(), "drop"); - /// assert_eq!(RecoveryPolicy::Quarantine.as_str(), "quarantine"); - /// assert_eq!(RecoveryPolicy::Disconnect.as_str(), "disconnect"); - /// ``` - #[must_use] - pub const fn as_str(self) -> &'static str { - match self { - Self::Drop => "drop", - Self::Quarantine => "quarantine", - Self::Disconnect => "disconnect", - } - } -} - -/// Structured context for codec error logging and diagnostics. -/// -/// This struct captures connection-specific information to include in -/// structured logs and metrics when codec errors occur. -/// -/// # Examples -/// -/// ``` -/// use wireframe::codec::CodecErrorContext; -/// -/// let ctx = CodecErrorContext::new() -/// .with_connection_id(42) -/// .with_correlation_id(123); -/// -/// assert_eq!(ctx.connection_id, Some(42)); -/// assert_eq!(ctx.correlation_id, Some(123)); -/// ``` -#[derive(Clone, Debug, Default)] -pub struct CodecErrorContext { - /// Unique identifier for the connection. - pub connection_id: Option, - - /// Remote peer address. - pub peer_address: Option, - - /// Correlation identifier from the frame, if available. - /// - /// This helps attribute errors to specific requests when debugging. - pub correlation_id: Option, - - /// Codec instance state for debugging. - /// - /// May include sequence numbers, bytes processed, or other codec-specific - /// state information. - pub codec_state: Option, -} - -impl CodecErrorContext { - /// Create a new empty context. - #[must_use] - pub fn new() -> Self { Self::default() } - - /// Set the connection identifier. - #[must_use] - pub fn with_connection_id(mut self, id: u64) -> Self { - self.connection_id = Some(id); - self - } - - /// Set the peer address. - #[must_use] - pub fn with_peer_address(mut self, addr: SocketAddr) -> Self { - self.peer_address = Some(addr); - self - } - - /// Set the correlation identifier. - #[must_use] - pub fn with_correlation_id(mut self, id: u64) -> Self { - self.correlation_id = Some(id); - self - } - - /// Set codec-specific state information. - #[must_use] - pub fn with_codec_state(mut self, state: impl Into) -> Self { - self.codec_state = Some(state.into()); - self - } -} - -/// Hook trait for customising codec error recovery behaviour. -/// -/// Implementations can override default recovery policies based on -/// application-specific requirements or connection state. -/// -/// # Default Implementation -/// -/// The default implementation ([`DefaultRecoveryPolicy`]) delegates to -/// [`CodecError::default_recovery_policy`] for all errors. -/// -/// # Examples -/// -/// ``` -/// use std::time::Duration; -/// -/// use wireframe::codec::{ -/// CodecError, -/// CodecErrorContext, -/// EofError, -/// RecoveryPolicy, -/// RecoveryPolicyHook, -/// }; -/// -/// /// Quarantine connections that close unexpectedly. -/// struct QuarantineOnPrematureEof; -/// -/// impl RecoveryPolicyHook for QuarantineOnPrematureEof { -/// fn recovery_policy(&self, error: &CodecError, _ctx: &CodecErrorContext) -> RecoveryPolicy { -/// match error { -/// CodecError::Eof(EofError::MidFrame { .. }) => RecoveryPolicy::Quarantine, -/// _ => error.default_recovery_policy(), -/// } -/// } -/// -/// fn quarantine_duration(&self, _error: &CodecError, _ctx: &CodecErrorContext) -> Duration { -/// Duration::from_secs(60) -/// } -/// } -/// ``` -pub trait RecoveryPolicyHook: Send + Sync { - /// Determine the recovery policy for a codec error. - /// - /// The default implementation delegates to - /// [`CodecError::default_recovery_policy`]. - fn recovery_policy(&self, error: &CodecError, _ctx: &CodecErrorContext) -> RecoveryPolicy { - error.default_recovery_policy() - } - - /// Returns the quarantine duration when [`RecoveryPolicy::Quarantine`] is - /// selected. - /// - /// Default: 30 seconds. - fn quarantine_duration(&self, _error: &CodecError, _ctx: &CodecErrorContext) -> Duration { - Duration::from_secs(30) - } - - /// Called before applying the recovery policy. - /// - /// Use this hook for logging, metrics emission, or state updates. The - /// default implementation does nothing. - fn on_error(&self, _error: &CodecError, _ctx: &CodecErrorContext, _policy: RecoveryPolicy) {} -} - -/// Default recovery policy implementation. -/// -/// This implementation uses the built-in default policies from -/// [`CodecError::default_recovery_policy`] without any customisation. -#[derive(Clone, Copy, Debug, Default)] -pub struct DefaultRecoveryPolicy; - -impl RecoveryPolicyHook for DefaultRecoveryPolicy {} - -/// Configuration for recovery policy behaviour. -/// -/// Use this to configure global recovery settings on the application builder. -/// -/// # Examples -/// -/// ``` -/// use std::time::Duration; -/// -/// use wireframe::codec::RecoveryConfig; -/// -/// let config = RecoveryConfig::default() -/// .max_consecutive_drops(5) -/// .quarantine_duration(Duration::from_secs(60)) -/// .log_dropped_frames(true); -/// -/// assert_eq!(config.max_consecutive_drops, 5); -/// ``` -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct RecoveryConfig { - /// Maximum consecutive dropped frames before escalating to disconnect. - /// - /// When this threshold is exceeded, the recovery policy escalates from - /// [`RecoveryPolicy::Drop`] to [`RecoveryPolicy::Disconnect`]. - /// - /// Default: 10. - pub max_consecutive_drops: u32, - - /// Default quarantine duration. - /// - /// Default: 30 seconds. - pub quarantine_duration: Duration, - - /// Whether to log dropped frames at warn level. - /// - /// Default: true. - pub log_dropped_frames: bool, -} - -impl Default for RecoveryConfig { - fn default() -> Self { - Self { - max_consecutive_drops: 10, - quarantine_duration: Duration::from_secs(30), - log_dropped_frames: true, - } - } -} - -impl RecoveryConfig { - /// Set the maximum consecutive dropped frames before disconnect. - #[must_use] - pub fn max_consecutive_drops(mut self, count: u32) -> Self { - self.max_consecutive_drops = count; - self - } - - /// Set the default quarantine duration. - #[must_use] - pub fn quarantine_duration(mut self, duration: Duration) -> Self { - self.quarantine_duration = duration; - self - } - - /// Set whether to log dropped frames. - #[must_use] - pub fn log_dropped_frames(mut self, enabled: bool) -> Self { - self.log_dropped_frames = enabled; - self - } -} - -#[cfg(test)] -mod tests; diff --git a/src/codec/recovery/config.rs b/src/codec/recovery/config.rs new file mode 100644 index 00000000..5526fa3b --- /dev/null +++ b/src/codec/recovery/config.rs @@ -0,0 +1,76 @@ +//! Configuration for codec recovery policies. + +use std::time::Duration; + +/// Configuration for recovery policy behaviour. +/// +/// Use this to configure global recovery settings on the application builder. +/// +/// # Examples +/// +/// ``` +/// use std::time::Duration; +/// +/// use wireframe::codec::RecoveryConfig; +/// +/// let config = RecoveryConfig::default() +/// .max_consecutive_drops(5) +/// .quarantine_duration(Duration::from_secs(60)) +/// .log_dropped_frames(true); +/// +/// assert_eq!(config.max_consecutive_drops, 5); +/// ``` +#[derive(Clone, Debug)] +pub struct RecoveryConfig { + /// Maximum consecutive dropped frames before escalating to disconnect. + /// + /// When this threshold is exceeded, the recovery policy escalates from + /// [`RecoveryPolicy::Drop`](crate::codec::RecoveryPolicy::Drop) to + /// [`RecoveryPolicy::Disconnect`](crate::codec::RecoveryPolicy::Disconnect). + /// + /// Default: 10. + pub max_consecutive_drops: u32, + + /// Default quarantine duration. + /// + /// Default: 30 seconds. + pub quarantine_duration: Duration, + + /// Whether to log dropped frames at warn level. + /// + /// Default: true. + pub log_dropped_frames: bool, +} + +impl Default for RecoveryConfig { + fn default() -> Self { + Self { + max_consecutive_drops: 10, + quarantine_duration: Duration::from_secs(30), + log_dropped_frames: true, + } + } +} + +impl RecoveryConfig { + /// Set the maximum consecutive dropped frames before disconnect. + #[must_use] + pub fn max_consecutive_drops(mut self, count: u32) -> Self { + self.max_consecutive_drops = count; + self + } + + /// Set the default quarantine duration. + #[must_use] + pub fn quarantine_duration(mut self, duration: Duration) -> Self { + self.quarantine_duration = duration; + self + } + + /// Set whether to log dropped frames. + #[must_use] + pub fn log_dropped_frames(mut self, enabled: bool) -> Self { + self.log_dropped_frames = enabled; + self + } +} diff --git a/src/codec/recovery/context.rs b/src/codec/recovery/context.rs new file mode 100644 index 00000000..5ab103a2 --- /dev/null +++ b/src/codec/recovery/context.rs @@ -0,0 +1,74 @@ +//! Structured context for codec error recovery and logging. + +use std::net::SocketAddr; + +/// Structured context for codec error logging and diagnostics. +/// +/// This struct captures connection-specific information to include in +/// structured logs and metrics when codec errors occur. +/// +/// # Examples +/// +/// ``` +/// use wireframe::codec::CodecErrorContext; +/// +/// let ctx = CodecErrorContext::new() +/// .with_connection_id(42) +/// .with_correlation_id(123); +/// +/// assert_eq!(ctx.connection_id, Some(42)); +/// assert_eq!(ctx.correlation_id, Some(123)); +/// ``` +#[derive(Clone, Debug, Default)] +pub struct CodecErrorContext { + /// Unique identifier for the connection. + pub connection_id: Option, + + /// Remote peer address. + pub peer_address: Option, + + /// Correlation identifier from the frame, if available. + /// + /// This helps attribute errors to specific requests when debugging. + pub correlation_id: Option, + + /// Codec instance state for debugging. + /// + /// May include sequence numbers, bytes processed, or other codec-specific + /// state information. + pub codec_state: Option, +} + +impl CodecErrorContext { + /// Create a new empty context. + #[must_use] + pub fn new() -> Self { Self::default() } + + /// Set the connection identifier. + #[must_use] + pub fn with_connection_id(mut self, id: u64) -> Self { + self.connection_id = Some(id); + self + } + + /// Set the peer address. + #[must_use] + pub fn with_peer_address(mut self, addr: SocketAddr) -> Self { + self.peer_address = Some(addr); + self + } + + /// Set the correlation identifier. + #[must_use] + pub fn with_correlation_id(mut self, id: u64) -> Self { + self.correlation_id = Some(id); + self + } + + /// Set codec-specific state information. + #[must_use] + pub fn with_codec_state(mut self, state: impl Into) -> Self { + self.codec_state = Some(state.into()); + self + } +} diff --git a/src/codec/recovery/hook.rs b/src/codec/recovery/hook.rs new file mode 100644 index 00000000..80bce25b --- /dev/null +++ b/src/codec/recovery/hook.rs @@ -0,0 +1,79 @@ +//! Hooks for customising codec error recovery behaviour. + +use std::time::Duration; + +use super::{CodecErrorContext, RecoveryPolicy}; +use crate::codec::error::CodecError; + +/// Hook trait for customising codec error recovery behaviour. +/// +/// Implementations can override default recovery policies based on +/// application-specific requirements or connection state. +/// +/// # Default Implementation +/// +/// The default implementation ([`DefaultRecoveryPolicy`]) delegates to +/// [`CodecError::default_recovery_policy`] for all errors. +/// +/// # Examples +/// +/// ``` +/// use std::time::Duration; +/// +/// use wireframe::codec::{ +/// CodecError, +/// CodecErrorContext, +/// EofError, +/// RecoveryPolicy, +/// RecoveryPolicyHook, +/// }; +/// +/// /// Quarantine connections that close unexpectedly. +/// struct QuarantineOnPrematureEof; +/// +/// impl RecoveryPolicyHook for QuarantineOnPrematureEof { +/// fn recovery_policy(&self, error: &CodecError, _ctx: &CodecErrorContext) -> RecoveryPolicy { +/// match error { +/// CodecError::Eof(EofError::MidFrame { .. }) => RecoveryPolicy::Quarantine, +/// _ => error.default_recovery_policy(), +/// } +/// } +/// +/// fn quarantine_duration(&self, _error: &CodecError, _ctx: &CodecErrorContext) -> Duration { +/// Duration::from_secs(60) +/// } +/// } +/// ``` +pub trait RecoveryPolicyHook: Send + Sync { + /// Determine the recovery policy for a codec error. + /// + /// The default implementation delegates to + /// [`CodecError::default_recovery_policy`]. + fn recovery_policy(&self, error: &CodecError, ctx: &CodecErrorContext) -> RecoveryPolicy { + let _ = ctx; + error.default_recovery_policy() + } + + /// Returns the quarantine duration when [`RecoveryPolicy::Quarantine`] is + /// selected. + /// + /// Default: 30 seconds. + fn quarantine_duration(&self, _error: &CodecError, _ctx: &CodecErrorContext) -> Duration { + Duration::from_secs(30) + } + + /// Called before applying the recovery policy. + /// + /// Use this hook for logging, metrics emission, or state updates. The + /// default implementation does nothing. + fn on_error(&self, _error: &CodecError, _ctx: &CodecErrorContext, _policy: RecoveryPolicy) {} +} + +/// Default recovery policy implementation. +/// +/// This implementation uses the built-in default policies from +/// [`CodecError::default_recovery_policy`] without any customisation. +#[derive(Clone, Copy, Debug, Default)] +pub struct DefaultRecoveryPolicy; + +impl RecoveryPolicyHook for DefaultRecoveryPolicy {} diff --git a/src/codec/recovery/mod.rs b/src/codec/recovery/mod.rs new file mode 100644 index 00000000..c5f39e6a --- /dev/null +++ b/src/codec/recovery/mod.rs @@ -0,0 +1,49 @@ +//! Recovery policy types and hooks for codec error handling. +//! +//! This module provides infrastructure for customising how codec errors are +//! handled, including recovery policies, error context for structured logging, +//! and hooks for application-specific error handling. +//! +//! # Recovery Policies +//! +//! When a codec error occurs, the framework applies a recovery policy: +//! +//! - [`RecoveryPolicy::Drop`]: Discard the malformed frame and continue processing. Suitable for +//! recoverable errors like oversized frames. +//! - [`RecoveryPolicy::Quarantine`]: Pause the connection temporarily before retrying. Useful for +//! rate-limiting misbehaving clients. +//! - [`RecoveryPolicy::Disconnect`]: Terminate the connection immediately. Required for +//! unrecoverable errors like I/O failures. +//! +//! # Custom Recovery Hooks +//! +//! Applications can customise error handling by implementing +//! [`RecoveryPolicyHook`]: +//! +//! ``` +//! use std::time::Duration; +//! +//! use wireframe::codec::{CodecError, CodecErrorContext, RecoveryPolicy, RecoveryPolicyHook}; +//! +//! struct StrictRecovery; +//! +//! impl RecoveryPolicyHook for StrictRecovery { +//! fn recovery_policy(&self, _error: &CodecError, _ctx: &CodecErrorContext) -> RecoveryPolicy { +//! // Disconnect on any error +//! RecoveryPolicy::Disconnect +//! } +//! } +//! ``` + +mod config; +mod context; +mod hook; +mod policy; + +pub use config::RecoveryConfig; +pub use context::CodecErrorContext; +pub use hook::{DefaultRecoveryPolicy, RecoveryPolicyHook}; +pub use policy::RecoveryPolicy; + +#[cfg(test)] +mod tests; diff --git a/src/codec/recovery/policy.rs b/src/codec/recovery/policy.rs new file mode 100644 index 00000000..865e12d9 --- /dev/null +++ b/src/codec/recovery/policy.rs @@ -0,0 +1,75 @@ +//! Recovery policies for codec errors. + +/// Recovery policies for codec errors. +/// +/// Each policy defines how the framework responds to a codec error. +/// +/// # Default Behaviour +/// +/// [`CodecError::default_recovery_policy`](crate::codec::CodecError::default_recovery_policy) +/// returns the recommended policy for each error type. Applications can +/// override this via [`RecoveryPolicyHook`](crate::codec::RecoveryPolicyHook). +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum RecoveryPolicy { + /// Discard the malformed frame and continue processing. + /// + /// This is the default for recoverable errors like oversized frames or + /// protocol violations that affect only a single message. The connection + /// remains open and can process subsequent frames. + /// + /// # When to Use + /// + /// - Oversized frames that exceed `max_frame_length` + /// - Empty frames where non-empty is expected + /// - Protocol-level errors (unknown message type, sequence violation) + #[default] + Drop, + + /// Pause the connection temporarily before retrying. + /// + /// The connection enters a quarantine state for a configurable duration. + /// During quarantine, no frames are processed. After the timeout, normal + /// processing resumes. + /// + /// # When to Use + /// + /// - Rate-limiting misbehaving clients + /// - Temporary back-off after repeated errors + /// - Giving time for upstream issues to resolve + Quarantine, + + /// Terminate the connection immediately. + /// + /// The connection is closed without processing further frames. This is + /// required for unrecoverable errors where the framing state is corrupted + /// or the transport has failed. + /// + /// # When to Use + /// + /// - I/O errors (socket closed, write failed) + /// - Invalid frame length encoding (framing state corrupted) + /// - EOF conditions (connection ending) + Disconnect, +} + +impl RecoveryPolicy { + /// Returns the policy name as a static string for metrics and logging. + /// + /// # Examples + /// + /// ``` + /// use wireframe::codec::RecoveryPolicy; + /// + /// assert_eq!(RecoveryPolicy::Drop.as_str(), "drop"); + /// assert_eq!(RecoveryPolicy::Quarantine.as_str(), "quarantine"); + /// assert_eq!(RecoveryPolicy::Disconnect.as_str(), "disconnect"); + /// ``` + #[must_use] + pub const fn as_str(self) -> &'static str { + match self { + Self::Drop => "drop", + Self::Quarantine => "quarantine", + Self::Disconnect => "disconnect", + } + } +} diff --git a/src/codec/recovery/tests.rs b/src/codec/recovery/tests.rs index 7534667f..6b59e328 100644 --- a/src/codec/recovery/tests.rs +++ b/src/codec/recovery/tests.rs @@ -1,22 +1,9 @@ -//! Tests for codec recovery policies, hooks, and configuration. +//! Tests for codec recovery policies. -use std::{io, net::SocketAddr, time::Duration}; - -use rstest::{fixture, rstest}; +use std::{io, time::Duration}; use super::*; - -#[fixture] -fn default_hook() -> DefaultRecoveryPolicy { - // Use the framework default hook for baseline policy assertions. - DefaultRecoveryPolicy -} - -#[fixture] -fn context() -> CodecErrorContext { - // These tests exercise hook behaviour without connection metadata. - CodecErrorContext::new() -} +use crate::codec::CodecError; #[test] fn recovery_policy_default_is_drop() { @@ -37,47 +24,37 @@ fn context_builder_sets_fields() { #[test] fn context_with_peer_address() { - let addr: SocketAddr = "127.0.0.1:8080".parse().expect("valid test address"); + let addr: std::net::SocketAddr = "127.0.0.1:8080".parse().expect("valid test address"); let ctx = CodecErrorContext::new().with_peer_address(addr); assert_eq!(ctx.peer_address, Some(addr)); } -#[rstest] -fn default_recovery_policy_delegates_to_error( - default_hook: DefaultRecoveryPolicy, - context: CodecErrorContext, -) { +#[test] +fn default_recovery_policy_delegates_to_error() { use crate::codec::error::{EofError, FramingError}; + let hook = DefaultRecoveryPolicy; + let ctx = CodecErrorContext::new(); + // Check various error types let err = CodecError::Framing(FramingError::OversizedFrame { size: 100, max: 50 }); - assert_eq!( - default_hook.recovery_policy(&err, &context), - RecoveryPolicy::Drop - ); + assert_eq!(hook.recovery_policy(&err, &ctx), RecoveryPolicy::Drop); let err = CodecError::Io(io::Error::other("test")); - assert_eq!( - default_hook.recovery_policy(&err, &context), - RecoveryPolicy::Disconnect - ); + assert_eq!(hook.recovery_policy(&err, &ctx), RecoveryPolicy::Disconnect); let err = CodecError::Eof(EofError::CleanClose); - assert_eq!( - default_hook.recovery_policy(&err, &context), - RecoveryPolicy::Disconnect - ); + assert_eq!(hook.recovery_policy(&err, &ctx), RecoveryPolicy::Disconnect); } -#[rstest] -fn default_quarantine_duration_is_30_seconds( - default_hook: DefaultRecoveryPolicy, - context: CodecErrorContext, -) { - let io_error = CodecError::Io(io::Error::other("test")); +#[test] +fn default_quarantine_duration_is_30_seconds() { + let hook = DefaultRecoveryPolicy; + let ctx = CodecErrorContext::new(); + let err = CodecError::Io(io::Error::other("test")); assert_eq!( - default_hook.quarantine_duration(&io_error, &context), + hook.quarantine_duration(&err, &ctx), Duration::from_secs(30) ); } diff --git a/src/extractor/connection_info.rs b/src/extractor/connection_info.rs new file mode 100644 index 00000000..f3879d81 --- /dev/null +++ b/src/extractor/connection_info.rs @@ -0,0 +1,47 @@ +//! Extractor for connection metadata. + +use std::net::SocketAddr; + +use super::{FromMessageRequest, MessageRequest, Payload}; + +/// Extractor providing peer connection metadata. +#[derive(Debug, Clone, Copy)] +pub struct ConnectionInfo { + peer_addr: Option, +} + +impl ConnectionInfo { + /// Returns the peer's socket address for the current connection, if available. + /// + /// # Examples + /// + /// ```rust + /// use std::net::SocketAddr; + /// + /// use wireframe::extractor::{ConnectionInfo, FromMessageRequest, MessageRequest, Payload}; + /// + /// let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap(); + /// let req = MessageRequest::new().with_peer_addr(Some(addr)); + /// let info = ConnectionInfo::from_message_request(&req, &mut Payload::default()).unwrap(); + /// assert_eq!(info.peer_addr(), Some(addr)); + /// ``` + #[must_use] + pub fn peer_addr(&self) -> Option { self.peer_addr } +} + +impl FromMessageRequest for ConnectionInfo { + type Error = std::convert::Infallible; + + /// Extracts connection metadata from the message request. + /// + /// Returns a `ConnectionInfo` containing the peer's socket address, if available. This + /// extraction is infallible. + fn from_message_request( + req: &MessageRequest, + _payload: &mut Payload<'_>, + ) -> Result { + Ok(Self { + peer_addr: req.peer_addr, + }) + } +} diff --git a/src/extractor/error.rs b/src/extractor/error.rs new file mode 100644 index 00000000..c7d0a663 --- /dev/null +++ b/src/extractor/error.rs @@ -0,0 +1,48 @@ +//! Error types for built-in extractors. + +/// Errors that can occur when extracting built-in types. +/// +/// This enum is marked `#[non_exhaustive]` so more variants may be added in +/// the future without breaking changes. +#[derive(Debug)] +#[non_exhaustive] +pub enum ExtractError { + /// No shared state of the requested type was found. + MissingState(&'static str), + /// Failed to decode the message payload. + InvalidPayload(bincode::error::DecodeError), + /// No streaming body was available for this request. + /// + /// This occurs when: + /// - The request was not configured for streaming consumption + /// - The stream was already consumed by another extractor + MissingBodyStream, +} + +impl std::fmt::Display for ExtractError { + /// Formats the `ExtractError` for display purposes. + /// + /// Displays a descriptive message for missing shared state or payload decoding errors. + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::MissingState(ty) => write!(f, "no shared state registered for {ty}"), + Self::InvalidPayload(e) => write!(f, "failed to decode payload: {e}"), + Self::MissingBodyStream => { + write!(f, "no streaming body available for this request") + } + } + } +} + +impl std::error::Error for ExtractError { + /// Returns the underlying error if this is an `InvalidPayload` variant. + /// + /// # Returns + /// An optional reference to the underlying decode error, or `None` if not applicable. + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::InvalidPayload(e) => Some(e), + _ => None, + } + } +} diff --git a/src/extractor/extractors.rs b/src/extractor/extractors.rs deleted file mode 100644 index 266e2efb..00000000 --- a/src/extractor/extractors.rs +++ /dev/null @@ -1,182 +0,0 @@ -//! Built-in extractor implementations for message, streaming body, and connection metadata. - -use std::net::SocketAddr; - -use super::{ExtractError, FromMessageRequest, MessageRequest, Payload}; -use crate::message::Message as WireMessage; - -/// Extractor that deserializes the message payload into `T`. -#[derive(Debug, Clone)] -pub struct Message(T); - -impl Message { - /// Consumes the extractor and returns the inner deserialized message value. - #[must_use] - pub fn into_inner(self) -> T { self.0 } -} - -impl std::ops::Deref for Message { - type Target = T; - - /// Returns a reference to the inner value. - /// - /// This enables transparent access to the wrapped type via dereferencing. - fn deref(&self) -> &Self::Target { &self.0 } -} - -impl FromMessageRequest for Message -where - T: WireMessage, -{ - type Error = ExtractError; - - /// Attempts to extract and deserialize a message of type `T` from the payload. - /// - /// Advances the payload by the number of bytes consumed during deserialization. - /// Returns an error if the payload cannot be decoded into the target type. - /// - /// # Returns - /// - `Ok(Self)`: The successfully extracted and deserialized message. - /// - `Err(ExtractError::InvalidPayload)`: If deserialization fails. - fn from_message_request( - _req: &MessageRequest, - payload: &mut Payload<'_>, - ) -> Result { - let (msg, consumed) = - T::from_bytes(payload.as_ref()).map_err(ExtractError::InvalidPayload)?; - payload.advance(consumed); - Ok(Self(msg)) - } -} - -/// Extractor providing streaming access to the request body. -/// -/// Unlike [`Payload`] which borrows buffered bytes, this extractor -/// takes ownership of a streaming body channel. Handlers opting into -/// streaming receive chunks incrementally via a [`RequestBodyStream`]. -/// -/// This type is the inbound counterpart to [`crate::Response::Stream`]. -/// -/// # Examples -/// -/// ``` -/// use bytes::Bytes; -/// use tokio::io::AsyncReadExt; -/// use wireframe::{extractor::StreamingBody, request::body_channel}; -/// -/// # #[tokio::main] -/// # async fn main() { -/// let (tx, stream) = body_channel(4); -/// -/// tokio::spawn(async move { -/// let _ = tx.send(Ok(Bytes::from_static(b"payload"))).await; -/// }); -/// -/// let body = StreamingBody::new(stream); -/// let mut reader = body.into_reader(); -/// let mut buf = Vec::new(); -/// reader.read_to_end(&mut buf).await.expect("read body"); -/// assert_eq!(buf, b"payload"); -/// # } -/// ``` -/// -/// [`RequestBodyStream`]: crate::request::RequestBodyStream -pub struct StreamingBody { - stream: crate::request::RequestBodyStream, -} - -impl StreamingBody { - /// Create a streaming body from the given stream. - /// - /// Typically constructed by the framework when a handler opts into - /// streaming request consumption. - #[must_use] - pub fn new(stream: crate::request::RequestBodyStream) -> Self { Self { stream } } - - /// Consume the extractor and return the underlying stream. - /// - /// Use this when you need direct access to the stream for custom - /// processing with [`futures::StreamExt`] methods. - #[must_use] - pub fn into_stream(self) -> crate::request::RequestBodyStream { self.stream } - - /// Convert to an [`AsyncRead`] adapter. - /// - /// Protocol crates can use this to feed streaming bytes into parsers - /// that operate on readers rather than streams. - /// - /// [`AsyncRead`]: tokio::io::AsyncRead - #[must_use] - pub fn into_reader(self) -> crate::request::RequestBodyReader { - crate::request::RequestBodyReader::new(self.stream) - } -} - -impl std::fmt::Debug for StreamingBody { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("StreamingBody").finish_non_exhaustive() - } -} - -impl FromMessageRequest for StreamingBody { - type Error = ExtractError; - - /// Extract the streaming body from the request. - /// - /// # Errors - /// - /// Returns [`ExtractError::MissingBodyStream`] if: - /// - The request was not configured for streaming consumption - /// - The stream was already consumed by another extractor - fn from_message_request( - req: &MessageRequest, - _payload: &mut Payload<'_>, - ) -> Result { - req.take_body_stream() - .map(Self::new) - .ok_or(ExtractError::MissingBodyStream) - } -} - -/// Extractor providing peer connection metadata. -#[derive(Debug, Clone, Copy)] -pub struct ConnectionInfo { - peer_addr: Option, -} - -impl ConnectionInfo { - /// Returns the peer's socket address for the current connection, if available. - /// - /// # Examples - /// - /// ```rust - /// use std::net::SocketAddr; - /// - /// use wireframe::extractor::{ConnectionInfo, FromMessageRequest, MessageRequest, Payload}; - /// - /// let addr: SocketAddr = "127.0.0.1:8080".parse().expect("valid socket address"); - /// let req = MessageRequest::new().with_peer_addr(Some(addr)); - /// let info = ConnectionInfo::from_message_request(&req, &mut Payload::default()) - /// .expect("connection info extraction should succeed"); - /// assert_eq!(info.peer_addr(), Some(addr)); - /// ``` - #[must_use] - pub fn peer_addr(&self) -> Option { self.peer_addr } -} - -impl FromMessageRequest for ConnectionInfo { - type Error = std::convert::Infallible; - - /// Extracts connection metadata from the message request. - /// - /// Returns a `ConnectionInfo` containing the peer's socket address, if available. This - /// extraction is infallible. - fn from_message_request( - req: &MessageRequest, - _payload: &mut Payload<'_>, - ) -> Result { - Ok(Self { - peer_addr: req.peer_addr, - }) - } -} diff --git a/src/extractor/message.rs b/src/extractor/message.rs new file mode 100644 index 00000000..9e818058 --- /dev/null +++ b/src/extractor/message.rs @@ -0,0 +1,47 @@ +//! Message extractor for deserialised payloads. + +use super::{ExtractError, FromMessageRequest, MessageRequest, Payload}; +use crate::message::Message as WireMessage; + +/// Extractor that deserializes the message payload into `T`. +#[derive(Debug, Clone)] +pub struct Message(T); + +impl Message { + /// Consumes the extractor and returns the inner deserialised message value. + #[must_use] + pub fn into_inner(self) -> T { self.0 } +} + +impl std::ops::Deref for Message { + type Target = T; + + /// Returns a reference to the inner value. + /// + /// This enables transparent access to the wrapped type via dereferencing. + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl FromMessageRequest for Message +where + T: WireMessage, +{ + type Error = ExtractError; + + /// Attempts to extract and deserialize a message of type `T` from the payload. + /// + /// Advances the payload by the number of bytes consumed during deserialization. + /// Returns an error if the payload cannot be decoded into the target type. + /// + /// # Returns + /// - `Ok(Self)`: The successfully extracted and deserialized message. + /// - `Err(ExtractError::InvalidPayload)`: If deserialization fails. + fn from_message_request( + _req: &MessageRequest, + payload: &mut Payload<'_>, + ) -> Result { + let (msg, consumed) = T::from_bytes(payload.data).map_err(ExtractError::InvalidPayload)?; + payload.advance(consumed); + Ok(Self(msg)) + } +} diff --git a/src/extractor/mod.rs b/src/extractor/mod.rs index a1a158f5..2174bb8e 100644 --- a/src/extractor/mod.rs +++ b/src/extractor/mod.rs @@ -4,333 +4,18 @@ //! state. Implement [`FromMessageRequest`] for custom extractors to parse //! payload bytes or inspect connection info before your handler runs. -use std::{ - any::{Any, TypeId}, - collections::HashMap, - net::SocketAddr, - sync::{Arc, Mutex}, -}; - -use thiserror::Error; - -use crate::request::RequestBodyStream; - -mod extractors; - -pub use extractors::{ConnectionInfo, Message, StreamingBody}; - -/// Request context passed to extractors. -/// -/// This type contains metadata about the current connection and provides -/// access to application state registered with [`crate::app::WireframeApp`]. -#[derive(Default)] -pub struct MessageRequest { - /// Address of the peer that sent the current message. - pub peer_addr: Option, - /// Shared state values registered with the application. - /// - /// Values are keyed by their [`TypeId`]. Registering additional - /// state of the same type will replace the previous entry. - pub app_data: HashMap>, - /// Optional streaming body for handlers that opt into streaming consumption. - /// - /// When present, the [`StreamingBody`] extractor can take ownership of this - /// stream. Only one handler/extractor may consume the stream; subsequent - /// extractions will receive [`ExtractError::MissingBodyStream`]. - body_stream: Option>>, -} - -impl MessageRequest { - /// Create a new empty message request. - /// - /// Use [`Self::with_peer_addr`] to configure connection metadata. - #[must_use] - pub fn new() -> Self { Self::default() } - - /// Set the peer address for this request. - /// - /// # Examples - /// - /// ```rust - /// use std::net::SocketAddr; - /// - /// use wireframe::extractor::MessageRequest; - /// - /// let addr: SocketAddr = "127.0.0.1:8080".parse().expect("valid socket address"); - /// let req = MessageRequest::new().with_peer_addr(Some(addr)); - /// assert!(req.peer_addr.is_some()); - /// ``` - #[must_use] - pub fn with_peer_addr(mut self, addr: Option) -> Self { - self.peer_addr = addr; - self - } - - /// Retrieve shared state of type `T` if available. - /// - /// Returns `None` when no value of type `T` was registered. - /// - /// # Examples - /// - /// ```rust,no_run - /// use wireframe::{ - /// app::WireframeApp, - /// extractor::{MessageRequest, SharedState}, - /// }; - /// - /// let _app = WireframeApp::new() - /// .expect("failed to initialize app") - /// .app_data(5u32); - /// // The framework populates the request with application data. - /// # let mut req = MessageRequest::default(); - /// # req.insert_state(5u32); - /// let val: Option> = req.state(); - /// assert_eq!(*val.expect("state should be available"), 5); - /// ``` - #[must_use] - pub fn state(&self) -> Option> - where - T: Send + Sync + 'static, - { - self.app_data - .get(&TypeId::of::()) - .and_then(|data| data.clone().downcast::().ok()) - .map(SharedState) - } - - /// Insert shared state of type `T` into the request. - /// - /// This replaces any existing value of the same type. - /// - /// # Examples - /// - /// ```rust - /// use wireframe::extractor::{MessageRequest, SharedState}; - /// - /// let mut req = MessageRequest::default(); - /// req.insert_state(5u32); - /// let val: Option> = req.state(); - /// assert_eq!(*val.expect("state should be available"), 5); - /// ``` - pub fn insert_state(&mut self, state: T) - where - T: Send + Sync + 'static, - { - self.app_data.insert( - TypeId::of::(), - Arc::new(state) as Arc, - ); - } - - /// Set the streaming body for this request. - /// - /// The framework calls this when a handler opts into streaming consumption. - /// The stream can later be taken by the [`StreamingBody`] extractor. - /// - /// # Examples - /// - /// ```rust - /// use wireframe::{extractor::MessageRequest, request::body_channel}; - /// - /// let mut req = MessageRequest::default(); - /// let (_tx, stream) = body_channel(4); - /// req.set_body_stream(stream); - /// ``` - pub fn set_body_stream(&mut self, stream: RequestBodyStream) { - self.body_stream = Some(Mutex::new(Some(stream))); - } - - /// Take the streaming body from this request, if present. - /// - /// Returns `None` if no body stream was set or if it was already taken - /// by a previous extractor. This ensures only one consumer receives the - /// stream. - /// - /// # Mutex poisoning - /// - /// If the internal mutex is poisoned (for example, due to a panic in another - /// thread while holding the lock), this method returns `None` rather than - /// propagating the panic. This behaviour prioritizes availability over - /// strict correctness: a poisoned mutex typically indicates a serious bug - /// elsewhere, but crashing additional handlers would only compound the - /// problem. The missing stream will surface as an [`ExtractError::MissingBodyStream`] - /// in the handler, which can be logged and investigated. - #[must_use] - pub fn take_body_stream(&self) -> Option { - self.body_stream - .as_ref() - .and_then(|mutex| mutex.lock().ok()) - .and_then(|mut guard| guard.take()) - } -} - -/// Raw payload buffer handed to extractors. -/// -/// Create a `Payload` from a slice using [`Payload::new`] or `into`: -/// -/// ```rust -/// use wireframe::extractor::Payload; -/// -/// let p1 = Payload::new(b"abc"); -/// let p2: Payload<'_> = b"xyz".as_slice().into(); -/// assert_eq!(p1.as_ref(), b"abc" as &[u8]); -/// assert_eq!(p2.as_ref(), b"xyz" as &[u8]); -/// ``` -#[derive(Default)] -pub struct Payload<'a> { - /// Incoming bytes not yet processed. - data: &'a [u8], -} - -impl<'a> Payload<'a> { - /// Creates a new `Payload` from the provided byte slice. - /// - /// # Examples - /// - /// ```rust,no_run - /// use wireframe::extractor::Payload; - /// - /// let payload = Payload::new(b"data"); - /// assert_eq!(payload.as_ref(), b"data" as &[u8]); - /// ``` - #[must_use] - #[inline] - pub fn new(data: &'a [u8]) -> Self { Self { data } } -} - -impl<'a> From<&'a [u8]> for Payload<'a> { - #[inline] - fn from(data: &'a [u8]) -> Self { Self { data } } -} - -impl AsRef<[u8]> for Payload<'_> { - fn as_ref(&self) -> &[u8] { self.data } -} - -impl Payload<'_> { - /// Advances the payload by `count` bytes. - /// - /// Consumes up to `count` bytes from the front of the slice, ensuring we - /// never slice beyond the available buffer. - /// - /// # Examples - /// - /// ```rust,no_run - /// use wireframe::extractor::Payload; - /// - /// let mut payload = Payload::new(b"abcd"); - /// payload.advance(2); - /// assert_eq!(payload.as_ref(), b"cd" as &[u8]); - /// ``` - pub fn advance(&mut self, count: usize) { - let n = count.min(self.data.len()); - self.data = self.data.get(n..).unwrap_or_default(); - } - - /// Returns the number of bytes remaining. - /// - /// # Examples - /// - /// ```rust,no_run - /// use wireframe::extractor::Payload; - /// - /// let mut payload = Payload::new(b"bytes"); - /// assert_eq!(payload.remaining(), 5); - /// payload.advance(2); - /// assert_eq!(payload.remaining(), 3); - /// ``` - #[must_use] - pub fn remaining(&self) -> usize { self.data.len() } -} - -/// Trait for extracting data from a [`MessageRequest`]. -/// -/// Types implementing this trait can be used as parameters to handler -/// functions. When invoked, `wireframe` passes the current request metadata and -/// message payload, allowing the extractor to parse bytes or inspect -/// connection information. This makes it easy to share common parsing and -/// validation logic across handlers. -pub trait FromMessageRequest: Sized { - /// Error type returned when extraction fails. - type Error: std::error::Error + Send + Sync + 'static; - - /// Perform extraction from the request and payload. - /// - /// # Errors - /// - /// Returns an error if extraction fails. - fn from_message_request( - req: &MessageRequest, - payload: &mut Payload<'_>, - ) -> Result; -} - -/// Shared application state accessible to handlers. -#[derive(Clone)] -pub struct SharedState(Arc); - -impl From> for SharedState { - fn from(inner: Arc) -> Self { Self(inner) } -} - -impl From for SharedState { - fn from(inner: T) -> Self { Self(Arc::new(inner)) } -} - -/// Errors that can occur when extracting built-in types. -/// -/// This enum is marked `#[non_exhaustive]` so more variants may be added in -/// the future without breaking changes. -#[derive(Debug, Error)] -#[non_exhaustive] -pub enum ExtractError { - /// No shared state of the requested type was found. - #[error("no shared state registered for {0}")] - MissingState(&'static str), - /// Failed to decode the message payload. - #[error("failed to decode payload: {0}")] - InvalidPayload(#[source] bincode::error::DecodeError), - /// No streaming body was available for this request. - /// - /// This occurs when: - /// - The request was not configured for streaming consumption - /// - The stream was already consumed by another extractor - #[error("no streaming body available for this request")] - MissingBodyStream, -} - -impl FromMessageRequest for SharedState -where - T: Send + Sync + 'static, -{ - type Error = ExtractError; - - fn from_message_request( - req: &MessageRequest, - _payload: &mut Payload<'_>, - ) -> Result { - req.state::() - .ok_or(ExtractError::MissingState(std::any::type_name::())) - } -} - -impl std::ops::Deref for SharedState { - type Target = T; - - /// Returns a reference to the inner shared state value. - /// - /// This allows transparent access to the underlying state managed by `SharedState`. - /// - /// # Examples - /// - /// ```rust,no_run - /// use std::sync::Arc; - /// - /// use wireframe::extractor::SharedState; - /// - /// let state = Arc::new(42); - /// let shared: SharedState = state.clone().into(); - /// assert_eq!(*shared, 42); - /// ``` - fn deref(&self) -> &Self::Target { &self.0 } -} +mod connection_info; +mod error; +mod message; +mod request; +mod shared_state; +mod streaming; +mod trait_def; + +pub use connection_info::ConnectionInfo; +pub use error::ExtractError; +pub use message::Message; +pub use request::{MessageRequest, Payload}; +pub use shared_state::SharedState; +pub use streaming::StreamingBody; +pub use trait_def::FromMessageRequest; diff --git a/src/extractor/request.rs b/src/extractor/request.rs new file mode 100644 index 00000000..e569dee2 --- /dev/null +++ b/src/extractor/request.rs @@ -0,0 +1,235 @@ +//! Request context and payload buffer types for extractors. + +use std::{ + any::{Any, TypeId}, + collections::HashMap, + net::SocketAddr, + sync::{Arc, Mutex}, +}; + +use super::SharedState; +use crate::request::RequestBodyStream; + +/// Request context passed to extractors. +/// +/// This type contains metadata about the current connection and provides +/// access to application state registered with [`crate::app::WireframeApp`]. +#[derive(Default)] +pub struct MessageRequest { + /// Address of the peer that sent the current message. + pub peer_addr: Option, + /// Shared state values registered with the application. + /// + /// Values are keyed by their [`TypeId`]. Registering additional + /// state of the same type will replace the previous entry. + pub app_data: HashMap>, + /// Optional streaming body for handlers that opt into streaming consumption. + /// + /// When present, the [`StreamingBody`](crate::extractor::StreamingBody) + /// extractor can take ownership of this stream. Only one handler/extractor + /// may consume the stream; subsequent extractions will receive + /// [`ExtractError::MissingBodyStream`]. + body_stream: Option>>, +} + +impl MessageRequest { + /// Create a new empty message request. + /// + /// Use [`Self::with_peer_addr`] to configure connection metadata. + #[must_use] + pub fn new() -> Self { Self::default() } + + /// Set the peer address for this request. + /// + /// # Examples + /// + /// ```rust + /// use std::net::SocketAddr; + /// + /// use wireframe::extractor::MessageRequest; + /// + /// let req = MessageRequest::new().with_peer_addr(Some("127.0.0.1:8080".parse().unwrap())); + /// assert!(req.peer_addr.is_some()); + /// ``` + #[must_use] + pub fn with_peer_addr(mut self, addr: Option) -> Self { + self.peer_addr = addr; + self + } + + /// Retrieve shared state of type `T` if available. + /// + /// Returns `None` when no value of type `T` was registered. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::{ + /// app::WireframeApp, + /// extractor::{MessageRequest, SharedState}, + /// }; + /// + /// let _app = WireframeApp::new().unwrap().app_data(5u32); + /// // The framework populates the request with application data. + /// # let mut req = MessageRequest::default(); + /// # req.insert_state(5u32); + /// let val: Option> = req.state(); + /// assert_eq!(*val.unwrap(), 5); + /// ``` + #[must_use] + pub fn state(&self) -> Option> + where + T: Send + Sync + 'static, + { + self.app_data + .get(&TypeId::of::()) + .and_then(|data| data.clone().downcast::().ok()) + .map(SharedState::from) + } + + /// Insert shared state of type `T` into the request. + /// + /// This replaces any existing value of the same type. + /// + /// # Examples + /// + /// ```rust + /// use wireframe::extractor::{MessageRequest, SharedState}; + /// + /// let mut req = MessageRequest::default(); + /// req.insert_state(5u32); + /// let val: Option> = req.state(); + /// assert_eq!(*val.unwrap(), 5); + /// ``` + pub fn insert_state(&mut self, state: T) + where + T: Send + Sync + 'static, + { + self.app_data.insert( + TypeId::of::(), + Arc::new(state) as Arc, + ); + } + + /// Set the streaming body for this request. + /// + /// The framework calls this when a handler opts into streaming consumption. + /// The stream can later be taken by the + /// [`StreamingBody`](crate::extractor::StreamingBody) extractor. + /// + /// # Examples + /// + /// ```rust + /// use wireframe::{extractor::MessageRequest, request::body_channel}; + /// + /// let mut req = MessageRequest::default(); + /// let (_tx, stream) = body_channel(4); + /// req.set_body_stream(stream); + /// ``` + pub fn set_body_stream(&mut self, stream: RequestBodyStream) { + self.body_stream = Some(Mutex::new(Some(stream))); + } + + /// Take the streaming body from this request, if present. + /// + /// Returns `None` if no body stream was set or if it was already taken + /// by a previous extractor. This ensures only one consumer receives the + /// stream. + /// + /// # Mutex poisoning + /// + /// If the internal mutex is poisoned (for example, due to a panic in another + /// thread while holding the lock), this method returns `None` rather than + /// propagating the panic. This behaviour prioritizes availability over + /// strict correctness: a poisoned mutex typically indicates a serious bug + /// elsewhere, but crashing additional handlers would only compound the + /// problem. The missing stream will surface as an + /// [`ExtractError::MissingBodyStream`](crate::extractor::ExtractError::MissingBodyStream) + /// in the handler, which can be logged and investigated. + #[must_use] + pub fn take_body_stream(&self) -> Option { + self.body_stream + .as_ref() + .and_then(|mutex| mutex.lock().ok()) + .and_then(|mut guard| guard.take()) + } +} + +/// Raw payload buffer handed to extractors. +/// +/// Create a `Payload` from a slice using [`Payload::new`] or `into`: +/// +/// ```rust +/// use wireframe::extractor::Payload; +/// +/// let p1 = Payload::new(b"abc"); +/// let p2: Payload<'_> = b"xyz".as_slice().into(); +/// assert_eq!(p1.as_ref(), b"abc" as &[u8]); +/// assert_eq!(p2.as_ref(), b"xyz" as &[u8]); +/// ``` +#[derive(Default)] +pub struct Payload<'a> { + /// Incoming bytes not yet processed. + pub(super) data: &'a [u8], +} + +impl<'a> Payload<'a> { + /// Creates a new `Payload` from the provided byte slice. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::extractor::Payload; + /// + /// let payload = Payload::new(b"data"); + /// assert_eq!(payload.as_ref(), b"data" as &[u8]); + /// ``` + #[must_use] + #[inline] + pub fn new(data: &'a [u8]) -> Self { Self { data } } +} + +impl<'a> From<&'a [u8]> for Payload<'a> { + #[inline] + fn from(data: &'a [u8]) -> Self { Self { data } } +} + +impl AsRef<[u8]> for Payload<'_> { + fn as_ref(&self) -> &[u8] { self.data } +} + +impl Payload<'_> { + /// Advances the payload by `count` bytes. + /// + /// Consumes up to `count` bytes from the front of the slice, ensuring we + /// never slice beyond the available buffer. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::extractor::Payload; + /// + /// let mut payload = Payload::new(b"abcd"); + /// payload.advance(2); + /// assert_eq!(payload.as_ref(), b"cd" as &[u8]); + /// ``` + pub fn advance(&mut self, count: usize) { + let n = count.min(self.data.len()); + self.data = self.data.get(n..).unwrap_or_default(); + } + + /// Returns the number of bytes remaining. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::extractor::Payload; + /// + /// let mut payload = Payload::new(b"bytes"); + /// assert_eq!(payload.remaining(), 5); + /// payload.advance(2); + /// assert_eq!(payload.remaining(), 3); + /// ``` + #[must_use] + pub fn remaining(&self) -> usize { self.data.len() } +} diff --git a/src/extractor/shared_state.rs b/src/extractor/shared_state.rs new file mode 100644 index 00000000..fdb41b9d --- /dev/null +++ b/src/extractor/shared_state.rs @@ -0,0 +1,53 @@ +//! Shared application state extractor. + +use std::sync::Arc; + +use super::{ExtractError, FromMessageRequest, MessageRequest, Payload}; + +/// Shared application state accessible to handlers. +#[derive(Clone)] +pub struct SharedState(Arc); + +impl From> for SharedState { + fn from(inner: Arc) -> Self { Self(inner) } +} + +impl From for SharedState { + fn from(inner: T) -> Self { Self(Arc::new(inner)) } +} + +impl FromMessageRequest for SharedState +where + T: Send + Sync + 'static, +{ + type Error = ExtractError; + + fn from_message_request( + req: &MessageRequest, + _payload: &mut Payload<'_>, + ) -> Result { + req.state::() + .ok_or(ExtractError::MissingState(std::any::type_name::())) + } +} + +impl std::ops::Deref for SharedState { + type Target = T; + + /// Returns a reference to the inner shared state value. + /// + /// This allows transparent access to the underlying state managed by `SharedState`. + /// + /// # Examples + /// + /// ```rust,no_run + /// use std::sync::Arc; + /// + /// use wireframe::extractor::SharedState; + /// + /// let state = Arc::new(42); + /// let shared: SharedState = state.clone().into(); + /// assert_eq!(*shared, 42); + /// ``` + fn deref(&self) -> &Self::Target { &self.0 } +} diff --git a/src/extractor/streaming.rs b/src/extractor/streaming.rs new file mode 100644 index 00000000..6bcd70db --- /dev/null +++ b/src/extractor/streaming.rs @@ -0,0 +1,92 @@ +//! Streaming body extractor. + +use super::{ExtractError, FromMessageRequest, MessageRequest, Payload}; + +/// Extractor providing streaming access to the request body. +/// +/// Unlike [`Payload`] which borrows buffered bytes, this extractor +/// takes ownership of a streaming body channel. Handlers opting into +/// streaming receive chunks incrementally via a [`RequestBodyStream`]. +/// +/// This type is the inbound counterpart to [`crate::Response::Stream`]. +/// +/// # Examples +/// +/// ``` +/// use bytes::Bytes; +/// use tokio::io::AsyncReadExt; +/// use wireframe::{extractor::StreamingBody, request::body_channel}; +/// +/// # #[tokio::main] +/// # async fn main() { +/// let (tx, stream) = body_channel(4); +/// +/// tokio::spawn(async move { +/// let _ = tx.send(Ok(Bytes::from_static(b"payload"))).await; +/// }); +/// +/// let body = StreamingBody::new(stream); +/// let mut reader = body.into_reader(); +/// let mut buf = Vec::new(); +/// reader.read_to_end(&mut buf).await.expect("read body"); +/// assert_eq!(buf, b"payload"); +/// # } +/// ``` +/// +/// [`RequestBodyStream`]: crate::request::RequestBodyStream +pub struct StreamingBody { + stream: crate::request::RequestBodyStream, +} + +impl StreamingBody { + /// Create a streaming body from the given stream. + /// + /// Typically constructed by the framework when a handler opts into + /// streaming request consumption. + #[must_use] + pub fn new(stream: crate::request::RequestBodyStream) -> Self { Self { stream } } + + /// Consume the extractor and return the underlying stream. + /// + /// Use this when you need direct access to the stream for custom + /// processing with [`futures::StreamExt`] methods. + #[must_use] + pub fn into_stream(self) -> crate::request::RequestBodyStream { self.stream } + + /// Convert to an [`AsyncRead`] adapter. + /// + /// Protocol crates can use this to feed streaming bytes into parsers + /// that operate on readers rather than streams. + /// + /// [`AsyncRead`]: tokio::io::AsyncRead + #[must_use] + pub fn into_reader(self) -> crate::request::RequestBodyReader { + crate::request::RequestBodyReader::new(self.stream) + } +} + +impl std::fmt::Debug for StreamingBody { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StreamingBody").finish_non_exhaustive() + } +} + +impl FromMessageRequest for StreamingBody { + type Error = ExtractError; + + /// Extract the streaming body from the request. + /// + /// # Errors + /// + /// Returns [`ExtractError::MissingBodyStream`] if: + /// - The request was not configured for streaming consumption + /// - The stream was already consumed by another extractor + fn from_message_request( + req: &MessageRequest, + _payload: &mut Payload<'_>, + ) -> Result { + req.take_body_stream() + .map(Self::new) + .ok_or(ExtractError::MissingBodyStream) + } +} diff --git a/src/extractor/trait_def.rs b/src/extractor/trait_def.rs new file mode 100644 index 00000000..870faf93 --- /dev/null +++ b/src/extractor/trait_def.rs @@ -0,0 +1,25 @@ +//! Trait definition for extractor types. + +use super::{MessageRequest, Payload}; + +/// Trait for extracting data from a [`MessageRequest`]. +/// +/// Types implementing this trait can be used as parameters to handler +/// functions. When invoked, `wireframe` passes the current request metadata and +/// message payload, allowing the extractor to parse bytes or inspect +/// connection information. This makes it easy to share common parsing and +/// validation logic across handlers. +pub trait FromMessageRequest: Sized { + /// Error type returned when extraction fails. + type Error: std::error::Error + Send + Sync + 'static; + + /// Perform extraction from the request and payload. + /// + /// # Errors + /// + /// Returns an error if extraction fails. + fn from_message_request( + req: &MessageRequest, + payload: &mut Payload<'_>, + ) -> Result; +} diff --git a/src/server/config/mod.rs b/src/server/config/mod.rs index 31480b75..fdc822a5 100644 --- a/src/server/config/mod.rs +++ b/src/server/config/mod.rs @@ -52,6 +52,8 @@ pub mod preamble; fn default_worker_count() -> usize { std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get) } +#[cfg(test)] +mod tests; impl WireframeServer where diff --git a/src/server/config/tests.rs b/src/server/config/tests.rs deleted file mode 100644 index 5e32879e..00000000 --- a/src/server/config/tests.rs +++ /dev/null @@ -1,406 +0,0 @@ -//! Tests for server configuration utilities. -//! -//! This module exercises the `WireframeServer` builder, covering worker counts, -//! binding behaviour, preamble handling, handler registration, and method -//! chaining. Fixtures from `test_util` provide shared setup and parameterised -//! cases via `rstest`. - -use std::{ - io, - sync::{ - Arc, - atomic::{AtomicUsize, Ordering}, - }, - time::Duration, -}; - -use bincode::error::DecodeError; -use rstest::{fixture, rstest}; -use tokio::net::{TcpListener, TcpStream}; - -use super::*; -use crate::server::{ - test_util::{ - TestPreamble, - bind_server, - factory, - free_listener, - listener_addr, - server_with_preamble, - }, - BackoffConfig, -}; - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum PreambleHandlerKind { - Success, - Failure, -} - -fn assert_local_addr_matches_listener( - server: WireframeServer, - expected: std::net::SocketAddr, -) where - F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, - T: crate::preamble::Preamble, - S: crate::server::ServerState, - Ser: crate::serializer::Serializer + Send + Sync, - Ctx: Send + 'static, - E: crate::app::Packet, - Codec: crate::codec::FrameCodec, -{ - let local_addr = server.local_addr().expect("local address missing"); - assert_eq!(local_addr, expected); -} - -#[fixture] -async fn connected_streams() -> io::Result<(TcpStream, TcpStream)> { - let listener = TcpListener::bind("127.0.0.1:0").await?; - let addr = listener.local_addr()?; - let client = TcpStream::connect(addr).await?; - let (server, _) = listener.accept().await?; - Ok((client, server)) -} - -#[rstest] -fn test_new_server_creation(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - let server = WireframeServer::new(factory); - assert!(server.worker_count() >= 1); - assert!(server.local_addr().is_none()); -} - -#[rstest] -fn test_new_server_default_worker_count( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) { - let server = WireframeServer::new(factory); - assert_eq!(server.worker_count(), default_worker_count()); -} - -#[rstest] -fn test_workers_configuration(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - let server = WireframeServer::new(factory); - let server = server.workers(4); - assert_eq!(server.worker_count(), 4); - let server = server.workers(100); - assert_eq!(server.worker_count(), 100); - assert_eq!(server.workers(0).worker_count(), 1); -} - -#[rstest] -fn test_with_preamble_type_conversion( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) { - let server = WireframeServer::new(factory).with_preamble::(); - assert_eq!(server.worker_count(), default_worker_count()); -} - -#[rstest] -fn test_preamble_timeout_configuration( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) { - let timeout = Duration::from_millis(25); - let server = WireframeServer::new(factory).preamble_timeout(timeout); - assert_eq!(server.preamble_timeout, Some(timeout)); - - let clamped = WireframeServer::new(factory).preamble_timeout(Duration::ZERO); - assert_eq!(clamped.preamble_timeout, Some(Duration::from_millis(1))); -} - -#[rstest] -#[tokio::test] -async fn test_bind_success( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - free_listener: std::net::TcpListener, -) { - let expected = listener_addr(&free_listener); - let server = WireframeServer::new(factory) - .bind_existing_listener(free_listener) - .expect("Failed to bind"); - assert_local_addr_matches_listener(server, expected); -} - -#[rstest] -fn test_local_addr_before_bind(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - assert!(WireframeServer::new(factory).local_addr().is_none()); -} - -#[rstest] -#[tokio::test] -async fn test_local_addr_after_bind( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - free_listener: std::net::TcpListener, -) { - let expected = listener_addr(&free_listener); - let server = bind_server(factory, free_listener); - assert_local_addr_matches_listener(server, expected); -} - -#[rstest] -#[case::success(PreambleHandlerKind::Success)] -#[case::failure(PreambleHandlerKind::Failure)] -#[tokio::test] -async fn test_preamble_handler_registration( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - #[case] handler: PreambleHandlerKind, - connected_streams: io::Result<(TcpStream, TcpStream)>, -) -> io::Result<()> { - let counter = Arc::new(AtomicUsize::new(0)); - let c = counter.clone(); - - let server = server_with_preamble(factory); - let server = match handler { - PreambleHandlerKind::Success => server.on_preamble_decode_success(move |_p: &TestPreamble, _| { - let c = c.clone(); - Box::pin(async move { - c.fetch_add(1, Ordering::SeqCst); - Ok(()) - }) - }), - PreambleHandlerKind::Failure => server.on_preamble_decode_failure( - move |_err: &DecodeError, _stream| { - let c = c.clone(); - Box::pin(async move { - c.fetch_add(1, Ordering::SeqCst); - Ok::<(), io::Error>(()) - }) - }, - ), - }; - - assert_eq!(counter.load(Ordering::SeqCst), 0); - match handler { - PreambleHandlerKind::Success => { - let handler = server - .on_preamble_success - .as_ref() - .expect("success handler missing"); - let (_client, mut stream) = connected_streams?; - let preamble = TestPreamble { id: 0, message: String::new() }; - handler(&preamble, &mut stream).await?; - } - PreambleHandlerKind::Failure => { - let handler = server - .on_preamble_failure - .as_ref() - .expect("failure handler missing"); - let (_client, mut stream) = connected_streams?; - handler(&DecodeError::UnexpectedEnd, &mut stream).await?; - } - } - assert_eq!(counter.load(Ordering::SeqCst), 1); - Ok(()) -} - -#[rstest] -#[tokio::test] -async fn test_method_chaining( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - free_listener: std::net::TcpListener, -) { - let handler_invoked = Arc::new(AtomicUsize::new(0)); - let counter = handler_invoked.clone(); - let server = WireframeServer::new(factory) - .workers(2) - .with_preamble::() - .on_preamble_decode_success(move |_p: &TestPreamble, _| { - let c = counter.clone(); - Box::pin(async move { - c.fetch_add(1, Ordering::SeqCst); - Ok(()) - }) - }) - .on_preamble_decode_failure(|_: &DecodeError, _| Box::pin(async { Ok::<(), io::Error>(()) })) - .bind_existing_listener(free_listener) - .expect("Failed to bind"); - assert_eq!(server.worker_count(), 2); - assert!(server.local_addr().is_some()); - assert_eq!(handler_invoked.load(Ordering::SeqCst), 0); -} - -#[rstest] -#[tokio::test] -async fn test_server_configuration_persistence( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - free_listener: std::net::TcpListener, -) { - let server = WireframeServer::new(factory) - .workers(5) - .bind_existing_listener(free_listener) - .expect("Failed to bind"); - assert_eq!(server.worker_count(), 5); - assert!(server.local_addr().is_some()); -} - -#[rstest] -fn test_extreme_worker_counts(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - let server = WireframeServer::new(factory); - let server = server.workers(usize::MAX); - assert_eq!(server.worker_count(), usize::MAX); - assert_eq!(server.workers(0).worker_count(), 1); -} - -#[rstest] -#[tokio::test] -async fn test_bind_to_multiple_addresses( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - free_listener: std::net::TcpListener, -) { - let addr1 = listener_addr(&free_listener); - - let server = WireframeServer::new(factory); - let server = server - .bind_existing_listener(free_listener) - .expect("Failed to bind first address"); - let first = server.local_addr().expect("first bound address missing"); - assert_eq!(first, addr1); - - let server = server - .bind(std::net::SocketAddr::new(addr1.ip(), 0)) - .expect("Failed to bind second address"); - let second = server.local_addr().expect("second bound address missing"); - assert_eq!(second.ip(), addr1.ip()); - assert_ne!(first.port(), second.port()); -} - -#[derive(Debug)] -struct BackoffScenario { - description: &'static str, - config: BackoffConfig, - expected_initial: Duration, - expected_max: Duration, -} - -#[rstest] -#[case::accept_config(BackoffScenario { - description: "accepts explicit delays", - config: BackoffConfig { - initial_delay: Duration::from_millis(5), - max_delay: Duration::from_millis(500), - }, - expected_initial: Duration::from_millis(5), - expected_max: Duration::from_millis(500), -})] -#[case::accept_initial_delay(BackoffScenario { - description: "accepts initial delay with default max", - config: BackoffConfig { - initial_delay: Duration::from_millis(20), - ..BackoffConfig::default() - }, - expected_initial: Duration::from_millis(20), - expected_max: BackoffConfig::default().max_delay, -})] -#[case::accept_max_delay(BackoffScenario { - description: "accepts max delay with default initial", - config: BackoffConfig { - max_delay: Duration::from_millis(2000), - ..BackoffConfig::default() - }, - expected_initial: BackoffConfig::default().initial_delay, - expected_max: Duration::from_millis(2000), -})] -#[case::clamp_zero_initial(BackoffScenario { - description: "clamps zero initial delay", - config: BackoffConfig { - initial_delay: Duration::ZERO, - ..BackoffConfig::default() - }, - expected_initial: Duration::from_millis(1), - expected_max: BackoffConfig::default().max_delay, -})] -#[case::swap_initial_gt_max(BackoffScenario { - description: "swaps initial and max delays when inverted", - config: BackoffConfig { - initial_delay: Duration::from_millis(100), - max_delay: Duration::from_millis(50), - }, - expected_initial: Duration::from_millis(50), - expected_max: Duration::from_millis(100), -})] -#[case::swap_initial_over_default_max(BackoffScenario { - description: "swaps initial and max delays when initial exceeds default max", - config: BackoffConfig { - initial_delay: Duration::from_secs(2), - max_delay: Duration::from_secs(1), - }, - expected_initial: Duration::from_secs(1), - expected_max: Duration::from_secs(2), -})] -#[case::swap_small_values(BackoffScenario { - description: "swaps small inverted delays", - config: BackoffConfig { - initial_delay: Duration::from_millis(5), - max_delay: Duration::from_millis(1), - }, - expected_initial: Duration::from_millis(1), - expected_max: Duration::from_millis(5), -})] -#[case::clamp_zero_both(BackoffScenario { - description: "clamps zero initial and max delays", - config: BackoffConfig { - initial_delay: Duration::ZERO, - max_delay: Duration::ZERO, - }, - expected_initial: Duration::from_millis(1), - expected_max: Duration::from_millis(1), -})] -fn test_accept_backoff_scenarios( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - #[case] scenario: BackoffScenario, -) { - let server = WireframeServer::new(factory).accept_backoff(scenario.config); - assert_eq!( - server.backoff_config.initial_delay, - scenario.expected_initial, - "scenario: {}", - scenario.description - ); - assert_eq!( - server.backoff_config.max_delay, - scenario.expected_max, - "scenario: {}", - scenario.description - ); -} - -/// Behaviour test verifying exponential delay doubling and capping. -#[test] -fn test_accept_exponential_backoff_doubles_and_caps() { - let initial = Duration::from_millis(10); - let max = Duration::from_millis(80); - let attempts = 5; - - let sequence = backoff_sequence(initial, max, attempts); - - let expected_delays = [ - initial, - std::cmp::min(initial.saturating_mul(2), max), - std::cmp::min(initial.saturating_mul(4), max), - std::cmp::min(initial.saturating_mul(8), max), - max, - ]; - - assert_eq!(&sequence[..], &expected_delays); -} - -fn backoff_sequence(initial: Duration, max: Duration, attempts: usize) -> Vec { - let mut sequence = Vec::with_capacity(attempts); - let mut backoff = initial; - - for _ in 0..attempts { - sequence.push(backoff); - backoff = std::cmp::min(backoff * 2, max); - } - - sequence -} - -#[rstest] -fn test_backoff_default_values(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - let server = WireframeServer::new(factory); - assert_eq!( - server.backoff_config.initial_delay, - Duration::from_millis(10) - ); - assert_eq!(server.backoff_config.max_delay, Duration::from_secs(1)); -} diff --git a/src/server/config/tests/mod.rs b/src/server/config/tests/mod.rs new file mode 100644 index 00000000..7a911b42 --- /dev/null +++ b/src/server/config/tests/mod.rs @@ -0,0 +1,23 @@ +//! Tests for server configuration utilities. +//! +//! This module exercises the `WireframeServer` builder, covering worker counts, +//! binding behaviour, preamble handling, handler registration, and method +//! chaining. Fixtures from `test_util` provide shared setup and parameterised +//! cases via `rstest`. + +mod tests_backoff; +mod tests_basic; +mod tests_binding; +mod tests_integration; +mod tests_preamble; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum PreambleHandlerKind { + Success, + Failure, +} + +fn expected_default_worker_count() -> usize { + // Mirror the default worker logic to keep tests aligned with `WireframeServer::new`. + std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get) +} diff --git a/src/server/config/tests/tests_backoff.rs b/src/server/config/tests/tests_backoff.rs new file mode 100644 index 00000000..bb3fb50d --- /dev/null +++ b/src/server/config/tests/tests_backoff.rs @@ -0,0 +1,166 @@ +//! Backoff configuration tests for `WireframeServer`. + +use std::time::Duration; + +use rstest::rstest; + +use crate::{ + app::WireframeApp, + server::{BackoffConfig, WireframeServer, test_util::factory}, +}; + +#[rstest] +fn test_accept_backoff_configuration( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, +) { + let cfg = BackoffConfig { + initial_delay: Duration::from_millis(5), + max_delay: Duration::from_millis(500), + }; + let server = WireframeServer::new(factory).accept_backoff(cfg); + assert_eq!(server.backoff_config, cfg); +} + +/// Behaviour test verifying exponential delay doubling and capping. +#[test] +fn test_accept_exponential_backoff_doubles_and_caps() { + use std::{ + thread, + time::{Duration, Instant}, + }; + + let initial = Duration::from_millis(10); + let max = Duration::from_millis(80); + let mut backoff = initial; + let mut delays = Vec::new(); + let attempts = 5; + + let start = Instant::now(); + let mut last = start; + + for _i in 0..attempts { + thread::sleep(backoff); + let now = Instant::now(); + let elapsed = now.duration_since(last); + delays.push(elapsed); + last = now; + + backoff = std::cmp::min(backoff * 2, max); + } + + let expected_delays = [ + initial, + std::cmp::min(initial * 2, max), + std::cmp::min(initial * 4, max), + std::cmp::min(initial * 8, max), + max, + ]; + + for (i, (actual, expected)) in delays.iter().zip(expected_delays.iter()).enumerate() { + assert!( + *actual >= *expected, + "Delay {i} was {actual:?}, expected at least {expected:?}" + ); + let max_expected = *expected + Duration::from_millis(20); + assert!( + *actual < max_expected, + "Delay {i} was {actual:?}, expected less than {max_expected:?}" + ); + } +} + +#[rstest] +fn test_accept_initial_delay_configuration( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, +) { + let delay = Duration::from_millis(20); + let cfg = BackoffConfig { + initial_delay: delay, + ..BackoffConfig::default() + }; + let server = WireframeServer::new(factory).accept_backoff(cfg); + assert_eq!(server.backoff_config.initial_delay, delay); +} + +#[rstest] +fn test_accept_max_delay_configuration( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, +) { + let delay = Duration::from_millis(2000); + let cfg = BackoffConfig { + max_delay: delay, + ..BackoffConfig::default() + }; + let server = WireframeServer::new(factory).accept_backoff(cfg); + assert_eq!(server.backoff_config.max_delay, delay); +} + +#[rstest] +fn test_backoff_validation(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { + let server = WireframeServer::new(factory.clone()).accept_backoff(BackoffConfig { + initial_delay: Duration::ZERO, + ..BackoffConfig::default() + }); + assert_eq!( + server.backoff_config.initial_delay, + Duration::from_millis(1) + ); + + let server = WireframeServer::new(factory).accept_backoff(BackoffConfig { + initial_delay: Duration::from_millis(100), + max_delay: Duration::from_millis(50), + }); + assert_eq!( + server.backoff_config.initial_delay, + Duration::from_millis(50) + ); + assert_eq!(server.backoff_config.max_delay, Duration::from_millis(100)); +} + +#[rstest] +fn test_backoff_default_values(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { + let server = WireframeServer::new(factory); + assert_eq!( + server.backoff_config.initial_delay, + Duration::from_millis(10) + ); + assert_eq!(server.backoff_config.max_delay, Duration::from_secs(1)); +} + +#[rstest] +fn test_initial_delay_exceeds_default_max( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, +) { + let cfg = BackoffConfig { + initial_delay: Duration::from_secs(2), + max_delay: Duration::from_secs(1), + }; + let server = WireframeServer::new(factory).accept_backoff(cfg); + assert_eq!(server.backoff_config.initial_delay, Duration::from_secs(1)); + assert_eq!(server.backoff_config.max_delay, Duration::from_secs(2)); +} + +#[rstest] +fn test_accept_backoff_parameter_swapping( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, +) { + let server = WireframeServer::new(factory.clone()).accept_backoff(BackoffConfig { + initial_delay: Duration::from_millis(5), + max_delay: Duration::from_millis(1), + }); + assert_eq!( + server.backoff_config.initial_delay, + Duration::from_millis(1) + ); + assert_eq!(server.backoff_config.max_delay, Duration::from_millis(5)); + + let server = WireframeServer::new(factory).accept_backoff(BackoffConfig { + initial_delay: Duration::ZERO, + max_delay: Duration::ZERO, + }); + assert_eq!( + server.backoff_config.initial_delay, + Duration::from_millis(1) + ); + assert_eq!(server.backoff_config.max_delay, Duration::from_millis(1)); +} diff --git a/src/server/config/tests/tests_basic.rs b/src/server/config/tests/tests_basic.rs new file mode 100644 index 00000000..f7b7a2ca --- /dev/null +++ b/src/server/config/tests/tests_basic.rs @@ -0,0 +1,54 @@ +//! Basic configuration tests for `WireframeServer`. + +use rstest::rstest; + +use super::expected_default_worker_count; +use crate::{ + app::WireframeApp, + server::{ + WireframeServer, + test_util::{bind_server, factory, free_listener, listener_addr}, + }, +}; + +#[rstest] +fn test_new_server_creation(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { + let server = WireframeServer::new(factory); + assert!(server.worker_count() >= 1 && server.local_addr().is_none()); +} + +#[rstest] +fn test_new_server_default_worker_count( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, +) { + let server = WireframeServer::new(factory); + assert_eq!(server.worker_count(), expected_default_worker_count()); +} + +#[rstest] +fn test_workers_configuration(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { + let mut server = WireframeServer::new(factory); + server = server.workers(4); + assert_eq!(server.worker_count(), 4); + server = server.workers(100); + assert_eq!(server.worker_count(), 100); + assert_eq!(server.workers(0).worker_count(), 1); +} + +#[rstest] +fn test_local_addr_before_bind(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { + assert!(WireframeServer::new(factory).local_addr().is_none()); +} + +#[rstest] +#[tokio::test] +async fn test_local_addr_after_bind( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_listener: std::net::TcpListener, +) { + let expected = listener_addr(&free_listener); + let local_addr = bind_server(factory, free_listener) + .local_addr() + .expect("local address missing"); + assert_eq!(local_addr, expected); +} diff --git a/src/server/config/tests/tests_binding.rs b/src/server/config/tests/tests_binding.rs new file mode 100644 index 00000000..5b624b2c --- /dev/null +++ b/src/server/config/tests/tests_binding.rs @@ -0,0 +1,49 @@ +//! Binding behaviour tests for `WireframeServer`. + +use rstest::rstest; + +use crate::{ + app::WireframeApp, + server::{ + WireframeServer, + test_util::{factory, free_listener, listener_addr}, + }, +}; + +#[rstest] +#[tokio::test] +async fn test_bind_success( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_listener: std::net::TcpListener, +) { + let expected = listener_addr(&free_listener); + let local_addr = WireframeServer::new(factory) + .bind_existing_listener(free_listener) + .expect("Failed to bind") + .local_addr() + .expect("local address missing"); + assert_eq!(local_addr, expected); +} + +#[rstest] +#[tokio::test] +async fn test_bind_to_multiple_addresses( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_listener: std::net::TcpListener, +) { + let addr1 = listener_addr(&free_listener); + + let server = WireframeServer::new(factory); + let server = server + .bind_existing_listener(free_listener) + .expect("Failed to bind first address"); + let first = server.local_addr().expect("first bound address missing"); + assert_eq!(first, addr1); + + let server = server + .bind(std::net::SocketAddr::new(addr1.ip(), 0)) + .expect("Failed to bind second address"); + let second = server.local_addr().expect("second bound address missing"); + assert_eq!(second.ip(), addr1.ip()); + assert_ne!(first.port(), second.port()); +} diff --git a/src/server/config/tests/tests_integration.rs b/src/server/config/tests/tests_integration.rs new file mode 100644 index 00000000..671bbd8d --- /dev/null +++ b/src/server/config/tests/tests_integration.rs @@ -0,0 +1,70 @@ +//! Integration-style tests for server builder composition. + +use std::{ + io, + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, +}; + +use bincode::error::DecodeError; +use rstest::rstest; + +use crate::{ + app::WireframeApp, + server::{ + WireframeServer, + test_util::{TestPreamble, factory, free_listener}, + }, +}; + +#[rstest] +#[tokio::test] +async fn test_method_chaining( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_listener: std::net::TcpListener, +) { + let handler_invoked = Arc::new(AtomicUsize::new(0)); + let counter = handler_invoked.clone(); + let server = WireframeServer::new(factory) + .workers(2) + .with_preamble::() + .on_preamble_decode_success(move |_p: &TestPreamble, _| { + let c = counter.clone(); + Box::pin(async move { + c.fetch_add(1, Ordering::SeqCst); + Ok(()) + }) + }) + .on_preamble_decode_failure(|_: &DecodeError, _| { + Box::pin(async { Ok::<(), io::Error>(()) }) + }) + .bind_existing_listener(free_listener) + .expect("Failed to bind"); + assert_eq!(server.worker_count(), 2); + assert!(server.local_addr().is_some()); + assert_eq!(handler_invoked.load(Ordering::SeqCst), 0); +} + +#[rstest] +#[tokio::test] +async fn test_server_configuration_persistence( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_listener: std::net::TcpListener, +) { + let server = WireframeServer::new(factory) + .workers(5) + .bind_existing_listener(free_listener) + .expect("Failed to bind"); + assert_eq!(server.worker_count(), 5); + assert!(server.local_addr().is_some()); +} + +#[rstest] +fn test_extreme_worker_counts(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { + let mut server = WireframeServer::new(factory); + server = server.workers(usize::MAX); + assert_eq!(server.worker_count(), usize::MAX); + assert_eq!(server.workers(0).worker_count(), 1); +} diff --git a/src/server/config/tests/tests_preamble.rs b/src/server/config/tests/tests_preamble.rs new file mode 100644 index 00000000..5cf76ed6 --- /dev/null +++ b/src/server/config/tests/tests_preamble.rs @@ -0,0 +1,125 @@ +//! Preamble configuration tests for `WireframeServer`. + +use std::{ + io, + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, + time::Duration, +}; + +use bincode::error::DecodeError; +use rstest::rstest; +use tokio::net::{TcpListener, TcpStream}; + +use super::PreambleHandlerKind; +use crate::{ + app::WireframeApp, + server::{ + WireframeServer, + test_util::{TestPreamble, factory, server_with_preamble}, + }, +}; + +#[rstest] +fn test_with_preamble_type_conversion( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, +) { + let server = WireframeServer::new(factory).with_preamble::(); + assert_eq!( + server.worker_count(), + super::expected_default_worker_count() + ); +} + +#[rstest] +fn test_preamble_timeout_configuration( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, +) { + let timeout = Duration::from_millis(25); + let server = WireframeServer::new(factory.clone()).preamble_timeout(timeout); + assert_eq!(server.preamble_timeout, Some(timeout)); + + let clamped = WireframeServer::new(factory).preamble_timeout(Duration::ZERO); + assert_eq!(clamped.preamble_timeout, Some(Duration::from_millis(1))); +} + +#[rstest] +#[case::success(PreambleHandlerKind::Success)] +#[case::failure(PreambleHandlerKind::Failure)] +#[tokio::test] +async fn test_preamble_handler_registration( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + #[case] handler: PreambleHandlerKind, +) { + let counter = Arc::new(AtomicUsize::new(0)); + let c = counter.clone(); + + let server = server_with_preamble(factory); + let server = match handler { + PreambleHandlerKind::Success => { + server.on_preamble_decode_success(move |_p: &TestPreamble, _| { + let c = c.clone(); + Box::pin(async move { + c.fetch_add(1, Ordering::SeqCst); + Ok(()) + }) + }) + } + PreambleHandlerKind::Failure => { + server.on_preamble_decode_failure(move |_err: &DecodeError, _stream| { + let c = c.clone(); + Box::pin(async move { + c.fetch_add(1, Ordering::SeqCst); + Ok::<(), io::Error>(()) + }) + }) + } + }; + + assert_eq!(counter.load(Ordering::SeqCst), 0); + match handler { + PreambleHandlerKind::Success => { + assert!(server.on_preamble_success.is_some()); + let handler = server + .on_preamble_success + .as_ref() + .expect("success handler missing"); + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind listener"); + let addr = listener.local_addr().expect("listener addr"); + let _client = TcpStream::connect(addr) + .await + .expect("client connect failed"); + let (mut stream, _) = listener.accept().await.expect("accept stream"); + let preamble = TestPreamble { + id: 0, + message: String::new(), + }; + handler(&preamble, &mut stream) + .await + .expect("handler failed"); + } + PreambleHandlerKind::Failure => { + assert!(server.on_preamble_failure.is_some()); + let handler = server + .on_preamble_failure + .as_ref() + .expect("failure handler missing"); + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind listener"); + let addr = listener.local_addr().expect("listener addr"); + let _client = TcpStream::connect(addr) + .await + .expect("client connect failed"); + let (mut stream, _) = listener.accept().await.expect("accept stream"); + handler(&DecodeError::UnexpectedEnd { additional: 0 }, &mut stream) + .await + .expect("handler failed"); + } + } + assert_eq!(counter.load(Ordering::SeqCst), 1); +} From 3416e55c1cc606ee8d7cdbbddab5847dbf21192b Mon Sep 17 00:00:00 2001 From: Leynos Date: Thu, 5 Feb 2026 12:26:31 +0000 Subject: [PATCH 2/9] test(server/config): refactor backoff tests to use parameterized cases and add tcp connection setup helper - Replace multiple individual backoff config tests with parameterized test cases for better coverage and conciseness. - Add detailed cases including clamp, swap, and default value scenarios for backoff configuration validation. - Extract and reuse async TCP connection setup in preamble tests to reduce redundancy and improve readability. Co-authored-by: devboxerhub[bot] --- src/server/config/tests/tests_backoff.rs | 172 +++++++++++----------- src/server/config/tests/tests_preamble.rs | 29 ++-- 2 files changed, 99 insertions(+), 102 deletions(-) diff --git a/src/server/config/tests/tests_backoff.rs b/src/server/config/tests/tests_backoff.rs index bb3fb50d..7690649f 100644 --- a/src/server/config/tests/tests_backoff.rs +++ b/src/server/config/tests/tests_backoff.rs @@ -10,15 +10,43 @@ use crate::{ }; #[rstest] -fn test_accept_backoff_configuration( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) { - let cfg = BackoffConfig { +#[case::custom_backoff( + BackoffConfig { initial_delay: Duration::from_millis(5), max_delay: Duration::from_millis(500), - }; - let server = WireframeServer::new(factory).accept_backoff(cfg); - assert_eq!(server.backoff_config, cfg); + }, + Duration::from_millis(5), + Duration::from_millis(500), + "custom backoff" +)] +#[case::custom_initial_delay( + BackoffConfig { + initial_delay: Duration::from_millis(20), + ..BackoffConfig::default() + }, + Duration::from_millis(20), + Duration::from_secs(1), + "custom initial delay" +)] +#[case::custom_max_delay( + BackoffConfig { + max_delay: Duration::from_millis(2000), + ..BackoffConfig::default() + }, + Duration::from_millis(10), + Duration::from_millis(2000), + "custom max delay" +)] +fn test_accept_backoff_configuration_scenarios( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + #[case] config: BackoffConfig, + #[case] expected_initial: Duration, + #[case] expected_max: Duration, + #[case] _description: &'static str, +) { + let server = WireframeServer::new(factory).accept_backoff(config); + assert_eq!(server.backoff_config.initial_delay, expected_initial); + assert_eq!(server.backoff_config.max_delay, expected_max); } /// Behaviour test verifying exponential delay doubling and capping. @@ -70,51 +98,61 @@ fn test_accept_exponential_backoff_doubles_and_caps() { } #[rstest] -fn test_accept_initial_delay_configuration( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) { - let delay = Duration::from_millis(20); - let cfg = BackoffConfig { - initial_delay: delay, - ..BackoffConfig::default() - }; - let server = WireframeServer::new(factory).accept_backoff(cfg); - assert_eq!(server.backoff_config.initial_delay, delay); -} - -#[rstest] -fn test_accept_max_delay_configuration( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) { - let delay = Duration::from_millis(2000); - let cfg = BackoffConfig { - max_delay: delay, - ..BackoffConfig::default() - }; - let server = WireframeServer::new(factory).accept_backoff(cfg); - assert_eq!(server.backoff_config.max_delay, delay); -} - -#[rstest] -fn test_backoff_validation(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - let server = WireframeServer::new(factory.clone()).accept_backoff(BackoffConfig { +#[case::clamp_initial_delay( + BackoffConfig { initial_delay: Duration::ZERO, ..BackoffConfig::default() - }); - assert_eq!( - server.backoff_config.initial_delay, - Duration::from_millis(1) - ); - - let server = WireframeServer::new(factory).accept_backoff(BackoffConfig { + }, + Duration::from_millis(1), + Duration::from_secs(1), + "clamp initial delay" +)] +#[case::swap_shorter_max( + BackoffConfig { initial_delay: Duration::from_millis(100), max_delay: Duration::from_millis(50), - }); - assert_eq!( - server.backoff_config.initial_delay, - Duration::from_millis(50) - ); - assert_eq!(server.backoff_config.max_delay, Duration::from_millis(100)); + }, + Duration::from_millis(50), + Duration::from_millis(100), + "swap shorter max" +)] +#[case::swap_with_default_max( + BackoffConfig { + initial_delay: Duration::from_secs(2), + max_delay: Duration::from_secs(1), + }, + Duration::from_secs(1), + Duration::from_secs(2), + "swap with default max" +)] +#[case::swap_small_values( + BackoffConfig { + initial_delay: Duration::from_millis(5), + max_delay: Duration::from_millis(1), + }, + Duration::from_millis(1), + Duration::from_millis(5), + "swap small values" +)] +#[case::clamp_both_zero( + BackoffConfig { + initial_delay: Duration::ZERO, + max_delay: Duration::ZERO, + }, + Duration::from_millis(1), + Duration::from_millis(1), + "clamp both zero" +)] +fn test_backoff_validation_scenarios( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + #[case] config: BackoffConfig, + #[case] expected_initial: Duration, + #[case] expected_max: Duration, + #[case] _description: &'static str, +) { + let server = WireframeServer::new(factory).accept_backoff(config); + assert_eq!(server.backoff_config.initial_delay, expected_initial); + assert_eq!(server.backoff_config.max_delay, expected_max); } #[rstest] @@ -126,41 +164,3 @@ fn test_backoff_default_values(factory: impl Fn() -> WireframeApp + Send + Sync ); assert_eq!(server.backoff_config.max_delay, Duration::from_secs(1)); } - -#[rstest] -fn test_initial_delay_exceeds_default_max( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) { - let cfg = BackoffConfig { - initial_delay: Duration::from_secs(2), - max_delay: Duration::from_secs(1), - }; - let server = WireframeServer::new(factory).accept_backoff(cfg); - assert_eq!(server.backoff_config.initial_delay, Duration::from_secs(1)); - assert_eq!(server.backoff_config.max_delay, Duration::from_secs(2)); -} - -#[rstest] -fn test_accept_backoff_parameter_swapping( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, -) { - let server = WireframeServer::new(factory.clone()).accept_backoff(BackoffConfig { - initial_delay: Duration::from_millis(5), - max_delay: Duration::from_millis(1), - }); - assert_eq!( - server.backoff_config.initial_delay, - Duration::from_millis(1) - ); - assert_eq!(server.backoff_config.max_delay, Duration::from_millis(5)); - - let server = WireframeServer::new(factory).accept_backoff(BackoffConfig { - initial_delay: Duration::ZERO, - max_delay: Duration::ZERO, - }); - assert_eq!( - server.backoff_config.initial_delay, - Duration::from_millis(1) - ); - assert_eq!(server.backoff_config.max_delay, Duration::from_millis(1)); -} diff --git a/src/server/config/tests/tests_preamble.rs b/src/server/config/tests/tests_preamble.rs index 5cf76ed6..79dbee80 100644 --- a/src/server/config/tests/tests_preamble.rs +++ b/src/server/config/tests/tests_preamble.rs @@ -22,6 +22,18 @@ use crate::{ }, }; +async fn setup_tcp_connection() -> TcpStream { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("bind listener"); + let addr = listener.local_addr().expect("listener addr"); + let _client = TcpStream::connect(addr) + .await + .expect("client connect failed"); + let (stream, _) = listener.accept().await.expect("accept stream"); + stream +} + #[rstest] fn test_with_preamble_type_conversion( factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, @@ -79,6 +91,7 @@ async fn test_preamble_handler_registration( }; assert_eq!(counter.load(Ordering::SeqCst), 0); + let mut stream = setup_tcp_connection().await; match handler { PreambleHandlerKind::Success => { assert!(server.on_preamble_success.is_some()); @@ -86,14 +99,6 @@ async fn test_preamble_handler_registration( .on_preamble_success .as_ref() .expect("success handler missing"); - let listener = TcpListener::bind("127.0.0.1:0") - .await - .expect("bind listener"); - let addr = listener.local_addr().expect("listener addr"); - let _client = TcpStream::connect(addr) - .await - .expect("client connect failed"); - let (mut stream, _) = listener.accept().await.expect("accept stream"); let preamble = TestPreamble { id: 0, message: String::new(), @@ -108,14 +113,6 @@ async fn test_preamble_handler_registration( .on_preamble_failure .as_ref() .expect("failure handler missing"); - let listener = TcpListener::bind("127.0.0.1:0") - .await - .expect("bind listener"); - let addr = listener.local_addr().expect("listener addr"); - let _client = TcpStream::connect(addr) - .await - .expect("client connect failed"); - let (mut stream, _) = listener.accept().await.expect("accept stream"); handler(&DecodeError::UnexpectedEnd { additional: 0 }, &mut stream) .await .expect("handler failed"); From 488cf83971db15f2b013b248d711462e34aeca98 Mon Sep 17 00:00:00 2001 From: Leynos Date: Thu, 5 Feb 2026 12:36:37 +0000 Subject: [PATCH 3/9] test(server/config): refactor backoff tests to use BackoffScenario struct - Introduced BackoffScenario struct to encapsulate test parameters - Updated test cases to use BackoffScenario instances for clarity - Simplified assertions by using scenario description messages - Improved maintainability and readability of backoff config tests Co-authored-by: devboxerhub[bot] --- src/server/config/tests/tests_backoff.rs | 178 ++++++++++++++--------- 1 file changed, 108 insertions(+), 70 deletions(-) diff --git a/src/server/config/tests/tests_backoff.rs b/src/server/config/tests/tests_backoff.rs index 7690649f..c6dd84e4 100644 --- a/src/server/config/tests/tests_backoff.rs +++ b/src/server/config/tests/tests_backoff.rs @@ -9,44 +9,65 @@ use crate::{ server::{BackoffConfig, WireframeServer, test_util::factory}, }; +#[derive(Debug)] +struct BackoffScenario { + config: BackoffConfig, + expected_initial: Duration, + expected_max: Duration, + description: &'static str, +} + #[rstest] #[case::custom_backoff( - BackoffConfig { - initial_delay: Duration::from_millis(5), - max_delay: Duration::from_millis(500), - }, - Duration::from_millis(5), - Duration::from_millis(500), - "custom backoff" + BackoffScenario { + config: BackoffConfig { + initial_delay: Duration::from_millis(5), + max_delay: Duration::from_millis(500), + }, + expected_initial: Duration::from_millis(5), + expected_max: Duration::from_millis(500), + description: "custom backoff", + } )] #[case::custom_initial_delay( - BackoffConfig { - initial_delay: Duration::from_millis(20), - ..BackoffConfig::default() - }, - Duration::from_millis(20), - Duration::from_secs(1), - "custom initial delay" + BackoffScenario { + config: BackoffConfig { + initial_delay: Duration::from_millis(20), + ..BackoffConfig::default() + }, + expected_initial: Duration::from_millis(20), + expected_max: Duration::from_secs(1), + description: "custom initial delay", + } )] #[case::custom_max_delay( - BackoffConfig { - max_delay: Duration::from_millis(2000), - ..BackoffConfig::default() - }, - Duration::from_millis(10), - Duration::from_millis(2000), - "custom max delay" + BackoffScenario { + config: BackoffConfig { + max_delay: Duration::from_millis(2000), + ..BackoffConfig::default() + }, + expected_initial: Duration::from_millis(10), + expected_max: Duration::from_millis(2000), + description: "custom max delay", + } )] fn test_accept_backoff_configuration_scenarios( factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - #[case] config: BackoffConfig, - #[case] expected_initial: Duration, - #[case] expected_max: Duration, - #[case] _description: &'static str, + #[case] scenario: BackoffScenario, ) { - let server = WireframeServer::new(factory).accept_backoff(config); - assert_eq!(server.backoff_config.initial_delay, expected_initial); - assert_eq!(server.backoff_config.max_delay, expected_max); + let server = WireframeServer::new(factory).accept_backoff(scenario.config); + assert_eq!( + server.backoff_config.initial_delay, + scenario.expected_initial, + "scenario: {}", + scenario.description + ); + assert_eq!( + server.backoff_config.max_delay, + scenario.expected_max, + "scenario: {}", + scenario.description + ); } /// Behaviour test verifying exponential delay doubling and capping. @@ -99,60 +120,77 @@ fn test_accept_exponential_backoff_doubles_and_caps() { #[rstest] #[case::clamp_initial_delay( - BackoffConfig { - initial_delay: Duration::ZERO, - ..BackoffConfig::default() - }, - Duration::from_millis(1), - Duration::from_secs(1), - "clamp initial delay" + BackoffScenario { + config: BackoffConfig { + initial_delay: Duration::ZERO, + ..BackoffConfig::default() + }, + expected_initial: Duration::from_millis(1), + expected_max: Duration::from_secs(1), + description: "clamp initial delay", + } )] #[case::swap_shorter_max( - BackoffConfig { - initial_delay: Duration::from_millis(100), - max_delay: Duration::from_millis(50), - }, - Duration::from_millis(50), - Duration::from_millis(100), - "swap shorter max" + BackoffScenario { + config: BackoffConfig { + initial_delay: Duration::from_millis(100), + max_delay: Duration::from_millis(50), + }, + expected_initial: Duration::from_millis(50), + expected_max: Duration::from_millis(100), + description: "swap shorter max", + } )] #[case::swap_with_default_max( - BackoffConfig { - initial_delay: Duration::from_secs(2), - max_delay: Duration::from_secs(1), - }, - Duration::from_secs(1), - Duration::from_secs(2), - "swap with default max" + BackoffScenario { + config: BackoffConfig { + initial_delay: Duration::from_secs(2), + max_delay: Duration::from_secs(1), + }, + expected_initial: Duration::from_secs(1), + expected_max: Duration::from_secs(2), + description: "swap with default max", + } )] #[case::swap_small_values( - BackoffConfig { - initial_delay: Duration::from_millis(5), - max_delay: Duration::from_millis(1), - }, - Duration::from_millis(1), - Duration::from_millis(5), - "swap small values" + BackoffScenario { + config: BackoffConfig { + initial_delay: Duration::from_millis(5), + max_delay: Duration::from_millis(1), + }, + expected_initial: Duration::from_millis(1), + expected_max: Duration::from_millis(5), + description: "swap small values", + } )] #[case::clamp_both_zero( - BackoffConfig { - initial_delay: Duration::ZERO, - max_delay: Duration::ZERO, - }, - Duration::from_millis(1), - Duration::from_millis(1), - "clamp both zero" + BackoffScenario { + config: BackoffConfig { + initial_delay: Duration::ZERO, + max_delay: Duration::ZERO, + }, + expected_initial: Duration::from_millis(1), + expected_max: Duration::from_millis(1), + description: "clamp both zero", + } )] fn test_backoff_validation_scenarios( factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - #[case] config: BackoffConfig, - #[case] expected_initial: Duration, - #[case] expected_max: Duration, - #[case] _description: &'static str, + #[case] scenario: BackoffScenario, ) { - let server = WireframeServer::new(factory).accept_backoff(config); - assert_eq!(server.backoff_config.initial_delay, expected_initial); - assert_eq!(server.backoff_config.max_delay, expected_max); + let server = WireframeServer::new(factory).accept_backoff(scenario.config); + assert_eq!( + server.backoff_config.initial_delay, + scenario.expected_initial, + "scenario: {}", + scenario.description + ); + assert_eq!( + server.backoff_config.max_delay, + scenario.expected_max, + "scenario: {}", + scenario.description + ); } #[rstest] From 27b0aa62d5847f968f9633034eb40edb4ef03e0f Mon Sep 17 00:00:00 2001 From: Leynos Date: Thu, 5 Feb 2026 13:46:49 +0000 Subject: [PATCH 4/9] style: format backoff tests Run cargo fmt to align parameterised test assertions with project formatting rules. --- src/server/config/tests/tests_backoff.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/server/config/tests/tests_backoff.rs b/src/server/config/tests/tests_backoff.rs index c6dd84e4..ccffc3c5 100644 --- a/src/server/config/tests/tests_backoff.rs +++ b/src/server/config/tests/tests_backoff.rs @@ -57,14 +57,12 @@ fn test_accept_backoff_configuration_scenarios( ) { let server = WireframeServer::new(factory).accept_backoff(scenario.config); assert_eq!( - server.backoff_config.initial_delay, - scenario.expected_initial, + server.backoff_config.initial_delay, scenario.expected_initial, "scenario: {}", scenario.description ); assert_eq!( - server.backoff_config.max_delay, - scenario.expected_max, + server.backoff_config.max_delay, scenario.expected_max, "scenario: {}", scenario.description ); @@ -180,14 +178,12 @@ fn test_backoff_validation_scenarios( ) { let server = WireframeServer::new(factory).accept_backoff(scenario.config); assert_eq!( - server.backoff_config.initial_delay, - scenario.expected_initial, + server.backoff_config.initial_delay, scenario.expected_initial, "scenario: {}", scenario.description ); assert_eq!( - server.backoff_config.max_delay, - scenario.expected_max, + server.backoff_config.max_delay, scenario.expected_max, "scenario: {}", scenario.description ); From 053c82d7259ec90a07b61d1144e73c768c791521 Mon Sep 17 00:00:00 2001 From: Leynos Date: Thu, 5 Feb 2026 11:19:15 +0000 Subject: [PATCH 5/9] refactor(app builder): modularize WireframeApp and remove deprecated client builder Removed the monolithic src/app/builder.rs file and replaced it with a modular builder implementation split across multiple files (codec.rs, config.rs, core.rs, lifecycle.rs, mod.rs, protocol.rs, routing.rs, state.rs). This improves maintainability and readability by organizing related functionality into dedicated modules. Deleted the deprecated src/client/builder.rs and related client builder modules, cleaning up the codebase. Additionally, extracted frame handling helpers into a frame_handling module with submodules for core, reassembly, and response forwarding logic, keeping connection.rs smaller and more focused. Overall, this commit significantly improves project structure by organizing app builder and frame handling code into well-defined modules, and removing outdated client builder code. Co-authored-by: devboxerhub[bot] --- src/server/config/tests/tests_backoff.rs | 98 +++++++---------------- src/server/config/tests/tests_preamble.rs | 53 ++++++------ 2 files changed, 56 insertions(+), 95 deletions(-) diff --git a/src/server/config/tests/tests_backoff.rs b/src/server/config/tests/tests_backoff.rs index ccffc3c5..429e791f 100644 --- a/src/server/config/tests/tests_backoff.rs +++ b/src/server/config/tests/tests_backoff.rs @@ -51,72 +51,6 @@ struct BackoffScenario { description: "custom max delay", } )] -fn test_accept_backoff_configuration_scenarios( - factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, - #[case] scenario: BackoffScenario, -) { - let server = WireframeServer::new(factory).accept_backoff(scenario.config); - assert_eq!( - server.backoff_config.initial_delay, scenario.expected_initial, - "scenario: {}", - scenario.description - ); - assert_eq!( - server.backoff_config.max_delay, scenario.expected_max, - "scenario: {}", - scenario.description - ); -} - -/// Behaviour test verifying exponential delay doubling and capping. -#[test] -fn test_accept_exponential_backoff_doubles_and_caps() { - use std::{ - thread, - time::{Duration, Instant}, - }; - - let initial = Duration::from_millis(10); - let max = Duration::from_millis(80); - let mut backoff = initial; - let mut delays = Vec::new(); - let attempts = 5; - - let start = Instant::now(); - let mut last = start; - - for _i in 0..attempts { - thread::sleep(backoff); - let now = Instant::now(); - let elapsed = now.duration_since(last); - delays.push(elapsed); - last = now; - - backoff = std::cmp::min(backoff * 2, max); - } - - let expected_delays = [ - initial, - std::cmp::min(initial * 2, max), - std::cmp::min(initial * 4, max), - std::cmp::min(initial * 8, max), - max, - ]; - - for (i, (actual, expected)) in delays.iter().zip(expected_delays.iter()).enumerate() { - assert!( - *actual >= *expected, - "Delay {i} was {actual:?}, expected at least {expected:?}" - ); - let max_expected = *expected + Duration::from_millis(20); - assert!( - *actual < max_expected, - "Delay {i} was {actual:?}, expected less than {max_expected:?}" - ); - } -} - -#[rstest] #[case::clamp_initial_delay( BackoffScenario { config: BackoffConfig { @@ -172,7 +106,7 @@ fn test_accept_exponential_backoff_doubles_and_caps() { description: "clamp both zero", } )] -fn test_backoff_validation_scenarios( +fn test_backoff_configuration_scenarios( factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, #[case] scenario: BackoffScenario, ) { @@ -189,6 +123,36 @@ fn test_backoff_validation_scenarios( ); } +/// Behaviour test verifying exponential delay doubling and capping. +#[test] +fn test_accept_exponential_backoff_doubles_and_caps() { + let initial = Duration::from_millis(10); + let max = Duration::from_millis(80); + let attempts = 5; + let delays = backoff_sequence(initial, max, attempts); + let expected_delays = vec![ + initial, + std::cmp::min(initial * 2, max), + std::cmp::min(initial * 4, max), + std::cmp::min(initial * 8, max), + max, + ]; + + assert_eq!(delays, expected_delays); +} + +fn backoff_sequence(initial: Duration, max: Duration, attempts: usize) -> Vec { + let mut backoff = initial; + let mut delays = Vec::with_capacity(attempts); + + for _ in 0..attempts { + delays.push(backoff); + backoff = std::cmp::min(backoff * 2, max); + } + + delays +} + #[rstest] fn test_backoff_default_values(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { let server = WireframeServer::new(factory); diff --git a/src/server/config/tests/tests_preamble.rs b/src/server/config/tests/tests_preamble.rs index 79dbee80..16c459f1 100644 --- a/src/server/config/tests/tests_preamble.rs +++ b/src/server/config/tests/tests_preamble.rs @@ -2,6 +2,7 @@ use std::{ io, + net::TcpListener as StdTcpListener, sync::{ Arc, atomic::{AtomicUsize, Ordering}, @@ -10,7 +11,7 @@ use std::{ }; use bincode::error::DecodeError; -use rstest::rstest; +use rstest::{fixture, rstest}; use tokio::net::{TcpListener, TcpStream}; use super::PreambleHandlerKind; @@ -18,31 +19,16 @@ use crate::{ app::WireframeApp, server::{ WireframeServer, + default_worker_count, test_util::{TestPreamble, factory, server_with_preamble}, }, }; - -async fn setup_tcp_connection() -> TcpStream { - let listener = TcpListener::bind("127.0.0.1:0") - .await - .expect("bind listener"); - let addr = listener.local_addr().expect("listener addr"); - let _client = TcpStream::connect(addr) - .await - .expect("client connect failed"); - let (stream, _) = listener.accept().await.expect("accept stream"); - stream -} - #[rstest] fn test_with_preamble_type_conversion( factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, ) { let server = WireframeServer::new(factory).with_preamble::(); - assert_eq!( - server.worker_count(), - super::expected_default_worker_count() - ); + assert_eq!(server.worker_count(), default_worker_count()); } #[rstest] @@ -64,7 +50,8 @@ fn test_preamble_timeout_configuration( async fn test_preamble_handler_registration( factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, #[case] handler: PreambleHandlerKind, -) { + bound_listener: io::Result, +) -> io::Result<()> { let counter = Arc::new(AtomicUsize::new(0)); let c = counter.clone(); @@ -91,10 +78,9 @@ async fn test_preamble_handler_registration( }; assert_eq!(counter.load(Ordering::SeqCst), 0); - let mut stream = setup_tcp_connection().await; + let mut stream = accept_stream(bound_listener?).await?; match handler { PreambleHandlerKind::Success => { - assert!(server.on_preamble_success.is_some()); let handler = server .on_preamble_success .as_ref() @@ -103,20 +89,31 @@ async fn test_preamble_handler_registration( id: 0, message: String::new(), }; - handler(&preamble, &mut stream) - .await - .expect("handler failed"); + handler(&preamble, &mut stream).await?; } PreambleHandlerKind::Failure => { - assert!(server.on_preamble_failure.is_some()); let handler = server .on_preamble_failure .as_ref() .expect("failure handler missing"); - handler(&DecodeError::UnexpectedEnd { additional: 0 }, &mut stream) - .await - .expect("handler failed"); + handler(&DecodeError::UnexpectedEnd { additional: 0 }, &mut stream).await?; } } assert_eq!(counter.load(Ordering::SeqCst), 1); + Ok(()) +} + +#[fixture] +fn bound_listener() -> io::Result { + let addr = "127.0.0.1:0"; + StdTcpListener::bind(addr) +} + +async fn accept_stream(listener: StdTcpListener) -> io::Result { + let addr = listener.local_addr()?; + listener.set_nonblocking(true)?; + let listener = TcpListener::from_std(listener)?; + let _client = TcpStream::connect(addr).await?; + let (stream, _) = listener.accept().await?; + Ok(stream) } From 5c7b542062342f0871a6727eb0733775fbc56aa4 Mon Sep 17 00:00:00 2001 From: Leynos Date: Mon, 9 Feb 2026 11:58:06 +0000 Subject: [PATCH 6/9] refactor(frame_handling): streamline response handling with ? operator and improve logging - Replace nested matches with `?` operator in response serialization and sending - Enhance logging messages with consistent formatting and variable interpolation - Add documentation comments for DeserFailureTracker and ResponseContext - Improve error propagation and metric increments - Add and enhance tests with rstest fixtures and parameterization - General code cleanup for readability and safety (e.g., saturating_add) Co-authored-by: devboxerhub[bot] --- src/app/builder/codec.rs | 2 +- src/app/builder/core.rs | 4 +- src/app/frame_handling/core.rs | 4 +- src/app/frame_handling/response.rs | 40 +++++------ src/app/frame_handling/tests.rs | 72 +++++++++++++++----- src/client/builder/connect.rs | 8 ++- src/client/builder/lifecycle.rs | 7 +- src/codec/recovery/config.rs | 2 +- src/codec/recovery/hook.rs | 3 +- src/codec/recovery/tests.rs | 50 ++++++++++---- src/extractor/connection_info.rs | 5 +- src/extractor/error.rs | 37 ++-------- src/extractor/message.rs | 4 +- src/extractor/request.rs | 14 ++-- src/server/config/mod.rs | 2 +- src/server/config/tests/mod.rs | 5 -- src/server/config/tests/tests_basic.rs | 13 ++-- src/server/config/tests/tests_integration.rs | 4 +- src/server/mod.rs | 4 ++ 19 files changed, 167 insertions(+), 113 deletions(-) diff --git a/src/app/builder/codec.rs b/src/app/builder/codec.rs index f91ebc09..934c65dd 100644 --- a/src/app/builder/codec.rs +++ b/src/app/builder/codec.rs @@ -53,7 +53,7 @@ where impl WireframeApp where - S: Serializer + Default + Send + Sync, + S: Serializer + Send + Sync, C: Send + 'static, E: Packet, { diff --git a/src/app/builder/core.rs b/src/app/builder/core.rs index cb4a03ef..8ebc6795 100644 --- a/src/app/builder/core.rs +++ b/src/app/builder/core.rs @@ -10,7 +10,7 @@ use tokio::sync::{OnceCell, mpsc}; use crate::{ app::{ - builder_defaults::default_fragmentation, + builder_defaults::{DEFAULT_READ_TIMEOUT_MS, default_fragmentation}, envelope::{Envelope, Packet}, error::Result, lifecycle::{ConnectionSetup, ConnectionTeardown}, @@ -73,7 +73,7 @@ where protocol: None, push_dlq: None, codec, - read_timeout_ms: 100, + read_timeout_ms: DEFAULT_READ_TIMEOUT_MS, fragmentation: default_fragmentation(max_frame_length), message_assembler: None, } diff --git a/src/app/frame_handling/core.rs b/src/app/frame_handling/core.rs index 4b0be8bf..2fff9b7e 100644 --- a/src/app/frame_handling/core.rs +++ b/src/app/frame_handling/core.rs @@ -12,6 +12,7 @@ use crate::{ serializer::Serializer, }; +/// Tracks deserialization failures and enforces a maximum error threshold. pub(super) struct DeserFailureTracker<'a> { count: &'a mut u32, limit: u32, @@ -26,7 +27,7 @@ impl<'a> DeserFailureTracker<'a> { context: &str, err: impl std::fmt::Debug, ) -> io::Result<()> { - *self.count += 1; + *self.count = (*self.count).saturating_add(1); warn!("{context}: correlation_id={correlation_id:?}, error={err:?}"); crate::metrics::inc_deser_errors(); if *self.count >= self.limit { @@ -39,6 +40,7 @@ impl<'a> DeserFailureTracker<'a> { } } +/// Bundles shared dependencies for response forwarding. pub(crate) struct ResponseContext<'a, S, W, F> where S: Serializer + Send + Sync, diff --git a/src/app/frame_handling/response.rs b/src/app/frame_handling/response.rs index a8ef2c02..d135c489 100644 --- a/src/app/frame_handling/response.rs +++ b/src/app/frame_handling/response.rs @@ -38,8 +38,9 @@ where Ok(resp) => resp, Err(e) => { warn!( - "handler error: id={}, correlation_id={:?}, error={e:?}", - env.id, env.correlation_id + "handler error: id={id}, correlation_id={correlation_id:?}, error={e:?}", + id = env.id, + correlation_id = env.correlation_id ); crate::metrics::inc_handler_errors(); return Ok(()); @@ -49,22 +50,11 @@ where let parts = PacketParts::new(env.id, resp.correlation_id(), resp.into_inner()) .inherit_correlation(env.correlation_id); let correlation_id = parts.correlation_id(); - let Ok(responses) = fragment_responses(ctx.fragmentation, parts, env.id, correlation_id) else { - return Ok(()); // already logged - }; + let responses = fragment_responses(ctx.fragmentation, parts, env.id, correlation_id)?; for response in responses { - let Ok(bytes) = serialize_response(ctx.serializer, &response, env.id, correlation_id) - else { - break; // already logged - }; - - if send_response_payload::(ctx.codec, ctx.framed, Bytes::from(bytes), &response) - .await - .is_err() - { - break; - } + let bytes = serialize_response(ctx.serializer, &response, env.id, correlation_id)?; + send_response_payload::(ctx.codec, ctx.framed, Bytes::from(bytes), &response).await?; } Ok(()) @@ -82,8 +72,13 @@ fn fragment_responses( Ok(fragmented) => Ok(fragmented), Err(err) => { warn!( - "failed to fragment response: id={id}, correlation_id={correlation_id:?}, \ - error={err:?}" + concat!( + "failed to fragment response: id={id}, correlation_id={correlation_id:?}, ", + "error={err:?}" + ), + id = id, + correlation_id = correlation_id, + err = err ); crate::metrics::inc_handler_errors(); Err(io::Error::other("fragmentation failed")) @@ -103,8 +98,13 @@ fn serialize_response( Ok(bytes) => Ok(bytes), Err(e) => { warn!( - "failed to serialize response: id={id}, correlation_id={correlation_id:?}, \ - error={e:?}" + concat!( + "failed to serialize response: id={id}, correlation_id={correlation_id:?}, ", + "error={e:?}" + ), + id = id, + correlation_id = correlation_id, + e = e ); crate::metrics::inc_handler_errors(); Err(io::Error::other("serialization failed")) diff --git a/src/app/frame_handling/tests.rs b/src/app/frame_handling/tests.rs index abaebe1c..e693f68b 100644 --- a/src/app/frame_handling/tests.rs +++ b/src/app/frame_handling/tests.rs @@ -7,6 +7,8 @@ use std::sync::{ use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures::StreamExt; +use rstest::{fixture, rstest}; +use tokio::io::DuplexStream; use tokio_util::codec::{Decoder, Encoder}; use super::{ResponseContext, response::send_response_payload}; @@ -15,13 +17,14 @@ use crate::{ codec::FrameCodec, }; -/// Test codec that wraps payloads with a distinctive tag byte. +/// Test frame carrying a tag byte and payload. #[derive(Clone, Debug)] struct TestFrame { tag: u8, payload: Vec, } +/// Test codec that wraps payloads with a distinctive tag byte. #[derive(Clone, Debug)] struct TestCodec { max_frame_length: usize, @@ -126,13 +129,46 @@ impl FrameCodec for TestCodec { fn max_frame_length(&self) -> usize { self.max_frame_length } } -/// Verify `send_response_payload` uses `F::wrap_payload` to frame responses. -#[tokio::test] -async fn send_response_payload_wraps_with_codec() { - let codec = TestCodec::new(64); +struct TestHarness { + codec: TestCodec, + framed: tokio_util::codec::Framed>, + client: DuplexStream, +} + +#[fixture] +fn harness() -> TestHarness { + let max_frame_length = 64; + build_harness(max_frame_length) +} + +#[fixture] +fn small_harness() -> TestHarness { + let max_frame_length = 4; + build_harness(max_frame_length) +} + +fn build_harness(max_frame_length: usize) -> TestHarness { + let codec = TestCodec::new(max_frame_length); let (client, server) = tokio::io::duplex(256); let combined = CombinedCodec::new(codec.decoder(), codec.encoder()); - let mut framed = tokio_util::codec::Framed::new(server, combined); + let framed = tokio_util::codec::Framed::new(server, combined); + + TestHarness { + codec, + framed, + client, + } +} + +/// Verify `send_response_payload` uses `F::wrap_payload` to frame responses. +#[rstest] +#[tokio::test] +async fn send_response_payload_wraps_with_codec(harness: TestHarness) { + let TestHarness { + codec, + mut framed, + client, + } = harness; let payload = vec![1, 2, 3, 4]; let response = Envelope::new(1, Some(99), payload.clone()); @@ -161,14 +197,16 @@ async fn send_response_payload_wraps_with_codec() { } /// Verify `ResponseContext` fields are accessible and usable. +#[rstest] #[tokio::test] -async fn response_context_holds_references() { +async fn response_context_holds_references(harness: TestHarness) { use crate::serializer::BincodeSerializer; - let codec = TestCodec::new(64); - let (_client, server) = tokio::io::duplex(256); - let combined = CombinedCodec::new(codec.decoder(), codec.encoder()); - let mut framed = tokio_util::codec::Framed::new(server, combined); + let TestHarness { + codec, + mut framed, + client: _client, + } = harness; let serializer = BincodeSerializer; let mut fragmentation: Option = None; @@ -184,12 +222,14 @@ async fn response_context_holds_references() { } /// Verify `send_response_payload` returns error on send failure. +#[rstest] #[tokio::test] -async fn send_response_payload_returns_error_on_failure() { - let codec = TestCodec::new(4); // Small limit to trigger failure - let (_client, server) = tokio::io::duplex(256); - let combined = CombinedCodec::new(codec.decoder(), codec.encoder()); - let mut framed = tokio_util::codec::Framed::new(server, combined); +async fn send_response_payload_returns_error_on_failure(small_harness: TestHarness) { + let TestHarness { + codec, + mut framed, + client: _client, + } = small_harness; // Payload exceeds max_frame_length, so encode will fail let oversized_payload = vec![0u8; 100]; diff --git a/src/client/builder/connect.rs b/src/client/builder/connect.rs index ffb12c35..f72ed4b8 100644 --- a/src/client/builder/connect.rs +++ b/src/client/builder/connect.rs @@ -13,6 +13,8 @@ use crate::{ serializer::Serializer, }; +const INITIAL_READ_BUFFER_CAPACITY: usize = 64 * 1024; + impl WireframeClientBuilder where S: Serializer + Send + Sync, @@ -70,8 +72,10 @@ where let codec_config = self.codec_config; let codec = codec_config.build_codec(); let mut framed = Framed::new(RewindStream::new(leftover, stream), codec); - let initial_read_buffer_capacity = - core::cmp::min(64 * 1024, codec_config.max_frame_length_value()); + let initial_read_buffer_capacity = core::cmp::min( + INITIAL_READ_BUFFER_CAPACITY, + codec_config.max_frame_length_value(), + ); framed .read_buffer_mut() .reserve(initial_read_buffer_capacity); diff --git a/src/client/builder/lifecycle.rs b/src/client/builder/lifecycle.rs index a9269f2f..1ffcbe86 100644 --- a/src/client/builder/lifecycle.rs +++ b/src/client/builder/lifecycle.rs @@ -110,8 +110,11 @@ where /// ``` /// use wireframe::client::WireframeClientBuilder; /// - /// let builder = WireframeClientBuilder::new().on_error(|err| async move { - /// eprintln!("Client error: {err}"); + /// let builder = WireframeClientBuilder::new().on_error(|err| { + /// let message = err.to_string(); + /// async move { + /// eprintln!("Client error: {message}"); + /// } /// }); /// let _ = builder; /// ``` diff --git a/src/codec/recovery/config.rs b/src/codec/recovery/config.rs index 5526fa3b..69e75958 100644 --- a/src/codec/recovery/config.rs +++ b/src/codec/recovery/config.rs @@ -20,7 +20,7 @@ use std::time::Duration; /// /// assert_eq!(config.max_consecutive_drops, 5); /// ``` -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct RecoveryConfig { /// Maximum consecutive dropped frames before escalating to disconnect. /// diff --git a/src/codec/recovery/hook.rs b/src/codec/recovery/hook.rs index 80bce25b..aec5d1e6 100644 --- a/src/codec/recovery/hook.rs +++ b/src/codec/recovery/hook.rs @@ -49,8 +49,7 @@ pub trait RecoveryPolicyHook: Send + Sync { /// /// The default implementation delegates to /// [`CodecError::default_recovery_policy`]. - fn recovery_policy(&self, error: &CodecError, ctx: &CodecErrorContext) -> RecoveryPolicy { - let _ = ctx; + fn recovery_policy(&self, error: &CodecError, _ctx: &CodecErrorContext) -> RecoveryPolicy { error.default_recovery_policy() } diff --git a/src/codec/recovery/tests.rs b/src/codec/recovery/tests.rs index 6b59e328..4e7cc3f5 100644 --- a/src/codec/recovery/tests.rs +++ b/src/codec/recovery/tests.rs @@ -2,9 +2,23 @@ use std::{io, time::Duration}; +use rstest::{fixture, rstest}; + use super::*; use crate::codec::CodecError; +#[fixture] +fn default_hook() -> DefaultRecoveryPolicy { + std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst); + DefaultRecoveryPolicy +} + +#[fixture] +fn default_ctx() -> CodecErrorContext { + std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst); + CodecErrorContext::new() +} + #[test] fn recovery_policy_default_is_drop() { assert_eq!(RecoveryPolicy::default(), RecoveryPolicy::Drop); @@ -29,32 +43,42 @@ fn context_with_peer_address() { assert_eq!(ctx.peer_address, Some(addr)); } -#[test] -fn default_recovery_policy_delegates_to_error() { +#[rstest] +fn default_recovery_policy_delegates_to_error( + default_hook: DefaultRecoveryPolicy, + default_ctx: CodecErrorContext, +) { use crate::codec::error::{EofError, FramingError}; - let hook = DefaultRecoveryPolicy; - let ctx = CodecErrorContext::new(); - // Check various error types let err = CodecError::Framing(FramingError::OversizedFrame { size: 100, max: 50 }); - assert_eq!(hook.recovery_policy(&err, &ctx), RecoveryPolicy::Drop); + assert_eq!( + default_hook.recovery_policy(&err, &default_ctx), + RecoveryPolicy::Drop + ); let err = CodecError::Io(io::Error::other("test")); - assert_eq!(hook.recovery_policy(&err, &ctx), RecoveryPolicy::Disconnect); + assert_eq!( + default_hook.recovery_policy(&err, &default_ctx), + RecoveryPolicy::Disconnect + ); let err = CodecError::Eof(EofError::CleanClose); - assert_eq!(hook.recovery_policy(&err, &ctx), RecoveryPolicy::Disconnect); + assert_eq!( + default_hook.recovery_policy(&err, &default_ctx), + RecoveryPolicy::Disconnect + ); } -#[test] -fn default_quarantine_duration_is_30_seconds() { - let hook = DefaultRecoveryPolicy; - let ctx = CodecErrorContext::new(); +#[rstest] +fn default_quarantine_duration_is_30_seconds( + default_hook: DefaultRecoveryPolicy, + default_ctx: CodecErrorContext, +) { let err = CodecError::Io(io::Error::other("test")); assert_eq!( - hook.quarantine_duration(&err, &ctx), + default_hook.quarantine_duration(&err, &default_ctx), Duration::from_secs(30) ); } diff --git a/src/extractor/connection_info.rs b/src/extractor/connection_info.rs index f3879d81..9ac30659 100644 --- a/src/extractor/connection_info.rs +++ b/src/extractor/connection_info.rs @@ -20,9 +20,10 @@ impl ConnectionInfo { /// /// use wireframe::extractor::{ConnectionInfo, FromMessageRequest, MessageRequest, Payload}; /// - /// let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap(); + /// let addr: SocketAddr = "127.0.0.1:8080".parse().expect("valid socket address"); /// let req = MessageRequest::new().with_peer_addr(Some(addr)); - /// let info = ConnectionInfo::from_message_request(&req, &mut Payload::default()).unwrap(); + /// let info = ConnectionInfo::from_message_request(&req, &mut Payload::default()) + /// .expect("connection info extraction should succeed"); /// assert_eq!(info.peer_addr(), Some(addr)); /// ``` #[must_use] diff --git a/src/extractor/error.rs b/src/extractor/error.rs index c7d0a663..6d04be73 100644 --- a/src/extractor/error.rs +++ b/src/extractor/error.rs @@ -1,48 +1,25 @@ //! Error types for built-in extractors. +use thiserror::Error; + /// Errors that can occur when extracting built-in types. /// /// This enum is marked `#[non_exhaustive]` so more variants may be added in /// the future without breaking changes. -#[derive(Debug)] +#[derive(Debug, Error)] #[non_exhaustive] pub enum ExtractError { /// No shared state of the requested type was found. + #[error("no shared state registered for {0}")] MissingState(&'static str), /// Failed to decode the message payload. - InvalidPayload(bincode::error::DecodeError), + #[error("failed to decode payload: {0}")] + InvalidPayload(#[source] bincode::error::DecodeError), /// No streaming body was available for this request. /// /// This occurs when: /// - The request was not configured for streaming consumption /// - The stream was already consumed by another extractor + #[error("no streaming body available for this request")] MissingBodyStream, } - -impl std::fmt::Display for ExtractError { - /// Formats the `ExtractError` for display purposes. - /// - /// Displays a descriptive message for missing shared state or payload decoding errors. - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::MissingState(ty) => write!(f, "no shared state registered for {ty}"), - Self::InvalidPayload(e) => write!(f, "failed to decode payload: {e}"), - Self::MissingBodyStream => { - write!(f, "no streaming body available for this request") - } - } - } -} - -impl std::error::Error for ExtractError { - /// Returns the underlying error if this is an `InvalidPayload` variant. - /// - /// # Returns - /// An optional reference to the underlying decode error, or `None` if not applicable. - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - Self::InvalidPayload(e) => Some(e), - _ => None, - } - } -} diff --git a/src/extractor/message.rs b/src/extractor/message.rs index 9e818058..153a3204 100644 --- a/src/extractor/message.rs +++ b/src/extractor/message.rs @@ -1,4 +1,4 @@ -//! Message extractor for deserialised payloads. +//! Message extractor for deserialized payloads. use super::{ExtractError, FromMessageRequest, MessageRequest, Payload}; use crate::message::Message as WireMessage; @@ -8,7 +8,7 @@ use crate::message::Message as WireMessage; pub struct Message(T); impl Message { - /// Consumes the extractor and returns the inner deserialised message value. + /// Consumes the extractor and returns the inner deserialized message value. #[must_use] pub fn into_inner(self) -> T { self.0 } } diff --git a/src/extractor/request.rs b/src/extractor/request.rs index e569dee2..1b372595 100644 --- a/src/extractor/request.rs +++ b/src/extractor/request.rs @@ -22,7 +22,7 @@ pub struct MessageRequest { /// /// Values are keyed by their [`TypeId`]. Registering additional /// state of the same type will replace the previous entry. - pub app_data: HashMap>, + pub(crate) app_data: HashMap>, /// Optional streaming body for handlers that opt into streaming consumption. /// /// When present, the [`StreamingBody`](crate::extractor::StreamingBody) @@ -48,7 +48,9 @@ impl MessageRequest { /// /// use wireframe::extractor::MessageRequest; /// - /// let req = MessageRequest::new().with_peer_addr(Some("127.0.0.1:8080".parse().unwrap())); + /// let req = MessageRequest::new().with_peer_addr(Some( + /// "127.0.0.1:8080".parse().expect("valid socket address"), + /// )); /// assert!(req.peer_addr.is_some()); /// ``` #[must_use] @@ -69,12 +71,14 @@ impl MessageRequest { /// extractor::{MessageRequest, SharedState}, /// }; /// - /// let _app = WireframeApp::new().unwrap().app_data(5u32); + /// let _app = WireframeApp::new() + /// .expect("failed to initialize app") + /// .app_data(5u32); /// // The framework populates the request with application data. /// # let mut req = MessageRequest::default(); /// # req.insert_state(5u32); /// let val: Option> = req.state(); - /// assert_eq!(*val.unwrap(), 5); + /// assert_eq!(*val.expect("shared state missing"), 5); /// ``` #[must_use] pub fn state(&self) -> Option> @@ -99,7 +103,7 @@ impl MessageRequest { /// let mut req = MessageRequest::default(); /// req.insert_state(5u32); /// let val: Option> = req.state(); - /// assert_eq!(*val.unwrap(), 5); + /// assert_eq!(*val.expect("shared state missing"), 5); /// ``` pub fn insert_state(&mut self, state: T) where diff --git a/src/server/config/mod.rs b/src/server/config/mod.rs index fdc822a5..422dc98b 100644 --- a/src/server/config/mod.rs +++ b/src/server/config/mod.rs @@ -50,7 +50,7 @@ pub mod binding; pub mod preamble; fn default_worker_count() -> usize { - std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get) + super::default_worker_count() } #[cfg(test)] mod tests; diff --git a/src/server/config/tests/mod.rs b/src/server/config/tests/mod.rs index 7a911b42..f34ff9e6 100644 --- a/src/server/config/tests/mod.rs +++ b/src/server/config/tests/mod.rs @@ -16,8 +16,3 @@ enum PreambleHandlerKind { Success, Failure, } - -fn expected_default_worker_count() -> usize { - // Mirror the default worker logic to keep tests aligned with `WireframeServer::new`. - std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get) -} diff --git a/src/server/config/tests/tests_basic.rs b/src/server/config/tests/tests_basic.rs index f7b7a2ca..b41ca92f 100644 --- a/src/server/config/tests/tests_basic.rs +++ b/src/server/config/tests/tests_basic.rs @@ -2,11 +2,11 @@ use rstest::rstest; -use super::expected_default_worker_count; use crate::{ app::WireframeApp, server::{ WireframeServer, + default_worker_count, test_util::{bind_server, factory, free_listener, listener_addr}, }, }; @@ -14,7 +14,8 @@ use crate::{ #[rstest] fn test_new_server_creation(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { let server = WireframeServer::new(factory); - assert!(server.worker_count() >= 1 && server.local_addr().is_none()); + assert!(server.worker_count() >= 1); + assert!(server.local_addr().is_none()); } #[rstest] @@ -22,15 +23,15 @@ fn test_new_server_default_worker_count( factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, ) { let server = WireframeServer::new(factory); - assert_eq!(server.worker_count(), expected_default_worker_count()); + assert_eq!(server.worker_count(), default_worker_count()); } #[rstest] fn test_workers_configuration(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - let mut server = WireframeServer::new(factory); - server = server.workers(4); + let server = WireframeServer::new(factory); + let server = server.workers(4); assert_eq!(server.worker_count(), 4); - server = server.workers(100); + let server = server.workers(100); assert_eq!(server.worker_count(), 100); assert_eq!(server.workers(0).worker_count(), 1); } diff --git a/src/server/config/tests/tests_integration.rs b/src/server/config/tests/tests_integration.rs index 671bbd8d..f9ce7457 100644 --- a/src/server/config/tests/tests_integration.rs +++ b/src/server/config/tests/tests_integration.rs @@ -63,8 +63,8 @@ async fn test_server_configuration_persistence( #[rstest] fn test_extreme_worker_counts(factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static) { - let mut server = WireframeServer::new(factory); - server = server.workers(usize::MAX); + let server = WireframeServer::new(factory); + let server = server.workers(usize::MAX); assert_eq!(server.worker_count(), usize::MAX); assert_eq!(server.workers(0).worker_count(), 1); } diff --git a/src/server/mod.rs b/src/server/mod.rs index 28a0ac78..f286ce4b 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -17,6 +17,10 @@ use crate::{ serializer::{BincodeSerializer, Serializer}, }; +pub(crate) fn default_worker_count() -> usize { + std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get) +} + /// Handler invoked when a connection preamble decodes successfully. /// /// Implementors may perform asynchronous I/O on the provided stream before the From b14fbca9804aec27c5413d55f93725a393c53645 Mon Sep 17 00:00:00 2001 From: Leynos Date: Mon, 9 Feb 2026 12:39:40 +0000 Subject: [PATCH 7/9] refactor(codec,recovery,app,extractor,server): improve error handling, tests, and code clarity - Propagate underlying I/O errors in response handling instead of generic messages - Simplify recovery policy tests by removing rstest fixtures - Add Debug impl for SharedState with non-exhaustive output - Document default worker count computation in server module - Minor test assertion cleanup and added doc comment for default read timeout constant Co-authored-by: devboxerhub[bot] --- src/app/builder_defaults.rs | 1 + src/app/frame_handling/response.rs | 6 ++-- src/codec/recovery/tests.rs | 38 +++++++---------------- src/extractor/shared_state.rs | 6 ++++ src/server/config/tests/tests_preamble.rs | 1 + src/server/mod.rs | 1 + tests/app_data.rs | 3 +- 7 files changed, 25 insertions(+), 31 deletions(-) diff --git a/src/app/builder_defaults.rs b/src/app/builder_defaults.rs index fa539a29..b1c7c280 100644 --- a/src/app/builder_defaults.rs +++ b/src/app/builder_defaults.rs @@ -6,6 +6,7 @@ use crate::{codec::clamp_frame_length, fragment::FragmentationConfig}; pub(super) const MIN_READ_TIMEOUT_MS: u64 = 1; pub(super) const MAX_READ_TIMEOUT_MS: u64 = 86_400_000; +/// Default preamble read timeout in milliseconds. pub(super) const DEFAULT_READ_TIMEOUT_MS: u64 = 100; const DEFAULT_FRAGMENT_TIMEOUT: Duration = Duration::from_secs(30); const DEFAULT_MESSAGE_SIZE_MULTIPLIER: usize = 16; diff --git a/src/app/frame_handling/response.rs b/src/app/frame_handling/response.rs index d135c489..c827d787 100644 --- a/src/app/frame_handling/response.rs +++ b/src/app/frame_handling/response.rs @@ -81,7 +81,7 @@ fn fragment_responses( err = err ); crate::metrics::inc_handler_errors(); - Err(io::Error::other("fragmentation failed")) + Err(io::Error::other(err)) } }, None => Ok(vec![envelope]), @@ -107,7 +107,7 @@ fn serialize_response( e = e ); crate::metrics::inc_handler_errors(); - Err(io::Error::other("serialization failed")) + Err(io::Error::other(e)) } } } @@ -133,7 +133,7 @@ where let correlation_id = response.correlation_id; warn!("failed to send response: id={id}, correlation_id={correlation_id:?}, error={e:?}"); crate::metrics::inc_handler_errors(); - return Err(io::Error::other("send failed")); + return Err(io::Error::other(e)); } Ok(()) } diff --git a/src/codec/recovery/tests.rs b/src/codec/recovery/tests.rs index 4e7cc3f5..57e0daf3 100644 --- a/src/codec/recovery/tests.rs +++ b/src/codec/recovery/tests.rs @@ -2,22 +2,11 @@ use std::{io, time::Duration}; -use rstest::{fixture, rstest}; - use super::*; -use crate::codec::CodecError; - -#[fixture] -fn default_hook() -> DefaultRecoveryPolicy { - std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst); - DefaultRecoveryPolicy -} - -#[fixture] -fn default_ctx() -> CodecErrorContext { - std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst); - CodecErrorContext::new() -} +use crate::codec::{ + CodecError, + error::{EofError, FramingError}, +}; #[test] fn recovery_policy_default_is_drop() { @@ -43,12 +32,10 @@ fn context_with_peer_address() { assert_eq!(ctx.peer_address, Some(addr)); } -#[rstest] -fn default_recovery_policy_delegates_to_error( - default_hook: DefaultRecoveryPolicy, - default_ctx: CodecErrorContext, -) { - use crate::codec::error::{EofError, FramingError}; +#[test] +fn default_recovery_policy_delegates_to_error() { + let default_hook = DefaultRecoveryPolicy; + let default_ctx = CodecErrorContext::new(); // Check various error types let err = CodecError::Framing(FramingError::OversizedFrame { size: 100, max: 50 }); @@ -70,11 +57,10 @@ fn default_recovery_policy_delegates_to_error( ); } -#[rstest] -fn default_quarantine_duration_is_30_seconds( - default_hook: DefaultRecoveryPolicy, - default_ctx: CodecErrorContext, -) { +#[test] +fn default_quarantine_duration_is_30_seconds() { + let default_hook = DefaultRecoveryPolicy; + let default_ctx = CodecErrorContext::new(); let err = CodecError::Io(io::Error::other("test")); assert_eq!( diff --git a/src/extractor/shared_state.rs b/src/extractor/shared_state.rs index fdb41b9d..1589c814 100644 --- a/src/extractor/shared_state.rs +++ b/src/extractor/shared_state.rs @@ -8,6 +8,12 @@ use super::{ExtractError, FromMessageRequest, MessageRequest, Payload}; #[derive(Clone)] pub struct SharedState(Arc); +impl std::fmt::Debug for SharedState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SharedState").finish_non_exhaustive() + } +} + impl From> for SharedState { fn from(inner: Arc) -> Self { Self(inner) } } diff --git a/src/server/config/tests/tests_preamble.rs b/src/server/config/tests/tests_preamble.rs index 16c459f1..5360dfd6 100644 --- a/src/server/config/tests/tests_preamble.rs +++ b/src/server/config/tests/tests_preamble.rs @@ -23,6 +23,7 @@ use crate::{ test_util::{TestPreamble, factory, server_with_preamble}, }, }; + #[rstest] fn test_with_preamble_type_conversion( factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, diff --git a/src/server/mod.rs b/src/server/mod.rs index f286ce4b..d4e8fc4b 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -17,6 +17,7 @@ use crate::{ serializer::{BincodeSerializer, Serializer}, }; +/// Compute the default worker count from available CPU parallelism. pub(crate) fn default_worker_count() -> usize { std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get) } diff --git a/tests/app_data.rs b/tests/app_data.rs index 7dbd533a..0b85688b 100644 --- a/tests/app_data.rs +++ b/tests/app_data.rs @@ -41,7 +41,6 @@ fn missing_shared_state_returns_error( mut empty_payload: Payload<'static>, ) { let err = SharedState::::from_message_request(&request, &mut empty_payload) - .err() - .expect("missing state error expected"); + .expect_err("missing state error expected"); assert!(matches!(err, ExtractError::MissingState(_))); } From 56b1e1f1e3c2a0f16f480751f2276dbac16cd607 Mon Sep 17 00:00:00 2001 From: Leynos Date: Mon, 9 Feb 2026 15:41:13 +0000 Subject: [PATCH 8/9] style: format server config module after rebase Apply rustfmt output produced during post-rebase validation so working tree is clean and quality gates remain reproducible. --- src/server/config/mod.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/server/config/mod.rs b/src/server/config/mod.rs index 422dc98b..06d87a62 100644 --- a/src/server/config/mod.rs +++ b/src/server/config/mod.rs @@ -49,9 +49,7 @@ macro_rules! builder_callback { pub mod binding; pub mod preamble; -fn default_worker_count() -> usize { - super::default_worker_count() -} +fn default_worker_count() -> usize { super::default_worker_count() } #[cfg(test)] mod tests; From 96a141e7bc047e8a8437d2dade1aa60caab43d84 Mon Sep 17 00:00:00 2001 From: Leynos Date: Mon, 9 Feb 2026 16:28:36 +0000 Subject: [PATCH 9/9] docs(frame_handling): add detailed docs to forward_response function Extended the documentation of `forward_response` to clarify its behavior when handling errors from the service call. The docs now explain the difference between application-level errors, which are logged and counted but do not propagate an error, and transport-level errors, which are propagated as `io::Error`. This improves code understandability and maintenance. Co-authored-by: devboxerhub[bot] --- src/app/frame_handling/response.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/app/frame_handling/response.rs b/src/app/frame_handling/response.rs index c827d787..edcb3d33 100644 --- a/src/app/frame_handling/response.rs +++ b/src/app/frame_handling/response.rs @@ -22,6 +22,14 @@ use crate::{ }; /// Forward a handler response, fragmenting if required, and write to the framed stream. +/// +/// `forward_response` accepts an [`Envelope`], builds a [`ServiceRequest`], and +/// invokes `service.call(request)`. If the handler returns `Err(e)`, this is +/// treated as an application-level failure: the error is logged, +/// [`crate::metrics::inc_handler_errors()`] is incremented, and the function +/// returns `Ok(())` (intentional log-and-continue behaviour). Transport-level +/// I/O failures (for example during fragmentation, serialization, or frame +/// send) still return `io::Error` and are propagated to the caller. pub(crate) async fn forward_response( env: Envelope, service: &HandlerService,