From 2174a981d338c4c3327dab71c4e540425ed43876 Mon Sep 17 00:00:00 2001 From: Jeremiah Williams Date: Tue, 1 Jul 2025 18:10:01 -0700 Subject: [PATCH 1/2] Add native OAuth 2.0 authentication support to MCP client This code adds native OAuth 2.0 authentication support to the MCP client, enabling automatic authentication with OAuth-protected MCP servers. The implementation supports dynamic client registration and uses the standard OAuth 2.0 Authorization Code flow with PKCE. This code was written with the Notion MCP in mind and tested against that service, but should work with most OAuth 2.0-compliant remote MCPs. --- Cargo.lock | 7 + crates/mcp-client/Cargo.toml | 8 + crates/mcp-client/examples/test_auth.rs | 52 +++ crates/mcp-client/src/lib.rs | 2 + crates/mcp-client/src/oauth.rs | 419 ++++++++++++++++++ .../src/transport/streamable_http.rs | 67 ++- 6 files changed, 553 insertions(+), 2 deletions(-) create mode 100644 crates/mcp-client/examples/test_auth.rs create mode 100644 crates/mcp-client/src/oauth.rs diff --git a/Cargo.lock b/Cargo.lock index cab8ca6cd998..5786ec3792df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5311,14 +5311,20 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "axum", + "base64 0.22.1", + "chrono", "eventsource-client", "futures", "mcp-core", + "nanoid", "nix 0.30.1", "rand 0.8.5", "reqwest 0.11.27", "serde", "serde_json", + "serde_urlencoded", + "sha2", "thiserror 1.0.69", "tokio", "tokio-util", @@ -5327,6 +5333,7 @@ dependencies = [ "tracing", "tracing-subscriber", "url", + "webbrowser 1.0.4", ] [[package]] diff --git a/crates/mcp-client/Cargo.toml b/crates/mcp-client/Cargo.toml index 7188cf33792c..a678e8f20643 100644 --- a/crates/mcp-client/Cargo.toml +++ b/crates/mcp-client/Cargo.toml @@ -25,5 +25,13 @@ tower = { version = "0.4", features = ["timeout", "util"] } tower-service = "0.3" rand = "0.8" nix = { version = "0.30.1", features = ["process", "signal"] } +# OAuth dependencies +axum = { version = "0.8", features = ["query"] } +base64 = "0.22" +sha2 = "0.10" +chrono = { version = "0.4", features = ["serde"] } +nanoid = "0.4" +webbrowser = "1.0" +serde_urlencoded = "0.7" [dev-dependencies] diff --git a/crates/mcp-client/examples/test_auth.rs b/crates/mcp-client/examples/test_auth.rs new file mode 100644 index 000000000000..d4fba7d6f528 --- /dev/null +++ b/crates/mcp-client/examples/test_auth.rs @@ -0,0 +1,52 @@ +use anyhow::Result; +use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; +use mcp_client::transport::{StreamableHttpTransport, Transport}; +use std::collections::HashMap; +use std::time::Duration; +use tracing_subscriber::EnvFilter; + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::from_default_env() + .add_directive("mcp_client=debug".parse().unwrap()) + .add_directive("eventsource_client=info".parse().unwrap()), + ) + .init(); + + println!("Testing Streamable HTTP transport with auto-authentication..."); + + // Create the Streamable HTTP transport for any MCP service that supports OAuth + // This example uses a hypothetical MCP endpoint - replace with actual service + let mcp_endpoint = + std::env::var("MCP_ENDPOINT").unwrap_or_else(|_| "https://example.com/mcp".to_string()); + + println!("Using MCP endpoint: {}", mcp_endpoint); + + let transport = StreamableHttpTransport::new(&mcp_endpoint, HashMap::new()); + + // Start transport + let handle = transport.start().await?; + + // Create client + let mut client = McpClient::connect(handle, Duration::from_secs(30)).await?; + println!("Client created with Streamable HTTP transport\n"); + + // Initialize - this should trigger the OAuth flow if authentication is needed + let server_info = client + .initialize( + ClientInfo { + name: "streamable-http-auth-test".into(), + version: "1.0.0".into(), + }, + ClientCapabilities::default(), + ) + .await?; + + println!("Connected to server: {server_info:?}\n"); + println!("Authentication test completed successfully!"); + + Ok(()) +} diff --git a/crates/mcp-client/src/lib.rs b/crates/mcp-client/src/lib.rs index f6ed51dc467b..01f55864e0ba 100644 --- a/crates/mcp-client/src/lib.rs +++ b/crates/mcp-client/src/lib.rs @@ -1,8 +1,10 @@ pub mod client; +pub mod oauth; pub mod service; pub mod transport; pub use client::{ClientCapabilities, ClientInfo, Error, McpClient, McpClientTrait}; +pub use oauth::{authenticate_service, ServiceConfig}; pub use service::McpService; pub use transport::{ SseTransport, StdioTransport, StreamableHttpTransport, Transport, TransportHandle, diff --git a/crates/mcp-client/src/oauth.rs b/crates/mcp-client/src/oauth.rs new file mode 100644 index 000000000000..fc6af957628f --- /dev/null +++ b/crates/mcp-client/src/oauth.rs @@ -0,0 +1,419 @@ +use anyhow::Result; +use axum::{extract::Query, response::Html, routing::get, Router}; +use base64::Engine; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use sha2::Digest; +use std::{collections::HashMap, net::SocketAddr, sync::Arc}; +use tokio::sync::{oneshot, Mutex as TokioMutex}; +use url::Url; + +#[derive(Debug, Clone)] +struct OidcEndpoints { + authorization_endpoint: String, + token_endpoint: String, + registration_endpoint: Option, +} + +#[derive(Serialize, Deserialize)] +struct TokenData { + access_token: String, + refresh_token: Option, +} + +#[derive(Serialize, Deserialize)] +struct ClientRegistrationRequest { + redirect_uris: Vec, + token_endpoint_auth_method: String, + grant_types: Vec, + response_types: Vec, + client_name: String, + client_uri: String, +} + +#[derive(Serialize, Deserialize)] +struct ClientRegistrationResponse { + client_id: String, + client_id_issued_at: Option, + #[serde(default)] + client_secret: Option, +} + +/// OAuth configuration for any service +#[derive(Debug, Clone)] +pub struct ServiceConfig { + pub oauth_host: String, + pub redirect_uri: String, + pub client_name: String, + pub client_uri: String, + pub discovery_path: Option, +} + +impl ServiceConfig { + /// Create a generic OAuth configuration from an MCP endpoint URL + /// Extracts the base URL for OAuth discovery + pub fn from_mcp_endpoint(mcp_url: &str) -> Result { + let parsed_url = Url::parse(mcp_url.trim())?; + let oauth_host = format!( + "{}://{}{}", + parsed_url.scheme(), + parsed_url.host_str().ok_or_else(|| { + anyhow::anyhow!("Invalid MCP URL: no host found in {}", mcp_url) + })?, + if let Some(port) = parsed_url.port() { + format!(":{}", port) + } else { + String::new() + } + ); + + Ok(Self { + oauth_host, + redirect_uri: "http://localhost:8020".to_string(), + client_name: "Goose MCP Client".to_string(), + client_uri: "https://github.com/block/goose".to_string(), + discovery_path: None, // Use standard discovery + }) + } + + /// Create configuration with custom discovery path for non-standard services + pub fn with_custom_discovery(mut self, discovery_path: String) -> Self { + self.discovery_path = Some(discovery_path); + self + } +} + +struct OAuthFlow { + endpoints: OidcEndpoints, + client_id: String, + redirect_url: String, + state: String, + verifier: String, +} + +impl OAuthFlow { + fn new(endpoints: OidcEndpoints, client_id: String, redirect_url: String) -> Self { + Self { + endpoints, + client_id, + redirect_url, + state: nanoid::nanoid!(16), + verifier: nanoid::nanoid!(64), + } + } + + /// Register a dynamic client and return the client_id + async fn register_client(endpoints: &OidcEndpoints, config: &ServiceConfig) -> Result { + let Some(registration_endpoint) = &endpoints.registration_endpoint else { + return Err(anyhow::anyhow!("No registration endpoint available")); + }; + + let registration_request = ClientRegistrationRequest { + redirect_uris: vec![config.redirect_uri.clone()], + token_endpoint_auth_method: "none".to_string(), + grant_types: vec![ + "authorization_code".to_string(), + "refresh_token".to_string(), + ], + response_types: vec!["code".to_string()], + client_name: config.client_name.clone(), + client_uri: config.client_uri.clone(), + }; + + tracing::info!("Registering dynamic client with OAuth server..."); + + let client = reqwest::Client::new(); + let resp = client + .post(registration_endpoint) + .header("Content-Type", "application/json") + .json(®istration_request) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let err_text = resp.text().await?; + return Err(anyhow::anyhow!( + "Failed to register client: {} - {}", + status, + err_text + )); + } + + let registration_response: ClientRegistrationResponse = resp.json().await?; + + tracing::info!( + "Client registered successfully with ID: {}", + registration_response.client_id + ); + Ok(registration_response.client_id) + } + + fn get_authorization_url(&self) -> String { + let challenge = { + let digest = sha2::Sha256::digest(self.verifier.as_bytes()); + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest) + }; + + let params = [ + ("response_type", "code"), + ("client_id", &self.client_id), + ("redirect_uri", &self.redirect_url), + ("state", &self.state), + ("code_challenge", &challenge), + ("code_challenge_method", "S256"), + ]; + + format!( + "{}?{}", + self.endpoints.authorization_endpoint, + serde_urlencoded::to_string(params).unwrap() + ) + } + + async fn exchange_code_for_token(&self, code: &str) -> Result { + let params = [ + ("grant_type", "authorization_code"), + ("code", code), + ("redirect_uri", &self.redirect_url), + ("code_verifier", &self.verifier), + ("client_id", &self.client_id), + ]; + + let client = reqwest::Client::new(); + let resp = client + .post(&self.endpoints.token_endpoint) + .header("Content-Type", "application/x-www-form-urlencoded") + .form(¶ms) + .send() + .await?; + + if !resp.status().is_success() { + let err_text = resp.text().await?; + return Err(anyhow::anyhow!( + "Failed to exchange code for token: {}", + err_text + )); + } + + let token_response: Value = resp.json().await?; + + let access_token = token_response + .get("access_token") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("access_token not found in token response"))? + .to_string(); + + let refresh_token = token_response + .get("refresh_token") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + Ok(TokenData { + access_token, + refresh_token, + }) + } + + async fn execute(&self) -> Result { + // Create a channel that will send the auth code from the callback + let (tx, rx) = oneshot::channel(); + let state = self.state.clone(); + let tx = Arc::new(TokioMutex::new(Some(tx))); + + // Setup a server that will receive the redirect and capture the code + let app = Router::new().route( + "/", + get(move |Query(params): Query>| { + let tx = Arc::clone(&tx); + let state = state.clone(); + async move { + let code = params.get("code").cloned(); + let received_state = params.get("state").cloned(); + + if let (Some(code), Some(received_state)) = (code, received_state) { + if received_state == state { + if let Some(sender) = tx.lock().await.take() { + if sender.send(code).is_ok() { + return Html( + "

