diff --git a/README.md b/README.md index d8bc453b..0d8a0440 100644 --- a/README.md +++ b/README.md @@ -99,19 +99,59 @@ produced by `on_connection_setup` is passed to `on_connection_teardown` when the connection ends. ```rust -let app = WireframeApp::new() - .on_connection_setup(|| async { 42u32 }) - .on_connection_teardown(|state| async move { - println!("closing with {state}"); - }); + let app = WireframeApp::new() + .on_connection_setup(|| async { 42u32 }) + .on_connection_teardown(|state| async move { + println!("closing with {state}"); + }); +``` + +## Custom Extractors + +Extractors are types that implement `FromMessageRequest`. When a handler lists +an extractor as a parameter, `wireframe` automatically constructs it using the +incoming \[`MessageRequest`\] and remaining \[`Payload`\]. Built‑in extractors like +`Message`, `SharedState` and `ConnectionInfo` decode the payload, access +app state or expose peer information. + +Custom extractors let you centralize parsing and validation logic that would +otherwise be duplicated across handlers. A session token parser, for example, +can verify the token before any route-specific code executes +[Design Guide: Data Extraction and Type Safety](docs/rust-binary-router-library-design.md#53-data-extraction-and-type-safety). + +```rust +use wireframe::extractor::{ConnectionInfo, FromMessageRequest, MessageRequest, Payload}; + +pub struct SessionToken(String); + +impl FromMessageRequest for SessionToken { + type Error = std::convert::Infallible; + + fn from_message_request( + _req: &MessageRequest, + payload: &mut Payload<'_>, + ) -> Result { + let len = payload.data[0] as usize; + let token = std::str::from_utf8(&payload.data[1..=len]).unwrap().to_string(); + payload.advance(1 + len); + Ok(Self(token)) + } +} +``` + +Custom extractors integrate seamlessly with other parameters: + +```rust +async fn handle_ping(token: SessionToken, info: ConnectionInfo) { + println!("{} from {:?}", token.0, info.peer_addr()); +} ``` ## Current Limitations -Connection processing is not implemented yet. After the optional -preamble is read, the server logs a warning and immediately closes the -stream. Release builds fail to compile to prevent accidental production -use. +Connection handling now processes frames and routes messages, but the +server is still experimental. Release builds fail to compile, so the +library cannot be used accidentally in production. ## Roadmap diff --git a/docs/rust-binary-router-library-design.md b/docs/rust-binary-router-library-design.md index da3e840b..d3a27ba9 100644 --- a/docs/rust-binary-router-library-design.md +++ b/docs/rust-binary-router-library-design.md @@ -855,17 +855,51 @@ its context. For example, a custom extractor could parse a session token from a specific field in all messages, validate it, and provide a `UserSession` object to the handler. + This extractor system, backed by Rust's strong type system, ensures that handlers receive correctly typed and validated data, significantly reducing the likelihood of runtime errors and boilerplate parsing code within the handler logic itself. Custom extractors are particularly valuable as they allow common, -protocol-specific data extraction and validation logic (e.g., extracting and -verifying a session token from a custom frame header) to be encapsulated into -reusable components. This further reduces code duplication across multiple +protocol-specific data extraction and validation logic (for example extracting +and verifying a session token from a custom frame header) to be encapsulated +into reusable components. This further reduces code duplication across multiple handlers and keeps the handler functions lean and focused on their specific business tasks, mirroring the benefits seen with Actix Web's `FromRequest` trait.24 +```mermaid +classDiagram + class FromMessageRequest { + <> + +from_message_request(req: &MessageRequest, payload: &mut Payload) Result + +Error + } + class Message~T~ { + +Message(T) + +into_inner() T + +deref() &T + } + class ConnectionInfo { + +peer_addr: Option + +peer_addr() Option + } + class SharedState~T~ { + +deref() &T + } + class ExtractError { + +MissingState(&'static str) + +InvalidPayload(DecodeError) + } + FromMessageRequest <|.. Message + FromMessageRequest <|.. ConnectionInfo + FromMessageRequest <|.. SharedState + SharedState --> ExtractError + ExtractError o-- DecodeError + Message o-- T + SharedState o-- T + ConnectionInfo o-- SocketAddr +``` + ### 5.4. Middleware and Extensibility "wireframe" will incorporate a middleware system conceptually similar to Actix diff --git a/src/extractor.rs b/src/extractor.rs index 0351fd1c..f315c46c 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -1,8 +1,10 @@ -//! Request context types and extractor traits. +//! Extractor and request context definitions. //! -//! The `MessageRequest` struct carries connection metadata and shared -//! application state. Implement [`FromMessageRequest`] to extract data -//! for handlers. +//! This module provides [`MessageRequest`], which carries connection +//! metadata and shared application state, along with a set of extractor +//! types. Implement [`FromMessageRequest`] for custom extractors to +//! parse payload bytes or inspect connection info before your handler +//! runs. use std::{ any::{Any, TypeId}, @@ -11,6 +13,8 @@ use std::{ sync::Arc, }; +use crate::message::Message as WireMessage; + /// Request context passed to extractors. /// /// This type contains metadata about the current connection and provides @@ -65,6 +69,12 @@ impl Payload<'_> { } /// Trait for extracting data from a [`MessageRequest`]. +/// +/// Types implementing this trait can be used as parameters to handler +/// functions. When invoked, `wireframe` passes the current request metadata and +/// message payload, allowing the extractor to parse bytes or inspect +/// connection information. This makes it easy to share common parsing and +/// validation logic across handlers. pub trait FromMessageRequest: Sized { /// Error type returned when extraction fails. type Error: std::error::Error + Send + Sync + 'static; @@ -85,7 +95,7 @@ pub trait FromMessageRequest: Sized { pub struct SharedState(Arc); impl SharedState { - /// Construct a new [`SharedState`]. + /// Creates a new [`SharedState`] instance wrapping the provided `Arc`. /// /// # Examples /// @@ -99,19 +109,6 @@ impl SharedState { /// assert_eq!(*state, 5); /// ``` #[must_use] - /// Creates a new `SharedState` instance wrapping the provided `Arc`. - /// - /// # Examples - /// - /// ```no_run - /// use std::sync::Arc; - /// - /// use wireframe::extractor::SharedState; - /// - /// let state = Arc::new(42); - /// let shared: SharedState = state.clone().into(); - /// assert_eq!(*shared, 42); - /// ``` #[deprecated(since = "0.2.0", note = "construct via `inner.into()` instead")] pub fn new(inner: Arc) -> Self { Self(inner) } } @@ -133,17 +130,34 @@ impl From for SharedState { pub enum ExtractError { /// No shared state of the requested type was found. MissingState(&'static str), + /// Failed to decode the message payload. + InvalidPayload(bincode::error::DecodeError), } impl std::fmt::Display for ExtractError { + /// Formats the `ExtractError` for display purposes. + /// + /// Displays a descriptive message for missing shared state or payload decoding errors. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::MissingState(ty) => write!(f, "no shared state registered for {ty}"), + Self::InvalidPayload(e) => write!(f, "failed to decode payload: {e}"), } } } -impl std::error::Error for ExtractError {} +impl std::error::Error for ExtractError { + /// Returns the underlying error if this is an `InvalidPayload` variant. + /// + /// # Returns + /// An optional reference to the underlying decode error, or `None` if not applicable. + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::InvalidPayload(e) => Some(e), + _ => None, + } + } +} impl FromMessageRequest for SharedState where @@ -180,3 +194,75 @@ impl std::ops::Deref for SharedState { /// ``` fn deref(&self) -> &Self::Target { &self.0 } } + +/// Extractor that deserializes the message payload into `T`. +#[derive(Debug, Clone)] +pub struct Message(T); + +impl Message { + /// Consumes the extractor and returns the inner deserialised message value. + #[must_use] + pub fn into_inner(self) -> T { self.0 } +} + +impl std::ops::Deref for Message { + type Target = T; + + /// Returns a reference to the inner value. + /// + /// This enables transparent access to the wrapped type via dereferencing. + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl FromMessageRequest for Message +where + T: WireMessage, +{ + type Error = ExtractError; + + /// Attempts to extract and deserialize a message of type `T` from the payload. + /// + /// Advances the payload by the number of bytes consumed during deserialization. + /// Returns an error if the payload cannot be decoded into the target type. + /// + /// # Returns + /// - `Ok(Self)`: The successfully extracted and deserialized message. + /// - `Err(ExtractError::InvalidPayload)`: If deserialization fails. + fn from_message_request( + _req: &MessageRequest, + payload: &mut Payload<'_>, + ) -> Result { + let (msg, consumed) = T::from_bytes(payload.data).map_err(ExtractError::InvalidPayload)?; + payload.advance(consumed); + Ok(Self(msg)) + } +} + +/// Extractor providing peer connection metadata. +#[derive(Debug, Clone, Copy)] +pub struct ConnectionInfo { + peer_addr: Option, +} + +impl ConnectionInfo { + /// Returns the peer's socket address for the current connection, if available. + #[must_use] + pub fn peer_addr(&self) -> Option { self.peer_addr } +} + +impl FromMessageRequest for ConnectionInfo { + type Error = std::convert::Infallible; + + /// Extracts connection metadata from the message request. + /// + /// Returns a `ConnectionInfo` containing the peer's socket address, if available. This + /// extraction is infallible. + fn from_message_request( + req: &MessageRequest, + _payload: &mut Payload<'_>, + ) -> Result { + Ok(Self { + peer_addr: req.peer_addr, + }) + } +} diff --git a/src/middleware.rs b/src/middleware.rs index 927648c0..10e10991 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -25,10 +25,7 @@ impl<'a, S> Next<'a, S> where S: Service + ?Sized, { - /// Create a new [`Next`] wrapping the given service. - #[inline] - #[must_use] - /// Creates a new `Next` instance wrapping a reference to the given service. + /// Creates a new [`Next`] instance wrapping a reference to the given service. /// /// # Examples /// @@ -46,6 +43,8 @@ where /// let service = MyService; /// let next = Next::new(&service); /// ``` + #[inline] + #[must_use] pub fn new(service: &'a S) -> Self { Self { service } } /// Call the next service with the provided request. diff --git a/src/server.rs b/src/server.rs index c0e0bb7b..cdf8e76c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -183,9 +183,6 @@ where self } - /// Get the configured worker count. - #[inline] - #[must_use] /// Returns the configured number of worker tasks for the server. /// /// # Examples @@ -197,6 +194,8 @@ where /// let server = WireframeServer::new(factory); /// assert!(server.worker_count() >= 1); /// ``` + #[inline] + #[must_use] pub const fn worker_count(&self) -> usize { self.workers } /// Get the socket address the server is bound to, if available. diff --git a/tests/extractor.rs b/tests/extractor.rs new file mode 100644 index 00000000..488f2a68 --- /dev/null +++ b/tests/extractor.rs @@ -0,0 +1,87 @@ +use std::{collections::HashMap, net::SocketAddr}; + +use wireframe::{ + extractor::{ConnectionInfo, FromMessageRequest, Message, MessageRequest, Payload}, + message::Message as MessageTrait, +}; + +#[derive(bincode::Encode, bincode::BorrowDecode, PartialEq, Debug)] +struct TestMsg(u8); + +#[test] +/// Tests that a message can be extracted from a payload and that the payload cursor advances fully. +/// +/// Verifies that a `TestMsg` instance serialised into bytes can be correctly extracted from a +/// `Payload` using `Message::::from_message_request`, and asserts that the payload has no +/// remaining unread data after extraction. +fn message_extractor_parses_and_advances() { + let msg = TestMsg(42); + let bytes = msg.to_bytes().unwrap(); + let mut payload = Payload { + data: bytes.as_slice(), + }; + let req = MessageRequest::default(); + + let extracted = Message::::from_message_request(&req, &mut payload).unwrap(); + assert_eq!(*extracted, msg); + assert_eq!(payload.remaining(), 0); +} + +#[test] +/// Tests that `ConnectionInfo` correctly reports the peer socket address extracted from a +/// `MessageRequest`. +fn connection_info_reports_peer() { + let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap(); + let req = MessageRequest { + peer_addr: Some(addr), + app_data: HashMap::default(), + }; + let mut payload = Payload::default(); + let info = ConnectionInfo::from_message_request(&req, &mut payload).unwrap(); + assert_eq!(info.peer_addr(), Some(addr)); +} + +#[test] +/// Tests that shared state of type `u8` can be successfully extracted from a `MessageRequest`'s +/// `app_data`. +/// +/// Inserts an `Arc` into the request's shared state, extracts it using the `SharedState` +/// extractor, and asserts that the extracted value matches the original. +fn shared_state_extractor() { + let mut data = HashMap::default(); + data.insert( + std::any::TypeId::of::(), + std::sync::Arc::new(42u8) as std::sync::Arc, + ); + let req = MessageRequest { + peer_addr: None, + app_data: data, + }; + let mut payload = Payload::default(); + + let state = + wireframe::extractor::SharedState::::from_message_request(&req, &mut payload).unwrap(); + assert_eq!(*state, 42); +} + +#[test] +/// Tests that extracting a missing shared state from a `MessageRequest` +/// returns an `ExtractError::MissingState` containing the type name. +/// +/// Ensures that when no shared state of the requested type is present, +/// the correct error is produced and includes the expected type information. +fn shared_state_missing_error() { + let req = MessageRequest::default(); + let mut payload = Payload::default(); + let Err(err) = + wireframe::extractor::SharedState::::from_message_request(&req, &mut payload) + else { + panic!("expected error"); + }; + match err { + wireframe::extractor::ExtractError::MissingState(name) => { + assert!(name.contains("u8")); + } + _ => panic!("unexpected error"), + } +}