Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/stirrup/clients/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def to_openai_messages(msgs: list[ChatMessage]) -> list[dict[str, Any]]:
out.append({"role": "user", "content": content_to_openai(m.content)})
elif isinstance(m, AssistantMessage):
msg: dict[str, Any] = {"role": "assistant", "content": content_to_openai(m.content)}
if m.metadata:
msg["metadata"] = m.metadata

if m.reasoning:
if m.reasoning.content:
Expand Down
4 changes: 4 additions & 0 deletions src/stirrup/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tempfile import NamedTemporaryFile
from types import TracebackType
from typing import Annotated, Any, ClassVar, Literal, Protocol, Self, overload, runtime_checkable
from uuid import uuid4

import filetype
from moviepy import AudioFileClip, VideoFileClip
Expand Down Expand Up @@ -118,6 +119,7 @@ class ImageContentBlock(BinaryContentBlock):
allowed_mime_types: ClassVar[set[str]] = {
"image/jpeg", # JPEG
"image/png", # PNG
"image/webp", # WebP
"image/gif", # GIF
"image/bmp", # BMP
"image/tiff", # TIFF
Expand Down Expand Up @@ -616,11 +618,13 @@ class Reasoning(BaseModel):
class AssistantMessage(BaseModel):
"""LLM response message with optional tool calls and token usage tracking."""

id: str = Field(default_factory=lambda: uuid4().hex)
role: Literal["assistant"] = "assistant"
reasoning: Reasoning | None = None
content: Content
tool_calls: Annotated[list[ToolCall], Field(default_factory=list)]
token_usage: Annotated[TokenUsage, Field(default_factory=TokenUsage)]
metadata: Annotated[dict[str, Any], Field(default_factory=dict)]
request_start_time: float | None = None
request_end_time: float | None = None

Expand Down
56 changes: 46 additions & 10 deletions src/stirrup/tools/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,32 @@
from json_schema_to_pydantic import create_model
from pydantic import BaseModel, Field, model_validator

from stirrup.core.models import Tool, ToolProvider, ToolResult, ToolUseCountMetadata
from stirrup.core.models import (
AudioContentBlock,
Content,
ContentBlock,
ImageContentBlock,
Tool,
ToolProvider,
ToolResult,
ToolUseCountMetadata,
)

# MCP imports (optional dependency)
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import (
AudioContent as MCPAudioContent,
)
from mcp.types import (
ImageContent as MCPImageContent,
)
from mcp.types import (
TextContent as MCPTextContent,
)
except ImportError as e:
raise ImportError(
"Requires installation of the mcp extra. Install with (for example): `uv pip install stirrup[mcp]` or `uv add stirrup[mcp]`",
Expand Down Expand Up @@ -345,7 +363,29 @@ def all_tools(self) -> dict[str, list[str]]:
"""
return {server: [t["name"] for t in tools] for server, tools in self._tools.items()}

async def call_tool(self, server: str, tool_name: str, arguments: dict[str, Any]) -> str:
def _convert_mcp_content(self, content_blocks: list[Any]) -> Content:
"""Convert MCP content blocks into Stirrup content blocks."""
content: list[ContentBlock] = []

for block in content_blocks:
if isinstance(block, MCPTextContent):
content.append(block.text)
continue
if isinstance(block, MCPImageContent):
content.append(ImageContentBlock(data=block.data))
continue
if isinstance(block, MCPAudioContent):
content.append(AudioContentBlock(data=block.data))
continue
raise TypeError(f"Unsupported MCP content block: {type(block).__name__}")

if not content:
return ""
if len(content) == 1 and isinstance(content[0], str):
return content[0]
return content

async def call_tool(self, server: str, tool_name: str, arguments: dict[str, Any]) -> Content:
"""Call a tool on a specific MCP server.

Args:
Expand All @@ -354,7 +394,7 @@ async def call_tool(self, server: str, tool_name: str, arguments: dict[str, Any]
arguments: Arguments to pass to the tool.

Returns:
Tool result as a string (text content extracted from response).
Tool result converted into Stirrup content blocks.

Raises:
ValueError: If server is not connected.
Expand All @@ -364,10 +404,7 @@ async def call_tool(self, server: str, tool_name: str, arguments: dict[str, Any]
raise ValueError(f"Server '{server}' not connected. Available: {self.servers}")

result = await session.call_tool(tool_name, arguments)

# Extract text content from result
text_parts = [str(content.text) for content in result.content if hasattr(content, "text")]
return "\n".join(text_parts)
return self._convert_mcp_content(result.content)

def get_all_tools(self) -> list[Tool[Any, ToolUseCountMetadata]]:
"""Get individual Tool objects for each tool from all connected MCP servers.
Expand Down Expand Up @@ -403,15 +440,14 @@ async def executor(
_tool: str = mcp_tool_name,
) -> ToolResult[ToolUseCountMetadata]:
content = await self.call_tool(_server, _tool, params.model_dump())
xml_content = f"<mcp_result>\n{content}\n</mcp_result>"
return ToolResult(content=xml_content, metadata=ToolUseCountMetadata())
return ToolResult(content=content, metadata=ToolUseCountMetadata())

tools.append(
Tool(
name=unique_name,
description=tool_info.get("description") or f"Tool '{mcp_tool_name}' from {server_name}",
parameters=params_model,
executor=executor, # ty: ignore[invalid-argument-type]
executor=executor,
)
)

Expand Down
50 changes: 50 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""Tests for agent core functionality."""

from io import BytesIO

from PIL import Image
from pydantic import BaseModel

from stirrup.constants import FINISH_TOOL_NAME
from stirrup.core.agent import Agent
from stirrup.core.models import (
AssistantMessage,
ChatMessage,
ImageContentBlock,
LLMClient,
SummaryMessage,
SystemMessage,
Expand Down Expand Up @@ -42,23 +46,30 @@
return response


def _sample_png_block() -> ImageContentBlock:
img = Image.new("RGB", (1, 1), color=(255, 0, 0))
buffer = BytesIO()
img.save(buffer, format="PNG")
return ImageContentBlock(data=buffer.getvalue())


async def test_agent_basic_finish() -> None:
"""Test agent completes successfully when finish tool is called."""
# Create mock responses
responses = [
AssistantMessage(
content="I'll finish now",
tool_calls=[
ToolCall(
name=FINISH_TOOL_NAME,
arguments='{"reason": "Task completed successfully", "paths": []}',
tool_call_id="call_1",
)
],
token_usage=TokenUsage(input=100, answer=50),
request_start_time=100.0,
request_end_time=100.4,
)

Check failure on line 72 in tests/test_agent.py

View workflow job for this annotation

GitHub Actions / type-check

ty (missing-argument)

tests/test_agent.py:60:9: missing-argument: No argument provided for required parameter `metadata`
]

# Create agent with mock client
Expand Down Expand Up @@ -95,11 +106,11 @@
"""Test agent stops after max_turns is reached."""
# Create mock responses (never calls finish)
responses = [
AssistantMessage(
content=f"Turn {i}",
tool_calls=[],
token_usage=TokenUsage(input=100, answer=50),
)

Check failure on line 113 in tests/test_agent.py

View workflow job for this annotation

GitHub Actions / type-check

ty (missing-argument)

tests/test_agent.py:109:9: missing-argument: No argument provided for required parameter `metadata`
for i in range(5)
]

Expand Down Expand Up @@ -149,29 +160,29 @@
# Create mock responses
responses = [
# First turn: call echo tool
AssistantMessage(
content="I'll echo your message",
tool_calls=[
ToolCall(
name="echo",
arguments='{"message": "Hello"}',
tool_call_id="call_1",
)
],
token_usage=TokenUsage(input=100, answer=50),
),

Check failure on line 173 in tests/test_agent.py

View workflow job for this annotation

GitHub Actions / type-check

ty (missing-argument)

tests/test_agent.py:163:9: missing-argument: No argument provided for required parameter `metadata`
# Second turn: finish
AssistantMessage(
content="Done",
tool_calls=[
ToolCall(
name=FINISH_TOOL_NAME,
arguments='{"reason": "Echoed successfully", "paths": []}',
tool_call_id="call_2",
)
],
token_usage=TokenUsage(input=100, answer=50),
),

Check failure on line 185 in tests/test_agent.py

View workflow job for this annotation

GitHub Actions / type-check

ty (missing-argument)

tests/test_agent.py:175:9: missing-argument: No argument provided for required parameter `metadata`
]

# Create agent with mock client
Expand Down Expand Up @@ -212,34 +223,73 @@
assert "Echo: Hello" in echo_messages[0].content


async def test_run_tool_preserves_image_content() -> None:
"""Test run_tool preserves image blocks returned by tools."""

class EmptyParams(BaseModel):
pass

image_block = _sample_png_block()

def image_executor(_params: EmptyParams) -> ToolResult:
return ToolResult(content=[image_block])

image_tool = Tool[EmptyParams, None](
name="image_tool",
description="Return an image",
parameters=EmptyParams,
executor=image_executor, # ty: ignore[invalid-argument-type]
)

client = MockLLMClient([])
agent = Agent(
client=client,
name="test-agent",
max_turns=1,
tools=[image_tool],
finish_tool=SIMPLE_FINISH_TOOL,
)

async with agent.session() as session:
tool_message = await session.run_tool(
ToolCall(name="image_tool", arguments="{}", tool_call_id="call_1"),
run_metadata={},
)

assert isinstance(tool_message.content, list)
assert len(tool_message.content) == 1
assert isinstance(tool_message.content[0], ImageContentBlock)
assert tool_message.content[0].mime_type == "image/png"


async def test_agent_invalid_tool_call() -> None:
"""Test agent handles invalid tool calls gracefully."""
# Create mock responses
responses = [
# Call non-existent tool
AssistantMessage(
content="I'll call a tool",
tool_calls=[
ToolCall(
name="nonexistent_tool",
arguments='{"param": "value"}',
tool_call_id="call_1",
)
],
token_usage=TokenUsage(input=100, answer=50),
),

Check failure on line 280 in tests/test_agent.py

View workflow job for this annotation

GitHub Actions / type-check

ty (missing-argument)

tests/test_agent.py:270:9: missing-argument: No argument provided for required parameter `metadata`
# Then finish
AssistantMessage(
content="Done",
tool_calls=[
ToolCall(
name=FINISH_TOOL_NAME,
arguments='{"reason": "Handled error", "paths": []}',
tool_call_id="call_2",
)
],
token_usage=TokenUsage(input=100, answer=50),
),

Check failure on line 292 in tests/test_agent.py

View workflow job for this annotation

GitHub Actions / type-check

ty (missing-argument)

tests/test_agent.py:282:9: missing-argument: No argument provided for required parameter `metadata`
]

# Create agent with mock client
Expand Down Expand Up @@ -305,17 +355,17 @@
# Create mock responses
responses = [
# First: invalid finish (status != "complete")
AssistantMessage(
content="Trying to finish",
tool_calls=[
ToolCall(
name=FINISH_TOOL_NAME,
arguments='{"reason": "Not ready", "status": "pending"}',
tool_call_id="call_1",
)
],
token_usage=TokenUsage(input=100, answer=50),
),

Check failure on line 368 in tests/test_agent.py

View workflow job for this annotation

GitHub Actions / type-check

ty (missing-argument)

tests/test_agent.py:358:9: missing-argument: No argument provided for required parameter `metadata`
# Second: valid finish (status == "complete")
AssistantMessage(
content="Now finishing",
Expand Down
37 changes: 37 additions & 0 deletions tests/test_clients_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Tests for OpenAI client utility helpers."""

from stirrup.clients.utils import to_openai_messages
from stirrup.core.models import AssistantMessage, TokenUsage


def test_assistant_message_generates_id() -> None:
first = AssistantMessage(content="Hello", tool_calls=[], token_usage=TokenUsage(), metadata={})
second = AssistantMessage(
content="Hello again",
tool_calls=[],
token_usage=TokenUsage(),
metadata={},
)

assert first.id
assert second.id
assert first.id != second.id


def test_to_openai_messages_forwards_assistant_metadata() -> None:
message = AssistantMessage(
content="Hello",
tool_calls=[],
token_usage=TokenUsage(),
metadata={"source": "cache", "attempt": 2},
)

result = to_openai_messages([message])

assert result == [
{
"role": "assistant",
"content": [{"type": "text", "text": "Hello"}],
"metadata": {"source": "cache", "attempt": 2},
}
]
108 changes: 108 additions & 0 deletions tests/test_mcp_image_smoke.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Smoke test for MCP image tool results."""

import base64
import inspect
import sys
from io import BytesIO
from pathlib import Path
from typing import cast

import pytest
from PIL import Image

from stirrup.clients.utils import to_openai_messages
from stirrup.core.models import ImageContentBlock, ToolMessage, ToolResult, ToolUseCountMetadata
from stirrup.tools.mcp import MCPConfig, MCPToolProvider

pytest.importorskip("mcp.server.fastmcp")


def _image_b64(image_format: str) -> str:
"""Build a tiny real image payload for the temp MCP server."""
img = Image.new("RGB", (1, 1), color=(255, 0, 0))
buffer = BytesIO()
img.save(buffer, format=image_format)
return base64.b64encode(buffer.getvalue()).decode("ascii")


def _write_image_server(script_path: Path, image_format: str, tool_name: str) -> None:
"""Write a one-file stdio MCP server with a single image-returning tool."""
image_b64 = _image_b64(image_format)
image_ext = image_format.lower()
script_path.write_text(
f"""
import base64

from mcp.server.fastmcp import FastMCP, Image

mcp = FastMCP("image-server")


@mcp.tool()
def {tool_name}() -> Image:
return Image(data=base64.b64decode("{image_b64}"), format="{image_ext}")


if __name__ == "__main__":
mcp.run(transport="stdio")
""".strip()
)


def _make_provider(script_path: Path) -> MCPToolProvider:
"""Create a provider that launches the temp MCP server over stdio."""
config = MCPConfig.model_validate(
{
"mcpServers": {
"image_server": {
"command": sys.executable,
"args": [str(script_path)],
}
}
}
)
return MCPToolProvider(config=config)


async def _assert_tool_returns_image(
tmp_path: Path,
*,
image_format: str,
tool_name: str,
) -> None:
"""Assert the MCP bridge preserves an image through OpenAI-style serialization."""
script_path = tmp_path / "image_server.py"
_write_image_server(script_path, image_format=image_format, tool_name=tool_name)

provider = _make_provider(script_path)
async with provider as tools:
tool = next(tool for tool in tools if tool.name == f"image_server__{tool_name}")
executor_result = tool.executor(tool.parameters())
raw_result = await executor_result if inspect.isawaitable(executor_result) else executor_result
result = cast(ToolResult[ToolUseCountMetadata], raw_result)

# First prove the MCP bridge produced a real Stirrup image block.
assert isinstance(result.content, list)
assert len(result.content) == 1
assert isinstance(result.content[0], ImageContentBlock)
assert result.content[0].mime_type == f"image/{image_format.lower()}"

# Then prove the image still survives message serialization for the model layer.
messages = to_openai_messages(
[
ToolMessage(
content=result.content,
tool_call_id="call_1",
name=f"image_server__{tool_name}",
)
]
)
assert messages[0]["content"][0]["type"] == "image_url"


async def test_mcp_png_result_reaches_openai_message(tmp_path: Path) -> None:
await _assert_tool_returns_image(tmp_path, image_format="PNG", tool_name="read_png")


async def test_mcp_webp_result_reaches_openai_message(tmp_path: Path) -> None:
await _assert_tool_returns_image(tmp_path, image_format="WEBP", tool_name="read_webp")
Loading