diff --git a/docs/asynchronous-outbound-messaging-design.md b/docs/asynchronous-outbound-messaging-design.md index dab955b8..73f010c2 100644 --- a/docs/asynchronous-outbound-messaging-design.md +++ b/docs/asynchronous-outbound-messaging-design.md @@ -466,6 +466,51 @@ pub trait WireframeProtocol: Send + Sync + 'static { WireframeApp::new().with_protocol(MySqlProtocolImpl); ``` +```mermaid +classDiagram + class WireframeProtocol { + <> + +Frame: FrameLike + +ProtocolError + +on_connection_setup(PushHandle, &mut ConnectionContext) + +before_send(&mut Frame, &mut ConnectionContext) + +on_command_end(&mut ConnectionContext) + } + class ProtocolHooks { + -before_send: Option> + -on_command_end: Option + +before_send(&mut self, &mut F, &mut ConnectionContext) + +on_command_end(&mut self, &mut ConnectionContext) + +from_protocol(protocol: Arc

) + } + class ConnectionContext { + <> + } + class WireframeApp { + -protocol: Option, ProtocolError=()>>> + +with_protocol(protocol) + +protocol() + +protocol_hooks() + } + class ConnectionActor { + -hooks: ProtocolHooks + -ctx: ConnectionContext + } + WireframeApp --> "1" WireframeProtocol : uses + WireframeApp --> "1" ProtocolHooks : creates + ProtocolHooks --> "1" WireframeProtocol : from_protocol + ConnectionActor --> "1" ProtocolHooks : uses + ConnectionActor --> "1" ConnectionContext : owns + ProtocolHooks --> "1" ConnectionContext : passes to hooks + WireframeProtocol --> "1" ConnectionContext : uses + WireframeProtocol --> "1" PushHandle : uses + WireframeProtocol <|.. ProtocolHooks : implemented by +``` + +`ConnectionContext` is intentionally empty today. It offers a stable extension +point for per-connection data without breaking existing protocol +implementations. + ## 5. Error Handling & Resilience ### 5.1 `BrokenPipe` on Connection Loss diff --git a/src/app.rs b/src/app.rs index 4dd31b83..3905d0f6 100644 --- a/src/app.rs +++ b/src/app.rs @@ -18,6 +18,7 @@ use tokio::io::{self, AsyncWrite, AsyncWriteExt}; use crate::{ frame::{FrameProcessor, LengthFormat, LengthPrefixedProcessor}, + hooks::{ProtocolHooks, WireframeProtocol}, message::Message, middleware::{HandlerService, Service, ServiceRequest, Transform}, serializer::{BincodeSerializer, Serializer}, @@ -80,6 +81,7 @@ pub struct WireframeApp< app_data: HashMap>, on_connect: Option>>, on_disconnect: Option>>, + protocol: Option, ProtocolError = ()>>>, } /// Alias for asynchronous route handlers. @@ -235,6 +237,7 @@ where app_data: HashMap::new(), on_connect: None, on_disconnect: None, + protocol: None, } } } @@ -360,6 +363,7 @@ where app_data: self.app_data, on_connect: Some(Arc::new(move || Box::pin(f()))), on_disconnect: None, + protocol: self.protocol, }) } @@ -381,6 +385,41 @@ where Ok(self) } + /// Install a [`WireframeProtocol`] implementation. + /// + /// The protocol defines hooks for connection setup, frame modification, and + /// command completion. It is wrapped in an [`Arc`] and stored for later use + /// by the connection actor. + #[must_use] + pub fn with_protocol

(mut self, protocol: P) -> Self + where + P: WireframeProtocol, ProtocolError = ()> + 'static, + { + self.protocol = Some(Arc::new(protocol)); + self + } + + /// Get a clone of the configured protocol, if any. + /// + /// Returns `None` if no protocol was installed via [`with_protocol`](Self::with_protocol). + #[must_use] + pub fn protocol( + &self, + ) -> Option, ProtocolError = ()>>> { + self.protocol.as_ref().map(Arc::clone) + } + + /// Return protocol hooks derived from the installed protocol. + /// + /// If no protocol is installed, returns default (no-op) hooks. + #[must_use] + pub fn protocol_hooks(&self) -> ProtocolHooks> { + self.protocol + .as_ref() + .map(|p| ProtocolHooks::from_protocol(&Arc::clone(p))) + .unwrap_or_default() + } + /// Set the frame processor used for encoding and decoding frames. #[must_use] pub fn frame_processor

(mut self, processor: P) -> Self @@ -406,6 +445,7 @@ where app_data: self.app_data, on_connect: self.on_connect, on_disconnect: self.on_disconnect, + protocol: self.protocol, } } diff --git a/src/connection.rs b/src/connection.rs index 20fd423a..f2460bdb 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -13,8 +13,8 @@ use tokio::{ use tokio_util::sync::CancellationToken; use crate::{ - hooks::ProtocolHooks, - push::{FrameLike, PushQueues}, + hooks::{ConnectionContext, ProtocolHooks}, + push::{FrameLike, PushHandle, PushQueues}, response::{FrameStream, WireframeError}, }; @@ -48,9 +48,9 @@ impl Default for FairnessConfig { /// use tokio_util::sync::CancellationToken; /// use wireframe::{connection::ConnectionActor, push::PushQueues}; /// -/// let (queues, _handle) = PushQueues::::bounded(8, 8); +/// let (queues, handle) = PushQueues::::bounded(8, 8); /// let shutdown = CancellationToken::new(); -/// let mut actor: ConnectionActor<_, ()> = ConnectionActor::new(queues, None, shutdown); +/// let mut actor: ConnectionActor<_, ()> = ConnectionActor::new(queues, handle, None, shutdown); /// # drop(actor); /// ``` pub struct ConnectionActor { @@ -59,6 +59,7 @@ pub struct ConnectionActor { response: Option>, // current streaming response shutdown: CancellationToken, hooks: ProtocolHooks, + ctx: ConnectionContext, fairness: FairnessConfig, high_counter: usize, high_start: Option, @@ -76,34 +77,39 @@ where /// use tokio_util::sync::CancellationToken; /// use wireframe::{connection::ConnectionActor, push::PushQueues}; /// - /// let (queues, _handle) = PushQueues::::bounded(4, 4); + /// let (queues, handle) = PushQueues::::bounded(4, 4); /// let token = CancellationToken::new(); - /// let mut actor: ConnectionActor<_, ()> = ConnectionActor::new(queues, None, token); + /// let mut actor: ConnectionActor<_, ()> = ConnectionActor::new(queues, handle, None, token); /// # drop(actor); /// ``` #[must_use] pub fn new( queues: PushQueues, + handle: PushHandle, response: Option>, shutdown: CancellationToken, ) -> Self { - Self::with_hooks(queues, response, shutdown, ProtocolHooks::default()) + Self::with_hooks(queues, handle, response, shutdown, ProtocolHooks::default()) } /// Create a new `ConnectionActor` with custom protocol hooks. #[must_use] pub fn with_hooks( queues: PushQueues, + handle: PushHandle, response: Option>, shutdown: CancellationToken, - hooks: ProtocolHooks, + mut hooks: ProtocolHooks, ) -> Self { + let mut ctx = ConnectionContext; + hooks.on_connection_setup(handle, &mut ctx); Self { high_rx: Some(queues.high_priority_rx), low_rx: Some(queues.low_priority_rx), response, shutdown, hooks, + ctx, fairness: FairnessConfig::default(), high_counter: 0, high_start: None, @@ -208,7 +214,7 @@ where /// Handle the result of polling the high-priority queue. fn process_high(&mut self, res: Option, state: &mut ActorState, out: &mut Vec) { if let Some(mut frame) = res { - self.hooks.before_send(&mut frame); + self.hooks.before_send(&mut frame, &mut self.ctx); out.push(frame); self.after_high(out, state); } else { @@ -221,7 +227,7 @@ where /// Handle the result of polling the low-priority queue. fn process_low(&mut self, res: Option, state: &mut ActorState, out: &mut Vec) { if let Some(mut frame) = res { - self.hooks.before_send(&mut frame); + self.hooks.before_send(&mut frame, &mut self.ctx); out.push(frame); self.after_low(); } else { @@ -274,7 +280,7 @@ where { match rx.try_recv() { Ok(mut frame) => { - self.hooks.before_send(&mut frame); + self.hooks.before_send(&mut frame, &mut self.ctx); out.push(frame); self.after_low(); } @@ -317,13 +323,13 @@ where ) -> Result<(), WireframeError> { match res { Some(Ok(mut frame)) => { - self.hooks.before_send(&mut frame); + self.hooks.before_send(&mut frame, &mut self.ctx); out.push(frame); } Some(Err(e)) => return Err(e), None => { state.mark_closed(); - self.hooks.on_command_end(); + self.hooks.on_command_end(&mut self.ctx); } } diff --git a/src/hooks.rs b/src/hooks.rs index 825e1a92..25b71a18 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -1,17 +1,54 @@ //! Internal protocol hooks called by the connection actor. //! -//! This module defines [`ProtocolHooks`], a container for optional callback -//! functions invoked during connection output. The hooks are placeholders for -//! the future `WireframeProtocol` trait described in the design documents. +//! This module defines [`ProtocolHooks`] along with the public +//! [`WireframeProtocol`] trait. `ProtocolHooks` stores optional callbacks +//! invoked during connection output. Applications configure these callbacks via +//! an implementation of [`WireframeProtocol`]. + +use std::sync::Arc; + +use crate::push::{FrameLike, PushHandle}; + +/// Per-connection state passed to protocol callbacks. +/// +/// This empty struct is intentionally extensible. Future protocol features may +/// require storing connection-local data without breaking existing APIs. +#[derive(Default)] +pub struct ConnectionContext; + +/// Trait encapsulating protocol-specific logic and callbacks. +pub trait WireframeProtocol: Send + Sync + 'static { + /// Frame type written to the socket. + type Frame: FrameLike; + /// Custom error type for protocol operations. + type ProtocolError; + + /// Called once when a new connection is established. The provided + /// [`PushHandle`] may be stored by the implementation to enable + /// asynchronous server pushes. + fn on_connection_setup(&self, _handle: PushHandle, _ctx: &mut ConnectionContext) {} + + /// Invoked before any frame (push or response) is written to the socket. + fn before_send(&self, _frame: &mut Self::Frame, _ctx: &mut ConnectionContext) {} + + /// Invoked when a request/response cycle completes. + fn on_command_end(&self, _ctx: &mut ConnectionContext) {} +} /// Type alias for the `before_send` callback. -type BeforeSendHook = Box; +type BeforeSendHook = Box; + +/// Type alias for the `on_connection_setup` callback. +type OnConnectionSetupHook = + Box, &mut ConnectionContext) + Send + 'static>; /// Type alias for the `on_command_end` callback. -type OnCommandEndHook = Box; +type OnCommandEndHook = Box; /// Callbacks used by the connection actor. pub struct ProtocolHooks { + /// Invoked when a connection is established. + pub on_connection_setup: Option>, /// Invoked before a frame is written to the socket. pub before_send: Option>, /// Invoked once a command completes. @@ -21,6 +58,7 @@ pub struct ProtocolHooks { impl Default for ProtocolHooks { fn default() -> Self { Self { + on_connection_setup: None, before_send: None, on_command_end: None, } @@ -28,17 +66,50 @@ impl Default for ProtocolHooks { } impl ProtocolHooks { + /// Run the `on_connection_setup` hook if registered. + pub fn on_connection_setup(&mut self, handle: PushHandle, ctx: &mut ConnectionContext) { + if let Some(hook) = self.on_connection_setup.take() { + hook(handle, ctx); + } + } /// Run the `before_send` hook if registered. - pub fn before_send(&mut self, frame: &mut F) { + pub fn before_send(&mut self, frame: &mut F, ctx: &mut ConnectionContext) { if let Some(hook) = &mut self.before_send { - hook(frame); + hook(frame, ctx); } } /// Run the `on_command_end` hook if registered. - pub fn on_command_end(&mut self) { + pub fn on_command_end(&mut self, ctx: &mut ConnectionContext) { if let Some(hook) = &mut self.on_command_end { - hook(); + hook(ctx); + } + } + + /// Construct hooks from a [`WireframeProtocol`] implementation. + pub fn from_protocol

(protocol: &Arc

) -> Self + where + P: WireframeProtocol + ?Sized, + { + let protocol_before = Arc::clone(protocol); + let before = Box::new(move |frame: &mut F, ctx: &mut ConnectionContext| { + protocol_before.before_send(frame, ctx); + }) as BeforeSendHook; + + let protocol_end = Arc::clone(protocol); + let end = Box::new(move |ctx: &mut ConnectionContext| { + protocol_end.on_command_end(ctx); + }) as OnCommandEndHook; + + let protocol_setup = Arc::clone(protocol); + let setup = Box::new(move |handle: PushHandle, ctx: &mut ConnectionContext| { + protocol_setup.on_connection_setup(handle, ctx); + }) as OnConnectionSetupHook; + + Self { + on_connection_setup: Some(setup), + before_send: Some(before), + on_command_end: Some(end), } } } diff --git a/src/lib.rs b/src/lib.rs index ddb652e0..a553add2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,5 +20,5 @@ pub mod rewind_stream; pub mod server; pub use connection::ConnectionActor; -pub use hooks::ProtocolHooks; +pub use hooks::{ConnectionContext, ProtocolHooks, WireframeProtocol}; pub use response::{FrameStream, Response, WireframeError}; diff --git a/tests/connection_actor.rs b/tests/connection_actor.rs index 7cbd0f17..5e034f79 100644 --- a/tests/connection_actor.rs +++ b/tests/connection_actor.rs @@ -34,11 +34,10 @@ async fn strict_priority_order( let (queues, handle) = queues; handle.push_low_priority(2).await.unwrap(); handle.push_high_priority(1).await.unwrap(); - drop(handle); let stream = stream::iter(vec![Ok(3u8)]); let mut actor: ConnectionActor<_, ()> = - ConnectionActor::new(queues, Some(Box::pin(stream)), shutdown_token); + ConnectionActor::new(queues, handle, Some(Box::pin(stream)), shutdown_token); let mut out = Vec::new(); actor.run(&mut out).await.unwrap(); assert_eq!(out, vec![1, 2, 3]); @@ -60,9 +59,9 @@ async fn fairness_yields_low_after_burst( handle.push_high_priority(n).await.unwrap(); } handle.push_low_priority(99).await.unwrap(); - drop(handle); - let mut actor: ConnectionActor<_, ()> = ConnectionActor::new(queues, None, shutdown_token); + let mut actor: ConnectionActor<_, ()> = + ConnectionActor::new(queues, handle, None, shutdown_token); actor.set_fairness(fairness); let mut out = Vec::new(); actor.run(&mut out).await.unwrap(); @@ -76,9 +75,10 @@ async fn shutdown_signal_precedence( shutdown_token: CancellationToken, ) { let (queues, handle) = queues; - drop(handle); shutdown_token.cancel(); - let mut actor: ConnectionActor<_, ()> = ConnectionActor::new(queues, None, shutdown_token); + let mut actor: ConnectionActor<_, ()> = + ConnectionActor::new(queues, handle, None, shutdown_token); + // drop the handle after actor creation to mimic early disconnection let mut out = Vec::new(); actor.run(&mut out).await.unwrap(); assert!(out.is_empty()); @@ -92,11 +92,11 @@ async fn complete_draining_of_sources( ) { let (queues, handle) = queues; handle.push_high_priority(1).await.unwrap(); - drop(handle); let stream = stream::iter(vec![Ok(2u8), Ok(3u8)]); let mut actor: ConnectionActor<_, ()> = - ConnectionActor::new(queues, Some(Box::pin(stream)), shutdown_token); + ConnectionActor::new(queues, handle, Some(Box::pin(stream)), shutdown_token); + // drop handle after actor setup let mut out = Vec::new(); actor.run(&mut out).await.unwrap(); assert_eq!(out, vec![1, 2, 3]); @@ -114,14 +114,13 @@ async fn error_propagation_from_stream( shutdown_token: CancellationToken, ) { let (queues, handle) = queues; - drop(handle); let stream = stream::iter(vec![ Ok(1u8), Ok(2u8), Err(WireframeError::Protocol(TestError::Kaboom)), ]); let mut actor: ConnectionActor<_, TestError> = - ConnectionActor::new(queues, Some(Box::pin(stream)), shutdown_token); + ConnectionActor::new(queues, handle, Some(Box::pin(stream)), shutdown_token); let mut out = Vec::new(); let result = actor.run(&mut out).await; assert!(matches!( @@ -138,7 +137,6 @@ async fn interleaved_shutdown_during_stream( shutdown_token: CancellationToken, ) { let (queues, handle) = queues; - drop(handle); let token = shutdown_token.clone(); tokio::spawn(async move { sleep(Duration::from_millis(50)).await; @@ -154,7 +152,7 @@ async fn interleaved_shutdown_during_stream( } }); let mut actor: ConnectionActor<_, ()> = - ConnectionActor::new(queues, Some(Box::pin(stream)), shutdown_token); + ConnectionActor::new(queues, handle, Some(Box::pin(stream)), shutdown_token); let mut out = Vec::new(); actor.run(&mut out).await.unwrap(); assert!(!out.is_empty() && out.len() < 5); @@ -178,7 +176,7 @@ use std::sync::{ atomic::{AtomicUsize, Ordering}, }; -use wireframe::ProtocolHooks; +use wireframe::{ConnectionContext, ProtocolHooks}; #[rstest] #[tokio::test] @@ -188,16 +186,20 @@ async fn before_send_hook_modifies_frames( ) { let (queues, handle) = queues; handle.push_high_priority(1).await.unwrap(); - drop(handle); let stream = stream::iter(vec![Ok(2u8)]); let hooks = ProtocolHooks { - before_send: Some(Box::new(|f: &mut u8| *f += 1)), + before_send: Some(Box::new(|f: &mut u8, _ctx: &mut ConnectionContext| *f += 1)), ..ProtocolHooks::default() }; - let mut actor: ConnectionActor<_, ()> = - ConnectionActor::with_hooks(queues, Some(Box::pin(stream)), shutdown_token, hooks); + let mut actor: ConnectionActor<_, ()> = ConnectionActor::with_hooks( + queues, + handle, + Some(Box::pin(stream)), + shutdown_token, + hooks, + ); let mut out = Vec::new(); actor.run(&mut out).await.unwrap(); assert_eq!(out, vec![2, 3]); @@ -210,20 +212,24 @@ async fn on_command_end_hook_runs( shutdown_token: CancellationToken, ) { let (queues, handle) = queues; - drop(handle); let stream = stream::iter(vec![Ok(1u8)]); let counter = Arc::new(AtomicUsize::new(0)); let c = counter.clone(); let hooks = ProtocolHooks { - on_command_end: Some(Box::new(move || { + on_command_end: Some(Box::new(move |_ctx: &mut ConnectionContext| { c.fetch_add(1, Ordering::SeqCst); })), ..ProtocolHooks::default() }; - let mut actor: ConnectionActor<_, ()> = - ConnectionActor::with_hooks(queues, Some(Box::pin(stream)), shutdown_token, hooks); + let mut actor: ConnectionActor<_, ()> = ConnectionActor::with_hooks( + queues, + handle, + Some(Box::pin(stream)), + shutdown_token, + hooks, + ); let mut out = Vec::new(); actor.run(&mut out).await.unwrap(); assert_eq!(counter.load(Ordering::SeqCst), 1); diff --git a/tests/wireframe_protocol.rs b/tests/wireframe_protocol.rs new file mode 100644 index 00000000..e13e7934 --- /dev/null +++ b/tests/wireframe_protocol.rs @@ -0,0 +1,94 @@ +//! Integration tests for the `WireframeProtocol` trait. +//! +//! These tests ensure that protocol implementations integrate correctly with +//! [`WireframeApp`] and [`ConnectionActor`]. They verify that hooks are invoked +//! with the expected connection context and that frame mutations occur as +//! intended. + +use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, +}; + +use futures::stream; +use rstest::rstest; +use tokio_util::sync::CancellationToken; +use wireframe::{ + ConnectionContext, + WireframeProtocol, + app::WireframeApp, + connection::ConnectionActor, + push::PushQueues, +}; + +struct TestProtocol { + counter: Arc, +} + +impl WireframeProtocol for TestProtocol { + type Frame = Vec; + type ProtocolError = (); + + fn on_connection_setup( + &self, + _handle: wireframe::push::PushHandle, + _ctx: &mut ConnectionContext, + ) { + self.counter.fetch_add(1, Ordering::SeqCst); + } + + fn before_send(&self, frame: &mut Self::Frame, _ctx: &mut ConnectionContext) { frame.push(1); } + + fn on_command_end(&self, _ctx: &mut ConnectionContext) { + self.counter.fetch_add(1, Ordering::SeqCst); + } +} + +#[rstest] +#[tokio::test] +async fn builder_produces_protocol_hooks() { + let counter = Arc::new(AtomicUsize::new(0)); + let protocol = TestProtocol { + counter: counter.clone(), + }; + let app = WireframeApp::new().unwrap().with_protocol(protocol); + let mut hooks = app.protocol_hooks(); + + let (queues, handle) = PushQueues::bounded(1, 1); + hooks.on_connection_setup(handle, &mut ConnectionContext); + drop(queues); // silence unused warnings + + let mut frame = vec![1u8]; + hooks.before_send(&mut frame, &mut ConnectionContext); + hooks.on_command_end(&mut ConnectionContext); + + assert_eq!(frame, vec![1, 1]); + assert_eq!(counter.load(Ordering::SeqCst), 2); +} + +#[rstest] +#[tokio::test] +async fn connection_actor_uses_protocol_from_builder() { + let counter = Arc::new(AtomicUsize::new(0)); + let protocol = TestProtocol { + counter: counter.clone(), + }; + let app = WireframeApp::new().unwrap().with_protocol(protocol); + + let hooks = app.protocol_hooks(); + let (queues, handle) = PushQueues::bounded(8, 8); + handle.push_high_priority(vec![1]).await.unwrap(); + let stream = stream::iter(vec![Ok(vec![2u8])]); + let mut actor: ConnectionActor<_, ()> = ConnectionActor::with_hooks( + queues, + handle, + Some(Box::pin(stream)), + CancellationToken::new(), + hooks, + ); + let mut out = Vec::new(); + actor.run(&mut out).await.unwrap(); + + assert_eq!(out, vec![vec![1, 1], vec![2, 1]]); + assert_eq!(counter.load(Ordering::SeqCst), 2); +}