Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 326 additions & 0 deletions tests/unit/client/test_oauth2_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

from mcp.client.auth.oauth2 import (
ClientCredentialsProvider,
OAuthClientProvider,
OAuthFlowError,
TokenExchangeProvider,
)
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
ProtectedResourceMetadata,
OAuthToken,
)

Expand Down Expand Up @@ -343,6 +345,95 @@ async def fake_request_token() -> None:
assert not request_called


def test_client_credentials_has_valid_token_false_when_expired() -> None:
storage = InMemoryStorage()
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage)
provider._current_tokens = OAuthToken(access_token="token")
provider._token_expiry_time = time.time() - 1

assert not provider._has_valid_token()


@pytest.mark.anyio
async def test_client_credentials_validate_token_scopes_returns_when_missing() -> None:
storage = InMemoryStorage()
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage)

token = OAuthToken(access_token="token", scope=None)

await provider._validate_token_scopes(token)


@pytest.mark.anyio
async def test_client_credentials_get_or_register_client_triggers_registration(
monkeypatch: pytest.MonkeyPatch,
) -> None:
storage = InMemoryStorage()
metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
provider = ClientCredentialsProvider("https://api.example.com/service", metadata, storage)
provider._metadata = OAuthMetadata.model_validate(_metadata_json())

registration_response = _make_response(200, json_data=_registration_json())
clients = [DummyAsyncClient(send_responses=[registration_response])]
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))

client_info = await provider._get_or_register_client()

assert client_info.client_id == "client-id"
assert storage.client_info is not None


@pytest.mark.anyio
async def test_client_credentials_request_token_handles_metadata_errors(
monkeypatch: pytest.MonkeyPatch,
) -> None:
storage = InMemoryStorage()
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")
provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage)

invalid_metadata = _make_response(200, json_data={})
break_response = _make_response(500)
registration_response = _make_response(200, json_data=_registration_json())
token_response = _make_response(200, json_data={"access_token": "token", "token_type": "Bearer", "scope": "alpha"})

clients = [
DummyAsyncClient(send_responses=[invalid_metadata, break_response]),
DummyAsyncClient(send_responses=[registration_response]),
DummyAsyncClient(post_responses=[token_response]),
]
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))

await provider._request_token()

assert storage.tokens is not None
assert storage.tokens.expires_in is None
assert provider._token_expiry_time is None


@pytest.mark.anyio
async def test_client_credentials_request_token_raises_on_failure(
monkeypatch: pytest.MonkeyPatch,
) -> None:
storage = InMemoryStorage()
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")
provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage)
provider._metadata = OAuthMetadata.model_validate(_metadata_json())

registration_response = _make_response(200, json_data=_registration_json())
error_response = _make_response(500)

clients = [
DummyAsyncClient(send_responses=[registration_response]),
DummyAsyncClient(post_responses=[error_response]),
]
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))

with pytest.raises(Exception, match="Token request failed"):
await provider._request_token()


@pytest.mark.anyio
async def test_client_credentials_validate_token_scopes_rejects_extra() -> None:
storage = InMemoryStorage()
Expand Down Expand Up @@ -394,6 +485,103 @@ async def fake_ensure_token() -> None:
assert provider._current_tokens is None


@pytest.mark.anyio
async def test_oauth_client_provider_full_flow(monkeypatch: pytest.MonkeyPatch) -> None:
storage = InMemoryStorage()
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")
async def fake_redirect(url: str) -> None:
return None

async def fake_callback() -> tuple[str, str | None]:
return "code", None

provider = OAuthClientProvider(
"https://api.example.com/service",
client_metadata,
storage,
redirect_handler=fake_redirect,
callback_handler=fake_callback,
)

provider._initialized = True
provider.context.current_tokens = None

prm_metadata = ProtectedResourceMetadata(
resource="https://api.example.com/resource",
authorization_servers=["https://auth.example.com"],
scopes_supported=["alpha", "beta"],
)

monkeypatch.setattr(
provider,
"_build_protected_resource_discovery_urls",
lambda response: ["https://api.example.com/.well-known/prm"],
)

async def fake_handle_prm(response: httpx.Response) -> bool:
provider.context.protected_resource_metadata = prm_metadata
provider.context.auth_server_url = "https://auth.example.com"
return True

monkeypatch.setattr(provider, "_handle_protected_resource_response", fake_handle_prm)

def fake_get_discovery_urls(server_url: str | None = None) -> list[str]:
return ["https://auth.example.com/.well-known/oauth"]

monkeypatch.setattr(provider, "_get_discovery_urls", fake_get_discovery_urls)

async def fake_handle_metadata(response: httpx.Response) -> None:
provider._metadata = OAuthMetadata(
issuer="https://auth.example.com",
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
registration_endpoint="https://auth.example.com/register",
)

monkeypatch.setattr(provider, "_handle_oauth_metadata_response", fake_handle_metadata)

def fake_create_registration_request(metadata: OAuthMetadata | None) -> httpx.Request:
return httpx.Request("POST", "https://auth.example.com/register")

monkeypatch.setattr(provider, "_create_registration_request", fake_create_registration_request)

