Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions codex-rs/core/src/codex_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ use crate::state::TaskKind;
use crate::tasks::SessionTask;
use crate::tasks::SessionTaskContext;
use crate::tools::ToolRouter;
use crate::tools::context::FunctionToolOutput;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload;
use crate::tools::handlers::ShellHandler;
use crate::tools::handlers::UnifiedExecHandler;
use crate::tools::registry::AnyToolResult;
use crate::tools::registry::ToolHandler;
use crate::tools::router::ToolCallSource;
use crate::turn_diff_tracker::TurnDiffTracker;
Expand Down Expand Up @@ -119,10 +119,14 @@ use std::time::Duration as StdDuration;
#[path = "codex_tests_guardian.rs"]
mod guardian_tests;

use codex_protocol::models::function_call_output_content_items_to_text;

fn expect_text_tool_output(output: &FunctionToolOutput) -> String {
function_call_output_content_items_to_text(&output.body).unwrap_or_default()
fn expect_text_tool_output(output: &AnyToolResult) -> String {
let ResponseInputItem::FunctionCallOutput { output, .. } = output
.result
.to_response_item(&output.call_id, &output.payload)
else {
panic!("expected function call output");
};
output.body.to_text().unwrap_or_default()
}

struct InstructionsTestCase {
Expand Down Expand Up @@ -5115,8 +5119,7 @@ async fn fatal_tool_error_stops_turn_and_reports_error() {
ToolCallSource::Direct,
)
.await
.err()
.expect("expected fatal error");
.expect_err("expected fatal error");

match err {
FunctionCallError::Fatal(message) => {
Expand Down
13 changes: 9 additions & 4 deletions codex-rs/core/src/codex_tests_guardian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::exec::ExecParams;
use crate::exec_policy::ExecPolicyManager;
use crate::guardian::GUARDIAN_REVIEWER_NAME;
use crate::sandboxing::SandboxPermissions;
use crate::tools::context::FunctionToolOutput;
use crate::tools::registry::AnyToolResult;
use crate::turn_diff_tracker::TurnDiffTracker;
use codex_app_server_protocol::ConfigLayerSource;
use codex_exec_server::EnvironmentManager;
Expand All @@ -20,7 +20,6 @@ use codex_protocol::models::ContentItem;
use codex_protocol::models::NetworkPermissions;
use codex_protocol::models::PermissionProfile;
use codex_protocol::models::ResponseItem;
use codex_protocol::models::function_call_output_content_items_to_text;
use codex_protocol::permissions::FileSystemSandboxPolicy;
use codex_protocol::permissions::NetworkSandboxPolicy;
use codex_protocol::protocol::AskForApproval;
Expand All @@ -40,8 +39,14 @@ use std::fs;
use std::sync::Arc;
use tempfile::tempdir;

fn expect_text_output(output: &FunctionToolOutput) -> String {
function_call_output_content_items_to_text(&output.body).unwrap_or_default()
fn expect_text_output(output: &AnyToolResult) -> String {
let ResponseInputItem::FunctionCallOutput { output, .. } = output
.result
.to_response_item(&output.call_id, &output.payload)
else {
panic!("expected function call output");
};
output.body.to_text().unwrap_or_default()
}

#[tokio::test]
Expand Down
53 changes: 31 additions & 22 deletions codex-rs/core/src/tools/code_mode/execute_handler.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use async_trait::async_trait;

use crate::function_tool::FunctionCallError;
use crate::tools::context::FunctionToolOutput;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload;
use crate::tools::registry::AnyToolResult;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
use futures::future::BoxFuture;

use super::ExecContext;
use super::PUBLIC_TOOL_NAME;
Expand Down Expand Up @@ -53,10 +53,7 @@ impl CodeModeExecuteHandler {
}
}

#[async_trait]
impl ToolHandler for CodeModeExecuteHandler {
type Output = FunctionToolOutput;

fn kind(&self) -> ToolKind {
ToolKind::Function
}
Expand All @@ -65,23 +62,35 @@ impl ToolHandler for CodeModeExecuteHandler {
matches!(payload, ToolPayload::Custom { .. })
}

async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
let ToolInvocation {
session,
turn,
call_id,
tool_name,
payload,
..
} = invocation;
fn handle(
&self,
invocation: ToolInvocation,
) -> BoxFuture<'_, Result<AnyToolResult, FunctionCallError>> {
Box::pin(async move {
let ToolInvocation {
session,
turn,
call_id,
tool_name,
payload,
..
} = invocation;
let payload_for_result = payload.clone();

match payload {
ToolPayload::Custom { input } if tool_name == PUBLIC_TOOL_NAME => {
self.execute(session, turn, call_id, input).await
}
_ => Err(FunctionCallError::RespondToModel(format!(
"{PUBLIC_TOOL_NAME} expects raw JavaScript source text"
))),
}
let result = match payload {
ToolPayload::Custom { input } if tool_name == PUBLIC_TOOL_NAME => {
self.execute(session, turn, call_id.clone(), input).await
}
_ => Err(FunctionCallError::RespondToModel(format!(
"{PUBLIC_TOOL_NAME} expects raw JavaScript source text"
))),
}?;

Ok(AnyToolResult {
call_id,
payload: payload_for_result,
result: Box::new(result),
})
})
}
}
84 changes: 47 additions & 37 deletions codex-rs/core/src/tools/code_mode/wait_handler.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use async_trait::async_trait;
use futures::future::BoxFuture;
use serde::Deserialize;

use crate::function_tool::FunctionCallError;
use crate::tools::context::FunctionToolOutput;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload;
use crate::tools::registry::AnyToolResult;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;

Expand Down Expand Up @@ -39,46 +39,56 @@ where
})
}

