From bccce0f2d8e94a961dfb90c47ed7855010118136 Mon Sep 17 00:00:00 2001 From: nicholasclark-openai Date: Tue, 17 Mar 2026 18:14:32 -0700 Subject: [PATCH 1/6] Forward tool call task headers to MCP HTTP requests Co-authored-by: Codex --- codex-rs/core/src/codex.rs | 3 +- codex-rs/core/src/codex_delegate.rs | 1 + codex-rs/core/src/mcp_connection_manager.rs | 137 ++++++++++++++---- .../core/src/mcp_connection_manager_tests.rs | 5 + codex-rs/core/src/mcp_tool_call.rs | 69 ++++++++- codex-rs/rmcp-client/src/rmcp_client.rs | 106 +++++++++++++- codex-rs/rmcp-client/tests/resources.rs | 1 + .../tests/streamable_http_recovery.rs | 2 + 8 files changed, 283 insertions(+), 41 deletions(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 9382401bacdf..45d7486248f1 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -3912,12 +3912,13 @@ impl Session { tool: &str, arguments: Option, meta: Option, + request_headers: Option, ) -> anyhow::Result { self.services .mcp_connection_manager .read() .await - .call_tool(server, tool, arguments, meta) + .call_tool(server, tool, arguments, meta, request_headers) .await } diff --git a/codex-rs/core/src/codex_delegate.rs b/codex-rs/core/src/codex_delegate.rs index 4369b81dff3b..6bbc4f01bec6 100644 --- a/codex-rs/core/src/codex_delegate.rs +++ b/codex-rs/core/src/codex_delegate.rs @@ -669,6 +669,7 @@ async fn maybe_auto_review_mcp_request_user_input( parent_ctx.as_ref(), &invocation.server, &invocation.tool, + None, ) .await; let review_cancel = cancel_token.child_token(); diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 938d6d0b2bf3..939affa9dff4 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -378,7 +378,7 @@ struct ManagedClient { } impl ManagedClient { - fn listed_tools(&self) -> Vec { + fn listed_tools(&self, _request_headers: Option) -> Vec { let total_start = Instant::now(); if let Some(cache_context) = self.codex_apps_tools_cache_context.as_ref() && let CachedCodexAppsToolsLoad::Hit(tools) = @@ -425,9 +425,42 @@ struct AsyncManagedClient { client: Shared>>, startup_snapshot: Option>, startup_complete: Arc, + startup_request_headers: Arc>>, tool_plugin_provenance: Arc, } +struct StartupRequestHeadersGuard { + state: Arc>>, + previous: Option, +} + +impl StartupRequestHeadersGuard { + fn set( + state: Arc>>, + headers: Option, + ) -> Self { + let previous = { + let mut guard = state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let previous = guard.clone(); + *guard = headers; + previous + }; + Self { state, previous } + } +} + +impl Drop for StartupRequestHeadersGuard { + fn drop(&mut self) { + let mut guard = self + .state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *guard = self.previous.clone(); + } +} + impl AsyncManagedClient { // Keep this constructor flat so the startup inputs remain readable at the // single call site instead of introducing a one-off params wrapper. @@ -451,6 +484,8 @@ impl AsyncManagedClient { let startup_tool_filter = tool_filter; let startup_complete = Arc::new(AtomicBool::new(false)); let startup_complete_for_fut = Arc::clone(&startup_complete); + let startup_request_headers = Arc::new(StdMutex::new(None)); + let startup_request_headers_for_fut = Arc::clone(&startup_request_headers); let fut = async move { let outcome = async { if let Err(error) = validate_mcp_server_name(&server_name) { @@ -471,6 +506,7 @@ impl AsyncManagedClient { tx_event, elicitation_requests, codex_apps_tools_cache_context, + startup_request_headers: startup_request_headers_for_fut, }, ) .or_cancel(&cancel_token) @@ -497,11 +533,17 @@ impl AsyncManagedClient { client, startup_snapshot, startup_complete, + startup_request_headers, tool_plugin_provenance, } } - async fn client(&self) -> Result { + async fn client( + &self, + request_headers: Option, + ) -> Result { + let _request_headers_guard = + StartupRequestHeadersGuard::set(self.startup_request_headers.clone(), request_headers); self.client.clone().await } @@ -512,7 +554,10 @@ impl AsyncManagedClient { None } - async fn listed_tools(&self) -> Option> { + async fn listed_tools( + &self, + request_headers: Option, + ) -> Option> { let annotate_tools = |tools: Vec| { let mut tools = tools; for tool in &mut tools { @@ -564,8 +609,8 @@ impl AsyncManagedClient { let tools = if let Some(startup_tools) = self.startup_snapshot_while_initializing() { Some(startup_tools) } else { - match self.client().await { - Ok(client) => Some(client.listed_tools()), + match self.client(request_headers.clone()).await { + Ok(client) => Some(client.listed_tools(request_headers)), Err(_) => self.startup_snapshot.clone(), } }; @@ -573,7 +618,7 @@ impl AsyncManagedClient { } async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> { - let managed = self.client().await?; + let managed = self.client(None).await?; managed.notify_sandbox_state_change(sandbox_state).await } } @@ -686,7 +731,7 @@ impl McpConnectionManager { let auth_entry = auth_entries.get(&server_name).cloned(); let sandbox_state = initial_sandbox_state.clone(); join_set.spawn(async move { - let outcome = async_managed_client.client().await; + let outcome = async_managed_client.client(None).await; if cancel_token.is_cancelled() { return (server_name, Err(StartupOutcomeError::Cancelled)); } @@ -755,11 +800,15 @@ impl McpConnectionManager { (manager, cancel_token) } - async fn client_by_name(&self, name: &str) -> Result { + async fn client_by_name( + &self, + name: &str, + request_headers: Option, + ) -> Result { self.clients .get(name) .ok_or_else(|| anyhow!("unknown MCP server '{name}'"))? - .client() + .client(request_headers) .await .context("failed to get client") } @@ -780,7 +829,7 @@ impl McpConnectionManager { return false; }; - match tokio::time::timeout(timeout, async_managed_client.client()).await { + match tokio::time::timeout(timeout, async_managed_client.client(None)).await { Ok(Ok(_)) => true, Ok(Err(_)) | Err(_) => false, } @@ -800,7 +849,7 @@ impl McpConnectionManager { continue; }; - match async_managed_client.client().await { + match async_managed_client.client(None).await { Ok(_) => {} Err(error) => failures.push(McpStartupFailure { server: server_name.clone(), @@ -815,9 +864,18 @@ impl McpConnectionManager { /// fully-qualified name for the tool. #[instrument(level = "trace", skip_all)] pub async fn list_all_tools(&self) -> HashMap { + self.list_all_tools_with_request_headers(None).await + } + + #[instrument(level = "trace", skip_all)] + pub async fn list_all_tools_with_request_headers( + &self, + request_headers: Option, + ) -> HashMap { let mut tools = HashMap::new(); for managed_client in self.clients.values() { - let Some(server_tools) = managed_client.listed_tools().await else { + let Some(server_tools) = managed_client.listed_tools(request_headers.clone()).await + else { continue; }; tools.extend(qualify_tools(server_tools)); @@ -835,7 +893,7 @@ impl McpConnectionManager { .clients .get(CODEX_APPS_MCP_SERVER_NAME) .ok_or_else(|| anyhow!("unknown MCP server '{CODEX_APPS_MCP_SERVER_NAME}'"))? - .client() + .client(None) .await .context("failed to get client")?; @@ -845,6 +903,7 @@ impl McpConnectionManager { CODEX_APPS_MCP_SERVER_NAME, &managed_client.client, managed_client.tool_timeout, + None, ) .await .with_context(|| { @@ -881,7 +940,7 @@ impl McpConnectionManager { for (server_name, async_managed_client) in clients_snapshot { let server_name = server_name.clone(); - let Ok(managed_client) = async_managed_client.client().await else { + let Ok(managed_client) = async_managed_client.client(None).await else { continue; }; let timeout = managed_client.tool_timeout; @@ -947,7 +1006,7 @@ impl McpConnectionManager { for (server_name, async_managed_client) in clients_snapshot { let server_name_cloned = server_name.clone(); - let Ok(managed_client) = async_managed_client.client().await else { + let Ok(managed_client) = async_managed_client.client(None).await else { continue; }; let client = managed_client.client.clone(); @@ -1015,8 +1074,9 @@ impl McpConnectionManager { tool: &str, arguments: Option, meta: Option, + request_headers: Option, ) -> Result { - let client = self.client_by_name(server).await?; + let client = self.client_by_name(server, request_headers.clone()).await?; if !client.tool_filter.allows(tool) { return Err(anyhow!( "tool '{tool}' is disabled for MCP server '{server}'" @@ -1025,7 +1085,13 @@ impl McpConnectionManager { let result: rmcp::model::CallToolResult = client .client - .call_tool(tool.to_string(), arguments, meta, client.tool_timeout) + .call_tool( + tool.to_string(), + arguments, + meta, + client.tool_timeout, + request_headers, + ) .await .with_context(|| format!("tool call failed for `{server}/{tool}`"))?; @@ -1052,7 +1118,7 @@ impl McpConnectionManager { server: &str, params: Option, ) -> Result { - let managed = self.client_by_name(server).await?; + let managed = self.client_by_name(server, None).await?; let timeout = managed.tool_timeout; managed @@ -1068,7 +1134,7 @@ impl McpConnectionManager { server: &str, params: Option, ) -> Result { - let managed = self.client_by_name(server).await?; + let managed = self.client_by_name(server, None).await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; @@ -1084,7 +1150,7 @@ impl McpConnectionManager { server: &str, params: ReadResourceRequestParams, ) -> Result { - let managed = self.client_by_name(server).await?; + let managed = self.client_by_name(server, None).await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; let uri = params.uri.clone(); @@ -1344,6 +1410,7 @@ async fn start_server_task( tx_event, elicitation_requests, codex_apps_tools_cache_context, + startup_request_headers, } = params; let elicitation = elicitation_capability_for_server(&server_name); let params = InitializeRequestParams { @@ -1369,16 +1436,34 @@ async fn start_server_task( let send_elicitation = elicitation_requests.make_sender(server_name.clone(), tx_event); + let initialize_request_headers = startup_request_headers + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone(); let initialize_result = client - .initialize(params, startup_timeout, send_elicitation) + .initialize( + params, + startup_timeout, + send_elicitation, + initialize_request_headers, + ) .await .map_err(StartupOutcomeError::from)?; let list_start = Instant::now(); let fetch_start = Instant::now(); - let tools = list_tools_for_client_uncached(&server_name, &client, startup_timeout) - .await - .map_err(StartupOutcomeError::from)?; + let list_request_headers = startup_request_headers + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone(); + let tools = list_tools_for_client_uncached( + &server_name, + &client, + startup_timeout, + list_request_headers, + ) + .await + .map_err(StartupOutcomeError::from)?; emit_duration( MCP_TOOLS_FETCH_UNCACHED_DURATION_METRIC, fetch_start.elapsed(), @@ -1423,6 +1508,7 @@ struct StartServerTaskParams { tx_event: Sender, elicitation_requests: ElicitationRequestManager, codex_apps_tools_cache_context: Option, + startup_request_headers: Arc>>, } async fn make_rmcp_client( @@ -1584,9 +1670,10 @@ async fn list_tools_for_client_uncached( server_name: &str, client: &Arc, timeout: Option, + request_headers: Option, ) -> Result> { let resp = client - .list_tools_with_connector_ids(/*params*/ None, timeout) + .list_tools_with_connector_ids(/*params*/ None, timeout, request_headers) .await?; let tools = resp .tools diff --git a/codex-rs/core/src/mcp_connection_manager_tests.rs b/codex-rs/core/src/mcp_connection_manager_tests.rs index c5f7fc4a4086..980f4091a829 100644 --- a/codex-rs/core/src/mcp_connection_manager_tests.rs +++ b/codex-rs/core/src/mcp_connection_manager_tests.rs @@ -4,6 +4,7 @@ use codex_protocol::protocol::McpAuthStatus; use rmcp::model::JsonObject; use std::collections::HashSet; use std::sync::Arc; +use std::sync::Mutex as StdMutex; use tempfile::tempdir; fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo { @@ -415,6 +416,7 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() { client: pending_client, startup_snapshot: Some(startup_tools), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), + startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); @@ -440,6 +442,7 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot( client: pending_client, startup_snapshot: None, startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), + startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); @@ -462,6 +465,7 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty( client: pending_client, startup_snapshot: Some(Vec::new()), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), + startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); @@ -494,6 +498,7 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() { client: failed_client, startup_snapshot: Some(startup_tools), startup_complete, + startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); diff --git a/codex-rs/core/src/mcp_tool_call.rs b/codex-rs/core/src/mcp_tool_call.rs index 06d801cbac8c..cfa6b471db26 100644 --- a/codex-rs/core/src/mcp_tool_call.rs +++ b/codex-rs/core/src/mcp_tool_call.rs @@ -45,6 +45,8 @@ use codex_protocol::request_user_input::RequestUserInputQuestionOption; use codex_protocol::request_user_input::RequestUserInputResponse; use codex_rmcp_client::ElicitationAction; use codex_rmcp_client::ElicitationResponse; +use reqwest::header::HeaderMap; +use reqwest::header::HeaderValue; use rmcp::model::ToolAnnotations; use serde::Serialize; use std::path::Path; @@ -81,8 +83,15 @@ pub(crate) async fn handle_mcp_tool_call( arguments: arguments_value.clone(), }; - let metadata = - lookup_mcp_tool_metadata(sess.as_ref(), turn_context.as_ref(), &server, &tool_name).await; + let request_headers = build_mcp_request_headers(sess.as_ref(), turn_context.as_ref(), &server); + let metadata = lookup_mcp_tool_metadata( + sess.as_ref(), + turn_context.as_ref(), + &server, + &tool_name, + request_headers.clone(), + ) + .await; let app_tool_policy = if server == CODEX_APPS_MCP_SERVER_NAME { connectors::app_tool_policy( &turn_context.config, @@ -150,6 +159,7 @@ pub(crate) async fn handle_mcp_tool_call( &tool_name, arguments_value.clone(), request_meta.clone(), + request_headers.clone(), ) .await .map_err(|e| format!("tool call error: {e:?}")); @@ -180,6 +190,7 @@ pub(crate) async fn handle_mcp_tool_call( turn_context.as_ref(), &server, &tool_name, + request_headers.clone(), ) .await; result @@ -236,7 +247,13 @@ pub(crate) async fn handle_mcp_tool_call( let start = Instant::now(); // Perform the tool call. let result = sess - .call_tool(&server, &tool_name, arguments_value.clone(), request_meta) + .call_tool( + &server, + &tool_name, + arguments_value.clone(), + request_meta, + request_headers.clone(), + ) .await .map_err(|e| format!("tool call error: {e:?}")); let result = sanitize_mcp_tool_result_for_model( @@ -262,7 +279,14 @@ pub(crate) async fn handle_mcp_tool_call( tool_call_end_event.clone(), ) .await; - maybe_track_codex_app_used(sess.as_ref(), turn_context.as_ref(), &server, &tool_name).await; + maybe_track_codex_app_used( + sess.as_ref(), + turn_context.as_ref(), + &server, + &tool_name, + request_headers.clone(), + ) + .await; let status = if result.is_ok() { "ok" } else { "error" }; turn_context @@ -333,11 +357,12 @@ async fn maybe_track_codex_app_used( turn_context: &TurnContext, server: &str, tool_name: &str, + request_headers: Option, ) { if server != CODEX_APPS_MCP_SERVER_NAME { return; } - let metadata = lookup_mcp_app_usage_metadata(sess, server, tool_name).await; + let metadata = lookup_mcp_app_usage_metadata(sess, server, tool_name, request_headers).await; let (connector_id, app_name) = metadata .map(|metadata| (metadata.connector_id, metadata.app_name)) .unwrap_or((None, None)); @@ -389,6 +414,34 @@ pub(crate) struct McpToolApprovalMetadata { const MCP_TOOL_CODEX_APPS_META_KEY: &str = "_codex_apps"; +fn build_mcp_request_headers( + sess: &Session, + turn_context: &TurnContext, + server: &str, +) -> Option { + if server != CODEX_APPS_MCP_SERVER_NAME { + return None; + } + + let session_id = sess.conversation_id.to_string(); + let mut headers = HeaderMap::new(); + if let Ok(value) = HeaderValue::from_str(&session_id) { + headers.insert("session_id", value.clone()); + headers.insert("x-client-request-id", value); + } + if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value() + && let Ok(value) = HeaderValue::from_str(&turn_metadata) + { + headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value); + } + + if headers.is_empty() { + None + } else { + Some(headers) + } +} + fn build_mcp_tool_call_request_meta( server: &str, metadata: Option<&McpToolApprovalMetadata>, @@ -741,13 +794,14 @@ pub(crate) async fn lookup_mcp_tool_metadata( turn_context: &TurnContext, server: &str, tool_name: &str, + request_headers: Option, ) -> Option { let tools = sess .services .mcp_connection_manager .read() .await - .list_all_tools() + .list_all_tools_with_request_headers(request_headers.clone()) .await; let tool_info = tools @@ -798,13 +852,14 @@ async fn lookup_mcp_app_usage_metadata( sess: &Session, server: &str, tool_name: &str, + request_headers: Option, ) -> Option { let tools = sess .services .mcp_connection_manager .read() .await - .list_all_tools() + .list_all_tools_with_request_headers(request_headers) .await; tools.into_values().find_map(|tool_info| { diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index b898403b25c7..42aeb149ad6d 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -5,6 +5,7 @@ use std::io; use std::path::PathBuf; use std::process::Stdio; use std::sync::Arc; +use std::sync::Mutex as StdMutex; use std::time::Duration; use anyhow::Result; @@ -82,15 +83,37 @@ const JSON_MIME_TYPE: &str = "application/json"; const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id"; const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192; +fn apply_request_scoped_headers( + mut request: reqwest::RequestBuilder, + request_headers_state: &Arc>>, +) -> reqwest::RequestBuilder { + let extra_headers = request_headers_state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone(); + if let Some(extra_headers) = extra_headers { + for (name, value) in &extra_headers { + request = request.header(name, value.clone()); + } + } + request +} #[derive(Clone)] struct StreamableHttpResponseClient { inner: reqwest::Client, + request_headers_state: Arc>>, } impl StreamableHttpResponseClient { - fn new(inner: reqwest::Client) -> Self { - Self { inner } + fn new( + inner: reqwest::Client, + request_headers_state: Arc>>, + ) -> Self { + Self { + inner, + request_headers_state, + } } fn reqwest_error( @@ -133,6 +156,7 @@ impl StreamableHttpClient for StreamableHttpResponseClient { if let Some(session_id_value) = session_id.as_ref() { request = request.header(HEADER_SESSION_ID, session_id_value.as_ref()); } + request = apply_request_scoped_headers(request, &self.request_headers_state); let response = request .json(&message) @@ -228,6 +252,8 @@ impl StreamableHttpClient for StreamableHttpResponseClient { if let Some(auth_header) = auth_token { request_builder = request_builder.bearer_auth(auth_header); } + request_builder = + apply_request_scoped_headers(request_builder, &self.request_headers_state); let response = request_builder .header(HEADER_SESSION_ID, session.as_ref()) .send() @@ -265,6 +291,8 @@ impl StreamableHttpClient for StreamableHttpResponseClient { if let Some(auth_header) = auth_token { request_builder = request_builder.bearer_auth(auth_header); } + request_builder = + apply_request_scoped_headers(request_builder, &self.request_headers_state); let response = request_builder .send() @@ -472,6 +500,39 @@ pub struct RmcpClient { transport_recipe: TransportRecipe, initialize_context: Mutex>, session_recovery_lock: Mutex<()>, + request_headers: Option>>>, +} + +struct RequestHeadersGuard { + state: Option>>>, + previous: Option, +} + +impl RequestHeadersGuard { + fn set(state: Option>>>, headers: Option) -> Self { + let previous = if let Some(state_ref) = state.as_ref() { + let mut guard = state_ref + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let previous = guard.clone(); + *guard = headers; + previous + } else { + None + }; + Self { state, previous } + } +} + +impl Drop for RequestHeadersGuard { + fn drop(&mut self) { + if let Some(state) = self.state.as_ref() { + let mut guard = state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *guard = self.previous.clone(); + } + } } impl RmcpClient { @@ -489,7 +550,7 @@ impl RmcpClient { env_vars: env_vars.to_vec(), cwd, }; - let transport = Self::create_pending_transport(&transport_recipe) + let transport = Self::create_pending_transport(&transport_recipe, None) .await .map_err(io::Error::other)?; @@ -500,6 +561,7 @@ impl RmcpClient { transport_recipe, initialize_context: Mutex::new(None), session_recovery_lock: Mutex::new(()), + request_headers: None, }) } @@ -520,7 +582,9 @@ impl RmcpClient { env_http_headers, store_mode, }; - let transport = Self::create_pending_transport(&transport_recipe).await?; + let request_headers = Some(Arc::new(StdMutex::new(None))); + let transport = + Self::create_pending_transport(&transport_recipe, request_headers.clone()).await?; Ok(Self { state: Mutex::new(ClientState::Connecting { transport: Some(transport), @@ -528,6 +592,7 @@ impl RmcpClient { transport_recipe, initialize_context: Mutex::new(None), session_recovery_lock: Mutex::new(()), + request_headers, }) } @@ -538,6 +603,7 @@ impl RmcpClient { params: InitializeRequestParams, timeout: Option, send_elicitation: SendElicitation, + request_headers: Option, ) -> Result { let client_handler = LoggingClientHandler::new(params.clone(), send_elicitation); let pending_transport = { @@ -551,6 +617,8 @@ impl RmcpClient { } }; + let _request_headers_guard = + RequestHeadersGuard::set(self.request_headers.clone(), request_headers); let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport(pending_transport, client_handler.clone(), timeout) .await?; @@ -607,8 +675,11 @@ impl RmcpClient { &self, params: Option, timeout: Option, + request_headers: Option, ) -> Result { self.refresh_oauth_if_needed().await; + let _request_headers_guard = + RequestHeadersGuard::set(self.request_headers.clone(), request_headers); let result = self .run_service_operation("tools/list", timeout, move |service| { let params = params.clone(); @@ -702,6 +773,7 @@ impl RmcpClient { arguments: Option, meta: Option, timeout: Option, + request_headers: Option, ) -> Result { self.refresh_oauth_if_needed().await; let arguments = match arguments { @@ -728,6 +800,8 @@ impl RmcpClient { arguments, task: None, }; + let _request_headers_guard = + RequestHeadersGuard::set(self.request_headers.clone(), request_headers); let result = self .run_service_operation("tools/call", timeout, move |service| { let rmcp_params = rmcp_params.clone(); @@ -830,6 +904,7 @@ impl RmcpClient { async fn create_pending_transport( transport_recipe: &TransportRecipe, + request_headers: Option>>>, ) -> Result { match transport_recipe { TransportRecipe::Stdio { @@ -946,7 +1021,12 @@ impl RmcpClient { .auth_header(access_token); let http_client = build_http_client(&default_headers)?; let transport = StreamableHttpClientTransport::with_client( - StreamableHttpResponseClient::new(http_client), + StreamableHttpResponseClient::new( + http_client, + request_headers + .clone() + .unwrap_or_else(|| Arc::new(StdMutex::new(None))), + ), http_config, ); Ok(PendingTransport::StreamableHttp { transport }) @@ -963,7 +1043,12 @@ impl RmcpClient { let http_client = build_http_client(&default_headers)?; let transport = StreamableHttpClientTransport::with_client( - StreamableHttpResponseClient::new(http_client), + StreamableHttpResponseClient::new( + http_client, + request_headers + .clone() + .unwrap_or_else(|| Arc::new(StdMutex::new(None))), + ), http_config, ); Ok(PendingTransport::StreamableHttp { transport }) @@ -1111,7 +1196,9 @@ impl RmcpClient { .await .clone() .ok_or_else(|| anyhow!("MCP client cannot recover before initialize succeeds"))?; - let pending_transport = Self::create_pending_transport(&self.transport_recipe).await?; + let pending_transport = + Self::create_pending_transport(&self.transport_recipe, self.request_headers.clone()) + .await?; let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport( pending_transport, initialize_context.handler, @@ -1166,7 +1253,10 @@ async fn create_oauth_transport_and_runtime( } }; - let auth_client = AuthClient::new(StreamableHttpResponseClient::new(http_client), manager); + let auth_client = AuthClient::new( + StreamableHttpResponseClient::new(http_client, Arc::new(StdMutex::new(None))), + manager, + ); let auth_manager = auth_client.auth_manager.clone(); let transport = StreamableHttpClientTransport::with_client( diff --git a/codex-rs/rmcp-client/tests/resources.rs b/codex-rs/rmcp-client/tests/resources.rs index ba1a8e43106a..4a3e587d11bb 100644 --- a/codex-rs/rmcp-client/tests/resources.rs +++ b/codex-rs/rmcp-client/tests/resources.rs @@ -78,6 +78,7 @@ async fn rmcp_client_can_list_and_read_resources() -> anyhow::Result<()> { } .boxed() }), + /*request_headers*/ None, ) .await?; diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index fb2fc96d20f1..6512e248693a 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -94,6 +94,7 @@ async fn create_client(base_url: &str) -> anyhow::Result { } .boxed() }), + /*request_headers*/ None, ) .await?; @@ -107,6 +108,7 @@ async fn call_echo_tool(client: &RmcpClient, message: &str) -> anyhow::Result Date: Wed, 18 Mar 2026 09:12:10 -0700 Subject: [PATCH 2/6] codex: fix CI failure on PR #15011 Co-authored-by: Codex --- codex-rs/rmcp-client/src/rmcp_client.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 42aeb149ad6d..86d14cf881bf 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -550,9 +550,10 @@ impl RmcpClient { env_vars: env_vars.to_vec(), cwd, }; - let transport = Self::create_pending_transport(&transport_recipe, None) - .await - .map_err(io::Error::other)?; + let transport = + Self::create_pending_transport(&transport_recipe, /*request_headers*/ None) + .await + .map_err(io::Error::other)?; Ok(Self { state: Mutex::new(ClientState::Connecting { From de9864340312e1c60307a8c55a965676c6d3db56 Mon Sep 17 00:00:00 2001 From: nicholasclark-openai Date: Wed, 18 Mar 2026 09:41:06 -0700 Subject: [PATCH 3/6] codex: fix CI failure on PR #15011 Co-authored-by: Codex --- codex-rs/core/src/auth_env_telemetry.rs | 1 + codex-rs/core/src/codex_delegate.rs | 2 +- codex-rs/core/src/mcp_connection_manager.rs | 38 ++++++++++++++------- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/codex-rs/core/src/auth_env_telemetry.rs b/codex-rs/core/src/auth_env_telemetry.rs index be281e05a1ff..85cd23fe06f7 100644 --- a/codex-rs/core/src/auth_env_telemetry.rs +++ b/codex-rs/core/src/auth_env_telemetry.rs @@ -71,6 +71,7 @@ mod tests { request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, + websocket_connect_timeout_ms: None, requires_openai_auth: false, supports_websockets: false, }; diff --git a/codex-rs/core/src/codex_delegate.rs b/codex-rs/core/src/codex_delegate.rs index 6bbc4f01bec6..1838150408d8 100644 --- a/codex-rs/core/src/codex_delegate.rs +++ b/codex-rs/core/src/codex_delegate.rs @@ -669,7 +669,7 @@ async fn maybe_auto_review_mcp_request_user_input( parent_ctx.as_ref(), &invocation.server, &invocation.tool, - None, + /*request_headers*/ None, ) .await; let review_cancel = cancel_token.child_token(); diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 939affa9dff4..579a79b433f8 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -618,7 +618,7 @@ impl AsyncManagedClient { } async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> { - let managed = self.client(None).await?; + let managed = self.client(/*request_headers*/ None).await?; managed.notify_sandbox_state_change(sandbox_state).await } } @@ -731,7 +731,7 @@ impl McpConnectionManager { let auth_entry = auth_entries.get(&server_name).cloned(); let sandbox_state = initial_sandbox_state.clone(); join_set.spawn(async move { - let outcome = async_managed_client.client(None).await; + let outcome = async_managed_client.client(/*request_headers*/ None).await; if cancel_token.is_cancelled() { return (server_name, Err(StartupOutcomeError::Cancelled)); } @@ -829,7 +829,12 @@ impl McpConnectionManager { return false; }; - match tokio::time::timeout(timeout, async_managed_client.client(None)).await { + match tokio::time::timeout( + timeout, + async_managed_client.client(/*request_headers*/ None), + ) + .await + { Ok(Ok(_)) => true, Ok(Err(_)) | Err(_) => false, } @@ -849,7 +854,7 @@ impl McpConnectionManager { continue; }; - match async_managed_client.client(None).await { + match async_managed_client.client(/*request_headers*/ None).await { Ok(_) => {} Err(error) => failures.push(McpStartupFailure { server: server_name.clone(), @@ -864,7 +869,8 @@ impl McpConnectionManager { /// fully-qualified name for the tool. #[instrument(level = "trace", skip_all)] pub async fn list_all_tools(&self) -> HashMap { - self.list_all_tools_with_request_headers(None).await + self.list_all_tools_with_request_headers(/*request_headers*/ None) + .await } #[instrument(level = "trace", skip_all)] @@ -893,7 +899,7 @@ impl McpConnectionManager { .clients .get(CODEX_APPS_MCP_SERVER_NAME) .ok_or_else(|| anyhow!("unknown MCP server '{CODEX_APPS_MCP_SERVER_NAME}'"))? - .client(None) + .client(/*request_headers*/ None) .await .context("failed to get client")?; @@ -903,7 +909,7 @@ impl McpConnectionManager { CODEX_APPS_MCP_SERVER_NAME, &managed_client.client, managed_client.tool_timeout, - None, + /*request_headers*/ None, ) .await .with_context(|| { @@ -940,7 +946,8 @@ impl McpConnectionManager { for (server_name, async_managed_client) in clients_snapshot { let server_name = server_name.clone(); - let Ok(managed_client) = async_managed_client.client(None).await else { + let Ok(managed_client) = async_managed_client.client(/*request_headers*/ None).await + else { continue; }; let timeout = managed_client.tool_timeout; @@ -1006,7 +1013,8 @@ impl McpConnectionManager { for (server_name, async_managed_client) in clients_snapshot { let server_name_cloned = server_name.clone(); - let Ok(managed_client) = async_managed_client.client(None).await else { + let Ok(managed_client) = async_managed_client.client(/*request_headers*/ None).await + else { continue; }; let client = managed_client.client.clone(); @@ -1118,7 +1126,9 @@ impl McpConnectionManager { server: &str, params: Option, ) -> Result { - let managed = self.client_by_name(server, None).await?; + let managed = self + .client_by_name(server, /*request_headers*/ None) + .await?; let timeout = managed.tool_timeout; managed @@ -1134,7 +1144,9 @@ impl McpConnectionManager { server: &str, params: Option, ) -> Result { - let managed = self.client_by_name(server, None).await?; + let managed = self + .client_by_name(server, /*request_headers*/ None) + .await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; @@ -1150,7 +1162,9 @@ impl McpConnectionManager { server: &str, params: ReadResourceRequestParams, ) -> Result { - let managed = self.client_by_name(server, None).await?; + let managed = self + .client_by_name(server, /*request_headers*/ None) + .await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; let uri = params.uri.clone(); From 732d7ac81fd1ef112dd21714055cc6fc83feb316 Mon Sep 17 00:00:00 2001 From: nicholasclark-openai Date: Wed, 18 Mar 2026 16:42:10 -0700 Subject: [PATCH 4/6] codex: scope MCP request headers to turn lifecycle Set Codex Apps MCP request headers once per active turn and clear them on turn end, instead of threading request-scoped headers through every tool call. Keep RMCP header injection limited to streamable HTTP tools/call requests so list/init paths stay unchanged and concurrent tool calls on the same client are not serialized. Co-authored-by: Codex --- codex-rs/core/src/codex.rs | 41 +++- codex-rs/core/src/codex_delegate.rs | 1 - codex-rs/core/src/mcp_connection_manager.rs | 186 ++++++------------ .../core/src/mcp_connection_manager_tests.rs | 5 - codex-rs/core/src/mcp_tool_call.rs | 69 +------ codex-rs/core/src/tasks/mod.rs | 6 + codex-rs/rmcp-client/src/rmcp_client.rs | 66 ++----- codex-rs/rmcp-client/tests/resources.rs | 1 - .../tests/streamable_http_recovery.rs | 2 - 9 files changed, 127 insertions(+), 250 deletions(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 170b052f2a67..db0e4472e9e7 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -125,6 +125,8 @@ use futures::future::BoxFuture; use futures::future::Shared; use futures::prelude::*; use futures::stream::FuturesOrdered; +use reqwest::header::HeaderMap; +use reqwest::header::HeaderValue; use rmcp::model::ListResourceTemplatesResult; use rmcp::model::ListResourcesResult; use rmcp::model::PaginatedRequestParams; @@ -3942,16 +3944,51 @@ impl Session { tool: &str, arguments: Option, meta: Option, - request_headers: Option, ) -> anyhow::Result { self.services .mcp_connection_manager .read() .await - .call_tool(server, tool, arguments, meta, request_headers) + .call_tool(server, tool, arguments, meta) .await } + pub(crate) async fn sync_mcp_request_headers_for_turn(&self, turn_context: &TurnContext) { + let mut request_headers = HeaderMap::new(); + let session_id = self.conversation_id.to_string(); + if let Ok(value) = HeaderValue::from_str(&session_id) { + request_headers.insert("session_id", value.clone()); + request_headers.insert("x-client-request-id", value); + } + if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value() + && let Ok(value) = HeaderValue::from_str(&turn_metadata) + { + request_headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value); + } + + let request_headers = if request_headers.is_empty() { + None + } else { + Some(request_headers) + }; + self.services + .mcp_connection_manager + .read() + .await + .set_request_headers_for_server( + crate::mcp::CODEX_APPS_MCP_SERVER_NAME, + request_headers, + ); + } + + pub(crate) async fn clear_mcp_request_headers(&self) { + self.services + .mcp_connection_manager + .read() + .await + .set_request_headers_for_server(crate::mcp::CODEX_APPS_MCP_SERVER_NAME, None); + } + pub(crate) async fn parse_mcp_tool_name( &self, name: &str, diff --git a/codex-rs/core/src/codex_delegate.rs b/codex-rs/core/src/codex_delegate.rs index 084c6f1fb37f..e560cd9c7f15 100644 --- a/codex-rs/core/src/codex_delegate.rs +++ b/codex-rs/core/src/codex_delegate.rs @@ -670,7 +670,6 @@ async fn maybe_auto_review_mcp_request_user_input( parent_ctx.as_ref(), &invocation.server, &invocation.tool, - /*request_headers*/ None, ) .await; let review_cancel = cancel_token.child_token(); diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 579a79b433f8..7c8a34307022 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -378,7 +378,7 @@ struct ManagedClient { } impl ManagedClient { - fn listed_tools(&self, _request_headers: Option) -> Vec { + fn listed_tools(&self) -> Vec { let total_start = Instant::now(); if let Some(cache_context) = self.codex_apps_tools_cache_context.as_ref() && let CachedCodexAppsToolsLoad::Hit(tools) = @@ -423,44 +423,12 @@ impl ManagedClient { #[derive(Clone)] struct AsyncManagedClient { client: Shared>>, + request_headers: Arc>>, startup_snapshot: Option>, startup_complete: Arc, - startup_request_headers: Arc>>, tool_plugin_provenance: Arc, } -struct StartupRequestHeadersGuard { - state: Arc>>, - previous: Option, -} - -impl StartupRequestHeadersGuard { - fn set( - state: Arc>>, - headers: Option, - ) -> Self { - let previous = { - let mut guard = state - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - let previous = guard.clone(); - *guard = headers; - previous - }; - Self { state, previous } - } -} - -impl Drop for StartupRequestHeadersGuard { - fn drop(&mut self) { - let mut guard = self - .state - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - *guard = self.previous.clone(); - } -} - impl AsyncManagedClient { // Keep this constructor flat so the startup inputs remain readable at the // single call site instead of introducing a one-off params wrapper. @@ -481,19 +449,26 @@ impl AsyncManagedClient { codex_apps_tools_cache_context.as_ref(), ) .map(|tools| filter_tools(tools, &tool_filter)); + let request_headers = Arc::new(StdMutex::new(None)); let startup_tool_filter = tool_filter; let startup_complete = Arc::new(AtomicBool::new(false)); let startup_complete_for_fut = Arc::clone(&startup_complete); - let startup_request_headers = Arc::new(StdMutex::new(None)); - let startup_request_headers_for_fut = Arc::clone(&startup_request_headers); + let request_headers_for_client = Arc::clone(&request_headers); let fut = async move { let outcome = async { if let Err(error) = validate_mcp_server_name(&server_name) { return Err(error.into()); } - let client = - Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?); + let client = Arc::new( + make_rmcp_client( + &server_name, + config.transport, + store_mode, + request_headers_for_client, + ) + .await?, + ); match start_server_task( server_name, client, @@ -506,7 +481,6 @@ impl AsyncManagedClient { tx_event, elicitation_requests, codex_apps_tools_cache_context, - startup_request_headers: startup_request_headers_for_fut, }, ) .or_cancel(&cancel_token) @@ -531,19 +505,14 @@ impl AsyncManagedClient { Self { client, + request_headers, startup_snapshot, startup_complete, - startup_request_headers, tool_plugin_provenance, } } - async fn client( - &self, - request_headers: Option, - ) -> Result { - let _request_headers_guard = - StartupRequestHeadersGuard::set(self.startup_request_headers.clone(), request_headers); + async fn client(&self) -> Result { self.client.clone().await } @@ -554,10 +523,7 @@ impl AsyncManagedClient { None } - async fn listed_tools( - &self, - request_headers: Option, - ) -> Option> { + async fn listed_tools(&self) -> Option> { let annotate_tools = |tools: Vec| { let mut tools = tools; for tool in &mut tools { @@ -609,8 +575,8 @@ impl AsyncManagedClient { let tools = if let Some(startup_tools) = self.startup_snapshot_while_initializing() { Some(startup_tools) } else { - match self.client(request_headers.clone()).await { - Ok(client) => Some(client.listed_tools(request_headers)), + match self.client().await { + Ok(client) => Some(client.listed_tools()), Err(_) => self.startup_snapshot.clone(), } }; @@ -618,9 +584,17 @@ impl AsyncManagedClient { } async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> { - let managed = self.client(/*request_headers*/ None).await?; + let managed = self.client().await?; managed.notify_sandbox_state_change(sandbox_state).await } + + fn set_request_headers(&self, request_headers: Option) { + let mut guard = self + .request_headers + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *guard = request_headers; + } } pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state"; @@ -731,7 +705,7 @@ impl McpConnectionManager { let auth_entry = auth_entries.get(&server_name).cloned(); let sandbox_state = initial_sandbox_state.clone(); join_set.spawn(async move { - let outcome = async_managed_client.client(/*request_headers*/ None).await; + let outcome = async_managed_client.client().await; if cancel_token.is_cancelled() { return (server_name, Err(StartupOutcomeError::Cancelled)); } @@ -800,15 +774,11 @@ impl McpConnectionManager { (manager, cancel_token) } - async fn client_by_name( - &self, - name: &str, - request_headers: Option, - ) -> Result { + async fn client_by_name(&self, name: &str) -> Result { self.clients .get(name) .ok_or_else(|| anyhow!("unknown MCP server '{name}'"))? - .client(request_headers) + .client() .await .context("failed to get client") } @@ -829,12 +799,7 @@ impl McpConnectionManager { return false; }; - match tokio::time::timeout( - timeout, - async_managed_client.client(/*request_headers*/ None), - ) - .await - { + match tokio::time::timeout(timeout, async_managed_client.client()).await { Ok(Ok(_)) => true, Ok(Err(_)) | Err(_) => false, } @@ -854,7 +819,7 @@ impl McpConnectionManager { continue; }; - match async_managed_client.client(/*request_headers*/ None).await { + match async_managed_client.client().await { Ok(_) => {} Err(error) => failures.push(McpStartupFailure { server: server_name.clone(), @@ -869,19 +834,9 @@ impl McpConnectionManager { /// fully-qualified name for the tool. #[instrument(level = "trace", skip_all)] pub async fn list_all_tools(&self) -> HashMap { - self.list_all_tools_with_request_headers(/*request_headers*/ None) - .await - } - - #[instrument(level = "trace", skip_all)] - pub async fn list_all_tools_with_request_headers( - &self, - request_headers: Option, - ) -> HashMap { let mut tools = HashMap::new(); for managed_client in self.clients.values() { - let Some(server_tools) = managed_client.listed_tools(request_headers.clone()).await - else { + let Some(server_tools) = managed_client.listed_tools().await else { continue; }; tools.extend(qualify_tools(server_tools)); @@ -899,7 +854,7 @@ impl McpConnectionManager { .clients .get(CODEX_APPS_MCP_SERVER_NAME) .ok_or_else(|| anyhow!("unknown MCP server '{CODEX_APPS_MCP_SERVER_NAME}'"))? - .client(/*request_headers*/ None) + .client() .await .context("failed to get client")?; @@ -909,7 +864,6 @@ impl McpConnectionManager { CODEX_APPS_MCP_SERVER_NAME, &managed_client.client, managed_client.tool_timeout, - /*request_headers*/ None, ) .await .with_context(|| { @@ -946,8 +900,7 @@ impl McpConnectionManager { for (server_name, async_managed_client) in clients_snapshot { let server_name = server_name.clone(); - let Ok(managed_client) = async_managed_client.client(/*request_headers*/ None).await - else { + let Ok(managed_client) = async_managed_client.client().await else { continue; }; let timeout = managed_client.tool_timeout; @@ -1013,8 +966,7 @@ impl McpConnectionManager { for (server_name, async_managed_client) in clients_snapshot { let server_name_cloned = server_name.clone(); - let Ok(managed_client) = async_managed_client.client(/*request_headers*/ None).await - else { + let Ok(managed_client) = async_managed_client.client().await else { continue; }; let client = managed_client.client.clone(); @@ -1082,9 +1034,8 @@ impl McpConnectionManager { tool: &str, arguments: Option, meta: Option, - request_headers: Option, ) -> Result { - let client = self.client_by_name(server, request_headers.clone()).await?; + let client = self.client_by_name(server).await?; if !client.tool_filter.allows(tool) { return Err(anyhow!( "tool '{tool}' is disabled for MCP server '{server}'" @@ -1093,13 +1044,7 @@ impl McpConnectionManager { let result: rmcp::model::CallToolResult = client .client - .call_tool( - tool.to_string(), - arguments, - meta, - client.tool_timeout, - request_headers, - ) + .call_tool(tool.to_string(), arguments, meta, client.tool_timeout) .await .with_context(|| format!("tool call failed for `{server}/{tool}`"))?; @@ -1120,15 +1065,23 @@ impl McpConnectionManager { }) } + pub(crate) fn set_request_headers_for_server( + &self, + server_name: &str, + request_headers: Option, + ) { + if let Some(client) = self.clients.get(server_name) { + client.set_request_headers(request_headers); + } + } + /// List resources from the specified server. pub async fn list_resources( &self, server: &str, params: Option, ) -> Result { - let managed = self - .client_by_name(server, /*request_headers*/ None) - .await?; + let managed = self.client_by_name(server).await?; let timeout = managed.tool_timeout; managed @@ -1144,9 +1097,7 @@ impl McpConnectionManager { server: &str, params: Option, ) -> Result { - let managed = self - .client_by_name(server, /*request_headers*/ None) - .await?; + let managed = self.client_by_name(server).await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; @@ -1162,9 +1113,7 @@ impl McpConnectionManager { server: &str, params: ReadResourceRequestParams, ) -> Result { - let managed = self - .client_by_name(server, /*request_headers*/ None) - .await?; + let managed = self.client_by_name(server).await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; let uri = params.uri.clone(); @@ -1424,7 +1373,6 @@ async fn start_server_task( tx_event, elicitation_requests, codex_apps_tools_cache_context, - startup_request_headers, } = params; let elicitation = elicitation_capability_for_server(&server_name); let params = InitializeRequestParams { @@ -1450,34 +1398,16 @@ async fn start_server_task( let send_elicitation = elicitation_requests.make_sender(server_name.clone(), tx_event); - let initialize_request_headers = startup_request_headers - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .clone(); let initialize_result = client - .initialize( - params, - startup_timeout, - send_elicitation, - initialize_request_headers, - ) + .initialize(params, startup_timeout, send_elicitation) .await .map_err(StartupOutcomeError::from)?; let list_start = Instant::now(); let fetch_start = Instant::now(); - let list_request_headers = startup_request_headers - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .clone(); - let tools = list_tools_for_client_uncached( - &server_name, - &client, - startup_timeout, - list_request_headers, - ) - .await - .map_err(StartupOutcomeError::from)?; + let tools = list_tools_for_client_uncached(&server_name, &client, startup_timeout) + .await + .map_err(StartupOutcomeError::from)?; emit_duration( MCP_TOOLS_FETCH_UNCACHED_DURATION_METRIC, fetch_start.elapsed(), @@ -1522,13 +1452,13 @@ struct StartServerTaskParams { tx_event: Sender, elicitation_requests: ElicitationRequestManager, codex_apps_tools_cache_context: Option, - startup_request_headers: Arc>>, } async fn make_rmcp_client( server_name: &str, transport: McpServerTransportConfig, store_mode: OAuthCredentialsStoreMode, + request_headers: Arc>>, ) -> Result { match transport { McpServerTransportConfig::Stdio { @@ -1562,6 +1492,7 @@ async fn make_rmcp_client( http_headers, env_http_headers, store_mode, + request_headers, ) .await .map_err(StartupOutcomeError::from) @@ -1684,10 +1615,9 @@ async fn list_tools_for_client_uncached( server_name: &str, client: &Arc, timeout: Option, - request_headers: Option, ) -> Result> { let resp = client - .list_tools_with_connector_ids(/*params*/ None, timeout, request_headers) + .list_tools_with_connector_ids(/*params*/ None, timeout) .await?; let tools = resp .tools diff --git a/codex-rs/core/src/mcp_connection_manager_tests.rs b/codex-rs/core/src/mcp_connection_manager_tests.rs index 980f4091a829..c5f7fc4a4086 100644 --- a/codex-rs/core/src/mcp_connection_manager_tests.rs +++ b/codex-rs/core/src/mcp_connection_manager_tests.rs @@ -4,7 +4,6 @@ use codex_protocol::protocol::McpAuthStatus; use rmcp::model::JsonObject; use std::collections::HashSet; use std::sync::Arc; -use std::sync::Mutex as StdMutex; use tempfile::tempdir; fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo { @@ -416,7 +415,6 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() { client: pending_client, startup_snapshot: Some(startup_tools), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), - startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); @@ -442,7 +440,6 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot( client: pending_client, startup_snapshot: None, startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), - startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); @@ -465,7 +462,6 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty( client: pending_client, startup_snapshot: Some(Vec::new()), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), - startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); @@ -498,7 +494,6 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() { client: failed_client, startup_snapshot: Some(startup_tools), startup_complete, - startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); diff --git a/codex-rs/core/src/mcp_tool_call.rs b/codex-rs/core/src/mcp_tool_call.rs index cfa6b471db26..06d801cbac8c 100644 --- a/codex-rs/core/src/mcp_tool_call.rs +++ b/codex-rs/core/src/mcp_tool_call.rs @@ -45,8 +45,6 @@ use codex_protocol::request_user_input::RequestUserInputQuestionOption; use codex_protocol::request_user_input::RequestUserInputResponse; use codex_rmcp_client::ElicitationAction; use codex_rmcp_client::ElicitationResponse; -use reqwest::header::HeaderMap; -use reqwest::header::HeaderValue; use rmcp::model::ToolAnnotations; use serde::Serialize; use std::path::Path; @@ -83,15 +81,8 @@ pub(crate) async fn handle_mcp_tool_call( arguments: arguments_value.clone(), }; - let request_headers = build_mcp_request_headers(sess.as_ref(), turn_context.as_ref(), &server); - let metadata = lookup_mcp_tool_metadata( - sess.as_ref(), - turn_context.as_ref(), - &server, - &tool_name, - request_headers.clone(), - ) - .await; + let metadata = + lookup_mcp_tool_metadata(sess.as_ref(), turn_context.as_ref(), &server, &tool_name).await; let app_tool_policy = if server == CODEX_APPS_MCP_SERVER_NAME { connectors::app_tool_policy( &turn_context.config, @@ -159,7 +150,6 @@ pub(crate) async fn handle_mcp_tool_call( &tool_name, arguments_value.clone(), request_meta.clone(), - request_headers.clone(), ) .await .map_err(|e| format!("tool call error: {e:?}")); @@ -190,7 +180,6 @@ pub(crate) async fn handle_mcp_tool_call( turn_context.as_ref(), &server, &tool_name, - request_headers.clone(), ) .await; result @@ -247,13 +236,7 @@ pub(crate) async fn handle_mcp_tool_call( let start = Instant::now(); // Perform the tool call. let result = sess - .call_tool( - &server, - &tool_name, - arguments_value.clone(), - request_meta, - request_headers.clone(), - ) + .call_tool(&server, &tool_name, arguments_value.clone(), request_meta) .await .map_err(|e| format!("tool call error: {e:?}")); let result = sanitize_mcp_tool_result_for_model( @@ -279,14 +262,7 @@ pub(crate) async fn handle_mcp_tool_call( tool_call_end_event.clone(), ) .await; - maybe_track_codex_app_used( - sess.as_ref(), - turn_context.as_ref(), - &server, - &tool_name, - request_headers.clone(), - ) - .await; + maybe_track_codex_app_used(sess.as_ref(), turn_context.as_ref(), &server, &tool_name).await; let status = if result.is_ok() { "ok" } else { "error" }; turn_context @@ -357,12 +333,11 @@ async fn maybe_track_codex_app_used( turn_context: &TurnContext, server: &str, tool_name: &str, - request_headers: Option, ) { if server != CODEX_APPS_MCP_SERVER_NAME { return; } - let metadata = lookup_mcp_app_usage_metadata(sess, server, tool_name, request_headers).await; + let metadata = lookup_mcp_app_usage_metadata(sess, server, tool_name).await; let (connector_id, app_name) = metadata .map(|metadata| (metadata.connector_id, metadata.app_name)) .unwrap_or((None, None)); @@ -414,34 +389,6 @@ pub(crate) struct McpToolApprovalMetadata { const MCP_TOOL_CODEX_APPS_META_KEY: &str = "_codex_apps"; -fn build_mcp_request_headers( - sess: &Session, - turn_context: &TurnContext, - server: &str, -) -> Option { - if server != CODEX_APPS_MCP_SERVER_NAME { - return None; - } - - let session_id = sess.conversation_id.to_string(); - let mut headers = HeaderMap::new(); - if let Ok(value) = HeaderValue::from_str(&session_id) { - headers.insert("session_id", value.clone()); - headers.insert("x-client-request-id", value); - } - if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value() - && let Ok(value) = HeaderValue::from_str(&turn_metadata) - { - headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value); - } - - if headers.is_empty() { - None - } else { - Some(headers) - } -} - fn build_mcp_tool_call_request_meta( server: &str, metadata: Option<&McpToolApprovalMetadata>, @@ -794,14 +741,13 @@ pub(crate) async fn lookup_mcp_tool_metadata( turn_context: &TurnContext, server: &str, tool_name: &str, - request_headers: Option, ) -> Option { let tools = sess .services .mcp_connection_manager .read() .await - .list_all_tools_with_request_headers(request_headers.clone()) + .list_all_tools() .await; let tool_info = tools @@ -852,14 +798,13 @@ async fn lookup_mcp_app_usage_metadata( sess: &Session, server: &str, tool_name: &str, - request_headers: Option, ) -> Option { let tools = sess .services .mcp_connection_manager .read() .await - .list_all_tools_with_request_headers(request_headers) + .list_all_tools() .await; tools.into_values().find_map(|tool_info| { diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index c52e4f91780e..049ed56d45f2 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -153,6 +153,8 @@ impl Session { ) { self.abort_all_tasks(TurnAbortReason::Replaced).await; self.clear_connector_selection().await; + self.sync_mcp_request_headers_for_turn(turn_context.as_ref()) + .await; let task: Arc = Arc::new(task); let task_kind = task.kind(); @@ -233,6 +235,7 @@ impl Session { // in-flight approval wait can surface as a model-visible rejection before TurnAborted. active_turn.clear_pending().await; } + self.clear_mcp_request_headers().await; } pub async fn on_task_finished( @@ -262,6 +265,9 @@ impl Session { *active = None; } drop(active); + if should_clear_active_turn { + self.clear_mcp_request_headers().await; + } if !pending_input.is_empty() { for pending_input_item in pending_input { match inspect_pending_input(self, &turn_context, pending_input_item).await { diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 86d14cf881bf..cf4f90ad3b05 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -23,6 +23,7 @@ use reqwest::header::HeaderMap; use reqwest::header::WWW_AUTHENTICATE; use rmcp::model::CallToolRequestParams; use rmcp::model::CallToolResult; +use rmcp::model::ClientJsonRpcMessage; use rmcp::model::ClientNotification; use rmcp::model::ClientRequest; use rmcp::model::CreateElicitationRequestParams; @@ -83,6 +84,15 @@ const JSON_MIME_TYPE: &str = "application/json"; const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id"; const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192; + +fn message_uses_request_scoped_headers(message: &ClientJsonRpcMessage) -> bool { + matches!( + message, + ClientJsonRpcMessage::Request(request) + if request.request.method() == "tools/call" + ) +} + fn apply_request_scoped_headers( mut request: reqwest::RequestBuilder, request_headers_state: &Arc>>, @@ -156,7 +166,9 @@ impl StreamableHttpClient for StreamableHttpResponseClient { if let Some(session_id_value) = session_id.as_ref() { request = request.header(HEADER_SESSION_ID, session_id_value.as_ref()); } - request = apply_request_scoped_headers(request, &self.request_headers_state); + if message_uses_request_scoped_headers(&message) { + request = apply_request_scoped_headers(request, &self.request_headers_state); + } let response = request .json(&message) @@ -252,8 +264,6 @@ impl StreamableHttpClient for StreamableHttpResponseClient { if let Some(auth_header) = auth_token { request_builder = request_builder.bearer_auth(auth_header); } - request_builder = - apply_request_scoped_headers(request_builder, &self.request_headers_state); let response = request_builder .header(HEADER_SESSION_ID, session.as_ref()) .send() @@ -291,8 +301,6 @@ impl StreamableHttpClient for StreamableHttpResponseClient { if let Some(auth_header) = auth_token { request_builder = request_builder.bearer_auth(auth_header); } - request_builder = - apply_request_scoped_headers(request_builder, &self.request_headers_state); let response = request_builder .send() @@ -503,38 +511,6 @@ pub struct RmcpClient { request_headers: Option>>>, } -struct RequestHeadersGuard { - state: Option>>>, - previous: Option, -} - -impl RequestHeadersGuard { - fn set(state: Option>>>, headers: Option) -> Self { - let previous = if let Some(state_ref) = state.as_ref() { - let mut guard = state_ref - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - let previous = guard.clone(); - *guard = headers; - previous - } else { - None - }; - Self { state, previous } - } -} - -impl Drop for RequestHeadersGuard { - fn drop(&mut self) { - if let Some(state) = self.state.as_ref() { - let mut guard = state - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - *guard = self.previous.clone(); - } - } -} - impl RmcpClient { pub async fn new_stdio_client( program: OsString, @@ -574,6 +550,7 @@ impl RmcpClient { http_headers: Option>, env_http_headers: Option>, store_mode: OAuthCredentialsStoreMode, + request_headers: Arc>>, ) -> Result { let transport_recipe = TransportRecipe::StreamableHttp { server_name: server_name.to_string(), @@ -583,9 +560,9 @@ impl RmcpClient { env_http_headers, store_mode, }; - let request_headers = Some(Arc::new(StdMutex::new(None))); let transport = - Self::create_pending_transport(&transport_recipe, request_headers.clone()).await?; + Self::create_pending_transport(&transport_recipe, Some(Arc::clone(&request_headers))) + .await?; Ok(Self { state: Mutex::new(ClientState::Connecting { transport: Some(transport), @@ -593,7 +570,7 @@ impl RmcpClient { transport_recipe, initialize_context: Mutex::new(None), session_recovery_lock: Mutex::new(()), - request_headers, + request_headers: Some(request_headers), }) } @@ -604,7 +581,6 @@ impl RmcpClient { params: InitializeRequestParams, timeout: Option, send_elicitation: SendElicitation, - request_headers: Option, ) -> Result { let client_handler = LoggingClientHandler::new(params.clone(), send_elicitation); let pending_transport = { @@ -618,8 +594,6 @@ impl RmcpClient { } }; - let _request_headers_guard = - RequestHeadersGuard::set(self.request_headers.clone(), request_headers); let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport(pending_transport, client_handler.clone(), timeout) .await?; @@ -676,11 +650,8 @@ impl RmcpClient { &self, params: Option, timeout: Option, - request_headers: Option, ) -> Result { self.refresh_oauth_if_needed().await; - let _request_headers_guard = - RequestHeadersGuard::set(self.request_headers.clone(), request_headers); let result = self .run_service_operation("tools/list", timeout, move |service| { let params = params.clone(); @@ -774,7 +745,6 @@ impl RmcpClient { arguments: Option, meta: Option, timeout: Option, - request_headers: Option, ) -> Result { self.refresh_oauth_if_needed().await; let arguments = match arguments { @@ -801,8 +771,6 @@ impl RmcpClient { arguments, task: None, }; - let _request_headers_guard = - RequestHeadersGuard::set(self.request_headers.clone(), request_headers); let result = self .run_service_operation("tools/call", timeout, move |service| { let rmcp_params = rmcp_params.clone(); diff --git a/codex-rs/rmcp-client/tests/resources.rs b/codex-rs/rmcp-client/tests/resources.rs index 4a3e587d11bb..ba1a8e43106a 100644 --- a/codex-rs/rmcp-client/tests/resources.rs +++ b/codex-rs/rmcp-client/tests/resources.rs @@ -78,7 +78,6 @@ async fn rmcp_client_can_list_and_read_resources() -> anyhow::Result<()> { } .boxed() }), - /*request_headers*/ None, ) .await?; diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index 6512e248693a..fb2fc96d20f1 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -94,7 +94,6 @@ async fn create_client(base_url: &str) -> anyhow::Result { } .boxed() }), - /*request_headers*/ None, ) .await?; @@ -108,7 +107,6 @@ async fn call_echo_tool(client: &RmcpClient, message: &str) -> anyhow::Result Date: Wed, 18 Mar 2026 17:12:13 -0700 Subject: [PATCH 5/6] codex: fix MCP header CI regressions Add the new request_headers test fixture state, update the streamable HTTP recovery test to pass the new client constructor argument, and satisfy the argument comment lint for clearing turn-scoped MCP request headers. Co-authored-by: Codex --- codex-rs/core/src/codex.rs | 5 ++++- codex-rs/core/src/mcp_connection_manager_tests.rs | 5 +++++ codex-rs/rmcp-client/tests/streamable_http_recovery.rs | 3 +++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index db0e4472e9e7..b6f7505d1b63 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -3986,7 +3986,10 @@ impl Session { .mcp_connection_manager .read() .await - .set_request_headers_for_server(crate::mcp::CODEX_APPS_MCP_SERVER_NAME, None); + .set_request_headers_for_server( + crate::mcp::CODEX_APPS_MCP_SERVER_NAME, + /*request_headers*/ None, + ); } pub(crate) async fn parse_mcp_tool_name( diff --git a/codex-rs/core/src/mcp_connection_manager_tests.rs b/codex-rs/core/src/mcp_connection_manager_tests.rs index c5f7fc4a4086..9401b379bcbf 100644 --- a/codex-rs/core/src/mcp_connection_manager_tests.rs +++ b/codex-rs/core/src/mcp_connection_manager_tests.rs @@ -4,6 +4,7 @@ use codex_protocol::protocol::McpAuthStatus; use rmcp::model::JsonObject; use std::collections::HashSet; use std::sync::Arc; +use std::sync::Mutex as StdMutex; use tempfile::tempdir; fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo { @@ -413,6 +414,7 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() { CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: pending_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: Some(startup_tools), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), @@ -438,6 +440,7 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot( CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: pending_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: None, startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), @@ -460,6 +463,7 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty( CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: pending_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: Some(Vec::new()), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), @@ -492,6 +496,7 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() { CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: failed_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: Some(startup_tools), startup_complete, tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index fb2fc96d20f1..8b03da8f1ad6 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -1,5 +1,7 @@ use std::net::TcpListener; use std::path::PathBuf; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; use std::time::Duration; use std::time::Instant; @@ -77,6 +79,7 @@ async fn create_client(base_url: &str) -> anyhow::Result { None, None, OAuthCredentialsStoreMode::File, + Arc::new(StdMutex::new(None)), ) .await?; From 55be8ca8029b75a6d96a8930730fabf8cec22cba Mon Sep 17 00:00:00 2001 From: ychhabria Date: Wed, 18 Mar 2026 18:01:31 -0700 Subject: [PATCH 6/6] Refresh codex apps MCP headers at tool call time --- codex-rs/core/src/codex.rs | 6 + codex-rs/core/src/codex_tests.rs | 99 ++++++++++ codex-rs/core/src/mcp_connection_manager.rs | 34 ++++ codex-rs/core/src/turn_metadata.rs | 8 + codex-rs/core/tests/suite/search_tool.rs | 201 ++++++++++++++++++++ 5 files changed, 348 insertions(+) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index b6f7505d1b63..fcd827fddc98 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -3945,6 +3945,12 @@ impl Session { arguments: Option, meta: Option, ) -> anyhow::Result { + if server == CODEX_APPS_MCP_SERVER_NAME + && let Some((turn_context, _)) = self.active_turn_context_and_cancellation_token().await + { + self.sync_mcp_request_headers_for_turn(turn_context.as_ref()) + .await; + } self.services .mcp_connection_manager .read() diff --git a/codex-rs/core/src/codex_tests.rs b/codex-rs/core/src/codex_tests.rs index f767c05f4eb0..22e83ee216cc 100644 --- a/codex-rs/core/src/codex_tests.rs +++ b/codex-rs/core/src/codex_tests.rs @@ -26,6 +26,8 @@ use codex_protocol::protocol::ReadOnlyAccess; use codex_protocol::protocol::SandboxPolicy; use codex_protocol::request_permissions::PermissionGrantScope; use codex_protocol::request_permissions::RequestPermissionProfile; +use codex_protocol::user_input::UserInput; +use reqwest::header::HeaderValue; use tracing::Span; use crate::protocol::CompactedItem; @@ -56,6 +58,7 @@ use crate::tools::handlers::UnifiedExecHandler; use crate::tools::registry::ToolHandler; use crate::tools::router::ToolCallSource; use crate::turn_diff_tracker::TurnDiffTracker; +use async_trait::async_trait; use codex_app_server_protocol::AppInfo; use codex_execpolicy::Decision; use codex_execpolicy::NetworkRuleProtocol; @@ -78,7 +81,10 @@ use opentelemetry::trace::TracerProvider as _; use opentelemetry_sdk::trace::SdkTracerProvider; use std::path::Path; use std::time::Duration; +use tokio::sync::Notify; use tokio::time::sleep; +use tokio_util::sync::CancellationToken; +use tokio_util::task::AbortOnDropHandle; use tracing_opentelemetry::OpenTelemetrySpanExt; use tracing_subscriber::prelude::*; @@ -2716,6 +2722,99 @@ async fn request_permissions_is_auto_denied_when_granular_policy_blocks_tool_req ); } +struct NoopTask; + +#[async_trait] +impl SessionTask for NoopTask { + fn kind(&self) -> TaskKind { + TaskKind::Regular + } + + fn span_name(&self) -> &'static str { + "noop" + } + + async fn run( + self: Arc, + _ctx: Arc, + _turn_context: Arc, + _input: Vec, + _cancellation_token: CancellationToken, + ) -> Option { + None + } +} + +#[tokio::test] +async fn call_tool_refreshes_codex_apps_request_headers_from_active_turn() { + let (session, turn_context) = make_session_and_context().await; + let session = Arc::new(session); + let turn_context = Arc::new(turn_context); + + session + .services + .mcp_connection_manager + .write() + .await + .register_test_server_for_request_headers(CODEX_APPS_MCP_SERVER_NAME); + + session + .sync_mcp_request_headers_for_turn(turn_context.as_ref()) + .await; + let base_header = turn_context + .turn_metadata_state + .current_header_value() + .expect("base turn metadata header"); + + let updated_header = serde_json::json!({ + "turn_id": turn_context.sub_id, + "sandbox": "test", + "workspaces": { + "/tmp/repo": { + "has_changes": true + } + } + }) + .to_string(); + turn_context + .turn_metadata_state + .set_enriched_header_for_tests(Some(updated_header.clone())); + + let mut active_turn = ActiveTurn::default(); + let handle = tokio::spawn(async {}); + active_turn.add_task(crate::state::RunningTask { + done: Arc::new(Notify::new()), + kind: TaskKind::Regular, + task: Arc::new(NoopTask), + cancellation_token: CancellationToken::new(), + handle: Arc::new(AbortOnDropHandle::new(handle)), + turn_context: Arc::clone(&turn_context), + _timer: None, + }); + *session.active_turn.lock().await = Some(active_turn); + + let _err = session + .call_tool(CODEX_APPS_MCP_SERVER_NAME, "echo", None, None) + .await + .expect_err("test server is not initialized"); + + let headers = session + .services + .mcp_connection_manager + .read() + .await + .request_headers_for_server(CODEX_APPS_MCP_SERVER_NAME) + .expect("request headers should be tracked for codex apps"); + assert_eq!( + headers.get(crate::X_CODEX_TURN_METADATA_HEADER), + Some(&HeaderValue::from_str(&updated_header).expect("valid enriched header")), + ); + assert_ne!( + headers.get(crate::X_CODEX_TURN_METADATA_HEADER), + Some(&HeaderValue::from_str(&base_header).expect("valid base header")), + ); +} + #[tokio::test] async fn submit_with_id_captures_current_span_trace_context() { let (session, _turn_context) = make_session_and_context().await; diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 7c8a34307022..cd80220d716b 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -636,6 +636,40 @@ impl McpConnectionManager { Self::new_uninitialized(approval_policy) } + #[cfg(test)] + pub(crate) fn register_test_server_for_request_headers(&mut self, server_name: &str) { + let failed_client = futures::future::ready::>( + Err(StartupOutcomeError::Failed { + error: "test request headers stub".to_string(), + }), + ) + .boxed() + .shared(); + self.clients.insert( + server_name.to_string(), + AsyncManagedClient { + client: failed_client, + request_headers: Arc::new(StdMutex::new(None)), + startup_snapshot: None, + startup_complete: Arc::new(AtomicBool::new(true)), + tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), + }, + ); + } + + #[cfg(test)] + pub(crate) fn request_headers_for_server( + &self, + server_name: &str, + ) -> Option { + let client = self.clients.get(server_name)?; + client + .request_headers + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone() + } + pub(crate) fn has_servers(&self) -> bool { !self.clients.is_empty() } diff --git a/codex-rs/core/src/turn_metadata.rs b/codex-rs/core/src/turn_metadata.rs index c0298c522122..105675272869 100644 --- a/codex-rs/core/src/turn_metadata.rs +++ b/codex-rs/core/src/turn_metadata.rs @@ -217,6 +217,14 @@ impl TurnMetadataState { } } + #[cfg(test)] + pub(crate) fn set_enriched_header_for_tests(&self, header: Option) { + *self + .enriched_header + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) = header; + } + async fn fetch_workspace_git_metadata(&self) -> WorkspaceGitMetadata { let (latest_git_commit_hash, associated_remote_urls, has_changes) = tokio::join!( get_head_commit_hash(&self.cwd), diff --git a/codex-rs/core/tests/suite/search_tool.rs b/codex-rs/core/tests/suite/search_tool.rs index 118f1bd585c9..62659aa4de68 100644 --- a/codex-rs/core/tests/suite/search_tool.rs +++ b/codex-rs/core/tests/suite/search_tool.rs @@ -30,6 +30,8 @@ use core_test_support::wait_for_event; use pretty_assertions::assert_eq; use serde_json::Value; use serde_json::json; +use std::fs; +use std::process::Command; const SEARCH_TOOL_DESCRIPTION_SNIPPETS: [&str; 2] = [ "You have access to all the tools of the following apps/connectors", @@ -86,6 +88,15 @@ fn tool_search_output_tools(request: &ResponsesRequest, call_id: &str) -> Vec Option { + request + .body_json::() + .ok()? + .get("method") + .and_then(Value::as_str) + .map(str::to_string) +} + fn configure_apps(config: &mut Config, apps_base_url: &str) { config .features @@ -499,5 +510,195 @@ async fn tool_search_returns_deferred_tools_without_follow_up_tool_injection() - "post-tool follow-up should still rely on tool_search_output history, not tool injection: {third_request_tools:?}" ); + let mcp_requests = server + .received_requests() + .await + .expect("failed to fetch recorded requests"); + let tools_list_request = mcp_requests + .iter() + .find(|request| json_rpc_method(request).as_deref() == Some("tools/list")) + .expect("tools/list MCP request"); + assert!( + tools_list_request + .headers + .get("x-codex-turn-metadata") + .is_none(), + "tools/list should not include per-turn MCP headers" + ); + + let tools_call_request = mcp_requests + .iter() + .find(|request| json_rpc_method(request).as_deref() == Some("tools/call")) + .expect("tools/call MCP request"); + let session_id_header = tools_call_request + .headers + .get("session_id") + .expect("tools/call session_id header"); + let request_id_header = tools_call_request + .headers + .get("x-client-request-id") + .expect("tools/call x-client-request-id header"); + let turn_metadata_header = tools_call_request + .headers + .get("x-codex-turn-metadata") + .expect("tools/call turn metadata header"); + assert_eq!( + session_id_header + .to_str() + .expect("session_id header to be utf8"), + request_id_header + .to_str() + .expect("x-client-request-id header to be utf8") + ); + assert!( + turn_metadata_header + .to_str() + .expect("turn metadata header to be utf8") + .contains("\"turn_id\""), + "expected turn metadata header to contain serialized turn metadata" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn apps_mcp_tool_call_uses_enriched_turn_metadata_header() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let apps_server = AppsTestServer::mount_searchable(&server).await?; + let call_id = "tool-search-git-metadata"; + let mock = mount_sse_sequence( + &server, + vec![ + sse(vec![ + ev_response_created("resp-1"), + ev_tool_search_call( + call_id, + &json!({ + "query": "create calendar event", + "limit": 1, + }), + ), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_response_created("resp-2"), + json!({ + "type": "response.output_item.done", + "item": { + "type": "function_call", + "call_id": "calendar-call-git-metadata", + "name": SEARCH_CALENDAR_CREATE_TOOL, + "namespace": SEARCH_CALENDAR_NAMESPACE, + "arguments": serde_json::to_string(&json!({ + "title": "Lunch", + "starts_at": "2026-03-10T12:00:00Z" + })).expect("serialize calendar args") + } + }), + ev_completed("resp-2"), + ]), + sse(vec![ + ev_response_created("resp-3"), + ev_assistant_message("msg-1", "done"), + ev_completed("resp-3"), + ]), + ], + ) + .await; + + let mut builder = configured_builder(apps_server.chatgpt_base_url.clone()); + let test = builder.build(&server).await?; + + let cwd = test.cwd_path().to_path_buf(); + assert!( + Command::new("git") + .arg("init") + .current_dir(&cwd) + .status()? + .success() + ); + assert!( + Command::new("git") + .args(["config", "user.name", "Codex Test"]) + .current_dir(&cwd) + .status()? + .success() + ); + assert!( + Command::new("git") + .args(["config", "user.email", "codex@example.com"]) + .current_dir(&cwd) + .status()? + .success() + ); + assert!( + Command::new("git") + .args(["remote", "add", "origin", "https://example.test/repo.git"]) + .current_dir(&cwd) + .status()? + .success() + ); + for idx in 0..400 { + fs::write( + cwd.join(format!("file-{idx:04}.txt")), + format!("fixture file {idx}\n"), + )?; + } + assert!( + Command::new("git") + .args(["add", "."]) + .current_dir(&cwd) + .status()? + .success() + ); + assert!( + Command::new("git") + .args(["commit", "-m", "init"]) + .current_dir(&cwd) + .status()? + .success() + ); + + test.codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "Find the calendar create tool".to_string(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + + wait_for_event(&test.codex, |event| { + matches!(event, EventMsg::TurnComplete(_)) + }) + .await; + + let _requests = mock.requests(); + let mcp_requests = server + .received_requests() + .await + .expect("failed to fetch recorded requests"); + let tools_call_request = mcp_requests + .iter() + .find(|request| json_rpc_method(request).as_deref() == Some("tools/call")) + .expect("tools/call MCP request"); + let turn_metadata_header = tools_call_request + .headers + .get("x-codex-turn-metadata") + .expect("tools/call turn metadata header") + .to_str() + .expect("turn metadata header to be utf8"); + let parsed: Value = serde_json::from_str(turn_metadata_header)?; + assert!( + parsed + .get("workspaces") + .and_then(Value::as_object) + .is_some_and(|workspaces| !workspaces.is_empty()), + "expected enriched MCP turn metadata header with workspace git metadata, got {parsed:#?}" + ); + Ok(()) }