async def fake_handle_registration(response: httpx.Response) -> None:
client_info = OAuthClientInformationFull(client_id="client", client_secret="secret")
provider._client_info = client_info
provider.context.client_info = client_info

monkeypatch.setattr(provider, "_handle_registration_response", fake_handle_registration)

async def fake_perform_authorization() -> httpx.Request:
return httpx.Request("POST", "https://auth.example.com/token")

monkeypatch.setattr(provider, "_perform_authorization", fake_perform_authorization)

async def fake_handle_token_response(response: httpx.Response) -> None:
token = OAuthToken(access_token="new-token")
provider.context.current_tokens = token
await storage.set_tokens(token)

monkeypatch.setattr(provider, "_handle_token_response", fake_handle_token_response)

request = httpx.Request("GET", "https://api.example.com/resource")
flow = provider.async_auth_flow(request)

initial_request = await anext(flow)
assert "Authorization" not in initial_request.headers

prm_request = await flow.asend(httpx.Response(401, request=initial_request))
metadata_request = await flow.asend(httpx.Response(200, request=prm_request))
registration_request = await flow.asend(httpx.Response(200, request=metadata_request))
token_request = await flow.asend(httpx.Response(200, request=registration_request))
retry_request = await flow.asend(httpx.Response(200, request=token_request))

assert retry_request.headers["Authorization"] == "Bearer new-token"

with pytest.raises(StopAsyncIteration):
await flow.asend(httpx.Response(200, request=retry_request))


@pytest.mark.anyio
async def test_token_exchange_request_token(monkeypatch: pytest.MonkeyPatch) -> None:
storage = InMemoryStorage()
Expand Down Expand Up @@ -540,6 +728,144 @@ async def fake_ensure_token() -> None:
assert provider._current_tokens is None


def test_token_exchange_has_valid_token_false_when_expired() -> None:
storage = InMemoryStorage()
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())

async def provide_subject() -> str:
return "subject-token"

provider = TokenExchangeProvider(
"https://api.example.com/service",
client_metadata,
storage,
subject_token_supplier=provide_subject,
)

provider._current_tokens = OAuthToken(access_token="token")
provider._token_expiry_time = time.time() - 1

assert not provider._has_valid_token()


@pytest.mark.anyio
async def test_token_exchange_validate_token_scopes_returns_when_missing() -> None:
storage = InMemoryStorage()
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())

async def provide_subject() -> str:
return "subject-token"

provider = TokenExchangeProvider(
"https://api.example.com/service",
client_metadata,
storage,
subject_token_supplier=provide_subject,
)

token = OAuthToken(access_token="token", scope=None)

await provider._validate_token_scopes(token)


@pytest.mark.anyio
async def test_token_exchange_get_or_register_client_triggers_registration(
monkeypatch: pytest.MonkeyPatch,
) -> None:
storage = InMemoryStorage()
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())

async def provide_subject() -> str:
return "subject-token"

provider = TokenExchangeProvider(
"https://api.example.com/service",
client_metadata,
storage,
subject_token_supplier=provide_subject,
)
provider._metadata = OAuthMetadata.model_validate(_metadata_json())

registration_response = _make_response(200, json_data=_registration_json())
clients = [DummyAsyncClient(send_responses=[registration_response])]
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))

client_info = await provider._get_or_register_client()

assert client_info.client_id == "client-id"
assert storage.client_info is not None


@pytest.mark.anyio
async def test_token_exchange_request_token_handles_metadata_errors(
monkeypatch: pytest.MonkeyPatch,
) -> None:
storage = InMemoryStorage()
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")

async def provide_subject() -> str:
return "subject-token"

provider = TokenExchangeProvider(
"https://api.example.com/service",
client_metadata,
storage,
subject_token_supplier=provide_subject,
)

invalid_metadata = _make_response(200, json_data={})
break_response = _make_response(500)
registration_response = _make_response(200, json_data=_registration_json())
token_response = _make_response(
200,
json_data={"access_token": "token", "token_type": "Bearer", "scope": "alpha"},
)

clients = [
DummyAsyncClient(send_responses=[invalid_metadata, break_response]),
DummyAsyncClient(send_responses=[registration_response]),
DummyAsyncClient(post_responses=[token_response]),
]
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))

await provider._request_token()

assert storage.tokens is not None
assert storage.tokens.expires_in is None
assert provider._token_expiry_time is None


@pytest.mark.anyio
async def test_token_exchange_request_token_raises_on_failure(
monkeypatch: pytest.MonkeyPatch,
) -> None:
storage = InMemoryStorage()
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")

async def provide_subject() -> str:
return "subject-token"

provider = TokenExchangeProvider(
"https://api.example.com/service",
client_metadata,
storage,
subject_token_supplier=provide_subject,
)
provider._metadata = OAuthMetadata.model_validate(_metadata_json())

registration_response = _make_response(200, json_data=_registration_json())
error_response = _make_response(500)

clients = [
DummyAsyncClient(send_responses=[registration_response]),
DummyAsyncClient(post_responses=[error_response]),
]
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))

with pytest.raises(Exception, match="Token request failed"):
await provider._request_token()


@pytest.mark.anyio
async def test_token_exchange_ensure_token_returns_when_valid() -> None:
storage = InMemoryStorage()
Expand Down