diff --git a/README.md b/README.md index acf8a719..f576f9f6 100644 --- a/README.md +++ b/README.md @@ -234,8 +234,8 @@ impl FromMessageRequest for SessionToken { _req: &MessageRequest, payload: &mut Payload<'_>, ) -> Result { - let len = payload.data[0] as usize; - let token = std::str::from_utf8(&payload.data[1..=len]).unwrap().to_string(); + let len = payload.as_ref()[0] as usize; + let token = std::str::from_utf8(&payload.as_ref()[1..=len]).unwrap().to_string(); payload.advance(1 + len); Ok(Self(token)) } diff --git a/src/extractor.rs b/src/extractor.rs index f315c46c..a0095e03 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -34,6 +34,22 @@ impl MessageRequest { /// Retrieve shared state of type `T` if available. /// /// Returns `None` when no value of type `T` was registered. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::{ + /// app::WireframeApp, + /// extractor::{MessageRequest, SharedState}, + /// }; + /// + /// let _app = WireframeApp::new().unwrap().app_data(5u32); + /// // The framework populates the request with application data. + /// # let mut req = MessageRequest::default(); + /// # req.insert_state(5u32); + /// let val: Option> = req.state(); + /// assert_eq!(*val.unwrap(), 5); + /// ``` #[must_use] pub fn state(&self) -> Option> where @@ -44,13 +60,73 @@ impl MessageRequest { .and_then(|data| data.clone().downcast::().ok()) .map(SharedState) } + + /// Insert shared state of type `T` into the request. + /// + /// This replaces any existing value of the same type. + /// + /// # Examples + /// + /// ```rust + /// use wireframe::extractor::{MessageRequest, SharedState}; + /// + /// let mut req = MessageRequest::default(); + /// req.insert_state(5u32); + /// let val: Option> = req.state(); + /// assert_eq!(*val.unwrap(), 5); + /// ``` + pub fn insert_state(&mut self, state: T) + where + T: Send + Sync + 'static, + { + self.app_data.insert( + TypeId::of::(), + Arc::new(state) as Arc, + ); + } } /// Raw payload buffer handed to extractors. +/// +/// Create a `Payload` from a slice using [`Payload::new`] or `into`: +/// +/// ```rust +/// use wireframe::extractor::Payload; +/// +/// let p1 = Payload::new(b"abc"); +/// let p2: Payload<'_> = b"xyz".as_slice().into(); +/// assert_eq!(p1.as_ref(), b"abc" as &[u8]); +/// assert_eq!(p2.as_ref(), b"xyz" as &[u8]); +/// ``` #[derive(Default)] pub struct Payload<'a> { /// Incoming bytes not yet processed. - pub data: &'a [u8], + data: &'a [u8], +} + +impl<'a> Payload<'a> { + /// Creates a new `Payload` from the provided byte slice. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::extractor::Payload; + /// + /// let payload = Payload::new(b"data"); + /// assert_eq!(payload.as_ref(), b"data" as &[u8]); + /// ``` + #[must_use] + #[inline] + pub fn new(data: &'a [u8]) -> Self { Self { data } } +} + +impl<'a> From<&'a [u8]> for Payload<'a> { + #[inline] + fn from(data: &'a [u8]) -> Self { Self { data } } +} + +impl AsRef<[u8]> for Payload<'_> { + fn as_ref(&self) -> &[u8] { self.data } } impl Payload<'_> { @@ -58,12 +134,33 @@ impl Payload<'_> { /// /// Consumes up to `count` bytes from the front of the slice, ensuring we /// never slice beyond the available buffer. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::extractor::Payload; + /// + /// let mut payload = Payload::new(b"abcd"); + /// payload.advance(2); + /// assert_eq!(payload.as_ref(), b"cd" as &[u8]); + /// ``` pub fn advance(&mut self, count: usize) { let n = count.min(self.data.len()); self.data = &self.data[n..]; } /// Returns the number of bytes remaining. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::extractor::Payload; + /// + /// let mut payload = Payload::new(b"bytes"); + /// assert_eq!(payload.remaining(), 5); + /// payload.advance(2); + /// assert_eq!(payload.remaining(), 3); + /// ``` #[must_use] pub fn remaining(&self) -> usize { self.data.len() } } @@ -99,7 +196,7 @@ impl SharedState { /// /// # Examples /// - /// ```no_run + /// ```rust,no_run /// use std::sync::Arc; /// /// use wireframe::extractor::SharedState; @@ -183,7 +280,7 @@ impl std::ops::Deref for SharedState { /// /// # Examples /// - /// ```no_run + /// ```rust,no_run /// use std::sync::Arc; /// /// use wireframe::extractor::SharedState; @@ -246,6 +343,21 @@ pub struct ConnectionInfo { impl ConnectionInfo { /// Returns the peer's socket address for the current connection, if available. + /// + /// # Examples + /// + /// ```rust,no_run + /// use std::net::SocketAddr; + /// + /// use wireframe::extractor::{ConnectionInfo, FromMessageRequest, MessageRequest, Payload}; + /// + /// let req = MessageRequest { + /// peer_addr: Some("127.0.0.1:8080".parse::().unwrap()), + /// ..Default::default() + /// }; + /// let info = ConnectionInfo::from_message_request(&req, &mut Payload::default()).unwrap(); + /// assert_eq!(info.peer_addr(), req.peer_addr); + /// ``` #[must_use] pub fn peer_addr(&self) -> Option { self.peer_addr } } diff --git a/src/push.rs b/src/push.rs index aaeedef2..1d7d75c7 100644 --- a/src/push.rs +++ b/src/push.rs @@ -63,6 +63,21 @@ impl PushHandle { /// # Errors /// /// Returns [`PushError::Closed`] if the receiving end has been dropped. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::push::{PushPriority, PushQueues}; + /// + /// #[tokio::test] + /// async fn example() { + /// let (mut queues, handle) = PushQueues::bounded(1, 1); + /// handle.push_high_priority(42u8).await.unwrap(); + /// let (priority, frame) = queues.recv().await.unwrap(); + /// assert_eq!(priority, PushPriority::High); + /// assert_eq!(frame, 42); + /// } + /// ``` pub async fn push_high_priority(&self, frame: F) -> Result<(), PushError> { self.0 .high_prio_tx @@ -78,6 +93,21 @@ impl PushHandle { /// # Errors /// /// Returns [`PushError::Closed`] if the receiving end has been dropped. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::push::{PushPriority, PushQueues}; + /// + /// #[tokio::test] + /// async fn example() { + /// let (mut queues, handle) = PushQueues::bounded(1, 1); + /// handle.push_low_priority(10u8).await.unwrap(); + /// let (priority, frame) = queues.recv().await.unwrap(); + /// assert_eq!(priority, PushPriority::Low); + /// assert_eq!(frame, 10); + /// } + /// ``` pub async fn push_low_priority(&self, frame: F) -> Result<(), PushError> { self.0 .low_prio_tx @@ -93,6 +123,22 @@ impl PushHandle { /// Returns [`PushError::QueueFull`] if the queue is full and the policy is /// [`PushPolicy::ReturnErrorIfFull`]. Returns [`PushError::Closed`] if the /// receiving end has been dropped. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::push::{PushError, PushPolicy, PushPriority, PushQueues}; + /// + /// #[tokio::test] + /// async fn example() { + /// let (mut queues, handle) = PushQueues::bounded(1, 1); + /// handle.push_high_priority(1u8).await.unwrap(); + /// + /// let result = handle.try_push(2u8, PushPriority::High, PushPolicy::ReturnErrorIfFull); + /// assert!(matches!(result, Err(PushError::QueueFull))); + /// let _ = queues.recv().await; + /// } + /// ``` pub fn try_push( &self, frame: F, @@ -131,6 +177,21 @@ pub struct PushQueues { impl PushQueues { /// Create a new set of queues with the specified bounds for each priority /// and return them along with a [`PushHandle`] for producers. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::push::{PushPriority, PushQueues}; + /// + /// #[tokio::test] + /// async fn example() { + /// let (mut queues, handle) = PushQueues::::bounded(1, 1); + /// handle.push_high_priority(7u8).await.unwrap(); + /// let (priority, frame) = queues.recv().await.unwrap(); + /// assert_eq!(priority, PushPriority::High); + /// assert_eq!(frame, 7); + /// } + /// ``` #[must_use] pub fn bounded(high_capacity: usize, low_capacity: usize) -> (Self, PushHandle) { let (high_tx, high_rx) = mpsc::channel(high_capacity); @@ -151,6 +212,21 @@ impl PushQueues { /// Receive the next frame, preferring high priority frames when available. /// /// Returns `None` when both queues are closed and empty. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::push::{PushPriority, PushQueues}; + /// + /// #[tokio::test] + /// async fn example() { + /// let (mut queues, handle) = PushQueues::bounded(1, 1); + /// handle.push_high_priority(2u8).await.unwrap(); + /// let (priority, frame) = queues.recv().await.unwrap(); + /// assert_eq!(priority, PushPriority::High); + /// assert_eq!(frame, 2); + /// } + /// ``` pub async fn recv(&mut self) -> Option<(PushPriority, F)> { tokio::select! { biased; @@ -163,6 +239,15 @@ impl PushQueues { /// /// This is primarily used in tests to release resources when no actor is /// draining the queues. + /// + /// # Examples + /// + /// ```rust,no_run + /// use wireframe::push::PushQueues; + /// + /// let (mut queues, _handle) = PushQueues::::bounded(1, 1); + /// queues.close(); + /// ``` pub fn close(&mut self) { self.high_priority_rx.close(); self.low_priority_rx.close(); diff --git a/tests/app_data.rs b/tests/app_data.rs index 2000bf91..98900995 100644 --- a/tests/app_data.rs +++ b/tests/app_data.rs @@ -2,8 +2,6 @@ //! //! They verify successful extraction and error handling when state is missing. -use std::{any::TypeId, collections::HashMap, sync::Arc}; - use wireframe::extractor::{ ExtractError, FromMessageRequest, @@ -14,15 +12,8 @@ use wireframe::extractor::{ #[test] fn shared_state_extractor_returns_data() { - let mut map = HashMap::new(); - map.insert( - TypeId::of::(), - Arc::new(5u32) as Arc, - ); - let req = MessageRequest { - peer_addr: None, - app_data: map, - }; + let mut req = MessageRequest::default(); + req.insert_state(5u32); let mut payload = Payload::default(); let extracted = SharedState::::from_message_request(&req, &mut payload).unwrap(); assert_eq!(*extracted, 5); diff --git a/tests/extractor.rs b/tests/extractor.rs index 8b49774f..2f1564c3 100644 --- a/tests/extractor.rs +++ b/tests/extractor.rs @@ -2,7 +2,7 @@ //! //! Validate message parsing, connection info, and shared state behaviour. -use std::{collections::HashMap, net::SocketAddr}; +use std::net::SocketAddr; use wireframe::{ extractor::{ConnectionInfo, FromMessageRequest, Message, MessageRequest, Payload}, @@ -21,9 +21,7 @@ struct TestMsg(u8); fn message_extractor_parses_and_advances() { let msg = TestMsg(42); let bytes = msg.to_bytes().unwrap(); - let mut payload = Payload { - data: bytes.as_slice(), - }; + let mut payload = Payload::new(bytes.as_slice()); let req = MessageRequest::default(); let extracted = Message::::from_message_request(&req, &mut payload).unwrap(); @@ -35,10 +33,12 @@ fn message_extractor_parses_and_advances() { /// Tests that `ConnectionInfo` correctly reports the peer socket address extracted from a /// `MessageRequest`. fn connection_info_reports_peer() { - let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap(); + let addr: SocketAddr = "127.0.0.1:12345" + .parse() + .expect("hard-coded socket address must be valid"); let req = MessageRequest { peer_addr: Some(addr), - app_data: HashMap::default(), + ..Default::default() }; let mut payload = Payload::default(); let info = ConnectionInfo::from_message_request(&req, &mut payload).unwrap(); @@ -52,15 +52,8 @@ fn connection_info_reports_peer() { /// Inserts an `Arc` into the request's shared state, extracts it using the `SharedState` /// extractor, and asserts that the extracted value matches the original. fn shared_state_extractor() { - let mut data = HashMap::default(); - data.insert( - std::any::TypeId::of::(), - std::sync::Arc::new(42u8) as std::sync::Arc, - ); - let req = MessageRequest { - peer_addr: None, - app_data: data, - }; + let mut req = MessageRequest::default(); + req.insert_state(42u8); let mut payload = Payload::default(); let state =