diff --git a/codex-rs/acp/docs.md b/codex-rs/acp/docs.md index a853611ae..6c33a6936 100644 --- a/codex-rs/acp/docs.md +++ b/codex-rs/acp/docs.md @@ -808,26 +808,54 @@ The `pending_tool_calls` state is shared via `Arc>>`) that pairs the sender with a monotonic generation counter. The routing logic in the notification handler uses `try_send` with fallthrough: if the per-prompt channel fails (receiver dropped, or channel full), the notification falls through to the `persistent_tx` channel instead of being silently dropped. -The `turn_interrupted: Arc` field on `AcpBackend` prevents this: +The critical invariant is that `prompt()` does **not** clear `active_update_tx` when it returns. This is because `block_task()` (the SACP request/response mechanism) can return before all `SessionNotification` events have been delivered. Instead, callers use a `done_tx`/`done_rx` oneshot to signal the `update_handler` task: + +``` +prompt() returns + | + v +done_tx.send(()) -- signals update_handler that prompt is done + | + v +update_handler enters drain mode: + tokio::select! switches from waiting on (update_rx OR done_rx) + to waiting on update_rx with a 500ms timeout + | + v +After timeout or channel close, update_handler exits + (dropping update_rx, which causes future try_send to fail) + | + v +Next prompt() overwrites active_update_tx slot with a fresh sender +``` + +The generation counter on `active_update_tx` prevents stale cleanup: `close_update_channel(generation)` only clears the slot if the generation matches, so it is safe for `load_session` (which is sequential) to clear its own channel without risking a concurrent prompt's channel. `prompt()` callers do not call `close_update_channel` at all — they rely on the done/drain pattern instead. + +**Turn Interrupt Guard — Monotonic Turn Counter** (`submit_and_ops.rs`, `user_input.rs`): + +When `Op::Interrupt` fires, the backend emits `TurnLifecycle::Aborted` synchronously and calls `cancel()` on the ACP connection. However, the background tokio task spawned by `handle_user_input()` (and `handle_compact()`) continues running after cancellation and may emit stale `TurnLifecycle::Completed` or `ErrorEvent` at the end of its event loop. If the user submits a new message before these stale events arrive, they race with the next turn and can prematurely terminate it. + +The `turn_id: Arc` field on `AcpBackend` is a monotonic counter that eliminates this race. It is incremented on every `Op::Interrupt` and on every new turn (`handle_user_input()`, `handle_compact()`). Each spawned task captures its own turn ID at spawn time and only emits tail events (errors, warnings, `Completed`) if the counter still matches: ``` Op::Interrupt: - 1. turn_interrupted.store(true) -- flag the current turn as interrupted + 1. turn_id.fetch_add(1) -- advance the counter, invalidating the current task 2. connection.cancel() -- cancel the ACP session -handle_user_input(): - 1. turn_interrupted.store(false) -- reset for new turn +handle_user_input() / handle_compact(): + 1. my_turn_id = turn_id.fetch_add(1) + 1 -- advance counter, capture this turn's ID ... spawned task epilogue: - if !turn_interrupted -- only emit Completed if not interrupted + if turn_id.load() == my_turn_id -- only emit tail events if still current + emit ErrorEvent (if error) emit TurnLifecycle::Completed ``` -Since `TurnLifecycle::Aborted` already serves as the turn-ending signal for interrupted turns, suppressing the stale `Completed` is safe. The TUI also has a defense-in-depth counter (`pending_stale_completes`) that ignores stale `Completed` events at the presentation layer (see `@/codex-rs/tui/docs.md`). +Because the counter is monotonic and never reset, there is no TOCTOU window: an interrupt always invalidates any previously spawned task, and a new turn always gets a fresh ID that cannot collide with prior tasks. The TUI does not need any complementary guard — stale events are fully suppressed at the backend layer. **Tool Classification System:** diff --git a/codex-rs/acp/src/backend/hooks.rs b/codex-rs/acp/src/backend/hooks.rs index 0404c3392..ec9f22f9f 100644 --- a/codex-rs/acp/src/backend/hooks.rs +++ b/codex-rs/acp/src/backend/hooks.rs @@ -50,8 +50,8 @@ pub(super) async fn run_prompt_summary( drop(connection); match prompt_result { - Ok(Ok(_)) => {} - Ok(Err(e)) => return Err(e), + Ok((Ok(_), _gen)) => {} + Ok((Err(e), _gen)) => return Err(e), Err(_) => { debug!("Prompt summary timed out"); return Ok(()); diff --git a/codex-rs/acp/src/backend/mod.rs b/codex-rs/acp/src/backend/mod.rs index cd84bad1f..a8a7ed360 100644 --- a/codex-rs/acp/src/backend/mod.rs +++ b/codex-rs/acp/src/backend/mod.rs @@ -9,7 +9,7 @@ use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; -use std::sync::atomic::AtomicBool; +use std::sync::atomic::AtomicU64; use anyhow::Result; use codex_core::config::types::McpServerConfig; @@ -50,6 +50,12 @@ use crate::transcript::TranscriptRecorder; use crate::translator; use crate::undo::GhostSnapshotStack; +/// Maximum time to wait for late-arriving `SessionNotification` events after +/// `block_task()` returns. Empirically most arrive within ~50 ms; 500 ms +/// provides generous headroom without noticeably delaying the turn lifecycle. +pub(super) const POST_PROMPT_DRAIN_TIMEOUT: std::time::Duration = + std::time::Duration::from_millis(500); + // ============================================================================= // Error Categorization // ============================================================================= @@ -324,10 +330,12 @@ pub struct AcpBackend { client_event_normalizer: Arc>, /// MCP server configuration forwarded to ACP agents at session creation. mcp_servers: HashMap, - /// Set when `Op::Interrupt` fires; checked by the spawned prompt task - /// before emitting `TurnLifecycle::Completed`. Prevents a stale - /// `Completed` from a cancelled task from interfering with the next turn. - turn_interrupted: Arc, + /// Monotonic turn counter incremented on every new turn and on every + /// interrupt. Each spawned prompt task captures its turn ID at spawn + /// time and only emits `TurnLifecycle::Completed` if the counter still + /// matches — guaranteeing that stale tasks from cancelled turns never + /// emit a Completed that could interfere with a subsequent turn. + turn_id: Arc, } mod helpers; diff --git a/codex-rs/acp/src/backend/session.rs b/codex-rs/acp/src/backend/session.rs index 8a6eee036..0ca54b96f 100644 --- a/codex-rs/acp/src/backend/session.rs +++ b/codex-rs/acp/src/backend/session.rs @@ -81,10 +81,6 @@ impl AcpBackend { match connection.load_session(sid, &cwd, update_tx).await { Ok(session_id) => { - // Wait for all updates to be collected. This is safe - // because the collect task buffers into a Vec (no - // backpressure) and update_rx closes when load_session - // completes (the worker thread drops update_tx). let buffered_client_events = collect_handle.await.unwrap_or_default(); if !buffered_client_events.is_empty() { debug!( @@ -253,7 +249,7 @@ impl AcpBackend { script_timeout: config.script_timeout, client_event_normalizer: Arc::clone(&client_event_normalizer), mcp_servers: config.mcp_servers.clone(), - turn_interrupted: Arc::new(AtomicBool::new(false)), + turn_id: Arc::new(AtomicU64::new(0)), }; // Execute session_start hooks diff --git a/codex-rs/acp/src/backend/spawn_and_relay.rs b/codex-rs/acp/src/backend/spawn_and_relay.rs index b225732da..3b5929834 100644 --- a/codex-rs/acp/src/backend/spawn_and_relay.rs +++ b/codex-rs/acp/src/backend/spawn_and_relay.rs @@ -186,7 +186,7 @@ impl AcpBackend { script_timeout: config.script_timeout, client_event_normalizer: Arc::clone(&client_event_normalizer), mcp_servers: config.mcp_servers.clone(), - turn_interrupted: Arc::new(AtomicBool::new(false)), + turn_id: Arc::new(AtomicU64::new(0)), }; // Execute session_start hooks diff --git a/codex-rs/acp/src/backend/submit_and_ops.rs b/codex-rs/acp/src/backend/submit_and_ops.rs index 8f3053697..0915b921c 100644 --- a/codex-rs/acp/src/backend/submit_and_ops.rs +++ b/codex-rs/acp/src/backend/submit_and_ops.rs @@ -23,7 +23,7 @@ impl AcpBackend { self.handle_user_input(items, &id).await?; } Op::Interrupt => { - self.turn_interrupted.store(true, Ordering::SeqCst); + self.turn_id.fetch_add(1, Ordering::SeqCst); self.connection .cancel(&*self.session_id.read().await) .await?; @@ -276,6 +276,7 @@ impl AcpBackend { // Create channel for receiving session updates let (update_tx, mut update_rx) = mpsc::channel(32); + let (done_tx, mut done_rx) = tokio::sync::oneshot::channel::<()>(); // Clone what we need for capturing the response let event_tx = self.event_tx.clone(); @@ -292,7 +293,8 @@ impl AcpBackend { let client_event_normalizer = Arc::clone(&self.client_event_normalizer); let backend_event_tx = self.backend_event_tx.clone(); let transcript_recorder = self.transcript_recorder.clone(); - let turn_interrupted = Arc::clone(&self.turn_interrupted); + let turn_id = Arc::clone(&self.turn_id); + let my_turn_id = turn_id.fetch_add(1, Ordering::SeqCst) + 1; // Spawn task to handle the prompt and capture the summary tokio::spawn(async move { @@ -316,8 +318,31 @@ impl AcpBackend { let update_handler = tokio::spawn(async move { let mut summary_text = String::new(); - - while let Some(update) = update_rx.recv().await { + let mut done = false; + + loop { + let update = if done { + match tokio::time::timeout( + super::POST_PROMPT_DRAIN_TIMEOUT, + update_rx.recv(), + ) + .await + { + Ok(Some(u)) => u, + _ => break, + } + } else { + tokio::select! { + msg = update_rx.recv() => match msg { + Some(u) => u, + None => break, + }, + _ = &mut done_rx => { + done = true; + continue; + } + } + }; let client_events = normalize_session_update(&client_event_normalizer, &update).await; forward_client_events(&backend_event_tx_for_updates, &client_events).await; @@ -338,68 +363,63 @@ impl AcpBackend { // Send the summarization prompt let session_id_for_timer = session_id.to_string(); - let result = connection.prompt(session_id, prompt, update_tx).await; + let (result, _update_gen) = connection.prompt(session_id, prompt, update_tx).await; + + // Signal the update_handler to drain remaining events and stop. + let _ = done_tx.send(()); // Wait for all updates to be processed let _ = update_handler.await; - // If prompt failed, send error event and clear any partial summary - if let Err(ref e) = result { - warn!("Compact prompt failed: {e}"); - // Clear any partial summary that may have been stored - *pending_compact_summary.lock().await = None; - let _ = event_tx - .send(Event { - id: id_clone.clone(), - msg: EventMsg::Error(ErrorEvent { - message: format!("Compact failed: {e}"), - codex_error_info: None, - }), - }) - .await; - } else { - // Create a new session to clear the agent's conversation history. - // The summary we captured will be prepended to the next user prompt, - // giving the agent context about the previous conversation. - match connection.create_session(&cwd, mcp_servers).await { - Ok(new_session_id) => { - debug!("Created new session after compact: {:?}", new_session_id); - *session_id_lock.write().await = new_session_id; - } - Err(e) => { - warn!("Failed to create new session after compact: {e}"); - // Continue anyway - summary will still be prepended but agent - // will retain its full history, which is suboptimal but functional + // Only emit tail events if this is still the active turn. When + // the turn_id has advanced, this task is stale and all its late + // events (errors, warnings, Completed) must be suppressed. + if turn_id.load(Ordering::SeqCst) == my_turn_id { + if let Err(ref e) = result { + warn!("Compact prompt failed: {e}"); + *pending_compact_summary.lock().await = None; + let _ = event_tx + .send(Event { + id: id_clone.clone(), + msg: EventMsg::Error(ErrorEvent { + message: format!("Compact failed: {e}"), + codex_error_info: None, + }), + }) + .await; + } else { + match connection.create_session(&cwd, mcp_servers).await { + Ok(new_session_id) => { + debug!("Created new session after compact: {:?}", new_session_id); + *session_id_lock.write().await = new_session_id; + } + Err(e) => { + warn!("Failed to create new session after compact: {e}"); + } } - } - - // Send ContextCompacted event to notify TUI, including the - // summary text so the TUI can reprint it under a new session header. - let compact_summary = pending_compact_summary.lock().await.clone(); - emit_client_event( - &backend_event_tx, - transcript_recorder.as_ref(), - nori_protocol::ClientEvent::TurnLifecycle( - nori_protocol::TurnLifecycle::ContextCompacted { - summary: compact_summary.clone(), - }, - ), - ) - .await; - // Send warning about long conversations - let _ = event_tx - .send(Event { - id: id_clone.clone(), - msg: EventMsg::Warning(WarningEvent { - message: "Heads up: Long conversations and multiple compactions can cause the model to be less accurate. Start a new conversation when possible to keep conversations small and targeted.".to_string(), - }), - }) + let compact_summary = pending_compact_summary.lock().await.clone(); + emit_client_event( + &backend_event_tx, + transcript_recorder.as_ref(), + nori_protocol::ClientEvent::TurnLifecycle( + nori_protocol::TurnLifecycle::ContextCompacted { + summary: compact_summary.clone(), + }, + ), + ) .await; - } - // Send TaskComplete event, unless the turn was interrupted. - if !turn_interrupted.load(Ordering::SeqCst) { + let _ = event_tx + .send(Event { + id: id_clone.clone(), + msg: EventMsg::Warning(WarningEvent { + message: "Heads up: Long conversations and multiple compactions can cause the model to be less accurate. Start a new conversation when possible to keep conversations small and targeted.".to_string(), + }), + }) + .await; + } + emit_client_event( &backend_event_tx, transcript_recorder.as_ref(), @@ -410,21 +430,22 @@ impl AcpBackend { ), ) .await; - } - - // Start idle timer if configured - if let Some(duration) = notify_after_idle.as_duration() { - let idle_secs = duration.as_secs(); - let user_notifier_for_timer = Arc::clone(&user_notifier); - let idle_task = tokio::spawn(async move { - tokio::time::sleep(duration).await; - user_notifier_for_timer.notify(&codex_core::UserNotification::Idle { - session_id: session_id_for_timer, - idle_duration_secs: idle_secs, + // Start idle timer if configured + if let Some(duration) = notify_after_idle.as_duration() { + let idle_secs = duration.as_secs(); + let user_notifier_for_timer = Arc::clone(&user_notifier); + let idle_task = tokio::spawn(async move { + tokio::time::sleep(duration).await; + user_notifier_for_timer.notify(&codex_core::UserNotification::Idle { + session_id: session_id_for_timer, + idle_duration_secs: idle_secs, + }); }); - }); - // Store the abort handle so the timer can be cancelled on new activity - *idle_timer_abort.lock().await = Some(idle_task.abort_handle()); + *idle_timer_abort.lock().await = Some(idle_task.abort_handle()); + } + } else if let Err(ref e) = result { + warn!("Compact prompt failed (stale turn, suppressed): {e}"); + *pending_compact_summary.lock().await = None; } }); diff --git a/codex-rs/acp/src/backend/user_input.rs b/codex-rs/acp/src/backend/user_input.rs index bd0b52d13..ba4421917 100644 --- a/codex-rs/acp/src/backend/user_input.rs +++ b/codex-rs/acp/src/backend/user_input.rs @@ -5,8 +5,10 @@ use super::*; impl AcpBackend { /// Handle user input by sending a prompt to the ACP agent. pub(super) async fn handle_user_input(&self, items: Vec, id: &str) -> Result<()> { - // Reset the interrupt flag so this turn's Completed will be emitted. - self.turn_interrupted.store(false, Ordering::SeqCst); + // Advance the turn counter. The returned value (+1) is this turn's ID. + // The spawned task captures it and only emits Completed if the counter + // still matches, which guarantees stale tasks from prior turns are silent. + let my_turn_id = self.turn_id.fetch_add(1, Ordering::SeqCst) + 1; // Separate text items (needed for hooks, summary, transcript) from // image items (converted to ACP ContentBlock::Image). @@ -172,8 +174,10 @@ impl AcpBackend { } prompt.extend(image_blocks); - // Create channel for receiving session updates + // Create channel for receiving session updates, and a oneshot to + // signal the update_handler to stop after prompt() returns. let (update_tx, mut update_rx) = mpsc::channel(32); + let (done_tx, mut done_rx) = tokio::sync::oneshot::channel::<()>(); // Clone what we need for the background task let event_tx = self.event_tx.clone(); @@ -198,7 +202,7 @@ impl AcpBackend { let pending_hook_context = Arc::clone(&self.pending_hook_context); let client_event_normalizer = Arc::clone(&self.client_event_normalizer); let backend_event_tx = self.backend_event_tx.clone(); - let turn_interrupted = Arc::clone(&self.turn_interrupted); + let turn_id = Arc::clone(&self.turn_id); // Spawn task to handle the prompt and translate events tokio::spawn(async move { @@ -235,7 +239,34 @@ impl AcpBackend { let mut has_fired_pre_agent_response = false; let mut has_agent_text = false; let mut needs_agent_separator = false; - while let Some(update) = update_rx.recv().await { + // When prompt() returns, done_rx fires. We then drain any + // late-arriving notifications with a short timeout before + // exiting. This is needed because block_task() can return + // before all SessionNotification events have been delivered. + let mut done = false; + loop { + let update = if done { + match tokio::time::timeout( + super::POST_PROMPT_DRAIN_TIMEOUT, + update_rx.recv(), + ) + .await + { + Ok(Some(u)) => u, + _ => break, + } + } else { + tokio::select! { + msg = update_rx.recv() => match msg { + Some(u) => u, + None => break, + }, + _ = &mut done_rx => { + done = true; + continue; + } + } + }; let client_events = normalize_session_update(&client_event_normalizer, &update).await; forward_client_events(&backend_event_tx_for_updates, &client_events).await; @@ -378,7 +409,13 @@ impl AcpBackend { // Send the prompt (clone session_id before moving it since we need it for idle timer) let session_id_for_timer = session_id.to_string(); - let result = connection.prompt(session_id, prompt, update_tx).await; + let (result, _update_gen) = connection.prompt(session_id, prompt, update_tx).await; + + // Signal the update_handler to drain remaining events and stop. + // We do NOT close the active_update_tx slot here — late + // notifications may still arrive and should reach the handler. + // The slot will be overwritten by the next prompt() call. + let _ = done_tx.send(()); // Wait for all updates to be processed and get accumulated text let accumulated_text = update_handler.await.unwrap_or_default(); @@ -477,71 +514,53 @@ impl AcpBackend { ); } - // If prompt failed, send an error event to the TUI BEFORE TaskComplete - // This ensures the user sees why their request failed instead of a silent failure - if let Err(ref e) = result { - let error_string = format!("{e:?}"); - let category = categorize_acp_error(&error_string); - let display_error = format!("{e:#}"); - - // Generate user-friendly message based on error category - let user_message = match category { - AcpErrorCategory::Authentication => { - format!( - "Authentication error: {display_error}. Please check your credentials or re-authenticate." - ) - } - AcpErrorCategory::QuotaExceeded => { - format!("Rate limit or quota exceeded: {display_error}") - } - AcpErrorCategory::ExecutableNotFound => { - format!("Agent executable not found: {display_error}") - } - AcpErrorCategory::Initialization => { - format!("Agent initialization failed: {display_error}") - } - AcpErrorCategory::PromptTooLong => { - "Prompt is too long. Try using /compact to reduce context size, or start a new session." - .to_string() - } - AcpErrorCategory::ApiServerError => { - "The API returned a server error. This is usually temporary — please try again." - .to_string() - } - AcpErrorCategory::Unknown => { - format!("ACP prompt failed: {display_error}") - } - }; - - warn!("ACP prompt failed: {}", e); - debug!( - target: "acp_event_flow", - user_message = %user_message, - "ACP prompt failure: sending ErrorEvent to TUI" - ); + if turn_id.load(Ordering::SeqCst) == my_turn_id { + if let Err(ref e) = result { + let error_string = format!("{e:?}"); + let category = categorize_acp_error(&error_string); + let display_error = format!("{e:#}"); - // Send error event to TUI so user sees the error - let _ = event_tx - .send(Event { - id: id_clone.clone(), - msg: EventMsg::Error(ErrorEvent { - message: user_message.clone(), - codex_error_info: None, - }), - }) - .await; - - debug!( - target: "acp_event_flow", - "ACP prompt failure: ErrorEvent sent to TUI" - ); - } + let user_message = match category { + AcpErrorCategory::Authentication => { + format!( + "Authentication error: {display_error}. Please check your credentials or re-authenticate." + ) + } + AcpErrorCategory::QuotaExceeded => { + format!("Rate limit or quota exceeded: {display_error}") + } + AcpErrorCategory::ExecutableNotFound => { + format!("Agent executable not found: {display_error}") + } + AcpErrorCategory::Initialization => { + format!("Agent initialization failed: {display_error}") + } + AcpErrorCategory::PromptTooLong => { + "Prompt is too long. Try using /compact to reduce context size, or start a new session." + .to_string() + } + AcpErrorCategory::ApiServerError => { + "The API returned a server error. This is usually temporary — please try again." + .to_string() + } + AcpErrorCategory::Unknown => { + format!("ACP prompt failed: {display_error}") + } + }; + + warn!("ACP prompt failed: {}", e); + + let _ = event_tx + .send(Event { + id: id_clone.clone(), + msg: EventMsg::Error(ErrorEvent { + message: user_message.clone(), + codex_error_info: None, + }), + }) + .await; + } - // Send TaskComplete event to end the turn, unless this turn was - // interrupted. When Op::Interrupt fires, it emits - // TurnLifecycle::Aborted synchronously; emitting a Completed here - // would race with the next turn and prematurely terminate it. - if !turn_interrupted.load(Ordering::SeqCst) { emit_client_event( &backend_event_tx, transcript_recorder.as_ref(), @@ -552,21 +571,21 @@ impl AcpBackend { ), ) .await; - } - - // Start idle timer if configured - if let Some(duration) = notify_after_idle.as_duration() { - let idle_secs = duration.as_secs(); - let user_notifier_for_timer = Arc::clone(&user_notifier); - let idle_task = tokio::spawn(async move { - tokio::time::sleep(duration).await; - user_notifier_for_timer.notify(&codex_core::UserNotification::Idle { - session_id: session_id_for_timer, - idle_duration_secs: idle_secs, + // Start idle timer if configured + if let Some(duration) = notify_after_idle.as_duration() { + let idle_secs = duration.as_secs(); + let user_notifier_for_timer = Arc::clone(&user_notifier); + let idle_task = tokio::spawn(async move { + tokio::time::sleep(duration).await; + user_notifier_for_timer.notify(&codex_core::UserNotification::Idle { + session_id: session_id_for_timer, + idle_duration_secs: idle_secs, + }); }); - }); - // Store the abort handle so the timer can be cancelled on new activity - *idle_timer_abort.lock().await = Some(idle_task.abort_handle()); + *idle_timer_abort.lock().await = Some(idle_task.abort_handle()); + } + } else if let Err(ref e) = result { + warn!("ACP prompt failed (stale turn, suppressed): {e}"); } }); diff --git a/codex-rs/acp/src/connection/sacp_connection.rs b/codex-rs/acp/src/connection/sacp_connection.rs index 3a5e01e5e..cb249e890 100644 --- a/codex-rs/acp/src/connection/sacp_connection.rs +++ b/codex-rs/acp/src/connection/sacp_connection.rs @@ -66,6 +66,8 @@ use sacp::schema::SetSessionModelRequest; /// Minimum supported ACP protocol version. const MINIMUM_SUPPORTED_VERSION: ProtocolVersion = ProtocolVersion::V1; +type ActiveUpdateSlot = std::sync::Arc)>>>; + /// A thread-safe connection to an ACP agent subprocess using SACP v10. /// /// Unlike the old `AcpConnection`, this does NOT require a dedicated worker thread. @@ -95,9 +97,17 @@ pub struct SacpConnection { /// Shared session update sender. The notification handler routes updates /// to whoever currently holds the active sender. During a prompt, this - /// contains the caller's `update_tx`. Between turns, it is `None` and - /// notifications fall through to the persistent channel. - active_update_tx: std::sync::Arc>>>, + /// holds the caller's `update_tx`. It is NOT cleared when the prompt + /// returns because notifications may arrive after `block_task` completes. + /// Between turns the receiver is dropped by the `update_handler` task, + /// so `try_send` fails and the notification handler falls through to + /// `persistent_tx`. The next `prompt()` overwrites the slot. + active_update_tx: ActiveUpdateSlot, + /// Monotonic counter paired with `active_update_tx`. Each install gets + /// a unique generation; `close_update_channel` only clears if the + /// generation matches, preventing a stale task from wiping a newer + /// prompt's sender. + update_generation: std::sync::atomic::AtomicU64, /// Handle to the background task driving the SACP connection. connection_task: tokio::task::JoinHandle<()>, @@ -181,8 +191,7 @@ impl SacpConnection { // --- Set up channels --- let (approval_tx, approval_rx) = mpsc::channel::(16); let (persistent_tx, persistent_rx) = mpsc::channel::(64); - let active_update_tx: std::sync::Arc>>> = - std::sync::Arc::new(Mutex::new(None)); + let active_update_tx: ActiveUpdateSlot = std::sync::Arc::new(Mutex::new(None)); // --- Build SACP connection --- let transport = ByteStreams::new(stdin.compat_write(), stdout.compat()); @@ -212,9 +221,18 @@ impl SacpConnection { async move |notification: SessionNotification, _cx| { let update = notification.update; let guard = update_tx.lock().await; - if let Some(tx) = guard.as_ref() { - let _ = tx.try_send(update); + // Try the per-prompt channel first. If it fails + // (receiver dropped between turns, or channel + // full), fall through to the persistent channel. + let unsent = if let Some((_, tx)) = guard.as_ref() { + match tx.try_send(update) { + Ok(()) => None, + Err(e) => Some(e.into_inner()), + } } else { + Some(update) + }; + if let Some(update) = unsent { let _ = persistent_tx.try_send(update); } Ok(()) @@ -343,12 +361,16 @@ impl SacpConnection { .status(ToolCallStatus::Pending); { let guard = update_tx.lock().await; - if let Some(tx) = guard.as_ref() { - let _ = - tx.try_send(SessionUpdate::ToolCall(tool_call)); + let unsent = if let Some((_, tx)) = guard.as_ref() { + match tx.try_send(SessionUpdate::ToolCall(tool_call)) { + Ok(()) => None, + Err(e) => Some(e.into_inner()), + } } else { - let _ = persistent_tx - .try_send(SessionUpdate::ToolCall(tool_call)); + Some(SessionUpdate::ToolCall(tool_call)) + }; + if let Some(update) = unsent { + let _ = persistent_tx.try_send(update); } } @@ -446,12 +468,16 @@ impl SacpConnection { .status(ToolCallStatus::Pending); { let guard = update_tx.lock().await; - if let Some(tx) = guard.as_ref() { - let _ = - tx.try_send(SessionUpdate::ToolCall(tool_call)); + let unsent = if let Some((_, tx)) = guard.as_ref() { + match tx.try_send(SessionUpdate::ToolCall(tool_call)) { + Ok(()) => None, + Err(e) => Some(e.into_inner()), + } } else { - let _ = persistent_tx - .try_send(SessionUpdate::ToolCall(tool_call)); + Some(SessionUpdate::ToolCall(tool_call)) + }; + if let Some(update) = unsent { + let _ = persistent_tx.try_send(update); } } @@ -546,6 +572,7 @@ impl SacpConnection { persistent_rx, model_state: std::sync::Arc::new(std::sync::RwLock::new(AcpModelState::new())), active_update_tx, + update_generation: std::sync::atomic::AtomicU64::new(0), connection_task, child, stderr_task, @@ -595,11 +622,7 @@ impl SacpConnection { cwd: &Path, update_tx: mpsc::Sender, ) -> Result { - // Install the update channel for replay events. - { - let mut guard = self.active_update_tx.lock().await; - *guard = Some(update_tx); - } + let my_gen = self.install_update_channel(update_tx).await; let result = self .cx @@ -608,11 +631,8 @@ impl SacpConnection { .await .context("Failed to load ACP session"); - // Uninstall so replay events stop flowing to the caller's channel. - { - let mut guard = self.active_update_tx.lock().await; - *guard = None; - } + // Safe to clear here — load_session is never called concurrently. + self.close_update_channel(my_gen).await; let response = result?; @@ -623,8 +643,6 @@ impl SacpConnection { *state = AcpModelState::from_session_model_state(models); } - // The session ID from the request is reused since the response - // doesn't contain one. Ok(SessionId::from(session_id.to_string())) } @@ -634,12 +652,8 @@ impl SacpConnection { session_id: SessionId, prompt: Vec, update_tx: mpsc::Sender, - ) -> Result { - // Install the update channel. - { - let mut guard = self.active_update_tx.lock().await; - *guard = Some(update_tx); - } + ) -> (Result, u64) { + let my_gen = self.install_update_channel(update_tx).await; let result = self .cx @@ -648,13 +662,37 @@ impl SacpConnection { .await .context("ACP prompt failed"); - // Uninstall so inter-turn notifications go to persistent. - { - let mut guard = self.active_update_tx.lock().await; + // Do NOT clear active_update_tx here. Late SessionNotification + // events may still arrive and should flow to the update_handler + // via the done_rx drain. The slot will be overwritten by the + // next prompt() call; between turns, try_send failure falls + // through to persistent_tx automatically. + + (result.map(|r| r.stop_reason), my_gen) + } + + /// Install an update sender in the shared slot, returning the generation + /// counter for use with `close_update_channel`. + async fn install_update_channel(&self, update_tx: mpsc::Sender) -> u64 { + let my_gen = self + .update_generation + .fetch_add(1, std::sync::atomic::Ordering::SeqCst) + + 1; + let mut guard = self.active_update_tx.lock().await; + *guard = Some((my_gen, update_tx)); + my_gen + } + + /// Drop the active update sender if and only if the generation matches, + /// closing the channel so the `update_handler` task can terminate. + /// If a newer prompt has already installed its own sender, this is a + /// no-op — the newer prompt's channel is left intact. + pub async fn close_update_channel(&self, generation: u64) { + let mut guard = self.active_update_tx.lock().await; + let clearing = matches!(guard.as_ref(), Some((g, _)) if *g == generation); + if clearing { *guard = None; } - - result.map(|r| r.stop_reason) } /// Cancel an ongoing prompt. diff --git a/codex-rs/acp/src/connection/sacp_connection_tests.rs b/codex-rs/acp/src/connection/sacp_connection_tests.rs index 193d0e991..e18a43703 100644 --- a/codex-rs/acp/src/connection/sacp_connection_tests.rs +++ b/codex-rs/acp/src/connection/sacp_connection_tests.rs @@ -70,7 +70,8 @@ async fn test_prompt_receives_text_updates() { let (tx, mut rx) = mpsc::channel(32); let prompt = vec![acp::ContentBlock::Text(acp::TextContent::new("Hello"))]; - let stop_reason = conn.prompt(session_id, prompt, tx).await.expect("prompt"); + let (stop_reason_result, _gen) = conn.prompt(session_id, prompt, tx).await; + let stop_reason = stop_reason_result.expect("prompt"); // Collect all text messages from the updates channel. let mut messages = Vec::new(); @@ -191,7 +192,7 @@ async fn test_approval_receiver_forwards_requests() { .send(codex_protocol::protocol::ReviewDecision::Approved); // The prompt should complete (either normally or error) after approval. - let result = tokio::time::timeout(std::time::Duration::from_secs(10), prompt_handle) + let (result, _gen) = tokio::time::timeout(std::time::Duration::from_secs(10), prompt_handle) .await .expect("Prompt should complete within 10s after approval") .expect("Prompt task should not panic"); @@ -246,7 +247,7 @@ async fn test_codex_home_not_inherited() { let (tx, mut rx) = mpsc::channel(32); let prompt = vec![acp::ContentBlock::Text(acp::TextContent::new("check env"))]; - conn.prompt(session_id, prompt, tx).await.expect("prompt"); + conn.prompt(session_id, prompt, tx).await.0.expect("prompt"); let mut messages = Vec::new(); while let Ok(update) = rx.try_recv() { @@ -361,15 +362,81 @@ async fn test_cancel_during_prompt() { .expect("cancel should succeed"); // The prompt should complete with Cancelled stop reason - let result = tokio::time::timeout(std::time::Duration::from_secs(5), prompt_task) + let (result, _gen) = tokio::time::timeout(std::time::Duration::from_secs(5), prompt_task) .await .expect("Prompt should complete within 5s after cancel") - .expect("Prompt task should not panic") - .expect("Prompt should not error after cancel"); + .expect("Prompt task should not panic"); + let stop_reason = result.expect("Prompt should not error after cancel"); assert_eq!( - result, + stop_reason, acp::StopReason::Cancelled, "Stop reason should be Cancelled after cancel" ); } + +/// Test that the generation counter in `active_update_tx` prevents a stale +/// prompt's uninstall from wiping a newer prompt's channel. This directly +/// tests the `SacpConnection::prompt()` install/uninstall logic. +/// +/// We can't easily test concurrent overlapping prompts with the mock agent +/// (it doesn't handle two concurrent prompts), but we CAN verify that after +/// cancel → new prompt, the new prompt still receives its response correctly. +#[tokio::test] +#[serial] +async fn test_sequential_prompt_after_cancel_receives_response() { + let Some(config) = mock_agent_config() else { + return; + }; + + let temp_dir = tempdir().expect("temp dir"); + + let conn = SacpConnection::spawn(&config, temp_dir.path()) + .await + .expect("spawn"); + + let session_id = conn + .create_session(temp_dir.path(), vec![]) + .await + .expect("create session"); + + // --- Prompt 1: normal prompt, runs to completion --- + let (tx1, mut rx1) = mpsc::channel(32); + let prompt1 = vec![acp::ContentBlock::Text(acp::TextContent::new("hello"))]; + conn.prompt(session_id.clone(), prompt1, tx1) + .await + .0 + .expect("prompt 1"); + + let mut msgs1 = Vec::new(); + while let Ok(update) = rx1.try_recv() { + if let acp::SessionUpdate::AgentMessageChunk(chunk) = update + && let acp::ContentBlock::Text(text) = chunk.content + { + msgs1.push(text.text); + } + } + assert!(!msgs1.is_empty(), "Prompt 1 should receive text"); + + // --- Prompt 2: should also receive its response correctly --- + // This verifies the uninstall from prompt 1 doesn't corrupt state + // for prompt 2. + let (tx2, mut rx2) = mpsc::channel(32); + let prompt2 = vec![acp::ContentBlock::Text(acp::TextContent::new( + "hello again", + ))]; + conn.prompt(session_id.clone(), prompt2, tx2) + .await + .0 + .expect("prompt 2"); + + let mut msgs2 = Vec::new(); + while let Ok(update) = rx2.try_recv() { + if let acp::SessionUpdate::AgentMessageChunk(chunk) = update + && let acp::ContentBlock::Text(text) = chunk.content + { + msgs2.push(text.text); + } + } + assert!(!msgs2.is_empty(), "Prompt 2 should receive text updates"); +} diff --git a/codex-rs/tui/docs.md b/codex-rs/tui/docs.md index 292c80445..973c18f88 100644 --- a/codex-rs/tui/docs.md +++ b/codex-rs/tui/docs.md @@ -142,17 +142,9 @@ The ACP protocol has no end-of-turn synchronization guarantee. Answer deltas, re The gate is checked both in the legacy exec/mcp handlers and in the normalized ACP tool-snapshot handlers. When `turn_finished` is true, those methods return immediately without rendering any UI. This is complementary to the interrupt queue: the queue handles deferral during streaming within a turn, while `turn_finished` handles events that arrive after the turn ends entirely. -**Stale Completed Guard** (`chatwidget/mod.rs`, `chatwidget/event_handlers.rs`): +**Stale Event Suppression:** -When a turn is interrupted (ESC), the ACP backend emits `TurnLifecycle::Aborted` synchronously, but the background task may still emit a stale `TurnLifecycle::Completed` later. If that stale `Completed` arrives after the next turn has started, `on_task_complete()` would set `turn_finished = true` and discard all subsequent tool events for the new turn. The `pending_stale_completes: i32` counter on `ChatWidget` acts as defense-in-depth against this race: - -| Action | Method | Effect | -|--------|--------|--------| -| Interrupt received | `on_interrupted_turn()` | Increments `pending_stale_completes` | -| New turn starts | `on_task_started()` | Resets `pending_stale_completes` to 0 (drains orphaned counters from backend-suppressed Completeds) | -| Stale Completed arrives | `on_task_complete()` | If counter > 0, decrements and returns early (skips turn finalization) | - -This is complementary to the ACP backend's `turn_interrupted` flag (`@/codex-rs/acp/docs.md`), which suppresses the stale `Completed` at the source. In the common case the backend suppresses the stale event and the counter is never drained; the `on_task_started` reset ensures those orphaned counters don't consume the next turn's real Completed. The counter still provides defense-in-depth for the rare race where a stale Completed slips past the backend guard. +Stale `TurnLifecycle::Completed` and `ErrorEvent` from cancelled turns are suppressed entirely at the ACP backend layer via a monotonic turn counter (`turn_id: Arc`). Each spawned backend task captures its turn ID and only emits tail events if the counter still matches. The TUI does not need any complementary guard for this race — see `@/codex-rs/acp/docs.md` for details. **Turn-Boundary Cleanup of Incomplete Tool Cells** (`chatwidget/event_handlers.rs`): diff --git a/codex-rs/tui/src/chatwidget/constructors.rs b/codex-rs/tui/src/chatwidget/constructors.rs index e183c05be..5a2ae9239 100644 --- a/codex-rs/tui/src/chatwidget/constructors.rs +++ b/codex-rs/tui/src/chatwidget/constructors.rs @@ -101,7 +101,6 @@ impl ChatWidget { #[cfg(feature = "nori-config")] loop_count_override: None, turn_finished: false, - pending_stale_completes: 0, plan_drawer_mode: PlanDrawerMode::Off, pinned_plan: None, terminal_title_animation_origin: std::time::Instant::now(), @@ -212,7 +211,6 @@ impl ChatWidget { #[cfg(feature = "nori-config")] loop_count_override: None, turn_finished: false, - pending_stale_completes: 0, plan_drawer_mode: PlanDrawerMode::Off, pinned_plan: None, terminal_title_animation_origin: std::time::Instant::now(), diff --git a/codex-rs/tui/src/chatwidget/event_handlers.rs b/codex-rs/tui/src/chatwidget/event_handlers.rs index e8122cdb0..c8ba22839 100644 --- a/codex-rs/tui/src/chatwidget/event_handlers.rs +++ b/codex-rs/tui/src/chatwidget/event_handlers.rs @@ -182,12 +182,6 @@ impl ChatWidget { // Raw reasoning uses the same flow as summarized reasoning pub(super) fn on_task_started(&mut self) { - // When the ACP backend suppresses a stale Completed (via the - // turn_interrupted flag), the pending_stale_completes counter is - // never drained. Reset it here so leftover counters from previous - // interrupts don't consume this turn's real Completed. - self.pending_stale_completes = 0; - self.bottom_pane.clear_ctrl_c_quit_hint(); self.bottom_pane.set_task_running(true); self.retry_status_header = None; @@ -202,15 +196,6 @@ impl ChatWidget { } pub(super) fn on_task_complete(&mut self, last_agent_message: Option) { - // If this Completed is a stale leftover from a cancelled turn, skip it. - // Each on_interrupted_turn increments pending_stale_completes; the - // matching background task will eventually emit Completed which we - // must ignore to avoid prematurely ending the current turn. - if self.pending_stale_completes > 0 { - self.pending_stale_completes -= 1; - return; - } - // If a stream is currently active, finalize it. self.flush_answer_stream_with_separator(); @@ -443,12 +428,6 @@ impl ChatWidget { /// When there are queued user messages, restore them into the composer /// separated by newlines rather than auto‑submitting the next one. pub(super) fn on_interrupted_turn(&mut self, _reason: TurnAbortReason) { - // The ACP backend usually suppresses the stale Completed via - // turn_interrupted, but if it races through, on_task_complete - // can use this counter to ignore it. The counter is reset by - // the next on_task_started as a safety net. - self.pending_stale_completes += 1; - // Finalize, log a gentle prompt, and clear running state. self.finalize_turn(); self.cancel_loop(); diff --git a/codex-rs/tui/src/chatwidget/mod.rs b/codex-rs/tui/src/chatwidget/mod.rs index 626ca173a..874c9a77b 100644 --- a/codex-rs/tui/src/chatwidget/mod.rs +++ b/codex-rs/tui/src/chatwidget/mod.rs @@ -425,11 +425,6 @@ pub(crate) struct ChatWidget { // Gate: set when AgentMessage is received, cleared on next TaskStarted. // While true, late-arriving tool events are silently discarded. turn_finished: bool, - // Defense-in-depth counter for stale TurnLifecycle::Completed events - // after interrupts. Incremented by on_interrupted_turn, decremented by - // on_task_complete, and reset to 0 by on_task_started (to drain orphaned - // counters when the ACP backend suppresses the stale Completed). - pending_stale_completes: i32, /// Whether and how plan updates are rendered in a pinned drawer instead of /// history cells. plan_drawer_mode: PlanDrawerMode, diff --git a/codex-rs/tui/src/chatwidget/tests/mod.rs b/codex-rs/tui/src/chatwidget/tests/mod.rs index c667f0178..a9c74be1d 100644 --- a/codex-rs/tui/src/chatwidget/tests/mod.rs +++ b/codex-rs/tui/src/chatwidget/tests/mod.rs @@ -314,7 +314,6 @@ pub(crate) fn make_chatwidget_manual() -> ( #[cfg(feature = "nori-config")] loop_count_override: None, turn_finished: false, - pending_stale_completes: 0, plan_drawer_mode: PlanDrawerMode::Off, pinned_plan: None, terminal_title_animation_origin: std::time::Instant::now(), diff --git a/codex-rs/tui/src/chatwidget/tests/part7.rs b/codex-rs/tui/src/chatwidget/tests/part7.rs index b9498619d..88cfe617d 100644 --- a/codex-rs/tui/src/chatwidget/tests/part7.rs +++ b/codex-rs/tui/src/chatwidget/tests/part7.rs @@ -1,19 +1,16 @@ use super::*; -/// When the ACP backend suppresses the stale Completed (common case), the -/// next turn's real Completed must not be consumed as stale. +/// After an interrupt, the ACP backend's monotonic turn counter guarantees +/// that the stale Completed from the cancelled task is never emitted. The +/// TUI should handle the normal sequence without issues. /// /// Sequence: /// 1. Started(A) → task running -/// 2. Aborted(A) → task stopped (user pressed ESC), counter = 1 -/// 3. Started(B) → counter reset to 0 +/// 2. Aborted(A) → task stopped (user pressed ESC) +/// 3. Started(B) → new turn begins /// 4. Completed(B) → should finalize turn B normally -/// -/// Before the fix, the counter from step 2 was never drained (because the -/// ACP backend suppressed the stale Completed), so the real Completed in -/// step 4 was consumed as stale, leaving the spinner running forever. #[test] -fn acp_suppressed_stale_should_not_block_next_turn_completion() { +fn interrupt_then_new_turn_completes_normally() { let (mut chat, mut rx, _op_rx) = make_chatwidget_manual(); // Start and interrupt turn A @@ -43,10 +40,11 @@ fn acp_suppressed_stale_should_not_block_next_turn_completion() { ); } -/// Multiple consecutive interrupts where ACP suppresses all stale Completeds. -/// The final real turn's Completed must still finalize normally. +/// Multiple consecutive interrupts followed by a real turn. The ACP backend's +/// monotonic turn counter suppresses all stale Completeds, so the final real +/// turn's Completed must still finalize normally. #[test] -fn multiple_interrupts_with_acp_suppression_should_not_hang() { +fn multiple_interrupts_then_real_turn_completes_normally() { let (mut chat, mut rx, _op_rx) = make_chatwidget_manual(); // Interrupt twice in a row @@ -70,7 +68,7 @@ fn multiple_interrupts_with_acp_suppression_should_not_hang() { begin_exec(&mut chat, "real-call", "echo real"); assert!( chat.active_cell.is_some(), - "ExecCell should be created - counter was reset by on_task_started" + "ExecCell should be created during real turn" ); // Real Completed should finalize the turn