Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions src/extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ pub trait FromMessageRequest: Sized {
pub struct SharedState<T: Send + Sync>(Arc<T>);

impl<T: Send + Sync> SharedState<T> {
/// Construct a new [`SharedState`].
/// Creates a new [`SharedState`] instance wrapping the provided `Arc<T>`.
///
/// # Examples
///
Expand All @@ -109,19 +109,6 @@ impl<T: Send + Sync> SharedState<T> {
/// assert_eq!(*state, 5);
/// ```
#[must_use]
/// Creates a new `SharedState` instance wrapping the provided `Arc<T>`.
///
/// # Examples
///
/// ```no_run
/// use std::sync::Arc;
///
/// use wireframe::extractor::SharedState;
///
/// let state = Arc::new(42);
/// let shared: SharedState<u32> = state.clone().into();
/// assert_eq!(*shared, 42);
/// ```
#[deprecated(since = "0.2.0", note = "construct via `inner.into()` instead")]
pub fn new(inner: Arc<T>) -> Self { Self(inner) }
}
Expand All @@ -148,6 +135,9 @@ pub enum ExtractError {
}

impl std::fmt::Display for ExtractError {
/// Formats the `ExtractError` for display purposes.
///
/// Displays a descriptive message for missing shared state or payload decoding errors.
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingState(ty) => write!(f, "no shared state registered for {ty}"),
Expand All @@ -157,6 +147,10 @@ impl std::fmt::Display for ExtractError {
}

impl std::error::Error for ExtractError {
/// Returns the underlying error if this is an `InvalidPayload` variant.
///
/// # Returns
/// An optional reference to the underlying decode error, or `None` if not applicable.
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::InvalidPayload(e) => Some(e),
Expand Down Expand Up @@ -206,14 +200,17 @@ impl<T: Send + Sync> std::ops::Deref for SharedState<T> {
pub struct Message<T>(T);

impl<T> Message<T> {
/// Consume the extractor, returning the inner message.
/// Consumes the extractor and returns the inner deserialised message value.
#[must_use]
pub fn into_inner(self) -> T { self.0 }
}

impl<T> std::ops::Deref for Message<T> {
type Target = T;

/// Returns a reference to the inner value.
///
/// This enables transparent access to the wrapped type via dereferencing.
fn deref(&self) -> &Self::Target { &self.0 }
}

Expand All @@ -223,6 +220,14 @@ where
{
type Error = ExtractError;

/// Attempts to extract and deserialize a message of type `T` from the payload.
///
/// Advances the payload by the number of bytes consumed during deserialization.
/// Returns an error if the payload cannot be decoded into the target type.
///
/// # Returns
/// - `Ok(Self)`: The successfully extracted and deserialized message.
/// - `Err(ExtractError::InvalidPayload)`: If deserialization fails.
fn from_message_request(
_req: &MessageRequest,
payload: &mut Payload<'_>,
Expand All @@ -240,14 +245,18 @@ pub struct ConnectionInfo {
}

impl ConnectionInfo {
/// Returns the peer's socket address, if known.
/// Returns the peer's socket address for the current connection, if available.
#[must_use]
pub fn peer_addr(&self) -> Option<SocketAddr> { self.peer_addr }
}

impl FromMessageRequest for ConnectionInfo {
type Error = std::convert::Infallible;

/// Extracts connection metadata from the message request.
///
/// Returns a `ConnectionInfo` containing the peer's socket address, if available. This
/// extraction is infallible.
fn from_message_request(
req: &MessageRequest,
_payload: &mut Payload<'_>,
Expand Down
7 changes: 3 additions & 4 deletions src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@ impl<'a, S> Next<'a, S>
where
S: Service + ?Sized,
{
/// Create a new [`Next`] wrapping the given service.
#[inline]
#[must_use]
/// Creates a new `Next` instance wrapping a reference to the given service.
/// Creates a new [`Next`] instance wrapping a reference to the given service.
///
/// # Examples
///
Expand All @@ -46,6 +43,8 @@ where
/// let service = MyService;
/// let next = Next::new(&service);
/// ```
#[inline]
#[must_use]
pub fn new(service: &'a S) -> Self { Self { service } }

/// Call the next service with the provided request.
Expand Down
5 changes: 2 additions & 3 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,6 @@ where
self
}

/// Get the configured worker count.
#[inline]
#[must_use]
/// Returns the configured number of worker tasks for the server.
///
/// # Examples
Expand All @@ -197,6 +194,8 @@ where
/// let server = WireframeServer::new(factory);
/// assert!(server.worker_count() >= 1);
/// ```
#[inline]
#[must_use]
pub const fn worker_count(&self) -> usize { self.workers }

/// Get the socket address the server is bound to, if available.
Expand Down
17 changes: 17 additions & 0 deletions tests/extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ use wireframe::{
struct TestMsg(u8);

#[test]
/// Tests that a message can be extracted from a payload and that the payload cursor advances fully.
///
/// Verifies that a `TestMsg` instance serialised into bytes can be correctly extracted from a
/// `Payload` using `Message::<TestMsg>::from_message_request`, and asserts that the payload has no
/// remaining unread data after extraction.
fn message_extractor_parses_and_advances() {
let msg = TestMsg(42);
let bytes = msg.to_bytes().unwrap();
Expand All @@ -23,6 +28,8 @@ fn message_extractor_parses_and_advances() {
}

#[test]
/// Tests that `ConnectionInfo` correctly reports the peer socket address extracted from a
/// `MessageRequest`.
fn connection_info_reports_peer() {
let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
let req = MessageRequest {
Expand All @@ -35,6 +42,11 @@ fn connection_info_reports_peer() {
}

#[test]
/// Tests that shared state of type `u8` can be successfully extracted from a `MessageRequest`'s
/// `app_data`.
///
/// Inserts an `Arc<u8>` into the request's shared state, extracts it using the `SharedState`
/// extractor, and asserts that the extracted value matches the original.
fn shared_state_extractor() {
let mut data = HashMap::default();
data.insert(
Expand All @@ -53,6 +65,11 @@ fn shared_state_extractor() {
}

#[test]
/// Tests that extracting a missing shared state from a `MessageRequest`
/// returns an `ExtractError::MissingState` containing the type name.
///
/// Ensures that when no shared state of the requested type is present,
/// the correct error is produced and includes the expected type information.
fn shared_state_missing_error() {
let req = MessageRequest::default();
let mut payload = Payload::default();
Expand Down