diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index fab3ce6104..7031d52fc8 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -162,19 +162,76 @@ async def text_chat_stream( raise NotImplementedError() async def pop_record(self, context: list) -> None: - """弹出 context 第一条非系统提示词对话记录""" - poped = 0 - indexs_to_pop = [] - for idx, record in enumerate(context): - if record["role"] == "system": - continue - indexs_to_pop.append(idx) - poped += 1 - if poped == 2: + """弹出最早的非 system 记录,同时保持 tool_calls 与 tool 配对完整。""" + + def _has_tool_calls(message: dict) -> bool: + return bool(message.get("tool_calls")) + + def _first_non_system_index() -> int | None: + for idx, record in enumerate(context): + if record.get("role") != "system": + return idx + return None + + def _pop_earliest_unit() -> int: + start_idx = _first_non_system_index() + if start_idx is None: + return 0 + + record = context[start_idx] + role = record.get("role") + end_idx = start_idx + + if role == "assistant" and _has_tool_calls(record): + # Keep assistant(tool_calls) and following tool messages atomic. + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + elif role == "tool": + # Remove leading orphan tool messages together. + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + + removed_count = end_idx - start_idx + 1 + del context[start_idx : end_idx + 1] + return removed_count + + def _peek_earliest_unit_count() -> int: + start_idx = _first_non_system_index() + if start_idx is None: + return 0 + + record = context[start_idx] + role = record.get("role") + end_idx = start_idx + if role == "assistant" and _has_tool_calls(record): + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + elif role == "tool": + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + return end_idx - start_idx + 1 + + removed = 0 + while removed < 2: + next_unit_count = _peek_earliest_unit_count() + if next_unit_count == 0: break - - for idx in reversed(indexs_to_pop): - context.pop(idx) + # Keep behavior close to the old "pop around 2 records" strategy, + # while still preserving tool-call atomicity. + if removed > 0 and removed + next_unit_count > 3: + break + removed_now = _pop_earliest_unit() + if removed_now == 0: + break + removed += removed_now def _ensure_message_to_dicts( self, diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 39bb6d3810..55b5fb46e5 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -1,4 +1,6 @@ from types import SimpleNamespace +from pathlib import Path +from urllib.parse import urlparse, urlunparse import pytest from openai.types.chat.chat_completion import ChatCompletion @@ -244,6 +246,112 @@ async def test_openai_payload_keeps_reasoning_content_in_assistant_history(): await provider.terminate() +@pytest.mark.asyncio +async def test_pop_record_removes_assistant_tool_calls_with_following_tools_atomically(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + {"role": "assistant", "tool_calls": [{"id": "call_1"}], "content": None}, + {"role": "tool", "tool_call_id": "call_1", "content": "result"}, + {"role": "user", "content": "keep me"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "keep me"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_removes_leading_orphan_tool_messages(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + {"role": "tool", "tool_call_id": "call_1", "content": "orphan"}, + {"role": "user", "content": "old user"}, + {"role": "assistant", "content": "old assistant"}, + {"role": "user", "content": "new user"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "assistant", "content": "old assistant"}, + {"role": "user", "content": "new user"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_normal_messages_no_regression(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "user1"}, + {"role": "assistant", "content": "assistant1"}, + {"role": "user", "content": "user2"}, + {"role": "assistant", "content": "assistant2"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "user2"}, + {"role": "assistant", "content": "assistant2"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_assistant_with_multiple_tool_calls(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + { + "role": "assistant", + "tool_calls": [{"id": "call_1"}, {"id": "call_2"}], + "content": None, + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result1"}, + {"role": "tool", "tool_call_id": "call_2", "content": "result2"}, + {"role": "user", "content": "keep me"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "keep me"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_only_system_messages(): + provider = _make_provider() + try: + context = [{"role": "system", "content": "system"}] + + await provider.pop_record(context) + + assert context == [{"role": "system", "content": "system"}] + finally: + await provider.terminate() + + @pytest.mark.asyncio async def test_groq_payload_drops_reasoning_content_from_assistant_history(): provider = _make_groq_provider() @@ -782,9 +890,8 @@ async def test_prepare_chat_payload_materializes_context_file_uri_image_urls(tmp async def test_file_uri_to_path_preserves_windows_drive_letter(): provider = _make_provider() try: - assert provider._file_uri_to_path("file:///C:/tmp/quoted-image.png") == ( - "C:/tmp/quoted-image.png" - ) + resolved = provider._file_uri_to_path("file:///C:/tmp/quoted-image.png") + assert Path(resolved) == Path("C:/tmp/quoted-image.png") finally: await provider.terminate() @@ -793,9 +900,8 @@ async def test_file_uri_to_path_preserves_windows_drive_letter(): async def test_file_uri_to_path_preserves_windows_netloc_drive_letter(): provider = _make_provider() try: - assert provider._file_uri_to_path("file://C:/tmp/quoted-image.png") == ( - "C:/tmp/quoted-image.png" - ) + resolved = provider._file_uri_to_path("file://C:/tmp/quoted-image.png") + assert Path(resolved) == Path("C:/tmp/quoted-image.png") finally: await provider.terminate() @@ -804,9 +910,8 @@ async def test_file_uri_to_path_preserves_windows_netloc_drive_letter(): async def test_file_uri_to_path_preserves_remote_netloc_as_unc_path(): provider = _make_provider() try: - assert provider._file_uri_to_path("file://server/share/quoted-image.png") == ( - "//server/share/quoted-image.png" - ) + resolved = provider._file_uri_to_path("file://server/share/quoted-image.png") + assert Path(resolved) == Path("//server/share/quoted-image.png") finally: await provider.terminate() @@ -977,7 +1082,10 @@ async def test_prepare_chat_payload_materializes_context_localhost_file_uri_imag image_path = tmp_path / "quoted-image.png" PILImage.new("RGBA", (1, 1), (255, 0, 0, 255)).save(image_path) - localhost_uri = f"file://localhost{image_path.as_posix()}" + parsed_local_uri = urlparse(image_path.as_uri()) + localhost_uri = urlunparse( + ("file", "localhost", parsed_local_uri.path, "", "", "") + ) payloads, _ = await provider._prepare_chat_payload( prompt=None, contexts=[