From ca170161d9cf13406e408e9fe5b6195cc18c7d2a Mon Sep 17 00:00:00 2001 From: Sacha <55644767+SoldierSacha@users.noreply.github.com> Date: Tue, 11 Nov 2025 21:29:52 -0500 Subject: [PATCH] Add tests for OAuth client coverage --- tests/unit/client/test_oauth2_providers.py | 326 +++++++++++++++++++++ 1 file changed, 326 insertions(+) diff --git a/tests/unit/client/test_oauth2_providers.py b/tests/unit/client/test_oauth2_providers.py index 253c4f0bf2..c90cb919af 100644 --- a/tests/unit/client/test_oauth2_providers.py +++ b/tests/unit/client/test_oauth2_providers.py @@ -10,6 +10,7 @@ from mcp.client.auth.oauth2 import ( ClientCredentialsProvider, + OAuthClientProvider, OAuthFlowError, TokenExchangeProvider, ) @@ -17,6 +18,7 @@ OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, + ProtectedResourceMetadata, OAuthToken, ) @@ -343,6 +345,95 @@ async def fake_request_token() -> None: assert not request_called +def test_client_credentials_has_valid_token_false_when_expired() -> None: + storage = InMemoryStorage() + client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris()) + provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage) + provider._current_tokens = OAuthToken(access_token="token") + provider._token_expiry_time = time.time() - 1 + + assert not provider._has_valid_token() + + +@pytest.mark.anyio +async def test_client_credentials_validate_token_scopes_returns_when_missing() -> None: + storage = InMemoryStorage() + client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris()) + provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage) + + token = OAuthToken(access_token="token", scope=None) + + await provider._validate_token_scopes(token) + + +@pytest.mark.anyio +async def test_client_credentials_get_or_register_client_triggers_registration( + monkeypatch: pytest.MonkeyPatch, +) -> None: + storage = InMemoryStorage() + metadata = OAuthClientMetadata(redirect_uris=_redirect_uris()) + provider = ClientCredentialsProvider("https://api.example.com/service", metadata, storage) + provider._metadata = OAuthMetadata.model_validate(_metadata_json()) + + registration_response = _make_response(200, json_data=_registration_json()) + clients = [DummyAsyncClient(send_responses=[registration_response])] + monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients)) + + client_info = await provider._get_or_register_client() + + assert client_info.client_id == "client-id" + assert storage.client_info is not None + + +@pytest.mark.anyio +async def test_client_credentials_request_token_handles_metadata_errors( + monkeypatch: pytest.MonkeyPatch, +) -> None: + storage = InMemoryStorage() + client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha") + provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage) + + invalid_metadata = _make_response(200, json_data={}) + break_response = _make_response(500) + registration_response = _make_response(200, json_data=_registration_json()) + token_response = _make_response(200, json_data={"access_token": "token", "token_type": "Bearer", "scope": "alpha"}) + + clients = [ + DummyAsyncClient(send_responses=[invalid_metadata, break_response]), + DummyAsyncClient(send_responses=[registration_response]), + DummyAsyncClient(post_responses=[token_response]), + ] + monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients)) + + await provider._request_token() + + assert storage.tokens is not None + assert storage.tokens.expires_in is None + assert provider._token_expiry_time is None + + +@pytest.mark.anyio +async def test_client_credentials_request_token_raises_on_failure( + monkeypatch: pytest.MonkeyPatch, +) -> None: + storage = InMemoryStorage() + client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha") + provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage) + provider._metadata = OAuthMetadata.model_validate(_metadata_json()) + + registration_response = _make_response(200, json_data=_registration_json()) + error_response = _make_response(500) + + clients = [ + DummyAsyncClient(send_responses=[registration_response]), + DummyAsyncClient(post_responses=[error_response]), + ] + monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients)) + + with pytest.raises(Exception, match="Token request failed"): + await provider._request_token() + + @pytest.mark.anyio async def test_client_credentials_validate_token_scopes_rejects_extra() -> None: storage = InMemoryStorage() @@ -394,6 +485,103 @@ async def fake_ensure_token() -> None: assert provider._current_tokens is None +@pytest.mark.anyio +async def test_oauth_client_provider_full_flow(monkeypatch: pytest.MonkeyPatch) -> None: + storage = InMemoryStorage() + client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha") + async def fake_redirect(url: str) -> None: + return None + + async def fake_callback() -> tuple[str, str | None]: + return "code", None + + provider = OAuthClientProvider( + "https://api.example.com/service", + client_metadata, + storage, + redirect_handler=fake_redirect, + callback_handler=fake_callback, + ) + + provider._initialized = True + provider.context.current_tokens = None + + prm_metadata = ProtectedResourceMetadata( + resource="https://api.example.com/resource", + authorization_servers=["https://auth.example.com"], + scopes_supported=["alpha", "beta"], + ) + + monkeypatch.setattr( + provider, + "_build_protected_resource_discovery_urls", + lambda response: ["https://api.example.com/.well-known/prm"], + ) + + async def fake_handle_prm(response: httpx.Response) -> bool: + provider.context.protected_resource_metadata = prm_metadata + provider.context.auth_server_url = "https://auth.example.com" + return True + + monkeypatch.setattr(provider, "_handle_protected_resource_response", fake_handle_prm) + + def fake_get_discovery_urls(server_url: str | None = None) -> list[str]: + return ["https://auth.example.com/.well-known/oauth"] + + monkeypatch.setattr(provider, "_get_discovery_urls", fake_get_discovery_urls) + + async def fake_handle_metadata(response: httpx.Response) -> None: + provider._metadata = OAuthMetadata( + issuer="https://auth.example.com", + authorization_endpoint="https://auth.example.com/authorize", + token_endpoint="https://auth.example.com/token", + registration_endpoint="https://auth.example.com/register", + ) + + monkeypatch.setattr(provider, "_handle_oauth_metadata_response", fake_handle_metadata) + + def fake_create_registration_request(metadata: OAuthMetadata | None) -> httpx.Request: + return httpx.Request("POST", "https://auth.example.com/register") + + monkeypatch.setattr(provider, "_create_registration_request", fake_create_registration_request) + + async def fake_handle_registration(response: httpx.Response) -> None: + client_info = OAuthClientInformationFull(client_id="client", client_secret="secret") + provider._client_info = client_info + provider.context.client_info = client_info + + monkeypatch.setattr(provider, "_handle_registration_response", fake_handle_registration) + + async def fake_perform_authorization() -> httpx.Request: + return httpx.Request("POST", "https://auth.example.com/token") + + monkeypatch.setattr(provider, "_perform_authorization", fake_perform_authorization) + + async def fake_handle_token_response(response: httpx.Response) -> None: + token = OAuthToken(access_token="new-token") + provider.context.current_tokens = token + await storage.set_tokens(token) + + monkeypatch.setattr(provider, "_handle_token_response", fake_handle_token_response) + + request = httpx.Request("GET", "https://api.example.com/resource") + flow = provider.async_auth_flow(request) + + initial_request = await anext(flow) + assert "Authorization" not in initial_request.headers + + prm_request = await flow.asend(httpx.Response(401, request=initial_request)) + metadata_request = await flow.asend(httpx.Response(200, request=prm_request)) + registration_request = await flow.asend(httpx.Response(200, request=metadata_request)) + token_request = await flow.asend(httpx.Response(200, request=registration_request)) + retry_request = await flow.asend(httpx.Response(200, request=token_request)) + + assert retry_request.headers["Authorization"] == "Bearer new-token" + + with pytest.raises(StopAsyncIteration): + await flow.asend(httpx.Response(200, request=retry_request)) + + @pytest.mark.anyio async def test_token_exchange_request_token(monkeypatch: pytest.MonkeyPatch) -> None: storage = InMemoryStorage() @@ -540,6 +728,144 @@ async def fake_ensure_token() -> None: assert provider._current_tokens is None +def test_token_exchange_has_valid_token_false_when_expired() -> None: + storage = InMemoryStorage() + client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris()) + + async def provide_subject() -> str: + return "subject-token" + + provider = TokenExchangeProvider( + "https://api.example.com/service", + client_metadata, + storage, + subject_token_supplier=provide_subject, + ) + + provider._current_tokens = OAuthToken(access_token="token") + provider._token_expiry_time = time.time() - 1 + + assert not provider._has_valid_token() + + +@pytest.mark.anyio +async def test_token_exchange_validate_token_scopes_returns_when_missing() -> None: + storage = InMemoryStorage() + client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris()) + + async def provide_subject() -> str: + return "subject-token" + + provider = TokenExchangeProvider( + "https://api.example.com/service", + client_metadata, + storage, + subject_token_supplier=provide_subject, + ) + + token = OAuthToken(access_token="token", scope=None) + + await provider._validate_token_scopes(token) + + +@pytest.mark.anyio +async def test_token_exchange_get_or_register_client_triggers_registration( + monkeypatch: pytest.MonkeyPatch, +) -> None: + storage = InMemoryStorage() + client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris()) + + async def provide_subject() -> str: + return "subject-token" + + provider = TokenExchangeProvider( + "https://api.example.com/service", + client_metadata, + storage, + subject_token_supplier=provide_subject, + ) + provider._metadata = OAuthMetadata.model_validate(_metadata_json()) + + registration_response = _make_response(200, json_data=_registration_json()) + clients = [DummyAsyncClient(send_responses=[registration_response])] + monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients)) + + client_info = await provider._get_or_register_client() + + assert client_info.client_id == "client-id" + assert storage.client_info is not None + + +@pytest.mark.anyio +async def test_token_exchange_request_token_handles_metadata_errors( + monkeypatch: pytest.MonkeyPatch, +) -> None: + storage = InMemoryStorage() + client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha") + + async def provide_subject() -> str: + return "subject-token" + + provider = TokenExchangeProvider( + "https://api.example.com/service", + client_metadata, + storage, + subject_token_supplier=provide_subject, + ) + + invalid_metadata = _make_response(200, json_data={}) + break_response = _make_response(500) + registration_response = _make_response(200, json_data=_registration_json()) + token_response = _make_response( + 200, + json_data={"access_token": "token", "token_type": "Bearer", "scope": "alpha"}, + ) + + clients = [ + DummyAsyncClient(send_responses=[invalid_metadata, break_response]), + DummyAsyncClient(send_responses=[registration_response]), + DummyAsyncClient(post_responses=[token_response]), + ] + monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients)) + + await provider._request_token() + + assert storage.tokens is not None + assert storage.tokens.expires_in is None + assert provider._token_expiry_time is None + + +@pytest.mark.anyio +async def test_token_exchange_request_token_raises_on_failure( + monkeypatch: pytest.MonkeyPatch, +) -> None: + storage = InMemoryStorage() + client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha") + + async def provide_subject() -> str: + return "subject-token" + + provider = TokenExchangeProvider( + "https://api.example.com/service", + client_metadata, + storage, + subject_token_supplier=provide_subject, + ) + provider._metadata = OAuthMetadata.model_validate(_metadata_json()) + + registration_response = _make_response(200, json_data=_registration_json()) + error_response = _make_response(500) + + clients = [ + DummyAsyncClient(send_responses=[registration_response]), + DummyAsyncClient(post_responses=[error_response]), + ] + monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients)) + + with pytest.raises(Exception, match="Token request failed"): + await provider._request_token() + + @pytest.mark.anyio async def test_token_exchange_ensure_token_returns_when_valid() -> None: storage = InMemoryStorage()