diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 514f4a08..17b426cb 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -14,6 +14,7 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [dependencies] +async-trait = "0.1.89" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" thiserror = "2" diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 7ebbb955..3535410a 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -1,5 +1,6 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; +use async_trait::async_trait; use oauth2::{ AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, Scope, @@ -17,6 +18,62 @@ use tracing::{debug, error, warn}; const DEFAULT_EXCHANGE_URL: &str = "http://localhost"; +/// Stored credentials for OAuth2 authorization +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StoredCredentials { + pub client_id: String, + pub token_response: Option, +} + +/// Trait for storing and retrieving OAuth2 credentials +/// +/// Implementations of this trait can provide custom storage backends +/// for OAuth2 credentials, such as file-based storage, keychain integration, +/// or database storage. +#[async_trait] +pub trait CredentialStore: Send + Sync { + async fn load(&self) -> Result, AuthError>; + + async fn save(&self, credentials: StoredCredentials) -> Result<(), AuthError>; + + async fn clear(&self) -> Result<(), AuthError>; +} + +/// In-memory credential store (default implementation) +/// +/// This store keeps credentials in memory only and does not persist them +/// between application restarts. This is the default behavior when no +/// custom credential store is provided. +#[derive(Debug, Default, Clone)] +pub struct InMemoryCredentialStore { + credentials: Arc>>, +} + +impl InMemoryCredentialStore { + pub fn new() -> Self { + Self { + credentials: Arc::new(RwLock::new(None)), + } + } +} + +#[async_trait::async_trait] +impl CredentialStore for InMemoryCredentialStore { + async fn load(&self) -> Result, AuthError> { + Ok(self.credentials.read().await.clone()) + } + + async fn save(&self, credentials: StoredCredentials) -> Result<(), AuthError> { + *self.credentials.write().await = Some(credentials); + Ok(()) + } + + async fn clear(&self) -> Result<(), AuthError> { + *self.credentials.write().await = None; + Ok(()) + } +} + /// sse client with oauth2 authorization #[derive(Clone)] pub struct AuthClient { @@ -151,7 +208,7 @@ pub struct AuthorizationManager { http_client: HttpClient, metadata: Option, oauth_client: Option, - credentials: RwLock>, + credential_store: Arc, state: RwLock>, base_url: Url, } @@ -222,7 +279,7 @@ impl AuthorizationManager { http_client, metadata: None, oauth_client: None, - credentials: RwLock::new(None), + credential_store: Arc::new(InMemoryCredentialStore::new()), state: RwLock::new(None), base_url, }; @@ -230,6 +287,34 @@ impl AuthorizationManager { Ok(manager) } + /// Set a custom credential store + /// + /// This allows you to provide your own implementation of credential storage, + /// such as file-based storage, keychain integration, or database storage. + /// This should be called before any operations that need credentials. + pub fn set_credential_store(&mut self, store: S) { + self.credential_store = Arc::new(store); + } + + /// Initialize from stored credentials if available + /// + /// This will load credentials from the credential store and configure + /// the client if credentials are found. + pub async fn initialize_from_store(&mut self) -> Result { + if let Some(stored) = self.credential_store.load().await? { + if stored.token_response.is_some() { + if self.metadata.is_none() { + let metadata = self.discover_metadata().await?; + self.metadata = Some(metadata); + } + + self.configure_client_id(&stored.client_id)?; + return Ok(true); + } + } + Ok(false) + } + pub fn with_client(&mut self, http_client: HttpClient) -> Result<(), AuthError> { self.http_client = http_client; Ok(()) @@ -252,13 +337,16 @@ impl AuthorizationManager { /// get client id and credentials pub async fn get_credentials(&self) -> Result { - let credentials = self.credentials.read().await; let client_id = self .oauth_client .as_ref() .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))? .client_id(); - Ok((client_id.to_string(), credentials.clone())) + + let stored = self.credential_store.load().await?; + let token_response = stored.and_then(|s| s.token_response); + + Ok((client_id.to_string(), token_response)) } /// configure oauth2 client with client credentials @@ -309,7 +397,6 @@ impl AuthorizationManager { )); }; - // prepare registration request let registration_request = ClientRegistrationRequest { client_name: name.to_string(), redirect_uris: vec![redirect_uri.to_string()], @@ -479,23 +566,28 @@ impl AuthorizationManager { }; debug!("exchange token result: {:?}", token_result); - // store credentials - *self.credentials.write().await = Some(token_result.clone()); + + // Store credentials in the credential store + let client_id = oauth_client.client_id().to_string(); + let stored = StoredCredentials { + client_id, + token_response: Some(token_result.clone()), + }; + self.credential_store.save(stored).await?; Ok(token_result) } /// get access token, if expired, refresh it automatically pub async fn get_access_token(&self) -> Result { - let credentials = self.credentials.read().await; + // Load credentials from store + let stored = self.credential_store.load().await?; + let credentials = stored.and_then(|s| s.token_response); if let Some(creds) = credentials.as_ref() { - // check if the token is expire let expires_in = creds.expires_in().unwrap_or(Duration::from_secs(0)); if expires_in <= Duration::from_secs(0) { tracing::info!("Access token expired, refreshing."); - // token expired, try to refresh , release the lock - drop(credentials); let new_creds = self.refresh_token().await?; tracing::info!("Refreshed access token."); @@ -517,26 +609,28 @@ impl AuthorizationManager { .as_ref() .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; - let current_credentials = self - .credentials - .read() - .await - .clone() + let stored = self.credential_store.load().await?; + let current_credentials = stored + .and_then(|s| s.token_response) .ok_or_else(|| AuthError::AuthorizationRequired)?; let refresh_token = current_credentials.refresh_token().ok_or_else(|| { AuthError::TokenRefreshFailed("No refresh token available".to_string()) })?; debug!("refresh token: {:?}", refresh_token); - // refresh token + let token_result = oauth_client .exchange_refresh_token(&RefreshToken::new(refresh_token.secret().to_string())) .request_async(&self.http_client) .await .map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?; - // store new credentials - *self.credentials.write().await = Some(token_result.clone()); + let client_id = oauth_client.client_id().to_string(); + let stored = StoredCredentials { + client_id, + token_response: Some(token_result.clone()), + }; + self.credential_store.save(stored).await?; Ok(token_result) } @@ -1003,14 +1097,15 @@ impl OAuthState { AuthorizationManager::new(DEFAULT_EXCHANGE_URL).await?, ); - // write credentials - *manager.credentials.write().await = Some(credentials); + let stored = StoredCredentials { + client_id: client_id.to_string(), + token_response: Some(credentials), + }; + manager.credential_store.save(stored).await?; - // discover metadata let metadata = manager.discover_metadata().await?; manager.metadata = Some(metadata); - // set client id and secret manager.configure_client_id(client_id)?; *self = OAuthState::Authorized(manager);