diff --git a/Cargo.lock b/Cargo.lock index eba1e08b..df99e98a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -149,7 +149,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" dependencies = [ "cfg-if", - "hashbrown", + "hashbrown 0.14.5", "lock_api", "once_cell", "parking_lot_core", @@ -316,6 +316,12 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +[[package]] +name = "hashbrown" +version = "0.15.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" + [[package]] name = "lazy_static" version = "1.5.0" @@ -864,6 +870,8 @@ dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", + "hashbrown 0.15.4", "pin-project-lite", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 6e879628..8cb4cd46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ edition = "2024" serde = { version = "1", features = ["derive"] } bincode = "2" tokio = { version = "1", default-features = false, features = ["net", "signal", "rt-multi-thread", "macros", "sync", "time", "io-util"] } -tokio-util = "0.7" +tokio-util = { version = "0.7", features = ["rt"] } futures = "0.3" async-trait = "0.1" bytes = "1" diff --git a/docs/asynchronous-outbound-messaging-roadmap.md b/docs/asynchronous-outbound-messaging-roadmap.md index 488b5437..b0059dc2 100644 --- a/docs/asynchronous-outbound-messaging-roadmap.md +++ b/docs/asynchronous-outbound-messaging-roadmap.md @@ -38,7 +38,7 @@ design documents. ## 3. Production Hardening -- [ ] **Graceful shutdown** using `CancellationToken` and `TaskTracker` +- [x] **Graceful shutdown** using `CancellationToken` and `TaskTracker` ([Resilience Guide §2][resilience-shutdown]). - [ ] **Typed `WireframeError`** for recoverable protocol errors ([Design §5][design-errors]). diff --git a/docs/hardening-wireframe-a-guide-to-production-resilience.md b/docs/hardening-wireframe-a-guide-to-production-resilience.md index dd355bd6..3135e1fa 100644 --- a/docs/hardening-wireframe-a-guide-to-production-resilience.md +++ b/docs/hardening-wireframe-a-guide-to-production-resilience.md @@ -92,6 +92,61 @@ pub async fn run_connection( This pattern ensures that whether a connection ends due to a client disconnect, an error, or a server-wide shutdown, the task terminates cleanly and reliably. +#### Sequence Diagram + +```mermaid +sequenceDiagram + participant MainServer + participant TaskTracker + participant CancellationToken + participant WorkerTask + + MainServer->>TaskTracker: spawn(worker_task(..., token, tracker)) + loop For each worker + TaskTracker->>WorkerTask: Start worker_task + end + MainServer->>CancellationToken: Wait for shutdown signal + CancellationToken-->>WorkerTask: Signal cancellation + WorkerTask-->>TaskTracker: Complete and notify + MainServer->>TaskTracker: tracker.wait() + TaskTracker-->>MainServer: All tasks complete + MainServer->>CancellationToken: cancel() + MainServer->>TaskTracker: close() + MainServer->>TaskTracker: tracker.wait().await + TaskTracker-->>MainServer: Confirm shutdown complete +``` + +#### Class Diagram + +```mermaid +classDiagram + class Server { + +listener + +factory + +on_preamble_success + +on_preamble_failure + +workers + +run(shutdown) + } + class TaskTracker { + +spawn(task) + +wait() + +close() + } + class CancellationToken { + +cancel() + +cancelled() + } + class WorkerTask { + +worker_task(listener, factory, on_success, on_failure, shutdown, tracker) + } + Server --> TaskTracker : uses + Server --> CancellationToken : uses + TaskTracker --> WorkerTask : spawns + WorkerTask --> CancellationToken : checks + WorkerTask --> TaskTracker : notifies +``` + ## 3. Meticulous Resource Management Long-running servers are exquisitely sensitive to resource leaks. `wireframe` diff --git a/src/server.rs b/src/server.rs index e9e515a1..82f911fa 100644 --- a/src/server.rs +++ b/src/server.rs @@ -28,9 +28,9 @@ pub type PreambleCallback = Arc< pub type PreambleErrorCallback = Arc; use tokio::{ net::TcpListener, - sync::broadcast, time::{Duration, sleep}, }; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; use crate::{ app::WireframeApp, @@ -311,41 +311,29 @@ where S: futures::Future + Send, { let listener = self.listener.expect("`bind` must be called before `run`"); - // Reserve one slot per worker so lagged messages remain visible during - // debugging. - let (shutdown_tx, _) = broadcast::channel(self.workers.max(1)); - // Spawn worker tasks, giving each its own shutdown receiver. - let mut handles = Vec::with_capacity(self.workers); + let shutdown_token = CancellationToken::new(); + let tracker = TaskTracker::new(); + for _ in 0..self.workers { let listener = Arc::clone(&listener); let factory = self.factory.clone(); let on_success = self.on_preamble_success.clone(); let on_failure = self.on_preamble_failure.clone(); - handles.push(tokio::spawn(worker_task( - listener, - factory, - on_success, - on_failure, - shutdown_tx.subscribe(), - ))); + let token = shutdown_token.clone(); + let t = tracker.clone(); + tracker.spawn(worker_task( + listener, factory, on_success, on_failure, token, t, + )); } - let join_all = futures::future::join_all(handles); - tokio::pin!(join_all); - tokio::select! { - () = shutdown => { - let _ = shutdown_tx.send(()); - } - _ = &mut join_all => {} + () = shutdown => shutdown_token.cancel(), + () = tracker.wait() => {} } - for res in join_all.await { - if let Err(e) = res { - eprintln!("worker task failed: {e}"); - } - } + tracker.close(); + tracker.wait().await; Ok(()) } } @@ -360,8 +348,8 @@ async fn worker_task( factory: F, on_success: Option>, on_failure: Option, - // Each worker owns its shutdown receiver. - mut shutdown_rx: broadcast::Receiver<()>, + shutdown: CancellationToken, + tracker: TaskTracker, ) where F: Fn() -> WireframeApp + Send + Sync + Clone + 'static, // `Preamble` ensures `T` supports borrowed decoding. @@ -370,12 +358,17 @@ async fn worker_task( let mut delay = Duration::from_millis(10); loop { tokio::select! { + biased; + + () = shutdown.cancelled() => break, + res = listener.accept() => match res { Ok((stream, _)) => { let success = on_success.clone(); let failure = on_failure.clone(); let factory = factory.clone(); - tokio::spawn(process_stream(stream, factory, success, failure)); + let t = tracker.clone(); + t.spawn(process_stream(stream, factory, success, failure)); delay = Duration::from_millis(10); } Err(e) => { @@ -384,7 +377,6 @@ async fn worker_task( delay = (delay * 2).min(Duration::from_secs(1)); } }, - _ = shutdown_rx.recv() => break, } } } @@ -461,9 +453,9 @@ mod tests { use rstest::{fixture, rstest}; use tokio::{ net::TcpListener, - sync::broadcast, time::{Duration, timeout}, }; + use tokio_util::{sync::CancellationToken, task::TaskTracker}; use super::*; @@ -813,14 +805,23 @@ mod tests { async fn test_worker_task_shutdown_signal( factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static, ) { - let (tx, rx) = broadcast::channel(1); + let token = CancellationToken::new(); + let tracker = TaskTracker::new(); let listener = Arc::new(TcpListener::bind("127.0.0.1:0").await.unwrap()); - let _ = tx.send(()); + tracker.spawn(worker_task::<_, ()>( + listener, + factory, + None, + None, + token.clone(), + tracker.clone(), + )); - let task = tokio::spawn(worker_task::<_, ()>(listener, factory, None, None, rx)); + token.cancel(); + tracker.close(); - let result = timeout(Duration::from_millis(100), task).await; + let result = timeout(Duration::from_millis(100), tracker.wait()).await; assert!(result.is_ok()); } diff --git a/tests/connection_actor.rs b/tests/connection_actor.rs index 5e034f79..6d9a0c80 100644 --- a/tests/connection_actor.rs +++ b/tests/connection_actor.rs @@ -6,7 +6,7 @@ use futures::stream; use rstest::{fixture, rstest}; use tokio::time::{Duration, sleep, timeout}; -use tokio_util::sync::CancellationToken; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; use wireframe::{ connection::{ConnectionActor, FairnessConfig}, push::PushQueues, @@ -234,3 +234,31 @@ async fn on_command_end_hook_runs( actor.run(&mut out).await.unwrap(); assert_eq!(counter.load(Ordering::SeqCst), 1); } + +#[rstest] +#[tokio::test] +async fn graceful_shutdown_waits_for_tasks() { + let tracker = TaskTracker::new(); + let token = CancellationToken::new(); + + let mut handles = Vec::new(); + for _ in 0..5 { + let (queues, handle) = PushQueues::::bounded(1, 1); + let mut actor: ConnectionActor<_, ()> = + ConnectionActor::new(queues, handle.clone(), None, token.clone()); + handles.push(handle); + tracker.spawn(async move { + let mut out = Vec::new(); + let _ = actor.run(&mut out).await; + }); + } + + token.cancel(); + tracker.close(); + + assert!( + timeout(Duration::from_millis(500), tracker.wait()) + .await + .is_ok() + ); +}