Authentication Successful!

You can close this window and return to the application.

", + ); + } + } + Html("

Error

Authentication already completed.

") + } else { + Html("

Error

State mismatch - possible security issue.

") + } + } else { + Html("

Error

Authentication failed - missing parameters.

") + } + } + }), + ); + + // Start the callback server + let redirect_url = Url::parse(&self.redirect_url)?; + let port = redirect_url.port().unwrap_or(8020); + let addr = SocketAddr::from(([127, 0, 0, 1], port)); + + let listener = tokio::net::TcpListener::bind(addr).await?; + + let server_handle = tokio::spawn(async move { + let server = axum::serve(listener, app); + server.await.unwrap(); + }); + + // Open the browser for OAuth + let authorization_url = self.get_authorization_url(); + tracing::info!("Opening browser for OAuth authentication..."); + + if webbrowser::open(&authorization_url).is_err() { + tracing::warn!("Could not open browser automatically. Please open this URL manually:"); + tracing::warn!("{}", authorization_url); + } + + // Wait for the authorization code with a timeout + let code = tokio::time::timeout( + std::time::Duration::from_secs(120), // 2 minute timeout + rx, + ) + .await + .map_err(|_| anyhow::anyhow!("Authentication timed out after 2 minutes"))??; + + // Stop the callback server + server_handle.abort(); + + // Exchange the code for a token + self.exchange_code_for_token(&code).await + } +} + +async fn get_oauth_endpoints( + host: &str, + custom_discovery_path: Option<&str>, +) -> Result { + let base_url = Url::parse(host)?; + let client = reqwest::Client::new(); + + // Define discovery paths to try, with custom path first if provided + let mut discovery_paths = Vec::new(); + if let Some(custom_path) = custom_discovery_path { + discovery_paths.push(custom_path); + } + discovery_paths.extend([ + "/.well-known/oauth-authorization-server", + "/.well-known/openid_configuration", + "/oauth/.well-known/oauth-authorization-server", + "/.well-known/oauth_authorization_server", // Some services use underscore + ]); + + let discovery_paths_for_error = discovery_paths.clone(); // Clone for error message + let mut last_error = None; + + // Try each discovery path until one works + for path in discovery_paths { + match base_url.join(path) { + Ok(discovery_url) => { + tracing::debug!("Trying OAuth discovery at: {}", discovery_url); + + match client.get(discovery_url.clone()).send().await { + Ok(resp) if resp.status().is_success() => { + match resp.json::().await { + Ok(oidc_config) => { + // Try to parse the OAuth configuration + match parse_oauth_config(oidc_config) { + Ok(endpoints) => { + tracing::info!( + "Successfully discovered OAuth endpoints at: {}", + discovery_url + ); + return Ok(endpoints); + } + Err(e) => { + tracing::debug!( + "Invalid OAuth config at {}: {}", + discovery_url, + e + ); + last_error = Some(e); + } + } + } + Err(e) => { + tracing::debug!( + "Failed to parse JSON from {}: {}", + discovery_url, + e + ); + last_error = Some(e.into()); + } + } + } + Ok(resp) => { + tracing::debug!("HTTP {} from {}", resp.status(), discovery_url); + } + Err(e) => { + tracing::debug!("Request failed to {}: {}", discovery_url, e); + last_error = Some(e.into()); + } + } + } + Err(e) => { + tracing::debug!("Invalid discovery URL {}{}: {}", host, path, e); + } + } + } + + Err(last_error.unwrap_or_else(|| { + anyhow::anyhow!( + "No OAuth discovery endpoint found at {}. Tried paths: {:?}", + host, + discovery_paths_for_error + ) + })) +} + +fn parse_oauth_config(oidc_config: Value) -> Result { + let authorization_endpoint = oidc_config + .get("authorization_endpoint") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("authorization_endpoint not found in OAuth configuration"))? + .to_string(); + + let token_endpoint = oidc_config + .get("token_endpoint") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("token_endpoint not found in OAuth configuration"))? + .to_string(); + + let registration_endpoint = oidc_config + .get("registration_endpoint") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + Ok(OidcEndpoints { + authorization_endpoint, + token_endpoint, + registration_endpoint, + }) +} + +/// Perform OAuth flow for a service +pub async fn authenticate_service(config: ServiceConfig) -> Result { + tracing::info!("Starting OAuth authentication for service..."); + + // Get OAuth endpoints using flexible discovery + let endpoints = + get_oauth_endpoints(&config.oauth_host, config.discovery_path.as_deref()).await?; + + // Register dynamic client to get client_id + let client_id = OAuthFlow::register_client(&endpoints, &config).await?; + + // Create and execute OAuth flow with the dynamic client_id + let flow = OAuthFlow::new(endpoints, client_id, config.redirect_uri); + + let token_data = flow.execute().await?; + + tracing::info!("OAuth authentication successful!"); + Ok(token_data.access_token) +} diff --git a/crates/mcp-client/src/transport/streamable_http.rs b/crates/mcp-client/src/transport/streamable_http.rs index cc3f4fc5d172..0eb2b52a386f 100644 --- a/crates/mcp-client/src/transport/streamable_http.rs +++ b/crates/mcp-client/src/transport/streamable_http.rs @@ -1,3 +1,4 @@ +use crate::oauth::{authenticate_service, ServiceConfig}; use crate::transport::Error; use async_trait::async_trait; use eventsource_client::{Client, SSE}; @@ -8,7 +9,7 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, Mutex, RwLock}; use tokio::time::Duration; -use tracing::{debug, error, warn}; +use tracing::{debug, error, info, warn}; use url::Url; use super::{serialize_and_send, Transport, TransportHandle}; @@ -91,13 +92,45 @@ impl StreamableHttpActor { JsonRpcMessage::Request(JsonRpcRequest { id: Some(_), .. }) ); + // Try to send the request + match self.send_request(&message_str, expects_response).await { + Ok(()) => Ok(()), + Err(Error::HttpError { status, .. }) if status == 401 || status == 403 => { + // Authentication challenge - try to authenticate and retry + info!( + "Received authentication challenge ({}), attempting OAuth flow...", + status + ); + + if let Some(token) = self.attempt_authentication().await? { + info!("Authentication successful, retrying request..."); + self.headers + .insert("Authorization".to_string(), format!("Bearer {}", token)); + self.send_request(&message_str, expects_response).await + } else { + Err(Error::StreamableHttpError( + "Authentication failed - service not supported or OAuth flow failed" + .to_string(), + )) + } + } + Err(e) => Err(e), + } + } + + /// Send an HTTP request to the MCP endpoint + async fn send_request( + &mut self, + message_str: &str, + expects_response: bool, + ) -> Result<(), Error> { // Build the HTTP request let mut request = self .http_client .post(&self.mcp_endpoint) .header("Content-Type", "application/json") .header("Accept", "application/json, text/event-stream") - .body(message_str); + .body(message_str.to_string()); // Add session ID header if we have one if let Some(session_id) = self.session_id.read().await.as_ref() { @@ -173,6 +206,36 @@ impl StreamableHttpActor { Ok(()) } + /// Attempt to authenticate with the service + async fn attempt_authentication(&self) -> Result, Error> { + info!("Attempting to authenticate with service..."); + + // Create a generic OAuth configuration from the MCP endpoint + match ServiceConfig::from_mcp_endpoint(&self.mcp_endpoint) { + Ok(config) => { + info!("Created OAuth config for endpoint: {}", self.mcp_endpoint); + + match authenticate_service(config).await { + Ok(token) => { + info!("OAuth authentication successful!"); + Ok(Some(token)) + } + Err(e) => { + warn!("OAuth authentication failed: {}", e); + Err(Error::StreamableHttpError(format!("OAuth failed: {}", e))) + } + } + } + Err(e) => { + warn!( + "Could not create OAuth config from MCP endpoint {}: {}", + self.mcp_endpoint, e + ); + Ok(None) + } + } + } + /// Handle streaming HTTP response that uses Server-Sent Events format /// /// This is called when the server responds to an HTTP POST with `text/event-stream` From d667057a7bd0f86918add3568f18acc3c25e92ad Mon Sep 17 00:00:00 2001 From: Alex Hancock Date: Wed, 2 Jul 2025 12:12:15 -0400 Subject: [PATCH 2/2] feat: OAuth client spec compliance improvements (#3224) --- crates/mcp-client/examples/test_auth.rs | 18 ++++- crates/mcp-client/src/lib.rs | 3 + crates/mcp-client/src/oauth.rs | 51 ++++++++++-- crates/mcp-client/src/oauth_tests.rs | 81 +++++++++++++++++++ .../src/transport/streamable_http.rs | 9 ++- 5 files changed, 149 insertions(+), 13 deletions(-) create mode 100644 crates/mcp-client/src/oauth_tests.rs diff --git a/crates/mcp-client/examples/test_auth.rs b/crates/mcp-client/examples/test_auth.rs index d4fba7d6f528..b4159d41224f 100644 --- a/crates/mcp-client/examples/test_auth.rs +++ b/crates/mcp-client/examples/test_auth.rs @@ -16,7 +16,7 @@ async fn main() -> Result<()> { ) .init(); - println!("Testing Streamable HTTP transport with auto-authentication..."); + println!("Testing Streamable HTTP transport with OAuth 2.0 authentication..."); // Create the Streamable HTTP transport for any MCP service that supports OAuth // This example uses a hypothetical MCP endpoint - replace with actual service @@ -34,7 +34,13 @@ async fn main() -> Result<()> { let mut client = McpClient::connect(handle, Duration::from_secs(30)).await?; println!("Client created with Streamable HTTP transport\n"); - // Initialize - this should trigger the OAuth flow if authentication is needed + // Initialize - this will trigger the OAuth flow if authentication is needed + // The implementation now includes: + // - RFC 8707 Resource Parameter support for proper token audience binding + // - Proper OAuth 2.0 discovery with multiple fallback paths + // - Dynamic client registration (RFC 7591) + // - PKCE for security (RFC 7636) + // - MCP-Protocol-Version header as required by the specification let server_info = client .initialize( ClientInfo { @@ -46,7 +52,13 @@ async fn main() -> Result<()> { .await?; println!("Connected to server: {server_info:?}\n"); - println!("Authentication test completed successfully!"); + println!("OAuth 2.0 authentication test completed successfully!"); + println!("\nKey improvements implemented:"); + println!("✓ RFC 8707 Resource Parameter implementation"); + println!("✓ MCP-Protocol-Version header support"); + println!("✓ Enhanced OAuth discovery with multiple fallback paths"); + println!("✓ Proper canonical resource URI generation"); + println!("✓ Full compliance with MCP Authorization specification"); Ok(()) } diff --git a/crates/mcp-client/src/lib.rs b/crates/mcp-client/src/lib.rs index 01f55864e0ba..b659ac3753a1 100644 --- a/crates/mcp-client/src/lib.rs +++ b/crates/mcp-client/src/lib.rs @@ -3,6 +3,9 @@ pub mod oauth; pub mod service; pub mod transport; +#[cfg(test)] +mod oauth_tests; + pub use client::{ClientCapabilities, ClientInfo, Error, McpClient, McpClientTrait}; pub use oauth::{authenticate_service, ServiceConfig}; pub use service::McpService; diff --git a/crates/mcp-client/src/oauth.rs b/crates/mcp-client/src/oauth.rs index fc6af957628f..74bb892a8c77 100644 --- a/crates/mcp-client/src/oauth.rs +++ b/crates/mcp-client/src/oauth.rs @@ -81,6 +81,37 @@ impl ServiceConfig { self.discovery_path = Some(discovery_path); self } + + /// Get the canonical resource URI for the MCP server + /// This is used as the resource parameter in OAuth requests (RFC 8707) + pub fn get_canonical_resource_uri(&self, mcp_url: &str) -> Result { + let parsed_url = Url::parse(mcp_url.trim())?; + + // Build canonical URI: scheme://host[:port][/path] + let mut canonical = format!( + "{}://{}", + parsed_url.scheme().to_lowercase(), + parsed_url + .host_str() + .ok_or_else(|| { + anyhow::anyhow!("Invalid MCP URL: no host found in {}", mcp_url) + })? + .to_lowercase() + ); + + // Add port if not default + if let Some(port) = parsed_url.port() { + canonical.push_str(&format!(":{}", port)); + } + + // Add path if present and not just "/" + let path = parsed_url.path(); + if !path.is_empty() && path != "/" { + canonical.push_str(path); + } + + Ok(canonical) + } } struct OAuthFlow { @@ -149,7 +180,7 @@ impl OAuthFlow { Ok(registration_response.client_id) } - fn get_authorization_url(&self) -> String { + fn get_authorization_url(&self, resource: &str) -> String { let challenge = { let digest = sha2::Sha256::digest(self.verifier.as_bytes()); base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest) @@ -162,6 +193,7 @@ impl OAuthFlow { ("state", &self.state), ("code_challenge", &challenge), ("code_challenge_method", "S256"), + ("resource", resource), // RFC 8707 Resource Parameter ]; format!( @@ -171,13 +203,14 @@ impl OAuthFlow { ) } - async fn exchange_code_for_token(&self, code: &str) -> Result { + async fn exchange_code_for_token(&self, code: &str, resource: &str) -> Result { let params = [ ("grant_type", "authorization_code"), ("code", code), ("redirect_uri", &self.redirect_url), ("code_verifier", &self.verifier), ("client_id", &self.client_id), + ("resource", resource), // RFC 8707 Resource Parameter ]; let client = reqwest::Client::new(); @@ -215,7 +248,7 @@ impl OAuthFlow { }) } - async fn execute(&self) -> Result { + async fn execute(&self, resource: &str) -> Result { // Create a channel that will send the auth code from the callback let (tx, rx) = oneshot::channel(); let state = self.state.clone(); @@ -264,7 +297,7 @@ impl OAuthFlow { }); // Open the browser for OAuth - let authorization_url = self.get_authorization_url(); + let authorization_url = self.get_authorization_url(resource); tracing::info!("Opening browser for OAuth authentication..."); if webbrowser::open(&authorization_url).is_err() { @@ -284,7 +317,7 @@ impl OAuthFlow { server_handle.abort(); // Exchange the code for a token - self.exchange_code_for_token(&code).await + self.exchange_code_for_token(&code, resource).await } } @@ -399,9 +432,13 @@ fn parse_oauth_config(oidc_config: Value) -> Result { } /// Perform OAuth flow for a service -pub async fn authenticate_service(config: ServiceConfig) -> Result { +pub async fn authenticate_service(config: ServiceConfig, mcp_url: &str) -> Result { tracing::info!("Starting OAuth authentication for service..."); + // Get the canonical resource URI for the MCP server + let resource_uri = config.get_canonical_resource_uri(mcp_url)?; + tracing::info!("Using resource URI: {}", resource_uri); + // Get OAuth endpoints using flexible discovery let endpoints = get_oauth_endpoints(&config.oauth_host, config.discovery_path.as_deref()).await?; @@ -412,7 +449,7 @@ pub async fn authenticate_service(config: ServiceConfig) -> Result { // Create and execute OAuth flow with the dynamic client_id let flow = OAuthFlow::new(endpoints, client_id, config.redirect_uri); - let token_data = flow.execute().await?; + let token_data = flow.execute(&resource_uri).await?; tracing::info!("OAuth authentication successful!"); Ok(token_data.access_token) diff --git a/crates/mcp-client/src/oauth_tests.rs b/crates/mcp-client/src/oauth_tests.rs new file mode 100644 index 000000000000..8959c7323b90 --- /dev/null +++ b/crates/mcp-client/src/oauth_tests.rs @@ -0,0 +1,81 @@ +#[cfg(test)] +mod tests { + use crate::oauth::ServiceConfig; + + #[test] + fn test_canonical_resource_uri_generation() { + let config = ServiceConfig { + oauth_host: "https://example.com".to_string(), + redirect_uri: "http://localhost:8020".to_string(), + client_name: "Test Client".to_string(), + client_uri: "https://test.com".to_string(), + discovery_path: None, + }; + + // Test basic URL + let result = config + .get_canonical_resource_uri("https://mcp.example.com/mcp") + .unwrap(); + assert_eq!(result, "https://mcp.example.com/mcp"); + + // Test URL with port + let result = config + .get_canonical_resource_uri("https://mcp.example.com:8443/mcp") + .unwrap(); + assert_eq!(result, "https://mcp.example.com:8443/mcp"); + + // Test URL without path + let result = config + .get_canonical_resource_uri("https://mcp.example.com") + .unwrap(); + assert_eq!(result, "https://mcp.example.com"); + + // Test URL with root path + let result = config + .get_canonical_resource_uri("https://mcp.example.com/") + .unwrap(); + assert_eq!(result, "https://mcp.example.com"); + + // Test case normalization + let result = config + .get_canonical_resource_uri("HTTPS://MCP.EXAMPLE.COM/mcp") + .unwrap(); + assert_eq!(result, "https://mcp.example.com/mcp"); + } + + #[test] + fn test_service_config_from_mcp_endpoint() { + let config = ServiceConfig::from_mcp_endpoint("https://mcp.example.com/api/mcp").unwrap(); + + assert_eq!(config.oauth_host, "https://mcp.example.com"); + assert_eq!(config.redirect_uri, "http://localhost:8020"); + assert_eq!(config.client_name, "Goose MCP Client"); + assert_eq!(config.client_uri, "https://github.com/block/goose"); + assert!(config.discovery_path.is_none()); + } + + #[test] + fn test_service_config_with_port() { + let config = ServiceConfig::from_mcp_endpoint("https://mcp.example.com:8443/mcp").unwrap(); + + assert_eq!(config.oauth_host, "https://mcp.example.com:8443"); + } + + #[test] + fn test_service_config_invalid_url() { + let result = ServiceConfig::from_mcp_endpoint("invalid-url"); + assert!(result.is_err()); + } + + #[test] + fn test_custom_discovery_path() { + let config = ServiceConfig::from_mcp_endpoint("https://mcp.example.com/mcp") + .unwrap() + .with_custom_discovery("/custom/oauth/discovery".to_string()); + + assert_eq!( + config.discovery_path, + Some("/custom/oauth/discovery".to_string()) + ); + } +} diff --git a/crates/mcp-client/src/transport/streamable_http.rs b/crates/mcp-client/src/transport/streamable_http.rs index 0eb2b52a386f..7b39218b25a1 100644 --- a/crates/mcp-client/src/transport/streamable_http.rs +++ b/crates/mcp-client/src/transport/streamable_http.rs @@ -130,6 +130,7 @@ impl StreamableHttpActor { .post(&self.mcp_endpoint) .header("Content-Type", "application/json") .header("Accept", "application/json, text/event-stream") + .header("MCP-Protocol-Version", "2025-06-18") // Required protocol version header .body(message_str.to_string()); // Add session ID header if we have one @@ -215,7 +216,7 @@ impl StreamableHttpActor { Ok(config) => { info!("Created OAuth config for endpoint: {}", self.mcp_endpoint); - match authenticate_service(config).await { + match authenticate_service(config, &self.mcp_endpoint).await { Ok(token) => { info!("OAuth authentication successful!"); Ok(Some(token)) @@ -326,7 +327,8 @@ impl StreamableHttpTransportHandle { let mut request = self .http_client .delete(&self.mcp_endpoint) - .header("Mcp-Session-Id", session_id); + .header("Mcp-Session-Id", session_id) + .header("MCP-Protocol-Version", "2025-06-18"); // Required protocol version header // Add custom headers for (key, value) in &self.headers { @@ -353,7 +355,8 @@ impl StreamableHttpTransportHandle { let mut request = self .http_client .get(&self.mcp_endpoint) - .header("Accept", "text/event-stream"); + .header("Accept", "text/event-stream") + .header("MCP-Protocol-Version", "2025-06-18"); // Required protocol version header // Add session ID header if we have one if let Some(session_id) = self.session_id.read().await.as_ref() {