diff --git a/Cargo.lock b/Cargo.lock index a78fc8f6..f2310555 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,17 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "async-trait" +version = "0.1.88" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.4.0" @@ -58,6 +69,12 @@ dependencies = [ "virtue", ] +[[package]] +name = "bytes" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" + [[package]] name = "cfg-if" version = "1.0.1" @@ -448,7 +465,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" name = "wireframe" version = "0.1.0" dependencies = [ + "async-trait", "bincode", + "bytes", "futures", "num_cpus", "serde", diff --git a/Cargo.toml b/Cargo.toml index 0f795c61..50b2132b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,8 @@ bincode = "2" tokio = { version = "1", default-features = false, features = ["net", "signal", "rt-multi-thread", "macros", "sync", "time"] } futures = "0.3" num_cpus = "^1" +async-trait = "0.1" +bytes = "1" [lints.clippy] pedantic = "warn" diff --git a/docs/roadmap.md b/docs/roadmap.md index 3c943a5e..f0296d59 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -22,7 +22,7 @@ after formatting. Line numbers below refer to that file. `WireframeApp` instance from a factory closure. A Ctrl+C signal triggers graceful shutdown, notifying all workers to stop accepting new connections. - - [ ] Standardise supporting trait definitions. + - [x] Standardize supporting trait definitions. Provide naming conventions and generic bounds for the `FrameProcessor` trait, state extractors and middleware via `async_trait` and associated types. diff --git a/docs/rust-binary-router-library-design.md b/docs/rust-binary-router-library-design.md index 3793a674..259408a2 100644 --- a/docs/rust-binary-router-library-design.md +++ b/docs/rust-binary-router-library-design.md @@ -769,13 +769,14 @@ within handlers. ````rustrust use wireframe::dev::{MessageRequest, Payload}; // Hypothetical types - use std::future::Future; pub trait FromMessageRequest: Sized { type Error: Into; // Error type if extraction fails - type Future: Future>; - fn from_message_request(req: &MessageRequest, payload: &mut Payload) -> Self::Future; + fn from_message_request( + req: &MessageRequest, + payload: &mut Payload, + ) -> Result; } ```rust diff --git a/src/extractor.rs b/src/extractor.rs new file mode 100644 index 00000000..6f31a4bd --- /dev/null +++ b/src/extractor.rs @@ -0,0 +1,127 @@ +use std::net::SocketAddr; +use std::sync::Arc; + +/// Request context passed to extractors. +/// +/// This type contains metadata about the current connection and provides +/// access to application state registered with [`WireframeApp`]. +#[derive(Default)] +pub struct MessageRequest { + /// Address of the peer that sent the current message. + pub peer_addr: Option, +} + +/// Raw payload buffer handed to extractors. +#[derive(Default)] +pub struct Payload<'a> { + /// Incoming bytes not yet processed. + pub data: &'a [u8], +} + +impl Payload<'_> { + /// Advances the payload by `count` bytes. + /// + /// Consumes up to `count` bytes from the front of the slice, ensuring we + /// never slice beyond the available buffer. + pub fn advance(&mut self, count: usize) { + let n = count.min(self.data.len()); + self.data = &self.data[n..]; + } + + /// Returns the number of bytes remaining. + #[must_use] + pub fn remaining(&self) -> usize { + self.data.len() + } +} + +/// Trait for extracting data from a [`MessageRequest`]. +pub trait FromMessageRequest: Sized { + /// Error type returned when extraction fails. + type Error: std::error::Error + Send + Sync + 'static; + + /// Perform extraction from the request and payload. + /// + /// # Errors + /// + /// Returns an error if extraction fails. + fn from_message_request( + req: &MessageRequest, + payload: &mut Payload<'_>, + ) -> Result; +} + +/// Shared application state accessible to handlers. +#[derive(Clone)] +pub struct SharedState(Arc); + +impl SharedState { + /// Construct a new [`SharedState`]. + /// + /// # Examples + /// + /// ```ignore + /// use std::sync::Arc; + /// use wireframe::extractor::SharedState; + /// + /// let data = Arc::new(5u32); + /// let state = SharedState::new(Arc::clone(&data)); + /// assert_eq!(*state, 5); + /// ``` + #[must_use] + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn advance_consumes_bytes() { + let mut payload = Payload { data: b"hello" }; + payload.advance(2); + assert_eq!(payload.data, b"llo"); + payload.advance(10); + assert!(payload.data.is_empty()); + } + + #[test] + fn remaining_reports_length() { + let mut payload = Payload { data: b"abc" }; + assert_eq!(payload.remaining(), 3); + payload.advance(1); + assert_eq!(payload.remaining(), 2); + } +} + /// Creates a new `SharedState` instance wrapping the provided `Arc`. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// let state = Arc::new(42); + /// let shared = SharedState::new(state.clone()); + /// assert_eq!(*shared, 42); + /// ``` + pub fn new(inner: Arc) -> Self { + Self(inner) + } +} + +impl std::ops::Deref for SharedState { + type Target = T; + + /// Returns a reference to the inner shared state value. + /// + /// This allows transparent access to the underlying state managed by `SharedState`. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// let state = Arc::new(42); + /// let shared = SharedState::new(state.clone()); + /// assert_eq!(*shared, 42); + /// ``` + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/frame.rs b/src/frame.rs new file mode 100644 index 00000000..c15c3a66 --- /dev/null +++ b/src/frame.rs @@ -0,0 +1,23 @@ +use async_trait::async_trait; +use bytes::BytesMut; + +/// Trait defining how raw bytes are decoded into frames and how frames are +/// encoded back into bytes for transmission. +/// +/// The `Frame` associated type represents a logical unit extracted from or +/// written to the wire. Errors are represented by the `Error` associated type, +/// which must implement [`std::error::Error`]. +#[async_trait] +pub trait FrameProcessor: Send + Sync { + /// Logical frame type extracted from the stream. + type Frame; + + /// Error type returned by `decode` and `encode`. + type Error: std::error::Error + Send + Sync + 'static; + + /// Attempt to decode the next frame from `src`. + async fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error>; + + /// Encode `frame` and append the bytes to `dst`. + async fn encode(&mut self, frame: &Self::Frame, dst: &mut BytesMut) -> Result<(), Self::Error>; +} diff --git a/src/lib.rs b/src/lib.rs index 9c88284c..d93632f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,6 @@ pub mod app; +pub mod extractor; +pub mod frame; pub mod message; +pub mod middleware; pub mod server; diff --git a/src/middleware.rs b/src/middleware.rs new file mode 100644 index 00000000..ba5ab521 --- /dev/null +++ b/src/middleware.rs @@ -0,0 +1,91 @@ +use async_trait::async_trait; + +/// Incoming request wrapper passed through middleware. +#[derive(Debug)] +pub struct ServiceRequest; + +/// Response produced by a handler or middleware. +#[derive(Debug, Default)] +pub struct ServiceResponse; + +/// Continuation used by middleware to call the next service in the chain. +pub struct Next<'a, S> +where + S: Service + ?Sized, +{ + service: &'a S, +} + +impl<'a, S> Next<'a, S> +where + S: Service + ?Sized, +{ + /// Creates a new `Next` instance wrapping a reference to the given service. + /// +/// +/// ```ignore +/// use wireframe::middleware::{ServiceRequest, ServiceResponse, Next, Service}; +/// ``` + /// Service produced by the middleware. + type Wrapped: Service; + async fn transform(&self, service: S) -> Self::Wrapped; + /// let service = MyService::default(); + /// let next = Next::new(&service); + type Wrapped: Service; + async fn transform(&self, service: S) -> Self::Wrapped; + Self { service } + } + + /// Call the next service with the given request. + /// + /// # Errors + /// + /// Asynchronously invokes the next service in the middleware chain with the given request. + /// + /// Returns the response from the wrapped service, or propagates any error produced. + /// + /// # Examples + /// + /// ``` + /// # use your_crate::{ServiceRequest, ServiceResponse, Next, Service}; + /// # struct DummyService; + /// # #[async_trait::async_trait] + /// # impl Service for DummyService { + /// # type Error = std::convert::Infallible; + /// # async fn call(&self, _req: ServiceRequest) -> Result { + /// # Ok(ServiceResponse::default()) + /// # } + /// # } + /// # let service = DummyService; + /// let next = Next::new(&service); + /// let req = ServiceRequest {}; + /// let res = tokio_test::block_on(next.call(req)); + /// assert!(res.is_ok()); + /// ``` + pub async fn call(&self, req: ServiceRequest) -> Result { + self.service.call(req).await + } +} + +/// Trait representing an asynchronous service. +#[async_trait] +pub trait Service: Send + Sync { + /// Error type returned by the service. + type Error: std::error::Error + Send + Sync + 'static; + + /// Handle the incoming request and produce a response. + async fn call(&self, req: ServiceRequest) -> Result; +} + +/// Factory for wrapping services with middleware. +#[async_trait] +pub trait Transform: Send + Sync +where + S: Service, +{ + /// Wrapped service produced by the middleware. + type Output: Service; + + /// Create a new middleware service wrapping `service`. + async fn transform(&self, service: S) -> Self::Output; +} diff --git a/src/server.rs b/src/server.rs index 6ac1fa3c..66219a5a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -32,34 +32,36 @@ where /// /// The server is initialised with a default worker count equal to the number of CPU cores. /// - /// # Examples + /// ```no_run + /// use wireframe::{app::WireframeApp, server::WireframeServer}; /// + /// let factory = || WireframeApp::new().unwrap(); + /// let server = WireframeServer::new(factory); /// ``` - /// let server = WireframeServer::new(|| WireframeApp::default()); - /// ``` - pub fn new(factory: F) -> Self { - Self { - factory, - listener: None, - workers: num_cpus::get(), - } - } + workers: num_cpus::get().max(1), - /// Set the number of worker tasks to spawn. - #[must_use] - /// Sets the number of worker tasks to spawn for the server. - /// - /// Ensures that at least one worker is configured, even if a lower value is provided. + /// Set the number of worker tasks to spawn for the server. + /// + /// #[tokio::main] + /// async fn main() -> std::io::Result<()> { + /// let factory = || WireframeApp::new().unwrap(); + /// WireframeServer::new(factory) + /// .workers(4) + /// .bind("127.0.0.1:0".parse().unwrap())? + /// .run() + /// .await + /// } + /// A new `WireframeServer` instance with the updated worker count. /// - /// # Parameters - /// - `count`: Desired number of worker tasks. + /// # Examples /// - /// # Returns - /// A new `WireframeServer` instance with the updated worker count. + /// ```ignore + /// let server = WireframeServer::new(factory).workers(4); + /// Sets the number of worker tasks for the server, ensuring at least one worker. /// /// # Examples /// - /// ``` + /// ```ignore /// let server = WireframeServer::new(factory).workers(4); /// ``` pub fn workers(mut self, count: usize) -> Self { @@ -85,7 +87,7 @@ where /// /// # Examples /// - /// ``` + /// ```ignore /// use std::net::SocketAddr; /// let server = WireframeServer::new(factory); /// let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap(); @@ -125,7 +127,7 @@ where /// /// # Examples /// - /// ``` + /// ```ignore /// # use std::net::SocketAddr; /// # use mycrate::{WireframeServer, WireframeApp}; /// # async fn run_server() -> std::io::Result<()> {