From 49d79de55b3332c3bc9871b210277a635d2e4893 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+universeplayer@users.noreply.github.com> Date: Wed, 18 Mar 2026 23:52:37 +0800 Subject: [PATCH 1/2] =?UTF-8?q?fix:=20=E6=88=AA=E6=96=AD=E5=99=A8=E4=B8=A2?= =?UTF-8?q?=E5=A4=B1=E5=94=AF=E4=B8=80=20user=20=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E5=AF=BC=E8=87=B4=20API=20400?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复 #6196 当对话只有一条 user 消息(长 tool chain 场景:system → user → assistant → tool → assistant → tool → ...),三个截断方法都会把这条 user 消息丢掉, 导致智谱、Gemini 等要求 user 消息的 provider 返回 400。 改动: - 提取 `_split_system_rest()` 去掉三个方法里重复的 system/non-system 拆分 - 新增 `_ensure_user_message()`:截断后如果没有 user 了,从原始消息里补回 第一条 user,避免违反 API 格式要求 - 删掉 `truncate_by_dropping_oldest_turns` 里把没有 user 就清空全部消息的逻辑 - 5 个新测试覆盖单 user + 长 tool chain 场景,3 个旧测试更新断言 --- astrbot/core/agent/context/truncator.py | 81 +++++++++++++++---------- tests/agent/test_truncator.py | 72 ++++++++++++++++++++-- 2 files changed, 114 insertions(+), 39 deletions(-) diff --git a/astrbot/core/agent/context/truncator.py b/astrbot/core/agent/context/truncator.py index afd89f2bed..5b2290a6e7 100644 --- a/astrbot/core/agent/context/truncator.py +++ b/astrbot/core/agent/context/truncator.py @@ -12,6 +12,39 @@ def _has_tool_calls(self, message: Message) -> bool: and len(message.tool_calls) > 0 ) + @staticmethod + def _split_system_rest( + messages: list[Message], + ) -> tuple[list[Message], list[Message]]: + """把 system 消息和后面的对话消息分开。""" + first_non_system = 0 + for i, msg in enumerate(messages): + if msg.role != "system": + first_non_system = i + break + return messages[:first_non_system], messages[first_non_system:] + + @staticmethod + def _ensure_user_message( + system_messages: list[Message], + truncated: list[Message], + original_messages: list[Message], + ) -> list[Message]: + """截断后如果没有 user 消息了,从原始列表里把第一条 user 补回来。 + 很多 provider (智谱、Gemini 等) 要求 system 之后必须紧跟 user,否则直接 400。 + """ + if any(m.role == "user" for m in truncated): + return system_messages + truncated + + # 从原始消息里找第一条 user + first_user = next( + (m for m in original_messages if m.role == "user"), None + ) + if first_user is None: + return system_messages + truncated + + return system_messages + [first_user] + truncated + def fix_messages(self, messages: list[Message]) -> list[Message]: """修复消息列表,确保 tool call 和 tool response 的配对关系有效。 @@ -81,14 +114,7 @@ def truncate_by_turns( if keep_most_recent_turns == -1: return messages - first_non_system = 0 - for i, msg in enumerate(messages): - if msg.role != "system": - first_non_system = i - break - - system_messages = messages[:first_non_system] - non_system_messages = messages[first_non_system:] + system_messages, non_system_messages = self._split_system_rest(messages) if len(non_system_messages) // 2 <= keep_most_recent_turns: return messages @@ -99,7 +125,7 @@ def truncate_by_turns( else: truncated_contexts = non_system_messages[-num_to_keep * 2 :] - # 找到第一个 role 为 user 的索引,确保上下文格式正确 + # 对齐到第一条 user 消息 index = next( (i for i, item in enumerate(truncated_contexts) if item.role == "user"), None, @@ -107,8 +133,9 @@ def truncate_by_turns( if index is not None and index > 0: truncated_contexts = truncated_contexts[index:] - result = system_messages + truncated_contexts - + result = self._ensure_user_message( + system_messages, truncated_contexts, messages + ) return self.fix_messages(result) def truncate_by_dropping_oldest_turns( @@ -120,31 +147,24 @@ def truncate_by_dropping_oldest_turns( if drop_turns <= 0: return messages - first_non_system = 0 - for i, msg in enumerate(messages): - if msg.role != "system": - first_non_system = i - break - - system_messages = messages[:first_non_system] - non_system_messages = messages[first_non_system:] + system_messages, non_system_messages = self._split_system_rest(messages) if len(non_system_messages) // 2 <= drop_turns: truncated_non_system = [] else: truncated_non_system = non_system_messages[drop_turns * 2 :] + # 对齐到第一条 user index = next( (i for i, item in enumerate(truncated_non_system) if item.role == "user"), None, ) if index is not None: truncated_non_system = truncated_non_system[index:] - elif truncated_non_system: - truncated_non_system = [] - - result = system_messages + truncated_non_system + result = self._ensure_user_message( + system_messages, truncated_non_system, messages + ) return self.fix_messages(result) def truncate_by_halving( @@ -155,14 +175,7 @@ def truncate_by_halving( if len(messages) <= 2: return messages - first_non_system = 0 - for i, msg in enumerate(messages): - if msg.role != "system": - first_non_system = i - break - - system_messages = messages[:first_non_system] - non_system_messages = messages[first_non_system:] + system_messages, non_system_messages = self._split_system_rest(messages) messages_to_delete = len(non_system_messages) // 2 if messages_to_delete == 0: @@ -170,6 +183,7 @@ def truncate_by_halving( truncated_non_system = non_system_messages[messages_to_delete:] + # 对齐到第一条 user index = next( (i for i, item in enumerate(truncated_non_system) if item.role == "user"), None, @@ -177,6 +191,7 @@ def truncate_by_halving( if index is not None: truncated_non_system = truncated_non_system[index:] - result = system_messages + truncated_non_system - + result = self._ensure_user_message( + system_messages, truncated_non_system, messages + ) return self.fix_messages(result) diff --git a/tests/agent/test_truncator.py b/tests/agent/test_truncator.py index c85fb7d558..7dac80f9ce 100644 --- a/tests/agent/test_truncator.py +++ b/tests/agent/test_truncator.py @@ -104,8 +104,9 @@ def test_truncate_by_turns_zero_keep(self): messages, keep_most_recent_turns=0, drop_turns=1 ) - # Should result in empty or minimal list - assert len(result) == 0 + # 截断后至少保留一条 user 消息 (#6196) + assert len(result) >= 1 + assert result[0].role == "user" def test_truncate_by_turns_below_threshold(self): """Test truncate_by_turns when messages are below threshold.""" @@ -201,8 +202,9 @@ def test_truncate_by_dropping_oldest_turns_drop_all(self): messages = self.create_messages(4) result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2) - # Should drop all turns - assert len(result) == 0 + # 即使 drop 掉所有 turn,也会把 user 消息补回来 (#6196) + assert len(result) >= 1 + assert result[0].role == "user" def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self): """Test truncate_by_dropping_oldest_turns with drop_turns > available turns.""" @@ -211,8 +213,9 @@ def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self): messages = self.create_messages(4) result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5) - # Should result in empty list - assert len(result) == 0 + # 同理,user 消息会被保留 (#6196) + assert len(result) >= 1 + assert result[0].role == "user" def test_truncate_by_dropping_oldest_turns_ensures_user_first(self): """Test that result starts with user message after dropping.""" @@ -372,3 +375,60 @@ def test_all_system_messages(self): assert len(result) >= 0 # May keep system messages or clear all if len(result) > 0: assert all(msg.role == "system" for msg in result) + + # ==================== #6196: 长 tool chain 只有一条 user 消息 ==================== + + def _build_tool_chain(self, tool_rounds: int = 20) -> list[Message]: + """构造 system -> user -> (assistant -> tool) * N 的长链,只有一条 user。""" + msgs = [ + self.create_message("system", "You are a helpful assistant."), + self.create_message("user", "帮我查一下天气"), + ] + for i in range(tool_rounds): + msgs.append(self.create_message("assistant", f"调用工具 {i}")) + msgs.append(self.create_message("tool", f"工具结果 {i}")) + return msgs + + def test_drop_oldest_preserves_sole_user(self): + """#6196: drop 1 turn 不应丢掉唯一的 user 消息。""" + truncator = ContextTruncator() + msgs = self._build_tool_chain(20) # 1 system + 1 user + 40 asst/tool = 42 + result = truncator.truncate_by_dropping_oldest_turns(msgs, drop_turns=1) + roles = [m.role for m in result] + assert "user" in roles, "唯一的 user 消息被丢掉了" + assert roles[0] == "system" + + def test_halving_preserves_sole_user(self): + """#6196: 对半砍不应丢掉唯一的 user 消息。""" + truncator = ContextTruncator() + msgs = self._build_tool_chain(20) + result = truncator.truncate_by_halving(msgs) + roles = [m.role for m in result] + assert "user" in roles, "唯一的 user 消息被丢掉了" + + def test_truncate_by_turns_preserves_sole_user(self): + """#6196: keep_most_recent_turns 也不应丢掉唯一的 user 消息。""" + truncator = ContextTruncator() + msgs = self._build_tool_chain(20) + result = truncator.truncate_by_turns( + msgs, keep_most_recent_turns=3, drop_turns=1 + ) + roles = [m.role for m in result] + assert "user" in roles, "唯一的 user 消息被丢掉了" + + def test_drop_oldest_heavy_drops_still_has_user(self): + """#6196: 大量 drop 也不会丢 user。""" + truncator = ContextTruncator() + msgs = self._build_tool_chain(30) + result = truncator.truncate_by_dropping_oldest_turns(msgs, drop_turns=10) + roles = [m.role for m in result] + assert "user" in roles + + def test_normal_multi_user_not_affected(self): + """正常多 user 对话不受影响。""" + truncator = ContextTruncator() + msgs = self.create_messages(20, include_system=True) + result_before = truncator.truncate_by_dropping_oldest_turns(msgs, drop_turns=2) + # 多 user 场景下截断后仍有 user + roles = [m.role for m in result_before] + assert "user" in roles From 6bd12006779bcb9943c59ac6935e563e98a51f7e Mon Sep 17 00:00:00 2001 From: RC-CHN <1051989940@qq.com> Date: Fri, 20 Mar 2026 10:00:20 +0800 Subject: [PATCH 2/2] style: format code --- astrbot/core/agent/context/truncator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/astrbot/core/agent/context/truncator.py b/astrbot/core/agent/context/truncator.py index 5b2290a6e7..b552b5ef2f 100644 --- a/astrbot/core/agent/context/truncator.py +++ b/astrbot/core/agent/context/truncator.py @@ -37,9 +37,7 @@ def _ensure_user_message( return system_messages + truncated # 从原始消息里找第一条 user - first_user = next( - (m for m in original_messages if m.role == "user"), None - ) + first_user = next((m for m in original_messages if m.role == "user"), None) if first_user is None: return system_messages + truncated