Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,7 @@ async def _resolve_tool_exec(
func_tool=param_subset,
model=self.req.model,
session_id=self.req.session_id,
extra_user_content_parts=self.req.extra_user_content_parts,
abort_signal=self._abort_signal,
)
if requery_resp:
Expand Down
81 changes: 81 additions & 0 deletions tests/test_tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,87 @@ async def test_follow_up_ticket_not_consumed_when_no_next_tool_call(
assert ticket.consumed is False


@pytest.mark.asyncio
async def test_skills_like_requery_passes_extra_user_content_parts():
"""skills-like 模式 re-query 时应传递 extra_user_content_parts(如 image_caption)"""
from astrbot.core.agent.message import TextPart

captured_kwargs = {}

class SkillsLikeProvider(MockProvider):
async def text_chat(self, **kwargs) -> LLMResponse:
self.call_count += 1
if self.call_count == 1:
# 第一次调用:返回工具选择(light schema)
return LLMResponse(
role="assistant",
completion_text="选择工具",
tools_call_name=["test_tool"],
tools_call_args=[{"query": "test"}],
tools_call_ids=["call_1"],
usage=TokenUsage(input_other=10, output=5),
)
if self.call_count == 2:
# 第二次调用:re-query with param schema
captured_kwargs.update(kwargs)
return LLMResponse(
role="assistant",
completion_text="调用工具",
tools_call_name=["test_tool"],
tools_call_args=[{"query": "actual"}],
tools_call_ids=["call_2"],
usage=TokenUsage(input_other=10, output=5),
)
# 后续调用:正常回复
return LLMResponse(
role="assistant",
completion_text="最终回复",
usage=TokenUsage(input_other=10, output=5),
)

provider = SkillsLikeProvider()
tool = FunctionTool(
name="test_tool",
description="测试",
parameters={"type": "object", "properties": {"query": {"type": "string"}}},
handler=AsyncMock(),
)
tool_set = ToolSet(tools=[tool])

caption_part = TextPart(text="<image_caption>一张猫的照片</image_caption>")
req = ProviderRequest(
prompt="看看这张图",
func_tool=tool_set,
contexts=[],
extra_user_content_parts=[caption_part],
)

event = MockEvent(umo="test_umo", sender_id="test_sender")
ctx = MockAgentContext(event)
run_context = ContextWrapper(context=ctx)
runner = ToolLoopAgentRunner()

await runner.reset(
provider=provider,
request=req,
run_context=run_context,
tool_executor=MockToolExecutor(),
agent_hooks=MockHooks(),
tool_schema_mode="skills_like",
)

async for _ in runner.step():
pass

# 验证 re-query 调用包含了 extra_user_content_parts
assert "extra_user_content_parts" in captured_kwargs, (
"re-query 应该传递 extra_user_content_parts"
)
parts = captured_kwargs["extra_user_content_parts"]
assert len(parts) == 1
assert parts[0].text == "<image_caption>一张猫的照片</image_caption>"


@pytest.mark.asyncio
async def test_follow_up_accepted_when_active_and_not_stopping(
runner, mock_provider, provider_request, mock_tool_executor, mock_hooks
Expand Down
Loading