diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 9e9389ac14..8f071021d3 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -134,6 +134,7 @@ def __init__( self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} + self._server_capabilities: types.ServerCapabilities | None = None async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None @@ -170,10 +171,19 @@ async def initialize(self) -> types.InitializeResult: if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}") + self._server_capabilities = result.capabilities + await self.send_notification(types.ClientNotification(types.InitializedNotification())) return result + def get_server_capabilities(self) -> types.ServerCapabilities | None: + """Return the server capabilities received during initialization. + + Returns None if the session has not been initialized yet. + """ + return self._server_capabilities + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f2135e4552..11413d2265 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -504,6 +504,78 @@ async def mock_server(): assert received_capabilities.roots.listChanged is True +@pytest.mark.anyio +async def test_get_server_capabilities(): + """Test that get_server_capabilities returns None before init and capabilities after""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + expected_capabilities = ServerCapabilities( + logging=types.LoggingCapability(), + prompts=types.PromptsCapability(listChanged=True), + resources=types.ResourcesCapability(subscribe=True, listChanged=True), + tools=types.ToolsCapability(listChanged=False), + ) + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=expected_capabilities, + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + assert session.get_server_capabilities() is None + + tg.start_soon(mock_server) + await session.initialize() + + capabilities = session.get_server_capabilities() + assert capabilities is not None + assert capabilities == expected_capabilities + assert capabilities.logging is not None + assert capabilities.prompts is not None + assert capabilities.prompts.listChanged is True + assert capabilities.resources is not None + assert capabilities.resources.subscribe is True + assert capabilities.tools is not None + assert capabilities.tools.listChanged is False + + @pytest.mark.anyio @pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}]) async def test_client_tool_call_with_meta(meta: dict[str, Any] | None):