diff --git a/Cargo.lock b/Cargo.lock index 6059b6c..739c877 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -309,6 +309,7 @@ dependencies = [ "serde_json", "tempfile", "tokio", + "yaque", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index d88cede..e92de1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ comenq = { path = "crates/comenq" } comenqd = { path = "crates/comenqd" } ortho_config = { git = "https://github.com/leynos/ortho-config.git", tag = "v0.4.0" } tempfile = "3.10" # latest 3.x at time of writing; update as new patch versions release +yaque = { workspace = true } [[test]] name = "cucumber" diff --git a/crates/comenqd/src/daemon.rs b/crates/comenqd/src/daemon.rs index b13983a..e5e491a 100644 --- a/crates/comenqd/src/daemon.rs +++ b/crates/comenqd/src/daemon.rs @@ -15,6 +15,7 @@ use std::time::Duration; use tokio::fs; use tokio::io::AsyncReadExt; use tokio::net::{UnixListener, UnixStream}; +use tokio::sync::{mpsc, watch}; use yaque::{Receiver, Sender, channel}; fn build_octocrab(token: &str) -> Result { @@ -37,15 +38,30 @@ async fn ensure_queue_dir(path: &Path) -> Result<()> { Ok(()) } +pub async fn queue_writer( + mut sender: Sender, + mut rx: mpsc::UnboundedReceiver>, +) -> Result<()> { + while let Some(bytes) = rx.recv().await { + if let Err(e) = sender.send(bytes).await { + tracing::error!(error = %e, "Queue enqueue failed"); + } + } + Ok(()) +} + /// Start the daemon with the provided configuration. pub async fn run(config: Config) -> Result<()> { ensure_queue_dir(&config.queue_path).await?; tracing::info!(queue = %config.queue_path.display(), "Queue directory prepared"); let octocrab = Arc::new(build_octocrab(&config.github_token)?); - let (tx, rx) = channel(&config.queue_path)?; + let (queue_tx, rx) = channel(&config.queue_path)?; + let (client_tx, client_rx) = mpsc::unbounded_channel(); let cfg = Arc::new(config); + let (shutdown_tx, shutdown_rx) = watch::channel(()); - let listener = tokio::spawn(run_listener(cfg.clone(), tx)); + let writer = tokio::spawn(queue_writer(queue_tx, client_rx)); + let listener = tokio::spawn(run_listener(cfg.clone(), client_tx, shutdown_rx)); let worker = tokio::spawn(run_worker(cfg.clone(), rx, octocrab)); tokio::select! { @@ -59,26 +75,50 @@ pub async fn run(config: Config) -> Result<()> { }, } + let _ = shutdown_tx.send(()); + writer.await??; + Ok(()) } -async fn run_listener(config: Arc, mut tx: Sender) -> Result<()> { +pub async fn run_listener( + config: Arc, + tx: mpsc::UnboundedSender>, + mut shutdown: watch::Receiver<()>, +) -> Result<()> { let listener = prepare_listener(&config.socket_path)?; loop { - let (stream, _) = listener.accept().await?; - if let Err(e) = handle_client(stream, &mut tx).await { - tracing::warn!(error = %e, "Client handling failed"); + tokio::select! { + res = listener.accept() => match res { + Ok((stream, _)) => { + let tx_clone = tx.clone(); + tokio::spawn(async move { + if let Err(e) = handle_client(stream, tx_clone).await { + tracing::warn!(error = %e, "Client handling failed"); + } + }); + } + Err(e) => { + tracing::error!(error = %e, "Failed to accept client connection"); + tokio::time::sleep(Duration::from_millis(100)).await; + } + }, + _ = shutdown.changed() => { + break; + } } } + Ok(()) } -async fn handle_client(mut stream: UnixStream, tx: &mut Sender) -> Result<()> { +async fn handle_client(mut stream: UnixStream, tx: mpsc::UnboundedSender>) -> Result<()> { let mut buffer = Vec::new(); stream.read_to_end(&mut buffer).await?; let request: CommentRequest = serde_json::from_slice(&buffer)?; let bytes = serde_json::to_vec(&request)?; - tx.send(bytes).await?; + tx.send(bytes) + .map_err(|_| anyhow::anyhow!("queue writer dropped"))?; Ok(()) } @@ -107,8 +147,12 @@ async fn run_worker(config: Arc, mut rx: Receiver, octocrab: Arc, + cfg: Option>, + receiver: Option, + shutdown: Option>, + writer: Option>, + handle: Option>, +} + +impl std::fmt::Debug for ListenerWorld { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ListenerWorld").finish() + } +} + +#[given("a running listener task")] +async fn running_listener(world: &mut ListenerWorld) { + let dir = TempDir::new().expect("tempdir"); + let cfg = Arc::new(Config { + github_token: String::from("t"), + socket_path: dir.path().join("sock"), + queue_path: dir.path().join("q"), + cooldown_period_seconds: 1, + }); + let (sender, receiver) = channel(&cfg.queue_path).expect("channel"); + let (client_tx, writer_rx) = mpsc::unbounded_channel(); + let (shutdown_tx, shutdown_rx) = watch::channel(()); + let cfg_clone = cfg.clone(); + let writer = tokio::spawn(async move { + queue_writer(sender, writer_rx).await.unwrap(); + }); + let handle = tokio::spawn(async move { + run_listener(cfg_clone, client_tx, shutdown_rx) + .await + .unwrap(); + }); + world.dir = Some(dir); + world.cfg = Some(cfg); + world.shutdown = Some(shutdown_tx); + world.writer = Some(writer); + world.receiver = Some(receiver); + world.handle = Some(handle); + // wait for socket create + for _ in 0..10 { + if world.cfg.as_ref().unwrap().socket_path.exists() { + break; + } + sleep(Duration::from_millis(10)).await; + } +} + +#[when("a client sends a valid request")] +async fn client_sends_valid(world: &mut ListenerWorld) { + let cfg = world.cfg.as_ref().unwrap(); + let mut stream = UnixStream::connect(&cfg.socket_path) + .await + .expect("connect"); + let req = CommentRequest { + owner: "o".into(), + repo: "r".into(), + pr_number: 1, + body: "b".into(), + }; + let data = serde_json::to_vec(&req).unwrap(); + stream.write_all(&data).await.unwrap(); + stream.shutdown().await.expect("shutdown"); +} + +#[when("a client sends invalid JSON")] +async fn client_sends_invalid(world: &mut ListenerWorld) { + let cfg = world.cfg.as_ref().unwrap(); + let mut stream = UnixStream::connect(&cfg.socket_path) + .await + .expect("connect"); + stream.write_all(b"not json").await.unwrap(); + stream.shutdown().await.expect("shutdown"); +} + +#[then("the request is enqueued")] +async fn request_enqueued(world: &mut ListenerWorld) { + let receiver = world.receiver.as_mut().unwrap(); + let guard = receiver.recv().await.expect("recv"); + let req: CommentRequest = serde_json::from_slice(&guard).unwrap(); + assert_eq!(req.owner, "o"); +} + +#[then("the request is rejected")] +async fn request_rejected(world: &mut ListenerWorld) { + let receiver = world.receiver.as_mut().unwrap(); + let res = tokio::time::timeout(Duration::from_millis(100), receiver.recv()).await; + assert!(res.is_err()); +} + +impl Drop for ListenerWorld { + fn drop(&mut self) { + if let Some(shutdown) = self.shutdown.take() { + let _ = shutdown.send(()); + } + if let Some(writer) = self.writer.take() { + writer.abort(); + } + if let Some(handle) = self.handle.take() { + handle.abort(); + } + } +} diff --git a/tests/steps/mod.rs b/tests/steps/mod.rs index 3e68047..455cf76 100644 --- a/tests/steps/mod.rs +++ b/tests/steps/mod.rs @@ -6,3 +6,5 @@ pub mod comment_steps; pub use comment_steps::CommentWorld; pub mod config_steps; pub use config_steps::ConfigWorld; +pub mod listener_steps; +pub use listener_steps::ListenerWorld;