@@ -499,3 +499,207 @@ async def _refresh_access_token(self) -> bool:
499499 except Exception :
500500 logger .exception ("Token refresh failed" )
501501 return False
502+
503+
504+ class ClientCredentialsProvider (httpx .Auth ):
505+ """HTTPX auth using the OAuth2 client credentials grant."""
506+
507+ def __init__ (
508+ self ,
509+ server_url : str ,
510+ client_metadata : OAuthClientMetadata ,
511+ storage : TokenStorage ,
512+ timeout : float = 300.0 ,
513+ ):
514+ self .server_url = server_url
515+ self .client_metadata = client_metadata
516+ self .storage = storage
517+ self .timeout = timeout
518+
519+ self ._current_tokens : OAuthToken | None = None
520+ self ._metadata : OAuthMetadata | None = None
521+ self ._client_info : OAuthClientInformationFull | None = None
522+ self ._token_expiry_time : float | None = None
523+
524+ self ._token_lock = anyio .Lock ()
525+
526+ def _get_authorization_base_url (self , server_url : str ) -> str :
527+ from urllib .parse import urlparse , urlunparse
528+
529+ parsed = urlparse (server_url )
530+ return urlunparse ((parsed .scheme , parsed .netloc , "" , "" , "" , "" ))
531+
532+ async def _discover_oauth_metadata (self , server_url : str ) -> OAuthMetadata | None :
533+ auth_base_url = self ._get_authorization_base_url (server_url )
534+ url = urljoin (auth_base_url , "/.well-known/oauth-authorization-server" )
535+ headers = {"MCP-Protocol-Version" : LATEST_PROTOCOL_VERSION }
536+
537+ async with httpx .AsyncClient () as client :
538+ try :
539+ response = await client .get (url , headers = headers )
540+ if response .status_code == 404 :
541+ return None
542+ response .raise_for_status ()
543+ return OAuthMetadata .model_validate (response .json ())
544+ except Exception :
545+ try :
546+ response = await client .get (url )
547+ if response .status_code == 404 :
548+ return None
549+ response .raise_for_status ()
550+ return OAuthMetadata .model_validate (response .json ())
551+ except Exception :
552+ logger .exception ("Failed to discover OAuth metadata" )
553+ return None
554+
555+ async def _register_oauth_client (
556+ self ,
557+ server_url : str ,
558+ client_metadata : OAuthClientMetadata ,
559+ metadata : OAuthMetadata | None = None ,
560+ ) -> OAuthClientInformationFull :
561+ if not metadata :
562+ metadata = await self ._discover_oauth_metadata (server_url )
563+
564+ if metadata and metadata .registration_endpoint :
565+ registration_url = str (metadata .registration_endpoint )
566+ else :
567+ auth_base_url = self ._get_authorization_base_url (server_url )
568+ registration_url = urljoin (auth_base_url , "/register" )
569+
570+ if (
571+ client_metadata .scope is None
572+ and metadata
573+ and metadata .scopes_supported is not None
574+ ):
575+ client_metadata .scope = " " .join (metadata .scopes_supported )
576+
577+ registration_data = client_metadata .model_dump (
578+ by_alias = True , mode = "json" , exclude_none = True
579+ )
580+
581+ async with httpx .AsyncClient () as client :
582+ response = await client .post (
583+ registration_url ,
584+ json = registration_data ,
585+ headers = {"Content-Type" : "application/json" },
586+ )
587+
588+ if response .status_code not in (200 , 201 ):
589+ raise httpx .HTTPStatusError (
590+ f"Registration failed: { response .status_code } " ,
591+ request = response .request ,
592+ response = response ,
593+ )
594+
595+ return OAuthClientInformationFull .model_validate (response .json ())
596+
597+ def _has_valid_token (self ) -> bool :
598+ if not self ._current_tokens or not self ._current_tokens .access_token :
599+ return False
600+
601+ if self ._token_expiry_time and time .time () > self ._token_expiry_time :
602+ return False
603+ return True
604+
605+ async def _validate_token_scopes (self , token_response : OAuthToken ) -> None :
606+ if not token_response .scope :
607+ return
608+
609+ requested_scopes : set [str ] = set ()
610+ if self .client_metadata .scope :
611+ requested_scopes = set (self .client_metadata .scope .split ())
612+ returned_scopes = set (token_response .scope .split ())
613+ unauthorized_scopes = returned_scopes - requested_scopes
614+ if unauthorized_scopes :
615+ raise Exception (
616+ f"Server granted unauthorized scopes: { unauthorized_scopes } ."
617+ )
618+ else :
619+ granted = set (token_response .scope .split ())
620+ logger .debug (
621+ "No explicit scopes requested, accepting server-granted scopes: %s" ,
622+ granted ,
623+ )
624+
625+ async def initialize (self ) -> None :
626+ self ._current_tokens = await self .storage .get_tokens ()
627+ self ._client_info = await self .storage .get_client_info ()
628+
629+ async def _get_or_register_client (self ) -> OAuthClientInformationFull :
630+ if not self ._client_info :
631+ self ._client_info = await self ._register_oauth_client (
632+ self .server_url , self .client_metadata , self ._metadata
633+ )
634+ await self .storage .set_client_info (self ._client_info )
635+ return self ._client_info
636+
637+ async def _request_token (self ) -> None :
638+ if not self ._metadata :
639+ self ._metadata = await self ._discover_oauth_metadata (self .server_url )
640+
641+ client_info = await self ._get_or_register_client ()
642+
643+ if self ._metadata and self ._metadata .token_endpoint :
644+ token_url = str (self ._metadata .token_endpoint )
645+ else :
646+ auth_base_url = self ._get_authorization_base_url (self .server_url )
647+ token_url = urljoin (auth_base_url , "/token" )
648+
649+ token_data = {
650+ "grant_type" : "client_credentials" ,
651+ "client_id" : client_info .client_id ,
652+ }
653+
654+ if client_info .client_secret :
655+ token_data ["client_secret" ] = client_info .client_secret
656+
657+ if self .client_metadata .scope :
658+ token_data ["scope" ] = self .client_metadata .scope
659+
660+ async with httpx .AsyncClient () as client :
661+ response = await client .post (
662+ token_url ,
663+ data = token_data ,
664+ headers = {"Content-Type" : "application/x-www-form-urlencoded" },
665+ timeout = 30.0 ,
666+ )
667+
668+ if response .status_code != 200 :
669+ raise Exception (
670+ f"Token request failed: { response .status_code } { response .text } "
671+ )
672+
673+ token_response = OAuthToken .model_validate (response .json ())
674+ await self ._validate_token_scopes (token_response )
675+
676+ if token_response .expires_in :
677+ self ._token_expiry_time = time .time () + token_response .expires_in
678+ else :
679+ self ._token_expiry_time = None
680+
681+ await self .storage .set_tokens (token_response )
682+ self ._current_tokens = token_response
683+
684+ async def ensure_token (self ) -> None :
685+ async with self ._token_lock :
686+ if self ._has_valid_token ():
687+ return
688+ await self ._request_token ()
689+
690+ async def async_auth_flow (
691+ self , request : httpx .Request
692+ ) -> AsyncGenerator [httpx .Request , httpx .Response ]:
693+ if not self ._has_valid_token ():
694+ await self .initialize ()
695+ await self .ensure_token ()
696+
697+ if self ._current_tokens and self ._current_tokens .access_token :
698+ request .headers ["Authorization" ] = (
699+ f"Bearer { self ._current_tokens .access_token } "
700+ )
701+
702+ response = yield request
703+
704+ if response .status_code == 401 :
705+ self ._current_tokens = None
0 commit comments