diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 7b947249bda6..fcd827fddc98 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -125,6 +125,8 @@ use futures::future::BoxFuture; use futures::future::Shared; use futures::prelude::*; use futures::stream::FuturesOrdered; +use reqwest::header::HeaderMap; +use reqwest::header::HeaderValue; use rmcp::model::ListResourceTemplatesResult; use rmcp::model::ListResourcesResult; use rmcp::model::PaginatedRequestParams; @@ -3943,6 +3945,12 @@ impl Session { arguments: Option, meta: Option, ) -> anyhow::Result { + if server == CODEX_APPS_MCP_SERVER_NAME + && let Some((turn_context, _)) = self.active_turn_context_and_cancellation_token().await + { + self.sync_mcp_request_headers_for_turn(turn_context.as_ref()) + .await; + } self.services .mcp_connection_manager .read() @@ -3951,6 +3959,45 @@ impl Session { .await } + pub(crate) async fn sync_mcp_request_headers_for_turn(&self, turn_context: &TurnContext) { + let mut request_headers = HeaderMap::new(); + let session_id = self.conversation_id.to_string(); + if let Ok(value) = HeaderValue::from_str(&session_id) { + request_headers.insert("session_id", value.clone()); + request_headers.insert("x-client-request-id", value); + } + if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value() + && let Ok(value) = HeaderValue::from_str(&turn_metadata) + { + request_headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value); + } + + let request_headers = if request_headers.is_empty() { + None + } else { + Some(request_headers) + }; + self.services + .mcp_connection_manager + .read() + .await + .set_request_headers_for_server( + crate::mcp::CODEX_APPS_MCP_SERVER_NAME, + request_headers, + ); + } + + pub(crate) async fn clear_mcp_request_headers(&self) { + self.services + .mcp_connection_manager + .read() + .await + .set_request_headers_for_server( + crate::mcp::CODEX_APPS_MCP_SERVER_NAME, + /*request_headers*/ None, + ); + } + pub(crate) async fn parse_mcp_tool_name( &self, name: &str, diff --git a/codex-rs/core/src/codex_tests.rs b/codex-rs/core/src/codex_tests.rs index f767c05f4eb0..22e83ee216cc 100644 --- a/codex-rs/core/src/codex_tests.rs +++ b/codex-rs/core/src/codex_tests.rs @@ -26,6 +26,8 @@ use codex_protocol::protocol::ReadOnlyAccess; use codex_protocol::protocol::SandboxPolicy; use codex_protocol::request_permissions::PermissionGrantScope; use codex_protocol::request_permissions::RequestPermissionProfile; +use codex_protocol::user_input::UserInput; +use reqwest::header::HeaderValue; use tracing::Span; use crate::protocol::CompactedItem; @@ -56,6 +58,7 @@ use crate::tools::handlers::UnifiedExecHandler; use crate::tools::registry::ToolHandler; use crate::tools::router::ToolCallSource; use crate::turn_diff_tracker::TurnDiffTracker; +use async_trait::async_trait; use codex_app_server_protocol::AppInfo; use codex_execpolicy::Decision; use codex_execpolicy::NetworkRuleProtocol; @@ -78,7 +81,10 @@ use opentelemetry::trace::TracerProvider as _; use opentelemetry_sdk::trace::SdkTracerProvider; use std::path::Path; use std::time::Duration; +use tokio::sync::Notify; use tokio::time::sleep; +use tokio_util::sync::CancellationToken; +use tokio_util::task::AbortOnDropHandle; use tracing_opentelemetry::OpenTelemetrySpanExt; use tracing_subscriber::prelude::*; @@ -2716,6 +2722,99 @@ async fn request_permissions_is_auto_denied_when_granular_policy_blocks_tool_req ); } +struct NoopTask; + +#[async_trait] +impl SessionTask for NoopTask { + fn kind(&self) -> TaskKind { + TaskKind::Regular + } + + fn span_name(&self) -> &'static str { + "noop" + } + + async fn run( + self: Arc, + _ctx: Arc, + _turn_context: Arc, + _input: Vec, + _cancellation_token: CancellationToken, + ) -> Option { + None + } +} + +#[tokio::test] +async fn call_tool_refreshes_codex_apps_request_headers_from_active_turn() { + let (session, turn_context) = make_session_and_context().await; + let session = Arc::new(session); + let turn_context = Arc::new(turn_context); + + session + .services + .mcp_connection_manager + .write() + .await + .register_test_server_for_request_headers(CODEX_APPS_MCP_SERVER_NAME); + + session + .sync_mcp_request_headers_for_turn(turn_context.as_ref()) + .await; + let base_header = turn_context + .turn_metadata_state + .current_header_value() + .expect("base turn metadata header"); + + let updated_header = serde_json::json!({ + "turn_id": turn_context.sub_id, + "sandbox": "test", + "workspaces": { + "/tmp/repo": { + "has_changes": true + } + } + }) + .to_string(); + turn_context + .turn_metadata_state + .set_enriched_header_for_tests(Some(updated_header.clone())); + + let mut active_turn = ActiveTurn::default(); + let handle = tokio::spawn(async {}); + active_turn.add_task(crate::state::RunningTask { + done: Arc::new(Notify::new()), + kind: TaskKind::Regular, + task: Arc::new(NoopTask), + cancellation_token: CancellationToken::new(), + handle: Arc::new(AbortOnDropHandle::new(handle)), + turn_context: Arc::clone(&turn_context), + _timer: None, + }); + *session.active_turn.lock().await = Some(active_turn); + + let _err = session + .call_tool(CODEX_APPS_MCP_SERVER_NAME, "echo", None, None) + .await + .expect_err("test server is not initialized"); + + let headers = session + .services + .mcp_connection_manager + .read() + .await + .request_headers_for_server(CODEX_APPS_MCP_SERVER_NAME) + .expect("request headers should be tracked for codex apps"); + assert_eq!( + headers.get(crate::X_CODEX_TURN_METADATA_HEADER), + Some(&HeaderValue::from_str(&updated_header).expect("valid enriched header")), + ); + assert_ne!( + headers.get(crate::X_CODEX_TURN_METADATA_HEADER), + Some(&HeaderValue::from_str(&base_header).expect("valid base header")), + ); +} + #[tokio::test] async fn submit_with_id_captures_current_span_trace_context() { let (session, _turn_context) = make_session_and_context().await; diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 938d6d0b2bf3..cd80220d716b 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -423,6 +423,7 @@ impl ManagedClient { #[derive(Clone)] struct AsyncManagedClient { client: Shared>>, + request_headers: Arc>>, startup_snapshot: Option>, startup_complete: Arc, tool_plugin_provenance: Arc, @@ -448,17 +449,26 @@ impl AsyncManagedClient { codex_apps_tools_cache_context.as_ref(), ) .map(|tools| filter_tools(tools, &tool_filter)); + let request_headers = Arc::new(StdMutex::new(None)); let startup_tool_filter = tool_filter; let startup_complete = Arc::new(AtomicBool::new(false)); let startup_complete_for_fut = Arc::clone(&startup_complete); + let request_headers_for_client = Arc::clone(&request_headers); let fut = async move { let outcome = async { if let Err(error) = validate_mcp_server_name(&server_name) { return Err(error.into()); } - let client = - Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?); + let client = Arc::new( + make_rmcp_client( + &server_name, + config.transport, + store_mode, + request_headers_for_client, + ) + .await?, + ); match start_server_task( server_name, client, @@ -495,6 +505,7 @@ impl AsyncManagedClient { Self { client, + request_headers, startup_snapshot, startup_complete, tool_plugin_provenance, @@ -576,6 +587,14 @@ impl AsyncManagedClient { let managed = self.client().await?; managed.notify_sandbox_state_change(sandbox_state).await } + + fn set_request_headers(&self, request_headers: Option) { + let mut guard = self + .request_headers + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *guard = request_headers; + } } pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state"; @@ -617,6 +636,40 @@ impl McpConnectionManager { Self::new_uninitialized(approval_policy) } + #[cfg(test)] + pub(crate) fn register_test_server_for_request_headers(&mut self, server_name: &str) { + let failed_client = futures::future::ready::>( + Err(StartupOutcomeError::Failed { + error: "test request headers stub".to_string(), + }), + ) + .boxed() + .shared(); + self.clients.insert( + server_name.to_string(), + AsyncManagedClient { + client: failed_client, + request_headers: Arc::new(StdMutex::new(None)), + startup_snapshot: None, + startup_complete: Arc::new(AtomicBool::new(true)), + tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), + }, + ); + } + + #[cfg(test)] + pub(crate) fn request_headers_for_server( + &self, + server_name: &str, + ) -> Option { + let client = self.clients.get(server_name)?; + client + .request_headers + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone() + } + pub(crate) fn has_servers(&self) -> bool { !self.clients.is_empty() } @@ -1046,6 +1099,16 @@ impl McpConnectionManager { }) } + pub(crate) fn set_request_headers_for_server( + &self, + server_name: &str, + request_headers: Option, + ) { + if let Some(client) = self.clients.get(server_name) { + client.set_request_headers(request_headers); + } + } + /// List resources from the specified server. pub async fn list_resources( &self, @@ -1429,6 +1492,7 @@ async fn make_rmcp_client( server_name: &str, transport: McpServerTransportConfig, store_mode: OAuthCredentialsStoreMode, + request_headers: Arc>>, ) -> Result { match transport { McpServerTransportConfig::Stdio { @@ -1462,6 +1526,7 @@ async fn make_rmcp_client( http_headers, env_http_headers, store_mode, + request_headers, ) .await .map_err(StartupOutcomeError::from) diff --git a/codex-rs/core/src/mcp_connection_manager_tests.rs b/codex-rs/core/src/mcp_connection_manager_tests.rs index c5f7fc4a4086..9401b379bcbf 100644 --- a/codex-rs/core/src/mcp_connection_manager_tests.rs +++ b/codex-rs/core/src/mcp_connection_manager_tests.rs @@ -4,6 +4,7 @@ use codex_protocol::protocol::McpAuthStatus; use rmcp::model::JsonObject; use std::collections::HashSet; use std::sync::Arc; +use std::sync::Mutex as StdMutex; use tempfile::tempdir; fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo { @@ -413,6 +414,7 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() { CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: pending_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: Some(startup_tools), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), @@ -438,6 +440,7 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot( CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: pending_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: None, startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), @@ -460,6 +463,7 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty( CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: pending_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: Some(Vec::new()), startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)), tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), @@ -492,6 +496,7 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() { CODEX_APPS_MCP_SERVER_NAME.to_string(), AsyncManagedClient { client: failed_client, + request_headers: Arc::new(StdMutex::new(None)), startup_snapshot: Some(startup_tools), startup_complete, tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()), diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index c52e4f91780e..049ed56d45f2 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -153,6 +153,8 @@ impl Session { ) { self.abort_all_tasks(TurnAbortReason::Replaced).await; self.clear_connector_selection().await; + self.sync_mcp_request_headers_for_turn(turn_context.as_ref()) + .await; let task: Arc = Arc::new(task); let task_kind = task.kind(); @@ -233,6 +235,7 @@ impl Session { // in-flight approval wait can surface as a model-visible rejection before TurnAborted. active_turn.clear_pending().await; } + self.clear_mcp_request_headers().await; } pub async fn on_task_finished( @@ -262,6 +265,9 @@ impl Session { *active = None; } drop(active); + if should_clear_active_turn { + self.clear_mcp_request_headers().await; + } if !pending_input.is_empty() { for pending_input_item in pending_input { match inspect_pending_input(self, &turn_context, pending_input_item).await { diff --git a/codex-rs/core/src/turn_metadata.rs b/codex-rs/core/src/turn_metadata.rs index c0298c522122..105675272869 100644 --- a/codex-rs/core/src/turn_metadata.rs +++ b/codex-rs/core/src/turn_metadata.rs @@ -217,6 +217,14 @@ impl TurnMetadataState { } } + #[cfg(test)] + pub(crate) fn set_enriched_header_for_tests(&self, header: Option) { + *self + .enriched_header + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) = header; + } + async fn fetch_workspace_git_metadata(&self) -> WorkspaceGitMetadata { let (latest_git_commit_hash, associated_remote_urls, has_changes) = tokio::join!( get_head_commit_hash(&self.cwd), diff --git a/codex-rs/core/tests/suite/search_tool.rs b/codex-rs/core/tests/suite/search_tool.rs index 118f1bd585c9..62659aa4de68 100644 --- a/codex-rs/core/tests/suite/search_tool.rs +++ b/codex-rs/core/tests/suite/search_tool.rs @@ -30,6 +30,8 @@ use core_test_support::wait_for_event; use pretty_assertions::assert_eq; use serde_json::Value; use serde_json::json; +use std::fs; +use std::process::Command; const SEARCH_TOOL_DESCRIPTION_SNIPPETS: [&str; 2] = [ "You have access to all the tools of the following apps/connectors", @@ -86,6 +88,15 @@ fn tool_search_output_tools(request: &ResponsesRequest, call_id: &str) -> Vec Option { + request + .body_json::() + .ok()? + .get("method") + .and_then(Value::as_str) + .map(str::to_string) +} + fn configure_apps(config: &mut Config, apps_base_url: &str) { config .features @@ -499,5 +510,195 @@ async fn tool_search_returns_deferred_tools_without_follow_up_tool_injection() - "post-tool follow-up should still rely on tool_search_output history, not tool injection: {third_request_tools:?}" ); + let mcp_requests = server + .received_requests() + .await + .expect("failed to fetch recorded requests"); + let tools_list_request = mcp_requests + .iter() + .find(|request| json_rpc_method(request).as_deref() == Some("tools/list")) + .expect("tools/list MCP request"); + assert!( + tools_list_request + .headers + .get("x-codex-turn-metadata") + .is_none(), + "tools/list should not include per-turn MCP headers" + ); + + let tools_call_request = mcp_requests + .iter() + .find(|request| json_rpc_method(request).as_deref() == Some("tools/call")) + .expect("tools/call MCP request"); + let session_id_header = tools_call_request + .headers + .get("session_id") + .expect("tools/call session_id header"); + let request_id_header = tools_call_request + .headers + .get("x-client-request-id") + .expect("tools/call x-client-request-id header"); + let turn_metadata_header = tools_call_request + .headers + .get("x-codex-turn-metadata") + .expect("tools/call turn metadata header"); + assert_eq!( + session_id_header + .to_str() + .expect("session_id header to be utf8"), + request_id_header + .to_str() + .expect("x-client-request-id header to be utf8") + ); + assert!( + turn_metadata_header + .to_str() + .expect("turn metadata header to be utf8") + .contains("\"turn_id\""), + "expected turn metadata header to contain serialized turn metadata" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn apps_mcp_tool_call_uses_enriched_turn_metadata_header() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let apps_server = AppsTestServer::mount_searchable(&server).await?; + let call_id = "tool-search-git-metadata"; + let mock = mount_sse_sequence( + &server, + vec![ + sse(vec![ + ev_response_created("resp-1"), + ev_tool_search_call( + call_id, + &json!({ + "query": "create calendar event", + "limit": 1, + }), + ), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_response_created("resp-2"), + json!({ + "type": "response.output_item.done", + "item": { + "type": "function_call", + "call_id": "calendar-call-git-metadata", + "name": SEARCH_CALENDAR_CREATE_TOOL, + "namespace": SEARCH_CALENDAR_NAMESPACE, + "arguments": serde_json::to_string(&json!({ + "title": "Lunch", + "starts_at": "2026-03-10T12:00:00Z" + })).expect("serialize calendar args") + } + }), + ev_completed("resp-2"), + ]), + sse(vec![ + ev_response_created("resp-3"), + ev_assistant_message("msg-1", "done"), + ev_completed("resp-3"), + ]), + ], + ) + .await; + + let mut builder = configured_builder(apps_server.chatgpt_base_url.clone()); + let test = builder.build(&server).await?; + + let cwd = test.cwd_path().to_path_buf(); + assert!( + Command::new("git") + .arg("init") + .current_dir(&cwd) + .status()? + .success() + ); + assert!( + Command::new("git") + .args(["config", "user.name", "Codex Test"]) + .current_dir(&cwd) + .status()? + .success() + ); + assert!( + Command::new("git") + .args(["config", "user.email", "codex@example.com"]) + .current_dir(&cwd) + .status()? + .success() + ); + assert!( + Command::new("git") + .args(["remote", "add", "origin", "https://example.test/repo.git"]) + .current_dir(&cwd) + .status()? + .success() + ); + for idx in 0..400 { + fs::write( + cwd.join(format!("file-{idx:04}.txt")), + format!("fixture file {idx}\n"), + )?; + } + assert!( + Command::new("git") + .args(["add", "."]) + .current_dir(&cwd) + .status()? + .success() + ); + assert!( + Command::new("git") + .args(["commit", "-m", "init"]) + .current_dir(&cwd) + .status()? + .success() + ); + + test.codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "Find the calendar create tool".to_string(), + text_elements: Vec::new(), + }], + final_output_json_schema: None, + }) + .await?; + + wait_for_event(&test.codex, |event| { + matches!(event, EventMsg::TurnComplete(_)) + }) + .await; + + let _requests = mock.requests(); + let mcp_requests = server + .received_requests() + .await + .expect("failed to fetch recorded requests"); + let tools_call_request = mcp_requests + .iter() + .find(|request| json_rpc_method(request).as_deref() == Some("tools/call")) + .expect("tools/call MCP request"); + let turn_metadata_header = tools_call_request + .headers + .get("x-codex-turn-metadata") + .expect("tools/call turn metadata header") + .to_str() + .expect("turn metadata header to be utf8"); + let parsed: Value = serde_json::from_str(turn_metadata_header)?; + assert!( + parsed + .get("workspaces") + .and_then(Value::as_object) + .is_some_and(|workspaces| !workspaces.is_empty()), + "expected enriched MCP turn metadata header with workspace git metadata, got {parsed:#?}" + ); + Ok(()) } diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index b898403b25c7..cf4f90ad3b05 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -5,6 +5,7 @@ use std::io; use std::path::PathBuf; use std::process::Stdio; use std::sync::Arc; +use std::sync::Mutex as StdMutex; use std::time::Duration; use anyhow::Result; @@ -22,6 +23,7 @@ use reqwest::header::HeaderMap; use reqwest::header::WWW_AUTHENTICATE; use rmcp::model::CallToolRequestParams; use rmcp::model::CallToolResult; +use rmcp::model::ClientJsonRpcMessage; use rmcp::model::ClientNotification; use rmcp::model::ClientRequest; use rmcp::model::CreateElicitationRequestParams; @@ -83,14 +85,45 @@ const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id"; const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192; +fn message_uses_request_scoped_headers(message: &ClientJsonRpcMessage) -> bool { + matches!( + message, + ClientJsonRpcMessage::Request(request) + if request.request.method() == "tools/call" + ) +} + +fn apply_request_scoped_headers( + mut request: reqwest::RequestBuilder, + request_headers_state: &Arc>>, +) -> reqwest::RequestBuilder { + let extra_headers = request_headers_state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone(); + if let Some(extra_headers) = extra_headers { + for (name, value) in &extra_headers { + request = request.header(name, value.clone()); + } + } + request +} + #[derive(Clone)] struct StreamableHttpResponseClient { inner: reqwest::Client, + request_headers_state: Arc>>, } impl StreamableHttpResponseClient { - fn new(inner: reqwest::Client) -> Self { - Self { inner } + fn new( + inner: reqwest::Client, + request_headers_state: Arc>>, + ) -> Self { + Self { + inner, + request_headers_state, + } } fn reqwest_error( @@ -133,6 +166,9 @@ impl StreamableHttpClient for StreamableHttpResponseClient { if let Some(session_id_value) = session_id.as_ref() { request = request.header(HEADER_SESSION_ID, session_id_value.as_ref()); } + if message_uses_request_scoped_headers(&message) { + request = apply_request_scoped_headers(request, &self.request_headers_state); + } let response = request .json(&message) @@ -472,6 +508,7 @@ pub struct RmcpClient { transport_recipe: TransportRecipe, initialize_context: Mutex>, session_recovery_lock: Mutex<()>, + request_headers: Option>>>, } impl RmcpClient { @@ -489,9 +526,10 @@ impl RmcpClient { env_vars: env_vars.to_vec(), cwd, }; - let transport = Self::create_pending_transport(&transport_recipe) - .await - .map_err(io::Error::other)?; + let transport = + Self::create_pending_transport(&transport_recipe, /*request_headers*/ None) + .await + .map_err(io::Error::other)?; Ok(Self { state: Mutex::new(ClientState::Connecting { @@ -500,6 +538,7 @@ impl RmcpClient { transport_recipe, initialize_context: Mutex::new(None), session_recovery_lock: Mutex::new(()), + request_headers: None, }) } @@ -511,6 +550,7 @@ impl RmcpClient { http_headers: Option>, env_http_headers: Option>, store_mode: OAuthCredentialsStoreMode, + request_headers: Arc>>, ) -> Result { let transport_recipe = TransportRecipe::StreamableHttp { server_name: server_name.to_string(), @@ -520,7 +560,9 @@ impl RmcpClient { env_http_headers, store_mode, }; - let transport = Self::create_pending_transport(&transport_recipe).await?; + let transport = + Self::create_pending_transport(&transport_recipe, Some(Arc::clone(&request_headers))) + .await?; Ok(Self { state: Mutex::new(ClientState::Connecting { transport: Some(transport), @@ -528,6 +570,7 @@ impl RmcpClient { transport_recipe, initialize_context: Mutex::new(None), session_recovery_lock: Mutex::new(()), + request_headers: Some(request_headers), }) } @@ -830,6 +873,7 @@ impl RmcpClient { async fn create_pending_transport( transport_recipe: &TransportRecipe, + request_headers: Option>>>, ) -> Result { match transport_recipe { TransportRecipe::Stdio { @@ -946,7 +990,12 @@ impl RmcpClient { .auth_header(access_token); let http_client = build_http_client(&default_headers)?; let transport = StreamableHttpClientTransport::with_client( - StreamableHttpResponseClient::new(http_client), + StreamableHttpResponseClient::new( + http_client, + request_headers + .clone() + .unwrap_or_else(|| Arc::new(StdMutex::new(None))), + ), http_config, ); Ok(PendingTransport::StreamableHttp { transport }) @@ -963,7 +1012,12 @@ impl RmcpClient { let http_client = build_http_client(&default_headers)?; let transport = StreamableHttpClientTransport::with_client( - StreamableHttpResponseClient::new(http_client), + StreamableHttpResponseClient::new( + http_client, + request_headers + .clone() + .unwrap_or_else(|| Arc::new(StdMutex::new(None))), + ), http_config, ); Ok(PendingTransport::StreamableHttp { transport }) @@ -1111,7 +1165,9 @@ impl RmcpClient { .await .clone() .ok_or_else(|| anyhow!("MCP client cannot recover before initialize succeeds"))?; - let pending_transport = Self::create_pending_transport(&self.transport_recipe).await?; + let pending_transport = + Self::create_pending_transport(&self.transport_recipe, self.request_headers.clone()) + .await?; let (service, oauth_persistor, process_group_guard) = Self::connect_pending_transport( pending_transport, initialize_context.handler, @@ -1166,7 +1222,10 @@ async fn create_oauth_transport_and_runtime( } }; - let auth_client = AuthClient::new(StreamableHttpResponseClient::new(http_client), manager); + let auth_client = AuthClient::new( + StreamableHttpResponseClient::new(http_client, Arc::new(StdMutex::new(None))), + manager, + ); let auth_manager = auth_client.auth_manager.clone(); let transport = StreamableHttpClientTransport::with_client( diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index fb2fc96d20f1..8b03da8f1ad6 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -1,5 +1,7 @@ use std::net::TcpListener; use std::path::PathBuf; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; use std::time::Duration; use std::time::Instant; @@ -77,6 +79,7 @@ async fn create_client(base_url: &str) -> anyhow::Result { None, None, OAuthCredentialsStoreMode::File, + Arc::new(StdMutex::new(None)), ) .await?;