diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index f789c97dd3..cb7e22f341 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -317,7 +317,7 @@ def follow_up( message_text: str, ) -> FollowUpTicket | None: """Queue a follow-up message for the next tool result.""" - if self.done(): + if self.done() or self._is_stop_requested(): return None text = (message_text or "").strip() if not text: diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index f1796600f6..f983c5b4d4 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -699,6 +699,283 @@ async def test_follow_up_ticket_not_consumed_when_no_next_tool_call( assert ticket.consumed is False +@pytest.mark.asyncio +async def test_follow_up_accepted_when_active_and_not_stopping( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + """Test that follow-up is accepted when runner is active and stop is not requested.""" + + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + # Runner is active (not done) and stop is not requested + assert not runner.done() + assert runner._stop_requested is False + + ticket = runner.follow_up(message_text="valid follow-up message") + + assert ticket is not None, "Follow-up should be accepted when runner is active and not stopping" + assert ticket.text == "valid follow-up message" + assert ticket.consumed is False + assert ticket in runner._pending_follow_ups + + +@pytest.mark.asyncio +async def test_follow_up_rejected_when_stop_requested( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + """Test that follow-up is rejected when stop has been requested.""" + + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + # Request stop + runner.request_stop() + assert runner._stop_requested is True + + ticket = runner.follow_up(message_text="follow-up after stop") + + assert ticket is None, "Follow-up should be rejected after stop is requested" + assert len(runner._pending_follow_ups) == 0 + + +@pytest.mark.asyncio +async def test_follow_up_rejected_when_runner_done( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + """Test that follow-up is rejected when runner is done.""" + + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + # Run to completion + async for _ in runner.step_until_done(10): + pass + + # Runner should be done + assert runner.done() + + ticket = runner.follow_up(message_text="follow-up after done") + + assert ticket is None, "Follow-up should be rejected when runner is done" + + +@pytest.mark.asyncio +async def test_follow_up_rejected_after_stop_before_tool_call( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + """Test that follow-ups submitted after stop are not merged into tool results.""" + + mock_event = MockEvent("test:FriendMessage:stop_race", "u1") + run_context = ContextWrapper(context=MockAgentContext(mock_event)) + + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=run_context, + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + # Add a follow-up before stop + ticket_before_stop = runner.follow_up(message_text="before stop") + assert ticket_before_stop is not None + + # Request stop + runner.request_stop() + + # Try to add a follow-up after stop + ticket_after_stop = runner.follow_up(message_text="after stop") + assert ticket_after_stop is None, "Follow-up after stop should be rejected" + + # Verify only the pre-stop follow-up is in the queue + assert len(runner._pending_follow_ups) == 1 + assert runner._pending_follow_ups[0].text == "before stop" + + +@pytest.mark.asyncio +async def test_follow_up_merged_into_tool_result_before_stop( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + """Test that follow-ups queued before stop are merged into tool results.""" + + mock_event = MockEvent("test:FriendMessage:merge_before_stop", "u1") + run_context = ContextWrapper(context=MockAgentContext(mock_event)) + + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=run_context, + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + # Queue follow-ups before stop + ticket1 = runner.follow_up(message_text="follow up 1 before stop") + ticket2 = runner.follow_up(message_text="follow up 2 before stop") + assert ticket1 is not None + assert ticket2 is not None + + # Run the agent step (should execute tool and merge follow-ups) + async for _ in runner.step(): + pass + + # Verify follow-ups were merged into tool result + assert provider_request.tool_calls_result is not None + assert isinstance(provider_request.tool_calls_result, list) + assert provider_request.tool_calls_result + tool_result = str( + provider_request.tool_calls_result[0].tool_calls_result[0].content + ) + + # Should contain the follow-up notice + assert "SYSTEM NOTICE" in tool_result + assert "follow up 1 before stop" in tool_result + assert "follow up 2 before stop" in tool_result + + # Tickets should be marked as consumed + assert ticket1.consumed is True + assert ticket2.consumed is True + + +@pytest.mark.asyncio +async def test_follow_up_rejected_and_runner_stops_without_execution( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + """Test that when stop is requested before execution, follow-ups are rejected and runner stops gracefully.""" + + mock_event = MockEvent("test:FriendMessage:stop_before_execution", "u1") + run_context = ContextWrapper(context=MockAgentContext(mock_event)) + + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=run_context, + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + # Request stop before any execution (simulates /stop command received at start) + runner.request_stop() + assert runner._stop_requested is True + + # Try to add follow-up after stop (should be rejected) + ticket_after = runner.follow_up(message_text="follow-up after stop") + assert ticket_after is None, "Post-stop follow-up should be rejected" + + # Verify queue is empty + assert len(runner._pending_follow_ups) == 0 + + # Run the agent step - should stop immediately without executing tools + async for response in runner.step(): + # Should yield an aborted response + if response.type == "aborted": + break + + # Verify runner stopped gracefully + assert runner.done() + assert runner.was_aborted() + + # No tool execution should have occurred + assert provider_request.tool_calls_result is None + + +@pytest.mark.asyncio +async def test_follow_up_after_stop_not_merged_into_tool_result( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + """Regression test for issue #6626: verify post-stop follow-ups are not injected into tool results. + + This test simulates the race condition where: + 1. Runner is active and executing tools + 2. A follow-up is queued (should be included in tool result) + 3. Stop is requested + 4. Another follow-up is attempted (should be rejected) + 5. Tool execution completes and merges follow-ups into result + + The key assertion is that only pre-stop follow-ups are merged into the tool result. + """ + + mock_event = MockEvent("test:FriendMessage:regression_6626", "u1") + run_context = ContextWrapper(context=MockAgentContext(mock_event)) + + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=run_context, + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + # Add a follow-up before stop (should be included in tool result) + ticket_before = runner.follow_up(message_text="valid before stop") + assert ticket_before is not None + assert ticket_before in runner._pending_follow_ups + + # Request stop (simulates /stop command during active execution) + runner.request_stop() + assert runner._stop_requested is True + + # Try to add follow-up after stop (should be rejected) + ticket_after = runner.follow_up(message_text="invalid after stop") + assert ticket_after is None, "Post-stop follow-up should be rejected" + + # Verify queue only contains pre-stop follow-up + assert len(runner._pending_follow_ups) == 1 + assert runner._pending_follow_ups[0].text == "valid before stop" + + # Run the agent step - this will execute tool and merge follow-ups into result + async for response in runner.step(): + # The runner should execute tools and then stop + pass + + # Verify tool result was created with follow-up merged + # Note: When stop is requested, the tool may or may not execute depending on timing. + # The key assertion is that IF tool_calls_result exists, it only contains pre-stop follow-ups. + if provider_request.tool_calls_result is not None: + assert isinstance(provider_request.tool_calls_result, list) + assert provider_request.tool_calls_result + tool_result = str( + provider_request.tool_calls_result[0].tool_calls_result[0].content + ) + + # Should contain the pre-stop follow-up + assert "valid before stop" in tool_result + + # Should NOT contain the post-stop follow-up + assert "invalid after stop" not in tool_result + assert "after stop" not in tool_result or "after stop" in "valid before stop" + + # Ticket should be marked as consumed (merged into tool result) + assert ticket_before.consumed is True + else: + # If tool execution was aborted by stop, the ticket should still be resolved + # but not consumed (since there was no tool call to merge into) + assert ticket_before.resolved.is_set() + + if __name__ == "__main__": # 运行测试 pytest.main([__file__, "-v"])