Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 52 additions & 51 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,69 +741,70 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:
if isinstance(resp, CallToolResult):
res = resp
_final_resp = resp
if isinstance(res.content[0], TextContent):
if not res.content:
_append_tool_call_result(
func_tool_id,
res.content[0].text,
"The tool returned no content.",
)
elif isinstance(res.content[0], ImageContent):
# Cache the image instead of sending directly
cached_img = tool_image_cache.save_image(
base64_data=res.content[0].data,
tool_call_id=func_tool_id,
tool_name=func_tool_name,
index=0,
mime_type=res.content[0].mimeType or "image/png",
)
_append_tool_call_result(
func_tool_id,
(
f"Image returned and cached at path='{cached_img.file_path}'. "
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
f"with type='image' and path='{cached_img.file_path}'."
),
)
# Yield image info for LLM visibility (will be handled in step())
yield _HandleFunctionToolsResult.from_cached_image(
cached_img
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
_append_tool_call_result(
func_tool_id,
resource.text,
)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
continue

result_parts: list[str] = []
for index, content_item in enumerate(res.content):
if isinstance(content_item, TextContent):
result_parts.append(content_item.text)
elif isinstance(content_item, ImageContent):
# Cache the image instead of sending directly
cached_img = tool_image_cache.save_image(
base64_data=resource.blob,
base64_data=content_item.data,
tool_call_id=func_tool_id,
tool_name=func_tool_name,
index=0,
mime_type=resource.mimeType,
index=index,
mime_type=content_item.mimeType or "image/png",
)
_append_tool_call_result(
func_tool_id,
(
f"Image returned and cached at path='{cached_img.file_path}'. "
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
f"with type='image' and path='{cached_img.file_path}'."
),
result_parts.append(
f"Image returned and cached at path='{cached_img.file_path}'. "
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
f"with type='image' and path='{cached_img.file_path}'."
)
# Yield image info for LLM visibility
# Yield image info for LLM visibility (will be handled in step())
yield _HandleFunctionToolsResult.from_cached_image(
cached_img
)
else:
_append_tool_call_result(
func_tool_id,
"The tool has returned a data type that is not supported.",
)
elif isinstance(content_item, EmbeddedResource):
resource = content_item.resource
if isinstance(resource, TextResourceContents):
result_parts.append(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
# Cache the image instead of sending directly
cached_img = tool_image_cache.save_image(
base64_data=resource.blob,
tool_call_id=func_tool_id,
tool_name=func_tool_name,
index=index,
mime_type=resource.mimeType,
)
result_parts.append(
f"Image returned and cached at path='{cached_img.file_path}'. "
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
f"with type='image' and path='{cached_img.file_path}'."
)
# Yield image info for LLM visibility
yield _HandleFunctionToolsResult.from_cached_image(
cached_img
)
else:
result_parts.append(
"The tool has returned a data type that is not supported."
)
if result_parts:
_append_tool_call_result(
func_tool_id,
"\n\n".join(result_parts),
)

elif resp is None:
# Tool 直接请求发送消息给用户
Expand Down
83 changes: 83 additions & 0 deletions tests/test_tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,29 @@ async def generator():
return generator()


class MockMixedContentToolExecutor:
"""模拟返回图片 + 文本的工具执行器"""

@classmethod
def execute(cls, tool, run_context, **tool_args):
async def generator():
from mcp.types import CallToolResult, ImageContent, TextContent

result = CallToolResult(
content=[
ImageContent(
type="image",
data="dGVzdA==",
mimeType="image/png",
),
TextContent(type="text", text="直播间标题:新游首发:零~红蝶~"),
]
)
yield result

return generator()


class MockFailingProvider(MockProvider):
async def text_chat(self, **kwargs) -> LLMResponse:
self.call_count += 1
Expand Down Expand Up @@ -438,6 +461,66 @@ async def test_hooks_called_with_max_step(
assert mock_hooks.tool_end_called, "on_tool_end应该被调用"


@pytest.mark.asyncio
async def test_tool_result_includes_all_calltoolresult_content(
runner, mock_provider, provider_request, mock_hooks, monkeypatch
):
"""工具返回多个 content 项时,tool result 应包含全部内容。"""

from astrbot.core.agent.tool_image_cache import tool_image_cache

mock_provider.should_call_tools = True
mock_provider.max_calls_before_normal_response = 1

saved_images = []

def fake_save_image(
base64_data, tool_call_id, tool_name, index=0, mime_type="image/png"
):
saved_images.append(
{
"base64_data": base64_data,
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"index": index,
"mime_type": mime_type,
}
)
return SimpleNamespace(file_path=f"/tmp/{tool_call_id}_{index}.png")

monkeypatch.setattr(tool_image_cache, "save_image", fake_save_image)

await runner.reset(
provider=mock_provider,
request=provider_request,
run_context=ContextWrapper(context=None),
tool_executor=MockMixedContentToolExecutor,
agent_hooks=mock_hooks,
streaming=False,
)

async for _ in runner.step_until_done(3):
pass

tool_messages = [
m for m in runner.run_context.messages if getattr(m, "role", None) == "tool"
]
assert len(tool_messages) == 1

content = str(tool_messages[0].content)
assert "Image returned and cached at path='/tmp/call_123_0.png'." in content
assert "直播间标题:新游首发:零~红蝶~" in content
assert saved_images == [
{
"base64_data": "dGVzdA==",
"tool_call_id": "call_123",
"tool_name": "test_tool",
"index": 0,
"mime_type": "image/png",
}
]


@pytest.mark.asyncio
async def test_fallback_provider_used_when_primary_raises(
runner, provider_request, mock_tool_executor, mock_hooks
Expand Down
Loading