From bccce0f2d8e94a961dfb90c47ed7855010118136 Mon Sep 17 00:00:00 2001 From: nicholasclark-openai Date: Tue, 17 Mar 2026 18:14:32 -0700 Subject: [PATCH 1/5] Forward tool call task headers to MCP HTTP requests Co-authored-by: Codex --- codex-rs/core/src/codex.rs | 3 +- codex-rs/core/src/codex_delegate.rs | 1 + codex-rs/core/src/mcp_connection_manager.rs | 137 ++++++++++++++---- .../core/src/mcp_connection_manager_tests.rs | 5 + codex-rs/core/src/mcp_tool_call.rs | 69 ++++++++- codex-rs/rmcp-client/src/rmcp_client.rs | 106 +++++++++++++- codex-rs/rmcp-client/tests/resources.rs | 1 + .../tests/streamable_http_recovery.rs | 2 + 8 files changed, 283 insertions(+), 41 deletions(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 9382401bacdf..45d7486248f1 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -3912,12 +3912,13 @@ impl Session { tool: &str, arguments: Option, meta: Option, + request_headers: Option, ) -> anyhow::Result { self.services .mcp_connection_manager .read() .await - .call_tool(server, tool, arguments, meta) + .call_tool(server, tool, arguments, meta, request_headers) .await } diff --git a/codex-rs/core/src/codex_delegate.rs b/codex-rs/core/src/codex_delegate.rs index 4369b81dff3b..6bbc4f01bec6 100644 --- a/codex-rs/core/src/codex_delegate.rs +++ b/codex-rs/core/src/codex_delegate.rs @@ -669,6 +669,7 @@ async fn maybe_auto_review_mcp_request_user_input( parent_ctx.as_ref(), &invocation.server, &invocation.tool, + None, ) .await; let review_cancel = cancel_token.child_token(); diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 938d6d0b2bf3..939affa9dff4 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -378,7 +378,7 @@ struct ManagedClient { } impl ManagedClient { - fn listed_tools(&self) -> Vec { + fn listed_tools(&self, _request_headers: Option) -> Vec { let total_start = Instant::now(); if let Some(cache_context) = self.codex_apps_tools_cache_context.as_ref() && let CachedCodexAppsToolsLoad::Hit(tools) = @@ -425,9 +425,42 @@ struct AsyncManagedClient { client: Shared>>, startup_snapshot: Option>, startup_complete: Arc, + startup_request_headers: Arc>>, tool_plugin_provenance: Arc, } +struct StartupRequestHeadersGuard { + state: Arc>>, + previous: Option, +} + +impl StartupRequestHeadersGuard { + fn set( + state: Arc>>, + headers: Option, + ) -> Self { + let previous = { + let mut guard = state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let previous = guard.clone(); + *guard = headers; + previous + }; + Self { state, previous } + } +} + +impl Drop for StartupRequestHeadersGuard { + fn drop(&mut self) { + let mut guard = self + .state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *guard = self.previous.clone(); + } +} + impl AsyncManagedClient { // Keep this constructor flat so the startup inputs remain readable at the // single call site instead of introducing a one-off params wrapper. @@ -451,6 +484,8 @@ impl AsyncManagedClient { let startup_tool_filter = tool_filter; let startup_complete = Arc::new(AtomicBool::new(false)); let startup_complete_for_fut = Arc::clone(&startup_complete); + let startup_request_headers = Arc::new(StdMutex::new(None)); + let startup_request_headers_for_fut = Arc::clone(&startup_request_headers); let fut = async move { let outcome = async { if let Err(error) = validate_mcp_server_name(&server_name) { @@ -471,6 +506,7 @@ impl AsyncManagedClient { tx_event, elicitation_requests, codex_apps_tools_cache_context, + startup_request_headers: startup_request_headers_for_fut, }, ) .or_cancel(&cancel_token) @@ -497,11 +533,17 @@ impl AsyncManagedClient { client, startup_snapshot, startup_complete, + startup_request_headers, tool_plugin_provenance, } } - async fn client(&self) -> Result { + async fn client( + &self, + request_headers: Option, + ) -> Result { + let _request_headers_guard = + StartupRequestHeadersGuard::set(self.startup_request_headers.clone(), request_headers); self.client.clone().await } @@ -512,7 +554,10 @@ impl AsyncManagedClient { None } - async fn listed_tools(&self) -> Option> { + async fn listed_tools( + &self, + request_headers: Option, + ) -> Option> { let annotate_tools = |tools: Vec| { let mut tools = tools; for tool in &mut tools { @@ -564,8 +609,8 @@ impl AsyncManagedClient { let tools = if let Some(startup_tools) = self.startup_snapshot_while_initializing() { Some(startup_tools) } else { - match self.client().await { - Ok(client) => Some(client.listed_tools()), + match self.client(request_headers.clone()).await { + Ok(client) => Some(client.listed_tools(request_headers)), Err(_) => self.startup_snapshot.clone(), } }; @@ -573,7 +618,7 @@ impl AsyncManagedClient { } async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> { - let managed = self.client().await?; + let managed = self.client(None).await?; managed.notify_sandbox_state_change(sandbox_state).await } } @@ -686,7 +731,7 @@ impl McpConnectionManager { let auth_entry = auth_entries.get(&server_name).cloned(); let sandbox_state = initial_sandbox_state.clone(); join_set.spawn(async move { - let outcome = async_managed_client.client().await; + let outcome = async_managed_client.client(None).await; if cancel_token.is_cancelled() { return (server_name, Err(StartupOutcomeError::Cancelled)); } @@ -755,11 +800,15 @@ impl McpConnectionManager { (manager, cancel_token) } - async fn client_by_name(&self, name: &str) -> Result { + async fn client_by_name( + &self, + name: &str, + request_headers: Option, + ) -> Result { self.clients .get(name) .ok_or_else(|| anyhow!("unknown MCP server '{name}'"))? - .client() + .client(request_headers) .await .context("failed to get client") } @@ -780,7 +829,7 @@ impl McpConnectionManager { return false; }; - match tokio::time::timeout(timeout, async_managed_client.client()).await { + match tokio::time::timeout(timeout, async_managed_client.client(None)).await { Ok(Ok(_)) => true, Ok(Err(_)) | Err(_) => false, } @@ -800,7 +849,7 @@ impl McpConnectionManager { continue; }; - match async_managed_client.client().await { + match async_managed_client.client(None).await { Ok(_) => {} Err(error) => failures.push(McpStartupFailure { server: server_name.clone(), @@ -815,9 +864,18 @@ impl McpConnectionManager { /// fully-qualified name for the tool. #[instrument(level = "trace", skip_all)] pub async fn list_all_tools(&self) -> HashMap { + self.list_all_tools_with_request_headers(None).await + } + + #[instrument(level = "trace", skip_all)] + pub async fn list_all_tools_with_request_headers( + &self, + request_headers: Option, + ) -> HashMap { let mut tools = HashMap::new(); for managed_client in self.clients.values() { - let Some(server_tools) = managed_client.listed_tools().await else { + let Some(server_tools) = managed_client.listed_tools(request_headers.clone()).await + else { continue; }; tools.extend(qualify_tools(server_tools)); @@ -835,7 +893,7 @@ impl McpConnectionManager { .clients .get(CODEX_APPS_MCP_SERVER_NAME) .ok_or_else(|| anyhow!("unknown MCP server '{CODEX_APPS_MCP_SERVER_NAME}'"))? - .client() + .client(None) .await .context("failed to get client")?; @@ -845,6 +903,7 @@ impl McpConnectionManager { CODEX_APPS_MCP_SERVER_NAME, &managed_client.client, managed_client.tool_timeout, + None, ) .await .with_context(|| { @@ -881,7 +940,7 @@ impl McpConnectionManager { for (server_name, async_managed_client) in clients_snapshot { let server_name = server_name.clone(); - let Ok(managed_client) = async_managed_client.client().await else { + let Ok(managed_client) = async_managed_client.client(None).await else { continue; }; let timeout = managed_client.tool_timeout; @@ -947,7 +1006,7 @@ impl McpConnectionManager { for (server_name, async_managed_client) in clients_snapshot { let server_name_cloned = server_name.clone(); - let Ok(managed_client) = async_managed_client.client().await else { + let Ok(managed_client) = async_managed_client.client(None).await else { continue; }; let client = managed_client.client.clone(); @@ -1015,8 +1074,9 @@ impl McpConnectionManager { tool: &str, arguments: Option, meta: Option, + request_headers: Option, ) -> Result { - let client = self.client_by_name(server).await?; + let client = self.client_by_name(server, request_headers.clone()).await?; if !client.tool_filter.allows(tool) { return Err(anyhow!( "tool '{tool}' is disabled for MCP server '{server}'" @@ -1025,7 +1085,13 @@ impl McpConnectionManager { let result: rmcp::model::CallToolResult = client .client - .call_tool(tool.to_string(), arguments, meta, client.tool_timeout) + .call_tool( + tool.to_string(), + arguments, + meta, + client.tool_timeout, + request_headers, + ) .await .with_context(|| format!("tool call failed for `{server}/{tool}`"))?; @@ -1052,7 +1118,7 @@ impl McpConnectionManager { server: &str, params: Option, ) -> Result { - let managed = self.client_by_name(server).await?; + let managed = self.client_by_name(server, None).await?; let timeout = managed.tool_timeout; managed @@ -1068,7 +1134,7 @@ impl McpConnectionManager { server: &str, params: Option, ) -> Result { - let managed = self.client_by_name(server).await?; + let managed = self.client_by_name(server, None).await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; @@ -1084,7 +1150,7 @@ impl McpConnectionManager { server: &str, params: ReadResourceRequestParams, ) -> Result { - let managed = self.client_by_name(server).await?; + let managed = self.client_by_name(server, None).await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; let uri = params.uri.clone(); @@ -1344,6 +1410,7 @@ async fn start_server_task( tx_event, elicitation_requests, codex_apps_tools_cache_context, + startup_request_headers, } = params; let elicitation = elicitation_capability_for_server(&server_name); let params = InitializeRequestParams { @@ -1369,16 +1436,34 @@ async fn start_server_task( let send_elicitation = elicitation_requests.make_sender(server_name.clone(), tx_event); + let initialize_request_headers = startup_request_headers + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone(); let initialize_result = client - .initialize(params, startup_timeout, send_elicitation) + .initialize( + params, + startup_timeout, + send_elicitation, + initialize_request_headers, + ) .await .map_err(StartupOutcomeError::from)?; let list_start = Instant::now(); let fetch_start = Instant::now(); - let tools = list_tools_for_client_uncached(&server_name, &client, startup_timeout) - .await - .map_err(StartupOutcomeError::from)?; + let list_request_headers = startup_request_headers + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone(); + let tools = list_tools_for_client_uncached( + &server_name, + &client, + startup_timeout, + list_request_headers, + ) + .await + .map_err(StartupOutcomeError::from)?; emit_duration( MCP_TOOLS_FETCH_UNCACHED_DURATION_METRIC, fetch_start.elapsed(), @@ -1423,6 +1508,7 @@ struct StartServerTaskParams { tx_event: Sender, elicitation_requests: ElicitationRequestManager, codex_apps_tools_cache_context: Option, + startup_request_headers: Arc>>, } async fn make_rmcp_client( @@ -1584,9 +1670,10 @@ async fn list_tools_for_client_uncached( server_name: &str, client: &Arc, timeout: Option, + request_headers: Option, ) -> Result> { let resp = client - .list_tools_with_connector_ids(/*params*/ None, timeout) + .list_tools_with_connector_ids(/*params*/ None, timeout, request_headers) .await?; let tools = resp .tools diff --git a/codex-rs/core/src/mcp_connection_manager_tests.rs b/codex-rs/core/src/mcp_connection_manager_tests.rs index c5f7fc4a4086..980f4091a829 100644 --- a/codex-rs/core/src/mcp_connection_manager_tests.rs +++ b/codex-rs/core/src/mcp_connection_manager_tests.rs @@ -4,6 +4,7 @@ use codex_protocol::protocol::McpAuthStatus; use rmcp::model::JsonObject; use std::collections::HashSet; use std::sync::Arc; +use std::sync::Mutex as StdMutex; use tempfile::tempdir; fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo { @@ -415,6 +416,7 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() { client: pending_client, startup_snapshot: Some(startup_tools), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), + startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); @@ -440,6 +442,7 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot( client: pending_client, startup_snapshot: None, startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), + startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); @@ -462,6 +465,7 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty( client: pending_client, startup_snapshot: Some(Vec::new()), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), + startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); @@ -494,6 +498,7 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() { client: failed_client, startup_snapshot: Some(startup_tools), startup_complete, + startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); diff --git a/codex-rs/core/src/mcp_tool_call.rs b/codex-rs/core/src/mcp_tool_call.rs index 06d801cbac8c..cfa6b471db26 100644 --- a/codex-rs/core/src/mcp_tool_call.rs +++ b/codex-rs/core/src/mcp_tool_call.rs @@ -45,6 +45,8 @@ use codex_protocol::request_user_input::RequestUserInputQuestionOption; use codex_protocol::request_user_input::RequestUserInputResponse; use codex_rmcp_client::ElicitationAction; use codex_rmcp_client::ElicitationResponse; +use reqwest::header::HeaderMap; +use reqwest::header::HeaderValue; use rmcp::model::ToolAnnotations; use serde::Serialize; use std::path::Path; @@ -81,8 +83,15 @@ pub(crate) async fn handle_mcp_tool_call( arguments: arguments_value.clone(), }; - let metadata = - lookup_mcp_tool_metadata(sess.as_ref(), turn_context.as_ref(), &server, &tool_name).await; + let request_headers = build_mcp_request_headers(sess.as_ref(), turn_context.as_ref(), &server); + let metadata = lookup_mcp_tool_metadata( + sess.as_ref(), + turn_context.as_ref(), + &server, + &tool_name, + request_headers.clone(), + ) + .await; let app_tool_policy = if server == CODEX_APPS_MCP_SERVER_NAME { connectors::app_tool_policy( &turn_context.config, @@ -150,6 +159,7 @@ pub(crate) async fn handle_mcp_tool_call( &tool_name, arguments_value.clone(), request_meta.clone(), + request_headers.clone(), ) .await .map_err(|e| format!("tool call error: {e:?}")); @@ -180,6 +190,7 @@ pub(crate) async fn handle_mcp_tool_call( turn_context.as_ref(), &server, &tool_name, + request_headers.clone(), ) .await; result @@ -236,7 +247,13 @@ pub(crate) async fn handle_mcp_tool_call( let start = Instant::now(); // Perform the tool call. let result = sess - .call_tool(&server, &tool_name, arguments_value.clone(), request_meta) + .call_tool( + &server, + &tool_name, + arguments_value.clone(), + request_meta, + request_headers.clone(), + ) .await .map_err(|e| format!("tool call error: {e:?}")); let result = sanitize_mcp_tool_result_for_model( @@ -262,7 +279,14 @@ pub(crate) async fn handle_mcp_tool_call( tool_call_end_event.clone(), ) .await; - maybe_track_codex_app_used(sess.as_ref(), turn_context.as_ref(), &server, &tool_name).await; + maybe_track_codex_app_used( + sess.as_ref(), + turn_context.as_ref(), + &server, + &tool_name, + request_headers.clone(), + ) + .await; let status = if result.is_ok() { "ok" } else { "error" }; turn_context @@ -333,11 +357,12 @@ async fn maybe_track_codex_app_used( turn_context: &TurnContext, server: &str, tool_name: &str, + request_headers: Option, ) { if server != CODEX_APPS_MCP_SERVER_NAME { return; } - let metadata = lookup_mcp_app_usage_metadata(sess, server, tool_name).await; + let metadata = lookup_mcp_app_usage_metadata(sess, server, tool_name, request_headers).await; let (connector_id, app_name) = metadata .map(|metadata| (metadata.connector_id, metadata.app_name)) .unwrap_or((None, None)); @@ -389,6 +414,34 @@ pub(crate) struct McpToolApprovalMetadata { const MCP_TOOL_CODEX_APPS_META_KEY: &str = "_codex_apps"; +fn build_mcp_request_headers( + sess: &Session, + turn_context: &TurnContext, + server: &str, +) -> Option { + if server != CODEX_APPS_MCP_SERVER_NAME { + return None; + } + + let session_id = sess.conversation_id.to_string(); + let mut headers = HeaderMap::new(); + if let Ok(value) = HeaderValue::from_str(&session_id) { + headers.insert("session_id", value.clone()); + headers.insert("x-client-request-id", value); + } + if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value() + && let Ok(value) = HeaderValue::from_str(&turn_metadata) + { + headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value); + } + + if headers.is_empty() { + None + } else { + Some(headers) + } +} + fn build_mcp_tool_call_request_meta( server: &str, metadata: Option<&McpToolApprovalMetadata>, @@ -741,13 +794,14 @@ pub(crate) async fn lookup_mcp_tool_metadata( turn_context: &TurnContext, server: &str, tool_name: &str, + request_headers: Option, ) -> Option { let tools = sess .services .mcp_connection_manager .read() .await - .list_all_tools() + .list_all_tools_with_request_headers(request_headers.clone()) .await; let tool_info = tools @@ -798,13 +852,14 @@ async fn lookup_mcp_app_usage_metadata( sess: &Session, server: &str, tool_name: &str, + request_headers: Option, ) -> Option { let tools = sess .services .mcp_connection_manager .read() .await - .list_all_tools() + .list_all_tools_with_request_headers(request_headers) .await; tools.into_values().find_map(|tool_info| { diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index b898403b25c7..42aeb149ad6d 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -5,6 +5,7 @@ use std::io; use std::path::PathBuf; use std::process::Stdio; use std::sync::Arc; +use std::sync::Mutex as StdMutex; use std::time::Duration; use anyhow::Result; @@ -82,15 +83,37 @@ const JSON_MIME_TYPE: &str = "application/json"; const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id"; const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192; +fn apply_request_scoped_headers( + mut request: reqwest::RequestBuilder, + request_headers_state: &Arc>>, +) -> reqwest::RequestBuilder { + let extra_headers = request_headers_state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone(); + if let Some(extra_headers) = extra_headers { + for (name, value) in &extra_headers { + request = request.header(name, value.clone()); + } + } + request +} #[derive(Clone)] struct StreamableHttpResponseClient { inner: reqwest::Client, + request_headers_state: Arc>>, } impl StreamableHttpResponseClient { - fn new(inner: reqwest::Client) -> Self { - Self { inner } + fn new( + inner: reqwest::Client, + request_headers_state: Arc>>, + ) -> Self { + Self { + inner, + request_headers_state, + } } fn reqwest_error( @@ -133,6 +156,7 @@ impl StreamableHttpClient for StreamableHttpResponseClient { if let Some(session_id_value) = session_id.as_ref() { request = request.header(HEADER_SESSION_ID, session_id_value.as_ref()); } + request = apply_request_scoped_headers(request, &self.request_headers_state); let response = request .json(&message) @@ -228,6 +252,8 @@ impl StreamableHttpClient for StreamableHttpResponseClient { if let Some(auth_header) = auth_token { request_builder = request_builder.bearer_auth(auth_header); } + request_builder = + apply_request_scoped_headers(request_builder, &self.request_headers_state); let response = request_builder .header(HEADER_SESSION_ID, session.as_ref()) .send() @@ -265,6 +291,8 @@ impl StreamableHttpClient for StreamableHttpResponseClient { if let Some(auth_header) = auth_token { request_builder = request_builder.bearer_auth(auth_header); } + request_builder = + apply_request_scoped_headers(request_builder, &self.request_headers_state); let response = request_builder .send() @@ -472,6 +500,39 @@ pub struct RmcpClient { transport_recipe: TransportRecipe, initialize_context: Mutex>, session_recovery_lock: Mutex<()>, + request_headers: Option>>>, +} + +struct RequestHeadersGuard { + state: Option>>>, + previous: Option, +} + +impl RequestHeadersGuard { + fn set(state: Option>>>, headers: Option) -> Self { + let previous = if let Some(state_ref) = state.as_ref() { + let mut guard = state_ref + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let previous = guard.clone(); + *guard = headers; + previous + } else { + None + }; + Self { state, previous } + } +} + +impl Drop for RequestHeadersGuard { + fn drop(&mut self) { + if let Some(state) = self.state.as_ref() { + let mut guard = state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *guard = self.previous.clone(); + } + } } impl RmcpClient { @@ -489,7 +550,7 @@ impl RmcpClient { env_vars: env_vars.to_vec(), cwd, }; - let transport = Self::create_pending_transport(&transport_recipe) + let transport = Self::create_pending_transport(&transport_recipe, None) .await .map_err(io::Error::other)?; @@ -500,6 +561,7 @@ impl RmcpClient { transport_recipe, initialize_context: Mutex::new(None), session_recovery_lock: Mutex::new(()), + request_headers: None, }) } @@ -520,7 +582,9 @@ impl RmcpClient { env_http_headers, store_mode, }; - let transport = Self::create_pending_transport(&transport_recipe).await?; + let request_headers = Some(Arc::new(StdMutex::new(None))); + let transport = + Self::create_pending_transport(&transport_recipe, request_headers.clone()).await?; Ok(Self { state: Mutex::new(ClientState::Connecting { transport: Some(transport), @@ -528,6 +592,7 @@ impl RmcpClient { transport_recipe, initialize_context: Mutex::new(None), session_recovery_lock: Mutex::new(()), + request_headers, }) } @@ -538,6 +603,7 @@ impl RmcpClient { params: InitializeRequestParams, timeout: Option, send_elicitation: SendElicitation, + request_headers: Option, ) -> Result { let client_handler = LoggingClientHandler::new(params.clone(), send_elicitation); let pending_transport = { @@ -551,6 +617,8 @@ impl RmcpClient { } }; + let _request_headers_guard = + RequestHeadersGuard::set(self.request_headers.clone(), request_headers); let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport(pending_transport, client_handler.clone(), timeout) .await?; @@ -607,8 +675,11 @@ impl RmcpClient { &self, params: Option, timeout: Option, + request_headers: Option, ) -> Result { self.refresh_oauth_if_needed().await; + let _request_headers_guard = + RequestHeadersGuard::set(self.request_headers.clone(), request_headers); let result = self .run_service_operation("tools/list", timeout, move |service| { let params = params.clone(); @@ -702,6 +773,7 @@ impl RmcpClient { arguments: Option, meta: Option, timeout: Option, + request_headers: Option, ) -> Result { self.refresh_oauth_if_needed().await; let arguments = match arguments { @@ -728,6 +800,8 @@ impl RmcpClient { arguments, task: None, }; + let _request_headers_guard = + RequestHeadersGuard::set(self.request_headers.clone(), request_headers); let result = self .run_service_operation("tools/call", timeout, move |service| { let rmcp_params = rmcp_params.clone(); @@ -830,6 +904,7 @@ impl RmcpClient { async fn create_pending_transport( transport_recipe: &TransportRecipe, + request_headers: Option>>>, ) -> Result { match transport_recipe { TransportRecipe::Stdio { @@ -946,7 +1021,12 @@ impl RmcpClient { .auth_header(access_token); let http_client = build_http_client(&default_headers)?; let transport = StreamableHttpClientTransport::with_client( - StreamableHttpResponseClient::new(http_client), + StreamableHttpResponseClient::new( + http_client, + request_headers + .clone() + .unwrap_or_else(|| Arc::new(StdMutex::new(None))), + ), http_config, ); Ok(PendingTransport::StreamableHttp { transport }) @@ -963,7 +1043,12 @@ impl RmcpClient { let http_client = build_http_client(&default_headers)?; let transport = StreamableHttpClientTransport::with_client( - StreamableHttpResponseClient::new(http_client), + StreamableHttpResponseClient::new( + http_client, + request_headers + .clone() + .unwrap_or_else(|| Arc::new(StdMutex::new(None))), + ), http_config, ); Ok(PendingTransport::StreamableHttp { transport }) @@ -1111,7 +1196,9 @@ impl RmcpClient { .await .clone() .ok_or_else(|| anyhow!("MCP client cannot recover before initialize succeeds"))?; - let pending_transport = Self::create_pending_transport(&self.transport_recipe).await?; + let pending_transport = + Self::create_pending_transport(&self.transport_recipe, self.request_headers.clone()) + .await?; let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport( pending_transport, initialize_context.handler, @@ -1166,7 +1253,10 @@ async fn create_oauth_transport_and_runtime( } }; - let auth_client = AuthClient::new(StreamableHttpResponseClient::new(http_client), manager); + let auth_client = AuthClient::new( + StreamableHttpResponseClient::new(http_client, Arc::new(StdMutex::new(None))), + manager, + ); let auth_manager = auth_client.auth_manager.clone(); let transport = StreamableHttpClientTransport::with_client( diff --git a/codex-rs/rmcp-client/tests/resources.rs b/codex-rs/rmcp-client/tests/resources.rs index ba1a8e43106a..4a3e587d11bb 100644 --- a/codex-rs/rmcp-client/tests/resources.rs +++ b/codex-rs/rmcp-client/tests/resources.rs @@ -78,6 +78,7 @@ async fn rmcp_client_can_list_and_read_resources() -> anyhow::Result<()> { } .boxed() }), + /*request_headers*/ None, ) .await?; diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index fb2fc96d20f1..6512e248693a 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -94,6 +94,7 @@ async fn create_client(base_url: &str) -> anyhow::Result { } .boxed() }), + /*request_headers*/ None, ) .await?; @@ -107,6 +108,7 @@ async fn call_echo_tool(client: &RmcpClient, message: &str) -> anyhow::Result Date: Wed, 18 Mar 2026 09:12:10 -0700 Subject: [PATCH 2/5] codex: fix CI failure on PR #15011 Co-authored-by: Codex --- codex-rs/rmcp-client/src/rmcp_client.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 42aeb149ad6d..86d14cf881bf 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -550,9 +550,10 @@ impl RmcpClient { env_vars: env_vars.to_vec(), cwd, }; - let transport = Self::create_pending_transport(&transport_recipe, None) - .await - .map_err(io::Error::other)?; + let transport = + Self::create_pending_transport(&transport_recipe, /*request_headers*/ None) + .await + .map_err(io::Error::other)?; Ok(Self { state: Mutex::new(ClientState::Connecting { From de9864340312e1c60307a8c55a965676c6d3db56 Mon Sep 17 00:00:00 2001 From: nicholasclark-openai Date: Wed, 18 Mar 2026 09:41:06 -0700 Subject: [PATCH 3/5] codex: fix CI failure on PR #15011 Co-authored-by: Codex --- codex-rs/core/src/auth_env_telemetry.rs | 1 + codex-rs/core/src/codex_delegate.rs | 2 +- codex-rs/core/src/mcp_connection_manager.rs | 38 ++++++++++++++------- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/codex-rs/core/src/auth_env_telemetry.rs b/codex-rs/core/src/auth_env_telemetry.rs index be281e05a1ff..85cd23fe06f7 100644 --- a/codex-rs/core/src/auth_env_telemetry.rs +++ b/codex-rs/core/src/auth_env_telemetry.rs @@ -71,6 +71,7 @@ mod tests { request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, + websocket_connect_timeout_ms: None, requires_openai_auth: false, supports_websockets: false, }; diff --git a/codex-rs/core/src/codex_delegate.rs b/codex-rs/core/src/codex_delegate.rs index 6bbc4f01bec6..1838150408d8 100644 --- a/codex-rs/core/src/codex_delegate.rs +++ b/codex-rs/core/src/codex_delegate.rs @@ -669,7 +669,7 @@ async fn maybe_auto_review_mcp_request_user_input( parent_ctx.as_ref(), &invocation.server, &invocation.tool, - None, + /*request_headers*/ None, ) .await; let review_cancel = cancel_token.child_token(); diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 939affa9dff4..579a79b433f8 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -618,7 +618,7 @@ impl AsyncManagedClient { } async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> { - let managed = self.client(None).await?; + let managed = self.client(/*request_headers*/ None).await?; managed.notify_sandbox_state_change(sandbox_state).await } } @@ -731,7 +731,7 @@ impl McpConnectionManager { let auth_entry = auth_entries.get(&server_name).cloned(); let sandbox_state = initial_sandbox_state.clone(); join_set.spawn(async move { - let outcome = async_managed_client.client(None).await; + let outcome = async_managed_client.client(/*request_headers*/ None).await; if cancel_token.is_cancelled() { return (server_name, Err(StartupOutcomeError::Cancelled)); } @@ -829,7 +829,12 @@ impl McpConnectionManager { return false; }; - match tokio::time::timeout(timeout, async_managed_client.client(None)).await { + match tokio::time::timeout( + timeout, + async_managed_client.client(/*request_headers*/ None), + ) + .await + { Ok(Ok(_)) => true, Ok(Err(_)) | Err(_) => false, } @@ -849,7 +854,7 @@ impl McpConnectionManager { continue; }; - match async_managed_client.client(None).await { + match async_managed_client.client(/*request_headers*/ None).await { Ok(_) => {} Err(error) => failures.push(McpStartupFailure { server: server_name.clone(), @@ -864,7 +869,8 @@ impl McpConnectionManager { /// fully-qualified name for the tool. #[instrument(level = "trace", skip_all)] pub async fn list_all_tools(&self) -> HashMap { - self.list_all_tools_with_request_headers(None).await + self.list_all_tools_with_request_headers(/*request_headers*/ None) + .await } #[instrument(level = "trace", skip_all)] @@ -893,7 +899,7 @@ impl McpConnectionManager { .clients .get(CODEX_APPS_MCP_SERVER_NAME) .ok_or_else(|| anyhow!("unknown MCP server '{CODEX_APPS_MCP_SERVER_NAME}'"))? - .client(None) + .client(/*request_headers*/ None) .await .context("failed to get client")?; @@ -903,7 +909,7 @@ impl McpConnectionManager { CODEX_APPS_MCP_SERVER_NAME, &managed_client.client, managed_client.tool_timeout, - None, + /*request_headers*/ None, ) .await .with_context(|| { @@ -940,7 +946,8 @@ impl McpConnectionManager { for (server_name, async_managed_client) in clients_snapshot { let server_name = server_name.clone(); - let Ok(managed_client) = async_managed_client.client(None).await else { + let Ok(managed_client) = async_managed_client.client(/*request_headers*/ None).await + else { continue; }; let timeout = managed_client.tool_timeout; @@ -1006,7 +1013,8 @@ impl McpConnectionManager { for (server_name, async_managed_client) in clients_snapshot { let server_name_cloned = server_name.clone(); - let Ok(managed_client) = async_managed_client.client(None).await else { + let Ok(managed_client) = async_managed_client.client(/*request_headers*/ None).await + else { continue; }; let client = managed_client.client.clone(); @@ -1118,7 +1126,9 @@ impl McpConnectionManager { server: &str, params: Option, ) -> Result { - let managed = self.client_by_name(server, None).await?; + let managed = self + .client_by_name(server, /*request_headers*/ None) + .await?; let timeout = managed.tool_timeout; managed @@ -1134,7 +1144,9 @@ impl McpConnectionManager { server: &str, params: Option, ) -> Result { - let managed = self.client_by_name(server, None).await?; + let managed = self + .client_by_name(server, /*request_headers*/ None) + .await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; @@ -1150,7 +1162,9 @@ impl McpConnectionManager { server: &str, params: ReadResourceRequestParams, ) -> Result { - let managed = self.client_by_name(server, None).await?; + let managed = self + .client_by_name(server, /*request_headers*/ None) + .await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; let uri = params.uri.clone(); From 732d7ac81fd1ef112dd21714055cc6fc83feb316 Mon Sep 17 00:00:00 2001 From: nicholasclark-openai Date: Wed, 18 Mar 2026 16:42:10 -0700 Subject: [PATCH 4/5] codex: scope MCP request headers to turn lifecycle Set Codex Apps MCP request headers once per active turn and clear them on turn end, instead of threading request-scoped headers through every tool call. Keep RMCP header injection limited to streamable HTTP tools/call requests so list/init paths stay unchanged and concurrent tool calls on the same client are not serialized. Co-authored-by: Codex --- codex-rs/core/src/codex.rs | 41 +++- codex-rs/core/src/codex_delegate.rs | 1 - codex-rs/core/src/mcp_connection_manager.rs | 186 ++++++------------ .../core/src/mcp_connection_manager_tests.rs | 5 - codex-rs/core/src/mcp_tool_call.rs | 69 +------ codex-rs/core/src/tasks/mod.rs | 6 + codex-rs/rmcp-client/src/rmcp_client.rs | 66 ++----- codex-rs/rmcp-client/tests/resources.rs | 1 - .../tests/streamable_http_recovery.rs | 2 - 9 files changed, 127 insertions(+), 250 deletions(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 170b052f2a67..db0e4472e9e7 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -125,6 +125,8 @@ use futures::future::BoxFuture; use futures::future::Shared; use futures::prelude::*; use futures::stream::FuturesOrdered; +use reqwest::header::HeaderMap; +use reqwest::header::HeaderValue; use rmcp::model::ListResourceTemplatesResult; use rmcp::model::ListResourcesResult; use rmcp::model::PaginatedRequestParams; @@ -3942,16 +3944,51 @@ impl Session { tool: &str, arguments: Option, meta: Option, - request_headers: Option, ) -> anyhow::Result { self.services .mcp_connection_manager .read() .await - .call_tool(server, tool, arguments, meta, request_headers) + .call_tool(server, tool, arguments, meta) .await } + pub(crate) async fn sync_mcp_request_headers_for_turn(&self, turn_context: &TurnContext) { + let mut request_headers = HeaderMap::new(); + let session_id = self.conversation_id.to_string(); + if let Ok(value) = HeaderValue::from_str(&session_id) { + request_headers.insert("session_id", value.clone()); + request_headers.insert("x-client-request-id", value); + } + if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value() + && let Ok(value) = HeaderValue::from_str(&turn_metadata) + { + request_headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value); + } + + let request_headers = if request_headers.is_empty() { + None + } else { + Some(request_headers) + }; + self.services + .mcp_connection_manager + .read() + .await + .set_request_headers_for_server( + crate::mcp::CODEX_APPS_MCP_SERVER_NAME, + request_headers, + ); + } + + pub(crate) async fn clear_mcp_request_headers(&self) { + self.services + .mcp_connection_manager + .read() + .await + .set_request_headers_for_server(crate::mcp::CODEX_APPS_MCP_SERVER_NAME, None); + } + pub(crate) async fn parse_mcp_tool_name( &self, name: &str, diff --git a/codex-rs/core/src/codex_delegate.rs b/codex-rs/core/src/codex_delegate.rs index 084c6f1fb37f..e560cd9c7f15 100644 --- a/codex-rs/core/src/codex_delegate.rs +++ b/codex-rs/core/src/codex_delegate.rs @@ -670,7 +670,6 @@ async fn maybe_auto_review_mcp_request_user_input( parent_ctx.as_ref(), &invocation.server, &invocation.tool, - /*request_headers*/ None, ) .await; let review_cancel = cancel_token.child_token(); diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 579a79b433f8..7c8a34307022 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -378,7 +378,7 @@ struct ManagedClient { } impl ManagedClient { - fn listed_tools(&self, _request_headers: Option) -> Vec { + fn listed_tools(&self) -> Vec { let total_start = Instant::now(); if let Some(cache_context) = self.codex_apps_tools_cache_context.as_ref() && let CachedCodexAppsToolsLoad::Hit(tools) = @@ -423,44 +423,12 @@ impl ManagedClient { #[derive(Clone)] struct AsyncManagedClient { client: Shared>>, + request_headers: Arc>>, startup_snapshot: Option>, startup_complete: Arc, - startup_request_headers: Arc>>, tool_plugin_provenance: Arc, } -struct StartupRequestHeadersGuard { - state: Arc>>, - previous: Option, -} - -impl StartupRequestHeadersGuard { - fn set( - state: Arc>>, - headers: Option, - ) -> Self { - let previous = { - let mut guard = state - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - let previous = guard.clone(); - *guard = headers; - previous - }; - Self { state, previous } - } -} - -impl Drop for StartupRequestHeadersGuard { - fn drop(&mut self) { - let mut guard = self - .state - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - *guard = self.previous.clone(); - } -} - impl AsyncManagedClient { // Keep this constructor flat so the startup inputs remain readable at the // single call site instead of introducing a one-off params wrapper. @@ -481,19 +449,26 @@ impl AsyncManagedClient { codex_apps_tools_cache_context.as_ref(), ) .map(|tools| filter_tools(tools, &tool_filter)); + let request_headers = Arc::new(StdMutex::new(None)); let startup_tool_filter = tool_filter; let startup_complete = Arc::new(AtomicBool::new(false)); let startup_complete_for_fut = Arc::clone(&startup_complete); - let startup_request_headers = Arc::new(StdMutex::new(None)); - let startup_request_headers_for_fut = Arc::clone(&startup_request_headers); + let request_headers_for_client = Arc::clone(&request_headers); let fut = async move { let outcome = async { if let Err(error) = validate_mcp_server_name(&server_name) { return Err(error.into()); } - let client = - Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?); + let client = Arc::new( + make_rmcp_client( + &server_name, + config.transport, + store_mode, + request_headers_for_client, + ) + .await?, + ); match start_server_task( server_name, client, @@ -506,7 +481,6 @@ impl AsyncManagedClient { tx_event, elicitation_requests, codex_apps_tools_cache_context, - startup_request_headers: startup_request_headers_for_fut, }, ) .or_cancel(&cancel_token) @@ -531,19 +505,14 @@ impl AsyncManagedClient { Self { client, + request_headers, startup_snapshot, startup_complete, - startup_request_headers, tool_plugin_provenance, } } - async fn client( - &self, - request_headers: Option, - ) -> Result { - let _request_headers_guard = - StartupRequestHeadersGuard::set(self.startup_request_headers.clone(), request_headers); + async fn client(&self) -> Result { self.client.clone().await } @@ -554,10 +523,7 @@ impl AsyncManagedClient { None } - async fn listed_tools( - &self, - request_headers: Option, - ) -> Option> { + async fn listed_tools(&self) -> Option> { let annotate_tools = |tools: Vec| { let mut tools = tools; for tool in &mut tools { @@ -609,8 +575,8 @@ impl AsyncManagedClient { let tools = if let Some(startup_tools) = self.startup_snapshot_while_initializing() { Some(startup_tools) } else { - match self.client(request_headers.clone()).await { - Ok(client) => Some(client.listed_tools(request_headers)), + match self.client().await { + Ok(client) => Some(client.listed_tools()), Err(_) => self.startup_snapshot.clone(), } }; @@ -618,9 +584,17 @@ impl AsyncManagedClient { } async fn notify_sandbox_state_change(&self, sandbox_state: &SandboxState) -> Result<()> { - let managed = self.client(/*request_headers*/ None).await?; + let managed = self.client().await?; managed.notify_sandbox_state_change(sandbox_state).await } + + fn set_request_headers(&self, request_headers: Option) { + let mut guard = self + .request_headers + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *guard = request_headers; + } } pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state"; @@ -731,7 +705,7 @@ impl McpConnectionManager { let auth_entry = auth_entries.get(&server_name).cloned(); let sandbox_state = initial_sandbox_state.clone(); join_set.spawn(async move { - let outcome = async_managed_client.client(/*request_headers*/ None).await; + let outcome = async_managed_client.client().await; if cancel_token.is_cancelled() { return (server_name, Err(StartupOutcomeError::Cancelled)); } @@ -800,15 +774,11 @@ impl McpConnectionManager { (manager, cancel_token) } - async fn client_by_name( - &self, - name: &str, - request_headers: Option, - ) -> Result { + async fn client_by_name(&self, name: &str) -> Result { self.clients .get(name) .ok_or_else(|| anyhow!("unknown MCP server '{name}'"))? - .client(request_headers) + .client() .await .context("failed to get client") } @@ -829,12 +799,7 @@ impl McpConnectionManager { return false; }; - match tokio::time::timeout( - timeout, - async_managed_client.client(/*request_headers*/ None), - ) - .await - { + match tokio::time::timeout(timeout, async_managed_client.client()).await { Ok(Ok(_)) => true, Ok(Err(_)) | Err(_) => false, } @@ -854,7 +819,7 @@ impl McpConnectionManager { continue; }; - match async_managed_client.client(/*request_headers*/ None).await { + match async_managed_client.client().await { Ok(_) => {} Err(error) => failures.push(McpStartupFailure { server: server_name.clone(), @@ -869,19 +834,9 @@ impl McpConnectionManager { /// fully-qualified name for the tool. #[instrument(level = "trace", skip_all)] pub async fn list_all_tools(&self) -> HashMap { - self.list_all_tools_with_request_headers(/*request_headers*/ None) - .await - } - - #[instrument(level = "trace", skip_all)] - pub async fn list_all_tools_with_request_headers( - &self, - request_headers: Option, - ) -> HashMap { let mut tools = HashMap::new(); for managed_client in self.clients.values() { - let Some(server_tools) = managed_client.listed_tools(request_headers.clone()).await - else { + let Some(server_tools) = managed_client.listed_tools().await else { continue; }; tools.extend(qualify_tools(server_tools)); @@ -899,7 +854,7 @@ impl McpConnectionManager { .clients .get(CODEX_APPS_MCP_SERVER_NAME) .ok_or_else(|| anyhow!("unknown MCP server '{CODEX_APPS_MCP_SERVER_NAME}'"))? - .client(/*request_headers*/ None) + .client() .await .context("failed to get client")?; @@ -909,7 +864,6 @@ impl McpConnectionManager { CODEX_APPS_MCP_SERVER_NAME, &managed_client.client, managed_client.tool_timeout, - /*request_headers*/ None, ) .await .with_context(|| { @@ -946,8 +900,7 @@ impl McpConnectionManager { for (server_name, async_managed_client) in clients_snapshot { let server_name = server_name.clone(); - let Ok(managed_client) = async_managed_client.client(/*request_headers*/ None).await - else { + let Ok(managed_client) = async_managed_client.client().await else { continue; }; let timeout = managed_client.tool_timeout; @@ -1013,8 +966,7 @@ impl McpConnectionManager { for (server_name, async_managed_client) in clients_snapshot { let server_name_cloned = server_name.clone(); - let Ok(managed_client) = async_managed_client.client(/*request_headers*/ None).await - else { + let Ok(managed_client) = async_managed_client.client().await else { continue; }; let client = managed_client.client.clone(); @@ -1082,9 +1034,8 @@ impl McpConnectionManager { tool: &str, arguments: Option, meta: Option, - request_headers: Option, ) -> Result { - let client = self.client_by_name(server, request_headers.clone()).await?; + let client = self.client_by_name(server).await?; if !client.tool_filter.allows(tool) { return Err(anyhow!( "tool '{tool}' is disabled for MCP server '{server}'" @@ -1093,13 +1044,7 @@ impl McpConnectionManager { let result: rmcp::model::CallToolResult = client .client - .call_tool( - tool.to_string(), - arguments, - meta, - client.tool_timeout, - request_headers, - ) + .call_tool(tool.to_string(), arguments, meta, client.tool_timeout) .await .with_context(|| format!("tool call failed for `{server}/{tool}`"))?; @@ -1120,15 +1065,23 @@ impl McpConnectionManager { }) } + pub(crate) fn set_request_headers_for_server( + &self, + server_name: &str, + request_headers: Option, + ) { + if let Some(client) = self.clients.get(server_name) { + client.set_request_headers(request_headers); + } + } + /// List resources from the specified server. pub async fn list_resources( &self, server: &str, params: Option, ) -> Result { - let managed = self - .client_by_name(server, /*request_headers*/ None) - .await?; + let managed = self.client_by_name(server).await?; let timeout = managed.tool_timeout; managed @@ -1144,9 +1097,7 @@ impl McpConnectionManager { server: &str, params: Option, ) -> Result { - let managed = self - .client_by_name(server, /*request_headers*/ None) - .await?; + let managed = self.client_by_name(server).await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; @@ -1162,9 +1113,7 @@ impl McpConnectionManager { server: &str, params: ReadResourceRequestParams, ) -> Result { - let managed = self - .client_by_name(server, /*request_headers*/ None) - .await?; + let managed = self.client_by_name(server).await?; let client = managed.client.clone(); let timeout = managed.tool_timeout; let uri = params.uri.clone(); @@ -1424,7 +1373,6 @@ async fn start_server_task( tx_event, elicitation_requests, codex_apps_tools_cache_context, - startup_request_headers, } = params; let elicitation = elicitation_capability_for_server(&server_name); let params = InitializeRequestParams { @@ -1450,34 +1398,16 @@ async fn start_server_task( let send_elicitation = elicitation_requests.make_sender(server_name.clone(), tx_event); - let initialize_request_headers = startup_request_headers - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .clone(); let initialize_result = client - .initialize( - params, - startup_timeout, - send_elicitation, - initialize_request_headers, - ) + .initialize(params, startup_timeout, send_elicitation) .await .map_err(StartupOutcomeError::from)?; let list_start = Instant::now(); let fetch_start = Instant::now(); - let list_request_headers = startup_request_headers - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .clone(); - let tools = list_tools_for_client_uncached( - &server_name, - &client, - startup_timeout, - list_request_headers, - ) - .await - .map_err(StartupOutcomeError::from)?; + let tools = list_tools_for_client_uncached(&server_name, &client, startup_timeout) + .await + .map_err(StartupOutcomeError::from)?; emit_duration( MCP_TOOLS_FETCH_UNCACHED_DURATION_METRIC, fetch_start.elapsed(), @@ -1522,13 +1452,13 @@ struct StartServerTaskParams { tx_event: Sender, elicitation_requests: ElicitationRequestManager, codex_apps_tools_cache_context: Option, - startup_request_headers: Arc>>, } async fn make_rmcp_client( server_name: &str, transport: McpServerTransportConfig, store_mode: OAuthCredentialsStoreMode, + request_headers: Arc>>, ) -> Result { match transport { McpServerTransportConfig::Stdio { @@ -1562,6 +1492,7 @@ async fn make_rmcp_client( http_headers, env_http_headers, store_mode, + request_headers, ) .await .map_err(StartupOutcomeError::from) @@ -1684,10 +1615,9 @@ async fn list_tools_for_client_uncached( server_name: &str, client: &Arc, timeout: Option, - request_headers: Option, ) -> Result> { let resp = client - .list_tools_with_connector_ids(/*params*/ None, timeout, request_headers) + .list_tools_with_connector_ids(/*params*/ None, timeout) .await?; let tools = resp .tools diff --git a/codex-rs/core/src/mcp_connection_manager_tests.rs b/codex-rs/core/src/mcp_connection_manager_tests.rs index 980f4091a829..c5f7fc4a4086 100644 --- a/codex-rs/core/src/mcp_connection_manager_tests.rs +++ b/codex-rs/core/src/mcp_connection_manager_tests.rs @@ -4,7 +4,6 @@ use codex_protocol::protocol::McpAuthStatus; use rmcp::model::JsonObject; use std::collections::HashSet; use std::sync::Arc; -use std::sync::Mutex as StdMutex; use tempfile::tempdir; fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo { @@ -416,7 +415,6 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() { client: pending_client, startup_snapshot: Some(startup_tools), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), - startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); @@ -442,7 +440,6 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot( client: pending_client, startup_snapshot: None, startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), - startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); @@ -465,7 +462,6 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty( client: pending_client, startup_snapshot: Some(Vec::new()), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), - startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); @@ -498,7 +494,6 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() { client: failed_client, startup_snapshot: Some(startup_tools), startup_complete, - startup_request_headers: Arc::new(StdMutex::new(None)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), }, ); diff --git a/codex-rs/core/src/mcp_tool_call.rs b/codex-rs/core/src/mcp_tool_call.rs index cfa6b471db26..06d801cbac8c 100644 --- a/codex-rs/core/src/mcp_tool_call.rs +++ b/codex-rs/core/src/mcp_tool_call.rs @@ -45,8 +45,6 @@ use codex_protocol::request_user_input::RequestUserInputQuestionOption; use codex_protocol::request_user_input::RequestUserInputResponse; use codex_rmcp_client::ElicitationAction; use codex_rmcp_client::ElicitationResponse; -use reqwest::header::HeaderMap; -use reqwest::header::HeaderValue; use rmcp::model::ToolAnnotations; use serde::Serialize; use std::path::Path; @@ -83,15 +81,8 @@ pub(crate) async fn handle_mcp_tool_call( arguments: arguments_value.clone(), }; - let request_headers = build_mcp_request_headers(sess.as_ref(), turn_context.as_ref(), &server); - let metadata = lookup_mcp_tool_metadata( - sess.as_ref(), - turn_context.as_ref(), - &server, - &tool_name, - request_headers.clone(), - ) - .await; + let metadata = + lookup_mcp_tool_metadata(sess.as_ref(), turn_context.as_ref(), &server, &tool_name).await; let app_tool_policy = if server == CODEX_APPS_MCP_SERVER_NAME { connectors::app_tool_policy( &turn_context.config, @@ -159,7 +150,6 @@ pub(crate) async fn handle_mcp_tool_call( &tool_name, arguments_value.clone(), request_meta.clone(), - request_headers.clone(), ) .await .map_err(|e| format!("tool call error: {e:?}")); @@ -190,7 +180,6 @@ pub(crate) async fn handle_mcp_tool_call( turn_context.as_ref(), &server, &tool_name, - request_headers.clone(), ) .await; result @@ -247,13 +236,7 @@ pub(crate) async fn handle_mcp_tool_call( let start = Instant::now(); // Perform the tool call. let result = sess - .call_tool( - &server, - &tool_name, - arguments_value.clone(), - request_meta, - request_headers.clone(), - ) + .call_tool(&server, &tool_name, arguments_value.clone(), request_meta) .await .map_err(|e| format!("tool call error: {e:?}")); let result = sanitize_mcp_tool_result_for_model( @@ -279,14 +262,7 @@ pub(crate) async fn handle_mcp_tool_call( tool_call_end_event.clone(), ) .await; - maybe_track_codex_app_used( - sess.as_ref(), - turn_context.as_ref(), - &server, - &tool_name, - request_headers.clone(), - ) - .await; + maybe_track_codex_app_used(sess.as_ref(), turn_context.as_ref(), &server, &tool_name).await; let status = if result.is_ok() { "ok" } else { "error" }; turn_context @@ -357,12 +333,11 @@ async fn maybe_track_codex_app_used( turn_context: &TurnContext, server: &str, tool_name: &str, - request_headers: Option, ) { if server != CODEX_APPS_MCP_SERVER_NAME { return; } - let metadata = lookup_mcp_app_usage_metadata(sess, server, tool_name, request_headers).await; + let metadata = lookup_mcp_app_usage_metadata(sess, server, tool_name).await; let (connector_id, app_name) = metadata .map(|metadata| (metadata.connector_id, metadata.app_name)) .unwrap_or((None, None)); @@ -414,34 +389,6 @@ pub(crate) struct McpToolApprovalMetadata { const MCP_TOOL_CODEX_APPS_META_KEY: &str = "_codex_apps"; -fn build_mcp_request_headers( - sess: &Session, - turn_context: &TurnContext, - server: &str, -) -> Option { - if server != CODEX_APPS_MCP_SERVER_NAME { - return None; - } - - let session_id = sess.conversation_id.to_string(); - let mut headers = HeaderMap::new(); - if let Ok(value) = HeaderValue::from_str(&session_id) { - headers.insert("session_id", value.clone()); - headers.insert("x-client-request-id", value); - } - if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value() - && let Ok(value) = HeaderValue::from_str(&turn_metadata) - { - headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value); - } - - if headers.is_empty() { - None - } else { - Some(headers) - } -} - fn build_mcp_tool_call_request_meta( server: &str, metadata: Option<&McpToolApprovalMetadata>, @@ -794,14 +741,13 @@ pub(crate) async fn lookup_mcp_tool_metadata( turn_context: &TurnContext, server: &str, tool_name: &str, - request_headers: Option, ) -> Option { let tools = sess .services .mcp_connection_manager .read() .await - .list_all_tools_with_request_headers(request_headers.clone()) + .list_all_tools() .await; let tool_info = tools @@ -852,14 +798,13 @@ async fn lookup_mcp_app_usage_metadata( sess: &Session, server: &str, tool_name: &str, - request_headers: Option, ) -> Option { let tools = sess .services .mcp_connection_manager .read() .await - .list_all_tools_with_request_headers(request_headers) + .list_all_tools() .await; tools.into_values().find_map(|tool_info| { diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index c52e4f91780e..049ed56d45f2 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -153,6 +153,8 @@ impl Session { ) { self.abort_all_tasks(TurnAbortReason::Replaced).await; self.clear_connector_selection().await; + self.sync_mcp_request_headers_for_turn(turn_context.as_ref()) + .await; let task: Arc = Arc::new(task); let task_kind = task.kind(); @@ -233,6 +235,7 @@ impl Session { // in-flight approval wait can surface as a model-visible rejection before TurnAborted. active_turn.clear_pending().await; } + self.clear_mcp_request_headers().await; } pub async fn on_task_finished( @@ -262,6 +265,9 @@ impl Session { *active = None; } drop(active); + if should_clear_active_turn { + self.clear_mcp_request_headers().await; + } if !pending_input.is_empty() { for pending_input_item in pending_input { match inspect_pending_input(self, &turn_context, pending_input_item).await { diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 86d14cf881bf..cf4f90ad3b05 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -23,6 +23,7 @@ use reqwest::header::HeaderMap; use reqwest::header::WWW_AUTHENTICATE; use rmcp::model::CallToolRequestParams; use rmcp::model::CallToolResult; +use rmcp::model::ClientJsonRpcMessage; use rmcp::model::ClientNotification; use rmcp::model::ClientRequest; use rmcp::model::CreateElicitationRequestParams; @@ -83,6 +84,15 @@ const JSON_MIME_TYPE: &str = "application/json"; const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id"; const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192; + +fn message_uses_request_scoped_headers(message: &ClientJsonRpcMessage) -> bool { + matches!( + message, + ClientJsonRpcMessage::Request(request) + if request.request.method() == "tools/call" + ) +} + fn apply_request_scoped_headers( mut request: reqwest::RequestBuilder, request_headers_state: &Arc>>, @@ -156,7 +166,9 @@ impl StreamableHttpClient for StreamableHttpResponseClient { if let Some(session_id_value) = session_id.as_ref() { request = request.header(HEADER_SESSION_ID, session_id_value.as_ref()); } - request = apply_request_scoped_headers(request, &self.request_headers_state); + if message_uses_request_scoped_headers(&message) { + request = apply_request_scoped_headers(request, &self.request_headers_state); + } let response = request .json(&message) @@ -252,8 +264,6 @@ impl StreamableHttpClient for StreamableHttpResponseClient { if let Some(auth_header) = auth_token { request_builder = request_builder.bearer_auth(auth_header); } - request_builder = - apply_request_scoped_headers(request_builder, &self.request_headers_state); let response = request_builder .header(HEADER_SESSION_ID, session.as_ref()) .send() @@ -291,8 +301,6 @@ impl StreamableHttpClient for StreamableHttpResponseClient { if let Some(auth_header) = auth_token { request_builder = request_builder.bearer_auth(auth_header); } - request_builder = - apply_request_scoped_headers(request_builder, &self.request_headers_state); let response = request_builder .send() @@ -503,38 +511,6 @@ pub struct RmcpClient { request_headers: Option>>>, } -struct RequestHeadersGuard { - state: Option>>>, - previous: Option, -} - -impl RequestHeadersGuard { - fn set(state: Option>>>, headers: Option) -> Self { - let previous = if let Some(state_ref) = state.as_ref() { - let mut guard = state_ref - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - let previous = guard.clone(); - *guard = headers; - previous - } else { - None - }; - Self { state, previous } - } -} - -impl Drop for RequestHeadersGuard { - fn drop(&mut self) { - if let Some(state) = self.state.as_ref() { - let mut guard = state - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); - *guard = self.previous.clone(); - } - } -} - impl RmcpClient { pub async fn new_stdio_client( program: OsString, @@ -574,6 +550,7 @@ impl RmcpClient { http_headers: Option>, env_http_headers: Option>, store_mode: OAuthCredentialsStoreMode, + request_headers: Arc>>, ) -> Result { let transport_recipe = TransportRecipe::StreamableHttp { server_name: server_name.to_string(), @@ -583,9 +560,9 @@ impl RmcpClient { env_http_headers, store_mode, }; - let request_headers = Some(Arc::new(StdMutex::new(None))); let transport = - Self::create_pending_transport(&transport_recipe, request_headers.clone()).await?; + Self::create_pending_transport(&transport_recipe, Some(Arc::clone(&request_headers))) + .await?; Ok(Self { state: Mutex::new(ClientState::Connecting { transport: Some(transport), @@ -593,7 +570,7 @@ impl RmcpClient { transport_recipe, initialize_context: Mutex::new(None), session_recovery_lock: Mutex::new(()), - request_headers, + request_headers: Some(request_headers), }) } @@ -604,7 +581,6 @@ impl RmcpClient { params: InitializeRequestParams, timeout: Option, send_elicitation: SendElicitation, - request_headers: Option, ) -> Result { let client_handler = LoggingClientHandler::new(params.clone(), send_elicitation); let pending_transport = { @@ -618,8 +594,6 @@ impl RmcpClient { } }; - let _request_headers_guard = - RequestHeadersGuard::set(self.request_headers.clone(), request_headers); let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport(pending_transport, client_handler.clone(), timeout) .await?; @@ -676,11 +650,8 @@ impl RmcpClient { &self, params: Option, timeout: Option, - request_headers: Option, ) -> Result { self.refresh_oauth_if_needed().await; - let _request_headers_guard = - RequestHeadersGuard::set(self.request_headers.clone(), request_headers); let result = self .run_service_operation("tools/list", timeout, move |service| { let params = params.clone(); @@ -774,7 +745,6 @@ impl RmcpClient { arguments: Option, meta: Option, timeout: Option, - request_headers: Option, ) -> Result { self.refresh_oauth_if_needed().await; let arguments = match arguments { @@ -801,8 +771,6 @@ impl RmcpClient { arguments, task: None, }; - let _request_headers_guard = - RequestHeadersGuard::set(self.request_headers.clone(), request_headers); let result = self .run_service_operation("tools/call", timeout, move |service| { let rmcp_params = rmcp_params.clone(); diff --git a/codex-rs/rmcp-client/tests/resources.rs b/codex-rs/rmcp-client/tests/resources.rs index 4a3e587d11bb..ba1a8e43106a 100644 --- a/codex-rs/rmcp-client/tests/resources.rs +++ b/codex-rs/rmcp-client/tests/resources.rs @@ -78,7 +78,6 @@ async fn rmcp_client_can_list_and_read_resources() -> anyhow::Result<()> { } .boxed() }), - /*request_headers*/ None, ) .await?; diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index 6512e248693a..fb2fc96d20f1 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -94,7 +94,6 @@ async fn create_client(base_url: &str) -> anyhow::Result { } .boxed() }), - /*request_headers*/ None, ) .await?; @@ -108,7 +107,6 @@ async fn call_echo_tool(client: &RmcpClient, message: &str) -> anyhow::Result Date: Wed, 18 Mar 2026 17:12:13 -0700 Subject: [PATCH 5/5] codex: fix MCP header CI regressions Add the new request_headers test fixture state, update the streamable HTTP recovery test to pass the new client constructor argument, and satisfy the argument comment lint for clearing turn-scoped MCP request headers. Co-authored-by: Codex --- codex-rs/core/src/codex.rs | 5 ++++- codex-rs/core/src/mcp_connection_manager_tests.rs | 5 +++++ codex-rs/rmcp-client/tests/streamable_http_recovery.rs | 3 +++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index db0e4472e9e7..b6f7505d1b63 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -3986,7 +3986,10 @@ impl Session { .mcp_connection_manager .read() .await - .set_request_headers_for_server(crate::mcp::CODEX_APPS_MCP_SERVER_NAME, None); + .set_request_headers_for_server( + crate::mcp::CODEX_APPS_MCP_SERVER_NAME, + /*request_headers*/ None, + ); } pub(crate) async fn parse_mcp_tool_name( diff --git a/codex-rs/core/src/mcp_connection_manager_tests.rs b/codex-rs/core/src/mcp_connection_manager_tests.rs index c5f7fc4a4086..9401b379bcbf 100644 --- a/codex-rs/core/src/mcp_connection_manager_tests.rs +++ b/codex-rs/core/src/mcp_connection_manager_tests.rs @@ -4,6 +4,7 @@ use codex_protocol::protocol::McpAuthStatus; use rmcp::model::JsonObject; use std::collections::HashSet; use std::sync::Arc; +use std::sync::Mutex as StdMutex; use tempfile::tempdir; fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo { @@ -413,6 +414,7 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() { CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: pending_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: Some(startup_tools), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), @@ -438,6 +440,7 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot( CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: pending_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: None, startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), @@ -460,6 +463,7 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty( CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: pending_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: Some(Vec::new()), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), @@ -492,6 +496,7 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() { CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: failed_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: Some(startup_tools), startup_complete, tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index fb2fc96d20f1..8b03da8f1ad6 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -1,5 +1,7 @@ use std::net::TcpListener; use std::path::PathBuf; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; use std::time::Duration; use std::time::Instant; @@ -77,6 +79,7 @@ async fn create_client(base_url: &str) -> anyhow::Result { None, None, OAuthCredentialsStoreMode::File, + Arc::new(StdMutex::new(None)), ) .await?;