Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 63 additions & 27 deletions codex-rs/core/src/tools/handlers/agent_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
use async_trait::async_trait;
use codex_protocol::ThreadId;
use codex_protocol::protocol::AgentStatus;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_protocol::user_input::UserInput;
use futures::StreamExt;
use futures::stream::FuturesUnordered;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
Expand All @@ -26,8 +29,10 @@ use std::collections::HashSet;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::watch::Receiver;
use tokio::time::Duration;
use tokio::time::Instant;
use tokio::time::timeout;
use uuid::Uuid;

pub struct BatchJobHandler;
Expand Down Expand Up @@ -103,6 +108,7 @@ struct JobRunnerOptions {
struct ActiveJobItem {
item_id: String,
started_at: Instant,
status_rx: Option<Receiver<AgentStatus>>,
}

struct JobProgressEmitter {
Expand Down Expand Up @@ -670,6 +676,12 @@ async fn run_agent_job_loop(
ActiveJobItem {
item_id: item.item_id.clone(),
started_at: Instant::now(),
status_rx: session
.services
.agent_control
.subscribe_status(thread_id)
.await
.ok(),
},
);
progressed = true;
Expand Down Expand Up @@ -702,7 +714,7 @@ async fn run_agent_job_loop(
break;
}
if !progressed {
tokio::time::sleep(STATUS_POLL_INTERVAL).await;
wait_for_status_change(&active_items).await;
}
continue;
}
Expand Down Expand Up @@ -833,6 +845,12 @@ async fn recover_running_items(
ActiveJobItem {
item_id: item.item_id.clone(),
started_at: started_at_from_item(&item),
status_rx: session
.services
.agent_control
.subscribe_status(thread_id)
.await
.ok(),
},
);
}
Expand All @@ -846,13 +864,44 @@ async fn find_finished_threads(
) -> Vec<(ThreadId, String)> {
let mut finished = Vec::new();
for (thread_id, item) in active_items {
if is_final(&session.services.agent_control.get_status(*thread_id).await) {
let status = active_item_status(session.as_ref(), *thread_id, item).await;
if is_final(&status) {
finished.push((*thread_id, item.item_id.clone()));
}
}
finished
}

async fn active_item_status(
session: &Session,
thread_id: ThreadId,
item: &ActiveJobItem,
) -> AgentStatus {
if let Some(status_rx) = item.status_rx.as_ref()
&& status_rx.has_changed().is_ok()
{
return status_rx.borrow().clone();
}
session.services.agent_control.get_status(thread_id).await
}

async fn wait_for_status_change(active_items: &HashMap<ThreadId, ActiveJobItem>) {
let mut waiters = FuturesUnordered::new();
for item in active_items.values() {
if let Some(status_rx) = item.status_rx.as_ref() {
let mut status_rx = status_rx.clone();
waiters.push(async move {
let _ = status_rx.changed().await;
});
}
}
if waiters.is_empty() {
tokio::time::sleep(STATUS_POLL_INTERVAL).await;
return;
}
let _ = timeout(STATUS_POLL_INTERVAL, waiters.next()).await;
}

async fn reap_stale_active_items(
session: Arc<Session>,
db: Arc<codex_state::StateRuntime>,
Expand Down Expand Up @@ -890,37 +939,24 @@ async fn finalize_finished_item(
item_id: &str,
thread_id: ThreadId,
) -> anyhow::Result<()> {
let mut item = db
let item = db
.get_agent_job_item(job_id, item_id)
.await?
.ok_or_else(|| {
anyhow::anyhow!("job item not found for finalization: {job_id}/{item_id}")
})?;
if item.result_json.is_none() {
tokio::time::sleep(Duration::from_millis(250)).await;
item = db
.get_agent_job_item(job_id, item_id)
.await?
.ok_or_else(|| {
anyhow::anyhow!("job item not found after grace period: {job_id}/{item_id}")
})?;
}
if item.result_json.is_some() {
if !db.mark_agent_job_item_completed(job_id, item_id).await? {
db.mark_agent_job_item_failed(
job_id,
item_id,
"worker reported result but item could not transition to completed",
)
.await?;
if matches!(item.status, codex_state::AgentJobItemStatus::Running) {
if item.result_json.is_some() {
let _ = db.mark_agent_job_item_completed(job_id, item_id).await?;
} else {
let _ = db
.mark_agent_job_item_failed(
job_id,
item_id,
"worker finished without calling report_agent_job_result",
)
.await?;
}
} else {
db.mark_agent_job_item_failed(
job_id,
item_id,
"worker finished without calling report_agent_job_result",
)
.await?;
}
let _ = session
.services
Expand Down
124 changes: 123 additions & 1 deletion codex-rs/state/src/runtime/agent_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,20 +435,25 @@ WHERE job_id = ? AND item_id = ? AND status = ?
r#"
UPDATE agent_job_items
SET
status = ?,
result_json = ?,
reported_at = ?,
completed_at = ?,
updated_at = ?,
last_error = NULL
last_error = NULL,
assigned_thread_id = NULL
WHERE
job_id = ?
AND item_id = ?
AND status = ?
AND assigned_thread_id = ?
"#,
)
.bind(AgentJobItemStatus::Completed.as_str())
.bind(serialized)
.bind(now)
.bind(now)
.bind(now)
.bind(job_id)
.bind(item_id)
.bind(AgentJobItemStatus::Running.as_str())
Expand Down Expand Up @@ -560,3 +565,120 @@ WHERE job_id = ?
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::test_support::unique_temp_dir;
use pretty_assertions::assert_eq;
use serde_json::json;

async fn create_running_single_item_job(
runtime: &StateRuntime,
) -> anyhow::Result<(String, String, String)> {
let job_id = "job-1".to_string();
let item_id = "item-1".to_string();
let thread_id = "thread-1".to_string();
runtime
.create_agent_job(
&AgentJobCreateParams {
id: job_id.clone(),
name: "test-job".to_string(),
instruction: "Return a result".to_string(),
auto_export: true,
max_runtime_seconds: None,
output_schema_json: None,
input_headers: vec!["path".to_string()],
input_csv_path: "/tmp/in.csv".to_string(),
output_csv_path: "/tmp/out.csv".to_string(),
},
&[AgentJobItemCreateParams {
item_id: item_id.clone(),
row_index: 0,
source_id: None,
row_json: json!({"path":"file-1"}),
}],
)
.await?;
runtime.mark_agent_job_running(job_id.as_str()).await?;
let marked_running = runtime
.mark_agent_job_item_running_with_thread(
job_id.as_str(),
item_id.as_str(),
thread_id.as_str(),
)
.await?;
assert!(marked_running);
Ok((job_id, item_id, thread_id))
}

#[tokio::test]
async fn report_agent_job_item_result_completes_item_atomically() -> anyhow::Result<()> {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home, "test-provider".to_string()).await?;
let (job_id, item_id, thread_id) = create_running_single_item_job(runtime.as_ref()).await?;

let accepted = runtime
.report_agent_job_item_result(
job_id.as_str(),
item_id.as_str(),
thread_id.as_str(),
&json!({"ok": true}),
)
.await?;
assert!(accepted);

let item = runtime
.get_agent_job_item(job_id.as_str(), item_id.as_str())
.await?
.expect("job item should exist");
assert_eq!(item.status, AgentJobItemStatus::Completed);
assert_eq!(item.result_json, Some(json!({"ok": true})));
assert_eq!(item.assigned_thread_id, None);
assert_eq!(item.last_error, None);
assert!(item.reported_at.is_some());
assert!(item.completed_at.is_some());
let progress = runtime.get_agent_job_progress(job_id.as_str()).await?;
assert_eq!(
progress,
AgentJobProgress {
total_items: 1,
pending_items: 0,
running_items: 0,
completed_items: 1,
failed_items: 0,
}
);
Ok(())
}

#[tokio::test]
async fn report_agent_job_item_result_rejects_late_reports() -> anyhow::Result<()> {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home, "test-provider".to_string()).await?;
let (job_id, item_id, thread_id) = create_running_single_item_job(runtime.as_ref()).await?;

let marked_failed = runtime
.mark_agent_job_item_failed(job_id.as_str(), item_id.as_str(), "missing report")
.await?;
assert!(marked_failed);
let accepted = runtime
.report_agent_job_item_result(
job_id.as_str(),
item_id.as_str(),
thread_id.as_str(),
&json!({"late": true}),
)
.await?;
assert!(!accepted);

let item = runtime
.get_agent_job_item(job_id.as_str(), item_id.as_str())
.await?
.expect("job item should exist");
assert_eq!(item.status, AgentJobItemStatus::Failed);
assert_eq!(item.result_json, None);
assert_eq!(item.last_error, Some("missing report".to_string()));
Ok(())
}
}
Loading