diff --git a/.markdownlint-cli2.jsonc b/.markdownlint-cli2.jsonc index b1ed2065..a822f004 100644 --- a/.markdownlint-cli2.jsonc +++ b/.markdownlint-cli2.jsonc @@ -6,6 +6,7 @@ "line_length": 80, "code_block_line_length": 120, "tables": false - } + }, + "MD040": false } } diff --git a/Cargo.lock b/Cargo.lock index 64626b03..5d495cbd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,15 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + [[package]] name = "async-trait" version = "0.1.88" @@ -152,6 +161,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -176,6 +191,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + [[package]] name = "libc" version = "0.2.173" @@ -253,12 +274,91 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "regex" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + +[[package]] +name = "rstest" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97eeab2f3c0a199bc4be135c36c924b6590b88c377d416494288c14f2db30199" +dependencies = [ + "futures", + "futures-timer", + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d428f8247852f894ee1be110b375111b586d4fa431f6c46e64ba5a0dcccbe605" +dependencies = [ + "cfg-if", + "glob", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn", + "unicode-ident", +] + [[package]] name = "rustc-demangle" version = "0.1.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "989e6739f80c4ad5b13e0fd7fe89531180375b18520cc8c82080e4dc4035b84f" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "semver" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" + [[package]] name = "serde" version = "1.0.219" @@ -461,6 +561,7 @@ dependencies = [ "bytes", "futures", "log", + "rstest", "serde", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 821b2e32..71f202d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,5 +12,8 @@ async-trait = "0.1" bytes = "1" log = "0.4" +[dev-dependencies] +rstest = "0.18.2" + [lints.clippy] pedantic = "warn" diff --git a/src/server.rs b/src/server.rs index 12cca0ef..88d1347f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -283,19 +283,24 @@ where S: futures::Future + Send, { let listener = self.listener.expect("`bind` must be called before `run`"); - let (shutdown_tx, _) = broadcast::channel(16); + // Reserve one slot per worker so lagged messages remain visible during + // debugging. + let (shutdown_tx, _) = broadcast::channel(self.workers.max(1)); - // Spawn worker tasks. + // Spawn worker tasks, giving each its own shutdown receiver. let mut handles = Vec::with_capacity(self.workers); for _ in 0..self.workers { let listener = Arc::clone(&listener); let factory = self.factory.clone(); let on_success = self.on_preamble_success.clone(); let on_failure = self.on_preamble_failure.clone(); - let mut shutdown_rx = shutdown_tx.subscribe(); - handles.push(tokio::spawn(async move { - worker_task(listener, factory, on_success, on_failure, &mut shutdown_rx).await; - })); + handles.push(tokio::spawn(worker_task( + listener, + factory, + on_success, + on_failure, + shutdown_tx.subscribe(), + ))); } let join_all = futures::future::join_all(handles); @@ -326,7 +331,8 @@ async fn worker_task( factory: F, on_success: Option>, on_failure: Option>, - shutdown_rx: &mut broadcast::Receiver<()>, + // Each worker owns its shutdown receiver. + mut shutdown_rx: broadcast::Receiver<()>, ) where F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, // `Preamble` ensures `T` supports borrowed decoding. @@ -406,3 +412,407 @@ async fn process_stream( } } } + +#[cfg(test)] +mod tests { + use super::*; + use bincode::{Decode, Encode}; + use rstest::{fixture, rstest}; + use std::net::{Ipv4Addr, SocketAddr}; + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + use tokio::net::TcpListener; + use tokio::sync::broadcast; + use tokio::time::{Duration, timeout}; + + #[derive(Debug, Clone, PartialEq, Encode, Decode)] + struct TestPreamble { + id: u32, + message: String, + } + + #[derive(Debug, Clone, PartialEq, Encode, Decode)] + struct EmptyPreamble; + + #[fixture] + fn factory() -> impl Fn() -> WireframeApp + Send + Sync + Clone + 'static { + || WireframeApp::default() + } + + #[fixture] + fn free_port() -> SocketAddr { + let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0); + let listener = std::net::TcpListener::bind(addr).unwrap(); + listener.local_addr().unwrap() + } + + fn bind_server(factory: F, addr: SocketAddr) -> WireframeServer + where + F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, + { + WireframeServer::new(factory) + .bind(addr) + .expect("Failed to bind") + } + + fn server_with_preamble(factory: F) -> WireframeServer + where + F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, + { + WireframeServer::new(factory).with_preamble::() + } + + #[rstest] + fn test_new_server_creation( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + ) { + let server = WireframeServer::new(factory); + assert!(server.worker_count() >= 1); + assert!(server.local_addr().is_none()); + } + + #[rstest] + fn test_new_server_default_worker_count( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + ) { + let server = WireframeServer::new(factory); + let expected_workers = std::thread::available_parallelism() + .map_or(1, std::num::NonZeroUsize::get) + .max(1); + assert_eq!(server.worker_count(), expected_workers); + } + + #[rstest] + fn test_workers_configuration( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + ) { + let server = WireframeServer::new(factory); + + let server = server.workers(4); + assert_eq!(server.worker_count(), 4); + + let server = server.workers(100); + assert_eq!(server.worker_count(), 100); + + let server = server.workers(0); + assert_eq!(server.worker_count(), 1); + } + + #[rstest] + fn test_with_preamble_type_conversion( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + ) { + let server = WireframeServer::new(factory); + let server_with_preamble = server.with_preamble::(); + assert_eq!( + server_with_preamble.worker_count(), + std::thread::available_parallelism() + .map_or(1, std::num::NonZeroUsize::get) + .max(1) + ); + } + + #[rstest] + #[tokio::test] + async fn test_bind_success( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_port: SocketAddr, + ) { + let server = bind_server(factory, free_port); + let bound_addr = server.local_addr().unwrap(); + assert_eq!(bound_addr.ip(), free_port.ip()); + } + + #[rstest] + #[tokio::test] + async fn test_bind_invalid_address( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + ) { + let server = WireframeServer::new(factory); + let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 1); + let result = server.bind(addr); + assert!(result.is_ok() || result.is_err()); + } + + #[rstest] + fn test_local_addr_before_bind( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + ) { + let server = WireframeServer::new(factory); + assert!(server.local_addr().is_none()); + } + + #[rstest] + #[tokio::test] + async fn test_local_addr_after_bind( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_port: SocketAddr, + ) { + let server = bind_server(factory, free_port); + let local_addr = server.local_addr(); + assert!(local_addr.is_some()); + assert_eq!(local_addr.unwrap().ip(), free_port.ip()); + } + + #[rstest] + #[tokio::test] + async fn test_preamble_success_callback( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + ) { + let callback_counter = Arc::new(AtomicUsize::new(0)); + let counter_clone = callback_counter.clone(); + + let server = server_with_preamble(factory).on_preamble_decode_success( + move |_preamble: &TestPreamble| { + counter_clone.fetch_add(1, Ordering::SeqCst); + }, + ); + + assert_eq!(callback_counter.load(Ordering::SeqCst), 0); + assert!(server.on_preamble_success.is_some()); + } + + #[rstest] + #[tokio::test] + async fn test_preamble_failure_callback( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + ) { + let callback_counter = Arc::new(AtomicUsize::new(0)); + let counter_clone = callback_counter.clone(); + + let server = server_with_preamble(factory).on_preamble_decode_failure( + move |_error: &DecodeError| { + counter_clone.fetch_add(1, Ordering::SeqCst); + }, + ); + + assert_eq!(callback_counter.load(Ordering::SeqCst), 0); + assert!(server.on_preamble_failure.is_some()); + } + + #[rstest] + #[tokio::test] + async fn test_method_chaining( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_port: SocketAddr, + ) { + let callback_invoked = Arc::new(AtomicUsize::new(0)); + let counter_clone = callback_invoked.clone(); + + let server = WireframeServer::new(factory) + .workers(2) + .with_preamble::() + .on_preamble_decode_success(move |_: &TestPreamble| { + counter_clone.fetch_add(1, Ordering::SeqCst); + }) + .on_preamble_decode_failure(|_: &DecodeError| { + eprintln!("Preamble decode failed"); + }) + .bind(free_port) + .expect("Failed to bind"); + + assert_eq!(server.worker_count(), 2); + assert!(server.local_addr().is_some()); + assert!(server.on_preamble_success.is_some()); + assert!(server.on_preamble_failure.is_some()); + } + + #[rstest] + #[tokio::test] + #[should_panic(expected = "`bind` must be called before `run`")] + async fn test_run_without_bind_panics( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + ) { + let server = WireframeServer::new(factory); + let _ = timeout(Duration::from_millis(100), server.run()).await; + } + + #[rstest] + #[tokio::test] + #[should_panic(expected = "`bind` must be called before `run`")] + async fn test_run_with_shutdown_without_bind_panics( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + ) { + let server = WireframeServer::new(factory); + let shutdown_future = async { tokio::time::sleep(Duration::from_millis(10)).await }; + let _ = timeout( + Duration::from_millis(100), + server.run_with_shutdown(shutdown_future), + ) + .await; + } + + #[rstest] + #[tokio::test] + async fn test_run_with_immediate_shutdown( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_port: SocketAddr, + ) { + let server = WireframeServer::new(factory) + .workers(1) + .bind(free_port) + .expect("Failed to bind"); + + let shutdown_future = async {}; + + let result = timeout( + Duration::from_millis(1000), + server.run_with_shutdown(shutdown_future), + ) + .await; + + assert!(result.is_ok()); + assert!(result.unwrap().is_ok()); + } + + #[rstest] + #[tokio::test] + async fn test_server_graceful_shutdown_with_ctrl_c_simulation( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_port: SocketAddr, + ) { + let server = WireframeServer::new(factory) + .workers(2) + .bind(free_port) + .expect("Failed to bind"); + + let shutdown_future = async { + tokio::time::sleep(Duration::from_millis(50)).await; + }; + + let start = std::time::Instant::now(); + let result = timeout( + Duration::from_millis(1000), + server.run_with_shutdown(shutdown_future), + ) + .await; + let elapsed = start.elapsed(); + + assert!(result.is_ok()); + assert!(result.unwrap().is_ok()); + assert!(elapsed < Duration::from_millis(500)); + } + + #[rstest] + #[tokio::test] + async fn test_multiple_worker_creation( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_port: SocketAddr, + ) { + let _ = &factory; + let call_count = Arc::new(AtomicUsize::new(0)); + let call_count_clone = call_count.clone(); + + let factory = move || { + call_count_clone.fetch_add(1, Ordering::SeqCst); + WireframeApp::default() + }; + + let server = WireframeServer::new(factory) + .workers(3) + .bind(free_port) + .expect("Failed to bind"); + + let shutdown_future = async { + tokio::time::sleep(Duration::from_millis(10)).await; + }; + + let result = timeout( + Duration::from_millis(1000), + server.run_with_shutdown(shutdown_future), + ) + .await; + + assert!(result.is_ok()); + assert!(result.unwrap().is_ok()); + } + + #[rstest] + #[tokio::test] + async fn test_server_configuration_persistence( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_port: SocketAddr, + ) { + let server = WireframeServer::new(factory).workers(5); + + assert_eq!(server.worker_count(), 5); + + let server = server.bind(free_port).expect("Failed to bind"); + assert_eq!(server.worker_count(), 5); + assert!(server.local_addr().is_some()); + } + + #[rstest] + fn test_preamble_callbacks_reset_on_type_change( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + ) { + let server = WireframeServer::new(factory) + .on_preamble_decode_success(|&()| {}) + .on_preamble_decode_failure(|_: &DecodeError| {}); + + assert!(server.on_preamble_success.is_some()); + assert!(server.on_preamble_failure.is_some()); + + let server = server.with_preamble::(); + assert!(server.on_preamble_success.is_none()); + assert!(server.on_preamble_failure.is_none()); + } + + #[rstest] + #[tokio::test] + async fn test_worker_task_shutdown_signal( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + ) { + let (tx, rx) = broadcast::channel(1); + let listener = Arc::new(TcpListener::bind("127.0.0.1:0").await.unwrap()); + + let _ = tx.send(()); + + let task = tokio::spawn(worker_task::<_, ()>(listener, factory, None, None, rx)); + + let result = timeout(Duration::from_millis(100), task).await; + assert!(result.is_ok()); + } + + #[rstest] + fn test_extreme_worker_counts( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + ) { + let server = WireframeServer::new(factory); + + let server = server.workers(usize::MAX); + assert_eq!(server.worker_count(), usize::MAX); + + let server = server.workers(0); + assert_eq!(server.worker_count(), 1); + } + + #[rstest] + #[tokio::test] + async fn test_bind_to_multiple_addresses( + factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, + free_port: SocketAddr, + ) { + let server = WireframeServer::new(factory); + let addr1 = free_port; + let addr2 = { + let addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0); + let listener = std::net::TcpListener::bind(addr).unwrap(); + listener.local_addr().unwrap() + }; + + let server = server.bind(addr1).expect("Failed to bind first address"); + let first_local_addr = server.local_addr().unwrap(); + + let server = server.bind(addr2).expect("Failed to bind second address"); + let second_local_addr = server.local_addr().unwrap(); + + assert_ne!(first_local_addr.port(), second_local_addr.port()); + assert_eq!(second_local_addr.ip(), addr2.ip()); + } + + #[test] + fn test_server_debug_compilation_guard() { + assert!(cfg!(debug_assertions)); + } +}