From b837fc70f251cf9b734941835bf38c033fddd726 Mon Sep 17 00:00:00 2001 From: Leynos Date: Sat, 14 Jun 2025 04:08:16 +0100 Subject: [PATCH 01/10] Add tests for preamble callbacks --- Cargo.lock | 2 + Cargo.toml | 5 +- README.md | 1 + docs/roadmap.md | 7 ++- src/lib.rs | 1 + src/preamble.rs | 39 ++++++++++++ src/server.rs | 83 ++++++++++++++++++++++++- tests/preamble.rs | 151 ++++++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 282 insertions(+), 7 deletions(-) create mode 100644 src/preamble.rs create mode 100644 tests/preamble.rs diff --git a/Cargo.lock b/Cargo.lock index a80969b2..0a3e9bfd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -319,6 +319,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" dependencies = [ "backtrace", + "bytes", "libc", "mio", "pin-project-lite", @@ -453,6 +454,7 @@ dependencies = [ "bincode", "bytes", "futures", + "libc", "serde", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index ce7ff1e9..79e6afcf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,10 +6,13 @@ edition = "2024" [dependencies] serde = { version = "1", features = ["derive"] } bincode = "2" -tokio = { version = "1", default-features = false, features = ["net", "signal", "rt-multi-thread", "macros", "sync", "time"] } +tokio = { version = "1", default-features = false, features = ["net", "signal", "rt-multi-thread", "macros", "sync", "time", "io-util"] } futures = "0.3" async-trait = "0.1" bytes = "1" +[dev-dependencies] +libc = "0.2" + [lints.clippy] pedantic = "warn" diff --git a/README.md b/README.md index cd6fea42..c5daef0d 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ reduce this boilerplate through layered abstractions: - **Routing engine** that dispatches messages by ID - **Handler invocation** with extractor support - **Middleware chain** for request/response processing +- **Connection preamble** with customizable validation callbacks These layers correspond to the architecture outlined in the design document【F:docs/rust-binary-router-library-design.md†L292-L344】. diff --git a/docs/roadmap.md b/docs/roadmap.md index 5c6fd1ed..8b22212c 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -55,9 +55,10 @@ after formatting. Line numbers below refer to that file. } ``` -- [ ] Add connection preamble support. - Provide built-in parsing of a handshake preamble (e.g., Hotline's "TRTP") - and invoke a user-configured handler on success or failure. +- [x] Add connection preamble support. + Provide generic parsing of connection preambles with a Hotline handshake + example in the tests. Invoke user-configured callbacks on decode success + or failure. - [ ] Add response serialization and transmission. Encode handler responses using the selected serialization format and write them back through the framing layer. diff --git a/src/lib.rs b/src/lib.rs index d93632f1..b0436d66 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,4 +3,5 @@ pub mod extractor; pub mod frame; pub mod message; pub mod middleware; +pub mod preamble; pub mod server; diff --git a/src/preamble.rs b/src/preamble.rs new file mode 100644 index 00000000..05140df5 --- /dev/null +++ b/src/preamble.rs @@ -0,0 +1,39 @@ +use bincode::error::DecodeError; +use bincode::{Decode, config, decode_from_slice}; +use tokio::io::{AsyncRead, AsyncReadExt}; + +/// Read and decode a connection preamble using bincode. +/// +/// This helper reads the exact number of bytes required by `T`, as +/// indicated by [`DecodeError::UnexpectedEnd`]. Additional bytes are +/// requested from the reader until decoding succeeds or fails for some +/// other reason. +/// +/// # Errors +/// +/// Returns a [`DecodeError`] if decoding the preamble fails or an +/// underlying I/O error occurs while reading from `reader`. +pub async fn read_preamble(reader: &mut R) -> Result +where + R: AsyncRead + Unpin, + T: Decode<()>, +{ + let mut buf = Vec::new(); + let config = config::standard() + .with_big_endian() + .with_fixed_int_encoding(); + loop { + match decode_from_slice::(&buf, config) { + Ok((value, _)) => return Ok(value), + Err(DecodeError::UnexpectedEnd { additional }) => { + let start = buf.len(); + buf.resize(start + additional, 0); + reader + .read_exact(&mut buf[start..]) + .await + .map_err(|inner| DecodeError::Io { inner, additional })?; + } + Err(e) => return Err(e), + } + } +} diff --git a/src/server.rs b/src/server.rs index e1e850f0..d8685aaa 100644 --- a/src/server.rs +++ b/src/server.rs @@ -6,6 +6,11 @@ use tokio::net::TcpListener; use tokio::sync::broadcast; use tokio::time::{Duration, sleep}; +use core::marker::PhantomData; + +use crate::preamble::read_preamble; +use bincode::error::DecodeError; + use crate::app::WireframeApp; /// Tokio-based server for `WireframeApp` instances. @@ -15,16 +20,21 @@ use crate::app::WireframeApp; /// closure. The server listens for a shutdown signal using /// `tokio::signal::ctrl_c` and notifies all workers to stop /// accepting new connections. -pub struct WireframeServer +#[allow(clippy::type_complexity)] +pub struct WireframeServer where F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, + T: bincode::Decode<()> + Send + 'static, { factory: F, listener: Option>, workers: usize, + on_preamble_success: Option>, + on_preamble_failure: Option>, + _preamble: PhantomData, } -impl WireframeServer +impl WireframeServer where F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, { @@ -66,9 +76,34 @@ where factory, listener: None, workers, + on_preamble_success: None, + on_preamble_failure: None, + _preamble: PhantomData, } } + /// Convert this server to parse a custom preamble `T`. + #[must_use] + pub fn with_preamble(self) -> WireframeServer + where + T: bincode::Decode<()> + Send + 'static, + { + WireframeServer { + factory: self.factory, + listener: self.listener, + workers: self.workers, + on_preamble_success: None, + on_preamble_failure: None, + _preamble: PhantomData, + } + } +} + +impl WireframeServer +where + F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, + T: bincode::Decode<()> + Send + 'static, +{ /// Set the number of worker tasks to spawn for the server. /// /// Ensures at least one worker is configured. @@ -102,6 +137,27 @@ where self } + /// Register a callback invoked when the connection preamble + /// decodes successfully. + #[must_use] + pub fn on_preamble_decode_success(mut self, handler: H) -> Self + where + H: Fn(&T) + Send + Sync + 'static, + { + self.on_preamble_success = Some(Arc::new(handler)); + self + } + + /// Register a callback invoked when the connection preamble fails to decode. + #[must_use] + pub fn on_preamble_decode_failure(mut self, handler: H) -> Self + where + H: Fn(DecodeError) + Send + Sync + 'static, + { + self.on_preamble_failure = Some(Arc::new(handler)); + self + } + /// Get the configured worker count. #[inline] #[must_use] @@ -120,6 +176,12 @@ where self.workers } + /// Get the socket address the server is bound to, if available. + #[must_use] + pub fn local_addr(&self) -> Option { + self.listener.as_ref().and_then(|l| l.local_addr().ok()) + } + /// Bind the server to the given address and create a listener. /// /// # Errors @@ -202,13 +264,28 @@ where let mut shutdown_rx = shutdown_tx.subscribe(); let listener = Arc::clone(&listener); let factory = self.factory.clone(); + let on_success = self.on_preamble_success.clone(); + let on_failure = self.on_preamble_failure.clone(); handles.push(tokio::spawn(async move { let app = (factory)(); let mut delay = Duration::from_millis(10); loop { tokio::select! { res = listener.accept() => match res { - Ok((_stream, _)) => { + Ok((mut stream, _)) => { + let result = read_preamble::<_, T>(&mut stream).await; + match result { + Ok(preamble) => { + if let Some(handler) = on_success.as_ref() { + handler(&preamble); + } + } + Err(err) => { + if let Some(handler) = on_failure.as_ref() { + handler(err); + } + } + } // TODO: hand off stream to `app` delay = Duration::from_millis(10); } diff --git a/tests/preamble.rs b/tests/preamble.rs new file mode 100644 index 00000000..c7f39057 --- /dev/null +++ b/tests/preamble.rs @@ -0,0 +1,151 @@ +use bincode::error::DecodeError; +use libc; +use tokio::io::{AsyncWriteExt, duplex}; +use tokio::net::TcpStream; +use tokio::time::{Duration, timeout}; +use wireframe::preamble::read_preamble; +use wireframe::{app::WireframeApp, server::WireframeServer}; + +#[derive(Debug, Clone, PartialEq, Eq, bincode::Encode, bincode::Decode)] +struct HotlinePreamble { + /// Should always be `b"TRTPHOTL"`. + magic: [u8; 8], + /// Minimum server version this client supports. + min_version: u16, + /// Client version. + client_version: u16, +} + +impl HotlinePreamble { + const MAGIC: [u8; 8] = *b"TRTPHOTL"; + + fn validate(&self) -> Result<(), DecodeError> { + if self.magic != Self::MAGIC { + return Err(DecodeError::Other("invalid hotline preamble")); + } + Ok(()) + } +} + +#[tokio::test] +async fn parse_valid_preamble() { + let (mut client, mut server) = duplex(64); + let bytes = b"TRTPHOTL\x00\x01\x00\x02"; + client.write_all(bytes).await.unwrap(); + client.shutdown().await.unwrap(); + let p: HotlinePreamble = read_preamble(&mut server).await.expect("valid preamble"); + eprintln!("decoded: {:?}", p); + p.validate().unwrap(); + assert_eq!(p.magic, HotlinePreamble::MAGIC); + assert_eq!(p.min_version, 1); + assert_eq!(p.client_version, 2); +} + +#[tokio::test] +async fn invalid_magic_is_error() { + let (mut client, mut server) = duplex(64); + let bytes = b"WRONGMAG\x00\x01\x00\x02"; + client.write_all(bytes).await.unwrap(); + client.shutdown().await.unwrap(); + let preamble: HotlinePreamble = read_preamble(&mut server).await.expect("decoded"); + assert!(preamble.validate().is_err()); +} + +#[tokio::test] +async fn server_triggers_success_callback() { + let factory = || WireframeApp::new().expect("WireframeApp::new failed"); + let (success_tx, success_rx) = tokio::sync::oneshot::channel::(); + let (failure_tx, failure_rx) = tokio::sync::oneshot::channel::<()>(); + let success_tx = std::sync::Arc::new(std::sync::Mutex::new(Some(success_tx))); + let failure_tx = std::sync::Arc::new(std::sync::Mutex::new(Some(failure_tx))); + let server = WireframeServer::new(factory) + .workers(1) + .with_preamble::() + .on_preamble_decode_success({ + let success_tx = success_tx.clone(); + move |p| { + if let Some(tx) = success_tx.lock().unwrap().take() { + let _ = tx.send(p.clone()); + } + } + }) + .on_preamble_decode_failure({ + let failure_tx = failure_tx.clone(); + move |_| { + if let Some(tx) = failure_tx.lock().unwrap().take() { + let _ = tx.send(()); + } + } + }); + let server = server.bind("127.0.0.1:0".parse().unwrap()).expect("bind"); + let addr = server.local_addr().expect("addr"); + let handle = tokio::spawn(async move { server.run().await.unwrap() }); + + let mut stream = TcpStream::connect(addr).await.unwrap(); + let bytes = b"TRTPHOTL\x00\x01\x00\x02"; + stream.write_all(bytes).await.unwrap(); + stream.shutdown().await.unwrap(); + + let preamble = timeout(Duration::from_secs(1), success_rx) + .await + .expect("timeout waiting for success") + .expect("success send"); + assert_eq!(preamble.magic, HotlinePreamble::MAGIC); + assert!( + timeout(Duration::from_millis(100), failure_rx) + .await + .is_err() + ); + + unsafe { libc::raise(libc::SIGINT) }; + handle.await.unwrap(); +} + +#[tokio::test] +async fn server_triggers_failure_callback() { + let factory = || WireframeApp::new().expect("WireframeApp::new failed"); + let (success_tx, success_rx) = tokio::sync::oneshot::channel::(); + let (failure_tx, failure_rx) = tokio::sync::oneshot::channel::<()>(); + let success_tx = std::sync::Arc::new(std::sync::Mutex::new(Some(success_tx))); + let failure_tx = std::sync::Arc::new(std::sync::Mutex::new(Some(failure_tx))); + let server = WireframeServer::new(factory) + .workers(1) + .with_preamble::() + .on_preamble_decode_success({ + let success_tx = success_tx.clone(); + move |p| { + if let Some(tx) = success_tx.lock().unwrap().take() { + let _ = tx.send(p.clone()); + } + } + }) + .on_preamble_decode_failure({ + let failure_tx = failure_tx.clone(); + move |_| { + if let Some(tx) = failure_tx.lock().unwrap().take() { + let _ = tx.send(()); + } + } + }); + let server = server.bind("127.0.0.1:0".parse().unwrap()).expect("bind"); + let addr = server.local_addr().expect("addr"); + let handle = tokio::spawn(async move { server.run().await.unwrap() }); + + let mut stream = TcpStream::connect(addr).await.unwrap(); + let bytes = b"TRTPHOT"; // truncated + stream.write_all(bytes).await.unwrap(); + stream.shutdown().await.unwrap(); + + timeout(Duration::from_secs(1), failure_rx) + .await + .expect("timeout waiting for failure") + .expect("failure send"); + assert!( + timeout(Duration::from_millis(100), success_rx) + .await + .is_err() + ); + + unsafe { libc::raise(libc::SIGINT) }; + handle.await.unwrap(); +} From 5e5c3ec3e407362684c49fb0c33f1807ed325ef6 Mon Sep 17 00:00:00 2001 From: Leynos Date: Sat, 14 Jun 2025 04:39:33 +0100 Subject: [PATCH 02/10] Document preamble callbacks --- README.md | 2 +- docs/preamble-validator.md | 33 +++++++++++++++++++++++++++++++++ docs/roadmap.md | 2 +- 3 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 docs/preamble-validator.md diff --git a/README.md b/README.md index c5daef0d..e84b5cc2 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ reduce this boilerplate through layered abstractions: - **Routing engine** that dispatches messages by ID - **Handler invocation** with extractor support - **Middleware chain** for request/response processing -- **Connection preamble** with customizable validation callbacks +- **Connection preamble** with customizable validation callbacks [[docs](docs/preamble-validator.md)] These layers correspond to the architecture outlined in the design document【F:docs/rust-binary-router-library-design.md†L292-L344】. diff --git a/docs/preamble-validator.md b/docs/preamble-validator.md new file mode 100644 index 00000000..3ec99a54 --- /dev/null +++ b/docs/preamble-validator.md @@ -0,0 +1,33 @@ +# Connection Preamble Validation + +`wireframe` supports an optional connection preamble that is read as soon as a +client connects. The server decodes the preamble with +[`read_preamble`](../src/preamble.rs) and can invoke user-supplied callbacks on +success or failure. The helper uses `bincode` to decode any type implementing +`bincode::Decode` and reads exactly the number of bytes required. + +The flow is summarised below: + +```mermaid +sequenceDiagram + participant Client + participant Server + participant PreambleDecoder + participant SuccessCallback + participant FailureCallback + + Client->>Server: Connects and sends preamble bytes + Server->>PreambleDecoder: Reads and decodes preamble + alt Decode success + PreambleDecoder-->>Server: Decoded preamble (T) + Server->>SuccessCallback: Invoke with preamble data + else Decode failure + PreambleDecoder-->>Server: DecodeError + Server->>FailureCallback: Invoke with error + end + Server-->>Client: (Continues or closes connection) +``` + +In the tests a `HotlinePreamble` struct illustrates the pattern, but any +preamble type may be used. Register callbacks via `on_preamble_decode_success` +and `on_preamble_decode_failure` on `WireframeServer`. diff --git a/docs/roadmap.md b/docs/roadmap.md index 8b22212c..60919ecb 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -58,7 +58,7 @@ after formatting. Line numbers below refer to that file. - [x] Add connection preamble support. Provide generic parsing of connection preambles with a Hotline handshake example in the tests. Invoke user-configured callbacks on decode success - or failure. + or failure. See [preamble-validator](preamble-validator.md). - [ ] Add response serialization and transmission. Encode handler responses using the selected serialization format and write them back through the framing layer. From 1ef9d8227c1bd1b0888128d09b58d452c0209758 Mon Sep 17 00:00:00 2001 From: Leynos Date: Sat, 14 Jun 2025 04:39:39 +0100 Subject: [PATCH 03/10] Refine preamble handling and tests --- Cargo.lock | 5 +- Cargo.toml | 3 -- README.md | 2 +- src/preamble.rs | 39 +++++++++++--- src/server.rs | 131 ++++++++++++++++++++++++++++++++-------------- tests/preamble.rs | 34 +++++++++--- 6 files changed, 154 insertions(+), 60 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0a3e9bfd..cf39b728 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -178,9 +178,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "libc" -version = "0.2.172" +version = "0.2.173" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +checksum = "d8cfeafaffdbc32176b64fb251369d52ea9f0a8fbc6f8759edffef7b525d64bb" [[package]] name = "memchr" @@ -454,7 +454,6 @@ dependencies = [ "bincode", "bytes", "futures", - "libc", "serde", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 79e6afcf..17dd0710 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,8 +11,5 @@ futures = "0.3" async-trait = "0.1" bytes = "1" -[dev-dependencies] -libc = "0.2" - [lints.clippy] pedantic = "warn" diff --git a/README.md b/README.md index e84b5cc2..f22072f0 100644 --- a/README.md +++ b/README.md @@ -13,11 +13,11 @@ reduce this boilerplate through layered abstractions: - **Transport adapter** built on Tokio I/O - **Framing layer** for length‑prefixed or custom frames +- **Connection preamble** with customizable validation callbacks [[docs](docs/preamble-validator.md)] - **Serialization engine** using `bincode` or a `wire-rs` wrapper - **Routing engine** that dispatches messages by ID - **Handler invocation** with extractor support - **Middleware chain** for request/response processing -- **Connection preamble** with customizable validation callbacks [[docs](docs/preamble-validator.md)] These layers correspond to the architecture outlined in the design document【F:docs/rust-binary-router-library-design.md†L292-L344】. diff --git a/src/preamble.rs b/src/preamble.rs index 05140df5..57cad35c 100644 --- a/src/preamble.rs +++ b/src/preamble.rs @@ -1,6 +1,8 @@ use bincode::error::DecodeError; use bincode::{Decode, config, decode_from_slice}; -use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio::io::{self, AsyncRead, AsyncReadExt}; + +const MAX_PREAMBLE_LEN: usize = 1024; /// Read and decode a connection preamble using bincode. /// @@ -13,7 +15,7 @@ use tokio::io::{AsyncRead, AsyncReadExt}; /// /// Returns a [`DecodeError`] if decoding the preamble fails or an /// underlying I/O error occurs while reading from `reader`. -pub async fn read_preamble(reader: &mut R) -> Result +pub async fn read_preamble(reader: &mut R) -> Result<(T, Vec), DecodeError> where R: AsyncRead + Unpin, T: Decode<()>, @@ -24,14 +26,37 @@ where .with_fixed_int_encoding(); loop { match decode_from_slice::(&buf, config) { - Ok((value, _)) => return Ok(value), + Ok((value, consumed)) => { + let leftover = buf.split_off(consumed); + return Ok((value, leftover)); + } Err(DecodeError::UnexpectedEnd { additional }) => { let start = buf.len(); + if start + additional > MAX_PREAMBLE_LEN { + return Err(DecodeError::Other("preamble too long")); + } buf.resize(start + additional, 0); - reader - .read_exact(&mut buf[start..]) - .await - .map_err(|inner| DecodeError::Io { inner, additional })?; + let mut read = 0; + while read < additional { + match reader + .read(&mut buf[start + read..start + additional]) + .await + { + Ok(0) => { + return Err(DecodeError::Io { + inner: io::Error::from(io::ErrorKind::UnexpectedEof), + additional: additional - read, + }); + } + Ok(n) => read += n, + Err(inner) => { + return Err(DecodeError::Io { + inner, + additional: additional - read, + }); + } + } + } } Err(e) => return Err(e), } diff --git a/src/server.rs b/src/server.rs index d8685aaa..b42e8389 100644 --- a/src/server.rs +++ b/src/server.rs @@ -30,7 +30,7 @@ where listener: Option>, workers: usize, on_preamble_success: Option>, - on_preamble_failure: Option>, + on_preamble_failure: Option>, _preamble: PhantomData, } @@ -83,6 +83,9 @@ where } /// Convert this server to parse a custom preamble `T`. + /// + /// Call this before registering preamble handlers, otherwise any + /// previously configured callbacks will be dropped. #[must_use] pub fn with_preamble(self) -> WireframeServer where @@ -152,7 +155,7 @@ where #[must_use] pub fn on_preamble_decode_failure(mut self, handler: H) -> Self where - H: Fn(DecodeError) + Send + Sync + 'static, + H: Fn(&DecodeError) + Send + Sync + 'static, { self.on_preamble_failure = Some(Arc::new(handler)); self @@ -255,66 +258,116 @@ where /// } /// ``` pub async fn run(self) -> io::Result<()> { + self.run_with_shutdown(async { + let _ = tokio::signal::ctrl_c().await; + }) + .await + } + + /// Run the server until the `shutdown` future resolves. + /// + /// # Errors + /// + /// Returns an [`io::Error`] if accepting a connection fails during + /// runtime. + /// + /// # Panics + /// + /// Panics if [`bind`](Self::bind) was not called beforehand. + pub async fn run_with_shutdown(self, shutdown: S) -> io::Result<()> + where + S: futures::Future + Send, + { let listener = self.listener.expect("`bind` must be called before `run`"); let (shutdown_tx, _) = broadcast::channel(16); - // Spawn worker tasks using Tokio's runtime. + // Spawn worker tasks. let mut handles = Vec::with_capacity(self.workers); for _ in 0..self.workers { - let mut shutdown_rx = shutdown_tx.subscribe(); let listener = Arc::clone(&listener); let factory = self.factory.clone(); let on_success = self.on_preamble_success.clone(); let on_failure = self.on_preamble_failure.clone(); + let mut shutdown_rx = shutdown_tx.subscribe(); handles.push(tokio::spawn(async move { - let app = (factory)(); - let mut delay = Duration::from_millis(10); - loop { - tokio::select! { - res = listener.accept() => match res { - Ok((mut stream, _)) => { - let result = read_preamble::<_, T>(&mut stream).await; - match result { - Ok(preamble) => { - if let Some(handler) = on_success.as_ref() { - handler(&preamble); - } - } - Err(err) => { - if let Some(handler) = on_failure.as_ref() { - handler(err); - } - } - } - // TODO: hand off stream to `app` - delay = Duration::from_millis(10); - } - Err(e) => { - eprintln!("accept error: {e}"); - sleep(delay).await; - delay = (delay * 2).min(Duration::from_secs(1)); - } - }, - _ = shutdown_rx.recv() => break, - } - } - drop(app); + worker_task( + listener, + factory, + on_success.as_ref(), + on_failure.as_ref(), + &mut shutdown_rx, + ) + .await; })); } - // Wait for Ctrl+C or workers finishing. let join_all = futures::future::join_all(handles); tokio::pin!(join_all); tokio::select! { - _ = tokio::signal::ctrl_c() => { + () = shutdown => { let _ = shutdown_tx.send(()); } _ = &mut join_all => {} } - // Ensure all workers have exited before returning. join_all.await; Ok(()) } } + +#[allow(clippy::type_complexity)] +async fn worker_task( + listener: Arc, + factory: F, + on_success: Option<&Arc>, + on_failure: Option<&Arc>, + shutdown_rx: &mut broadcast::Receiver<()>, +) where + F: Fn() -> WireframeApp, + T: bincode::Decode<()> + Send + 'static, +{ + let app = (factory)(); + let mut delay = Duration::from_millis(10); + loop { + tokio::select! { + res = listener.accept() => match res { + Ok((stream, _)) => { + process_stream(stream, on_success, on_failure).await; + delay = Duration::from_millis(10); + } + Err(e) => { + eprintln!("accept error: {e}"); + sleep(delay).await; + delay = (delay * 2).min(Duration::from_secs(1)); + } + }, + _ = shutdown_rx.recv() => break, + } + } + drop(app); +} + +#[allow(clippy::type_complexity)] +async fn process_stream( + mut stream: tokio::net::TcpStream, + on_success: Option<&Arc>, + on_failure: Option<&Arc>, +) where + T: bincode::Decode<()> + Send + 'static, +{ + match read_preamble::<_, T>(&mut stream).await { + Ok((preamble, _)) => { + if let Some(handler) = on_success.as_ref() { + handler(&preamble); + } + // TODO: hand off stream to application + } + Err(err) => { + if let Some(handler) = on_failure.as_ref() { + handler(&err); + } + // drop stream on failure + } + } +} diff --git a/tests/preamble.rs b/tests/preamble.rs index c7f39057..4446ea6f 100644 --- a/tests/preamble.rs +++ b/tests/preamble.rs @@ -1,7 +1,7 @@ use bincode::error::DecodeError; -use libc; use tokio::io::{AsyncWriteExt, duplex}; use tokio::net::TcpStream; +use tokio::sync::oneshot; use tokio::time::{Duration, timeout}; use wireframe::preamble::read_preamble; use wireframe::{app::WireframeApp, server::WireframeServer}; @@ -33,7 +33,9 @@ async fn parse_valid_preamble() { let bytes = b"TRTPHOTL\x00\x01\x00\x02"; client.write_all(bytes).await.unwrap(); client.shutdown().await.unwrap(); - let p: HotlinePreamble = read_preamble(&mut server).await.expect("valid preamble"); + let (p, _) = read_preamble::<_, HotlinePreamble>(&mut server) + .await + .expect("valid preamble"); eprintln!("decoded: {:?}", p); p.validate().unwrap(); assert_eq!(p.magic, HotlinePreamble::MAGIC); @@ -47,7 +49,9 @@ async fn invalid_magic_is_error() { let bytes = b"WRONGMAG\x00\x01\x00\x02"; client.write_all(bytes).await.unwrap(); client.shutdown().await.unwrap(); - let preamble: HotlinePreamble = read_preamble(&mut server).await.expect("decoded"); + let (preamble, _) = read_preamble::<_, HotlinePreamble>(&mut server) + .await + .expect("decoded"); assert!(preamble.validate().is_err()); } @@ -79,7 +83,15 @@ async fn server_triggers_success_callback() { }); let server = server.bind("127.0.0.1:0".parse().unwrap()).expect("bind"); let addr = server.local_addr().expect("addr"); - let handle = tokio::spawn(async move { server.run().await.unwrap() }); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let handle = tokio::spawn(async move { + server + .run_with_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); let mut stream = TcpStream::connect(addr).await.unwrap(); let bytes = b"TRTPHOTL\x00\x01\x00\x02"; @@ -97,7 +109,7 @@ async fn server_triggers_success_callback() { .is_err() ); - unsafe { libc::raise(libc::SIGINT) }; + let _ = shutdown_tx.send(()); handle.await.unwrap(); } @@ -129,7 +141,15 @@ async fn server_triggers_failure_callback() { }); let server = server.bind("127.0.0.1:0".parse().unwrap()).expect("bind"); let addr = server.local_addr().expect("addr"); - let handle = tokio::spawn(async move { server.run().await.unwrap() }); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let handle = tokio::spawn(async move { + server + .run_with_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); let mut stream = TcpStream::connect(addr).await.unwrap(); let bytes = b"TRTPHOT"; // truncated @@ -146,6 +166,6 @@ async fn server_triggers_failure_callback() { .is_err() ); - unsafe { libc::raise(libc::SIGINT) }; + let _ = shutdown_tx.send(()); handle.await.unwrap(); } From dd75bce62b98416772ab207d12084ad9532c192e Mon Sep 17 00:00:00 2001 From: Leynos Date: Sat, 14 Jun 2025 12:44:11 +0100 Subject: [PATCH 04/10] Refactor preamble reader --- src/preamble.rs | 64 +++++++++++++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/src/preamble.rs b/src/preamble.rs index 57cad35c..633f4ee3 100644 --- a/src/preamble.rs +++ b/src/preamble.rs @@ -4,6 +4,43 @@ use tokio::io::{self, AsyncRead, AsyncReadExt}; const MAX_PREAMBLE_LEN: usize = 1024; +async fn read_more( + reader: &mut R, + buf: &mut Vec, + additional: usize, +) -> Result<(), DecodeError> +where + R: AsyncRead + Unpin, +{ + let start = buf.len(); + if start + additional > MAX_PREAMBLE_LEN { + return Err(DecodeError::Other("preamble too long")); + } + buf.resize(start + additional, 0); + let mut read = 0; + while read < additional { + match reader + .read(&mut buf[start + read..start + additional]) + .await + { + Ok(0) => { + return Err(DecodeError::Io { + inner: io::Error::from(io::ErrorKind::UnexpectedEof), + additional: additional - read, + }); + } + Ok(n) => read += n, + Err(inner) => { + return Err(DecodeError::Io { + inner, + additional: additional - read, + }); + } + } + } + Ok(()) +} + /// Read and decode a connection preamble using bincode. /// /// This helper reads the exact number of bytes required by `T`, as @@ -31,32 +68,7 @@ where return Ok((value, leftover)); } Err(DecodeError::UnexpectedEnd { additional }) => { - let start = buf.len(); - if start + additional > MAX_PREAMBLE_LEN { - return Err(DecodeError::Other("preamble too long")); - } - buf.resize(start + additional, 0); - let mut read = 0; - while read < additional { - match reader - .read(&mut buf[start + read..start + additional]) - .await - { - Ok(0) => { - return Err(DecodeError::Io { - inner: io::Error::from(io::ErrorKind::UnexpectedEof), - additional: additional - read, - }); - } - Ok(n) => read += n, - Err(inner) => { - return Err(DecodeError::Io { - inner, - additional: additional - read, - }); - } - } - } + read_more(reader, &mut buf, additional).await?; } Err(e) => return Err(e), } From 4c57460b28e33afb6a8809e2dd4bb361e8c8468e Mon Sep 17 00:00:00 2001 From: Leynos Date: Sat, 14 Jun 2025 14:12:25 +0100 Subject: [PATCH 05/10] Spawn per-connection tasks --- docs/preamble-validator.md | 4 ++-- src/preamble.rs | 1 + src/server.rs | 21 ++++++++------------- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/docs/preamble-validator.md b/docs/preamble-validator.md index 3ec99a54..f0905fd2 100644 --- a/docs/preamble-validator.md +++ b/docs/preamble-validator.md @@ -6,7 +6,7 @@ client connects. The server decodes the preamble with success or failure. The helper uses `bincode` to decode any type implementing `bincode::Decode` and reads exactly the number of bytes required. -The flow is summarised below: +The flow is summarized below: ```mermaid sequenceDiagram @@ -28,6 +28,6 @@ sequenceDiagram Server-->>Client: (Continues or closes connection) ``` -In the tests a `HotlinePreamble` struct illustrates the pattern, but any +In the tests, a `HotlinePreamble` struct illustrates the pattern, but any preamble type may be used. Register callbacks via `on_preamble_decode_success` and `on_preamble_decode_failure` on `WireframeServer`. diff --git a/src/preamble.rs b/src/preamble.rs index 633f4ee3..84164fbb 100644 --- a/src/preamble.rs +++ b/src/preamble.rs @@ -58,6 +58,7 @@ where T: Decode<()>, { let mut buf = Vec::new(); + // Build the decoder configuration once to avoid repeated allocations. let config = config::standard() .with_big_endian() .with_fixed_int_encoding(); diff --git a/src/server.rs b/src/server.rs index b42e8389..bb69ce7a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -290,14 +290,7 @@ where let on_failure = self.on_preamble_failure.clone(); let mut shutdown_rx = shutdown_tx.subscribe(); handles.push(tokio::spawn(async move { - worker_task( - listener, - factory, - on_success.as_ref(), - on_failure.as_ref(), - &mut shutdown_rx, - ) - .await; + worker_task(listener, factory, on_success, on_failure, &mut shutdown_rx).await; })); } @@ -320,8 +313,8 @@ where async fn worker_task( listener: Arc, factory: F, - on_success: Option<&Arc>, - on_failure: Option<&Arc>, + on_success: Option>, + on_failure: Option>, shutdown_rx: &mut broadcast::Receiver<()>, ) where F: Fn() -> WireframeApp, @@ -333,7 +326,9 @@ async fn worker_task( tokio::select! { res = listener.accept() => match res { Ok((stream, _)) => { - process_stream(stream, on_success, on_failure).await; + let success = on_success.clone(); + let failure = on_failure.clone(); + tokio::spawn(process_stream(stream, success, failure)); delay = Duration::from_millis(10); } Err(e) => { @@ -351,8 +346,8 @@ async fn worker_task( #[allow(clippy::type_complexity)] async fn process_stream( mut stream: tokio::net::TcpStream, - on_success: Option<&Arc>, - on_failure: Option<&Arc>, + on_success: Option>, + on_failure: Option>, ) where T: bincode::Decode<()> + Send + 'static, { From 429aad4092f0ca4330746c5003bdf68cde79bfa7 Mon Sep 17 00:00:00 2001 From: Leynos Date: Sat, 14 Jun 2025 15:31:33 +0100 Subject: [PATCH 06/10] Document decode context --- src/preamble.rs | 2 ++ src/server.rs | 12 ++++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/preamble.rs b/src/preamble.rs index 84164fbb..81c9bc0c 100644 --- a/src/preamble.rs +++ b/src/preamble.rs @@ -55,6 +55,8 @@ where pub async fn read_preamble(reader: &mut R) -> Result<(T, Vec), DecodeError> where R: AsyncRead + Unpin, + // `Decode` expects a decoding context type, not a lifetime. Most callers + // use the unit type as the context, which requires no external state. T: Decode<()>, { let mut buf = Vec::new(); diff --git a/src/server.rs b/src/server.rs index bb69ce7a..d0313158 100644 --- a/src/server.rs +++ b/src/server.rs @@ -24,6 +24,8 @@ use crate::app::WireframeApp; pub struct WireframeServer where F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, + // `Decode`'s type parameter represents a decoding context. + // The unit type signals that no context is required. T: bincode::Decode<()> + Send + 'static, { factory: F, @@ -89,6 +91,7 @@ where #[must_use] pub fn with_preamble(self) -> WireframeServer where + // Unit context indicates no external state is required when decoding. T: bincode::Decode<()> + Send + 'static, { WireframeServer { @@ -105,6 +108,7 @@ where impl WireframeServer where F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, + // `Decode` is generic over a context type; we use `()` here. T: bincode::Decode<()> + Send + 'static, { /// Set the number of worker tasks to spawn for the server. @@ -318,6 +322,7 @@ async fn worker_task( shutdown_rx: &mut broadcast::Receiver<()>, ) where F: Fn() -> WireframeApp, + // The unit context indicates no additional state is needed to decode `T`. T: bincode::Decode<()> + Send + 'static, { let app = (factory)(); @@ -349,14 +354,17 @@ async fn process_stream( on_success: Option>, on_failure: Option>, ) where + // The decoding context parameter is `()`; no external state is needed. T: bincode::Decode<()> + Send + 'static, { match read_preamble::<_, T>(&mut stream).await { - Ok((preamble, _)) => { + Ok((preamble, _leftover)) => { if let Some(handler) = on_success.as_ref() { handler(&preamble); } - // TODO: hand off stream to application + // TODO: hand off `stream` **and** `leftover` to the application layer, + // e.g. by wrapping the stream in a struct that replays `leftover` + // before delegating to the underlying socket. } Err(err) => { if let Some(handler) = on_failure.as_ref() { From 0be9b537c02c6d280c25b343ca5eb257874a26ae Mon Sep 17 00:00:00 2001 From: Leynos Date: Sat, 14 Jun 2025 16:07:43 +0100 Subject: [PATCH 07/10] Log worker failures and avoid unused app --- src/server.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/server.rs b/src/server.rs index d0313158..d57c0971 100644 --- a/src/server.rs +++ b/src/server.rs @@ -308,7 +308,11 @@ where _ = &mut join_all => {} } - join_all.await; + for res in join_all.await { + if let Err(e) = res { + eprintln!("worker task failed: {e}"); + } + } Ok(()) } } @@ -316,7 +320,7 @@ where #[allow(clippy::type_complexity)] async fn worker_task( listener: Arc, - factory: F, + _factory: F, on_success: Option>, on_failure: Option>, shutdown_rx: &mut broadcast::Receiver<()>, @@ -325,7 +329,6 @@ async fn worker_task( // The unit context indicates no additional state is needed to decode `T`. T: bincode::Decode<()> + Send + 'static, { - let app = (factory)(); let mut delay = Duration::from_millis(10); loop { tokio::select! { @@ -345,7 +348,6 @@ async fn worker_task( _ = shutdown_rx.recv() => break, } } - drop(app); } #[allow(clippy::type_complexity)] From ac74c8fb46c448f473024812cfeb913bda80f6c8 Mon Sep 17 00:00:00 2001 From: Leynos Date: Sat, 14 Jun 2025 17:30:14 +0100 Subject: [PATCH 08/10] Handle leftover bytes in preamble --- src/server.rs | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/server.rs b/src/server.rs index d57c0971..ae199042 100644 --- a/src/server.rs +++ b/src/server.rs @@ -289,12 +289,11 @@ where let mut handles = Vec::with_capacity(self.workers); for _ in 0..self.workers { let listener = Arc::clone(&listener); - let factory = self.factory.clone(); let on_success = self.on_preamble_success.clone(); let on_failure = self.on_preamble_failure.clone(); let mut shutdown_rx = shutdown_tx.subscribe(); handles.push(tokio::spawn(async move { - worker_task(listener, factory, on_success, on_failure, &mut shutdown_rx).await; + worker_task(listener, on_success, on_failure, &mut shutdown_rx).await; })); } @@ -318,14 +317,12 @@ where } #[allow(clippy::type_complexity)] -async fn worker_task( +async fn worker_task( listener: Arc, - _factory: F, on_success: Option>, on_failure: Option>, shutdown_rx: &mut broadcast::Receiver<()>, ) where - F: Fn() -> WireframeApp, // The unit context indicates no additional state is needed to decode `T`. T: bincode::Decode<()> + Send + 'static, { @@ -360,13 +357,13 @@ async fn process_stream( T: bincode::Decode<()> + Send + 'static, { match read_preamble::<_, T>(&mut stream).await { - Ok((preamble, _leftover)) => { + Ok((preamble, leftover)) => { + let _ = &leftover; // retain for future replay logic if let Some(handler) = on_success.as_ref() { handler(&preamble); } - // TODO: hand off `stream` **and** `leftover` to the application layer, - // e.g. by wrapping the stream in a struct that replays `leftover` - // before delegating to the underlying socket. + // TODO: wrap `stream` so that `leftover` is replayed before + // delegating to the underlying socket (e.g. `RewindableStream`). } Err(err) => { if let Some(handler) = on_failure.as_ref() { From 7b0c371c7919fe0b9f8b96f8b50aa513d986a11f Mon Sep 17 00:00:00 2001 From: Leynos Date: Sat, 14 Jun 2025 20:29:01 +0100 Subject: [PATCH 09/10] Add stream wrapper for leftover preamble bytes --- src/lib.rs | 1 + src/rewind_stream.rs | 71 ++++++++++++++++++++++++++++++++++++++++++++ src/server.rs | 6 ++-- 3 files changed, 75 insertions(+), 3 deletions(-) create mode 100644 src/rewind_stream.rs diff --git a/src/lib.rs b/src/lib.rs index b0436d66..d9acc4e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,4 +4,5 @@ pub mod frame; pub mod message; pub mod middleware; pub mod preamble; +pub mod rewind_stream; pub mod server; diff --git a/src/rewind_stream.rs b/src/rewind_stream.rs new file mode 100644 index 00000000..53e0ed3c --- /dev/null +++ b/src/rewind_stream.rs @@ -0,0 +1,71 @@ +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// A stream adapter that replays buffered bytes before reading +/// from the underlying stream. +pub struct RewindStream { + leftover: Vec, + pos: usize, + inner: S, +} + +impl RewindStream { + /// Create a new `RewindStream` that will yield `leftover` before + /// delegating to `inner`. + pub fn new(leftover: Vec, inner: S) -> Self { + Self { + leftover, + pos: 0, + inner, + } + } +} + +impl AsyncRead for RewindStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.pos < self.leftover.len() { + let remaining = self.leftover.len() - self.pos; + let to_copy = remaining.min(buf.remaining()); + let start = self.pos; + let end = start + to_copy; + buf.put_slice(&self.leftover[start..end]); + self.pos += to_copy; + if self.pos < self.leftover.len() || to_copy > 0 { + return Poll::Ready(Ok(())); + } + } + + if self.pos >= self.leftover.len() && !self.leftover.is_empty() { + self.leftover.clear(); + } + + Pin::new(&mut self.inner).poll_read(cx, buf) + } +} + +impl AsyncWrite for RewindStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +impl Unpin for RewindStream {} diff --git a/src/server.rs b/src/server.rs index ae199042..4141bfc7 100644 --- a/src/server.rs +++ b/src/server.rs @@ -9,6 +9,7 @@ use tokio::time::{Duration, sleep}; use core::marker::PhantomData; use crate::preamble::read_preamble; +use crate::rewind_stream::RewindStream; use bincode::error::DecodeError; use crate::app::WireframeApp; @@ -358,12 +359,11 @@ async fn process_stream( { match read_preamble::<_, T>(&mut stream).await { Ok((preamble, leftover)) => { - let _ = &leftover; // retain for future replay logic + let _stream = RewindStream::new(leftover, stream); if let Some(handler) = on_success.as_ref() { handler(&preamble); } - // TODO: wrap `stream` so that `leftover` is replayed before - // delegating to the underlying socket (e.g. `RewindableStream`). + // `RewindStream` plays back leftover bytes before using the socket. } Err(err) => { if let Some(handler) = on_failure.as_ref() { From 8c824fcbfb440ba0a62865011ae5004033517565 Mon Sep 17 00:00:00 2001 From: Leynos Date: Sat, 14 Jun 2025 20:58:30 +0100 Subject: [PATCH 10/10] Forward connections to application --- src/app.rs | 12 ++++++++++++ src/rewind_stream.rs | 1 + src/server.rs | 22 ++++++++++++++++------ 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/app.rs b/src/app.rs index 291e654d..122a1316 100644 --- a/src/app.rs +++ b/src/app.rs @@ -82,4 +82,16 @@ impl WireframeApp { self.middleware.push(Box::new(mw)); Ok(self) } + + /// Handle an accepted connection. + /// + /// This placeholder simply drops the stream. Future implementations + /// will decode frames and dispatch them to registered handlers. + pub async fn handle_connection(&self, _stream: S) + where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static, + { + // Connection handling will be implemented later. + tokio::task::yield_now().await; + } } diff --git a/src/rewind_stream.rs b/src/rewind_stream.rs index 53e0ed3c..65a5ff76 100644 --- a/src/rewind_stream.rs +++ b/src/rewind_stream.rs @@ -44,6 +44,7 @@ impl AsyncRead for RewindStream { if self.pos >= self.leftover.len() && !self.leftover.is_empty() { self.leftover.clear(); + self.pos = 0; } Pin::new(&mut self.inner).poll_read(cx, buf) diff --git a/src/server.rs b/src/server.rs index 4141bfc7..46c21f64 100644 --- a/src/server.rs +++ b/src/server.rs @@ -290,11 +290,12 @@ where let mut handles = Vec::with_capacity(self.workers); for _ in 0..self.workers { let listener = Arc::clone(&listener); + let factory = self.factory.clone(); let on_success = self.on_preamble_success.clone(); let on_failure = self.on_preamble_failure.clone(); let mut shutdown_rx = shutdown_tx.subscribe(); handles.push(tokio::spawn(async move { - worker_task(listener, on_success, on_failure, &mut shutdown_rx).await; + worker_task(listener, factory, on_success, on_failure, &mut shutdown_rx).await; })); } @@ -318,12 +319,14 @@ where } #[allow(clippy::type_complexity)] -async fn worker_task( +async fn worker_task( listener: Arc, + factory: F, on_success: Option>, on_failure: Option>, shutdown_rx: &mut broadcast::Receiver<()>, ) where + F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, // The unit context indicates no additional state is needed to decode `T`. T: bincode::Decode<()> + Send + 'static, { @@ -334,7 +337,8 @@ async fn worker_task( Ok((stream, _)) => { let success = on_success.clone(); let failure = on_failure.clone(); - tokio::spawn(process_stream(stream, success, failure)); + let factory = factory.clone(); + tokio::spawn(process_stream(stream, factory, success, failure)); delay = Duration::from_millis(10); } Err(e) => { @@ -349,21 +353,27 @@ async fn worker_task( } #[allow(clippy::type_complexity)] -async fn process_stream( +async fn process_stream( mut stream: tokio::net::TcpStream, + factory: F, on_success: Option>, on_failure: Option>, ) where + F: Fn() -> WireframeApp + Send + Sync + 'static, // The decoding context parameter is `()`; no external state is needed. T: bincode::Decode<()> + Send + 'static, { match read_preamble::<_, T>(&mut stream).await { Ok((preamble, leftover)) => { - let _stream = RewindStream::new(leftover, stream); if let Some(handler) = on_success.as_ref() { handler(&preamble); } - // `RewindStream` plays back leftover bytes before using the socket. + let stream = RewindStream::new(leftover, stream); + // Hand the connection to the application for processing. + let app = (factory)(); + tokio::spawn(async move { + app.handle_connection(stream).await; + }); } Err(err) => { if let Some(handler) = on_failure.as_ref() {