From a64abeb83297c9a39d647da16409f44aeccf026a Mon Sep 17 00:00:00 2001 From: Dave Aitel Date: Mon, 9 Mar 2026 16:08:01 -0400 Subject: [PATCH] Fix agent job finalization and status waiting --- .../core/src/tools/handlers/agent_jobs.rs | 90 +++++++++---- codex-rs/state/src/runtime/agent_jobs.rs | 124 +++++++++++++++++- 2 files changed, 186 insertions(+), 28 deletions(-) diff --git a/codex-rs/core/src/tools/handlers/agent_jobs.rs b/codex-rs/core/src/tools/handlers/agent_jobs.rs index 4e786178f87..2aaa306068d 100644 --- a/codex-rs/core/src/tools/handlers/agent_jobs.rs +++ b/codex-rs/core/src/tools/handlers/agent_jobs.rs @@ -15,9 +15,12 @@ use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; use async_trait::async_trait; use codex_protocol::ThreadId; +use codex_protocol::protocol::AgentStatus; use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::SubAgentSource; use codex_protocol::user_input::UserInput; +use futures::StreamExt; +use futures::stream::FuturesUnordered; use serde::Deserialize; use serde::Serialize; use serde_json::Value; @@ -26,8 +29,10 @@ use std::collections::HashSet; use std::path::Path; use std::path::PathBuf; use std::sync::Arc; +use tokio::sync::watch::Receiver; use tokio::time::Duration; use tokio::time::Instant; +use tokio::time::timeout; use uuid::Uuid; pub struct BatchJobHandler; @@ -103,6 +108,7 @@ struct JobRunnerOptions { struct ActiveJobItem { item_id: String, started_at: Instant, + status_rx: Option>, } struct JobProgressEmitter { @@ -670,6 +676,12 @@ async fn run_agent_job_loop( ActiveJobItem { item_id: item.item_id.clone(), started_at: Instant::now(), + status_rx: session + .services + .agent_control + .subscribe_status(thread_id) + .await + .ok(), }, ); progressed = true; @@ -702,7 +714,7 @@ async fn run_agent_job_loop( break; } if !progressed { - tokio::time::sleep(STATUS_POLL_INTERVAL).await; + wait_for_status_change(&active_items).await; } continue; } @@ -833,6 +845,12 @@ async fn recover_running_items( ActiveJobItem { item_id: item.item_id.clone(), started_at: started_at_from_item(&item), + status_rx: session + .services + .agent_control + .subscribe_status(thread_id) + .await + .ok(), }, ); } @@ -846,13 +864,44 @@ async fn find_finished_threads( ) -> Vec<(ThreadId, String)> { let mut finished = Vec::new(); for (thread_id, item) in active_items { - if is_final(&session.services.agent_control.get_status(*thread_id).await) { + let status = active_item_status(session.as_ref(), *thread_id, item).await; + if is_final(&status) { finished.push((*thread_id, item.item_id.clone())); } } finished } +async fn active_item_status( + session: &Session, + thread_id: ThreadId, + item: &ActiveJobItem, +) -> AgentStatus { + if let Some(status_rx) = item.status_rx.as_ref() + && status_rx.has_changed().is_ok() + { + return status_rx.borrow().clone(); + } + session.services.agent_control.get_status(thread_id).await +} + +async fn wait_for_status_change(active_items: &HashMap) { + let mut waiters = FuturesUnordered::new(); + for item in active_items.values() { + if let Some(status_rx) = item.status_rx.as_ref() { + let mut status_rx = status_rx.clone(); + waiters.push(async move { + let _ = status_rx.changed().await; + }); + } + } + if waiters.is_empty() { + tokio::time::sleep(STATUS_POLL_INTERVAL).await; + return; + } + let _ = timeout(STATUS_POLL_INTERVAL, waiters.next()).await; +} + async fn reap_stale_active_items( session: Arc, db: Arc, @@ -890,37 +939,24 @@ async fn finalize_finished_item( item_id: &str, thread_id: ThreadId, ) -> anyhow::Result<()> { - let mut item = db + let item = db .get_agent_job_item(job_id, item_id) .await? .ok_or_else(|| { anyhow::anyhow!("job item not found for finalization: {job_id}/{item_id}") })?; - if item.result_json.is_none() { - tokio::time::sleep(Duration::from_millis(250)).await; - item = db - .get_agent_job_item(job_id, item_id) - .await? - .ok_or_else(|| { - anyhow::anyhow!("job item not found after grace period: {job_id}/{item_id}") - })?; - } - if item.result_json.is_some() { - if !db.mark_agent_job_item_completed(job_id, item_id).await? { - db.mark_agent_job_item_failed( - job_id, - item_id, - "worker reported result but item could not transition to completed", - ) - .await?; + if matches!(item.status, codex_state::AgentJobItemStatus::Running) { + if item.result_json.is_some() { + let _ = db.mark_agent_job_item_completed(job_id, item_id).await?; + } else { + let _ = db + .mark_agent_job_item_failed( + job_id, + item_id, + "worker finished without calling report_agent_job_result", + ) + .await?; } - } else { - db.mark_agent_job_item_failed( - job_id, - item_id, - "worker finished without calling report_agent_job_result", - ) - .await?; } let _ = session .services diff --git a/codex-rs/state/src/runtime/agent_jobs.rs b/codex-rs/state/src/runtime/agent_jobs.rs index c6856059457..3f5526c58dd 100644 --- a/codex-rs/state/src/runtime/agent_jobs.rs +++ b/codex-rs/state/src/runtime/agent_jobs.rs @@ -435,10 +435,13 @@ WHERE job_id = ? AND item_id = ? AND status = ? r#" UPDATE agent_job_items SET + status = ?, result_json = ?, reported_at = ?, + completed_at = ?, updated_at = ?, - last_error = NULL + last_error = NULL, + assigned_thread_id = NULL WHERE job_id = ? AND item_id = ? @@ -446,9 +449,11 @@ WHERE AND assigned_thread_id = ? "#, ) + .bind(AgentJobItemStatus::Completed.as_str()) .bind(serialized) .bind(now) .bind(now) + .bind(now) .bind(job_id) .bind(item_id) .bind(AgentJobItemStatus::Running.as_str()) @@ -560,3 +565,120 @@ WHERE job_id = ? }) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::runtime::test_support::unique_temp_dir; + use pretty_assertions::assert_eq; + use serde_json::json; + + async fn create_running_single_item_job( + runtime: &StateRuntime, + ) -> anyhow::Result<(String, String, String)> { + let job_id = "job-1".to_string(); + let item_id = "item-1".to_string(); + let thread_id = "thread-1".to_string(); + runtime + .create_agent_job( + &AgentJobCreateParams { + id: job_id.clone(), + name: "test-job".to_string(), + instruction: "Return a result".to_string(), + auto_export: true, + max_runtime_seconds: None, + output_schema_json: None, + input_headers: vec!["path".to_string()], + input_csv_path: "/tmp/in.csv".to_string(), + output_csv_path: "/tmp/out.csv".to_string(), + }, + &[AgentJobItemCreateParams { + item_id: item_id.clone(), + row_index: 0, + source_id: None, + row_json: json!({"path":"file-1"}), + }], + ) + .await?; + runtime.mark_agent_job_running(job_id.as_str()).await?; + let marked_running = runtime + .mark_agent_job_item_running_with_thread( + job_id.as_str(), + item_id.as_str(), + thread_id.as_str(), + ) + .await?; + assert!(marked_running); + Ok((job_id, item_id, thread_id)) + } + + #[tokio::test] + async fn report_agent_job_item_result_completes_item_atomically() -> anyhow::Result<()> { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home, "test-provider".to_string()).await?; + let (job_id, item_id, thread_id) = create_running_single_item_job(runtime.as_ref()).await?; + + let accepted = runtime + .report_agent_job_item_result( + job_id.as_str(), + item_id.as_str(), + thread_id.as_str(), + &json!({"ok": true}), + ) + .await?; + assert!(accepted); + + let item = runtime + .get_agent_job_item(job_id.as_str(), item_id.as_str()) + .await? + .expect("job item should exist"); + assert_eq!(item.status, AgentJobItemStatus::Completed); + assert_eq!(item.result_json, Some(json!({"ok": true}))); + assert_eq!(item.assigned_thread_id, None); + assert_eq!(item.last_error, None); + assert!(item.reported_at.is_some()); + assert!(item.completed_at.is_some()); + let progress = runtime.get_agent_job_progress(job_id.as_str()).await?; + assert_eq!( + progress, + AgentJobProgress { + total_items: 1, + pending_items: 0, + running_items: 0, + completed_items: 1, + failed_items: 0, + } + ); + Ok(()) + } + + #[tokio::test] + async fn report_agent_job_item_result_rejects_late_reports() -> anyhow::Result<()> { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home, "test-provider".to_string()).await?; + let (job_id, item_id, thread_id) = create_running_single_item_job(runtime.as_ref()).await?; + + let marked_failed = runtime + .mark_agent_job_item_failed(job_id.as_str(), item_id.as_str(), "missing report") + .await?; + assert!(marked_failed); + let accepted = runtime + .report_agent_job_item_result( + job_id.as_str(), + item_id.as_str(), + thread_id.as_str(), + &json!({"late": true}), + ) + .await?; + assert!(!accepted); + + let item = runtime + .get_agent_job_item(job_id.as_str(), item_id.as_str()) + .await? + .expect("job item should exist"); + assert_eq!(item.status, AgentJobItemStatus::Failed); + assert_eq!(item.result_json, None); + assert_eq!(item.last_error, Some("missing report".to_string())); + Ok(()) + } +}