diff --git a/astrbot/core/agent/context/truncator.py b/astrbot/core/agent/context/truncator.py index 8d1da6f569..afd89f2bed 100644 --- a/astrbot/core/agent/context/truncator.py +++ b/astrbot/core/agent/context/truncator.py @@ -4,19 +4,60 @@ class ContextTruncator: """Context truncator.""" + def _has_tool_calls(self, message: Message) -> bool: + """Check if a message contains tool calls.""" + return ( + message.role == "assistant" + and message.tool_calls is not None + and len(message.tool_calls) > 0 + ) + def fix_messages(self, messages: list[Message]) -> list[Message]: - fixed_messages = [] - for message in messages: - if message.role == "tool": - # tool block 前面必须要有 user 和 assistant block - if len(fixed_messages) < 2: - # 这种情况可能是上下文被截断导致的 - # 我们直接将之前的上下文都清空 - fixed_messages = [] - else: - fixed_messages.append(message) - else: - fixed_messages.append(message) + """修复消息列表,确保 tool call 和 tool response 的配对关系有效。 + + 此方法确保: + 1. 每个 `tool` 消息前面都有一个包含 tool_calls 的 `assistant` 消息 + 2. 每个包含 tool_calls 的 `assistant` 消息后面都有对应的 `tool` 响应 + + 这是 OpenAI Chat Completions API 规范的要求(Gemini 对此执行严格检查)。 + """ + if not messages: + return messages + + fixed_messages: list[Message] = [] + pending_assistant: Message | None = None + pending_tools: list[Message] = [] + + def flush_pending_if_valid() -> None: + nonlocal pending_assistant, pending_tools + if pending_assistant is not None and pending_tools: + fixed_messages.append(pending_assistant) + fixed_messages.extend(pending_tools) + pending_assistant = None + pending_tools = [] + + for msg in messages: + if msg.role == "tool": + # 只有在有挂起的 assistant(tool_calls) 时才记录 tool 响应 + if pending_assistant is not None: + pending_tools.append(msg) + # else: 孤立的 tool 消息,直接忽略 + continue + + if self._has_tool_calls(msg): + # 遇到新的 assistant(tool_calls) 前,先处理旧的 pending 链 + flush_pending_if_valid() + pending_assistant = msg + continue + + # 非 tool,且不含 tool_calls 的消息 + # 先结束任何 pending 链,再正常追加 + flush_pending_if_valid() + fixed_messages.append(msg) + + # 结束时处理最后一个 pending 链 + flush_pending_if_valid() + return fixed_messages def truncate_by_turns(