diff --git a/Cargo.lock b/Cargo.lock index a80969b2..cf39b728 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -178,9 +178,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "libc" -version = "0.2.172" +version = "0.2.173" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +checksum = "d8cfeafaffdbc32176b64fb251369d52ea9f0a8fbc6f8759edffef7b525d64bb" [[package]] name = "memchr" @@ -319,6 +319,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" dependencies = [ "backtrace", + "bytes", "libc", "mio", "pin-project-lite", diff --git a/Cargo.toml b/Cargo.toml index ce7ff1e9..17dd0710 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2024" [dependencies] serde = { version = "1", features = ["derive"] } bincode = "2" -tokio = { version = "1", default-features = false, features = ["net", "signal", "rt-multi-thread", "macros", "sync", "time"] } +tokio = { version = "1", default-features = false, features = ["net", "signal", "rt-multi-thread", "macros", "sync", "time", "io-util"] } futures = "0.3" async-trait = "0.1" bytes = "1" diff --git a/README.md b/README.md index cd6fea42..f22072f0 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ reduce this boilerplate through layered abstractions: - **Transport adapter** built on Tokio I/O - **Framing layer** for length‑prefixed or custom frames +- **Connection preamble** with customizable validation callbacks [[docs](docs/preamble-validator.md)] - **Serialization engine** using `bincode` or a `wire-rs` wrapper - **Routing engine** that dispatches messages by ID - **Handler invocation** with extractor support diff --git a/docs/preamble-validator.md b/docs/preamble-validator.md new file mode 100644 index 00000000..f0905fd2 --- /dev/null +++ b/docs/preamble-validator.md @@ -0,0 +1,33 @@ +# Connection Preamble Validation + +`wireframe` supports an optional connection preamble that is read as soon as a +client connects. The server decodes the preamble with +[`read_preamble`](../src/preamble.rs) and can invoke user-supplied callbacks on +success or failure. The helper uses `bincode` to decode any type implementing +`bincode::Decode` and reads exactly the number of bytes required. + +The flow is summarized below: + +```mermaid +sequenceDiagram + participant Client + participant Server + participant PreambleDecoder + participant SuccessCallback + participant FailureCallback + + Client->>Server: Connects and sends preamble bytes + Server->>PreambleDecoder: Reads and decodes preamble + alt Decode success + PreambleDecoder-->>Server: Decoded preamble (T) + Server->>SuccessCallback: Invoke with preamble data + else Decode failure + PreambleDecoder-->>Server: DecodeError + Server->>FailureCallback: Invoke with error + end + Server-->>Client: (Continues or closes connection) +``` + +In the tests, a `HotlinePreamble` struct illustrates the pattern, but any +preamble type may be used. Register callbacks via `on_preamble_decode_success` +and `on_preamble_decode_failure` on `WireframeServer`. diff --git a/docs/roadmap.md b/docs/roadmap.md index 5c6fd1ed..60919ecb 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -55,9 +55,10 @@ after formatting. Line numbers below refer to that file. } ``` -- [ ] Add connection preamble support. - Provide built-in parsing of a handshake preamble (e.g., Hotline's "TRTP") - and invoke a user-configured handler on success or failure. +- [x] Add connection preamble support. + Provide generic parsing of connection preambles with a Hotline handshake + example in the tests. Invoke user-configured callbacks on decode success + or failure. See [preamble-validator](preamble-validator.md). - [ ] Add response serialization and transmission. Encode handler responses using the selected serialization format and write them back through the framing layer. diff --git a/src/app.rs b/src/app.rs index 291e654d..122a1316 100644 --- a/src/app.rs +++ b/src/app.rs @@ -82,4 +82,16 @@ impl WireframeApp { self.middleware.push(Box::new(mw)); Ok(self) } + + /// Handle an accepted connection. + /// + /// This placeholder simply drops the stream. Future implementations + /// will decode frames and dispatch them to registered handlers. + pub async fn handle_connection(&self, _stream: S) + where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static, + { + // Connection handling will be implemented later. + tokio::task::yield_now().await; + } } diff --git a/src/lib.rs b/src/lib.rs index d93632f1..d9acc4e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,4 +3,6 @@ pub mod extractor; pub mod frame; pub mod message; pub mod middleware; +pub mod preamble; +pub mod rewind_stream; pub mod server; diff --git a/src/preamble.rs b/src/preamble.rs new file mode 100644 index 00000000..81c9bc0c --- /dev/null +++ b/src/preamble.rs @@ -0,0 +1,79 @@ +use bincode::error::DecodeError; +use bincode::{Decode, config, decode_from_slice}; +use tokio::io::{self, AsyncRead, AsyncReadExt}; + +const MAX_PREAMBLE_LEN: usize = 1024; + +async fn read_more( + reader: &mut R, + buf: &mut Vec, + additional: usize, +) -> Result<(), DecodeError> +where + R: AsyncRead + Unpin, +{ + let start = buf.len(); + if start + additional > MAX_PREAMBLE_LEN { + return Err(DecodeError::Other("preamble too long")); + } + buf.resize(start + additional, 0); + let mut read = 0; + while read < additional { + match reader + .read(&mut buf[start + read..start + additional]) + .await + { + Ok(0) => { + return Err(DecodeError::Io { + inner: io::Error::from(io::ErrorKind::UnexpectedEof), + additional: additional - read, + }); + } + Ok(n) => read += n, + Err(inner) => { + return Err(DecodeError::Io { + inner, + additional: additional - read, + }); + } + } + } + Ok(()) +} + +/// Read and decode a connection preamble using bincode. +/// +/// This helper reads the exact number of bytes required by `T`, as +/// indicated by [`DecodeError::UnexpectedEnd`]. Additional bytes are +/// requested from the reader until decoding succeeds or fails for some +/// other reason. +/// +/// # Errors +/// +/// Returns a [`DecodeError`] if decoding the preamble fails or an +/// underlying I/O error occurs while reading from `reader`. +pub async fn read_preamble(reader: &mut R) -> Result<(T, Vec), DecodeError> +where + R: AsyncRead + Unpin, + // `Decode` expects a decoding context type, not a lifetime. Most callers + // use the unit type as the context, which requires no external state. + T: Decode<()>, +{ + let mut buf = Vec::new(); + // Build the decoder configuration once to avoid repeated allocations. + let config = config::standard() + .with_big_endian() + .with_fixed_int_encoding(); + loop { + match decode_from_slice::(&buf, config) { + Ok((value, consumed)) => { + let leftover = buf.split_off(consumed); + return Ok((value, leftover)); + } + Err(DecodeError::UnexpectedEnd { additional }) => { + read_more(reader, &mut buf, additional).await?; + } + Err(e) => return Err(e), + } + } +} diff --git a/src/rewind_stream.rs b/src/rewind_stream.rs new file mode 100644 index 00000000..65a5ff76 --- /dev/null +++ b/src/rewind_stream.rs @@ -0,0 +1,72 @@ +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// A stream adapter that replays buffered bytes before reading +/// from the underlying stream. +pub struct RewindStream { + leftover: Vec, + pos: usize, + inner: S, +} + +impl RewindStream { + /// Create a new `RewindStream` that will yield `leftover` before + /// delegating to `inner`. + pub fn new(leftover: Vec, inner: S) -> Self { + Self { + leftover, + pos: 0, + inner, + } + } +} + +impl AsyncRead for RewindStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.pos < self.leftover.len() { + let remaining = self.leftover.len() - self.pos; + let to_copy = remaining.min(buf.remaining()); + let start = self.pos; + let end = start + to_copy; + buf.put_slice(&self.leftover[start..end]); + self.pos += to_copy; + if self.pos < self.leftover.len() || to_copy > 0 { + return Poll::Ready(Ok(())); + } + } + + if self.pos >= self.leftover.len() && !self.leftover.is_empty() { + self.leftover.clear(); + self.pos = 0; + } + + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for RewindStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +impl Unpin for RewindStream {} diff --git a/src/server.rs b/src/server.rs index e1e850f0..46c21f64 100644 --- a/src/server.rs +++ b/src/server.rs @@ -6,6 +6,12 @@ use tokio::net::TcpListener; use tokio::sync::broadcast; use tokio::time::{Duration, sleep}; +use core::marker::PhantomData; + +use crate::preamble::read_preamble; +use crate::rewind_stream::RewindStream; +use bincode::error::DecodeError; + use crate::app::WireframeApp; /// Tokio-based server for `WireframeApp` instances. @@ -15,16 +21,23 @@ use crate::app::WireframeApp; /// closure. The server listens for a shutdown signal using /// `tokio::signal::ctrl_c` and notifies all workers to stop /// accepting new connections. -pub struct WireframeServer +#[allow(clippy::type_complexity)] +pub struct WireframeServer where F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, + // `Decode`'s type parameter represents a decoding context. + // The unit type signals that no context is required. + T: bincode::Decode<()> + Send + 'static, { factory: F, listener: Option>, workers: usize, + on_preamble_success: Option>, + on_preamble_failure: Option>, + _preamble: PhantomData, } -impl WireframeServer +impl WireframeServer where F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, { @@ -66,9 +79,39 @@ where factory, listener: None, workers, + on_preamble_success: None, + on_preamble_failure: None, + _preamble: PhantomData, + } + } + + /// Convert this server to parse a custom preamble `T`. + /// + /// Call this before registering preamble handlers, otherwise any + /// previously configured callbacks will be dropped. + #[must_use] + pub fn with_preamble(self) -> WireframeServer + where + // Unit context indicates no external state is required when decoding. + T: bincode::Decode<()> + Send + 'static, + { + WireframeServer { + factory: self.factory, + listener: self.listener, + workers: self.workers, + on_preamble_success: None, + on_preamble_failure: None, + _preamble: PhantomData, } } +} +impl WireframeServer +where + F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, + // `Decode` is generic over a context type; we use `()` here. + T: bincode::Decode<()> + Send + 'static, +{ /// Set the number of worker tasks to spawn for the server. /// /// Ensures at least one worker is configured. @@ -102,6 +145,27 @@ where self } + /// Register a callback invoked when the connection preamble + /// decodes successfully. + #[must_use] + pub fn on_preamble_decode_success(mut self, handler: H) -> Self + where + H: Fn(&T) + Send + Sync + 'static, + { + self.on_preamble_success = Some(Arc::new(handler)); + self + } + + /// Register a callback invoked when the connection preamble fails to decode. + #[must_use] + pub fn on_preamble_decode_failure(mut self, handler: H) -> Self + where + H: Fn(&DecodeError) + Send + Sync + 'static, + { + self.on_preamble_failure = Some(Arc::new(handler)); + self + } + /// Get the configured worker count. #[inline] #[must_use] @@ -120,6 +184,12 @@ where self.workers } + /// Get the socket address the server is bound to, if available. + #[must_use] + pub fn local_addr(&self) -> Option { + self.listener.as_ref().and_then(|l| l.local_addr().ok()) + } + /// Bind the server to the given address and create a listener. /// /// # Errors @@ -193,51 +263,123 @@ where /// } /// ``` pub async fn run(self) -> io::Result<()> { + self.run_with_shutdown(async { + let _ = tokio::signal::ctrl_c().await; + }) + .await + } + + /// Run the server until the `shutdown` future resolves. + /// + /// # Errors + /// + /// Returns an [`io::Error`] if accepting a connection fails during + /// runtime. + /// + /// # Panics + /// + /// Panics if [`bind`](Self::bind) was not called beforehand. + pub async fn run_with_shutdown(self, shutdown: S) -> io::Result<()> + where + S: futures::Future + Send, + { let listener = self.listener.expect("`bind` must be called before `run`"); let (shutdown_tx, _) = broadcast::channel(16); - // Spawn worker tasks using Tokio's runtime. + // Spawn worker tasks. let mut handles = Vec::with_capacity(self.workers); for _ in 0..self.workers { - let mut shutdown_rx = shutdown_tx.subscribe(); let listener = Arc::clone(&listener); let factory = self.factory.clone(); + let on_success = self.on_preamble_success.clone(); + let on_failure = self.on_preamble_failure.clone(); + let mut shutdown_rx = shutdown_tx.subscribe(); handles.push(tokio::spawn(async move { - let app = (factory)(); - let mut delay = Duration::from_millis(10); - loop { - tokio::select! { - res = listener.accept() => match res { - Ok((_stream, _)) => { - // TODO: hand off stream to `app` - delay = Duration::from_millis(10); - } - Err(e) => { - eprintln!("accept error: {e}"); - sleep(delay).await; - delay = (delay * 2).min(Duration::from_secs(1)); - } - }, - _ = shutdown_rx.recv() => break, - } - } - drop(app); + worker_task(listener, factory, on_success, on_failure, &mut shutdown_rx).await; })); } - // Wait for Ctrl+C or workers finishing. let join_all = futures::future::join_all(handles); tokio::pin!(join_all); tokio::select! { - _ = tokio::signal::ctrl_c() => { + () = shutdown => { let _ = shutdown_tx.send(()); } _ = &mut join_all => {} } - // Ensure all workers have exited before returning. - join_all.await; + for res in join_all.await { + if let Err(e) = res { + eprintln!("worker task failed: {e}"); + } + } Ok(()) } } + +#[allow(clippy::type_complexity)] +async fn worker_task( + listener: Arc, + factory: F, + on_success: Option>, + on_failure: Option>, + shutdown_rx: &mut broadcast::Receiver<()>, +) where + F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, + // The unit context indicates no additional state is needed to decode `T`. + T: bincode::Decode<()> + Send + 'static, +{ + let mut delay = Duration::from_millis(10); + loop { + tokio::select! { + res = listener.accept() => match res { + Ok((stream, _)) => { + let success = on_success.clone(); + let failure = on_failure.clone(); + let factory = factory.clone(); + tokio::spawn(process_stream(stream, factory, success, failure)); + delay = Duration::from_millis(10); + } + Err(e) => { + eprintln!("accept error: {e}"); + sleep(delay).await; + delay = (delay * 2).min(Duration::from_secs(1)); + } + }, + _ = shutdown_rx.recv() => break, + } + } +} + +#[allow(clippy::type_complexity)] +async fn process_stream( + mut stream: tokio::net::TcpStream, + factory: F, + on_success: Option>, + on_failure: Option>, +) where + F: Fn() -> WireframeApp + Send + Sync + 'static, + // The decoding context parameter is `()`; no external state is needed. + T: bincode::Decode<()> + Send + 'static, +{ + match read_preamble::<_, T>(&mut stream).await { + Ok((preamble, leftover)) => { + if let Some(handler) = on_success.as_ref() { + handler(&preamble); + } + let stream = RewindStream::new(leftover, stream); + // Hand the connection to the application for processing. + let app = (factory)(); + tokio::spawn(async move { + app.handle_connection(stream).await; + }); + } + Err(err) => { + if let Some(handler) = on_failure.as_ref() { + handler(&err); + } + // drop stream on failure + } + } +} diff --git a/tests/preamble.rs b/tests/preamble.rs new file mode 100644 index 00000000..4446ea6f --- /dev/null +++ b/tests/preamble.rs @@ -0,0 +1,171 @@ +use bincode::error::DecodeError; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::net::TcpStream; +use tokio::sync::oneshot; +use tokio::time::{Duration, timeout}; +use wireframe::preamble::read_preamble; +use wireframe::{app::WireframeApp, server::WireframeServer}; + +#[derive(Debug, Clone, PartialEq, Eq, bincode::Encode, bincode::Decode)] +struct HotlinePreamble { + /// Should always be `b"TRTPHOTL"`. + magic: [u8; 8], + /// Minimum server version this client supports. + min_version: u16, + /// Client version. + client_version: u16, +} + +impl HotlinePreamble { + const MAGIC: [u8; 8] = *b"TRTPHOTL"; + + fn validate(&self) -> Result<(), DecodeError> { + if self.magic != Self::MAGIC { + return Err(DecodeError::Other("invalid hotline preamble")); + } + Ok(()) + } +} + +#[tokio::test] +async fn parse_valid_preamble() { + let (mut client, mut server) = duplex(64); + let bytes = b"TRTPHOTL\x00\x01\x00\x02"; + client.write_all(bytes).await.unwrap(); + client.shutdown().await.unwrap(); + let (p, _) = read_preamble::<_, HotlinePreamble>(&mut server) + .await + .expect("valid preamble"); + eprintln!("decoded: {:?}", p); + p.validate().unwrap(); + assert_eq!(p.magic, HotlinePreamble::MAGIC); + assert_eq!(p.min_version, 1); + assert_eq!(p.client_version, 2); +} + +#[tokio::test] +async fn invalid_magic_is_error() { + let (mut client, mut server) = duplex(64); + let bytes = b"WRONGMAG\x00\x01\x00\x02"; + client.write_all(bytes).await.unwrap(); + client.shutdown().await.unwrap(); + let (preamble, _) = read_preamble::<_, HotlinePreamble>(&mut server) + .await + .expect("decoded"); + assert!(preamble.validate().is_err()); +} + +#[tokio::test] +async fn server_triggers_success_callback() { + let factory = || WireframeApp::new().expect("WireframeApp::new failed"); + let (success_tx, success_rx) = tokio::sync::oneshot::channel::(); + let (failure_tx, failure_rx) = tokio::sync::oneshot::channel::<()>(); + let success_tx = std::sync::Arc::new(std::sync::Mutex::new(Some(success_tx))); + let failure_tx = std::sync::Arc::new(std::sync::Mutex::new(Some(failure_tx))); + let server = WireframeServer::new(factory) + .workers(1) + .with_preamble::() + .on_preamble_decode_success({ + let success_tx = success_tx.clone(); + move |p| { + if let Some(tx) = success_tx.lock().unwrap().take() { + let _ = tx.send(p.clone()); + } + } + }) + .on_preamble_decode_failure({ + let failure_tx = failure_tx.clone(); + move |_| { + if let Some(tx) = failure_tx.lock().unwrap().take() { + let _ = tx.send(()); + } + } + }); + let server = server.bind("127.0.0.1:0".parse().unwrap()).expect("bind"); + let addr = server.local_addr().expect("addr"); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let handle = tokio::spawn(async move { + server + .run_with_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + + let mut stream = TcpStream::connect(addr).await.unwrap(); + let bytes = b"TRTPHOTL\x00\x01\x00\x02"; + stream.write_all(bytes).await.unwrap(); + stream.shutdown().await.unwrap(); + + let preamble = timeout(Duration::from_secs(1), success_rx) + .await + .expect("timeout waiting for success") + .expect("success send"); + assert_eq!(preamble.magic, HotlinePreamble::MAGIC); + assert!( + timeout(Duration::from_millis(100), failure_rx) + .await + .is_err() + ); + + let _ = shutdown_tx.send(()); + handle.await.unwrap(); +} + +#[tokio::test] +async fn server_triggers_failure_callback() { + let factory = || WireframeApp::new().expect("WireframeApp::new failed"); + let (success_tx, success_rx) = tokio::sync::oneshot::channel::(); + let (failure_tx, failure_rx) = tokio::sync::oneshot::channel::<()>(); + let success_tx = std::sync::Arc::new(std::sync::Mutex::new(Some(success_tx))); + let failure_tx = std::sync::Arc::new(std::sync::Mutex::new(Some(failure_tx))); + let server = WireframeServer::new(factory) + .workers(1) + .with_preamble::() + .on_preamble_decode_success({ + let success_tx = success_tx.clone(); + move |p| { + if let Some(tx) = success_tx.lock().unwrap().take() { + let _ = tx.send(p.clone()); + } + } + }) + .on_preamble_decode_failure({ + let failure_tx = failure_tx.clone(); + move |_| { + if let Some(tx) = failure_tx.lock().unwrap().take() { + let _ = tx.send(()); + } + } + }); + let server = server.bind("127.0.0.1:0".parse().unwrap()).expect("bind"); + let addr = server.local_addr().expect("addr"); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let handle = tokio::spawn(async move { + server + .run_with_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + + let mut stream = TcpStream::connect(addr).await.unwrap(); + let bytes = b"TRTPHOT"; // truncated + stream.write_all(bytes).await.unwrap(); + stream.shutdown().await.unwrap(); + + timeout(Duration::from_secs(1), failure_rx) + .await + .expect("timeout waiting for failure") + .expect("failure send"); + assert!( + timeout(Duration::from_millis(100), success_rx) + .await + .is_err() + ); + + let _ = shutdown_tx.send(()); + handle.await.unwrap(); +}