Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions crates/rmcp/src/transport/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use thiserror::Error;
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, error};

const DEFAULT_EXCHANGE_URL: &str = "http://localhost";

/// sse client with oauth2 authorization
#[derive(Clone)]
pub struct AuthClient<C> {
Expand Down Expand Up @@ -225,10 +227,17 @@ impl AuthorizationManager {
// discard the path part, only keep scheme, host, port
auth_base.set_path("");

// Helper function to create endpoint URL
let create_endpoint = |path: &str| -> String {
let mut url = auth_base.clone();
url.set_path(path);
url.to_string()
};

Ok(AuthorizationMetadata {
authorization_endpoint: format!("{}/authorize", auth_base),
token_endpoint: format!("{}/token", auth_base),
registration_endpoint: format!("{}/register", auth_base),
authorization_endpoint: create_endpoint("authorize"),
token_endpoint: create_endpoint("token"),
registration_endpoint: create_endpoint("register"),
issuer: None,
jwks_uri: None,
scopes_supported: None,
Expand Down Expand Up @@ -686,7 +695,7 @@ impl OAuthState {
if let OAuthState::Unauthorized(manager) = self {
let mut manager = std::mem::replace(
manager,
AuthorizationManager::new("http://localhost").await?,
AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?,
);

// write credentials
Expand Down Expand Up @@ -716,7 +725,7 @@ impl OAuthState {
) -> Result<(), AuthError> {
if let OAuthState::Unauthorized(mut manager) = std::mem::replace(
self,
OAuthState::Unauthorized(AuthorizationManager::new("http://localhost").await?),
OAuthState::Unauthorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?),
) {
debug!("start discovery");
let metadata = manager.discover_metadata().await?;
Expand All @@ -736,7 +745,7 @@ impl OAuthState {
pub async fn complete_authorization(&mut self) -> Result<(), AuthError> {
if let OAuthState::Session(session) = std::mem::replace(
self,
OAuthState::Unauthorized(AuthorizationManager::new("http://localhost").await?),
OAuthState::Unauthorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?),
) {
*self = OAuthState::Authorized(session.auth_manager);
Ok(())
Expand All @@ -748,7 +757,7 @@ impl OAuthState {
pub async fn to_authorized_http_client(&mut self) -> Result<(), AuthError> {
if let OAuthState::Authorized(manager) = std::mem::replace(
self,
OAuthState::Authorized(AuthorizationManager::new("http://localhost").await?),
OAuthState::Authorized(AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?),
) {
*self = OAuthState::AuthorizedHttpClient(AuthorizedHttpClient::new(
Arc::new(manager),
Expand Down