From 2810b26ffd983179d2ee44e1136b3a3a94f01cf3 Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Thu, 2 Apr 2026 15:16:56 -0700 Subject: [PATCH] core: type-erase ToolHandler outputs --- codex-rs/core/src/codex_tests.rs | 17 +- codex-rs/core/src/codex_tests_guardian.rs | 13 +- .../src/tools/code_mode/execute_handler.rs | 53 ++- .../core/src/tools/code_mode/wait_handler.rs | 84 ++-- .../core/src/tools/handlers/agent_jobs.rs | 69 +-- .../core/src/tools/handlers/apply_patch.rs | 226 +++++----- codex-rs/core/src/tools/handlers/dynamic.rs | 88 ++-- codex-rs/core/src/tools/handlers/js_repl.rs | 177 ++++---- codex-rs/core/src/tools/handlers/list_dir.rs | 104 +++-- codex-rs/core/src/tools/handlers/mcp.rs | 86 ++-- .../core/src/tools/handlers/mcp_resource.rs | 124 +++--- .../core/src/tools/handlers/multi_agents.rs | 3 +- .../handlers/multi_agents/close_agent.rs | 165 +++---- .../handlers/multi_agents/resume_agent.rs | 197 +++++---- .../tools/handlers/multi_agents/send_input.rs | 137 +++--- .../src/tools/handlers/multi_agents/spawn.rs | 287 ++++++------ .../src/tools/handlers/multi_agents/wait.rs | 303 ++++++------- .../src/tools/handlers/multi_agents_v2.rs | 3 +- .../handlers/multi_agents_v2/close_agent.rs | 181 ++++---- .../handlers/multi_agents_v2/followup_task.rs | 39 +- .../handlers/multi_agents_v2/list_agents.rs | 54 ++- .../handlers/multi_agents_v2/send_message.rs | 39 +- .../tools/handlers/multi_agents_v2/spawn.rs | 339 +++++++------- .../tools/handlers/multi_agents_v2/wait.rs | 117 ++--- codex-rs/core/src/tools/handlers/plan.rs | 58 +-- .../src/tools/handlers/request_permissions.rs | 90 ++-- .../src/tools/handlers/request_user_input.rs | 90 ++-- codex-rs/core/src/tools/handlers/shell.rs | 239 +++++----- codex-rs/core/src/tools/handlers/test_sync.rs | 70 +-- .../core/src/tools/handlers/tool_search.rs | 138 +++--- .../core/src/tools/handlers/tool_suggest.rs | 226 +++++----- .../core/src/tools/handlers/unified_exec.rs | 413 +++++++++--------- .../core/src/tools/handlers/view_image.rs | 229 +++++----- codex-rs/core/src/tools/registry.rs | 129 +++--- codex-rs/core/src/tools/registry_tests.rs | 16 +- codex-rs/core/src/tools/router_tests.rs | 6 +- 36 files changed, 2437 insertions(+), 2172 deletions(-) diff --git a/codex-rs/core/src/codex_tests.rs b/codex-rs/core/src/codex_tests.rs index 4ff31e8bd00f..2f76e31779e5 100644 --- a/codex-rs/core/src/codex_tests.rs +++ b/codex-rs/core/src/codex_tests.rs @@ -41,11 +41,11 @@ use crate::state::TaskKind; use crate::tasks::SessionTask; use crate::tasks::SessionTaskContext; use crate::tools::ToolRouter; -use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; use crate::tools::handlers::ShellHandler; use crate::tools::handlers::UnifiedExecHandler; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::router::ToolCallSource; use crate::turn_diff_tracker::TurnDiffTracker; @@ -119,10 +119,14 @@ use std::time::Duration as StdDuration; #[path = "codex_tests_guardian.rs"] mod guardian_tests; -use codex_protocol::models::function_call_output_content_items_to_text; - -fn expect_text_tool_output(output: &FunctionToolOutput) -> String { - function_call_output_content_items_to_text(&output.body).unwrap_or_default() +fn expect_text_tool_output(output: &AnyToolResult) -> String { + let ResponseInputItem::FunctionCallOutput { output, .. } = output + .result + .to_response_item(&output.call_id, &output.payload) + else { + panic!("expected function call output"); + }; + output.body.to_text().unwrap_or_default() } struct InstructionsTestCase { @@ -5115,8 +5119,7 @@ async fn fatal_tool_error_stops_turn_and_reports_error() { ToolCallSource::Direct, ) .await - .err() - .expect("expected fatal error"); + .expect_err("expected fatal error"); match err { FunctionCallError::Fatal(message) => { diff --git a/codex-rs/core/src/codex_tests_guardian.rs b/codex-rs/core/src/codex_tests_guardian.rs index 4f60c2f28eb0..343e2c10b182 100644 --- a/codex-rs/core/src/codex_tests_guardian.rs +++ b/codex-rs/core/src/codex_tests_guardian.rs @@ -8,7 +8,7 @@ use crate::exec::ExecParams; use crate::exec_policy::ExecPolicyManager; use crate::guardian::GUARDIAN_REVIEWER_NAME; use crate::sandboxing::SandboxPermissions; -use crate::tools::context::FunctionToolOutput; +use crate::tools::registry::AnyToolResult; use crate::turn_diff_tracker::TurnDiffTracker; use codex_app_server_protocol::ConfigLayerSource; use codex_exec_server::EnvironmentManager; @@ -20,7 +20,6 @@ use codex_protocol::models::ContentItem; use codex_protocol::models::NetworkPermissions; use codex_protocol::models::PermissionProfile; use codex_protocol::models::ResponseItem; -use codex_protocol::models::function_call_output_content_items_to_text; use codex_protocol::permissions::FileSystemSandboxPolicy; use codex_protocol::permissions::NetworkSandboxPolicy; use codex_protocol::protocol::AskForApproval; @@ -40,8 +39,14 @@ use std::fs; use std::sync::Arc; use tempfile::tempdir; -fn expect_text_output(output: &FunctionToolOutput) -> String { - function_call_output_content_items_to_text(&output.body).unwrap_or_default() +fn expect_text_output(output: &AnyToolResult) -> String { + let ResponseInputItem::FunctionCallOutput { output, .. } = output + .result + .to_response_item(&output.call_id, &output.payload) + else { + panic!("expected function call output"); + }; + output.body.to_text().unwrap_or_default() } #[tokio::test] diff --git a/codex-rs/core/src/tools/code_mode/execute_handler.rs b/codex-rs/core/src/tools/code_mode/execute_handler.rs index 3f77216c1674..fa0fcca9f792 100644 --- a/codex-rs/core/src/tools/code_mode/execute_handler.rs +++ b/codex-rs/core/src/tools/code_mode/execute_handler.rs @@ -1,11 +1,11 @@ -use async_trait::async_trait; - use crate::function_tool::FunctionCallError; use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; +use futures::future::BoxFuture; use super::ExecContext; use super::PUBLIC_TOOL_NAME; @@ -53,10 +53,7 @@ impl CodeModeExecuteHandler { } } -#[async_trait] impl ToolHandler for CodeModeExecuteHandler { - type Output = FunctionToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -65,23 +62,35 @@ impl ToolHandler for CodeModeExecuteHandler { matches!(payload, ToolPayload::Custom { .. }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - call_id, - tool_name, - payload, - .. - } = invocation; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + call_id, + tool_name, + payload, + .. + } = invocation; + let payload_for_result = payload.clone(); - match payload { - ToolPayload::Custom { input } if tool_name == PUBLIC_TOOL_NAME => { - self.execute(session, turn, call_id, input).await - } - _ => Err(FunctionCallError::RespondToModel(format!( - "{PUBLIC_TOOL_NAME} expects raw JavaScript source text" - ))), - } + let result = match payload { + ToolPayload::Custom { input } if tool_name == PUBLIC_TOOL_NAME => { + self.execute(session, turn, call_id.clone(), input).await + } + _ => Err(FunctionCallError::RespondToModel(format!( + "{PUBLIC_TOOL_NAME} expects raw JavaScript source text" + ))), + }?; + + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(result), + }) + }) } } diff --git a/codex-rs/core/src/tools/code_mode/wait_handler.rs b/codex-rs/core/src/tools/code_mode/wait_handler.rs index f319985a885e..7b39e1b4e724 100644 --- a/codex-rs/core/src/tools/code_mode/wait_handler.rs +++ b/codex-rs/core/src/tools/code_mode/wait_handler.rs @@ -1,10 +1,10 @@ -use async_trait::async_trait; +use futures::future::BoxFuture; use serde::Deserialize; use crate::function_tool::FunctionCallError; -use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; @@ -39,46 +39,56 @@ where }) } -#[async_trait] impl ToolHandler for CodeModeWaitHandler { - type Output = FunctionToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - tool_name, - payload, - .. - } = invocation; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + call_id, + tool_name, + payload, + .. + } = invocation; + let payload_for_result = payload.clone(); + + let result = match payload { + ToolPayload::Function { arguments } if tool_name == WAIT_TOOL_NAME => { + let args: ExecWaitArgs = parse_arguments(&arguments)?; + let exec = ExecContext { session, turn }; + let started_at = std::time::Instant::now(); + let response = exec + .session + .services + .code_mode_service + .wait(codex_code_mode::WaitRequest { + cell_id: args.cell_id, + yield_time_ms: args.yield_time_ms, + terminate: args.terminate, + }) + .await + .map_err(FunctionCallError::RespondToModel)?; + handle_runtime_response(&exec, response, args.max_tokens, started_at) + .await + .map_err(FunctionCallError::RespondToModel) + } + _ => Err(FunctionCallError::RespondToModel(format!( + "{WAIT_TOOL_NAME} expects JSON arguments" + ))), + }?; - match payload { - ToolPayload::Function { arguments } if tool_name == WAIT_TOOL_NAME => { - let args: ExecWaitArgs = parse_arguments(&arguments)?; - let exec = ExecContext { session, turn }; - let started_at = std::time::Instant::now(); - let response = exec - .session - .services - .code_mode_service - .wait(codex_code_mode::WaitRequest { - cell_id: args.cell_id, - yield_time_ms: args.yield_time_ms, - terminate: args.terminate, - }) - .await - .map_err(FunctionCallError::RespondToModel)?; - handle_runtime_response(&exec, response, args.max_tokens, started_at) - .await - .map_err(FunctionCallError::RespondToModel) - } - _ => Err(FunctionCallError::RespondToModel(format!( - "{WAIT_TOOL_NAME} expects JSON arguments" - ))), - } + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(result), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/agent_jobs.rs b/codex-rs/core/src/tools/handlers/agent_jobs.rs index c607730b8fd2..949359b45d0c 100644 --- a/codex-rs/core/src/tools/handlers/agent_jobs.rs +++ b/codex-rs/core/src/tools/handlers/agent_jobs.rs @@ -11,15 +11,16 @@ use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; use crate::tools::handlers::multi_agents::build_agent_spawn_config; use crate::tools::handlers::parse_arguments; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; -use async_trait::async_trait; use codex_protocol::ThreadId; use codex_protocol::protocol::AgentStatus; use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::SubAgentSource; use codex_protocol::user_input::UserInput; use futures::StreamExt; +use futures::future::BoxFuture; use futures::stream::FuturesUnordered; use serde::Deserialize; use serde::Serialize; @@ -178,10 +179,7 @@ impl JobProgressEmitter { } } -#[async_trait] impl ToolHandler for BatchJobHandler { - type Output = FunctionToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -190,31 +188,48 @@ impl ToolHandler for BatchJobHandler { matches!(payload, ToolPayload::Function { .. }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - tool_name, - payload, - .. - } = invocation; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + call_id, + tool_name, + payload, + .. + } = invocation; + let payload_for_result = payload.clone(); - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "agent jobs handler received unsupported payload".to_string(), - )); - } - }; + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "agent jobs handler received unsupported payload".to_string(), + )); + } + }; - match tool_name.as_str() { - "spawn_agents_on_csv" => spawn_agents_on_csv::handle(session, turn, arguments).await, - "report_agent_job_result" => report_agent_job_result::handle(session, arguments).await, - other => Err(FunctionCallError::RespondToModel(format!( - "unsupported agent job tool {other}" - ))), - } + let result = match tool_name.as_str() { + "spawn_agents_on_csv" => { + spawn_agents_on_csv::handle(session, turn, arguments).await + } + "report_agent_job_result" => { + report_agent_job_result::handle(session, arguments).await + } + other => Err(FunctionCallError::RespondToModel(format!( + "unsupported agent job tool {other}" + ))), + }?; + + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(result), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/apply_patch.rs b/codex-rs/core/src/tools/handlers/apply_patch.rs index 801cff56f381..a1eb433ec6cd 100644 --- a/codex-rs/core/src/tools/handlers/apply_patch.rs +++ b/codex-rs/core/src/tools/handlers/apply_patch.rs @@ -16,12 +16,12 @@ use crate::tools::events::ToolEventCtx; use crate::tools::handlers::apply_granted_turn_permissions; use crate::tools::handlers::parse_arguments; use crate::tools::orchestrator::ToolOrchestrator; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; use crate::tools::runtimes::apply_patch::ApplyPatchRequest; use crate::tools::runtimes::apply_patch::ApplyPatchRuntime; use crate::tools::sandboxing::ToolCtx; -use async_trait::async_trait; use codex_apply_patch::ApplyPatchAction; use codex_apply_patch::ApplyPatchFileChange; use codex_protocol::models::FileSystemPermissions; @@ -31,6 +31,7 @@ use codex_sandboxing::policy_transforms::merge_permission_profiles; use codex_sandboxing::policy_transforms::normalize_additional_permissions; use codex_tools::ApplyPatchToolArgs; use codex_utils_absolute_path::AbsolutePathBuf; +use futures::future::BoxFuture; use std::collections::BTreeSet; use std::sync::Arc; @@ -122,10 +123,7 @@ async fn effective_patch_permissions( ) } -#[async_trait] impl ToolHandler for ApplyPatchHandler { - type Output = ApplyPatchToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -137,119 +135,137 @@ impl ToolHandler for ApplyPatchHandler { ) } - async fn is_mutating(&self, _invocation: &ToolInvocation) -> bool { + fn is_mutating(&self, _invocation: &ToolInvocation) -> bool { true } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - tracker, - call_id, - tool_name, - payload, - .. - } = invocation; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + tracker, + call_id, + tool_name, + payload, + .. + } = invocation; + let payload_for_result = payload.clone(); - let patch_input = match payload { - ToolPayload::Function { arguments } => { - let args: ApplyPatchToolArgs = parse_arguments(&arguments)?; - args.input - } - ToolPayload::Custom { input } => input, - _ => { - return Err(FunctionCallError::RespondToModel( - "apply_patch handler received unsupported payload".to_string(), - )); - } - }; + let patch_input = match payload { + ToolPayload::Function { arguments } => { + let args: ApplyPatchToolArgs = parse_arguments(&arguments)?; + args.input + } + ToolPayload::Custom { input } => input, + _ => { + return Err(FunctionCallError::RespondToModel( + "apply_patch handler received unsupported payload".to_string(), + )); + } + }; - // Re-parse and verify the patch so we can compute changes and approval. - // Avoid building temporary ExecParams/command vectors; derive directly from inputs. - let cwd = turn.cwd.clone(); - let command = vec!["apply_patch".to_string(), patch_input.clone()]; - match codex_apply_patch::maybe_parse_apply_patch_verified(&command, &cwd) { - codex_apply_patch::MaybeApplyPatchVerified::Body(changes) => { - let (file_paths, effective_additional_permissions, file_system_sandbox_policy) = - effective_patch_permissions(session.as_ref(), turn.as_ref(), &changes).await; - match apply_patch::apply_patch(turn.as_ref(), &file_system_sandbox_policy, changes) + // Re-parse and verify the patch so we can compute changes and approval. + // Avoid building temporary ExecParams/command vectors; derive directly from inputs. + let cwd = turn.cwd.clone(); + let command = vec!["apply_patch".to_string(), patch_input.clone()]; + match codex_apply_patch::maybe_parse_apply_patch_verified(&command, &cwd) { + codex_apply_patch::MaybeApplyPatchVerified::Body(changes) => { + let (file_paths, effective_additional_permissions, file_system_sandbox_policy) = + effective_patch_permissions(session.as_ref(), turn.as_ref(), &changes).await; + match apply_patch::apply_patch( + turn.as_ref(), + &file_system_sandbox_policy, + changes, + ) .await - { - InternalApplyPatchInvocation::Output(item) => { - let content = item?; - Ok(ApplyPatchToolOutput::from_text(content)) - } - InternalApplyPatchInvocation::DelegateToExec(apply) => { - let changes = convert_apply_patch_to_protocol(&apply.action); - let emitter = - ToolEmitter::apply_patch(changes.clone(), apply.auto_approved); - let event_ctx = ToolEventCtx::new( - session.as_ref(), - turn.as_ref(), - &call_id, - Some(&tracker), - ); - emitter.begin(event_ctx).await; + { + InternalApplyPatchInvocation::Output(item) => { + let content = item?; + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(ApplyPatchToolOutput::from_text(content)), + }) + } + InternalApplyPatchInvocation::DelegateToExec(apply) => { + let changes = convert_apply_patch_to_protocol(&apply.action); + let emitter = + ToolEmitter::apply_patch(changes.clone(), apply.auto_approved); + let event_ctx = ToolEventCtx::new( + session.as_ref(), + turn.as_ref(), + &call_id, + Some(&tracker), + ); + emitter.begin(event_ctx).await; - let req = ApplyPatchRequest { - action: apply.action, - file_paths, - changes, - exec_approval_requirement: apply.exec_approval_requirement, - additional_permissions: effective_additional_permissions - .additional_permissions, - permissions_preapproved: effective_additional_permissions - .permissions_preapproved, - timeout_ms: None, - }; + let req = ApplyPatchRequest { + action: apply.action, + file_paths, + changes, + exec_approval_requirement: apply.exec_approval_requirement, + additional_permissions: effective_additional_permissions + .additional_permissions, + permissions_preapproved: effective_additional_permissions + .permissions_preapproved, + timeout_ms: None, + }; - let mut orchestrator = ToolOrchestrator::new(); - let mut runtime = ApplyPatchRuntime::new(); - let tool_ctx = ToolCtx { - session: session.clone(), - turn: turn.clone(), - call_id: call_id.clone(), - tool_name: tool_name.to_string(), - }; - let out = orchestrator - .run( - &mut runtime, - &req, - &tool_ctx, + let mut orchestrator = ToolOrchestrator::new(); + let mut runtime = ApplyPatchRuntime::new(); + let tool_ctx = ToolCtx { + session: session.clone(), + turn: turn.clone(), + call_id: call_id.clone(), + tool_name: tool_name.to_string(), + }; + let out = orchestrator + .run( + &mut runtime, + &req, + &tool_ctx, + turn.as_ref(), + turn.approval_policy.value(), + ) + .await + .map(|result| result.output); + let event_ctx = ToolEventCtx::new( + session.as_ref(), turn.as_ref(), - turn.approval_policy.value(), - ) - .await - .map(|result| result.output); - let event_ctx = ToolEventCtx::new( - session.as_ref(), - turn.as_ref(), - &call_id, - Some(&tracker), - ); - let content = emitter.finish(event_ctx, out).await?; - Ok(ApplyPatchToolOutput::from_text(content)) + &call_id, + Some(&tracker), + ); + let content = emitter.finish(event_ctx, out).await?; + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(ApplyPatchToolOutput::from_text(content)), + }) + } } } + codex_apply_patch::MaybeApplyPatchVerified::CorrectnessError(parse_error) => { + Err(FunctionCallError::RespondToModel(format!( + "apply_patch verification failed: {parse_error}" + ))) + } + codex_apply_patch::MaybeApplyPatchVerified::ShellParseError(error) => { + tracing::trace!("Failed to parse apply_patch input, {error:?}"); + Err(FunctionCallError::RespondToModel( + "apply_patch handler received invalid patch input".to_string(), + )) + } + codex_apply_patch::MaybeApplyPatchVerified::NotApplyPatch => { + Err(FunctionCallError::RespondToModel( + "apply_patch handler received non-apply_patch input".to_string(), + )) + } } - codex_apply_patch::MaybeApplyPatchVerified::CorrectnessError(parse_error) => { - Err(FunctionCallError::RespondToModel(format!( - "apply_patch verification failed: {parse_error}" - ))) - } - codex_apply_patch::MaybeApplyPatchVerified::ShellParseError(error) => { - tracing::trace!("Failed to parse apply_patch input, {error:?}"); - Err(FunctionCallError::RespondToModel( - "apply_patch handler received invalid patch input".to_string(), - )) - } - codex_apply_patch::MaybeApplyPatchVerified::NotApplyPatch => { - Err(FunctionCallError::RespondToModel( - "apply_patch handler received non-apply_patch input".to_string(), - )) - } - } + }) } } diff --git a/codex-rs/core/src/tools/handlers/dynamic.rs b/codex-rs/core/src/tools/handlers/dynamic.rs index 23285bc9079d..6c01aa0d6702 100644 --- a/codex-rs/core/src/tools/handlers/dynamic.rs +++ b/codex-rs/core/src/tools/handlers/dynamic.rs @@ -5,14 +5,15 @@ use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; use crate::tools::handlers::parse_arguments; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; -use async_trait::async_trait; use codex_protocol::dynamic_tools::DynamicToolCallRequest; use codex_protocol::dynamic_tools::DynamicToolResponse; use codex_protocol::models::FunctionCallOutputContentItem; use codex_protocol::protocol::DynamicToolCallResponseEvent; use codex_protocol::protocol::EventMsg; +use futures::future::BoxFuture; use serde_json::Value; use std::time::Instant; use tokio::sync::oneshot; @@ -20,55 +21,64 @@ use tracing::warn; pub struct DynamicToolHandler; -#[async_trait] impl ToolHandler for DynamicToolHandler { - type Output = FunctionToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } - async fn is_mutating(&self, _invocation: &ToolInvocation) -> bool { + fn is_mutating(&self, _invocation: &ToolInvocation) -> bool { true } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - call_id, - tool_name, - payload, - .. - } = invocation; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + call_id, + tool_name, + payload, + .. + } = invocation; - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "dynamic tool handler received unsupported payload".to_string(), - )); - } - }; + let payload_for_result = payload.clone(); + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "dynamic tool handler received unsupported payload".to_string(), + )); + } + }; - let args: Value = parse_arguments(&arguments)?; - let response = request_dynamic_tool(&session, turn.as_ref(), call_id, tool_name, args) - .await - .ok_or_else(|| { - FunctionCallError::RespondToModel( - "dynamic tool call was cancelled before receiving a response".to_string(), - ) - })?; + let args: Value = parse_arguments(&arguments)?; + let response = + request_dynamic_tool(&session, turn.as_ref(), call_id.clone(), tool_name, args) + .await + .ok_or_else(|| { + FunctionCallError::RespondToModel( + "dynamic tool call was cancelled before receiving a response" + .to_string(), + ) + })?; - let DynamicToolResponse { - content_items, - success, - } = response; - let body = content_items - .into_iter() - .map(FunctionCallOutputContentItem::from) - .collect::>(); - Ok(FunctionToolOutput::from_content(body, Some(success))) + let DynamicToolResponse { + content_items, + success, + } = response; + let body = content_items + .into_iter() + .map(FunctionCallOutputContentItem::from) + .collect::>(); + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(FunctionToolOutput::from_content(body, Some(success))), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/js_repl.rs b/codex-rs/core/src/tools/handlers/js_repl.rs index c5906abafcfb..7ffd9e13dd8d 100644 --- a/codex-rs/core/src/tools/handlers/js_repl.rs +++ b/codex-rs/core/src/tools/handlers/js_repl.rs @@ -1,4 +1,4 @@ -use async_trait::async_trait; +use futures::future::BoxFuture; use serde_json::Value as JsonValue; use std::sync::Arc; use std::time::Duration; @@ -17,6 +17,7 @@ use crate::tools::events::ToolEventStage; use crate::tools::handlers::parse_arguments; use crate::tools::js_repl::JS_REPL_PRAGMA_PREFIX; use crate::tools::js_repl::JsReplArgs; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; use codex_features::Feature; @@ -92,10 +93,7 @@ async fn emit_js_repl_exec_end( }; emitter.emit(ctx, stage).await; } -#[async_trait] impl ToolHandler for JsReplHandler { - type Output = FunctionToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -107,101 +105,114 @@ impl ToolHandler for JsReplHandler { ) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - tracker, - payload, - call_id, - .. - } = invocation; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + tracker, + payload, + call_id, + .. + } = invocation; + let payload_for_result = payload.clone(); - if !session.features().enabled(Feature::JsRepl) { - return Err(FunctionCallError::RespondToModel( - "js_repl is disabled by feature flag".to_string(), - )); - } - - let args = match payload { - ToolPayload::Function { arguments } => parse_arguments(&arguments)?, - ToolPayload::Custom { input } => parse_freeform_args(&input)?, - _ => { + if !session.features().enabled(Feature::JsRepl) { return Err(FunctionCallError::RespondToModel( - "js_repl expects custom or function payload".to_string(), + "js_repl is disabled by feature flag".to_string(), )); } - }; - let manager = turn.js_repl.manager().await?; - let started_at = Instant::now(); - emit_js_repl_exec_begin(session.as_ref(), turn.as_ref(), &call_id).await; - let result = manager - .execute(Arc::clone(&session), Arc::clone(&turn), tracker, args) - .await; - let result = match result { - Ok(result) => result, - Err(err) => { - let message = err.to_string(); - emit_js_repl_exec_end( - session.as_ref(), - turn.as_ref(), - &call_id, - "", - Some(&message), - started_at.elapsed(), - ) + + let args = match payload { + ToolPayload::Function { arguments } => parse_arguments(&arguments)?, + ToolPayload::Custom { input } => parse_freeform_args(&input)?, + _ => { + return Err(FunctionCallError::RespondToModel( + "js_repl expects custom or function payload".to_string(), + )); + } + }; + let manager = turn.js_repl.manager().await?; + let started_at = Instant::now(); + emit_js_repl_exec_begin(session.as_ref(), turn.as_ref(), &call_id).await; + let result = manager + .execute(Arc::clone(&session), Arc::clone(&turn), tracker, args) .await; - return Err(err); - } - }; + let result = match result { + Ok(result) => result, + Err(err) => { + let message = err.to_string(); + emit_js_repl_exec_end( + session.as_ref(), + turn.as_ref(), + &call_id, + "", + Some(&message), + started_at.elapsed(), + ) + .await; + return Err(err); + } + }; - let content = result.output; - let mut items = Vec::with_capacity(result.content_items.len() + 1); - if !content.is_empty() { - items.push(FunctionCallOutputContentItem::InputText { - text: content.clone(), - }); - } - items.extend(result.content_items); + let content = result.output; + let mut items = Vec::with_capacity(result.content_items.len() + 1); + if !content.is_empty() { + items.push(FunctionCallOutputContentItem::InputText { + text: content.clone(), + }); + } + items.extend(result.content_items); - emit_js_repl_exec_end( - session.as_ref(), - turn.as_ref(), - &call_id, - &content, - /*error*/ None, - started_at.elapsed(), - ) - .await; + emit_js_repl_exec_end( + session.as_ref(), + turn.as_ref(), + &call_id, + &content, + /*error*/ None, + started_at.elapsed(), + ) + .await; - if items.is_empty() { - Ok(FunctionToolOutput::from_text(content, Some(true))) - } else { - Ok(FunctionToolOutput::from_content(items, Some(true))) - } + let output = if items.is_empty() { + FunctionToolOutput::from_text(content, Some(true)) + } else { + FunctionToolOutput::from_content(items, Some(true)) + }; + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(output), + }) + }) } } -#[async_trait] impl ToolHandler for JsReplResetHandler { - type Output = FunctionToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } - async fn handle(&self, invocation: ToolInvocation) -> Result { - if !invocation.session.features().enabled(Feature::JsRepl) { - return Err(FunctionCallError::RespondToModel( - "js_repl is disabled by feature flag".to_string(), - )); - } - let manager = invocation.turn.js_repl.manager().await?; - manager.reset().await?; - Ok(FunctionToolOutput::from_text( - "js_repl kernel reset".to_string(), - Some(true), - )) + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + if !invocation.session.features().enabled(Feature::JsRepl) { + return Err(FunctionCallError::RespondToModel( + "js_repl is disabled by feature flag".to_string(), + )); + } + let manager = invocation.turn.js_repl.manager().await?; + manager.reset().await?; + Ok(AnyToolResult::new( + &invocation, + FunctionToolOutput::from_text("js_repl kernel reset".to_string(), Some(true)), + )) + }) } } diff --git a/codex-rs/core/src/tools/handlers/list_dir.rs b/codex-rs/core/src/tools/handlers/list_dir.rs index fd461e82e5d4..a0130e968016 100644 --- a/codex-rs/core/src/tools/handlers/list_dir.rs +++ b/codex-rs/core/src/tools/handlers/list_dir.rs @@ -4,8 +4,8 @@ use std::fs::FileType; use std::path::Path; use std::path::PathBuf; -use async_trait::async_trait; use codex_utils_string::take_bytes_at_char_boundary; +use futures::future::BoxFuture; use serde::Deserialize; use tokio::fs; @@ -14,6 +14,7 @@ use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; use crate::tools::handlers::parse_arguments; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; @@ -45,65 +46,74 @@ struct ListDirArgs { depth: usize, } -#[async_trait] impl ToolHandler for ListDirHandler { - type Output = FunctionToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { payload, .. } = invocation; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + call_id, payload, .. + } = invocation; + + let payload_for_result = payload.clone(); + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "list_dir handler received unsupported payload".to_string(), + )); + } + }; + + let args: ListDirArgs = parse_arguments(&arguments)?; - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { + let ListDirArgs { + dir_path, + offset, + limit, + depth, + } = args; + + if offset == 0 { return Err(FunctionCallError::RespondToModel( - "list_dir handler received unsupported payload".to_string(), + "offset must be a 1-indexed entry number".to_string(), )); } - }; - - let args: ListDirArgs = parse_arguments(&arguments)?; - - let ListDirArgs { - dir_path, - offset, - limit, - depth, - } = args; - if offset == 0 { - return Err(FunctionCallError::RespondToModel( - "offset must be a 1-indexed entry number".to_string(), - )); - } - - if limit == 0 { - return Err(FunctionCallError::RespondToModel( - "limit must be greater than zero".to_string(), - )); - } + if limit == 0 { + return Err(FunctionCallError::RespondToModel( + "limit must be greater than zero".to_string(), + )); + } - if depth == 0 { - return Err(FunctionCallError::RespondToModel( - "depth must be greater than zero".to_string(), - )); - } + if depth == 0 { + return Err(FunctionCallError::RespondToModel( + "depth must be greater than zero".to_string(), + )); + } - let path = PathBuf::from(&dir_path); - if !path.is_absolute() { - return Err(FunctionCallError::RespondToModel( - "dir_path must be an absolute path".to_string(), - )); - } + let path = PathBuf::from(&dir_path); + if !path.is_absolute() { + return Err(FunctionCallError::RespondToModel( + "dir_path must be an absolute path".to_string(), + )); + } - let entries = list_dir_slice(&path, offset, limit, depth).await?; - let mut output = Vec::with_capacity(entries.len() + 1); - output.push(format!("Absolute path: {}", path.display())); - output.extend(entries); - Ok(FunctionToolOutput::from_text(output.join("\n"), Some(true))) + let entries = list_dir_slice(&path, offset, limit, depth).await?; + let mut output = Vec::with_capacity(entries.len() + 1); + output.push(format!("Absolute path: {}", path.display())); + output.extend(entries); + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(FunctionToolOutput::from_text(output.join("\n"), Some(true))), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/mcp.rs b/codex-rs/core/src/tools/handlers/mcp.rs index 18e0df25c4ad..9830e98df499 100644 --- a/codex-rs/core/src/tools/handlers/mcp.rs +++ b/codex-rs/core/src/tools/handlers/mcp.rs @@ -1,58 +1,66 @@ -use async_trait::async_trait; +use futures::future::BoxFuture; use std::sync::Arc; use crate::function_tool::FunctionCallError; use crate::mcp_tool_call::handle_mcp_tool_call; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; -use codex_protocol::mcp::CallToolResult; pub struct McpHandler; -#[async_trait] -impl ToolHandler for McpHandler { - type Output = CallToolResult; +impl ToolHandler for McpHandler { fn kind(&self) -> ToolKind { ToolKind::Mcp } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - call_id, - payload, - .. - } = invocation; - - let payload = match payload { - ToolPayload::Mcp { + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + call_id, + payload, + .. + } = invocation; + + let payload_for_result = payload.clone(); + let payload = match payload { + ToolPayload::Mcp { + server, + tool, + raw_arguments, + } => (server, tool, raw_arguments), + _ => { + return Err(FunctionCallError::RespondToModel( + "mcp handler received unsupported payload".to_string(), + )); + } + }; + + let (server, tool, raw_arguments) = payload; + let arguments_str = raw_arguments; + + let output = handle_mcp_tool_call( + Arc::clone(&session), + &turn, + call_id.clone(), server, tool, - raw_arguments, - } => (server, tool, raw_arguments), - _ => { - return Err(FunctionCallError::RespondToModel( - "mcp handler received unsupported payload".to_string(), - )); - } - }; - - let (server, tool, raw_arguments) = payload; - let arguments_str = raw_arguments; - - let output = handle_mcp_tool_call( - Arc::clone(&session), - &turn, - call_id.clone(), - server, - tool, - arguments_str, - ) - .await; - - Ok(output) + arguments_str, + ) + .await; + + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(output), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/mcp_resource.rs b/codex-rs/core/src/tools/handlers/mcp_resource.rs index 158aee612b56..dd5f441373b7 100644 --- a/codex-rs/core/src/tools/handlers/mcp_resource.rs +++ b/codex-rs/core/src/tools/handlers/mcp_resource.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use std::time::Duration; use std::time::Instant; -use async_trait::async_trait; use codex_protocol::mcp::CallToolResult; use codex_protocol::models::function_call_output_content_items_to_text; use rmcp::model::ListResourceTemplatesResult; @@ -24,12 +23,14 @@ use crate::function_tool::FunctionCallError; use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::McpInvocation; use codex_protocol::protocol::McpToolCallBeginEvent; use codex_protocol::protocol::McpToolCallEndEvent; +use futures::future::BoxFuture; pub struct McpResourceHandler; @@ -178,67 +179,76 @@ struct ReadResourcePayload { result: ReadResourceResult, } -#[async_trait] impl ToolHandler for McpResourceHandler { - type Output = FunctionToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - call_id, - tool_name, - payload, - .. - } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "mcp_resource handler received unsupported payload".to_string(), - )); - } - }; - - let arguments_value = parse_arguments(arguments.as_str())?; - - match tool_name.as_str() { - "list_mcp_resources" => { - handle_list_resources( - Arc::clone(&session), - Arc::clone(&turn), - call_id.clone(), - arguments_value.clone(), - ) - .await - } - "list_mcp_resource_templates" => { - handle_list_resource_templates( - Arc::clone(&session), - Arc::clone(&turn), - call_id.clone(), - arguments_value.clone(), - ) - .await - } - "read_mcp_resource" => { - handle_read_resource( - Arc::clone(&session), - Arc::clone(&turn), - call_id, - arguments_value, - ) - .await - } - other => Err(FunctionCallError::RespondToModel(format!( - "unsupported MCP resource tool: {other}" - ))), - } + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + call_id, + tool_name, + payload, + .. + } = invocation; + let payload_for_result = payload.clone(); + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "mcp_resource handler received unsupported payload".to_string(), + )); + } + }; + + let arguments_value = parse_arguments(arguments.as_str())?; + + let result = match tool_name.as_str() { + "list_mcp_resources" => { + handle_list_resources( + Arc::clone(&session), + Arc::clone(&turn), + call_id.clone(), + arguments_value.clone(), + ) + .await + } + "list_mcp_resource_templates" => { + handle_list_resource_templates( + Arc::clone(&session), + Arc::clone(&turn), + call_id.clone(), + arguments_value.clone(), + ) + .await + } + "read_mcp_resource" => { + handle_read_resource( + Arc::clone(&session), + Arc::clone(&turn), + call_id.clone(), + arguments_value, + ) + .await + } + other => Err(FunctionCallError::RespondToModel(format!( + "unsupported MCP resource tool: {other}" + ))), + }?; + + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(result), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/multi_agents.rs b/codex-rs/core/src/tools/handlers/multi_agents.rs index 166dbd287d73..8a0f876718b7 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents.rs @@ -15,9 +15,9 @@ use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; pub(crate) use crate::tools::handlers::multi_agents_common::*; use crate::tools::handlers::parse_arguments; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; -use async_trait::async_trait; use codex_protocol::ThreadId; use codex_protocol::models::ResponseInputItem; use codex_protocol::openai_models::ReasoningEffort; @@ -33,6 +33,7 @@ use codex_protocol::protocol::CollabResumeEndEvent; use codex_protocol::protocol::CollabWaitingBeginEvent; use codex_protocol::protocol::CollabWaitingEndEvent; use codex_protocol::user_input::UserInput; +use futures::future::BoxFuture; use serde::Deserialize; use serde::Serialize; use serde_json::Value as JsonValue; diff --git a/codex-rs/core/src/tools/handlers/multi_agents/close_agent.rs b/codex-rs/core/src/tools/handlers/multi_agents/close_agent.rs index efc8ec378e32..0a995d4a67c2 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents/close_agent.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents/close_agent.rs @@ -2,10 +2,7 @@ use super::*; pub(crate) struct Handler; -#[async_trait] impl ToolHandler for Handler { - type Output = CloseAgentResult; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -14,84 +11,94 @@ impl ToolHandler for Handler { matches!(payload, ToolPayload::Function { .. }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - payload, - call_id, - .. - } = invocation; - let arguments = function_arguments(payload)?; - let args: CloseAgentArgs = parse_arguments(&arguments)?; - let agent_id = parse_agent_id_target(&args.target)?; - let receiver_agent = session - .services - .agent_control - .get_agent_metadata(agent_id) - .unwrap_or_default(); - session - .send_event( - &turn, - CollabCloseBeginEvent { - call_id: call_id.clone(), - sender_thread_id: session.conversation_id, - receiver_thread_id: agent_id, - } - .into(), - ) - .await; - let status = match session - .services - .agent_control - .subscribe_status(agent_id) - .await - { - Ok(mut status_rx) => status_rx.borrow_and_update().clone(), - Err(err) => { - let status = session.services.agent_control.get_status(agent_id).await; - session - .send_event( - &turn, - CollabCloseEndEvent { - call_id: call_id.clone(), - sender_thread_id: session.conversation_id, - receiver_thread_id: agent_id, - receiver_agent_nickname: receiver_agent.agent_nickname.clone(), - receiver_agent_role: receiver_agent.agent_role.clone(), - status, - } - .into(), - ) - .await; - return Err(collab_agent_error(agent_id, err)); - } - }; - let result = session - .services - .agent_control - .close_agent(agent_id) - .await - .map_err(|err| collab_agent_error(agent_id, err)) - .map(|_| ()); - session - .send_event( - &turn, - CollabCloseEndEvent { - call_id, - sender_thread_id: session.conversation_id, - receiver_thread_id: agent_id, - receiver_agent_nickname: receiver_agent.agent_nickname, - receiver_agent_role: receiver_agent.agent_role, - status: status.clone(), + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + payload, + call_id, + .. + } = invocation; + let payload_for_result = payload.clone(); + let arguments = function_arguments(payload)?; + let args: CloseAgentArgs = parse_arguments(&arguments)?; + let agent_id = parse_agent_id_target(&args.target)?; + let receiver_agent = session + .services + .agent_control + .get_agent_metadata(agent_id) + .unwrap_or_default(); + session + .send_event( + &turn, + CollabCloseBeginEvent { + call_id: call_id.clone(), + sender_thread_id: session.conversation_id, + receiver_thread_id: agent_id, + } + .into(), + ) + .await; + let status = match session + .services + .agent_control + .subscribe_status(agent_id) + .await + { + Ok(mut status_rx) => status_rx.borrow_and_update().clone(), + Err(err) => { + let status = session.services.agent_control.get_status(agent_id).await; + session + .send_event( + &turn, + CollabCloseEndEvent { + call_id: call_id.clone(), + sender_thread_id: session.conversation_id, + receiver_thread_id: agent_id, + receiver_agent_nickname: receiver_agent.agent_nickname.clone(), + receiver_agent_role: receiver_agent.agent_role.clone(), + status, + } + .into(), + ) + .await; + return Err(collab_agent_error(agent_id, err)); } - .into(), - ) - .await; - result?; + }; + let result = session + .services + .agent_control + .close_agent(agent_id) + .await + .map_err(|err| collab_agent_error(agent_id, err)) + .map(|_| ()); + session + .send_event( + &turn, + CollabCloseEndEvent { + call_id: call_id.clone(), + sender_thread_id: session.conversation_id, + receiver_thread_id: agent_id, + receiver_agent_nickname: receiver_agent.agent_nickname, + receiver_agent_role: receiver_agent.agent_role, + status: status.clone(), + } + .into(), + ) + .await; + result?; - Ok(CloseAgentResult { - previous_status: status, + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(CloseAgentResult { + previous_status: status, + }), + }) }) } } diff --git a/codex-rs/core/src/tools/handlers/multi_agents/resume_agent.rs b/codex-rs/core/src/tools/handlers/multi_agents/resume_agent.rs index 09526182f293..e4a18391b6ad 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents/resume_agent.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents/resume_agent.rs @@ -4,10 +4,7 @@ use std::sync::Arc; pub(crate) struct Handler; -#[async_trait] impl ToolHandler for Handler { - type Output = ResumeAgentResult; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -16,102 +13,114 @@ impl ToolHandler for Handler { matches!(payload, ToolPayload::Function { .. }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - payload, - call_id, - .. - } = invocation; - let arguments = function_arguments(payload)?; - let args: ResumeAgentArgs = parse_arguments(&arguments)?; - let receiver_thread_id = ThreadId::from_string(&args.id).map_err(|err| { - FunctionCallError::RespondToModel(format!("invalid agent id {}: {err:?}", args.id)) - })?; - let receiver_agent = session - .services - .agent_control - .get_agent_metadata(receiver_thread_id) - .unwrap_or_default(); - let child_depth = next_thread_spawn_depth(&turn.session_source); - let max_depth = turn.config.agent_max_depth; - if exceeds_thread_spawn_depth_limit(child_depth, max_depth) { - return Err(FunctionCallError::RespondToModel( - "Agent depth limit reached. Solve the task yourself.".to_string(), - )); - } - - session - .send_event( - &turn, - CollabResumeBeginEvent { - call_id: call_id.clone(), - sender_thread_id: session.conversation_id, - receiver_thread_id, - receiver_agent_nickname: receiver_agent.agent_nickname.clone(), - receiver_agent_role: receiver_agent.agent_role.clone(), - } - .into(), - ) - .await; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + payload, + call_id, + .. + } = invocation; + let payload_for_result = payload.clone(); + let arguments = function_arguments(payload)?; + let args: ResumeAgentArgs = parse_arguments(&arguments)?; + let receiver_thread_id = ThreadId::from_string(&args.id).map_err(|err| { + FunctionCallError::RespondToModel(format!("invalid agent id {}: {err:?}", args.id)) + })?; + let receiver_agent = session + .services + .agent_control + .get_agent_metadata(receiver_thread_id) + .unwrap_or_default(); + let child_depth = next_thread_spawn_depth(&turn.session_source); + let max_depth = turn.config.agent_max_depth; + if exceeds_thread_spawn_depth_limit(child_depth, max_depth) { + return Err(FunctionCallError::RespondToModel( + "Agent depth limit reached. Solve the task yourself.".to_string(), + )); + } - let mut status = session - .services - .agent_control - .get_status(receiver_thread_id) - .await; - let (receiver_agent, error) = if matches!(status, AgentStatus::NotFound) { - match try_resume_closed_agent(&session, &turn, receiver_thread_id, child_depth).await { - Ok(()) => { - status = session - .services - .agent_control - .get_status(receiver_thread_id) - .await; - ( - session + session + .send_event( + &turn, + CollabResumeBeginEvent { + call_id: call_id.clone(), + sender_thread_id: session.conversation_id, + receiver_thread_id, + receiver_agent_nickname: receiver_agent.agent_nickname.clone(), + receiver_agent_role: receiver_agent.agent_role.clone(), + } + .into(), + ) + .await; + + let mut status = session + .services + .agent_control + .get_status(receiver_thread_id) + .await; + let (receiver_agent, error) = if matches!(status, AgentStatus::NotFound) { + match try_resume_closed_agent(&session, &turn, receiver_thread_id, child_depth) + .await + { + Ok(()) => { + status = session .services .agent_control - .get_agent_metadata(receiver_thread_id) - .unwrap_or(receiver_agent), - None, - ) - } - Err(err) => { - status = session - .services - .agent_control - .get_status(receiver_thread_id) - .await; - (receiver_agent, Some(err)) + .get_status(receiver_thread_id) + .await; + ( + session + .services + .agent_control + .get_agent_metadata(receiver_thread_id) + .unwrap_or(receiver_agent), + None, + ) + } + Err(err) => { + status = session + .services + .agent_control + .get_status(receiver_thread_id) + .await; + (receiver_agent, Some(err)) + } } + } else { + (receiver_agent, None) + }; + session + .send_event( + &turn, + CollabResumeEndEvent { + call_id: call_id.clone(), + sender_thread_id: session.conversation_id, + receiver_thread_id, + receiver_agent_nickname: receiver_agent.agent_nickname, + receiver_agent_role: receiver_agent.agent_role, + status: status.clone(), + } + .into(), + ) + .await; + + if let Some(err) = error { + return Err(err); } - } else { - (receiver_agent, None) - }; - session - .send_event( - &turn, - CollabResumeEndEvent { - call_id, - sender_thread_id: session.conversation_id, - receiver_thread_id, - receiver_agent_nickname: receiver_agent.agent_nickname, - receiver_agent_role: receiver_agent.agent_role, - status: status.clone(), - } - .into(), - ) - .await; - - if let Some(err) = error { - return Err(err); - } - turn.session_telemetry - .counter("codex.multi_agent.resume", /*inc*/ 1, &[]); - - Ok(ResumeAgentResult { status }) + turn.session_telemetry + .counter("codex.multi_agent.resume", /*inc*/ 1, &[]); + + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(ResumeAgentResult { status }), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/multi_agents/send_input.rs b/codex-rs/core/src/tools/handlers/multi_agents/send_input.rs index 3c8527712b2e..7023e4197753 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents/send_input.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents/send_input.rs @@ -3,10 +3,7 @@ use crate::agent::control::render_input_preview; pub(crate) struct Handler; -#[async_trait] impl ToolHandler for Handler { - type Output = SendInputResult; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -15,72 +12,82 @@ impl ToolHandler for Handler { matches!(payload, ToolPayload::Function { .. }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - payload, - call_id, - .. - } = invocation; - let arguments = function_arguments(payload)?; - let args: SendInputArgs = parse_arguments(&arguments)?; - let receiver_thread_id = parse_agent_id_target(&args.target)?; - let input_items = parse_collab_input(args.message, args.items)?; - let prompt = render_input_preview(&input_items); - let receiver_agent = session - .services - .agent_control - .get_agent_metadata(receiver_thread_id) - .unwrap_or_default(); - if args.interrupt { - session + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + payload, + call_id, + .. + } = invocation; + let payload_for_result = payload.clone(); + let arguments = function_arguments(payload)?; + let args: SendInputArgs = parse_arguments(&arguments)?; + let receiver_thread_id = parse_agent_id_target(&args.target)?; + let input_items = parse_collab_input(args.message, args.items)?; + let prompt = render_input_preview(&input_items); + let receiver_agent = session .services .agent_control - .interrupt_agent(receiver_thread_id) + .get_agent_metadata(receiver_thread_id) + .unwrap_or_default(); + if args.interrupt { + session + .services + .agent_control + .interrupt_agent(receiver_thread_id) + .await + .map_err(|err| collab_agent_error(receiver_thread_id, err))?; + } + session + .send_event( + &turn, + CollabAgentInteractionBeginEvent { + call_id: call_id.clone(), + sender_thread_id: session.conversation_id, + receiver_thread_id, + prompt: prompt.clone(), + } + .into(), + ) + .await; + let agent_control = session.services.agent_control.clone(); + let result = agent_control + .send_input(receiver_thread_id, input_items) .await - .map_err(|err| collab_agent_error(receiver_thread_id, err))?; - } - session - .send_event( - &turn, - CollabAgentInteractionBeginEvent { - call_id: call_id.clone(), - sender_thread_id: session.conversation_id, - receiver_thread_id, - prompt: prompt.clone(), - } - .into(), - ) - .await; - let agent_control = session.services.agent_control.clone(); - let result = agent_control - .send_input(receiver_thread_id, input_items) - .await - .map_err(|err| collab_agent_error(receiver_thread_id, err)); - let status = session - .services - .agent_control - .get_status(receiver_thread_id) - .await; - session - .send_event( - &turn, - CollabAgentInteractionEndEvent { - call_id, - sender_thread_id: session.conversation_id, - receiver_thread_id, - receiver_agent_nickname: receiver_agent.agent_nickname, - receiver_agent_role: receiver_agent.agent_role, - prompt, - status, - } - .into(), - ) - .await; - let submission_id = result?; + .map_err(|err| collab_agent_error(receiver_thread_id, err)); + let status = session + .services + .agent_control + .get_status(receiver_thread_id) + .await; + session + .send_event( + &turn, + CollabAgentInteractionEndEvent { + call_id: call_id.clone(), + sender_thread_id: session.conversation_id, + receiver_thread_id, + receiver_agent_nickname: receiver_agent.agent_nickname, + receiver_agent_role: receiver_agent.agent_role, + prompt, + status, + } + .into(), + ) + .await; + let submission_id = result?; - Ok(SendInputResult { submission_id }) + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(SendInputResult { submission_id }), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/multi_agents/spawn.rs b/codex-rs/core/src/tools/handlers/multi_agents/spawn.rs index 308ec49d8564..e08f474a563b 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents/spawn.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents/spawn.rs @@ -10,10 +10,7 @@ use crate::agent::next_thread_spawn_depth; pub(crate) struct Handler; -#[async_trait] impl ToolHandler for Handler { - type Output = SpawnAgentResult; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -22,149 +19,159 @@ impl ToolHandler for Handler { matches!(payload, ToolPayload::Function { .. }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - payload, - call_id, - .. - } = invocation; - let arguments = function_arguments(payload)?; - let args: SpawnAgentArgs = parse_arguments(&arguments)?; - let role_name = args - .agent_type - .as_deref() - .map(str::trim) - .filter(|role| !role.is_empty()); - let input_items = parse_collab_input(args.message, args.items)?; - let prompt = render_input_preview(&input_items); - let session_source = turn.session_source.clone(); - let child_depth = next_thread_spawn_depth(&session_source); - let max_depth = turn.config.agent_max_depth; - if exceeds_thread_spawn_depth_limit(child_depth, max_depth) { - return Err(FunctionCallError::RespondToModel( - "Agent depth limit reached. Solve the task yourself.".to_string(), - )); - } - session - .send_event( - &turn, - CollabAgentSpawnBeginEvent { - call_id: call_id.clone(), - sender_thread_id: session.conversation_id, - prompt: prompt.clone(), - model: args.model.clone().unwrap_or_default(), - reasoning_effort: args.reasoning_effort.unwrap_or_default(), - } - .into(), + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + payload, + call_id, + .. + } = invocation; + let payload_for_result = payload.clone(); + let arguments = function_arguments(payload)?; + let args: SpawnAgentArgs = parse_arguments(&arguments)?; + let role_name = args + .agent_type + .as_deref() + .map(str::trim) + .filter(|role| !role.is_empty()); + let input_items = parse_collab_input(args.message, args.items)?; + let prompt = render_input_preview(&input_items); + let session_source = turn.session_source.clone(); + let child_depth = next_thread_spawn_depth(&session_source); + let max_depth = turn.config.agent_max_depth; + if exceeds_thread_spawn_depth_limit(child_depth, max_depth) { + return Err(FunctionCallError::RespondToModel( + "Agent depth limit reached. Solve the task yourself.".to_string(), + )); + } + session + .send_event( + &turn, + CollabAgentSpawnBeginEvent { + call_id: call_id.clone(), + sender_thread_id: session.conversation_id, + prompt: prompt.clone(), + model: args.model.clone().unwrap_or_default(), + reasoning_effort: args.reasoning_effort.unwrap_or_default(), + } + .into(), + ) + .await; + let mut config = + build_agent_spawn_config(&session.get_base_instructions().await, turn.as_ref())?; + apply_requested_spawn_agent_model_overrides( + &session, + turn.as_ref(), + &mut config, + args.model.as_deref(), + args.reasoning_effort, ) - .await; - let mut config = - build_agent_spawn_config(&session.get_base_instructions().await, turn.as_ref())?; - apply_requested_spawn_agent_model_overrides( - &session, - turn.as_ref(), - &mut config, - args.model.as_deref(), - args.reasoning_effort, - ) - .await?; - apply_role_to_config(&mut config, role_name) - .await - .map_err(FunctionCallError::RespondToModel)?; - apply_spawn_agent_runtime_overrides(&mut config, turn.as_ref())?; - apply_spawn_agent_overrides(&mut config, child_depth); + .await?; + apply_role_to_config(&mut config, role_name) + .await + .map_err(FunctionCallError::RespondToModel)?; + apply_spawn_agent_runtime_overrides(&mut config, turn.as_ref())?; + apply_spawn_agent_overrides(&mut config, child_depth); - let result = session - .services - .agent_control - .spawn_agent_with_metadata( - config, - input_items, - Some(thread_spawn_source( - session.conversation_id, - &turn.session_source, - child_depth, - role_name, - /*task_name*/ None, - )?), - SpawnAgentOptions { - fork_parent_spawn_call_id: args.fork_context.then(|| call_id.clone()), - fork_mode: args.fork_context.then_some(SpawnAgentForkMode::FullHistory), - }, - ) - .await - .map_err(collab_spawn_error); - let (new_thread_id, new_agent_metadata, status) = match &result { - Ok(spawned_agent) => ( - Some(spawned_agent.thread_id), - Some(spawned_agent.metadata.clone()), - spawned_agent.status.clone(), - ), - Err(_) => (None, None, AgentStatus::NotFound), - }; - let agent_snapshot = match new_thread_id { - Some(thread_id) => { - session - .services - .agent_control - .get_agent_config_snapshot(thread_id) - .await - } - None => None, - }; - let (_new_agent_path, new_agent_nickname, new_agent_role) = - match (&agent_snapshot, new_agent_metadata) { - (Some(snapshot), _) => ( - snapshot.session_source.get_agent_path().map(String::from), - snapshot.session_source.get_nickname(), - snapshot.session_source.get_agent_role(), - ), - (None, Some(metadata)) => ( - metadata.agent_path.map(String::from), - metadata.agent_nickname, - metadata.agent_role, + let result = session + .services + .agent_control + .spawn_agent_with_metadata( + config, + input_items, + Some(thread_spawn_source( + session.conversation_id, + &turn.session_source, + child_depth, + role_name, + /*task_name*/ None, + )?), + SpawnAgentOptions { + fork_parent_spawn_call_id: args.fork_context.then(|| call_id.clone()), + fork_mode: args.fork_context.then_some(SpawnAgentForkMode::FullHistory), + }, + ) + .await + .map_err(collab_spawn_error); + let (new_thread_id, new_agent_metadata, status) = match &result { + Ok(spawned_agent) => ( + Some(spawned_agent.thread_id), + Some(spawned_agent.metadata.clone()), + spawned_agent.status.clone(), ), - (None, None) => (None, None, None), + Err(_) => (None, None, AgentStatus::NotFound), }; - let effective_model = agent_snapshot - .as_ref() - .map(|snapshot| snapshot.model.clone()) - .unwrap_or_else(|| args.model.clone().unwrap_or_default()); - let effective_reasoning_effort = agent_snapshot - .as_ref() - .and_then(|snapshot| snapshot.reasoning_effort) - .unwrap_or(args.reasoning_effort.unwrap_or_default()); - let nickname = new_agent_nickname.clone(); - session - .send_event( - &turn, - CollabAgentSpawnEndEvent { - call_id, - sender_thread_id: session.conversation_id, - new_thread_id, - new_agent_nickname, - new_agent_role, - prompt, - model: effective_model, - reasoning_effort: effective_reasoning_effort, - status, + let agent_snapshot = match new_thread_id { + Some(thread_id) => { + session + .services + .agent_control + .get_agent_config_snapshot(thread_id) + .await } - .into(), - ) - .await; - let new_thread_id = result?.thread_id; - let role_tag = role_name.unwrap_or(DEFAULT_ROLE_NAME); - turn.session_telemetry.counter( - "codex.multi_agent.spawn", - /*inc*/ 1, - &[("role", role_tag)], - ); + None => None, + }; + let (_new_agent_path, new_agent_nickname, new_agent_role) = + match (&agent_snapshot, new_agent_metadata) { + (Some(snapshot), _) => ( + snapshot.session_source.get_agent_path().map(String::from), + snapshot.session_source.get_nickname(), + snapshot.session_source.get_agent_role(), + ), + (None, Some(metadata)) => ( + metadata.agent_path.map(String::from), + metadata.agent_nickname, + metadata.agent_role, + ), + (None, None) => (None, None, None), + }; + let effective_model = agent_snapshot + .as_ref() + .map(|snapshot| snapshot.model.clone()) + .unwrap_or_else(|| args.model.clone().unwrap_or_default()); + let effective_reasoning_effort = agent_snapshot + .as_ref() + .and_then(|snapshot| snapshot.reasoning_effort) + .unwrap_or(args.reasoning_effort.unwrap_or_default()); + let nickname = new_agent_nickname.clone(); + session + .send_event( + &turn, + CollabAgentSpawnEndEvent { + call_id: call_id.clone(), + sender_thread_id: session.conversation_id, + new_thread_id, + new_agent_nickname, + new_agent_role, + prompt, + model: effective_model, + reasoning_effort: effective_reasoning_effort, + status, + } + .into(), + ) + .await; + let new_thread_id = result?.thread_id; + let role_tag = role_name.unwrap_or(DEFAULT_ROLE_NAME); + turn.session_telemetry.counter( + "codex.multi_agent.spawn", + /*inc*/ 1, + &[("role", role_tag)], + ); - Ok(SpawnAgentResult { - agent_id: new_thread_id.to_string(), - nickname, + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(SpawnAgentResult { + agent_id: new_thread_id.to_string(), + nickname, + }), + }) }) } } diff --git a/codex-rs/core/src/tools/handlers/multi_agents/wait.rs b/codex-rs/core/src/tools/handlers/multi_agents/wait.rs index 2fe33f1edd3c..873444aaecea 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents/wait.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents/wait.rs @@ -14,10 +14,7 @@ use tokio::time::timeout_at; pub(crate) struct Handler; -#[async_trait] impl ToolHandler for Handler { - type Output = WaitAgentResult; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -26,162 +23,172 @@ impl ToolHandler for Handler { matches!(payload, ToolPayload::Function { .. }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - payload, - call_id, - .. - } = invocation; - let arguments = function_arguments(payload)?; - let args: WaitArgs = parse_arguments(&arguments)?; - let receiver_thread_ids = parse_agent_id_targets(args.targets)?; - let mut receiver_agents = Vec::with_capacity(receiver_thread_ids.len()); - let mut target_by_thread_id = HashMap::with_capacity(receiver_thread_ids.len()); - for receiver_thread_id in &receiver_thread_ids { - let agent_metadata = session - .services - .agent_control - .get_agent_metadata(*receiver_thread_id) - .unwrap_or_default(); - target_by_thread_id.insert( - *receiver_thread_id, - agent_metadata - .agent_path - .as_ref() - .map(ToString::to_string) - .unwrap_or_else(|| receiver_thread_id.to_string()), - ); - receiver_agents.push(CollabAgentRef { - thread_id: *receiver_thread_id, - agent_nickname: agent_metadata.agent_nickname, - agent_role: agent_metadata.agent_role, - }); - } - - let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS); - let timeout_ms = match timeout_ms { - ms if ms <= 0 => { - return Err(FunctionCallError::RespondToModel( - "timeout_ms must be greater than zero".to_owned(), - )); + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + payload, + call_id, + .. + } = invocation; + let payload_for_result = payload.clone(); + let arguments = function_arguments(payload)?; + let args: WaitArgs = parse_arguments(&arguments)?; + let receiver_thread_ids = parse_agent_id_targets(args.targets)?; + let mut receiver_agents = Vec::with_capacity(receiver_thread_ids.len()); + let mut target_by_thread_id = HashMap::with_capacity(receiver_thread_ids.len()); + for receiver_thread_id in &receiver_thread_ids { + let agent_metadata = session + .services + .agent_control + .get_agent_metadata(*receiver_thread_id) + .unwrap_or_default(); + target_by_thread_id.insert( + *receiver_thread_id, + agent_metadata + .agent_path + .as_ref() + .map(ToString::to_string) + .unwrap_or_else(|| receiver_thread_id.to_string()), + ); + receiver_agents.push(CollabAgentRef { + thread_id: *receiver_thread_id, + agent_nickname: agent_metadata.agent_nickname, + agent_role: agent_metadata.agent_role, + }); } - ms => ms.clamp(MIN_WAIT_TIMEOUT_MS, MAX_WAIT_TIMEOUT_MS), - }; - - session - .send_event( - &turn, - CollabWaitingBeginEvent { - sender_thread_id: session.conversation_id, - receiver_thread_ids: receiver_thread_ids.clone(), - receiver_agents: receiver_agents.clone(), - call_id: call_id.clone(), + + let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS); + let timeout_ms = match timeout_ms { + ms if ms <= 0 => { + return Err(FunctionCallError::RespondToModel( + "timeout_ms must be greater than zero".to_owned(), + )); } - .into(), - ) - .await; - - let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len()); - let mut initial_final_statuses = Vec::new(); - for id in &receiver_thread_ids { - match session.services.agent_control.subscribe_status(*id).await { - Ok(rx) => { - let status = rx.borrow().clone(); - if is_final(&status) { - initial_final_statuses.push((*id, status)); + ms => ms.clamp(MIN_WAIT_TIMEOUT_MS, MAX_WAIT_TIMEOUT_MS), + }; + + session + .send_event( + &turn, + CollabWaitingBeginEvent { + sender_thread_id: session.conversation_id, + receiver_thread_ids: receiver_thread_ids.clone(), + receiver_agents: receiver_agents.clone(), + call_id: call_id.clone(), + } + .into(), + ) + .await; + + let mut status_rxs = Vec::with_capacity(receiver_thread_ids.len()); + let mut initial_final_statuses = Vec::new(); + for id in &receiver_thread_ids { + match session.services.agent_control.subscribe_status(*id).await { + Ok(rx) => { + let status = rx.borrow().clone(); + if is_final(&status) { + initial_final_statuses.push((*id, status)); + } + status_rxs.push((*id, rx)); + } + Err(CodexErr::ThreadNotFound(_)) => { + initial_final_statuses.push((*id, AgentStatus::NotFound)); + } + Err(err) => { + let mut statuses = HashMap::with_capacity(1); + statuses.insert(*id, session.services.agent_control.get_status(*id).await); + session + .send_event( + &turn, + CollabWaitingEndEvent { + sender_thread_id: session.conversation_id, + call_id: call_id.clone(), + agent_statuses: build_wait_agent_statuses( + &statuses, + &receiver_agents, + ), + statuses, + } + .into(), + ) + .await; + return Err(collab_agent_error(*id, err)); } - status_rxs.push((*id, rx)); - } - Err(CodexErr::ThreadNotFound(_)) => { - initial_final_statuses.push((*id, AgentStatus::NotFound)); - } - Err(err) => { - let mut statuses = HashMap::with_capacity(1); - statuses.insert(*id, session.services.agent_control.get_status(*id).await); - session - .send_event( - &turn, - CollabWaitingEndEvent { - sender_thread_id: session.conversation_id, - call_id: call_id.clone(), - agent_statuses: build_wait_agent_statuses( - &statuses, - &receiver_agents, - ), - statuses, - } - .into(), - ) - .await; - return Err(collab_agent_error(*id, err)); } } - } - let statuses = if !initial_final_statuses.is_empty() { - initial_final_statuses - } else { - let mut futures = FuturesUnordered::new(); - for (id, rx) in status_rxs.into_iter() { - let session = session.clone(); - futures.push(wait_for_final_status(session, id, rx)); - } - let mut results = Vec::new(); - let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64); - loop { - match timeout_at(deadline, futures.next()).await { - Ok(Some(Some(result))) => { - results.push(result); - break; - } - Ok(Some(None)) => continue, - Ok(None) | Err(_) => break, + let statuses = if !initial_final_statuses.is_empty() { + initial_final_statuses + } else { + let mut futures = FuturesUnordered::new(); + for (id, rx) in status_rxs.into_iter() { + let session = session.clone(); + futures.push(wait_for_final_status(session, id, rx)); } - } - if !results.is_empty() { + let mut results = Vec::new(); + let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64); loop { - match futures.next().now_or_never() { - Some(Some(Some(result))) => results.push(result), - Some(Some(None)) => continue, - Some(None) | None => break, + match timeout_at(deadline, futures.next()).await { + Ok(Some(Some(result))) => { + results.push(result); + break; + } + Ok(Some(None)) => continue, + Ok(None) | Err(_) => break, } } - } - results - }; - - let timed_out = statuses.is_empty(); - let statuses_by_id = statuses.clone().into_iter().collect::>(); - let agent_statuses = build_wait_agent_statuses(&statuses_by_id, &receiver_agents); - let result = WaitAgentResult { - status: statuses - .into_iter() - .filter_map(|(thread_id, status)| { - target_by_thread_id - .get(&thread_id) - .cloned() - .map(|target| (target, status)) - }) - .collect(), - timed_out, - }; - - session - .send_event( - &turn, - CollabWaitingEndEvent { - sender_thread_id: session.conversation_id, - call_id, - agent_statuses, - statuses: statuses_by_id, + if !results.is_empty() { + loop { + match futures.next().now_or_never() { + Some(Some(Some(result))) => results.push(result), + Some(Some(None)) => continue, + Some(None) | None => break, + } + } } - .into(), - ) - .await; - - Ok(result) + results + }; + + let timed_out = statuses.is_empty(); + let statuses_by_id = statuses.clone().into_iter().collect::>(); + let agent_statuses = build_wait_agent_statuses(&statuses_by_id, &receiver_agents); + let result = WaitAgentResult { + status: statuses + .into_iter() + .filter_map(|(thread_id, status)| { + target_by_thread_id + .get(&thread_id) + .cloned() + .map(|target| (target, status)) + }) + .collect(), + timed_out, + }; + + session + .send_event( + &turn, + CollabWaitingEndEvent { + sender_thread_id: session.conversation_id, + call_id: call_id.clone(), + agent_statuses, + statuses: statuses_by_id, + } + .into(), + ) + .await; + + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(result), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/multi_agents_v2.rs b/codex-rs/core/src/tools/handlers/multi_agents_v2.rs index 44191ed85e8d..afd6210452e8 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents_v2.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents_v2.rs @@ -9,9 +9,9 @@ use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; use crate::tools::handlers::multi_agents_common::*; use crate::tools::handlers::parse_arguments; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; -use async_trait::async_trait; use codex_protocol::AgentPath; use codex_protocol::models::ResponseInputItem; use codex_protocol::openai_models::ReasoningEffort; @@ -24,6 +24,7 @@ use codex_protocol::protocol::CollabCloseEndEvent; use codex_protocol::protocol::CollabWaitingBeginEvent; use codex_protocol::protocol::CollabWaitingEndEvent; use codex_protocol::user_input::UserInput; +use futures::future::BoxFuture; use serde::Deserialize; use serde::Serialize; use serde_json::Value as JsonValue; diff --git a/codex-rs/core/src/tools/handlers/multi_agents_v2/close_agent.rs b/codex-rs/core/src/tools/handlers/multi_agents_v2/close_agent.rs index 2296ae5dd3ea..9cd5e9bbe327 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents_v2/close_agent.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents_v2/close_agent.rs @@ -2,10 +2,7 @@ use super::*; pub(crate) struct Handler; -#[async_trait] impl ToolHandler for Handler { - type Output = CloseAgentResult; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -14,93 +11,103 @@ impl ToolHandler for Handler { matches!(payload, ToolPayload::Function { .. }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - payload, - call_id, - .. - } = invocation; - let arguments = function_arguments(payload)?; - let args: CloseAgentArgs = parse_arguments(&arguments)?; - let agent_id = resolve_agent_target(&session, &turn, &args.target).await?; - let receiver_agent = session - .services - .agent_control - .get_agent_metadata(agent_id) - .unwrap_or_default(); - if receiver_agent - .agent_path - .as_ref() - .is_some_and(AgentPath::is_root) - { - return Err(FunctionCallError::RespondToModel( - "root is not a spawned agent".to_string(), - )); - } - session - .send_event( - &turn, - CollabCloseBeginEvent { - call_id: call_id.clone(), - sender_thread_id: session.conversation_id, - receiver_thread_id: agent_id, - } - .into(), - ) - .await; - let status = match session - .services - .agent_control - .subscribe_status(agent_id) - .await - { - Ok(mut status_rx) => status_rx.borrow_and_update().clone(), - Err(err) => { - let status = session.services.agent_control.get_status(agent_id).await; - session - .send_event( - &turn, - CollabCloseEndEvent { - call_id: call_id.clone(), - sender_thread_id: session.conversation_id, - receiver_thread_id: agent_id, - receiver_agent_nickname: receiver_agent.agent_nickname.clone(), - receiver_agent_role: receiver_agent.agent_role.clone(), - status, - } - .into(), - ) - .await; - return Err(collab_agent_error(agent_id, err)); + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + payload, + call_id, + .. + } = invocation; + let payload_for_result = payload.clone(); + let arguments = function_arguments(payload)?; + let args: CloseAgentArgs = parse_arguments(&arguments)?; + let agent_id = resolve_agent_target(&session, &turn, &args.target).await?; + let receiver_agent = session + .services + .agent_control + .get_agent_metadata(agent_id) + .unwrap_or_default(); + if receiver_agent + .agent_path + .as_ref() + .is_some_and(AgentPath::is_root) + { + return Err(FunctionCallError::RespondToModel( + "root is not a spawned agent".to_string(), + )); } - }; - let result = session - .services - .agent_control - .close_agent(agent_id) - .await - .map_err(|err| collab_agent_error(agent_id, err)) - .map(|_| ()); - session - .send_event( - &turn, - CollabCloseEndEvent { - call_id, - sender_thread_id: session.conversation_id, - receiver_thread_id: agent_id, - receiver_agent_nickname: receiver_agent.agent_nickname, - receiver_agent_role: receiver_agent.agent_role, - status: status.clone(), + session + .send_event( + &turn, + CollabCloseBeginEvent { + call_id: call_id.clone(), + sender_thread_id: session.conversation_id, + receiver_thread_id: agent_id, + } + .into(), + ) + .await; + let status = match session + .services + .agent_control + .subscribe_status(agent_id) + .await + { + Ok(mut status_rx) => status_rx.borrow_and_update().clone(), + Err(err) => { + let status = session.services.agent_control.get_status(agent_id).await; + session + .send_event( + &turn, + CollabCloseEndEvent { + call_id: call_id.clone(), + sender_thread_id: session.conversation_id, + receiver_thread_id: agent_id, + receiver_agent_nickname: receiver_agent.agent_nickname.clone(), + receiver_agent_role: receiver_agent.agent_role.clone(), + status, + } + .into(), + ) + .await; + return Err(collab_agent_error(agent_id, err)); } - .into(), - ) - .await; - result?; + }; + let result = session + .services + .agent_control + .close_agent(agent_id) + .await + .map_err(|err| collab_agent_error(agent_id, err)) + .map(|_| ()); + session + .send_event( + &turn, + CollabCloseEndEvent { + call_id: call_id.clone(), + sender_thread_id: session.conversation_id, + receiver_thread_id: agent_id, + receiver_agent_nickname: receiver_agent.agent_nickname, + receiver_agent_role: receiver_agent.agent_role, + status: status.clone(), + } + .into(), + ) + .await; + result?; - Ok(CloseAgentResult { - previous_status: status, + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(CloseAgentResult { + previous_status: status, + }), + }) }) } } diff --git a/codex-rs/core/src/tools/handlers/multi_agents_v2/followup_task.rs b/codex-rs/core/src/tools/handlers/multi_agents_v2/followup_task.rs index d0f87f34268b..f3b1fcaf252f 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents_v2/followup_task.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents_v2/followup_task.rs @@ -1,15 +1,11 @@ use super::message_tool::FollowupTaskArgs; use super::message_tool::MessageDeliveryMode; -use super::message_tool::MessageToolResult; use super::message_tool::handle_message_string_tool; use super::*; pub(crate) struct Handler; -#[async_trait] impl ToolHandler for Handler { - type Output = MessageToolResult; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -18,16 +14,29 @@ impl ToolHandler for Handler { matches!(payload, ToolPayload::Function { .. }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let arguments = function_arguments(invocation.payload.clone())?; - let args: FollowupTaskArgs = parse_arguments(&arguments)?; - handle_message_string_tool( - invocation, - MessageDeliveryMode::TriggerTurn, - args.target, - args.message, - args.interrupt, - ) - .await + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let call_id = invocation.call_id.clone(); + let payload_for_result = invocation.payload.clone(); + let arguments = function_arguments(invocation.payload.clone())?; + let args: FollowupTaskArgs = parse_arguments(&arguments)?; + let result = handle_message_string_tool( + invocation, + MessageDeliveryMode::TriggerTurn, + args.target, + args.message, + args.interrupt, + ) + .await?; + + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(result), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/multi_agents_v2/list_agents.rs b/codex-rs/core/src/tools/handlers/multi_agents_v2/list_agents.rs index b3b3f7160b7a..3a1383841e35 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents_v2/list_agents.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents_v2/list_agents.rs @@ -3,10 +3,7 @@ use crate::agent::control::ListedAgent; pub(crate) struct Handler; -#[async_trait] impl ToolHandler for Handler { - type Output = ListAgentsResult; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -15,27 +12,38 @@ impl ToolHandler for Handler { matches!(payload, ToolPayload::Function { .. }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - payload, - .. - } = invocation; - let arguments = function_arguments(payload)?; - let args: ListAgentsArgs = parse_arguments(&arguments)?; - session - .services - .agent_control - .register_session_root(session.conversation_id, &turn.session_source); - let agents = session - .services - .agent_control - .list_agents(&turn.session_source, args.path_prefix.as_deref()) - .await - .map_err(collab_spawn_error)?; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + payload, + call_id, + .. + } = invocation; + let payload_for_result = payload.clone(); + let arguments = function_arguments(payload)?; + let args: ListAgentsArgs = parse_arguments(&arguments)?; + session + .services + .agent_control + .register_session_root(session.conversation_id, &turn.session_source); + let agents = session + .services + .agent_control + .list_agents(&turn.session_source, args.path_prefix.as_deref()) + .await + .map_err(collab_spawn_error)?; - Ok(ListAgentsResult { agents }) + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(ListAgentsResult { agents }), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/multi_agents_v2/send_message.rs b/codex-rs/core/src/tools/handlers/multi_agents_v2/send_message.rs index a16aebd4c299..18da7e8a8b4a 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents_v2/send_message.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents_v2/send_message.rs @@ -1,15 +1,11 @@ use super::message_tool::MessageDeliveryMode; -use super::message_tool::MessageToolResult; use super::message_tool::SendMessageArgs; use super::message_tool::handle_message_string_tool; use super::*; pub(crate) struct Handler; -#[async_trait] impl ToolHandler for Handler { - type Output = MessageToolResult; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -18,16 +14,29 @@ impl ToolHandler for Handler { matches!(payload, ToolPayload::Function { .. }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let arguments = function_arguments(invocation.payload.clone())?; - let args: SendMessageArgs = parse_arguments(&arguments)?; - handle_message_string_tool( - invocation, - MessageDeliveryMode::QueueOnly, - args.target, - args.message, - /*interrupt*/ false, - ) - .await + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let call_id = invocation.call_id.clone(); + let payload_for_result = invocation.payload.clone(); + let arguments = function_arguments(invocation.payload.clone())?; + let args: SendMessageArgs = parse_arguments(&arguments)?; + let result = handle_message_string_tool( + invocation, + MessageDeliveryMode::QueueOnly, + args.target, + args.message, + /*interrupt*/ false, + ) + .await?; + + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(result), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/multi_agents_v2/spawn.rs b/codex-rs/core/src/tools/handlers/multi_agents_v2/spawn.rs index 3dc8d7202464..a81aad656457 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents_v2/spawn.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents_v2/spawn.rs @@ -11,10 +11,7 @@ use codex_protocol::protocol::Op; pub(crate) struct Handler; -#[async_trait] impl ToolHandler for Handler { - type Output = SpawnAgentResult; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -23,178 +20,188 @@ impl ToolHandler for Handler { matches!(payload, ToolPayload::Function { .. }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - payload, - call_id, - .. - } = invocation; - let arguments = function_arguments(payload)?; - let args: SpawnAgentArgs = parse_arguments(&arguments)?; - let fork_mode = args.fork_mode()?; - let role_name = args - .agent_type - .as_deref() - .map(str::trim) - .filter(|role| !role.is_empty()); + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + payload, + call_id, + .. + } = invocation; + let payload_for_result = payload.clone(); + let arguments = function_arguments(payload)?; + let args: SpawnAgentArgs = parse_arguments(&arguments)?; + let fork_mode = args.fork_mode()?; + let role_name = args + .agent_type + .as_deref() + .map(str::trim) + .filter(|role| !role.is_empty()); - let initial_operation = parse_collab_input(Some(args.message), /*items*/ None)?; - let prompt = render_input_preview(&initial_operation); + let initial_operation = parse_collab_input(Some(args.message), /*items*/ None)?; + let prompt = render_input_preview(&initial_operation); - let session_source = turn.session_source.clone(); - let child_depth = next_thread_spawn_depth(&session_source); - let max_depth = turn.config.agent_max_depth; - if exceeds_thread_spawn_depth_limit(child_depth, max_depth) { - return Err(FunctionCallError::RespondToModel( - "Agent depth limit reached. Solve the task yourself.".to_string(), - )); - } - session - .send_event( - &turn, - CollabAgentSpawnBeginEvent { - call_id: call_id.clone(), - sender_thread_id: session.conversation_id, - prompt: prompt.clone(), - model: args.model.clone().unwrap_or_default(), - reasoning_effort: args.reasoning_effort.unwrap_or_default(), - } - .into(), + let session_source = turn.session_source.clone(); + let child_depth = next_thread_spawn_depth(&session_source); + let max_depth = turn.config.agent_max_depth; + if exceeds_thread_spawn_depth_limit(child_depth, max_depth) { + return Err(FunctionCallError::RespondToModel( + "Agent depth limit reached. Solve the task yourself.".to_string(), + )); + } + session + .send_event( + &turn, + CollabAgentSpawnBeginEvent { + call_id: call_id.clone(), + sender_thread_id: session.conversation_id, + prompt: prompt.clone(), + model: args.model.clone().unwrap_or_default(), + reasoning_effort: args.reasoning_effort.unwrap_or_default(), + } + .into(), + ) + .await; + let mut config = + build_agent_spawn_config(&session.get_base_instructions().await, turn.as_ref())?; + apply_requested_spawn_agent_model_overrides( + &session, + turn.as_ref(), + &mut config, + args.model.as_deref(), + args.reasoning_effort, ) - .await; - let mut config = - build_agent_spawn_config(&session.get_base_instructions().await, turn.as_ref())?; - apply_requested_spawn_agent_model_overrides( - &session, - turn.as_ref(), - &mut config, - args.model.as_deref(), - args.reasoning_effort, - ) - .await?; - apply_role_to_config(&mut config, role_name) - .await - .map_err(FunctionCallError::RespondToModel)?; - apply_spawn_agent_runtime_overrides(&mut config, turn.as_ref())?; - apply_spawn_agent_overrides(&mut config, child_depth); + .await?; + apply_role_to_config(&mut config, role_name) + .await + .map_err(FunctionCallError::RespondToModel)?; + apply_spawn_agent_runtime_overrides(&mut config, turn.as_ref())?; + apply_spawn_agent_overrides(&mut config, child_depth); - let spawn_source = thread_spawn_source( - session.conversation_id, - &turn.session_source, - child_depth, - role_name, - Some(args.task_name.clone()), - )?; - let result = session - .services - .agent_control - .spawn_agent_with_metadata( - config, - match (spawn_source.get_agent_path(), initial_operation) { - (Some(recipient), Op::UserInput { items, .. }) - if items - .iter() - .all(|item| matches!(item, UserInput::Text { .. })) => - { - Op::InterAgentCommunication { - communication: InterAgentCommunication::new( - turn.session_source - .get_agent_path() - .unwrap_or_else(AgentPath::root), - recipient, - Vec::new(), - prompt.clone(), - /*trigger_turn*/ true, - ), + let spawn_source = thread_spawn_source( + session.conversation_id, + &turn.session_source, + child_depth, + role_name, + Some(args.task_name.clone()), + )?; + let result = session + .services + .agent_control + .spawn_agent_with_metadata( + config, + match (spawn_source.get_agent_path(), initial_operation) { + (Some(recipient), Op::UserInput { items, .. }) + if items + .iter() + .all(|item| matches!(item, UserInput::Text { .. })) => + { + Op::InterAgentCommunication { + communication: InterAgentCommunication::new( + turn.session_source + .get_agent_path() + .unwrap_or_else(AgentPath::root), + recipient, + Vec::new(), + prompt.clone(), + /*trigger_turn*/ true, + ), + } } - } - (_, initial_operation) => initial_operation, - }, - Some(spawn_source), - SpawnAgentOptions { - fork_parent_spawn_call_id: fork_mode.as_ref().map(|_| call_id.clone()), - fork_mode, - }, - ) - .await - .map_err(collab_spawn_error); - let (new_thread_id, new_agent_metadata, status) = match &result { - Ok(spawned_agent) => ( - Some(spawned_agent.thread_id), - Some(spawned_agent.metadata.clone()), - spawned_agent.status.clone(), - ), - Err(_) => (None, None, AgentStatus::NotFound), - }; - let agent_snapshot = match new_thread_id { - Some(thread_id) => { - session - .services - .agent_control - .get_agent_config_snapshot(thread_id) - .await - } - None => None, - }; - let (new_agent_path, new_agent_nickname, new_agent_role) = - match (&agent_snapshot, new_agent_metadata) { - (Some(snapshot), _) => ( - snapshot.session_source.get_agent_path().map(String::from), - snapshot.session_source.get_nickname(), - snapshot.session_source.get_agent_role(), + (_, initial_operation) => initial_operation, + }, + Some(spawn_source), + SpawnAgentOptions { + fork_parent_spawn_call_id: fork_mode.as_ref().map(|_| call_id.clone()), + fork_mode, + }, + ) + .await + .map_err(collab_spawn_error); + let (new_thread_id, new_agent_metadata, status) = match &result { + Ok(spawned_agent) => ( + Some(spawned_agent.thread_id), + Some(spawned_agent.metadata.clone()), + spawned_agent.status.clone(), ), - (None, Some(metadata)) => ( - metadata.agent_path.map(String::from), - metadata.agent_nickname, - metadata.agent_role, - ), - (None, None) => (None, None, None), + Err(_) => (None, None, AgentStatus::NotFound), }; - let effective_model = agent_snapshot - .as_ref() - .map(|snapshot| snapshot.model.clone()) - .unwrap_or_else(|| args.model.clone().unwrap_or_default()); - let effective_reasoning_effort = agent_snapshot - .as_ref() - .and_then(|snapshot| snapshot.reasoning_effort) - .unwrap_or(args.reasoning_effort.unwrap_or_default()); - let nickname = new_agent_nickname.clone(); - session - .send_event( - &turn, - CollabAgentSpawnEndEvent { - call_id, - sender_thread_id: session.conversation_id, - new_thread_id, - new_agent_nickname, - new_agent_role, - prompt, - model: effective_model, - reasoning_effort: effective_reasoning_effort, - status, + let agent_snapshot = match new_thread_id { + Some(thread_id) => { + session + .services + .agent_control + .get_agent_config_snapshot(thread_id) + .await } - .into(), - ) - .await; - let _ = result?; - let role_tag = role_name.unwrap_or(DEFAULT_ROLE_NAME); - turn.session_telemetry.counter( - "codex.multi_agent.spawn", - /*inc*/ 1, - &[("role", role_tag)], - ); - let task_name = new_agent_path.ok_or_else(|| { - FunctionCallError::RespondToModel( - "spawned agent is missing a canonical task name".to_string(), - ) - })?; + None => None, + }; + let (new_agent_path, new_agent_nickname, new_agent_role) = + match (&agent_snapshot, new_agent_metadata) { + (Some(snapshot), _) => ( + snapshot.session_source.get_agent_path().map(String::from), + snapshot.session_source.get_nickname(), + snapshot.session_source.get_agent_role(), + ), + (None, Some(metadata)) => ( + metadata.agent_path.map(String::from), + metadata.agent_nickname, + metadata.agent_role, + ), + (None, None) => (None, None, None), + }; + let effective_model = agent_snapshot + .as_ref() + .map(|snapshot| snapshot.model.clone()) + .unwrap_or_else(|| args.model.clone().unwrap_or_default()); + let effective_reasoning_effort = agent_snapshot + .as_ref() + .and_then(|snapshot| snapshot.reasoning_effort) + .unwrap_or(args.reasoning_effort.unwrap_or_default()); + let nickname = new_agent_nickname.clone(); + session + .send_event( + &turn, + CollabAgentSpawnEndEvent { + call_id: call_id.clone(), + sender_thread_id: session.conversation_id, + new_thread_id, + new_agent_nickname, + new_agent_role, + prompt, + model: effective_model, + reasoning_effort: effective_reasoning_effort, + status, + } + .into(), + ) + .await; + let _ = result?; + let role_tag = role_name.unwrap_or(DEFAULT_ROLE_NAME); + turn.session_telemetry.counter( + "codex.multi_agent.spawn", + /*inc*/ 1, + &[("role", role_tag)], + ); + let task_name = new_agent_path.ok_or_else(|| { + FunctionCallError::RespondToModel( + "spawned agent is missing a canonical task name".to_string(), + ) + })?; - Ok(SpawnAgentResult { - agent_id: None, - task_name, - nickname, + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(SpawnAgentResult { + agent_id: None, + task_name, + nickname, + }), + }) }) } } diff --git a/codex-rs/core/src/tools/handlers/multi_agents_v2/wait.rs b/codex-rs/core/src/tools/handlers/multi_agents_v2/wait.rs index bf3ad8e93409..702aa5cbd479 100644 --- a/codex-rs/core/src/tools/handlers/multi_agents_v2/wait.rs +++ b/codex-rs/core/src/tools/handlers/multi_agents_v2/wait.rs @@ -6,10 +6,7 @@ use tokio::time::timeout_at; pub(crate) struct Handler; -#[async_trait] impl ToolHandler for Handler { - type Output = WaitAgentResult; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -18,59 +15,69 @@ impl ToolHandler for Handler { matches!(payload, ToolPayload::Function { .. }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - payload, - call_id, - .. - } = invocation; - let arguments = function_arguments(payload)?; - let args: WaitArgs = parse_arguments(&arguments)?; - let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS); - let timeout_ms = match timeout_ms { - ms if ms <= 0 => { - return Err(FunctionCallError::RespondToModel( - "timeout_ms must be greater than zero".to_owned(), - )); - } - ms => ms.clamp(MIN_WAIT_TIMEOUT_MS, MAX_WAIT_TIMEOUT_MS), - }; - - let mut mailbox_seq_rx = session.subscribe_mailbox_seq(); - - session - .send_event( - &turn, - CollabWaitingBeginEvent { - sender_thread_id: session.conversation_id, - receiver_thread_ids: Vec::new(), - receiver_agents: Vec::new(), - call_id: call_id.clone(), - } - .into(), - ) - .await; - - let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64); - let timed_out = !wait_for_mailbox_change(&mut mailbox_seq_rx, deadline).await; - let result = WaitAgentResult::from_timed_out(timed_out); - - session - .send_event( - &turn, - CollabWaitingEndEvent { - sender_thread_id: session.conversation_id, - call_id, - agent_statuses: Vec::new(), - statuses: HashMap::new(), + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + payload, + call_id, + .. + } = invocation; + let payload_for_result = payload.clone(); + let arguments = function_arguments(payload)?; + let args: WaitArgs = parse_arguments(&arguments)?; + let timeout_ms = args.timeout_ms.unwrap_or(DEFAULT_WAIT_TIMEOUT_MS); + let timeout_ms = match timeout_ms { + ms if ms <= 0 => { + return Err(FunctionCallError::RespondToModel( + "timeout_ms must be greater than zero".to_owned(), + )); } - .into(), - ) - .await; - - Ok(result) + ms => ms.clamp(MIN_WAIT_TIMEOUT_MS, MAX_WAIT_TIMEOUT_MS), + }; + + let mut mailbox_seq_rx = session.subscribe_mailbox_seq(); + + session + .send_event( + &turn, + CollabWaitingBeginEvent { + sender_thread_id: session.conversation_id, + receiver_thread_ids: Vec::new(), + receiver_agents: Vec::new(), + call_id: call_id.clone(), + } + .into(), + ) + .await; + + let deadline = Instant::now() + Duration::from_millis(timeout_ms as u64); + let timed_out = !wait_for_mailbox_change(&mut mailbox_seq_rx, deadline).await; + let result = WaitAgentResult::from_timed_out(timed_out); + + session + .send_event( + &turn, + CollabWaitingEndEvent { + sender_thread_id: session.conversation_id, + call_id: call_id.clone(), + agent_statuses: Vec::new(), + statuses: HashMap::new(), + } + .into(), + ) + .await; + + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(result), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/plan.rs b/codex-rs/core/src/tools/handlers/plan.rs index ce0b98b1cb93..8bf3fa3ec57c 100644 --- a/codex-rs/core/src/tools/handlers/plan.rs +++ b/codex-rs/core/src/tools/handlers/plan.rs @@ -4,14 +4,15 @@ use crate::function_tool::FunctionCallError; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; -use async_trait::async_trait; use codex_protocol::config_types::ModeKind; use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::ResponseInputItem; use codex_protocol::plan_tool::UpdatePlanArgs; use codex_protocol::protocol::EventMsg; +use futures::future::BoxFuture; use serde_json::Value as JsonValue; pub struct PlanHandler; @@ -44,35 +45,42 @@ impl ToolOutput for PlanToolOutput { } } -#[async_trait] impl ToolHandler for PlanHandler { - type Output = PlanToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - call_id, - payload, - .. - } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "update_plan handler received unsupported payload".to_string(), - )); - } - }; - - handle_update_plan(session.as_ref(), turn.as_ref(), arguments, call_id).await?; - - Ok(PlanToolOutput) + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + call_id, + payload, + .. + } = invocation; + + let payload_for_result = payload.clone(); + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "update_plan handler received unsupported payload".to_string(), + )); + } + }; + + handle_update_plan(session.as_ref(), turn.as_ref(), arguments, call_id.clone()).await?; + + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(PlanToolOutput), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/request_permissions.rs b/codex-rs/core/src/tools/handlers/request_permissions.rs index e0755deac7a9..9a5edc912c69 100644 --- a/codex-rs/core/src/tools/handlers/request_permissions.rs +++ b/codex-rs/core/src/tools/handlers/request_permissions.rs @@ -1,69 +1,77 @@ -use async_trait::async_trait; use codex_protocol::request_permissions::RequestPermissionsArgs; use codex_sandboxing::policy_transforms::normalize_additional_permissions; +use futures::future::BoxFuture; use crate::function_tool::FunctionCallError; use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; use crate::tools::handlers::parse_arguments_with_base_path; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; pub struct RequestPermissionsHandler; -#[async_trait] impl ToolHandler for RequestPermissionsHandler { - type Output = FunctionToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - call_id, - payload, - .. - } = invocation; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + call_id, + payload, + .. + } = invocation; + + let payload_for_result = payload.clone(); + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "request_permissions handler received unsupported payload".to_string(), + )); + } + }; - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { + let mut args: RequestPermissionsArgs = + parse_arguments_with_base_path(&arguments, turn.cwd.as_path())?; + args.permissions = normalize_additional_permissions(args.permissions.into()) + .map(codex_protocol::request_permissions::RequestPermissionProfile::from) + .map_err(FunctionCallError::RespondToModel)?; + if args.permissions.is_empty() { return Err(FunctionCallError::RespondToModel( - "request_permissions handler received unsupported payload".to_string(), + "request_permissions requires at least one permission".to_string(), )); } - }; - let mut args: RequestPermissionsArgs = - parse_arguments_with_base_path(&arguments, turn.cwd.as_path())?; - args.permissions = normalize_additional_permissions(args.permissions.into()) - .map(codex_protocol::request_permissions::RequestPermissionProfile::from) - .map_err(FunctionCallError::RespondToModel)?; - if args.permissions.is_empty() { - return Err(FunctionCallError::RespondToModel( - "request_permissions requires at least one permission".to_string(), - )); - } + let response = session + .request_permissions(turn.as_ref(), call_id.clone(), args) + .await + .ok_or_else(|| { + FunctionCallError::RespondToModel( + "request_permissions was cancelled before receiving a response".to_string(), + ) + })?; - let response = session - .request_permissions(turn.as_ref(), call_id, args) - .await - .ok_or_else(|| { - FunctionCallError::RespondToModel( - "request_permissions was cancelled before receiving a response".to_string(), - ) + let content = serde_json::to_string(&response).map_err(|err| { + FunctionCallError::Fatal(format!( + "failed to serialize request_permissions response: {err}" + )) })?; - let content = serde_json::to_string(&response).map_err(|err| { - FunctionCallError::Fatal(format!( - "failed to serialize request_permissions response: {err}" - )) - })?; - - Ok(FunctionToolOutput::from_text(content, Some(true))) + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(FunctionToolOutput::from_text(content, Some(true))), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/request_user_input.rs b/codex-rs/core/src/tools/handlers/request_user_input.rs index f12defd0e7e9..f1093f12cc8c 100644 --- a/codex-rs/core/src/tools/handlers/request_user_input.rs +++ b/codex-rs/core/src/tools/handlers/request_user_input.rs @@ -3,69 +3,77 @@ use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; use crate::tools::handlers::parse_arguments; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; -use async_trait::async_trait; use codex_protocol::request_user_input::RequestUserInputArgs; use codex_tools::REQUEST_USER_INPUT_TOOL_NAME; use codex_tools::normalize_request_user_input_args; use codex_tools::request_user_input_unavailable_message; +use futures::future::BoxFuture; pub struct RequestUserInputHandler { pub default_mode_request_user_input: bool, } -#[async_trait] impl ToolHandler for RequestUserInputHandler { - type Output = FunctionToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - call_id, - payload, - .. - } = invocation; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + call_id, + payload, + .. + } = invocation; + + let payload_for_result = payload.clone(); + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel(format!( + "{REQUEST_USER_INPUT_TOOL_NAME} handler received unsupported payload" + ))); + } + }; - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel(format!( - "{REQUEST_USER_INPUT_TOOL_NAME} handler received unsupported payload" - ))); + let mode = session.collaboration_mode().await.mode; + if let Some(message) = + request_user_input_unavailable_message(mode, self.default_mode_request_user_input) + { + return Err(FunctionCallError::RespondToModel(message)); } - }; - let mode = session.collaboration_mode().await.mode; - if let Some(message) = - request_user_input_unavailable_message(mode, self.default_mode_request_user_input) - { - return Err(FunctionCallError::RespondToModel(message)); - } + let args: RequestUserInputArgs = parse_arguments(&arguments)?; + let args = normalize_request_user_input_args(args) + .map_err(FunctionCallError::RespondToModel)?; + let response = session + .request_user_input(turn.as_ref(), call_id.clone(), args) + .await + .ok_or_else(|| { + FunctionCallError::RespondToModel(format!( + "{REQUEST_USER_INPUT_TOOL_NAME} was cancelled before receiving a response" + )) + })?; - let args: RequestUserInputArgs = parse_arguments(&arguments)?; - let args = - normalize_request_user_input_args(args).map_err(FunctionCallError::RespondToModel)?; - let response = session - .request_user_input(turn.as_ref(), call_id, args) - .await - .ok_or_else(|| { - FunctionCallError::RespondToModel(format!( - "{REQUEST_USER_INPUT_TOOL_NAME} was cancelled before receiving a response" + let content = serde_json::to_string(&response).map_err(|err| { + FunctionCallError::Fatal(format!( + "failed to serialize {REQUEST_USER_INPUT_TOOL_NAME} response: {err}" )) })?; - let content = serde_json::to_string(&response).map_err(|err| { - FunctionCallError::Fatal(format!( - "failed to serialize {REQUEST_USER_INPUT_TOOL_NAME} response: {err}" - )) - })?; - - Ok(FunctionToolOutput::from_text(content, Some(true))) + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(FunctionToolOutput::from_text(content, Some(true))), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/shell.rs b/codex-rs/core/src/tools/handlers/shell.rs index 34cdd5b976e6..cbf7ed9ceefa 100644 --- a/codex-rs/core/src/tools/handlers/shell.rs +++ b/codex-rs/core/src/tools/handlers/shell.rs @@ -1,4 +1,3 @@ -use async_trait::async_trait; use codex_protocol::ThreadId; use codex_protocol::models::ShellCommandToolCallParams; use codex_protocol::models::ShellToolCallParams; @@ -27,6 +26,7 @@ use crate::tools::handlers::parse_arguments; use crate::tools::handlers::parse_arguments_with_base_path; use crate::tools::handlers::resolve_workdir_base_path; use crate::tools::orchestrator::ToolOrchestrator; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::PostToolUsePayload; use crate::tools::registry::PreToolUsePayload; use crate::tools::registry::ToolHandler; @@ -40,6 +40,7 @@ use codex_protocol::models::PermissionProfile; use codex_protocol::protocol::ExecCommandSource; use codex_shell_command::is_safe_command::is_known_safe_command; use codex_tools::ShellCommandBackendConfig; +use futures::future::BoxFuture; pub struct ShellHandler; @@ -178,10 +179,7 @@ impl From for ShellCommandHandler { } } -#[async_trait] impl ToolHandler for ShellHandler { - type Output = FunctionToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -193,7 +191,7 @@ impl ToolHandler for ShellHandler { ) } - async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { + fn is_mutating(&self, invocation: &ToolInvocation) -> bool { match &invocation.payload { ToolPayload::Function { arguments } => { serde_json::from_str::(arguments) @@ -222,67 +220,78 @@ impl ToolHandler for ShellHandler { }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - tracker, - call_id, - tool_name, - payload, - .. - } = invocation; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + tracker, + call_id, + tool_name, + payload, + .. + } = invocation; + let payload_for_result = payload.clone(); + + let output = match payload { + ToolPayload::Function { arguments } => { + let cwd = resolve_workdir_base_path(&arguments, turn.cwd.as_path())?; + let params: ShellToolCallParams = + parse_arguments_with_base_path(&arguments, cwd.as_path())?; + let prefix_rule = params.prefix_rule.clone(); + let exec_params = + Self::to_exec_params(¶ms, turn.as_ref(), session.conversation_id); + Self::run_exec_like(RunExecLikeArgs { + tool_name: tool_name.clone(), + exec_params, + additional_permissions: params.additional_permissions.clone(), + prefix_rule, + session, + turn, + tracker, + call_id: call_id.clone(), + freeform: false, + shell_runtime_backend: ShellRuntimeBackend::Generic, + }) + .await? + } + ToolPayload::LocalShell { params } => { + let exec_params = + Self::to_exec_params(¶ms, turn.as_ref(), session.conversation_id); + Self::run_exec_like(RunExecLikeArgs { + tool_name: tool_name.clone(), + exec_params, + additional_permissions: None, + prefix_rule: None, + session, + turn, + tracker, + call_id: call_id.clone(), + freeform: false, + shell_runtime_backend: ShellRuntimeBackend::Generic, + }) + .await? + } + _ => { + return Err(FunctionCallError::RespondToModel(format!( + "unsupported payload for shell handler: {tool_name}" + ))); + } + }; - match payload { - ToolPayload::Function { arguments } => { - let cwd = resolve_workdir_base_path(&arguments, turn.cwd.as_path())?; - let params: ShellToolCallParams = - parse_arguments_with_base_path(&arguments, cwd.as_path())?; - let prefix_rule = params.prefix_rule.clone(); - let exec_params = - Self::to_exec_params(¶ms, turn.as_ref(), session.conversation_id); - Self::run_exec_like(RunExecLikeArgs { - tool_name: tool_name.clone(), - exec_params, - additional_permissions: params.additional_permissions.clone(), - prefix_rule, - session, - turn, - tracker, - call_id, - freeform: false, - shell_runtime_backend: ShellRuntimeBackend::Generic, - }) - .await - } - ToolPayload::LocalShell { params } => { - let exec_params = - Self::to_exec_params(¶ms, turn.as_ref(), session.conversation_id); - Self::run_exec_like(RunExecLikeArgs { - tool_name: tool_name.clone(), - exec_params, - additional_permissions: None, - prefix_rule: None, - session, - turn, - tracker, - call_id, - freeform: false, - shell_runtime_backend: ShellRuntimeBackend::Generic, - }) - .await - } - _ => Err(FunctionCallError::RespondToModel(format!( - "unsupported payload for shell handler: {tool_name}" - ))), - } + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(output), + }) + }) } } -#[async_trait] impl ToolHandler for ShellCommandHandler { - type Output = FunctionToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -291,7 +300,7 @@ impl ToolHandler for ShellCommandHandler { matches!(payload, ToolPayload::Function { .. }) } - async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { + fn is_mutating(&self, invocation: &ToolInvocation) -> bool { let ToolPayload::Function { arguments } = &invocation.payload else { return true; }; @@ -330,55 +339,67 @@ impl ToolHandler for ShellCommandHandler { }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - tracker, - call_id, - tool_name, - payload, - .. - } = invocation; - - let ToolPayload::Function { arguments } = payload else { - return Err(FunctionCallError::RespondToModel(format!( - "unsupported payload for shell_command handler: {tool_name}" - ))); - }; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + tracker, + call_id, + tool_name, + payload, + .. + } = invocation; + let payload_for_result = payload.clone(); + + let ToolPayload::Function { arguments } = payload else { + return Err(FunctionCallError::RespondToModel(format!( + "unsupported payload for shell_command handler: {tool_name}" + ))); + }; + + let cwd = resolve_workdir_base_path(&arguments, turn.cwd.as_path())?; + let params: ShellCommandToolCallParams = + parse_arguments_with_base_path(&arguments, cwd.as_path())?; + let workdir = turn.resolve_path(params.workdir.clone()); + maybe_emit_implicit_skill_invocation( + session.as_ref(), + turn.as_ref(), + ¶ms.command, + &workdir, + ) + .await; + let prefix_rule = params.prefix_rule.clone(); + let exec_params = Self::to_exec_params( + ¶ms, + session.as_ref(), + turn.as_ref(), + session.conversation_id, + turn.tools_config.allow_login_shell, + )?; + let output = ShellHandler::run_exec_like(RunExecLikeArgs { + tool_name, + exec_params, + additional_permissions: params.additional_permissions.clone(), + prefix_rule, + session, + turn, + tracker, + call_id: call_id.clone(), + freeform: true, + shell_runtime_backend: self.shell_runtime_backend(), + }) + .await?; - let cwd = resolve_workdir_base_path(&arguments, turn.cwd.as_path())?; - let params: ShellCommandToolCallParams = - parse_arguments_with_base_path(&arguments, cwd.as_path())?; - let workdir = turn.resolve_path(params.workdir.clone()); - maybe_emit_implicit_skill_invocation( - session.as_ref(), - turn.as_ref(), - ¶ms.command, - &workdir, - ) - .await; - let prefix_rule = params.prefix_rule.clone(); - let exec_params = Self::to_exec_params( - ¶ms, - session.as_ref(), - turn.as_ref(), - session.conversation_id, - turn.tools_config.allow_login_shell, - )?; - ShellHandler::run_exec_like(RunExecLikeArgs { - tool_name, - exec_params, - additional_permissions: params.additional_permissions.clone(), - prefix_rule, - session, - turn, - tracker, - call_id, - freeform: true, - shell_runtime_backend: self.shell_runtime_backend(), + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(output), + }) }) - .await } } diff --git a/codex-rs/core/src/tools/handlers/test_sync.rs b/codex-rs/core/src/tools/handlers/test_sync.rs index 2d6e351488c5..5b86c0c4abb7 100644 --- a/codex-rs/core/src/tools/handlers/test_sync.rs +++ b/codex-rs/core/src/tools/handlers/test_sync.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use std::sync::OnceLock; use std::time::Duration; -use async_trait::async_trait; +use futures::future::BoxFuture; use serde::Deserialize; use tokio::sync::Barrier; use tokio::time::sleep; @@ -14,6 +14,7 @@ use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; use crate::tools::handlers::parse_arguments; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; @@ -54,45 +55,54 @@ fn barrier_map() -> &'static tokio::sync::Mutex> { BARRIERS.get_or_init(|| tokio::sync::Mutex::new(HashMap::new())) } -#[async_trait] impl ToolHandler for TestSyncHandler { - type Output = FunctionToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { payload, .. } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "test_sync_tool handler received unsupported payload".to_string(), - )); - } - }; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + call_id, payload, .. + } = invocation; + let payload_for_result = payload.clone(); + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "test_sync_tool handler received unsupported payload".to_string(), + )); + } + }; - let args: TestSyncArgs = parse_arguments(&arguments)?; + let args: TestSyncArgs = parse_arguments(&arguments)?; - if let Some(delay) = args.sleep_before_ms - && delay > 0 - { - sleep(Duration::from_millis(delay)).await; - } + if let Some(delay) = args.sleep_before_ms + && delay > 0 + { + sleep(Duration::from_millis(delay)).await; + } - if let Some(barrier) = args.barrier { - wait_on_barrier(barrier).await?; - } + if let Some(barrier) = args.barrier { + wait_on_barrier(barrier).await?; + } - if let Some(delay) = args.sleep_after_ms - && delay > 0 - { - sleep(Duration::from_millis(delay)).await; - } + if let Some(delay) = args.sleep_after_ms + && delay > 0 + { + sleep(Duration::from_millis(delay)).await; + } - Ok(FunctionToolOutput::from_text("ok".to_string(), Some(true))) + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(FunctionToolOutput::from_text("ok".to_string(), Some(true))), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/tool_search.rs b/codex-rs/core/src/tools/handlers/tool_search.rs index aa9d8e92f19e..d73463bd0543 100644 --- a/codex-rs/core/src/tools/handlers/tool_search.rs +++ b/codex-rs/core/src/tools/handlers/tool_search.rs @@ -2,9 +2,9 @@ use crate::function_tool::FunctionCallError; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; use crate::tools::context::ToolSearchOutput; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; -use async_trait::async_trait; use bm25::Document; use bm25::Language; use bm25::SearchEngineBuilder; @@ -13,6 +13,7 @@ use codex_tools::TOOL_SEARCH_DEFAULT_LIMIT; use codex_tools::TOOL_SEARCH_TOOL_NAME; use codex_tools::ToolSearchResultSource; use codex_tools::collect_tool_search_output_tools; +use futures::future::BoxFuture; use std::collections::HashMap; pub struct ToolSearchHandler { @@ -25,78 +26,85 @@ impl ToolSearchHandler { } } -#[async_trait] impl ToolHandler for ToolSearchHandler { - type Output = ToolSearchOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } - async fn handle( + fn handle( &self, invocation: ToolInvocation, - ) -> Result { - let ToolInvocation { payload, .. } = invocation; - - let args = match payload { - ToolPayload::ToolSearch { arguments } => arguments, - _ => { - return Err(FunctionCallError::Fatal(format!( - "{TOOL_SEARCH_TOOL_NAME} handler received unsupported payload" - ))); + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + call_id, payload, .. + } = invocation; + + let payload_for_result = payload.clone(); + let args = match payload { + ToolPayload::ToolSearch { arguments } => arguments, + _ => { + return Err(FunctionCallError::Fatal(format!( + "{TOOL_SEARCH_TOOL_NAME} handler received unsupported payload" + ))); + } + }; + + let query = args.query.trim(); + if query.is_empty() { + return Err(FunctionCallError::RespondToModel( + "query must not be empty".to_string(), + )); + } + let limit = args.limit.unwrap_or(TOOL_SEARCH_DEFAULT_LIMIT); + + if limit == 0 { + return Err(FunctionCallError::RespondToModel( + "limit must be greater than zero".to_string(), + )); } - }; - - let query = args.query.trim(); - if query.is_empty() { - return Err(FunctionCallError::RespondToModel( - "query must not be empty".to_string(), - )); - } - let limit = args.limit.unwrap_or(TOOL_SEARCH_DEFAULT_LIMIT); - - if limit == 0 { - return Err(FunctionCallError::RespondToModel( - "limit must be greater than zero".to_string(), - )); - } - - let mut entries: Vec<(String, ToolInfo)> = self.tools.clone().into_iter().collect(); - entries.sort_by(|a, b| a.0.cmp(&b.0)); - - if entries.is_empty() { - return Ok(ToolSearchOutput { tools: Vec::new() }); - } - - let documents: Vec> = entries - .iter() - .enumerate() - .map(|(idx, (name, info))| Document::new(idx, build_search_text(name, info))) - .collect(); - let search_engine = - SearchEngineBuilder::::with_documents(Language::English, documents).build(); - let results = search_engine.search(query, limit); - - let tools = collect_tool_search_output_tools( - results - .into_iter() - .filter_map(|result| entries.get(result.document.id)) - .map(|(_name, tool)| ToolSearchResultSource { - tool_namespace: tool.tool_namespace.as_str(), - tool_name: tool.tool_name.as_str(), - tool: &tool.tool, - connector_name: tool.connector_name.as_deref(), - connector_description: tool.connector_description.as_deref(), - }), - ) - .map_err(|err| { - FunctionCallError::Fatal(format!( - "failed to encode {TOOL_SEARCH_TOOL_NAME} output: {err}" - )) - })?; - - Ok(ToolSearchOutput { tools }) + + let mut entries: Vec<(String, ToolInfo)> = self.tools.clone().into_iter().collect(); + entries.sort_by(|a, b| a.0.cmp(&b.0)); + + let tools = if entries.is_empty() { + Vec::new() + } else { + let documents: Vec> = entries + .iter() + .enumerate() + .map(|(idx, (name, info))| Document::new(idx, build_search_text(name, info))) + .collect(); + let search_engine = + SearchEngineBuilder::::with_documents(Language::English, documents) + .build(); + let results = search_engine.search(query, limit); + + collect_tool_search_output_tools( + results + .into_iter() + .filter_map(|result| entries.get(result.document.id)) + .map(|(_name, tool)| ToolSearchResultSource { + tool_namespace: tool.tool_namespace.as_str(), + tool_name: tool.tool_name.as_str(), + tool: &tool.tool, + connector_name: tool.connector_name.as_deref(), + connector_description: tool.connector_description.as_deref(), + }), + ) + .map_err(|err| { + FunctionCallError::Fatal(format!( + "failed to encode {TOOL_SEARCH_TOOL_NAME} output: {err}" + )) + })? + }; + + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(ToolSearchOutput { tools }), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/tool_suggest.rs b/codex-rs/core/src/tools/handlers/tool_suggest.rs index 45c77a350c3b..40a4c68d5f0d 100644 --- a/codex-rs/core/src/tools/handlers/tool_suggest.rs +++ b/codex-rs/core/src/tools/handlers/tool_suggest.rs @@ -1,6 +1,5 @@ use std::collections::HashSet; -use async_trait::async_trait; use codex_app_server_protocol::AppInfo; use codex_mcp::mcp::CODEX_APPS_MCP_SERVER_NAME; use codex_rmcp_client::ElicitationAction; @@ -23,136 +22,145 @@ use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; use crate::tools::handlers::parse_arguments; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; +use futures::future::BoxFuture; pub struct ToolSuggestHandler; -#[async_trait] impl ToolHandler for ToolSuggestHandler { - type Output = FunctionToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - payload, - session, - turn, - call_id, - .. - } = invocation; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + payload, + session, + turn, + call_id, + .. + } = invocation; + let payload_for_result = payload.clone(); - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::Fatal(format!( - "{TOOL_SUGGEST_TOOL_NAME} handler received unsupported payload" - ))); - } - }; + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::Fatal(format!( + "{TOOL_SUGGEST_TOOL_NAME} handler received unsupported payload" + ))); + } + }; - let args: ToolSuggestArgs = parse_arguments(&arguments)?; - let suggest_reason = args.suggest_reason.trim(); - if suggest_reason.is_empty() { - return Err(FunctionCallError::RespondToModel( - "suggest_reason must not be empty".to_string(), - )); - } - if args.action_type != DiscoverableToolAction::Install { - return Err(FunctionCallError::RespondToModel( - "tool suggestions currently support only action_type=\"install\"".to_string(), - )); - } - if args.tool_type == DiscoverableToolType::Plugin - && turn.app_server_client_name.as_deref() == Some("codex-tui") - { - return Err(FunctionCallError::RespondToModel( - "plugin tool suggestions are not available in codex-tui yet".to_string(), - )); - } + let args: ToolSuggestArgs = parse_arguments(&arguments)?; + let suggest_reason = args.suggest_reason.trim(); + if suggest_reason.is_empty() { + return Err(FunctionCallError::RespondToModel( + "suggest_reason must not be empty".to_string(), + )); + } + if args.action_type != DiscoverableToolAction::Install { + return Err(FunctionCallError::RespondToModel( + "tool suggestions currently support only action_type=\"install\"".to_string(), + )); + } + if args.tool_type == DiscoverableToolType::Plugin + && turn.app_server_client_name.as_deref() == Some("codex-tui") + { + return Err(FunctionCallError::RespondToModel( + "plugin tool suggestions are not available in codex-tui yet".to_string(), + )); + } - let auth = session.services.auth_manager.auth().await; - let manager = session.services.mcp_connection_manager.read().await; - let mcp_tools = manager.list_all_tools().await; - drop(manager); - let accessible_connectors = connectors::with_app_enabled_state( - connectors::accessible_connectors_from_mcp_tools(&mcp_tools), - &turn.config, - ); - let discoverable_tools = connectors::list_tool_suggest_discoverable_tools_with_auth( - &turn.config, - auth.as_ref(), - &accessible_connectors, - ) - .await - .map(|discoverable_tools| { - filter_tool_suggest_discoverable_tools_for_client( - discoverable_tools, - turn.app_server_client_name.as_deref(), + let auth = session.services.auth_manager.auth().await; + let manager = session.services.mcp_connection_manager.read().await; + let mcp_tools = manager.list_all_tools().await; + drop(manager); + let accessible_connectors = connectors::with_app_enabled_state( + connectors::accessible_connectors_from_mcp_tools(&mcp_tools), + &turn.config, + ); + let discoverable_tools = connectors::list_tool_suggest_discoverable_tools_with_auth( + &turn.config, + auth.as_ref(), + &accessible_connectors, ) - }) - .map_err(|err| { - FunctionCallError::RespondToModel(format!( - "tool suggestions are unavailable right now: {err}" - )) - })?; - - let tool = discoverable_tools - .into_iter() - .find(|tool| tool.tool_type() == args.tool_type && tool.id() == args.tool_id) - .ok_or_else(|| { + .await + .map(|discoverable_tools| { + filter_tool_suggest_discoverable_tools_for_client( + discoverable_tools, + turn.app_server_client_name.as_deref(), + ) + }) + .map_err(|err| { FunctionCallError::RespondToModel(format!( - "tool_id must match one of the discoverable tools exposed by {TOOL_SUGGEST_TOOL_NAME}" + "tool suggestions are unavailable right now: {err}" )) })?; - let request_id = RequestId::String(format!("tool_suggestion_{call_id}").into()); - let params = build_tool_suggestion_elicitation_request( - CODEX_APPS_MCP_SERVER_NAME, - session.conversation_id.to_string(), - turn.sub_id.clone(), - &args, - suggest_reason, - &tool, - ); - let response = session - .request_mcp_server_elicitation(turn.as_ref(), request_id, params) - .await; - let user_confirmed = response - .as_ref() - .is_some_and(|response| response.action == ElicitationAction::Accept); - - let completed = if user_confirmed { - verify_tool_suggestion_completed(&session, &turn, &tool, auth.as_ref()).await - } else { - false - }; + let tool = discoverable_tools + .into_iter() + .find(|tool| tool.tool_type() == args.tool_type && tool.id() == args.tool_id) + .ok_or_else(|| { + FunctionCallError::RespondToModel(format!( + "tool_id must match one of the discoverable tools exposed by {TOOL_SUGGEST_TOOL_NAME}" + )) + })?; - if completed && let DiscoverableTool::Connector(connector) = &tool { - session - .merge_connector_selection(HashSet::from([connector.id.clone()])) + let request_id = RequestId::String(format!("tool_suggestion_{call_id}").into()); + let params = build_tool_suggestion_elicitation_request( + CODEX_APPS_MCP_SERVER_NAME, + session.conversation_id.to_string(), + turn.sub_id.clone(), + &args, + suggest_reason, + &tool, + ); + let response = session + .request_mcp_server_elicitation(turn.as_ref(), request_id, params) .await; - } + let user_confirmed = response + .as_ref() + .is_some_and(|response| response.action == ElicitationAction::Accept); - let content = serde_json::to_string(&ToolSuggestResult { - completed, - user_confirmed, - tool_type: args.tool_type, - action_type: args.action_type, - tool_id: tool.id().to_string(), - tool_name: tool.name().to_string(), - suggest_reason: suggest_reason.to_string(), - }) - .map_err(|err| { - FunctionCallError::Fatal(format!( - "failed to serialize {TOOL_SUGGEST_TOOL_NAME} response: {err}" - )) - })?; + let completed = if user_confirmed { + verify_tool_suggestion_completed(&session, &turn, &tool, auth.as_ref()).await + } else { + false + }; + + if completed && let DiscoverableTool::Connector(connector) = &tool { + session + .merge_connector_selection(HashSet::from([connector.id.clone()])) + .await; + } - Ok(FunctionToolOutput::from_text(content, Some(true))) + let content = serde_json::to_string(&ToolSuggestResult { + completed, + user_confirmed, + tool_type: args.tool_type, + action_type: args.action_type, + tool_id: tool.id().to_string(), + tool_name: tool.name().to_string(), + suggest_reason: suggest_reason.to_string(), + }) + .map_err(|err| { + FunctionCallError::Fatal(format!( + "failed to serialize {TOOL_SUGGEST_TOOL_NAME} response: {err}" + )) + })?; + + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(FunctionToolOutput::from_text(content, Some(true))), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/unified_exec.rs b/codex-rs/core/src/tools/handlers/unified_exec.rs index 98d7d717fced..05b6cc1cfb69 100644 --- a/codex-rs/core/src/tools/handlers/unified_exec.rs +++ b/codex-rs/core/src/tools/handlers/unified_exec.rs @@ -14,6 +14,7 @@ use crate::tools::handlers::normalize_and_validate_additional_permissions; use crate::tools::handlers::parse_arguments; use crate::tools::handlers::parse_arguments_with_base_path; use crate::tools::handlers::resolve_workdir_base_path; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::PostToolUsePayload; use crate::tools::registry::PreToolUsePayload; use crate::tools::registry::ToolHandler; @@ -22,7 +23,6 @@ use crate::unified_exec::ExecCommandRequest; use crate::unified_exec::UnifiedExecContext; use crate::unified_exec::UnifiedExecProcessManager; use crate::unified_exec::WriteStdinRequest; -use async_trait::async_trait; use codex_features::Feature; use codex_otel::SessionTelemetry; use codex_otel::metrics::names::TOOL_CALL_UNIFIED_EXEC_METRIC; @@ -31,6 +31,7 @@ use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::TerminalInteractionEvent; use codex_shell_command::is_safe_command::is_known_safe_command; use codex_tools::UnifiedExecShellMode; +use futures::future::BoxFuture; use serde::Deserialize; use std::path::PathBuf; use std::sync::Arc; @@ -86,10 +87,7 @@ fn default_tty() -> bool { false } -#[async_trait] impl ToolHandler for UnifiedExecHandler { - type Output = ExecCommandToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } @@ -98,7 +96,7 @@ impl ToolHandler for UnifiedExecHandler { matches!(payload, ToolPayload::Function { .. }) } - async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { + fn is_mutating(&self, invocation: &ToolInvocation) -> bool { let ToolPayload::Function { arguments } = &invocation.payload else { tracing::error!( "This should never happen, invocation payload is wrong: {:?}", @@ -158,211 +156,226 @@ impl ToolHandler for UnifiedExecHandler { }) } - async fn handle(&self, invocation: ToolInvocation) -> Result { - let ToolInvocation { - session, - turn, - tracker, - call_id, - tool_name, - payload, - .. - } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { - return Err(FunctionCallError::RespondToModel( - "unified_exec handler received unsupported payload".to_string(), - )); - } - }; - - let manager: &UnifiedExecProcessManager = &session.services.unified_exec_manager; - let context = UnifiedExecContext::new(session.clone(), turn.clone(), call_id.clone()); - - let response = match tool_name.as_str() { - "exec_command" => { - let cwd = resolve_workdir_base_path(&arguments, context.turn.cwd.as_path())?; - let args: ExecCommandArgs = - parse_arguments_with_base_path(&arguments, cwd.as_path())?; - let workdir = context.turn.resolve_path(args.workdir.clone()); - maybe_emit_implicit_skill_invocation( - session.as_ref(), - context.turn.as_ref(), - &args.cmd, - &workdir, - ) - .await; - let process_id = manager.allocate_process_id().await; - let command = get_command( - &args, - session.user_shell(), - &turn.tools_config.unified_exec_shell_mode, - turn.tools_config.allow_login_shell, - ) - .map_err(FunctionCallError::RespondToModel)?; - let command_for_display = codex_shell_command::parse_command::shlex_join(&command); - - let ExecCommandArgs { - workdir, - tty, - yield_time_ms, - max_output_tokens, - sandbox_permissions, - additional_permissions, - justification, - prefix_rule, - .. - } = args; - - let exec_permission_approvals_enabled = - session.features().enabled(Feature::ExecPermissionApprovals); - let requested_additional_permissions = additional_permissions.clone(); - let effective_additional_permissions = apply_granted_turn_permissions( - context.session.as_ref(), - sandbox_permissions, - additional_permissions, - ) - .await; - let additional_permissions_allowed = exec_permission_approvals_enabled - || (session.features().enabled(Feature::RequestPermissionsTool) - && effective_additional_permissions.permissions_preapproved); - - // Sticky turn permissions have already been approved, so they should - // continue through the normal exec approval flow for the command. - if effective_additional_permissions - .sandbox_permissions - .requests_sandbox_override() - && !effective_additional_permissions.permissions_preapproved - && !matches!( - context.turn.approval_policy.value(), - codex_protocol::protocol::AskForApproval::OnRequest - ) - { - let approval_policy = context.turn.approval_policy.value(); - manager.release_process_id(process_id).await; - return Err(FunctionCallError::RespondToModel(format!( - "approval policy is {approval_policy:?}; reject command — you cannot ask for escalated permissions if the approval policy is {approval_policy:?}" - ))); + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + let ToolInvocation { + session, + turn, + tracker, + call_id, + tool_name, + payload, + .. + } = invocation; + let payload_for_result = payload.clone(); + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "unified_exec handler received unsupported payload".to_string(), + )); } - - let workdir = workdir.filter(|value| !value.is_empty()); - - let workdir = workdir.map(|dir| context.turn.resolve_path(Some(dir))); - let cwd = workdir.clone().unwrap_or(cwd); - let normalized_additional_permissions = match implicit_granted_permissions( - sandbox_permissions, - requested_additional_permissions.as_ref(), - &effective_additional_permissions, - ) - .map_or_else( - || { - normalize_and_validate_additional_permissions( - additional_permissions_allowed, + }; + + let manager: &UnifiedExecProcessManager = &session.services.unified_exec_manager; + let context = UnifiedExecContext::new(session.clone(), turn.clone(), call_id.clone()); + + let response = match tool_name.as_str() { + "exec_command" => { + let cwd = resolve_workdir_base_path(&arguments, context.turn.cwd.as_path())?; + let args: ExecCommandArgs = + parse_arguments_with_base_path(&arguments, cwd.as_path())?; + let workdir = context.turn.resolve_path(args.workdir.clone()); + maybe_emit_implicit_skill_invocation( + session.as_ref(), + context.turn.as_ref(), + &args.cmd, + &workdir, + ) + .await; + let process_id = manager.allocate_process_id().await; + let command = get_command( + &args, + session.user_shell(), + &turn.tools_config.unified_exec_shell_mode, + turn.tools_config.allow_login_shell, + ) + .map_err(FunctionCallError::RespondToModel)?; + let command_for_display = + codex_shell_command::parse_command::shlex_join(&command); + + let ExecCommandArgs { + workdir, + tty, + yield_time_ms, + max_output_tokens, + sandbox_permissions, + additional_permissions, + justification, + prefix_rule, + .. + } = args; + + let exec_permission_approvals_enabled = + session.features().enabled(Feature::ExecPermissionApprovals); + let requested_additional_permissions = additional_permissions.clone(); + let effective_additional_permissions = apply_granted_turn_permissions( + context.session.as_ref(), + sandbox_permissions, + additional_permissions, + ) + .await; + let additional_permissions_allowed = exec_permission_approvals_enabled + || (session.features().enabled(Feature::RequestPermissionsTool) + && effective_additional_permissions.permissions_preapproved); + + // Sticky turn permissions have already been approved, so they should + // continue through the normal exec approval flow for the command. + if effective_additional_permissions + .sandbox_permissions + .requests_sandbox_override() + && !effective_additional_permissions.permissions_preapproved + && !matches!( context.turn.approval_policy.value(), - effective_additional_permissions.sandbox_permissions, - effective_additional_permissions.additional_permissions, - effective_additional_permissions.permissions_preapproved, - &cwd, + codex_protocol::protocol::AskForApproval::OnRequest ) - }, - |permissions| Ok(Some(permissions)), - ) { - Ok(normalized) => normalized, - Err(err) => { + { + let approval_policy = context.turn.approval_policy.value(); manager.release_process_id(process_id).await; - return Err(FunctionCallError::RespondToModel(err)); + return Err(FunctionCallError::RespondToModel(format!( + "approval policy is {approval_policy:?}; reject command — you cannot ask for escalated permissions if the approval policy is {approval_policy:?}" + ))); } - }; - - if let Some(output) = intercept_apply_patch( - &command, - &cwd, - Some(yield_time_ms), - context.session.clone(), - context.turn.clone(), - Some(&tracker), - &context.call_id, - tool_name.as_str(), - ) - .await? - { - manager.release_process_id(process_id).await; - return Ok(ExecCommandToolOutput { - event_call_id: String::new(), - chunk_id: String::new(), - wall_time: std::time::Duration::ZERO, - raw_output: output.into_text().into_bytes(), - max_output_tokens: None, - process_id: None, - exit_code: None, - original_token_count: None, - session_command: None, - }); - } - emit_unified_exec_tty_metric(&turn.session_telemetry, tty); - manager - .exec_command( - ExecCommandRequest { - command, - process_id, - yield_time_ms, - max_output_tokens, - workdir, - network: context.turn.network.clone(), - tty, - sandbox_permissions: effective_additional_permissions - .sandbox_permissions, - additional_permissions: normalized_additional_permissions, - additional_permissions_preapproved: effective_additional_permissions - .permissions_preapproved, - justification, - prefix_rule, + let workdir = workdir.filter(|value| !value.is_empty()); + + let workdir = workdir.map(|dir| context.turn.resolve_path(Some(dir))); + let cwd = workdir.clone().unwrap_or(cwd); + let normalized_additional_permissions = match implicit_granted_permissions( + sandbox_permissions, + requested_additional_permissions.as_ref(), + &effective_additional_permissions, + ) + .map_or_else( + || { + normalize_and_validate_additional_permissions( + additional_permissions_allowed, + context.turn.approval_policy.value(), + effective_additional_permissions.sandbox_permissions, + effective_additional_permissions.additional_permissions, + effective_additional_permissions.permissions_preapproved, + &cwd, + ) }, - &context, + |permissions| Ok(Some(permissions)), + ) { + Ok(normalized) => normalized, + Err(err) => { + manager.release_process_id(process_id).await; + return Err(FunctionCallError::RespondToModel(err)); + } + }; + + if let Some(output) = intercept_apply_patch( + &command, + &cwd, + Some(yield_time_ms), + context.session.clone(), + context.turn.clone(), + Some(&tracker), + &context.call_id, + tool_name.as_str(), ) - .await - .map_err(|err| { - FunctionCallError::RespondToModel(format!( - "exec_command failed for `{command_for_display}`: {err:?}" - )) - })? - } - "write_stdin" => { - let args: WriteStdinArgs = parse_arguments(&arguments)?; - let response = manager - .write_stdin(WriteStdinRequest { - process_id: args.session_id, - input: &args.chars, - yield_time_ms: args.yield_time_ms, - max_output_tokens: args.max_output_tokens, - }) - .await - .map_err(|err| { - FunctionCallError::RespondToModel(format!("write_stdin failed: {err}")) - })?; - - let interaction = TerminalInteractionEvent { - call_id: response.event_call_id.clone(), - process_id: args.session_id.to_string(), - stdin: args.chars.clone(), - }; - session - .send_event(turn.as_ref(), EventMsg::TerminalInteraction(interaction)) - .await; + .await? + { + manager.release_process_id(process_id).await; + return Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(ExecCommandToolOutput { + event_call_id: String::new(), + chunk_id: String::new(), + wall_time: std::time::Duration::ZERO, + raw_output: output.into_text().into_bytes(), + max_output_tokens: None, + process_id: None, + exit_code: None, + original_token_count: None, + session_command: None, + }), + }); + } - response - } - other => { - return Err(FunctionCallError::RespondToModel(format!( - "unsupported unified exec function {other}" - ))); - } - }; + emit_unified_exec_tty_metric(&turn.session_telemetry, tty); + manager + .exec_command( + ExecCommandRequest { + command, + process_id, + yield_time_ms, + max_output_tokens, + workdir, + network: context.turn.network.clone(), + tty, + sandbox_permissions: effective_additional_permissions + .sandbox_permissions, + additional_permissions: normalized_additional_permissions, + additional_permissions_preapproved: + effective_additional_permissions.permissions_preapproved, + justification, + prefix_rule, + }, + &context, + ) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!( + "exec_command failed for `{command_for_display}`: {err:?}" + )) + })? + } + "write_stdin" => { + let args: WriteStdinArgs = parse_arguments(&arguments)?; + let response = manager + .write_stdin(WriteStdinRequest { + process_id: args.session_id, + input: &args.chars, + yield_time_ms: args.yield_time_ms, + max_output_tokens: args.max_output_tokens, + }) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!("write_stdin failed: {err}")) + })?; + + let interaction = TerminalInteractionEvent { + call_id: response.event_call_id.clone(), + process_id: args.session_id.to_string(), + stdin: args.chars.clone(), + }; + session + .send_event(turn.as_ref(), EventMsg::TerminalInteraction(interaction)) + .await; + + response + } + other => { + return Err(FunctionCallError::RespondToModel(format!( + "unsupported unified exec function {other}" + ))); + } + }; - Ok(response) + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(response), + }) + }) } } diff --git a/codex-rs/core/src/tools/handlers/view_image.rs b/codex-rs/core/src/tools/handlers/view_image.rs index e87638bba6ac..20e56ff412e4 100644 --- a/codex-rs/core/src/tools/handlers/view_image.rs +++ b/codex-rs/core/src/tools/handlers/view_image.rs @@ -1,4 +1,3 @@ -use async_trait::async_trait; use codex_protocol::models::FunctionCallOutputBody; use codex_protocol::models::FunctionCallOutputContentItem; use codex_protocol::models::FunctionCallOutputPayload; @@ -16,10 +15,12 @@ use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; use crate::tools::handlers::parse_arguments; +use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::ViewImageToolCallEvent; +use futures::future::BoxFuture; pub struct ViewImageHandler; @@ -37,127 +38,137 @@ enum ViewImageDetail { Original, } -#[async_trait] impl ToolHandler for ViewImageHandler { - type Output = ViewImageOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } - async fn handle(&self, invocation: ToolInvocation) -> Result { - if !invocation - .turn - .model_info - .input_modalities - .contains(&InputModality::Image) - { - return Err(FunctionCallError::RespondToModel( - VIEW_IMAGE_UNSUPPORTED_MESSAGE.to_string(), - )); - } - - let ToolInvocation { - session, - turn, - payload, - call_id, - .. - } = invocation; - - let arguments = match payload { - ToolPayload::Function { arguments } => arguments, - _ => { + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { + if !invocation + .turn + .model_info + .input_modalities + .contains(&InputModality::Image) + { return Err(FunctionCallError::RespondToModel( - "view_image handler received unsupported payload".to_string(), + VIEW_IMAGE_UNSUPPORTED_MESSAGE.to_string(), )); } - }; - let args: ViewImageArgs = parse_arguments(&arguments)?; - // `view_image` accepts only its documented detail values: omit - // `detail` for the default path or set it to `original`. - // Other string values remain invalid rather than being silently - // reinterpreted. - let detail = match args.detail.as_deref() { - None => None, - Some("original") => Some(ViewImageDetail::Original), - Some(detail) => { + let ToolInvocation { + session, + turn, + payload, + call_id, + .. + } = invocation; + let payload_for_result = payload.clone(); + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "view_image handler received unsupported payload".to_string(), + )); + } + }; + + let args: ViewImageArgs = parse_arguments(&arguments)?; + // `view_image` accepts only its documented detail values: omit + // `detail` for the default path or set it to `original`. + // Other string values remain invalid rather than being silently + // reinterpreted. + let detail = match args.detail.as_deref() { + None => None, + Some("original") => Some(ViewImageDetail::Original), + Some(detail) => { + return Err(FunctionCallError::RespondToModel(format!( + "view_image.detail only supports `original`; omit `detail` for default resized behavior, got `{detail}`" + ))); + } + }; + + let abs_path = + AbsolutePathBuf::try_from(turn.resolve_path(Some(args.path))).map_err(|error| { + FunctionCallError::RespondToModel(format!( + "unable to resolve image path: {error}" + )) + })?; + + let metadata = turn + .environment + .get_filesystem() + .get_metadata(&abs_path) + .await + .map_err(|error| { + FunctionCallError::RespondToModel(format!( + "unable to locate image at `{}`: {error}", + abs_path.display() + )) + })?; + + if !metadata.is_file { return Err(FunctionCallError::RespondToModel(format!( - "view_image.detail only supports `original`; omit `detail` for default resized behavior, got `{detail}`" + "image path `{}` is not a file", + abs_path.display() ))); } - }; - - let abs_path = - AbsolutePathBuf::try_from(turn.resolve_path(Some(args.path))).map_err(|error| { - FunctionCallError::RespondToModel(format!("unable to resolve image path: {error}")) - })?; - - let metadata = turn - .environment - .get_filesystem() - .get_metadata(&abs_path) - .await - .map_err(|error| { - FunctionCallError::RespondToModel(format!( - "unable to locate image at `{}`: {error}", - abs_path.display() - )) - })?; - - if !metadata.is_file { - return Err(FunctionCallError::RespondToModel(format!( - "image path `{}` is not a file", - abs_path.display() - ))); - } - let file_bytes = turn - .environment - .get_filesystem() - .read_file(&abs_path) - .await - .map_err(|error| { - FunctionCallError::RespondToModel(format!( - "unable to read image at `{}`: {error}", - abs_path.display() - )) - })?; - let event_path = abs_path.to_path_buf(); - - let can_request_original_detail = - can_request_original_image_detail(turn.features.get(), &turn.model_info); - let use_original_detail = - can_request_original_detail && matches!(detail, Some(ViewImageDetail::Original)); - let image_mode = if use_original_detail { - PromptImageMode::Original - } else { - PromptImageMode::ResizeToFit - }; - let image_detail = use_original_detail.then_some(ImageDetail::Original); - - let image = - load_for_prompt_bytes(abs_path.as_path(), file_bytes, image_mode).map_err(|error| { - FunctionCallError::RespondToModel(format!( - "unable to process image at `{}`: {error}", - abs_path.display() - )) - })?; - let image_url = image.into_data_url(); - - session - .send_event( - turn.as_ref(), - EventMsg::ViewImageToolCall(ViewImageToolCallEvent { - call_id, - path: event_path, + let file_bytes = turn + .environment + .get_filesystem() + .read_file(&abs_path) + .await + .map_err(|error| { + FunctionCallError::RespondToModel(format!( + "unable to read image at `{}`: {error}", + abs_path.display() + )) + })?; + let event_path = abs_path.to_path_buf(); + + let can_request_original_detail = + can_request_original_image_detail(turn.features.get(), &turn.model_info); + let use_original_detail = + can_request_original_detail && matches!(detail, Some(ViewImageDetail::Original)); + let image_mode = if use_original_detail { + PromptImageMode::Original + } else { + PromptImageMode::ResizeToFit + }; + let image_detail = use_original_detail.then_some(ImageDetail::Original); + + let image = load_for_prompt_bytes(abs_path.as_path(), file_bytes, image_mode).map_err( + |error| { + FunctionCallError::RespondToModel(format!( + "unable to process image at `{}`: {error}", + abs_path.display() + )) + }, + )?; + let image_url = image.into_data_url(); + + session + .send_event( + turn.as_ref(), + EventMsg::ViewImageToolCall(ViewImageToolCallEvent { + call_id: call_id.clone(), + path: event_path, + }), + ) + .await; + + Ok(AnyToolResult { + call_id, + payload: payload_for_result, + result: Box::new(ViewImageOutput { + image_url, + image_detail, }), - ) - .await; - - Ok(ViewImageOutput { - image_url, - image_detail, + }) }) } } diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index f1ebb38c06c8..8672e6967186 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -13,7 +13,6 @@ use crate::tools::context::FunctionToolOutput; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; -use async_trait::async_trait; use codex_hooks::HookEvent; use codex_hooks::HookEventAfterToolUse; use codex_hooks::HookPayload; @@ -26,6 +25,7 @@ use codex_protocol::protocol::SandboxPolicy; use codex_tools::ConfiguredToolSpec; use codex_tools::ToolSpec; use codex_utils_readiness::Readiness; +use futures::future::BoxFuture; use serde_json::Value; use tracing::warn; @@ -35,10 +35,7 @@ pub enum ToolKind { Mcp, } -#[async_trait] pub trait ToolHandler: Send + Sync { - type Output: ToolOutput + 'static; - fn kind(&self) -> ToolKind; fn matches_kind(&self, payload: &ToolPayload) -> bool { @@ -54,7 +51,7 @@ pub trait ToolHandler: Send + Sync { /// user (through file system, OS operations, ...). /// This function must remains defensive and return `true` if a doubt exist on the /// exact effect of a ToolInvocation. - async fn is_mutating(&self, _invocation: &ToolInvocation) -> bool { + fn is_mutating(&self, _invocation: &ToolInvocation) -> bool { false } @@ -73,7 +70,10 @@ pub trait ToolHandler: Send + Sync { /// Perform the actual [ToolInvocation] and returns a [ToolOutput] containing /// the final output to return to the model. - async fn handle(&self, invocation: ToolInvocation) -> Result; + fn handle( + &self, + invocation: ToolInvocation, + ) -> BoxFuture<'_, Result>; } pub(crate) struct AnyToolResult { @@ -82,7 +82,26 @@ pub(crate) struct AnyToolResult { pub(crate) result: Box, } +impl std::fmt::Debug for AnyToolResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AnyToolResult") + .field("call_id", &self.call_id) + .field("payload", &self.payload) + .field("log_preview", &self.result.log_preview()) + .field("success", &self.result.success_for_logging()) + .finish() + } +} + impl AnyToolResult { + pub(crate) fn new(invocation: &ToolInvocation, result: impl ToolOutput + 'static) -> Self { + Self { + call_id: invocation.call_id.clone(), + payload: invocation.payload.clone(), + result: Box::new(result), + } + } + pub(crate) fn into_response(self) -> ResponseInputItem { let Self { call_id, @@ -101,79 +120,39 @@ impl AnyToolResult { } } -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct PreToolUsePayload { - pub(crate) command: String, -} - -#[derive(Debug, Clone, PartialEq)] -pub(crate) struct PostToolUsePayload { - pub(crate) command: String, - pub(crate) tool_response: Value, -} - -#[async_trait] -trait AnyToolHandler: Send + Sync { - fn matches_kind(&self, payload: &ToolPayload) -> bool; - - async fn is_mutating(&self, invocation: &ToolInvocation) -> bool; - - fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option; - - fn post_tool_use_payload( - &self, - call_id: &str, - payload: &ToolPayload, - result: &dyn ToolOutput, - ) -> Option; - - async fn handle_any( - &self, - invocation: ToolInvocation, - ) -> Result; -} - -#[async_trait] -impl AnyToolHandler for T -where - T: ToolHandler, -{ - fn matches_kind(&self, payload: &ToolPayload) -> bool { - ToolHandler::matches_kind(self, payload) +impl ToolOutput for AnyToolResult { + fn log_preview(&self) -> String { + self.result.log_preview() } - async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { - ToolHandler::is_mutating(self, invocation).await + fn success_for_logging(&self) -> bool { + self.result.success_for_logging() } - fn pre_tool_use_payload(&self, invocation: &ToolInvocation) -> Option { - ToolHandler::pre_tool_use_payload(self, invocation) + fn to_response_item(&self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem { + self.result.to_response_item(call_id, payload) } - fn post_tool_use_payload( - &self, - call_id: &str, - payload: &ToolPayload, - result: &dyn ToolOutput, - ) -> Option { - ToolHandler::post_tool_use_payload(self, call_id, payload, result) + fn post_tool_use_response(&self, call_id: &str, payload: &ToolPayload) -> Option { + self.result.post_tool_use_response(call_id, payload) } - async fn handle_any( - &self, - invocation: ToolInvocation, - ) -> Result { - let call_id = invocation.call_id.clone(); - let payload = invocation.payload.clone(); - let output = self.handle(invocation).await?; - Ok(AnyToolResult { - call_id, - payload, - result: Box::new(output), - }) + fn code_mode_result(&self, payload: &ToolPayload) -> serde_json::Value { + self.result.code_mode_result(payload) } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct PreToolUsePayload { + pub(crate) command: String, +} + +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct PostToolUsePayload { + pub(crate) command: String, + pub(crate) tool_response: Value, +} + pub(crate) fn tool_handler_key(tool_name: &str, namespace: Option<&str>) -> String { if let Some(namespace) = namespace { format!("{namespace}:{tool_name}") @@ -183,15 +162,15 @@ pub(crate) fn tool_handler_key(tool_name: &str, namespace: Option<&str>) -> Stri } pub struct ToolRegistry { - handlers: HashMap>, + handlers: HashMap>, } impl ToolRegistry { - fn new(handlers: HashMap>) -> Self { + fn new(handlers: HashMap>) -> Self { Self { handlers } } - fn handler(&self, name: &str, namespace: Option<&str>) -> Option> { + fn handler(&self, name: &str, namespace: Option<&str>) -> Option> { self.handlers .get(&tool_handler_key(name, namespace)) .map(Arc::clone) @@ -311,7 +290,7 @@ impl ToolRegistry { ))); } - let is_mutating = handler.is_mutating(&invocation).await; + let is_mutating = handler.is_mutating(&invocation); let response_cell = tokio::sync::Mutex::new(None); let invocation_for_tool = invocation.clone(); @@ -333,7 +312,7 @@ impl ToolRegistry { invocation_for_tool.turn.tool_call_gate.wait_ready().await; tracing::trace!("tool gate released"); } - match handler.handle_any(invocation_for_tool).await { + match handler.handle(invocation_for_tool).await { Ok(result) => { let preview = result.result.log_preview(); let success = result.result.success_for_logging(); @@ -438,7 +417,7 @@ impl ToolRegistry { } pub struct ToolRegistryBuilder { - handlers: HashMap>, + handlers: HashMap>, specs: Vec, } @@ -468,7 +447,7 @@ impl ToolRegistryBuilder { H: ToolHandler + 'static, { let name = name.into(); - let handler: Arc = handler; + let handler: Arc = handler; if self .handlers .insert(name.clone(), handler.clone()) diff --git a/codex-rs/core/src/tools/registry_tests.rs b/codex-rs/core/src/tools/registry_tests.rs index 687364872c96..e1b399fc9387 100644 --- a/codex-rs/core/src/tools/registry_tests.rs +++ b/codex-rs/core/src/tools/registry_tests.rs @@ -1,26 +1,26 @@ use super::*; -use async_trait::async_trait; +use futures::future::BoxFuture; use pretty_assertions::assert_eq; struct TestHandler; -#[async_trait] impl ToolHandler for TestHandler { - type Output = crate::tools::context::FunctionToolOutput; - fn kind(&self) -> ToolKind { ToolKind::Function } - async fn handle(&self, _invocation: ToolInvocation) -> Result { - unreachable!("test handler should not be invoked") + fn handle( + &self, + _invocation: ToolInvocation, + ) -> BoxFuture<'_, Result> { + Box::pin(async move { unreachable!("test handler should not be invoked") }) } } #[test] fn handler_looks_up_namespaced_aliases_explicitly() { - let plain_handler = Arc::new(TestHandler) as Arc; - let namespaced_handler = Arc::new(TestHandler) as Arc; + let plain_handler = Arc::new(TestHandler) as Arc; + let namespaced_handler = Arc::new(TestHandler) as Arc; let namespace = "mcp__codex_apps__gmail"; let tool_name = "gmail_get_recent_emails"; let namespaced_name = tool_handler_key(tool_name, Some(namespace)); diff --git a/codex-rs/core/src/tools/router_tests.rs b/codex-rs/core/src/tools/router_tests.rs index 641adb56de03..d60c0413f419 100644 --- a/codex-rs/core/src/tools/router_tests.rs +++ b/codex-rs/core/src/tools/router_tests.rs @@ -59,8 +59,7 @@ async fn js_repl_tools_only_blocks_direct_tool_calls() -> anyhow::Result<()> { ToolCallSource::Direct, ) .await - .err() - .expect("direct tool calls should be blocked"); + .expect_err("direct tool calls should be blocked"); let FunctionCallError::RespondToModel(message) = err else { panic!("expected RespondToModel, got {err:?}"); }; @@ -117,8 +116,7 @@ async fn js_repl_tools_only_allows_js_repl_source_calls() -> anyhow::Result<()> ToolCallSource::JsRepl, ) .await - .err() - .expect("shell call with empty args should fail"); + .expect_err("shell call with empty args should fail"); let message = err.to_string(); assert!( !message.contains("direct tool calls are disabled"),