Skip to content
41 changes: 41 additions & 0 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ use futures::future::BoxFuture;
use futures::future::Shared;
use futures::prelude::*;
use futures::stream::FuturesOrdered;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use rmcp::model::ListResourceTemplatesResult;
use rmcp::model::ListResourcesResult;
use rmcp::model::PaginatedRequestParams;
Expand Down Expand Up @@ -3951,6 +3953,45 @@ impl Session {
.await
}

pub(crate) async fn sync_mcp_request_headers_for_turn(&self, turn_context: &TurnContext) {
let mut request_headers = HeaderMap::new();
let session_id = self.conversation_id.to_string();
if let Ok(value) = HeaderValue::from_str(&session_id) {
request_headers.insert("session_id", value.clone());
request_headers.insert("x-client-request-id", value);
}
if let Some(turn_metadata) = turn_context.turn_metadata_state.current_header_value()
&& let Ok(value) = HeaderValue::from_str(&turn_metadata)
{
request_headers.insert(crate::X_CODEX_TURN_METADATA_HEADER, value);
}

let request_headers = if request_headers.is_empty() {
None
} else {
Some(request_headers)
};
self.services
.mcp_connection_manager
.read()
.await
.set_request_headers_for_server(
crate::mcp::CODEX_APPS_MCP_SERVER_NAME,
request_headers,
);
}

pub(crate) async fn clear_mcp_request_headers(&self) {
self.services
.mcp_connection_manager
.read()
.await
.set_request_headers_for_server(
crate::mcp::CODEX_APPS_MCP_SERVER_NAME,
/*request_headers*/ None,
);
}

pub(crate) async fn parse_mcp_tool_name(
&self,
name: &str,
Expand Down
35 changes: 33 additions & 2 deletions codex-rs/core/src/mcp_connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ impl ManagedClient {
#[derive(Clone)]
struct AsyncManagedClient {
client: Shared<BoxFuture<'static, Result<ManagedClient, StartupOutcomeError>>>,
request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
startup_snapshot: Option<Vec<ToolInfo>>,
startup_complete: Arc<AtomicBool>,
tool_plugin_provenance: Arc<ToolPluginProvenance>,
Expand All @@ -448,17 +449,26 @@ impl AsyncManagedClient {
codex_apps_tools_cache_context.as_ref(),
)
.map(|tools| filter_tools(tools, &tool_filter));
let request_headers = Arc::new(StdMutex::new(None));
let startup_tool_filter = tool_filter;
let startup_complete = Arc::new(AtomicBool::new(false));
let startup_complete_for_fut = Arc::clone(&startup_complete);
let request_headers_for_client = Arc::clone(&request_headers);
let fut = async move {
let outcome = async {
if let Err(error) = validate_mcp_server_name(&server_name) {
return Err(error.into());
}

let client =
Arc::new(make_rmcp_client(&server_name, config.transport, store_mode).await?);
let client = Arc::new(
make_rmcp_client(
&server_name,
config.transport,
store_mode,
request_headers_for_client,
)
.await?,
);
match start_server_task(
server_name,
client,
Expand Down Expand Up @@ -495,6 +505,7 @@ impl AsyncManagedClient {

Self {
client,
request_headers,
startup_snapshot,
startup_complete,
tool_plugin_provenance,
Expand Down Expand Up @@ -576,6 +587,14 @@ impl AsyncManagedClient {
let managed = self.client().await?;
managed.notify_sandbox_state_change(sandbox_state).await
}

fn set_request_headers(&self, request_headers: Option<reqwest::header::HeaderMap>) {
let mut guard = self
.request_headers
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*guard = request_headers;
}
}

pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state";
Expand Down Expand Up @@ -1046,6 +1065,16 @@ impl McpConnectionManager {
})
}

pub(crate) fn set_request_headers_for_server(
&self,
server_name: &str,
request_headers: Option<reqwest::header::HeaderMap>,
) {
if let Some(client) = self.clients.get(server_name) {
client.set_request_headers(request_headers);
}
}

/// List resources from the specified server.
pub async fn list_resources(
&self,
Expand Down Expand Up @@ -1429,6 +1458,7 @@ async fn make_rmcp_client(
server_name: &str,
transport: McpServerTransportConfig,
store_mode: OAuthCredentialsStoreMode,
request_headers: Arc<StdMutex<Option<reqwest::header::HeaderMap>>>,
) -> Result<RmcpClient, StartupOutcomeError> {
match transport {
McpServerTransportConfig::Stdio {
Expand Down Expand Up @@ -1462,6 +1492,7 @@ async fn make_rmcp_client(
http_headers,
env_http_headers,
store_mode,
request_headers,
)
.await
.map_err(StartupOutcomeError::from)
Expand Down
5 changes: 5 additions & 0 deletions codex-rs/core/src/mcp_connection_manager_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use codex_protocol::protocol::McpAuthStatus;
use rmcp::model::JsonObject;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;
use tempfile::tempdir;

fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo {
Expand Down Expand Up @@ -413,6 +414,7 @@ async fn list_all_tools_uses_startup_snapshot_while_client_is_pending() {
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: pending_client,
request_headers: Arc::new(StdMutex::new(None)),
startup_snapshot: Some(startup_tools),
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
Expand All @@ -438,6 +440,7 @@ async fn list_all_tools_blocks_while_client_is_pending_without_startup_snapshot(
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: pending_client,
request_headers: Arc::new(StdMutex::new(None)),
startup_snapshot: None,
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
Expand All @@ -460,6 +463,7 @@ async fn list_all_tools_does_not_block_when_startup_snapshot_cache_hit_is_empty(
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: pending_client,
request_headers: Arc::new(StdMutex::new(None)),
startup_snapshot: Some(Vec::new()),
startup_complete: Arc::new(std::sync::atomic::AtomicBool::new(false)),
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
Expand Down Expand Up @@ -492,6 +496,7 @@ async fn list_all_tools_uses_startup_snapshot_when_client_startup_fails() {
CODEX_APPS_MCP_SERVER_NAME.to_string(),
AsyncManagedClient {
client: failed_client,
request_headers: Arc::new(StdMutex::new(None)),
startup_snapshot: Some(startup_tools),
startup_complete,
tool_plugin_provenance: Arc::new(ToolPluginProvenance::default()),
Expand Down
6 changes: 6 additions & 0 deletions codex-rs/core/src/tasks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ impl Session {
) {
self.abort_all_tasks(TurnAbortReason::Replaced).await;
self.clear_connector_selection().await;
self.sync_mcp_request_headers_for_turn(turn_context.as_ref())
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MCP client lifecycle is different than that of a turn and can issues parallel calls, so there's an inherent race condition changing headers inside of a turn. Given turns seem sequential, we set the headers as a turn is starting and unset them once done. This allows to have the headers in the MCP calls without needing blocking concurrency or introducing a race.

.await;

let task: Arc<dyn SessionTask> = Arc::new(task);
let task_kind = task.kind();
Expand Down Expand Up @@ -233,6 +235,7 @@ impl Session {
// in-flight approval wait can surface as a model-visible rejection before TurnAborted.
active_turn.clear_pending().await;
}
self.clear_mcp_request_headers().await;
}

pub async fn on_task_finished(
Expand Down Expand Up @@ -262,6 +265,9 @@ impl Session {
*active = None;
}
drop(active);
if should_clear_active_turn {
self.clear_mcp_request_headers().await;
}
if !pending_input.is_empty() {
for pending_input_item in pending_input {
match inspect_pending_input(self, &turn_context, pending_input_item).await {
Expand Down
Loading
Loading