diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 2b08f292..1ee302e9 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -230,3 +230,13 @@ path = "tests/test_sampling.rs" name = "test_close_connection" required-features = ["server", "client"] path = "tests/test_close_connection.rs" + +[[test]] +name = "test_custom_headers" +required-features = [ + "client", + "server", + "transport-streamable-http-client-reqwest", + "transport-streamable-http-server", +] +path = "tests/test_custom_headers.rs" diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index ad2d69ab..a612a52b 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -16,6 +16,8 @@ use thiserror::Error; use tokio::sync::{Mutex, RwLock}; use tracing::{debug, error, warn}; +use crate::transport::common::http_header::HEADER_MCP_PROTOCOL_VERSION; + const DEFAULT_EXCHANGE_URL: &str = "http://localhost"; /// Stored credentials for OAuth2 authorization @@ -1051,7 +1053,7 @@ impl AuthorizationManager { let response = match self .http_client .get(discovery_url.clone()) - .header("MCP-Protocol-Version", "2024-11-05") + .header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05") .send() .await { @@ -1171,7 +1173,7 @@ impl AuthorizationManager { let response = match self .http_client .get(url.clone()) - .header("MCP-Protocol-Version", "2024-11-05") + .header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05") .send() .await { @@ -1224,7 +1226,7 @@ impl AuthorizationManager { let response = match self .http_client .get(resource_metadata_url.clone()) - .header("MCP-Protocol-Version", "2024-11-05") + .header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05") .send() .await { diff --git a/crates/rmcp/src/transport/common/auth/streamable_http_client.rs b/crates/rmcp/src/transport/common/auth/streamable_http_client.rs index 49ebefcd..35e3ed5a 100644 --- a/crates/rmcp/src/transport/common/auth/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/auth/streamable_http_client.rs @@ -1,3 +1,7 @@ +use std::collections::HashMap; + +use http::{HeaderName, HeaderValue}; + use crate::transport::{ auth::AuthClient, streamable_http_client::{StreamableHttpClient, StreamableHttpError}, @@ -47,6 +51,7 @@ where message: crate::model::ClientJsonRpcMessage, session_id: Option>, mut auth_token: Option, + custom_headers: HashMap, ) -> Result< crate::transport::streamable_http_client::StreamableHttpPostResponse, StreamableHttpError, @@ -55,7 +60,7 @@ where auth_token = Some(self.get_access_token().await?); } self.http_client - .post_message(uri, message, session_id, auth_token) + .post_message(uri, message, session_id, auth_token, custom_headers) .await } } diff --git a/crates/rmcp/src/transport/common/http_header.rs b/crates/rmcp/src/transport/common/http_header.rs index 84bc7bfb..44175326 100644 --- a/crates/rmcp/src/transport/common/http_header.rs +++ b/crates/rmcp/src/transport/common/http_header.rs @@ -1,4 +1,5 @@ pub const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; pub const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id"; +pub const HEADER_MCP_PROTOCOL_VERSION: &str = "MCP-Protocol-Version"; pub const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream"; pub const JSON_MIME_TYPE: &str = "application/json"; diff --git a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs index cc26bdc9..b4cdafd1 100644 --- a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs @@ -1,7 +1,7 @@ -use std::{borrow::Cow, sync::Arc}; +use std::{borrow::Cow, collections::HashMap, sync::Arc}; use futures::{StreamExt, stream::BoxStream}; -use http::header::WWW_AUTHENTICATE; +use http::{HeaderName, HeaderValue, header::WWW_AUTHENTICATE}; use reqwest::header::ACCEPT; use sse_stream::{Sse, SseStream}; @@ -9,7 +9,8 @@ use crate::{ model::{ClientJsonRpcMessage, ServerJsonRpcMessage}, transport::{ common::http_header::{ - EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE, + EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION, + HEADER_SESSION_ID, JSON_MIME_TYPE, }, streamable_http_client::*, }, @@ -94,6 +95,7 @@ impl StreamableHttpClient for reqwest::Client { message: ClientJsonRpcMessage, session_id: Option>, auth_token: Option, + custom_headers: HashMap, ) -> Result> { let mut request = self .post(uri.as_ref()) @@ -101,6 +103,26 @@ impl StreamableHttpClient for reqwest::Client { if let Some(auth_header) = auth_token { request = request.bearer_auth(auth_header); } + + // Apply custom headers + let reserved_headers = [ + ACCEPT.as_str(), + HEADER_SESSION_ID, + HEADER_MCP_PROTOCOL_VERSION, + HEADER_LAST_EVENT_ID, + ]; + for (name, value) in custom_headers { + if reserved_headers + .iter() + .any(|&r| name.as_str().eq_ignore_ascii_case(r)) + { + return Err(StreamableHttpError::ReservedHeaderConflict( + name.to_string(), + )); + } + + request = request.header(name, value); + } if let Some(session_id) = session_id { request = request.header(HEADER_SESSION_ID, session_id.as_ref()); } diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 550d261b..37653c42 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -1,6 +1,7 @@ -use std::{borrow::Cow, sync::Arc, time::Duration}; +use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration}; use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream}; +use http::{HeaderName, HeaderValue}; pub use sse_stream::Error as SseError; use sse_stream::Sse; use thiserror::Error; @@ -76,6 +77,8 @@ pub enum StreamableHttpError { AuthRequired(AuthRequiredError), #[error("Insufficient scope")] InsufficientScope(InsufficientScopeError), + #[error("Header name '{0}' is reserved and conflicts with default headers")] + ReservedHeaderConflict(String), } #[derive(Debug, Clone, Error)] @@ -173,6 +176,7 @@ pub trait StreamableHttpClient: Clone + Send + 'static { message: ClientJsonRpcMessage, session_id: Option>, auth_header: Option, + custom_headers: HashMap, ) -> impl Future>> + Send + '_; @@ -324,6 +328,7 @@ impl Worker for StreamableHttpClientWorker { initialize_request, None, self.config.auth_header, + self.config.custom_headers, ) .await { @@ -372,6 +377,7 @@ impl Worker for StreamableHttpClientWorker { initialized_notification.message, session_id.clone(), config.auth_header.clone(), + config.custom_headers.clone(), ) .await .map_err(WorkerQuitReason::fatal_context( @@ -477,6 +483,7 @@ impl Worker for StreamableHttpClientWorker { message, session_id.clone(), config.auth_header.clone(), + config.custom_headers.clone(), ) .await; let send_result = match response { @@ -609,8 +616,10 @@ impl Worker for StreamableHttpClientWorker { /// StreamableHttpClientTransportConfig /// }; /// use std::sync::Arc; +/// use std::collections::HashMap; /// use futures::stream::BoxStream; /// use rmcp::model::ClientJsonRpcMessage; +/// use http::{HeaderName, HeaderValue}; /// use sse_stream::{Sse, Error as SseError}; /// /// #[derive(Clone)] @@ -634,6 +643,7 @@ impl Worker for StreamableHttpClientWorker { /// _message: ClientJsonRpcMessage, /// _session_id: Option>, /// _auth_header: Option, +/// _custom_headers: HashMap, /// ) -> Result> { /// todo!() /// } @@ -690,8 +700,10 @@ impl StreamableHttpClientTransport { /// StreamableHttpClientTransportConfig /// }; /// use std::sync::Arc; + /// use std::collections::HashMap; /// use futures::stream::BoxStream; /// use rmcp::model::ClientJsonRpcMessage; + /// use http::{HeaderName, HeaderValue}; /// use sse_stream::{Sse, Error as SseError}; /// /// // Define your custom client @@ -716,6 +728,7 @@ impl StreamableHttpClientTransport { /// _message: ClientJsonRpcMessage, /// _session_id: Option>, /// _auth_header: Option, + /// _custom_headers: HashMap, /// ) -> Result> { /// todo!() /// } @@ -759,6 +772,8 @@ pub struct StreamableHttpClientTransportConfig { pub allow_stateless: bool, /// The value to send in the authorization header pub auth_header: Option, + /// Custom HTTP headers to include with every request + pub custom_headers: HashMap, } impl StreamableHttpClientTransportConfig { @@ -779,6 +794,33 @@ impl StreamableHttpClientTransportConfig { self.auth_header = Some(value.into()); self } + + /// Set custom HTTP headers to include with every request + /// + /// # Arguments + /// + /// * `custom_headers` - A HashMap of header names to header values + /// + /// # Example + /// + /// ```rust,no_run + /// use std::collections::HashMap; + /// use http::{HeaderName, HeaderValue}; + /// use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; + /// + /// let mut headers = HashMap::new(); + /// headers.insert( + /// HeaderName::from_static("x-custom-header"), + /// HeaderValue::from_static("custom-value") + /// ); + /// + /// let config = StreamableHttpClientTransportConfig::with_uri("http://localhost:8000") + /// .custom_headers(headers); + /// ``` + pub fn custom_headers(mut self, custom_headers: HashMap) -> Self { + self.custom_headers = custom_headers; + self + } } impl Default for StreamableHttpClientTransportConfig { @@ -789,6 +831,7 @@ impl Default for StreamableHttpClientTransportConfig { channel_buffer_capacity: 16, allow_stateless: true, auth_header: None, + custom_headers: HashMap::new(), } } } diff --git a/crates/rmcp/tests/test_custom_headers.rs b/crates/rmcp/tests/test_custom_headers.rs new file mode 100644 index 00000000..c9307109 --- /dev/null +++ b/crates/rmcp/tests/test_custom_headers.rs @@ -0,0 +1,531 @@ +use std::collections::HashMap; + +use http::{HeaderName, HeaderValue}; + +#[test] +fn test_config_custom_headers_default_empty() { + use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; + + let config = StreamableHttpClientTransportConfig::with_uri("http://localhost:8080"); + assert!( + config.custom_headers.is_empty(), + "Default custom_headers should be empty" + ); +} + +#[test] +fn test_config_custom_headers_builder() { + use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; + + let mut headers = HashMap::new(); + headers.insert( + HeaderName::from_static("x-test-header"), + HeaderValue::from_static("test-value"), + ); + + let config = StreamableHttpClientTransportConfig::with_uri("http://localhost:8080") + .custom_headers(headers); + + assert_eq!(config.custom_headers.len(), 1); + assert_eq!( + config + .custom_headers + .get(&HeaderName::from_static("x-test-header")), + Some(&HeaderValue::from_static("test-value")) + ); +} + +#[test] +fn test_config_custom_headers_multiple_values() { + use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; + + let mut headers = HashMap::new(); + headers.insert( + HeaderName::from_static("x-header-1"), + HeaderValue::from_static("value-1"), + ); + headers.insert( + HeaderName::from_static("x-header-2"), + HeaderValue::from_static("value-2"), + ); + headers.insert( + HeaderName::from_static("authorization"), + HeaderValue::from_static("Bearer token123"), + ); + + let config = StreamableHttpClientTransportConfig::with_uri("http://localhost:8080") + .custom_headers(headers); + + assert_eq!(config.custom_headers.len(), 3); + assert_eq!( + config + .custom_headers + .get(&HeaderName::from_static("x-header-1")), + Some(&HeaderValue::from_static("value-1")) + ); + assert_eq!( + config + .custom_headers + .get(&HeaderName::from_static("x-header-2")), + Some(&HeaderValue::from_static("value-2")) + ); + assert_eq!( + config + .custom_headers + .get(&HeaderName::from_static("authorization")), + Some(&HeaderValue::from_static("Bearer token123")) + ); +} + +#[test] +fn test_config_auth_header_and_custom_headers_together() { + use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; + + let mut headers = HashMap::new(); + headers.insert( + HeaderName::from_static("x-custom-header"), + HeaderValue::from_static("custom-value"), + ); + + let config = StreamableHttpClientTransportConfig::with_uri("http://localhost:8080") + .auth_header("my-bearer-token") + .custom_headers(headers); + + assert_eq!(config.auth_header, Some("my-bearer-token".to_string())); + assert_eq!( + config + .custom_headers + .get(&HeaderName::from_static("x-custom-header")), + Some(&HeaderValue::from_static("custom-value")) + ); +} + +/// Unit test: post_message should reject reserved header "accept" +#[tokio::test] +#[cfg(feature = "transport-streamable-http-client-reqwest")] +async fn test_post_message_rejects_accept_header() { + use std::sync::Arc; + + use rmcp::{ + model::{ClientJsonRpcMessage, ClientRequest, PingRequest, RequestId}, + transport::streamable_http_client::{StreamableHttpClient, StreamableHttpError}, + }; + + let client = reqwest::Client::new(); + let mut custom_headers = HashMap::new(); + custom_headers.insert( + HeaderName::from_static("accept"), + HeaderValue::from_static("text/html"), + ); + + let message = ClientJsonRpcMessage::request( + ClientRequest::PingRequest(PingRequest::default()), + RequestId::Number(1), + ); + + let result = client + .post_message( + Arc::from("http://localhost:9999/mcp"), + message, + None, + None, + custom_headers, + ) + .await; + + assert!(result.is_err(), "Should reject 'accept' header"); + match result { + Err(StreamableHttpError::ReservedHeaderConflict(header_name)) => { + assert_eq!( + header_name, "accept", + "Error should indicate 'accept' header" + ); + } + other => panic!("Expected ReservedHeaderConflict error, got: {:?}", other), + } +} + +/// Unit test: post_message should reject reserved header "mcp-session-id" +#[tokio::test] +#[cfg(feature = "transport-streamable-http-client-reqwest")] +async fn test_post_message_rejects_mcp_session_id() { + use std::sync::Arc; + + use rmcp::{ + model::{ClientJsonRpcMessage, ClientRequest, PingRequest, RequestId}, + transport::streamable_http_client::{StreamableHttpClient, StreamableHttpError}, + }; + + let client = reqwest::Client::new(); + let mut custom_headers = HashMap::new(); + custom_headers.insert( + HeaderName::from_static("mcp-session-id"), + HeaderValue::from_static("my-session"), + ); + + let message = ClientJsonRpcMessage::request( + ClientRequest::PingRequest(PingRequest::default()), + RequestId::Number(1), + ); + + let result = client + .post_message( + Arc::from("http://localhost:9999/mcp"), + message, + None, + None, + custom_headers, + ) + .await; + + assert!(result.is_err(), "Should reject 'mcp-session-id' header"); + match result { + Err(StreamableHttpError::ReservedHeaderConflict(header_name)) => { + assert_eq!( + header_name, "mcp-session-id", + "Error should indicate 'mcp-session-id' header" + ); + } + other => panic!("Expected ReservedHeaderConflict error, got: {:?}", other), + } +} + +/// Unit test: post_message should reject reserved header "mcp-protocol-version" +#[tokio::test] +#[cfg(feature = "transport-streamable-http-client-reqwest")] +async fn test_post_message_rejects_mcp_protocol_version() { + use std::sync::Arc; + + use rmcp::{ + model::{ClientJsonRpcMessage, ClientRequest, PingRequest, RequestId}, + transport::streamable_http_client::{StreamableHttpClient, StreamableHttpError}, + }; + + let client = reqwest::Client::new(); + let mut custom_headers = HashMap::new(); + custom_headers.insert( + HeaderName::from_static("mcp-protocol-version"), + HeaderValue::from_static("1.0"), + ); + + let message = ClientJsonRpcMessage::request( + ClientRequest::PingRequest(PingRequest::default()), + RequestId::Number(1), + ); + + let result = client + .post_message( + Arc::from("http://localhost:9999/mcp"), + message, + None, + None, + custom_headers, + ) + .await; + + assert!( + result.is_err(), + "Should reject 'mcp-protocol-version' header" + ); + match result { + Err(StreamableHttpError::ReservedHeaderConflict(header_name)) => { + assert_eq!( + header_name, "mcp-protocol-version", + "Error should indicate 'mcp-protocol-version' header" + ); + } + other => panic!("Expected ReservedHeaderConflict error, got: {:?}", other), + } +} + +/// Unit test: post_message should reject reserved header "last-event-id" +#[tokio::test] +#[cfg(feature = "transport-streamable-http-client-reqwest")] +async fn test_post_message_rejects_last_event_id() { + use std::sync::Arc; + + use rmcp::{ + model::{ClientJsonRpcMessage, ClientRequest, PingRequest, RequestId}, + transport::streamable_http_client::{StreamableHttpClient, StreamableHttpError}, + }; + + let client = reqwest::Client::new(); + let mut custom_headers = HashMap::new(); + custom_headers.insert( + HeaderName::from_static("last-event-id"), + HeaderValue::from_static("event-123"), + ); + + let message = ClientJsonRpcMessage::request( + ClientRequest::PingRequest(PingRequest::default()), + RequestId::Number(1), + ); + + let result = client + .post_message( + Arc::from("http://localhost:9999/mcp"), + message, + None, + None, + custom_headers, + ) + .await; + + assert!(result.is_err(), "Should reject 'last-event-id' header"); + match result { + Err(StreamableHttpError::ReservedHeaderConflict(header_name)) => { + assert_eq!( + header_name, "last-event-id", + "Error should indicate 'last-event-id' header" + ); + } + other => panic!("Expected ReservedHeaderConflict error, got: {:?}", other), + } +} + +/// Unit test: post_message should do case-insensitive matching for reserved headers +#[tokio::test] +#[cfg(feature = "transport-streamable-http-client-reqwest")] +async fn test_post_message_case_insensitive_matching() { + use std::sync::Arc; + + use rmcp::{ + model::{ClientJsonRpcMessage, ClientRequest, PingRequest, RequestId}, + transport::streamable_http_client::{StreamableHttpClient, StreamableHttpError}, + }; + + let client = reqwest::Client::new(); + let message = ClientJsonRpcMessage::request( + ClientRequest::PingRequest(PingRequest::default()), + RequestId::Number(1), + ); + + // Test different casings + let test_cases = vec![ + ("Accept", "Should reject 'Accept' (capitalized)"), + ("ACCEPT", "Should reject 'ACCEPT' (uppercase)"), + ("Mcp-Session-Id", "Should reject 'Mcp-Session-Id'"), + ("MCP-SESSION-ID", "Should reject 'MCP-SESSION-ID'"), + ]; + + for (header_name, error_msg) in test_cases { + let mut custom_headers = HashMap::new(); + custom_headers.insert( + HeaderName::from_bytes(header_name.as_bytes()).unwrap(), + HeaderValue::from_static("value"), + ); + + let result = client + .post_message( + Arc::from("http://localhost:9999/mcp"), + message.clone(), + None, + None, + custom_headers, + ) + .await; + + assert!(result.is_err(), "{}", error_msg); + if let Err(StreamableHttpError::ReservedHeaderConflict(_)) = result { + // Success + } else { + panic!( + "{}: Expected ReservedHeaderConflict, got: {:?}", + error_msg, result + ); + } + } +} + +/// Integration test: Verify that custom headers are actually sent in MCP HTTP requests +#[tokio::test] +#[cfg(all( + feature = "transport-streamable-http-client", + feature = "transport-streamable-http-client-reqwest" +))] +async fn test_mcp_custom_headers_sent_to_server() -> anyhow::Result<()> { + use std::{net::SocketAddr, sync::Arc}; + + use axum::{ + Router, body::Bytes, extract::State, http::StatusCode, response::IntoResponse, + routing::post, + }; + use rmcp::{ + ServiceExt, + transport::{ + StreamableHttpClientTransport, + streamable_http_client::StreamableHttpClientTransportConfig, + }, + }; + use serde_json::json; + use tokio::sync::Mutex; + + // State to capture received headers + #[derive(Clone)] + struct ServerState { + received_headers: Arc>>, + initialize_called: Arc, + } + + // Handler that captures headers from MCP requests + async fn mcp_handler( + State(state): State, + headers: http::HeaderMap, + body: Bytes, + ) -> impl IntoResponse { + // Capture all custom headers (starting with x-) + let mut headers_map = HashMap::new(); + for (name, value) in headers.iter() { + let name_str = name.as_str(); + if name_str.starts_with("x-") { + if let Ok(v) = value.to_str() { + headers_map.insert(name_str.to_string(), v.to_string()); + } + } + } + + // Store captured headers + let mut stored = state.received_headers.lock().await; + stored.extend(headers_map); + + // Parse the MCP request + if let Ok(json_body) = serde_json::from_slice::(&body) { + if let Some(method) = json_body.get("method").and_then(|m| m.as_str()) { + if method == "initialize" { + state.initialize_called.notify_one(); + // Return a valid MCP initialize response with session header + let response = json!({ + "jsonrpc": "2.0", + "id": json_body.get("id"), + "result": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "serverInfo": { + "name": "test-server", + "version": "1.0.0" + } + } + }); + return ( + StatusCode::OK, + [ + (http::header::CONTENT_TYPE, "application/json"), + ( + http::HeaderName::from_static("mcp-session-id"), + "test-session-123", + ), + ], + response.to_string(), + ); + } else if method == "notifications/initialized" { + // For initialized notification, return 202 Accepted + return ( + StatusCode::ACCEPTED, + [ + (http::header::CONTENT_TYPE, "application/json"), + ( + http::HeaderName::from_static("mcp-session-id"), + "test-session-123", + ), + ], + String::new(), + ); + } + } + } + + // Default response for other requests + let response = json!({ + "jsonrpc": "2.0", + "id": 1, + "result": {} + }); + ( + StatusCode::OK, + [ + (http::header::CONTENT_TYPE, "application/json"), + ( + http::HeaderName::from_static("mcp-session-id"), + "test-session-123", + ), + ], + response.to_string(), + ) + } + + // Setup test server + let state = ServerState { + received_headers: Arc::new(Mutex::new(HashMap::new())), + initialize_called: Arc::new(tokio::sync::Notify::new()), + }; + + let app = Router::new() + .route("/mcp", post(mcp_handler)) + .with_state(state.clone()); + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = tokio::net::TcpListener::bind(addr).await?; + let port = listener.local_addr()?.port(); + + let server_handle = tokio::spawn(async move { axum::serve(listener, app).await }); + + // Wait for server to be ready + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + // Create MCP client with custom headers + let mut custom_headers = HashMap::new(); + custom_headers.insert( + HeaderName::from_static("x-test-header"), + HeaderValue::from_static("test-value-123"), + ); + custom_headers.insert( + HeaderName::from_static("x-another-header"), + HeaderValue::from_static("another-value-456"), + ); + custom_headers.insert( + HeaderName::from_static("x-client-id"), + HeaderValue::from_static("test-client"), + ); + + let config = + StreamableHttpClientTransportConfig::with_uri(format!("http://127.0.0.1:{}/mcp", port)) + .custom_headers(custom_headers); + + let transport = StreamableHttpClientTransport::from_config(config); + + // Start MCP client with empty handler (this will trigger initialize request) + let client = ().serve(transport).await.expect("Failed to start client"); + + // Wait for initialize to be called + tokio::time::timeout( + std::time::Duration::from_secs(5), + state.initialize_called.notified(), + ) + .await + .expect("Initialize request should be received"); + + // Verify that custom headers were received + let headers = state.received_headers.lock().await; + + assert_eq!( + headers.get("x-test-header"), + Some(&"test-value-123".to_string()), + "Custom header x-test-header should be sent to MCP server" + ); + assert_eq!( + headers.get("x-another-header"), + Some(&"another-value-456".to_string()), + "Custom header x-another-header should be sent to MCP server" + ); + assert_eq!( + headers.get("x-client-id"), + Some(&"test-client".to_string()), + "Custom header x-client-id should be sent to MCP server" + ); + + // Cleanup + drop(client); + server_handle.abort(); + + Ok(()) +}