From 621081e8d4306c2a9042ecabb84c4ec16f04e5ea Mon Sep 17 00:00:00 2001 From: Felix Weinberger Date: Fri, 16 Jan 2026 16:29:21 +0100 Subject: [PATCH] feat: add ergonomic Client class for testing MCP servers Add a high-level Client class that wraps ClientSession with transport management, making it easier to test MCP servers. Key changes: - Add Client class with async context manager support - Add InMemoryTransport for in-memory server connections - Migrate test files to use new Client(server) API - Update testing documentation with Client usage examples Usage: async with Client(server) as client: result = await client.call_tool('my_tool', {'arg': 'value'}) --- docs/testing.md | 25 +- examples/fastmcp/weather_structured.py | 4 +- src/mcp/__init__.py | 2 + src/mcp/client/__init__.py | 9 + src/mcp/client/_memory.py | 97 +++++ src/mcp/client/client.py | 302 ++++++++++++++++ src/mcp/client/session.py | 1 - src/mcp/shared/memory.py | 57 --- tests/client/conftest.py | 16 +- tests/client/test_client.py | 336 ++++++++++++++++++ tests/client/test_list_methods_cursor.py | 179 +++++----- tests/client/test_list_roots_callback.py | 12 +- tests/client/test_logging_callback.py | 16 +- tests/client/test_output_schema_validation.py | 14 +- tests/client/test_sampling_callback.py | 16 +- tests/client/transports/__init__.py | 0 tests/client/transports/test_memory.py | 114 ++++++ tests/issues/test_141_resource_templates.py | 9 +- tests/issues/test_152_resource_mime_type.py | 9 +- .../test_1574_resource_uri_validation.py | 9 +- .../issues/test_1754_mime_type_parameters.py | 6 +- tests/issues/test_188_concurrency.py | 6 +- tests/server/fastmcp/test_elicitation.py | 19 +- tests/server/fastmcp/test_server.py | 152 +++++--- tests/server/fastmcp/test_title.py | 19 +- tests/server/fastmcp/test_url_elicitation.py | 90 ++--- .../test_url_elicitation_error_throw.py | 21 +- tests/server/test_cancel_handling.py | 105 +++--- tests/server/test_completion_with_context.py | 10 +- tests/shared/test_memory.py | 21 +- tests/shared/test_progress_notifications.py | 50 +-- tests/shared/test_session.py | 76 ++-- tests/test_examples.py | 22 +- 33 files changed, 1298 insertions(+), 526 deletions(-) create mode 100644 src/mcp/client/_memory.py create mode 100644 src/mcp/client/client.py create mode 100644 tests/client/test_client.py create mode 100644 tests/client/transports/__init__.py create mode 100644 tests/client/transports/test_memory.py diff --git a/docs/testing.md b/docs/testing.md index 8d84449893..f869873608 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -1,10 +1,11 @@ # Testing MCP Servers -If you call yourself a developer, you will want to test your MCP server. -The Python SDK offers the `create_connected_server_and_client_session` function to create a session -using an in-memory transport. I know, I know, the name is too long... We are working on improving it. +The Python SDK provides a `Client` class for testing MCP servers with an in-memory transport. +This makes it easy to write tests without network overhead. -Anyway, let's assume you have a simple server with a single tool: +## Basic Usage + +Let's assume you have a simple server with a single tool: ```python title="server.py" from mcp.server import FastMCP @@ -40,12 +41,9 @@ To run the below test, you'll need to install the following dependencies: server - you don't need to use it, but we are spreading the word for best practices. ```python title="test_server.py" -from collections.abc import AsyncGenerator - import pytest from inline_snapshot import snapshot -from mcp.client.session import ClientSession -from mcp.shared.memory import create_connected_server_and_client_session +from mcp import Client from mcp.types import CallToolResult, TextContent from server import app @@ -57,14 +55,14 @@ def anyio_backend(): # (1)! @pytest.fixture -async def client_session() -> AsyncGenerator[ClientSession]: - async with create_connected_server_and_client_session(app, raise_exceptions=True) as _session: - yield _session +async def client(): # (2)! + async with Client(app, raise_exceptions=True) as c: + yield c @pytest.mark.anyio -async def test_call_add_tool(client_session: ClientSession): - result = await client_session.call_tool("add", {"a": 1, "b": 2}) +async def test_call_add_tool(client: Client): + result = await client.call_tool("add", {"a": 1, "b": 2}) assert result == snapshot( CallToolResult( content=[TextContent(type="text", text="3")], @@ -74,5 +72,6 @@ async def test_call_add_tool(client_session: ClientSession): ``` 1. If you are using `trio`, you should set `"trio"` as the `anyio_backend`. Check more information in the [anyio documentation](https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on). +2. The `client` fixture creates a connected client that can be reused across multiple tests. There you go! You can now extend your tests to cover more scenarios. diff --git a/examples/fastmcp/weather_structured.py b/examples/fastmcp/weather_structured.py index 87ad8993fc..60c24a8f53 100644 --- a/examples/fastmcp/weather_structured.py +++ b/examples/fastmcp/weather_structured.py @@ -14,8 +14,8 @@ from pydantic import BaseModel, Field +from mcp.client import Client from mcp.server.fastmcp import FastMCP -from mcp.shared.memory import create_connected_server_and_client_session as client_session # Create server mcp = FastMCP("Weather Service") @@ -157,7 +157,7 @@ async def test() -> None: print("Testing Weather Service Tools (via MCP protocol)\n") print("=" * 80) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: # Test get_weather result = await client.call_tool("get_weather", {"city": "London"}) print("\nWeather in London:") diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index 65a2bd50e7..9823523148 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -1,3 +1,4 @@ +from .client.client import Client from .client.session import ClientSession from .client.session_group import ClientSessionGroup from .client.stdio import StdioServerParameters, stdio_client @@ -66,6 +67,7 @@ __all__ = [ "CallToolRequest", + "Client", "ClientCapabilities", "ClientNotification", "ClientRequest", diff --git a/src/mcp/client/__init__.py b/src/mcp/client/__init__.py index e69de29bb2..7b94647102 100644 --- a/src/mcp/client/__init__.py +++ b/src/mcp/client/__init__.py @@ -0,0 +1,9 @@ +"""MCP Client module.""" + +from mcp.client.client import Client +from mcp.client.session import ClientSession + +__all__ = [ + "Client", + "ClientSession", +] diff --git a/src/mcp/client/_memory.py b/src/mcp/client/_memory.py new file mode 100644 index 0000000000..8b959e3c4d --- /dev/null +++ b/src/mcp/client/_memory.py @@ -0,0 +1,97 @@ +"""In-memory transport for testing MCP servers without network overhead.""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from typing import Any + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +from mcp.server import Server +from mcp.server.fastmcp import FastMCP +from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.message import SessionMessage + + +class InMemoryTransport: + """ + In-memory transport for testing MCP servers without network overhead. + + This transport starts the server in a background task and provides + streams for client-side communication. The server is automatically + stopped when the context manager exits. + + Example: + server = FastMCP("test") + transport = InMemoryTransport(server) + + async with transport.connect() as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + # Use the session... + + Or more commonly, use with Client: + async with Client(server) as client: + result = await client.call_tool("my_tool", {...}) + """ + + def __init__( + self, + server: Server[Any] | FastMCP, + *, + raise_exceptions: bool = False, + ) -> None: + """ + Initialize the in-memory transport. + + Args: + server: The MCP server to connect to (Server or FastMCP instance) + raise_exceptions: Whether to raise exceptions from the server + """ + self._server = server + self._raise_exceptions = raise_exceptions + + @asynccontextmanager + async def connect( + self, + ) -> AsyncGenerator[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + ], + None, + ]: + """ + Connect to the server and return streams for communication. + + Yields: + A tuple of (read_stream, write_stream) for bidirectional communication + """ + # Unwrap FastMCP to get underlying Server + actual_server: Server[Any] + if isinstance(self._server, FastMCP): + actual_server = self._server._mcp_server # type: ignore[reportPrivateUsage] + else: + actual_server = self._server + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as tg: + # Start server in background + tg.start_soon( + lambda: actual_server.run( + server_read, + server_write, + actual_server.create_initialization_options(), + raise_exceptions=self._raise_exceptions, + ) + ) + + try: + yield client_read, client_write + finally: + tg.cancel_scope.cancel() diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py new file mode 100644 index 0000000000..699a24f046 --- /dev/null +++ b/src/mcp/client/client.py @@ -0,0 +1,302 @@ +"""Unified MCP Client that wraps ClientSession with transport management.""" + +from __future__ import annotations + +import logging +from contextlib import AsyncExitStack +from typing import Any + +from pydantic import AnyUrl + +import mcp.types as types +from mcp.client._memory import InMemoryTransport +from mcp.client.session import ( + ClientSession, + ElicitationFnT, + ListRootsFnT, + LoggingFnT, + MessageHandlerFnT, + SamplingFnT, +) +from mcp.server import Server +from mcp.server.fastmcp import FastMCP +from mcp.shared.session import ProgressFnT + +logger = logging.getLogger(__name__) + + +class Client: + """A high-level MCP client for connecting to MCP servers. + + Currently supports in-memory transport for testing. Pass a Server or + FastMCP instance directly to the constructor. + + Example: + ```python + from mcp.client import Client + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + @server.tool() + def add(a: int, b: int) -> int: + return a + b + + async with Client(server) as client: + result = await client.call_tool("add", {"a": 1, "b": 2}) + ``` + """ + + # TODO(felixweinberger): Expand to support all transport types (like FastMCP 2): + # - Add ClientTransport base class with connect_session() method + # - Add StreamableHttpTransport, SSETransport, StdioTransport + # - Add infer_transport() to auto-detect transport from input type + # - Accept URL strings, Path objects, config dicts in constructor + # - Add auth support (OAuth, bearer tokens) + + def __init__( + self, + server: Server[Any] | FastMCP, + *, + raise_exceptions: bool = False, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: types.Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, + ) -> None: + """ + Initialize the client with a server. + + Args: + server: The MCP server to connect to (Server or FastMCP instance) + raise_exceptions: Whether to raise exceptions from the server + read_timeout_seconds: Timeout for read operations + sampling_callback: Callback for handling sampling requests + list_roots_callback: Callback for handling list roots requests + logging_callback: Callback for handling logging notifications + message_handler: Callback for handling raw messages + client_info: Client implementation info to send to server + elicitation_callback: Callback for handling elicitation requests + """ + self._server = server + self._raise_exceptions = raise_exceptions + self._read_timeout_seconds = read_timeout_seconds + self._sampling_callback = sampling_callback + self._list_roots_callback = list_roots_callback + self._logging_callback = logging_callback + self._message_handler = message_handler + self._client_info = client_info + self._elicitation_callback = elicitation_callback + + self._session: ClientSession | None = None + self._exit_stack: AsyncExitStack | None = None + + async def __aenter__(self) -> Client: + """Enter the async context manager.""" + if self._session is not None: + raise RuntimeError("Client is already entered; cannot reenter") + + async with AsyncExitStack() as exit_stack: + # Create transport and connect + transport = InMemoryTransport(self._server, raise_exceptions=self._raise_exceptions) + read_stream, write_stream = await exit_stack.enter_async_context(transport.connect()) + + # Create session + self._session = await exit_stack.enter_async_context( + ClientSession( + read_stream=read_stream, + write_stream=write_stream, + read_timeout_seconds=self._read_timeout_seconds, + sampling_callback=self._sampling_callback, + list_roots_callback=self._list_roots_callback, + logging_callback=self._logging_callback, + message_handler=self._message_handler, + client_info=self._client_info, + elicitation_callback=self._elicitation_callback, + ) + ) + + # Initialize the session + await self._session.initialize() + + # Transfer ownership to self for __aexit__ to handle + self._exit_stack = exit_stack.pop_all() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Exit the async context manager.""" + if self._exit_stack: + await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) + self._session = None + + @property + def session(self) -> ClientSession: + """ + Get the underlying ClientSession. + + This provides access to the full ClientSession API for advanced use cases. + + Raises: + RuntimeError: If accessed before entering the context manager. + """ + if self._session is None: + raise RuntimeError("Client must be used within an async context manager") + return self._session + + @property + def server_capabilities(self) -> types.ServerCapabilities | None: + """The server capabilities received during initialization, or None if not yet initialized.""" + return self.session.get_server_capabilities() + + async def send_ping(self) -> types.EmptyResult: + """Send a ping request to the server.""" + return await self.session.send_ping() + + async def send_progress_notification( + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + ) -> None: + """Send a progress notification to the server.""" + await self.session.send_progress_notification( + progress_token=progress_token, + progress=progress, + total=total, + message=message, + ) + + async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult: + """Set the logging level on the server.""" + return await self.session.set_logging_level(level) + + async def list_resources( + self, + params: types.PaginatedRequestParams | None = None, + ) -> types.ListResourcesResult: + """List available resources from the server.""" + return await self.session.list_resources(params=params) + + async def list_resource_templates( + self, + params: types.PaginatedRequestParams | None = None, + ) -> types.ListResourceTemplatesResult: + """List available resource templates from the server.""" + return await self.session.list_resource_templates(params=params) + + async def read_resource(self, uri: str | AnyUrl) -> types.ReadResourceResult: + """ + Read a resource from the server. + + Args: + uri: The URI of the resource to read + + Returns: + The resource content + """ + return await self.session.read_resource(uri) + + async def subscribe_resource(self, uri: str | AnyUrl) -> types.EmptyResult: + """Subscribe to resource updates.""" + return await self.session.subscribe_resource(uri) + + async def unsubscribe_resource(self, uri: str | AnyUrl) -> types.EmptyResult: + """Unsubscribe from resource updates.""" + return await self.session.unsubscribe_resource(uri) + + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: float | None = None, + progress_callback: ProgressFnT | None = None, + *, + meta: dict[str, Any] | None = None, + ) -> types.CallToolResult: + """ + Call a tool on the server. + + Args: + name: The name of the tool to call + arguments: Arguments to pass to the tool + read_timeout_seconds: Timeout for the tool call + progress_callback: Callback for progress updates + meta: Additional metadata for the request + + Returns: + The tool result + """ + return await self.session.call_tool( + name=name, + arguments=arguments, + read_timeout_seconds=read_timeout_seconds, + progress_callback=progress_callback, + meta=meta, + ) + + async def list_prompts( + self, + params: types.PaginatedRequestParams | None = None, + ) -> types.ListPromptsResult: + """List available prompts from the server.""" + return await self.session.list_prompts(params=params) + + async def get_prompt( + self, + name: str, + arguments: dict[str, str] | None = None, + ) -> types.GetPromptResult: + """ + Get a prompt from the server. + + Args: + name: The name of the prompt + arguments: Arguments to pass to the prompt + + Returns: + The prompt content + """ + return await self.session.get_prompt(name=name, arguments=arguments) + + async def complete( + self, + ref: types.ResourceTemplateReference | types.PromptReference, + argument: dict[str, str], + context_arguments: dict[str, str] | None = None, + ) -> types.CompleteResult: + """ + Get completions for a prompt or resource template argument. + + Args: + ref: Reference to the prompt or resource template + argument: The argument to complete + context_arguments: Additional context arguments + + Returns: + Completion suggestions + """ + return await self.session.complete( + ref=ref, + argument=argument, + context_arguments=context_arguments, + ) + + async def list_tools( + self, + params: types.PaginatedRequestParams | None = None, + ) -> types.ListToolsResult: + """List available tools from the server.""" + return await self.session.list_tools(params=params) + + async def send_roots_list_changed(self) -> None: + """Send a notification that the roots list has changed.""" + await self.session.send_roots_list_changed() diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 637a3d1b15..7aeee2cd8a 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -407,7 +407,6 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None """Send a tools/list request. Args: - cursor: Simple cursor string for pagination (deprecated, use params instead) params: Full pagination parameters including cursor and any future fields """ result = await self.send_request( diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index c7c6dbabc2..e35c487b92 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -6,15 +6,10 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import Any import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -import mcp.types as types -from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT -from mcp.server import Server -from mcp.server.fastmcp import FastMCP from mcp.shared.message import SessionMessage MessageStream = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] @@ -43,55 +38,3 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageS server_to_client_send, ): yield client_streams, server_streams - - -@asynccontextmanager -async def create_connected_server_and_client_session( - server: Server[Any] | FastMCP, - read_timeout_seconds: float | None = None, - sampling_callback: SamplingFnT | None = None, - list_roots_callback: ListRootsFnT | None = None, - logging_callback: LoggingFnT | None = None, - message_handler: MessageHandlerFnT | None = None, - client_info: types.Implementation | None = None, - raise_exceptions: bool = False, - elicitation_callback: ElicitationFnT | None = None, -) -> AsyncGenerator[ClientSession, None]: - """Creates a ClientSession that is connected to a running MCP server.""" - - # TODO(Marcelo): we should have a proper `Client` that can use this "in-memory transport", - # and we should expose a method in the `FastMCP` so we don't access a private attribute. - if isinstance(server, FastMCP): # pragma: no cover - server = server._mcp_server # type: ignore[reportPrivateUsage] - - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - # Create a cancel scope for the server task - async with anyio.create_task_group() as tg: - tg.start_soon( - lambda: server.run( - server_read, - server_write, - server.create_initialization_options(), - raise_exceptions=raise_exceptions, - ) - ) - - try: - async with ClientSession( - read_stream=client_read, - write_stream=client_write, - read_timeout_seconds=read_timeout_seconds, - sampling_callback=sampling_callback, - list_roots_callback=list_roots_callback, - logging_callback=logging_callback, - message_handler=message_handler, - client_info=client_info, - elicitation_callback=elicitation_callback, - ) as client_session: - await client_session.initialize() - yield client_session - finally: # pragma: no cover - tg.cancel_scope.cancel() diff --git a/tests/client/conftest.py b/tests/client/conftest.py index 1e5c4d524c..dfcad8215d 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -123,11 +123,13 @@ async def patched_create_streams(): yield (client_read, spy_client_write), (server_read, spy_server_write) # Apply the patch for the duration of the test + # Patch both locations since InMemoryTransport imports it directly with patch("mcp.shared.memory.create_client_server_memory_streams", patched_create_streams): - # Return a collection with helper methods - def get_spy_collection() -> StreamSpyCollection: - assert client_spy is not None, "client_spy was not initialized" - assert server_spy is not None, "server_spy was not initialized" - return StreamSpyCollection(client_spy, server_spy) - - yield get_spy_collection + with patch("mcp.client._memory.create_client_server_memory_streams", patched_create_streams): + # Return a collection with helper methods + def get_spy_collection() -> StreamSpyCollection: + assert client_spy is not None, "client_spy was not initialized" + assert server_spy is not None, "server_spy was not initialized" + return StreamSpyCollection(client_spy, server_spy) + + yield get_spy_collection diff --git a/tests/client/test_client.py b/tests/client/test_client.py new file mode 100644 index 0000000000..148debaccc --- /dev/null +++ b/tests/client/test_client.py @@ -0,0 +1,336 @@ +"""Tests for the unified Client class.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +import mcp.types as types +from mcp.client.client import Client +from mcp.server import Server +from mcp.server.fastmcp import FastMCP +from mcp.types import EmptyResult, Resource + +pytestmark = pytest.mark.anyio + + +@pytest.fixture +def simple_server() -> Server: + """Create a simple MCP server for testing.""" + server = Server(name="test_server") + + @server.list_resources() + async def handle_list_resources(): + return [ + Resource( + uri="memory://test", + name="Test Resource", + description="A test resource", + ) + ] + + return server + + +@pytest.fixture +def app() -> FastMCP: + """Create a FastMCP server for testing.""" + server = FastMCP("test") + + @server.tool() + def greet(name: str) -> str: + """Greet someone by name.""" + return f"Hello, {name}!" + + @server.tool() + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + @server.resource("test://resource") + def test_resource() -> str: + """A test resource.""" + return "Test content" + + @server.prompt() + def greeting_prompt(name: str) -> str: + """A greeting prompt.""" + return f"Please greet {name} warmly." + + return server + + +async def test_creates_client(app: FastMCP): + """Test that from_server creates a connected client.""" + async with Client(app) as client: + assert client is not None + + +async def test_client_is_initialized(app: FastMCP): + """Test that the client is initialized after entering context.""" + async with Client(app) as client: + caps = client.server_capabilities + assert caps is not None + assert caps.tools is not None + + +async def test_with_simple_server(simple_server: Server): + """Test that from_server works with a basic Server instance.""" + async with Client(simple_server) as client: + assert client is not None + caps = client.server_capabilities + assert caps is not None + # Verify list_resources works and returns expected resource + resources = await client.list_resources() + assert len(resources.resources) == 1 + assert resources.resources[0].uri == "memory://test" + + +async def test_ping_returns_empty_result(app: FastMCP): + """Test that ping returns an EmptyResult.""" + async with Client(app) as client: + result = await client.send_ping() + assert isinstance(result, EmptyResult) + + +async def test_list_tools(app: FastMCP): + """Test listing tools.""" + async with Client(app) as client: + result = await client.list_tools() + assert result.tools is not None + tool_names = [t.name for t in result.tools] + assert "greet" in tool_names + assert "add" in tool_names + + +async def test_list_tools_with_pagination(app: FastMCP): + """Test listing tools with pagination params.""" + from mcp.types import PaginatedRequestParams + + async with Client(app) as client: + result = await client.list_tools(params=PaginatedRequestParams()) + assert result.tools is not None + + +async def test_call_tool(app: FastMCP): + """Test calling a tool.""" + async with Client(app) as client: + result = await client.call_tool("greet", {"name": "World"}) + assert result.content is not None + assert len(result.content) > 0 + content_str = str(result.content[0]) + assert "Hello, World!" in content_str + + +async def test_call_tool_with_multiple_args(app: FastMCP): + """Test calling a tool with multiple arguments.""" + async with Client(app) as client: + result = await client.call_tool("add", {"a": 5, "b": 3}) + assert result.content is not None + content_str = str(result.content[0]) + assert "8" in content_str + + +async def test_list_resources(app: FastMCP): + """Test listing resources.""" + async with Client(app) as client: + result = await client.list_resources() + # FastMCP may have different resource listing behavior + assert result is not None + + +async def test_read_resource(app: FastMCP): + """Test reading a resource.""" + async with Client(app) as client: + result = await client.read_resource("test://resource") + assert result.contents is not None + assert len(result.contents) > 0 + + +async def test_list_prompts(app: FastMCP): + """Test listing prompts.""" + async with Client(app) as client: + result = await client.list_prompts() + prompt_names = [p.name for p in result.prompts] + assert "greeting_prompt" in prompt_names + + +async def test_get_prompt(app: FastMCP): + """Test getting a prompt.""" + async with Client(app) as client: + result = await client.get_prompt("greeting_prompt", {"name": "Alice"}) + assert result.messages is not None + assert len(result.messages) > 0 + + +async def test_session_property(app: FastMCP): + """Test that the session property returns the ClientSession.""" + from mcp.client.session import ClientSession + + async with Client(app) as client: + session = client.session + assert isinstance(session, ClientSession) + + +async def test_session_is_same_as_internal(app: FastMCP): + """Test that session property returns consistent instance.""" + async with Client(app) as client: + session1 = client.session + session2 = client.session + assert session1 is session2 + + +async def test_enters_and_exits_cleanly(app: FastMCP): + """Test that the client enters and exits cleanly.""" + async with Client(app) as client: + # Should be able to use client + await client.send_ping() + # After exiting, resources should be cleaned up + + +async def test_exception_during_use(app: FastMCP): + """Test that exceptions during use don't prevent cleanup.""" + with pytest.raises(Exception): # May be wrapped in ExceptionGroup by anyio + async with Client(app) as client: + await client.send_ping() + raise ValueError("Test exception") + # Should exit cleanly despite exception + + +async def test_aexit_without_aenter(app: FastMCP): + """Test that calling __aexit__ without __aenter__ doesn't raise.""" + client = Client(app) + # This should not raise even though __aenter__ was never called + await client.__aexit__(None, None, None) + assert client._session is None + + +async def test_server_capabilities_after_init(app: FastMCP): + """Test server_capabilities property after initialization.""" + async with Client(app) as client: + caps = client.server_capabilities + assert caps is not None + # FastMCP should advertise tools capability + assert caps.tools is not None + + +def test_session_property_before_enter(app: FastMCP): + """Test that accessing session before context manager raises RuntimeError.""" + client = Client(app) + with pytest.raises(RuntimeError, match="Client must be used within an async context manager"): + _ = client.session + + +async def test_reentry_raises_runtime_error(app: FastMCP): + """Test that reentering a client raises RuntimeError.""" + async with Client(app) as client: + with pytest.raises(RuntimeError, match="Client is already entered"): + await client.__aenter__() + + +async def test_cleanup_on_init_failure(app: FastMCP): + """Test that resources are cleaned up if initialization fails.""" + with patch("mcp.client.client.ClientSession") as mock_session_class: + # Create a mock context manager that fails on __aenter__ + mock_session = AsyncMock() + mock_session.__aenter__.side_effect = RuntimeError("Session init failed") + mock_session.__aexit__ = AsyncMock(return_value=None) + mock_session_class.return_value = mock_session + + client = Client(app) + with pytest.raises(BaseException) as exc_info: + await client.__aenter__() + + # The error should contain our message (may be wrapped in ExceptionGroup) + # Use repr() to see nested exceptions in ExceptionGroup + assert "Session init failed" in repr(exc_info.value) + + # Verify the client is in a clean state (session should be None) + assert client._session is None + + +async def test_send_progress_notification(app: FastMCP): + """Test sending progress notification.""" + async with Client(app) as client: + # Send a progress notification - this should not raise + await client.send_progress_notification( + progress_token="test-token", + progress=50.0, + total=100.0, + message="Half done", + ) + + +async def test_subscribe_resource(app: FastMCP): + """Test subscribing to a resource.""" + async with Client(app) as client: + # Mock the session's subscribe_resource since FastMCP doesn't support it + with patch.object(client.session, "subscribe_resource", return_value=EmptyResult()): + result = await client.subscribe_resource("test://resource") + assert isinstance(result, EmptyResult) + + +async def test_unsubscribe_resource(app: FastMCP): + """Test unsubscribing from a resource.""" + async with Client(app) as client: + # Mock the session's unsubscribe_resource since FastMCP doesn't support it + with patch.object(client.session, "unsubscribe_resource", return_value=EmptyResult()): + result = await client.unsubscribe_resource("test://resource") + assert isinstance(result, EmptyResult) + + +async def test_send_roots_list_changed(app: FastMCP): + """Test sending roots list changed notification.""" + async with Client(app) as client: + # Send roots list changed notification - should not raise + await client.send_roots_list_changed() + + +async def test_set_logging_level(app: FastMCP): + """Test setting logging level.""" + async with Client(app) as client: + # Mock the session's set_logging_level since FastMCP doesn't support it + with patch.object(client.session, "set_logging_level", return_value=EmptyResult()): + result = await client.set_logging_level("debug") + assert isinstance(result, EmptyResult) + + +async def test_list_resources_with_params(app: FastMCP): + """Test listing resources with params parameter.""" + async with Client(app) as client: + result = await client.list_resources(params=types.PaginatedRequestParams()) + assert result is not None + + +async def test_list_resource_templates_with_params(app: FastMCP): + """Test listing resource templates with params parameter.""" + async with Client(app) as client: + result = await client.list_resource_templates(params=types.PaginatedRequestParams()) + assert result is not None + + +async def test_list_resource_templates_default(app: FastMCP): + """Test listing resource templates with no params or cursor.""" + async with Client(app) as client: + result = await client.list_resource_templates() + assert result is not None + + +async def test_list_prompts_with_params(app: FastMCP): + """Test listing prompts with params parameter.""" + async with Client(app) as client: + result = await client.list_prompts(params=types.PaginatedRequestParams()) + assert result is not None + + +async def test_complete_with_prompt_reference(app: FastMCP): + """Test getting completions for a prompt argument.""" + async with Client(app) as client: + ref = types.PromptReference(type="ref/prompt", name="greeting_prompt") + # Mock the session's complete method since FastMCP may not support it + with patch.object( + client.session, + "complete", + return_value=types.CompleteResult(completion=types.Completion(values=[])), + ): + result = await client.complete(ref=ref, argument={"name": "test"}) + assert result is not None diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index 9b6e886f9a..a5f79910f6 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -3,9 +3,10 @@ import pytest import mcp.types as types +from mcp.client._memory import InMemoryTransport +from mcp.client.session import ClientSession from mcp.server import Server from mcp.server.fastmcp import FastMCP -from mcp.shared.memory import create_connected_server_and_client_session as create_session from mcp.types import ListToolsRequest, ListToolsResult from .conftest import StreamSpyCollection @@ -18,30 +19,28 @@ async def full_featured_server(): """Create a server with tools, resources, prompts, and templates.""" server = FastMCP("test") - @server.tool(name="test_tool_1") - async def test_tool_1() -> str: # pragma: no cover - """First test tool""" - return "Result 1" + # pragma: no cover on handlers below - these exist only to register items with the + # server so list_* methods return results. The handlers themselves are never called + # because these tests only verify pagination/cursor behavior, not tool/resource invocation. + @server.tool() + def greet(name: str) -> str: # pragma: no cover + """Greet someone by name.""" + return f"Hello, {name}!" - @server.tool(name="test_tool_2") - async def test_tool_2() -> str: # pragma: no cover - """Second test tool""" - return "Result 2" + @server.resource("test://resource") + def test_resource() -> str: # pragma: no cover + """A test resource.""" + return "Test content" - @server.resource("resource://test/data") - async def test_resource() -> str: # pragma: no cover - """Test resource""" - return "Test data" + @server.resource("test://template/{id}") + def test_template(id: str) -> str: # pragma: no cover + """A test resource template.""" + return f"Template content for {id}" @server.prompt() - async def test_prompt(name: str) -> str: # pragma: no cover - """Test prompt""" - return f"Hello, {name}!" - - @server.resource("resource://test/{name}") - async def test_template(name: str) -> str: # pragma: no cover - """Test resource template""" - return f"Data for {name}" + def greeting_prompt(name: str) -> str: # pragma: no cover + """A greeting prompt.""" + return f"Please greet {name}." return server @@ -61,78 +60,82 @@ async def test_list_methods_params_parameter( method_name: str, request_method: str, ): - """Test that the params parameter works correctly for list methods. + """Test that the params parameter is accepted and correctly passed to the server. Covers: list_tools, list_resources, list_prompts, list_resource_templates - This tests the new params parameter API (non-deprecated) to ensure - it correctly handles all parameter combinations. + See: https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/pagination#request-format """ - async with create_session(full_featured_server._mcp_server) as client_session: - spies = stream_spy() - method = getattr(client_session, method_name) - - # Test without params parameter (omitted) - _ = await method() - requests = spies.get_client_requests(method=request_method) - assert len(requests) == 1 - assert requests[0].params is None - - spies.clear() - - # Test with params=None - _ = await method(params=None) - requests = spies.get_client_requests(method=request_method) - assert len(requests) == 1 - assert requests[0].params is None - - spies.clear() - - # Test with empty params (for strict servers) - _ = await method(params=types.PaginatedRequestParams()) - requests = spies.get_client_requests(method=request_method) - assert len(requests) == 1 - assert requests[0].params is not None - assert requests[0].params.get("cursor") is None - - spies.clear() - - # Test with params containing cursor - _ = await method(params=types.PaginatedRequestParams(cursor="some_cursor_value")) - requests = spies.get_client_requests(method=request_method) - assert len(requests) == 1 - assert requests[0].params is not None - assert requests[0].params["cursor"] == "some_cursor_value" - - -async def test_list_tools_with_strict_server_validation(): - """Test that list_tools works with strict servers require a params field, - even if it is empty. - - Some MCP servers may implement strict JSON-RPC validation that requires - the params field to always be present in requests, even if empty {}. + transport = InMemoryTransport(full_featured_server) + async with transport.connect() as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + spies = stream_spy() + + # Test without params (omitted) + method = getattr(session, method_name) + _ = await method() + requests = spies.get_client_requests(method=request_method) + assert len(requests) == 1 + assert requests[0].params is None + + spies.clear() + + # Test with params containing cursor + _ = await method(params=types.PaginatedRequestParams(cursor="from_params")) + requests = spies.get_client_requests(method=request_method) + assert len(requests) == 1 + assert requests[0].params is not None + assert requests[0].params["cursor"] == "from_params" + + spies.clear() + + # Test with empty params + _ = await method(params=types.PaginatedRequestParams()) + requests = spies.get_client_requests(method=request_method) + assert len(requests) == 1 + # Empty params means no cursor + assert requests[0].params is None or "cursor" not in requests[0].params + + +async def test_list_tools_with_strict_server_validation( + full_featured_server: FastMCP, +): + """Test pagination with a server that validates request format strictly.""" + transport = InMemoryTransport(full_featured_server) + async with transport.connect() as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + result = await session.list_tools(params=types.PaginatedRequestParams()) + assert isinstance(result, ListToolsResult) + assert len(result.tools) > 0 - This test ensures such servers are supported by the client SDK for list_resources - requests without a cursor. - """ - server = Server("strict_server") +async def test_list_tools_with_lowlevel_server(): + """Test that list_tools works with a lowlevel Server using params.""" + server = Server("test-lowlevel") @server.list_tools() - async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: # pragma: no cover - """Strict handler that validates params field exists""" - - # Simulate strict server validation - if request.params is None: - raise ValueError( - "Strict server validation failed: params field must be present. " - "Expected params: {} for requests without cursor." - ) - - # Return empty tools list - return ListToolsResult(tools=[]) - - async with create_session(server) as client_session: - # Use params to explicitly send params: {} for strict server compatibility - result = await client_session.list_tools(params=types.PaginatedRequestParams()) - assert result is not None + async def handle_list_tools(request: ListToolsRequest) -> ListToolsResult: + # Echo back what cursor we received in the tool description + cursor = request.params.cursor if request.params else None + return ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description=f"cursor={cursor}", + input_schema={}, + ) + ] + ) + + transport = InMemoryTransport(server) + async with transport.connect() as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + result = await session.list_tools(params=types.PaginatedRequestParams()) + assert result.tools[0].description == "cursor=None" + + result = await session.list_tools(params=types.PaginatedRequestParams(cursor="page2")) + assert result.tools[0].description == "cursor=page2" diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index c664676163..a8f8823fe5 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -1,14 +1,12 @@ import pytest from pydantic import FileUrl +from mcp import Client from mcp.client.session import ClientSession from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp.server import Context from mcp.server.session import ServerSession from mcp.shared.context import RequestContext -from mcp.shared.memory import ( - create_connected_server_and_client_session as create_session, -) from mcp.types import ListRootsResult, Root, TextContent @@ -41,17 +39,17 @@ async def test_list_roots(context: Context[ServerSession, None], message: str): return True # Test with list_roots callback - async with create_session(server._mcp_server, list_roots_callback=list_roots_callback) as client_session: + async with Client(server, list_roots_callback=list_roots_callback) as client: # Make a request to trigger sampling callback - result = await client_session.call_tool("test_list_roots", {"message": "test message"}) + result = await client.call_tool("test_list_roots", {"message": "test message"}) assert result.is_error is False assert isinstance(result.content[0], TextContent) assert result.content[0].text == "true" # Test without list_roots callback - async with create_session(server._mcp_server) as client_session: + async with Client(server) as client: # Make a request to trigger sampling callback - result = await client_session.call_tool("test_list_roots", {"message": "test message"}) + result = await client.call_tool("test_list_roots", {"message": "test message"}) assert result.is_error is True assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Error executing tool test_list_roots: List roots not supported" diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 066837885d..687efca71e 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -3,10 +3,8 @@ import pytest import mcp.types as types +from mcp import Client from mcp.server.fastmcp import FastMCP -from mcp.shared.memory import ( - create_connected_server_and_client_session as create_session, -) from mcp.shared.session import RequestResponder from mcp.types import ( LoggingMessageNotificationParams, @@ -70,19 +68,19 @@ async def message_handler( if isinstance(message, Exception): # pragma: no cover raise message - async with create_session( - server._mcp_server, + async with Client( + server, logging_callback=logging_collector, message_handler=message_handler, - ) as client_session: + ) as client: # First verify our test tool works - result = await client_session.call_tool("test_tool", {}) + result = await client.call_tool("test_tool", {}) assert result.is_error is False assert isinstance(result.content[0], TextContent) assert result.content[0].text == "true" # Now send a log message via our tool - log_result = await client_session.call_tool( + log_result = await client.call_tool( "test_tool_with_log", { "message": "Test log message", @@ -90,7 +88,7 @@ async def message_handler( "logger": "test_logger", }, ) - log_result_with_extra = await client_session.call_tool( + log_result_with_extra = await client.call_tool( "test_tool_with_log_extra", { "message": "Test log message", diff --git a/tests/client/test_output_schema_validation.py b/tests/client/test_output_schema_validation.py index 8a9c93acad..24f7b2b69c 100644 --- a/tests/client/test_output_schema_validation.py +++ b/tests/client/test_output_schema_validation.py @@ -7,10 +7,8 @@ import jsonschema import pytest +from mcp import Client from mcp.server.lowlevel import Server -from mcp.shared.memory import ( - create_connected_server_and_client_session as client_session, -) from mcp.types import Tool @@ -77,7 +75,7 @@ async def call_tool(name: str, arguments: dict[str, Any]): # Test that client validates the structured content with bypass_server_output_validation(): - async with client_session(server) as client: + async with Client(server) as client: # The client validates structured content and should raise an error with pytest.raises(RuntimeError) as exc_info: await client.call_tool("get_user", {}) @@ -114,7 +112,7 @@ async def call_tool(name: str, arguments: dict[str, Any]): return {"result": "not_a_number"} # Invalid: should be int with bypass_server_output_validation(): - async with client_session(server) as client: + async with Client(server) as client: # The client validates structured content and should raise an error with pytest.raises(RuntimeError) as exc_info: await client.call_tool("calculate", {}) @@ -145,7 +143,7 @@ async def call_tool(name: str, arguments: dict[str, Any]): return {"alice": "100", "bob": "85"} # Invalid: values should be int with bypass_server_output_validation(): - async with client_session(server) as client: + async with Client(server) as client: # The client validates structured content and should raise an error with pytest.raises(RuntimeError) as exc_info: await client.call_tool("get_scores", {}) @@ -180,7 +178,7 @@ async def call_tool(name: str, arguments: dict[str, Any]): return {"name": "John", "age": 30} # Missing required 'email' with bypass_server_output_validation(): - async with client_session(server) as client: + async with Client(server) as client: # The client validates structured content and should raise an error with pytest.raises(RuntimeError) as exc_info: await client.call_tool("get_person", {}) @@ -205,7 +203,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]: caplog.set_level(logging.WARNING) with bypass_server_output_validation(): - async with client_session(server) as client: + async with Client(server) as client: # Call a tool that wasn't listed result = await client.call_tool("mystery_tool", {}) assert result.structured_content == {"result": 42} diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 3d0b58ed1d..1394e665ca 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,11 +1,9 @@ import pytest +from mcp import Client from mcp.client.session import ClientSession from mcp.server.fastmcp import FastMCP from mcp.shared.context import RequestContext -from mcp.shared.memory import ( - create_connected_server_and_client_session as create_session, -) from mcp.types import ( CreateMessageRequestParams, CreateMessageResult, @@ -43,17 +41,17 @@ async def test_sampling_tool(message: str): return True # Test with sampling callback - async with create_session(server._mcp_server, sampling_callback=sampling_callback) as client_session: + async with Client(server, sampling_callback=sampling_callback) as client: # Make a request to trigger sampling callback - result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"}) + result = await client.call_tool("test_sampling", {"message": "Test message for sampling"}) assert result.is_error is False assert isinstance(result.content[0], TextContent) assert result.content[0].text == "true" # Test without sampling callback - async with create_session(server._mcp_server) as client_session: + async with Client(server) as client: # Make a request to trigger sampling callback - result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"}) + result = await client.call_tool("test_sampling", {"message": "Test message for sampling"}) assert result.is_error is True assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported" @@ -94,8 +92,8 @@ async def test_tool(message: str): assert not hasattr(result, "content_as_list") or not callable(getattr(result, "content_as_list", None)) return True - async with create_session(server._mcp_server, sampling_callback=sampling_callback) as client_session: - result = await client_session.call_tool("test_backwards_compat", {"message": "Test"}) + async with Client(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("test_backwards_compat", {"message": "Test"}) assert result.is_error is False assert isinstance(result.content[0], TextContent) assert result.content[0].text == "true" diff --git a/tests/client/transports/__init__.py b/tests/client/transports/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py new file mode 100644 index 0000000000..fbaf9a9820 --- /dev/null +++ b/tests/client/transports/test_memory.py @@ -0,0 +1,114 @@ +"""Tests for InMemoryTransport.""" + +import pytest + +from mcp.client._memory import InMemoryTransport +from mcp.server import Server +from mcp.server.fastmcp import FastMCP +from mcp.types import Resource + + +@pytest.fixture +def simple_server() -> Server: + """Create a simple MCP server for testing.""" + server = Server(name="test_server") + + # pragma: no cover - handler exists only to register a resource capability. + # Transport tests verify stream creation, not handler invocation. + @server.list_resources() + async def handle_list_resources(): # pragma: no cover + return [ + Resource( + uri="memory://test", + name="Test Resource", + description="A test resource", + ) + ] + + return server + + +@pytest.fixture +def fastmcp_server() -> FastMCP: + """Create a FastMCP server for testing.""" + server = FastMCP("test") + + # pragma: no cover on handlers below - they exist only to register capabilities. + # Transport tests verify stream creation and basic protocol, not handler invocation. + @server.tool() + def greet(name: str) -> str: # pragma: no cover + """Greet someone by name.""" + return f"Hello, {name}!" + + @server.resource("test://resource") + def test_resource() -> str: # pragma: no cover + """A test resource.""" + return "Test content" + + return server + + +pytestmark = pytest.mark.anyio + + +async def test_with_server(simple_server: Server): + """Test creating transport with a Server instance.""" + transport = InMemoryTransport(simple_server) + async with transport.connect() as (read_stream, write_stream): + assert read_stream is not None + assert write_stream is not None + + +async def test_with_fastmcp(fastmcp_server: FastMCP): + """Test creating transport with a FastMCP instance.""" + transport = InMemoryTransport(fastmcp_server) + async with transport.connect() as (read_stream, write_stream): + assert read_stream is not None + assert write_stream is not None + + +async def test_server_is_running(fastmcp_server: FastMCP): + """Test that the server is running and responding to requests.""" + from mcp.client.session import ClientSession + + transport = InMemoryTransport(fastmcp_server) + async with transport.connect() as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + result = await session.initialize() + assert result is not None + assert result.server_info.name == "test" + + +async def test_list_tools(fastmcp_server: FastMCP): + """Test listing tools through the transport.""" + from mcp.client.session import ClientSession + + transport = InMemoryTransport(fastmcp_server) + async with transport.connect() as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + tools_result = await session.list_tools() + assert len(tools_result.tools) > 0 + tool_names = [t.name for t in tools_result.tools] + assert "greet" in tool_names + + +async def test_call_tool(fastmcp_server: FastMCP): + """Test calling a tool through the transport.""" + from mcp.client.session import ClientSession + + transport = InMemoryTransport(fastmcp_server) + async with transport.connect() as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + result = await session.call_tool("greet", {"name": "World"}) + assert result is not None + assert len(result.content) > 0 + assert "Hello, World!" in str(result.content[0]) + + +async def test_raise_exceptions(fastmcp_server: FastMCP): + """Test that raise_exceptions parameter is passed through.""" + transport = InMemoryTransport(fastmcp_server, raise_exceptions=True) + async with transport.connect() as (read_stream, _write_stream): + assert read_stream is not None diff --git a/tests/issues/test_141_resource_templates.py b/tests/issues/test_141_resource_templates.py index 2300f7f736..b024d8e923 100644 --- a/tests/issues/test_141_resource_templates.py +++ b/tests/issues/test_141_resource_templates.py @@ -1,10 +1,8 @@ import pytest from pydantic import AnyUrl +from mcp import Client from mcp.server.fastmcp import FastMCP -from mcp.shared.memory import ( - create_connected_server_and_client_session as client_session, -) from mcp.types import ( ListResourceTemplatesResult, TextResourceContents, @@ -78,10 +76,7 @@ def get_user_post(user_id: str, post_id: str) -> str: def get_user_profile(user_id: str) -> str: return f"Profile for user {user_id}" - async with client_session(mcp._mcp_server) as session: - # Initialize the session - await session.initialize() - + async with Client(mcp) as session: # List available resources resources = await session.list_resource_templates() assert isinstance(resources, ListResourceTemplatesResult) diff --git a/tests/issues/test_152_resource_mime_type.py b/tests/issues/test_152_resource_mime_type.py index 07c129aadf..9618d8414a 100644 --- a/tests/issues/test_152_resource_mime_type.py +++ b/tests/issues/test_152_resource_mime_type.py @@ -3,13 +3,10 @@ import pytest from pydantic import AnyUrl -from mcp import types +from mcp import Client, types from mcp.server.fastmcp import FastMCP from mcp.server.lowlevel import Server from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.shared.memory import ( - create_connected_server_and_client_session as client_session, -) pytestmark = pytest.mark.anyio @@ -33,7 +30,7 @@ def get_image_as_bytes() -> bytes: return image_bytes # Test that resources are listed with correct mime type - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: # List resources and verify mime types resources = await client.list_resources() assert resources.resources is not None @@ -91,7 +88,7 @@ async def handle_read_resource(uri: str): raise Exception(f"Resource not found: {uri}") # pragma: no cover # Test that resources are listed with correct mime type - async with client_session(server) as client: + async with Client(server) as client: # List resources and verify mime types resources = await client.list_resources() assert resources.resources is not None diff --git a/tests/issues/test_1574_resource_uri_validation.py b/tests/issues/test_1574_resource_uri_validation.py index c936af09f9..e6ff568774 100644 --- a/tests/issues/test_1574_resource_uri_validation.py +++ b/tests/issues/test_1574_resource_uri_validation.py @@ -12,12 +12,9 @@ import pytest -from mcp import types +from mcp import Client, types from mcp.server.lowlevel import Server from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.shared.memory import ( - create_connected_server_and_client_session as client_session, -) pytestmark = pytest.mark.anyio @@ -48,7 +45,7 @@ async def read_resource(uri: str): ) ] - async with client_session(server) as client: + async with Client(server) as client: # List should return the exact URIs we specified resources = await client.list_resources() uri_map = {r.uri: r for r in resources.resources} @@ -83,7 +80,7 @@ async def list_resources(): async def read_resource(uri: str): return [ReadResourceContents(content="data", mime_type="text/plain")] - async with client_session(server) as client: + async with Client(server) as client: resources = await client.list_resources() uri_map = {r.uri: r for r in resources.resources} diff --git a/tests/issues/test_1754_mime_type_parameters.py b/tests/issues/test_1754_mime_type_parameters.py index 0260a5b691..c48d56b810 100644 --- a/tests/issues/test_1754_mime_type_parameters.py +++ b/tests/issues/test_1754_mime_type_parameters.py @@ -7,10 +7,8 @@ import pytest from pydantic import AnyUrl +from mcp import Client from mcp.server.fastmcp import FastMCP -from mcp.shared.memory import ( - create_connected_server_and_client_session as client_session, -) pytestmark = pytest.mark.anyio @@ -63,7 +61,7 @@ async def test_mime_type_preserved_in_read_resource(): def my_widget() -> str: return "Hello MCP-UI" - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: # Read the resource result = await client.read_resource(AnyUrl("ui://my-widget")) assert len(result.contents) == 1 diff --git a/tests/issues/test_188_concurrency.py b/tests/issues/test_188_concurrency.py index 831736510b..615df3d8e8 100644 --- a/tests/issues/test_188_concurrency.py +++ b/tests/issues/test_188_concurrency.py @@ -2,8 +2,8 @@ import pytest from pydantic import AnyUrl +from mcp import Client from mcp.server.fastmcp import FastMCP -from mcp.shared.memory import create_connected_server_and_client_session as create_session @pytest.mark.anyio @@ -30,7 +30,7 @@ async def trigger(): call_order.append("trigger_end") return "slow" - async with create_session(server._mcp_server) as client_session: + async with Client(server) as client_session: # First tool will wait on event, second will set it async with anyio.create_task_group() as tg: # Start the tool first (it will wait on event) @@ -70,7 +70,7 @@ async def slow_resource(): call_order.append("resource_end") return "slow" - async with create_session(server._mcp_server) as client_session: + async with Client(server) as client_session: # First tool will wait on event, second will set it async with anyio.create_task_group() as tg: # Start the tool first (it will wait on event) diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index 4ba5ac0007..a8bf2815c5 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -7,12 +7,11 @@ import pytest from pydantic import BaseModel, Field -from mcp import types +from mcp import Client, types from mcp.client.session import ClientSession, ElicitationFnT from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession from mcp.shared.context import RequestContext -from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -47,12 +46,8 @@ async def call_tool_and_assert( text_contains: list[str] | None = None, ): """Helper to create session, call tool, and assert result.""" - async with create_connected_server_and_client_session( - mcp._mcp_server, elicitation_callback=elicitation_callback - ) as client_session: - await client_session.initialize() - - result = await client_session.call_tool(tool_name, args) + async with Client(mcp, elicitation_callback=elicitation_callback) as client: + result = await client.call_tool(tool_name, args) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -134,14 +129,10 @@ async def elicitation_callback( ): # pragma: no cover return ElicitResult(action="accept", content={}) - async with create_connected_server_and_client_session( - mcp._mcp_server, elicitation_callback=elicitation_callback - ) as client_session: - await client_session.initialize() - + async with Client(mcp, elicitation_callback=elicitation_callback) as client: # Test both invalid schemas for tool_name, field_name in [("invalid_list", "numbers"), ("nested_model", "nested")]: - result = await client_session.call_tool(tool_name, {}) + result = await client.call_tool(tool_name, {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) assert "Validation failed as expected" in result.content[0].text diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index 8a27732fbf..6d1cee58ef 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -8,6 +8,7 @@ from starlette.applications import Starlette from starlette.routing import Mount, Route +from mcp.client import Client from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.prompts.base import Message, UserMessage @@ -16,7 +17,6 @@ from mcp.server.session import ServerSession from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import McpError -from mcp.shared.memory import create_connected_server_and_client_session as client_session from mcp.types import ( AudioContent, BlobResourceContents, @@ -77,7 +77,7 @@ async def test_non_ascii_description(self): def hello_world(name: str = "世界") -> str: return f"¡Hola, {name}! 👋" - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: tools = await client.list_tools() assert len(tools.tools) == 1 tool = tools.tools[0] @@ -230,7 +230,7 @@ async def test_add_tool(self): async def test_list_tools(self): mcp = FastMCP() mcp.add_tool(tool_fn) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: tools = await client.list_tools() assert len(tools.tools) == 1 @@ -238,7 +238,7 @@ async def test_list_tools(self): async def test_call_tool(self): mcp = FastMCP() mcp.add_tool(tool_fn) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("my_tool", {"arg1": "value"}) assert not hasattr(result, "error") assert len(result.content) > 0 @@ -247,7 +247,7 @@ async def test_call_tool(self): async def test_tool_exception_handling(self): mcp = FastMCP() mcp.add_tool(error_tool_fn) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("error_tool_fn", {}) assert len(result.content) == 1 content = result.content[0] @@ -259,7 +259,7 @@ async def test_tool_exception_handling(self): async def test_tool_error_handling(self): mcp = FastMCP() mcp.add_tool(error_tool_fn) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("error_tool_fn", {}) assert len(result.content) == 1 content = result.content[0] @@ -272,7 +272,7 @@ async def test_tool_error_details(self): """Test that exception details are properly formatted in the response""" mcp = FastMCP() mcp.add_tool(error_tool_fn) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("error_tool_fn", {}) content = result.content[0] assert isinstance(content, TextContent) @@ -284,7 +284,7 @@ async def test_tool_error_details(self): async def test_tool_return_value_conversion(self): mcp = FastMCP() mcp.add_tool(tool_fn) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("tool_fn", {"x": 1, "y": 2}) assert len(result.content) == 1 content = result.content[0] @@ -302,7 +302,7 @@ async def test_tool_image_helper(self, tmp_path: Path): mcp = FastMCP() mcp.add_tool(image_tool_fn) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("image_tool_fn", {"path": str(image_path)}) assert len(result.content) == 1 content = result.content[0] @@ -323,7 +323,7 @@ async def test_tool_audio_helper(self, tmp_path: Path): mcp = FastMCP() mcp.add_tool(audio_tool_fn) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("audio_tool_fn", {"path": str(audio_path)}) assert len(result.content) == 1 content = result.content[0] @@ -358,7 +358,7 @@ async def test_tool_audio_suffix_detection(self, tmp_path: Path, filename: str, audio_path = tmp_path / filename audio_path.write_bytes(b"fake audio data") - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("audio_tool_fn", {"path": str(audio_path)}) assert len(result.content) == 1 content = result.content[0] @@ -373,7 +373,7 @@ async def test_tool_audio_suffix_detection(self, tmp_path: Path, filename: str, async def test_tool_mixed_content(self): mcp = FastMCP() mcp.add_tool(mixed_content_tool_fn) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("mixed_content_tool_fn", {}) assert len(result.content) == 3 content1, content2, content3 = result.content @@ -425,7 +425,7 @@ def mixed_list_fn() -> list: # type: ignore mcp = FastMCP() mcp.add_tool(mixed_list_fn) # type: ignore - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("mixed_list_fn", {}) assert len(result.content) == 5 # Check text conversion @@ -469,7 +469,7 @@ def get_user(user_id: int) -> UserOutput: mcp = FastMCP() mcp.add_tool(get_user) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: # Check that the tool has outputSchema tools = await client.list_tools() tool = next(t for t in tools.tools if t.name == "get_user") @@ -499,7 +499,7 @@ def calculate_sum(a: int, b: int) -> int: mcp = FastMCP() mcp.add_tool(calculate_sum) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: # Check that the tool has outputSchema tools = await client.list_tools() tool = next(t for t in tools.tools if t.name == "calculate_sum") @@ -526,7 +526,7 @@ def get_numbers() -> list[int]: mcp = FastMCP() mcp.add_tool(get_numbers) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("get_numbers", {}) assert result.is_error is False assert result.structured_content is not None @@ -542,7 +542,7 @@ def get_numbers() -> list[int]: mcp = FastMCP() mcp.add_tool(get_numbers) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("get_numbers", {}) assert result.is_error is True assert result.structured_content is None @@ -566,7 +566,7 @@ def get_metadata() -> dict[str, Any]: mcp = FastMCP() mcp.add_tool(get_metadata) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: # Check schema tools = await client.list_tools() tool = next(t for t in tools.tools if t.name == "get_metadata") @@ -602,7 +602,7 @@ def get_settings() -> dict[str, str]: mcp = FastMCP() mcp.add_tool(get_settings) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: # Check schema tools = await client.list_tools() tool = next(t for t in tools.tools if t.name == "get_settings") @@ -646,7 +646,7 @@ async def test_remove_tool_and_list(self): mcp.add_tool(error_tool_fn) # Verify both tools exist - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: tools = await client.list_tools() assert len(tools.tools) == 2 tool_names = [t.name for t in tools.tools] @@ -657,7 +657,7 @@ async def test_remove_tool_and_list(self): mcp.remove_tool("tool_fn") # Verify only one tool remains - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: tools = await client.list_tools() assert len(tools.tools) == 1 assert tools.tools[0].name == "error_tool_fn" @@ -669,7 +669,7 @@ async def test_remove_tool_and_call(self): mcp.add_tool(tool_fn) # Verify tool works before removal - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("tool_fn", {"x": 1, "y": 2}) assert not result.is_error content = result.content[0] @@ -680,7 +680,7 @@ async def test_remove_tool_and_call(self): mcp.remove_tool("tool_fn") # Verify calling removed tool returns an error - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("tool_fn", {"x": 1, "y": 2}) assert result.is_error content = result.content[0] @@ -699,8 +699,12 @@ def get_text(): resource = FunctionResource(uri="resource://test", name="test", fn=get_text) mcp.add_resource(resource) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.read_resource("resource://test") + + async with Client(mcp) as client: + result = await client.read_resource("resource://test") + assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Hello, world!" @@ -719,8 +723,12 @@ def get_binary(): ) mcp.add_resource(resource) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: + result = await client.read_resource("resource://binary") + + async with Client(mcp) as client: result = await client.read_resource("resource://binary") + assert isinstance(result.contents[0], BlobResourceContents) assert result.contents[0].blob == base64.b64encode(b"Binary data").decode() @@ -735,8 +743,12 @@ async def test_file_resource_text(self, tmp_path: Path): resource = FileResource(uri="file://test.txt", name="test.txt", path=text_file) mcp.add_resource(resource) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: + result = await client.read_resource("file://test.txt") + + async with Client(mcp) as client: result = await client.read_resource("file://test.txt") + assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Hello from file!" @@ -756,8 +768,12 @@ async def test_file_resource_binary(self, tmp_path: Path): ) mcp.add_resource(resource) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.read_resource("file://test.bin") + + async with Client(mcp) as client: + result = await client.read_resource("file://test.bin") + assert isinstance(result.contents[0], BlobResourceContents) assert result.contents[0].blob == base64.b64encode(b"Binary file data").decode() @@ -770,7 +786,7 @@ def get_data() -> str: # pragma: no cover """get_data returns a string""" return "Hello, world!" - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: resources = await client.list_resources() assert len(resources.resources) == 1 resource = resources.resources[0] @@ -822,8 +838,12 @@ async def test_resource_matching_params(self): def get_data(name: str) -> str: return f"Data for {name}" - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.read_resource("resource://test/data") + + async with Client(mcp) as client: + result = await client.read_resource("resource://test/data") + assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Data for test" @@ -847,8 +867,12 @@ async def test_resource_multiple_params(self): def get_data(org: str, repo: str) -> str: return f"Data for {org}/{repo}" - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.read_resource("resource://cursor/fastmcp/data") + + async with Client(mcp) as client: + result = await client.read_resource("resource://cursor/fastmcp/data") + assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Data for cursor/fastmcp" @@ -870,8 +894,12 @@ def get_data_mismatched(org: str, repo_2: str) -> str: # pragma: no cover def get_static_data() -> str: return "Static data" - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: + result = await client.read_resource("resource://static") + + async with Client(mcp) as client: result = await client.read_resource("resource://static") + assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Static data" @@ -910,8 +938,12 @@ def get_csv(user: str) -> str: assert hasattr(template, "mime_type") assert template.mime_type == "text/csv" - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.read_resource("resource://bob/csv") + + async with Client(mcp) as client: + result = await client.read_resource("resource://bob/csv") + assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "csv for bob" @@ -972,7 +1004,10 @@ async def test_read_resource_returns_meta(self): def get_data() -> str: return "test data" - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: + result = await client.read_resource("resource://data") + + async with Client(mcp) as client: result = await client.read_resource("resource://data") # Verify content and metadata in protocol response @@ -1008,7 +1043,7 @@ def tool_with_context(x: int, ctx: Context[ServerSession, None]) -> str: return f"Request {ctx.request_id}: {x}" mcp.add_tool(tool_with_context) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("tool_with_context", {"x": 42}) assert len(result.content) == 1 content = result.content[0] @@ -1026,7 +1061,7 @@ async def async_tool(x: int, ctx: Context[ServerSession, None]) -> str: return f"Async request {ctx.request_id}: {x}" mcp.add_tool(async_tool) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("async_tool", {"x": 42}) assert len(result.content) == 1 content = result.content[0] @@ -1049,7 +1084,7 @@ async def logging_tool(msg: str, ctx: Context[ServerSession, None]) -> str: mcp.add_tool(logging_tool) with patch("mcp.server.session.ServerSession.send_log_message") as mock_log: - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("logging_tool", {"msg": "test"}) assert len(result.content) == 1 content = result.content[0] @@ -1091,7 +1126,7 @@ def no_context(x: int) -> int: return x * 2 mcp.add_tool(no_context) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("no_context", {"x": 21}) assert len(result.content) == 1 content = result.content[0] @@ -1115,7 +1150,7 @@ async def tool_with_resource(ctx: Context[ServerSession, None]) -> str: r = r_list[0] return f"Read resource: {r.content} with mime type {r.mime_type}" - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.call_tool("tool_with_resource", {}) assert len(result.content) == 1 content = result.content[0] @@ -1141,8 +1176,13 @@ def resource_with_context(name: str, ctx: Context[ServerSession, None]) -> str: assert template.context_kwarg == "ctx" # Test via client - async with client_session(mcp._mcp_server) as client: + + async with Client(mcp) as client: + result = await client.read_resource("resource://context/test") + + async with Client(mcp) as client: result = await client.read_resource("resource://context/test") + assert len(result.contents) == 1 content = result.contents[0] assert isinstance(content, TextResourceContents) @@ -1166,8 +1206,13 @@ def resource_no_context(name: str) -> str: assert template.context_kwarg is None # Test via client - async with client_session(mcp._mcp_server) as client: + + async with Client(mcp) as client: result = await client.read_resource("resource://nocontext/test") + + async with Client(mcp) as client: + result = await client.read_resource("resource://nocontext/test") + assert len(result.contents) == 1 content = result.contents[0] assert isinstance(content, TextResourceContents) @@ -1191,8 +1236,13 @@ def resource_custom_ctx(id: str, my_ctx: Context[ServerSession, None]) -> str: assert template.context_kwarg == "my_ctx" # Test via client - async with client_session(mcp._mcp_server) as client: + + async with Client(mcp) as client: + result = await client.read_resource("resource://custom/123") + + async with Client(mcp) as client: result = await client.read_resource("resource://custom/123") + assert len(result.contents) == 1 content = result.contents[0] assert isinstance(content, TextResourceContents) @@ -1214,7 +1264,7 @@ def prompt_with_context(text: str, ctx: Context[ServerSession, None]) -> str: assert len(prompts) == 1 # Test via client - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: # Try calling without passing ctx explicitly result = await client.get_prompt("prompt_with_ctx", {"text": "test"}) # If this succeeds, check if context was injected @@ -1234,7 +1284,7 @@ def prompt_no_context(text: str) -> str: return f"Prompt '{text}' works" # Test via client - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.get_prompt("prompt_no_ctx", {"text": "test"}) assert len(result.messages) == 1 message = result.messages[0] @@ -1313,7 +1363,7 @@ async def test_list_prompts(self): def fn(name: str, optional: str = "default") -> str: # pragma: no cover return f"Hello, {name}!" - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.list_prompts() assert result.prompts is not None assert len(result.prompts) == 1 @@ -1335,7 +1385,7 @@ async def test_get_prompt(self): def fn(name: str) -> str: return f"Hello, {name}!" - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.get_prompt("fn", {"name": "World"}) assert len(result.messages) == 1 message = result.messages[0] @@ -1353,7 +1403,7 @@ async def test_get_prompt_with_description(self): def fn(name: str) -> str: return f"Hello, {name}!" - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.get_prompt("fn", {"name": "World"}) assert result.description == "Test prompt description" @@ -1366,7 +1416,7 @@ async def test_get_prompt_without_description(self): def fn(name: str) -> str: return f"Hello, {name}!" - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.get_prompt("fn", {"name": "World"}) assert result.description == "" @@ -1380,7 +1430,7 @@ def fn(name: str) -> str: """This is the function docstring.""" return f"Hello, {name}!" - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.get_prompt("fn", {"name": "World"}) assert result.description == "This is the function docstring." @@ -1402,7 +1452,7 @@ def fn() -> Message: ) ) - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: result = await client.get_prompt("fn") assert len(result.messages) == 1 message = result.messages[0] @@ -1418,7 +1468,7 @@ def fn() -> Message: async def test_get_unknown_prompt(self): """Test error when getting unknown prompt.""" mcp = FastMCP() - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: with pytest.raises(McpError, match="Unknown prompt"): await client.get_prompt("unknown") @@ -1431,7 +1481,7 @@ async def test_get_prompt_missing_args(self): def prompt_fn(name: str) -> str: # pragma: no cover return f"Hello, {name}!" - async with client_session(mcp._mcp_server) as client: + async with Client(mcp) as client: with pytest.raises(McpError, match="Missing required arguments"): await client.get_prompt("prompt_fn") diff --git a/tests/server/fastmcp/test_title.py b/tests/server/fastmcp/test_title.py index 7986db08c6..2cb1173b3e 100644 --- a/tests/server/fastmcp/test_title.py +++ b/tests/server/fastmcp/test_title.py @@ -2,9 +2,9 @@ import pytest +from mcp import Client from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp.resources import FunctionResource -from mcp.shared.memory import create_connected_server_and_client_session from mcp.shared.metadata_utils import get_display_name from mcp.types import Prompt, Resource, ResourceTemplate, Tool, ToolAnnotations @@ -24,8 +24,9 @@ async def test_server_name_title_description_version(): assert mcp.version == "1.0" # Start server and connect client - async with create_connected_server_and_client_session(mcp._mcp_server) as client: - init_result = await client.initialize() + async with Client(mcp) as client: + # Access initialization result from session + init_result = await client.session.initialize() assert init_result.server_info.name == "TestServer" assert init_result.server_info.title == "Test Server Title" assert init_result.server_info.description == "This is a test server description." @@ -60,9 +61,7 @@ def tool_with_both(message: str) -> str: # pragma: no cover return message # Start server and connect client - async with create_connected_server_and_client_session(mcp._mcp_server) as client: - await client.initialize() - + async with Client(mcp) as client: # List tools tools_result = await client.list_tools() tools = {tool.name: tool for tool in tools_result.tools} @@ -104,9 +103,7 @@ def titled_prompt(topic: str) -> str: # pragma: no cover return f"Tell me about {topic}" # Start server and connect client - async with create_connected_server_and_client_session(mcp._mcp_server) as client: - await client.initialize() - + async with Client(mcp) as client: # List prompts prompts_result = await client.list_prompts() prompts = {prompt.name: prompt for prompt in prompts_result.prompts} @@ -164,9 +161,7 @@ def titled_dynamic_resource(id: str) -> str: # pragma: no cover return f"Data for {id}" # Start server and connect client - async with create_connected_server_and_client_session(mcp._mcp_server) as client: - await client.initialize() - + async with Client(mcp) as client: # List resources resources_result = await client.list_resources() resources = {str(res.uri): res for res in resources_result.resources} diff --git a/tests/server/fastmcp/test_url_elicitation.py b/tests/server/fastmcp/test_url_elicitation.py index c960232223..cade2aa564 100644 --- a/tests/server/fastmcp/test_url_elicitation.py +++ b/tests/server/fastmcp/test_url_elicitation.py @@ -4,13 +4,12 @@ import pytest from pydantic import BaseModel, Field -from mcp import types +from mcp import Client, types from mcp.client.session import ClientSession from mcp.server.elicitation import CancelledElicitation, DeclinedElicitation, elicit_url from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession from mcp.shared.context import RequestContext -from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -37,12 +36,8 @@ async def elicitation_callback(context: RequestContext[ClientSession, None], par assert params.message == "Please provide your API key to continue." return ElicitResult(action="accept") - async with create_connected_server_and_client_session( - mcp._mcp_server, elicitation_callback=elicitation_callback - ) as client_session: - await client_session.initialize() - - result = await client_session.call_tool("request_api_key", {}) + async with Client(mcp, elicitation_callback=elicitation_callback) as client: + result = await client.call_tool("request_api_key", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) assert result.content[0].text == "User accept" @@ -67,12 +62,8 @@ async def elicitation_callback(context: RequestContext[ClientSession, None], par assert params.mode == "url" return ElicitResult(action="decline") - async with create_connected_server_and_client_session( - mcp._mcp_server, elicitation_callback=elicitation_callback - ) as client_session: - await client_session.initialize() - - result = await client_session.call_tool("oauth_flow", {}) + async with Client(mcp, elicitation_callback=elicitation_callback) as client: + result = await client.call_tool("oauth_flow", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) assert result.content[0].text == "User decline authorization" @@ -97,12 +88,8 @@ async def elicitation_callback(context: RequestContext[ClientSession, None], par assert params.mode == "url" return ElicitResult(action="cancel") - async with create_connected_server_and_client_session( - mcp._mcp_server, elicitation_callback=elicitation_callback - ) as client_session: - await client_session.initialize() - - result = await client_session.call_tool("payment_flow", {}) + async with Client(mcp, elicitation_callback=elicitation_callback) as client: + result = await client.call_tool("payment_flow", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) assert result.content[0].text == "User cancel payment" @@ -127,12 +114,8 @@ async def setup_credentials(ctx: Context[ServerSession, None]) -> str: async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): return ElicitResult(action="accept") - async with create_connected_server_and_client_session( - mcp._mcp_server, elicitation_callback=elicitation_callback - ) as client_session: - await client_session.initialize() - - result = await client_session.call_tool("setup_credentials", {}) + async with Client(mcp, elicitation_callback=elicitation_callback) as client: + result = await client.call_tool("setup_credentials", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) assert result.content[0].text == "AcceptedUrlElicitation" @@ -165,12 +148,8 @@ async def elicitation_callback(context: RequestContext[ClientSession, None], par # Return without content - this is correct for URL mode return ElicitResult(action="accept") - async with create_connected_server_and_client_session( - mcp._mcp_server, elicitation_callback=elicitation_callback - ) as client_session: - await client_session.initialize() - - result = await client_session.call_tool("check_url_response", {}) + async with Client(mcp, elicitation_callback=elicitation_callback) as client: + result = await client.call_tool("check_url_response", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) assert "Content: None" in result.content[0].text @@ -200,12 +179,8 @@ async def elicitation_callback(context: RequestContext[ClientSession, None], par assert params.requested_schema is not None return ElicitResult(action="accept", content={"name": "Alice"}) - async with create_connected_server_and_client_session( - mcp._mcp_server, elicitation_callback=elicitation_callback - ) as client_session: - await client_session.initialize() - - result = await client_session.call_tool("ask_name", {}) + async with Client(mcp, elicitation_callback=elicitation_callback) as client: + result = await client.call_tool("ask_name", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Hello, Alice!" @@ -235,12 +210,8 @@ async def trigger_elicitation(ctx: Context[ServerSession, None]) -> str: async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): return ElicitResult(action="accept") # pragma: no cover - async with create_connected_server_and_client_session( - mcp._mcp_server, elicitation_callback=elicitation_callback - ) as client_session: - await client_session.initialize() - - result = await client_session.call_tool("trigger_elicitation", {}) + async with Client(mcp, elicitation_callback=elicitation_callback) as client: + result = await client.call_tool("trigger_elicitation", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Elicitation completed" @@ -296,12 +267,8 @@ async def test_cancel(ctx: Context[ServerSession, None]) -> str: async def decline_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): return ElicitResult(action="decline") - async with create_connected_server_and_client_session( - mcp._mcp_server, elicitation_callback=decline_callback - ) as client_session: - await client_session.initialize() - - result = await client_session.call_tool("test_decline", {}) + async with Client(mcp, elicitation_callback=decline_callback) as client: + result = await client.call_tool("test_decline", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Declined" @@ -310,12 +277,8 @@ async def decline_callback(context: RequestContext[ClientSession, None], params: async def cancel_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): return ElicitResult(action="cancel") - async with create_connected_server_and_client_session( - mcp._mcp_server, elicitation_callback=cancel_callback - ) as client_session: - await client_session.initialize() - - result = await client_session.call_tool("test_cancel", {}) + async with Client(mcp, elicitation_callback=cancel_callback) as client: + result = await client.call_tool("test_cancel", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Cancelled" @@ -347,12 +310,8 @@ async def elicitation_callback(context: RequestContext[ClientSession, None], par assert params.requested_schema is not None return ElicitResult(action="accept", content={"email": "test@example.com"}) - async with create_connected_server_and_client_session( - mcp._mcp_server, elicitation_callback=elicitation_callback - ) as client_session: - await client_session.initialize() - - result = await client_session.call_tool("use_deprecated_elicit", {}) + async with Client(mcp, elicitation_callback=elicitation_callback) as client: + result = await client.call_tool("use_deprecated_elicit", {}) assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Email: test@example.com" @@ -378,10 +337,7 @@ async def elicitation_callback(context: RequestContext[ClientSession, None], par assert params.elicitation_id == "ctx-test-001" return ElicitResult(action="accept") - async with create_connected_server_and_client_session( - mcp._mcp_server, elicitation_callback=elicitation_callback - ) as client_session: - await client_session.initialize() - result = await client_session.call_tool("direct_elicit_url", {}) + async with Client(mcp, elicitation_callback=elicitation_callback) as client: + result = await client.call_tool("direct_elicit_url", {}) assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Result: accept" diff --git a/tests/server/fastmcp/test_url_elicitation_error_throw.py b/tests/server/fastmcp/test_url_elicitation_error_throw.py index 27effe55b6..cacc0b741c 100644 --- a/tests/server/fastmcp/test_url_elicitation_error_throw.py +++ b/tests/server/fastmcp/test_url_elicitation_error_throw.py @@ -2,11 +2,10 @@ import pytest -from mcp import types +from mcp import Client, types from mcp.server.fastmcp import Context, FastMCP from mcp.server.session import ServerSession from mcp.shared.exceptions import McpError, UrlElicitationRequiredError -from mcp.shared.memory import create_connected_server_and_client_session @pytest.mark.anyio @@ -28,12 +27,10 @@ async def connect_service(service_name: str, ctx: Context[ServerSession, None]) ] ) - async with create_connected_server_and_client_session(mcp._mcp_server) as client_session: - await client_session.initialize() - + async with Client(mcp) as client: # Call the tool - it should raise McpError with URL_ELICITATION_REQUIRED code with pytest.raises(McpError) as exc_info: - await client_session.call_tool("connect_service", {"service_name": "github"}) + await client.call_tool("connect_service", {"service_name": "github"}) # Verify the error details error = exc_info.value.error @@ -74,12 +71,10 @@ async def multi_auth(ctx: Context[ServerSession, None]) -> str: ] ) - async with create_connected_server_and_client_session(mcp._mcp_server) as client_session: - await client_session.initialize() - + async with Client(mcp) as client: # Call the tool and catch the error with pytest.raises(McpError) as exc_info: - await client_session.call_tool("multi_auth", {}) + await client.call_tool("multi_auth", {}) # Reconstruct the typed error mcp_error = exc_info.value @@ -102,11 +97,9 @@ async def test_normal_exceptions_still_return_error_result(): async def failing_tool(ctx: Context[ServerSession, None]) -> str: raise ValueError("Something went wrong") - async with create_connected_server_and_client_session(mcp._mcp_server) as client_session: - await client_session.initialize() - + async with Client(mcp) as client: # Normal exceptions should be returned as error results, not McpError - result = await client_session.call_tool("failing_tool", {}) + result = await client.call_tool("failing_tool", {}) assert result.is_error is True assert len(result.content) == 1 assert isinstance(result.content[0], types.TextContent) diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index ef3ef49367..8f109d9fb2 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -6,9 +6,10 @@ import pytest import mcp.types as types +from mcp.client._memory import InMemoryTransport +from mcp.client.session import ClientSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError -from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import ( CallToolRequest, CallToolRequestParams, @@ -54,57 +55,61 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[ return [types.TextContent(type="text", text=f"Call number: {call_count}")] raise ValueError(f"Unknown tool: {name}") # pragma: no cover - async with create_connected_server_and_client_session(server) as client: - # First request (will be cancelled) - async def first_request(): - try: - await client.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams(name="test_tool", arguments={}), - ) - ), - CallToolResult, - ) - pytest.fail("First request should have been cancelled") # pragma: no cover - except McpError: - pass # Expected - - # Start first request - async with anyio.create_task_group() as tg: - tg.start_soon(first_request) - - # Wait for it to start - await ev_first_call.wait() - - # Cancel it - assert first_request_id is not None - await client.send_notification( - ClientNotification( - CancelledNotification( - params=CancelledNotificationParams( - request_id=first_request_id, - reason="Testing server recovery", + transport = InMemoryTransport(server) + async with transport.connect() as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as client: + await client.initialize() + + # First request (will be cancelled) + async def first_request(): + try: + await client.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams(name="test_tool", arguments={}), + ) ), + CallToolResult, + ) + pytest.fail("First request should have been cancelled") # pragma: no cover + except McpError: + pass # Expected + + # Start first request + async with anyio.create_task_group() as tg: + tg.start_soon(first_request) + + # Wait for it to start + await ev_first_call.wait() + + # Cancel it + assert first_request_id is not None + await client.send_notification( + ClientNotification( + CancelledNotification( + params=CancelledNotificationParams( + request_id=first_request_id, + reason="Testing server recovery", + ), + ) ) ) + + # Second request (should work normally) + result = await client.send_request( + ClientRequest( + CallToolRequest( + params=CallToolRequestParams(name="test_tool", arguments={}), + ) + ), + CallToolResult, ) - # Second request (should work normally) - result = await client.send_request( - ClientRequest( - CallToolRequest( - params=CallToolRequestParams(name="test_tool", arguments={}), - ) - ), - CallToolResult, - ) - - # Verify second request completed successfully - assert len(result.content) == 1 - # Type narrowing for pyright - content = result.content[0] - assert content.type == "text" - assert isinstance(content, types.TextContent) - assert content.text == "Call number: 2" - assert call_count == 2 + # Verify second request completed successfully + assert len(result.content) == 1 + # Type narrowing for pyright + content = result.content[0] + assert content.type == "text" + assert isinstance(content, types.TextContent) + assert content.text == "Call number: 2" + assert call_count == 2 diff --git a/tests/server/test_completion_with_context.py b/tests/server/test_completion_with_context.py index c59916ef22..bbaa4018f8 100644 --- a/tests/server/test_completion_with_context.py +++ b/tests/server/test_completion_with_context.py @@ -6,8 +6,8 @@ import pytest +from mcp import Client from mcp.server.lowlevel import Server -from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import ( Completion, CompletionArgument, @@ -38,7 +38,7 @@ async def handle_completion( # Return test completion return Completion(values=["test-completion"], total=1, has_more=False) - async with create_connected_server_and_client_session(server) as client: + async with Client(server) as client: # Test with context result = await client.complete( ref=ResourceTemplateReference(type="ref/resource", uri="test://resource/{param}"), @@ -70,7 +70,7 @@ async def handle_completion( return Completion(values=["no-context-completion"], total=1, has_more=False) - async with create_connected_server_and_client_session(server) as client: + async with Client(server) as client: # Test without context result = await client.complete( ref=PromptReference(type="ref/prompt", name="test-prompt"), argument={"name": "arg", "value": "val"} @@ -109,7 +109,7 @@ async def handle_completion( return Completion(values=[], total=0, has_more=False) # pragma: no cover - async with create_connected_server_and_client_session(server) as client: + async with Client(server) as client: # First, complete database db_result = await client.complete( ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"), @@ -160,7 +160,7 @@ async def handle_completion( return Completion(values=[], total=0, has_more=False) # pragma: no cover - async with create_connected_server_and_client_session(server) as client: + async with Client(server) as client: # Try to complete table without database context - should raise error with pytest.raises(Exception) as exc_info: await client.complete( diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py index 10f580a6c5..31238b9ffd 100644 --- a/tests/shared/test_memory.py +++ b/tests/shared/test_memory.py @@ -1,9 +1,7 @@ import pytest -from typing_extensions import AsyncGenerator -from mcp.client.session import ClientSession +from mcp import Client from mcp.server import Server -from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import EmptyResult, Resource @@ -24,18 +22,9 @@ async def handle_list_resources(): # pragma: no cover return server -@pytest.fixture -async def client_connected_to_server( - mcp_server: Server, -) -> AsyncGenerator[ClientSession, None]: - async with create_connected_server_and_client_session(mcp_server) as client_session: - yield client_session - - @pytest.mark.anyio -async def test_memory_server_and_client_connection( - client_connected_to_server: ClientSession, -): +async def test_memory_server_and_client_connection(mcp_server: Server): """Shows how a client and server can communicate over memory streams.""" - response = await client_connected_to_server.send_ping() - assert isinstance(response, EmptyResult) + async with Client(mcp_server) as client: + response = await client.send_ping() + assert isinstance(response, EmptyResult) diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 5f0ac83fdc..1d7de0b346 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -5,15 +5,16 @@ import pytest import mcp.types as types +from mcp.client._memory import InMemoryTransport from mcp.client.session import ClientSession from mcp.server import Server from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.context import RequestContext -from mcp.shared.memory import create_connected_server_and_client_session +from mcp.shared.message import SessionMessage from mcp.shared.progress import progress -from mcp.shared.session import BaseSession, RequestResponder, SessionMessage +from mcp.shared.session import BaseSession, RequestResponder @pytest.mark.anyio @@ -368,25 +369,30 @@ async def handle_list_tools() -> list[types.Tool]: # Test with mocked logging with patch("mcp.shared.session.logging.error", side_effect=mock_log_error): - async with create_connected_server_and_client_session(server) as client_session: - # Send a request with a failing progress callback - result = await client_session.send_request( - types.ClientRequest( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams(name="progress_tool", arguments={}), - ) - ), - types.CallToolResult, - progress_callback=failing_progress_callback, - ) + transport = InMemoryTransport(server) + async with transport.connect() as (read_stream, write_stream): + async with ClientSession( # pragma: no branch + read_stream=read_stream, write_stream=write_stream + ) as session: + await session.initialize() + # Send a request with a failing progress callback + result = await session.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams(name="progress_tool", arguments={}), + ) + ), + types.CallToolResult, + progress_callback=failing_progress_callback, + ) - # Verify the request completed successfully despite the callback failure - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, types.TextContent) - assert content.text == "progress_result" + # Verify the request completed successfully despite the callback failure + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, types.TextContent) + assert content.text == "progress_result" - # Check that a warning was logged for the progress callback exception - assert len(logged_errors) > 0 - assert any("Progress callback raised an exception" in warning for warning in logged_errors) + # Check that a warning was logged for the progress callback exception + assert len(logged_errors) > 0 + assert any("Progress callback raised an exception" in warning for warning in logged_errors) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index bdb7052843..0656a01a6f 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -1,14 +1,14 @@ -from collections.abc import AsyncGenerator from typing import Any import anyio import pytest import mcp.types as types +from mcp.client._memory import InMemoryTransport from mcp.client.session import ClientSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError -from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session +from mcp.shared.memory import create_client_server_memory_streams from mcp.shared.message import SessionMessage from mcp.types import ( CancelledNotification, @@ -30,25 +30,20 @@ def mcp_server() -> Server: return Server(name="test server") -@pytest.fixture -async def client_connected_to_server( - mcp_server: Server, -) -> AsyncGenerator[ClientSession, None]: - async with create_connected_server_and_client_session(mcp_server) as client_session: - yield client_session - - @pytest.mark.anyio -async def test_in_flight_requests_cleared_after_completion( - client_connected_to_server: ClientSession, -): +async def test_in_flight_requests_cleared_after_completion(mcp_server: Server): """Verify that _in_flight is empty after all requests complete.""" - # Send a request and wait for response - response = await client_connected_to_server.send_ping() - assert isinstance(response, EmptyResult) + transport = InMemoryTransport(mcp_server) + async with transport.connect() as (read_stream, write_stream): + async with ClientSession(read_stream=read_stream, write_stream=write_stream) as session: + await session.initialize() - # Verify _in_flight is empty - assert len(client_connected_to_server._in_flight) == 0 + # Send a request and wait for response + response = await session.send_ping() + assert isinstance(response, EmptyResult) + + # Verify _in_flight is empty + assert len(session._in_flight) == 0 @pytest.mark.anyio @@ -88,10 +83,10 @@ async def handle_list_tools() -> list[types.Tool]: return server - async def make_request(client_session: ClientSession): + async def make_request(session: ClientSession): nonlocal ev_cancelled try: - await client_session.send_request( + await session.send_request( ClientRequest( types.CallToolRequest( params=types.CallToolRequestParams(name="slow_tool", arguments={}), @@ -105,28 +100,31 @@ async def make_request(client_session: ClientSession): assert "Request cancelled" in str(e) ev_cancelled.set() - async with create_connected_server_and_client_session(make_server()) as client_session: - async with anyio.create_task_group() as tg: - tg.start_soon(make_request, client_session) - - # Wait for the request to be in-flight - with anyio.fail_after(1): # Timeout after 1 second - await ev_tool_called.wait() - - # Send cancellation notification - assert request_id is not None - await client_session.send_notification( - ClientNotification( - CancelledNotification( - params=CancelledNotificationParams(request_id=request_id), + transport = InMemoryTransport(make_server()) + async with transport.connect() as (read_stream, write_stream): + async with ClientSession(read_stream=read_stream, write_stream=write_stream) as session: + await session.initialize() + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(make_request, session) + + # Wait for the request to be in-flight + with anyio.fail_after(1): # Timeout after 1 second + await ev_tool_called.wait() + + # Send cancellation notification + assert request_id is not None + await session.send_notification( + ClientNotification( + CancelledNotification( + params=CancelledNotificationParams(request_id=request_id), + ) ) ) - ) - # Give cancellation time to process - # TODO(Marcelo): Drop the pragma once https://github.com/coveragepy/coveragepy/issues/1987 is fixed. - with anyio.fail_after(1): # pragma: no cover - await ev_cancelled.wait() + # Give cancellation time to process + # TODO(Marcelo): Drop the pragma once https://github.com/coveragepy/coveragepy/issues/1987 is fixed. + with anyio.fail_after(1): # pragma: no cover + await ev_cancelled.wait() @pytest.mark.anyio diff --git a/tests/test_examples.py b/tests/test_examples.py index c7ef81e1e3..187cda3218 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -12,18 +12,16 @@ from pydantic import AnyUrl from pytest_examples import CodeExample, EvalExample, find_examples -from examples.fastmcp.complex_inputs import mcp as complex_inputs_mcp -from examples.fastmcp.desktop import mcp as desktop_mcp -from examples.fastmcp.direct_call_tool_result_return import mcp as direct_call_tool_result_mcp -from examples.fastmcp.simple_echo import mcp as simple_echo_mcp -from mcp.shared.memory import create_connected_server_and_client_session as client_session +from mcp import Client from mcp.types import TextContent, TextResourceContents @pytest.mark.anyio async def test_simple_echo(): """Test the simple echo server""" - async with client_session(simple_echo_mcp._mcp_server) as client: + from examples.fastmcp.simple_echo import mcp + + async with Client(mcp) as client: result = await client.call_tool("echo", {"text": "hello"}) assert len(result.content) == 1 content = result.content[0] @@ -34,7 +32,9 @@ async def test_simple_echo(): @pytest.mark.anyio async def test_complex_inputs(): """Test the complex inputs server""" - async with client_session(complex_inputs_mcp._mcp_server) as client: + from examples.fastmcp.complex_inputs import mcp + + async with Client(mcp) as client: tank = {"shrimp": [{"name": "bob"}, {"name": "alice"}]} result = await client.call_tool("name_shrimp", {"tank": tank, "extra_names": ["charlie"]}) assert len(result.content) == 3 @@ -49,7 +49,9 @@ async def test_complex_inputs(): @pytest.mark.anyio async def test_direct_call_tool_result_return(): """Test the CallToolResult echo server""" - async with client_session(direct_call_tool_result_mcp._mcp_server) as client: + from examples.fastmcp.direct_call_tool_result_return import mcp + + async with Client(mcp) as client: result = await client.call_tool("echo", {"text": "hello"}) assert len(result.content) == 1 content = result.content[0] @@ -69,7 +71,9 @@ async def test_desktop(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(Path, "iterdir", lambda self: mock_files) # type: ignore[reportUnknownArgumentType] monkeypatch.setattr(Path, "home", lambda: Path("/fake/home")) - async with client_session(desktop_mcp._mcp_server) as client: + from examples.fastmcp.desktop import mcp + + async with Client(mcp) as client: # Test the sum function result = await client.call_tool("sum", {"a": 1, "b": 2}) assert len(result.content) == 1