#[async_trait]
impl ToolHandler for CodeModeWaitHandler {
type Output = FunctionToolOutput;

fn kind(&self) -> ToolKind {
ToolKind::Function
}

async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
let ToolInvocation {
session,
turn,
tool_name,
payload,
..
} = invocation;
fn handle(
&self,
invocation: ToolInvocation,
) -> BoxFuture<'_, Result<AnyToolResult, FunctionCallError>> {
Box::pin(async move {
let ToolInvocation {
session,
turn,
call_id,
tool_name,
payload,
..
} = invocation;
let payload_for_result = payload.clone();

let result = match payload {
ToolPayload::Function { arguments } if tool_name == WAIT_TOOL_NAME => {
let args: ExecWaitArgs = parse_arguments(&arguments)?;
let exec = ExecContext { session, turn };
let started_at = std::time::Instant::now();
let response = exec
.session
.services
.code_mode_service
.wait(codex_code_mode::WaitRequest {
cell_id: args.cell_id,
yield_time_ms: args.yield_time_ms,
terminate: args.terminate,
})
.await
.map_err(FunctionCallError::RespondToModel)?;
handle_runtime_response(&exec, response, args.max_tokens, started_at)
.await
.map_err(FunctionCallError::RespondToModel)
}
_ => Err(FunctionCallError::RespondToModel(format!(
"{WAIT_TOOL_NAME} expects JSON arguments"
))),
}?;

match payload {
ToolPayload::Function { arguments } if tool_name == WAIT_TOOL_NAME => {
let args: ExecWaitArgs = parse_arguments(&arguments)?;
let exec = ExecContext { session, turn };
let started_at = std::time::Instant::now();
let response = exec
.session
.services
.code_mode_service
.wait(codex_code_mode::WaitRequest {
cell_id: args.cell_id,
yield_time_ms: args.yield_time_ms,
terminate: args.terminate,
})
.await
.map_err(FunctionCallError::RespondToModel)?;
handle_runtime_response(&exec, response, args.max_tokens, started_at)
.await
.map_err(FunctionCallError::RespondToModel)
}
_ => Err(FunctionCallError::RespondToModel(format!(
"{WAIT_TOOL_NAME} expects JSON arguments"
))),
}
Ok(AnyToolResult {
call_id,
payload: payload_for_result,
result: Box::new(result),
})
})
}
}
69 changes: 42 additions & 27 deletions codex-rs/core/src/tools/handlers/agent_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload;
use crate::tools::handlers::multi_agents::build_agent_spawn_config;
use crate::tools::handlers::parse_arguments;
use crate::tools::registry::AnyToolResult;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
use async_trait::async_trait;
use codex_protocol::ThreadId;
use codex_protocol::protocol::AgentStatus;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_protocol::user_input::UserInput;
use futures::StreamExt;
use futures::future::BoxFuture;
use futures::stream::FuturesUnordered;
use serde::Deserialize;
use serde::Serialize;
Expand Down Expand Up @@ -178,10 +179,7 @@ impl JobProgressEmitter {
}
}

#[async_trait]
impl ToolHandler for BatchJobHandler {
type Output = FunctionToolOutput;

fn kind(&self) -> ToolKind {
ToolKind::Function
}
Expand All @@ -190,31 +188,48 @@ impl ToolHandler for BatchJobHandler {
matches!(payload, ToolPayload::Function { .. })
}

async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
let ToolInvocation {
session,
turn,
tool_name,
payload,
..
} = invocation;
fn handle(
&self,
invocation: ToolInvocation,
) -> BoxFuture<'_, Result<AnyToolResult, FunctionCallError>> {
Box::pin(async move {
let ToolInvocation {
session,
turn,
call_id,
tool_name,
payload,
..
} = invocation;
let payload_for_result = payload.clone();

let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"agent jobs handler received unsupported payload".to_string(),
));
}
};
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"agent jobs handler received unsupported payload".to_string(),
));
}
};

match tool_name.as_str() {
"spawn_agents_on_csv" => spawn_agents_on_csv::handle(session, turn, arguments).await,
"report_agent_job_result" => report_agent_job_result::handle(session, arguments).await,
other => Err(FunctionCallError::RespondToModel(format!(
"unsupported agent job tool {other}"
))),
}
let result = match tool_name.as_str() {
"spawn_agents_on_csv" => {
spawn_agents_on_csv::handle(session, turn, arguments).await
}
"report_agent_job_result" => {
report_agent_job_result::handle(session, arguments).await
}
other => Err(FunctionCallError::RespondToModel(format!(
"unsupported agent job tool {other}"
))),
}?;

Ok(AnyToolResult {
call_id,
payload: payload_for_result,
result: Box::new(result),
})
})
}
}

Expand Down
Loading
Loading