From cd613e674f5c0c7f37521059c6928725b2d71e79 Mon Sep 17 00:00:00 2001 From: CompilError-bts <61622360+CompilError-bts@users.noreply.github.com> Date: Tue, 31 Mar 2026 20:44:32 +0800 Subject: [PATCH] refactor: deduplicate unit range detection in pop_record --- astrbot/core/provider/provider.py | 49 +++++++++--- tests/test_openai_source.py | 128 +++++++++++++++++++++++++++--- 2 files changed, 155 insertions(+), 22 deletions(-) diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index fab3ce6104..b5f291f8e1 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -162,19 +162,44 @@ async def text_chat_stream( raise NotImplementedError() async def pop_record(self, context: list) -> None: - """弹出 context 第一条非系统提示词对话记录""" - poped = 0 - indexs_to_pop = [] - for idx, record in enumerate(context): - if record["role"] == "system": - continue - indexs_to_pop.append(idx) - poped += 1 - if poped == 2: + """弹出最早的非 system 记录,同时保持 tool_calls 与 tool 配对完整。""" + + def _has_tool_calls(message: dict) -> bool: + return bool(message.get("tool_calls")) + + def _next_unit_bounds() -> tuple[int, int] | None: + for idx, record in enumerate(context): + if record.get("role") != "system": + end_idx = idx + role = record.get("role") + if role == "assistant" and _has_tool_calls(record): + # Keep assistant(tool_calls) and following tool messages atomic. + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + elif role == "tool": + # Remove leading orphan tool messages together. + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + return idx, end_idx + return None + + removed = 0 + while removed < 2: + next_unit = _next_unit_bounds() + if next_unit is None: break - - for idx in reversed(indexs_to_pop): - context.pop(idx) + start_idx, end_idx = next_unit + next_unit_count = end_idx - start_idx + 1 + # Keep behavior close to the old "pop around 2 records" strategy, + # while still preserving tool-call atomicity. + if removed > 0 and removed + next_unit_count > 3: + break + del context[start_idx : end_idx + 1] + removed += next_unit_count def _ensure_message_to_dicts( self, diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 39bb6d3810..0f31dea3da 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -1,4 +1,6 @@ +from pathlib import Path from types import SimpleNamespace +from urllib.parse import urlparse, urlunparse import pytest from openai.types.chat.chat_completion import ChatCompletion @@ -244,6 +246,112 @@ async def test_openai_payload_keeps_reasoning_content_in_assistant_history(): await provider.terminate() +@pytest.mark.asyncio +async def test_pop_record_removes_assistant_tool_calls_with_following_tools_atomically(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + {"role": "assistant", "tool_calls": [{"id": "call_1"}], "content": None}, + {"role": "tool", "tool_call_id": "call_1", "content": "result"}, + {"role": "user", "content": "keep me"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "keep me"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_removes_leading_orphan_tool_messages(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + {"role": "tool", "tool_call_id": "call_1", "content": "orphan"}, + {"role": "user", "content": "old user"}, + {"role": "assistant", "content": "old assistant"}, + {"role": "user", "content": "new user"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "assistant", "content": "old assistant"}, + {"role": "user", "content": "new user"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_normal_messages_no_regression(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "user1"}, + {"role": "assistant", "content": "assistant1"}, + {"role": "user", "content": "user2"}, + {"role": "assistant", "content": "assistant2"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "user2"}, + {"role": "assistant", "content": "assistant2"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_assistant_with_multiple_tool_calls(): + provider = _make_provider() + try: + context = [ + {"role": "system", "content": "system"}, + { + "role": "assistant", + "tool_calls": [{"id": "call_1"}, {"id": "call_2"}], + "content": None, + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result1"}, + {"role": "tool", "tool_call_id": "call_2", "content": "result2"}, + {"role": "user", "content": "keep me"}, + ] + + await provider.pop_record(context) + + assert context == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "keep me"}, + ] + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_pop_record_only_system_messages(): + provider = _make_provider() + try: + context = [{"role": "system", "content": "system"}] + + await provider.pop_record(context) + + assert context == [{"role": "system", "content": "system"}] + finally: + await provider.terminate() + + @pytest.mark.asyncio async def test_groq_payload_drops_reasoning_content_from_assistant_history(): provider = _make_groq_provider() @@ -782,9 +890,8 @@ async def test_prepare_chat_payload_materializes_context_file_uri_image_urls(tmp async def test_file_uri_to_path_preserves_windows_drive_letter(): provider = _make_provider() try: - assert provider._file_uri_to_path("file:///C:/tmp/quoted-image.png") == ( - "C:/tmp/quoted-image.png" - ) + resolved = provider._file_uri_to_path("file:///C:/tmp/quoted-image.png") + assert Path(resolved) == Path("C:/tmp/quoted-image.png") finally: await provider.terminate() @@ -793,9 +900,8 @@ async def test_file_uri_to_path_preserves_windows_drive_letter(): async def test_file_uri_to_path_preserves_windows_netloc_drive_letter(): provider = _make_provider() try: - assert provider._file_uri_to_path("file://C:/tmp/quoted-image.png") == ( - "C:/tmp/quoted-image.png" - ) + resolved = provider._file_uri_to_path("file://C:/tmp/quoted-image.png") + assert Path(resolved) == Path("C:/tmp/quoted-image.png") finally: await provider.terminate() @@ -804,9 +910,8 @@ async def test_file_uri_to_path_preserves_windows_netloc_drive_letter(): async def test_file_uri_to_path_preserves_remote_netloc_as_unc_path(): provider = _make_provider() try: - assert provider._file_uri_to_path("file://server/share/quoted-image.png") == ( - "//server/share/quoted-image.png" - ) + resolved = provider._file_uri_to_path("file://server/share/quoted-image.png") + assert Path(resolved) == Path("//server/share/quoted-image.png") finally: await provider.terminate() @@ -977,7 +1082,10 @@ async def test_prepare_chat_payload_materializes_context_localhost_file_uri_imag image_path = tmp_path / "quoted-image.png" PILImage.new("RGBA", (1, 1), (255, 0, 0, 255)).save(image_path) - localhost_uri = f"file://localhost{image_path.as_posix()}" + parsed_local_uri = urlparse(image_path.as_uri()) + localhost_uri = urlunparse( + ("file", "localhost", parsed_local_uri.path, "", "", "") + ) payloads, _ = await provider._prepare_chat_payload( prompt=None, contexts=[