diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 7031d52fc8..b5f291f8e1 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -167,71 +167,39 @@ async def pop_record(self, context: list) -> None: def _has_tool_calls(message: dict) -> bool: return bool(message.get("tool_calls")) - def _first_non_system_index() -> int | None: + def _next_unit_bounds() -> tuple[int, int] | None: for idx, record in enumerate(context): if record.get("role") != "system": - return idx + end_idx = idx + role = record.get("role") + if role == "assistant" and _has_tool_calls(record): + # Keep assistant(tool_calls) and following tool messages atomic. + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + elif role == "tool": + # Remove leading orphan tool messages together. + while end_idx + 1 < len(context) and ( + context[end_idx + 1].get("role") == "tool" + ): + end_idx += 1 + return idx, end_idx return None - def _pop_earliest_unit() -> int: - start_idx = _first_non_system_index() - if start_idx is None: - return 0 - - record = context[start_idx] - role = record.get("role") - end_idx = start_idx - - if role == "assistant" and _has_tool_calls(record): - # Keep assistant(tool_calls) and following tool messages atomic. - while end_idx + 1 < len(context) and ( - context[end_idx + 1].get("role") == "tool" - ): - end_idx += 1 - elif role == "tool": - # Remove leading orphan tool messages together. - while end_idx + 1 < len(context) and ( - context[end_idx + 1].get("role") == "tool" - ): - end_idx += 1 - - removed_count = end_idx - start_idx + 1 - del context[start_idx : end_idx + 1] - return removed_count - - def _peek_earliest_unit_count() -> int: - start_idx = _first_non_system_index() - if start_idx is None: - return 0 - - record = context[start_idx] - role = record.get("role") - end_idx = start_idx - if role == "assistant" and _has_tool_calls(record): - while end_idx + 1 < len(context) and ( - context[end_idx + 1].get("role") == "tool" - ): - end_idx += 1 - elif role == "tool": - while end_idx + 1 < len(context) and ( - context[end_idx + 1].get("role") == "tool" - ): - end_idx += 1 - return end_idx - start_idx + 1 - removed = 0 while removed < 2: - next_unit_count = _peek_earliest_unit_count() - if next_unit_count == 0: + next_unit = _next_unit_bounds() + if next_unit is None: break + start_idx, end_idx = next_unit + next_unit_count = end_idx - start_idx + 1 # Keep behavior close to the old "pop around 2 records" strategy, # while still preserving tool-call atomicity. if removed > 0 and removed + next_unit_count > 3: break - removed_now = _pop_earliest_unit() - if removed_now == 0: - break - removed += removed_now + del context[start_idx : end_idx + 1] + removed += next_unit_count def _ensure_message_to_dicts( self, diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 55b5fb46e5..0f31dea3da 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -1,5 +1,5 @@ -from types import SimpleNamespace from pathlib import Path +from types import SimpleNamespace from urllib.parse import urlparse, urlunparse import pytest