-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
fix: 截断器丢失唯一 user 消息导致智谱等 provider 返回 400 #6581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,37 @@ def _has_tool_calls(self, message: Message) -> bool: | |
| and len(message.tool_calls) > 0 | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def _split_system_rest( | ||
| messages: list[Message], | ||
| ) -> tuple[list[Message], list[Message]]: | ||
| """把 system 消息和后面的对话消息分开。""" | ||
| first_non_system = 0 | ||
| for i, msg in enumerate(messages): | ||
| if msg.role != "system": | ||
| first_non_system = i | ||
| break | ||
| return messages[:first_non_system], messages[first_non_system:] | ||
|
Comment on lines
+20
to
+25
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): _split_system_rest misclassifies all-system message lists as having no system messages. If all messages have |
||
|
|
||
| @staticmethod | ||
| def _ensure_user_message( | ||
| system_messages: list[Message], | ||
| truncated: list[Message], | ||
| original_messages: list[Message], | ||
| ) -> list[Message]: | ||
| """截断后如果没有 user 消息了,从原始列表里把第一条 user 补回来。 | ||
| 很多 provider (智谱、Gemini 等) 要求 system 之后必须紧跟 user,否则直接 400。 | ||
| """ | ||
| if any(m.role == "user" for m in truncated): | ||
| return system_messages + truncated | ||
|
|
||
| # 从原始消息里找第一条 user | ||
| first_user = next((m for m in original_messages if m.role == "user"), None) | ||
| if first_user is None: | ||
| return system_messages + truncated | ||
|
|
||
| return system_messages + [first_user] + truncated | ||
|
|
||
| def fix_messages(self, messages: list[Message]) -> list[Message]: | ||
| """修复消息列表,确保 tool call 和 tool response 的配对关系有效。 | ||
|
|
||
|
|
@@ -81,14 +112,7 @@ def truncate_by_turns( | |
| if keep_most_recent_turns == -1: | ||
| return messages | ||
|
|
||
| first_non_system = 0 | ||
| for i, msg in enumerate(messages): | ||
| if msg.role != "system": | ||
| first_non_system = i | ||
| break | ||
|
|
||
| system_messages = messages[:first_non_system] | ||
| non_system_messages = messages[first_non_system:] | ||
| system_messages, non_system_messages = self._split_system_rest(messages) | ||
|
|
||
| if len(non_system_messages) // 2 <= keep_most_recent_turns: | ||
| return messages | ||
|
|
@@ -99,16 +123,17 @@ def truncate_by_turns( | |
| else: | ||
| truncated_contexts = non_system_messages[-num_to_keep * 2 :] | ||
|
|
||
| # 找到第一个 role 为 user 的索引,确保上下文格式正确 | ||
| # 对齐到第一条 user 消息 | ||
| index = next( | ||
| (i for i, item in enumerate(truncated_contexts) if item.role == "user"), | ||
| None, | ||
| ) | ||
| if index is not None and index > 0: | ||
| truncated_contexts = truncated_contexts[index:] | ||
|
|
||
| result = system_messages + truncated_contexts | ||
|
|
||
| result = self._ensure_user_message( | ||
| system_messages, truncated_contexts, messages | ||
| ) | ||
| return self.fix_messages(result) | ||
|
|
||
| def truncate_by_dropping_oldest_turns( | ||
|
|
@@ -120,31 +145,24 @@ def truncate_by_dropping_oldest_turns( | |
| if drop_turns <= 0: | ||
| return messages | ||
|
|
||
| first_non_system = 0 | ||
| for i, msg in enumerate(messages): | ||
| if msg.role != "system": | ||
| first_non_system = i | ||
| break | ||
|
|
||
| system_messages = messages[:first_non_system] | ||
| non_system_messages = messages[first_non_system:] | ||
| system_messages, non_system_messages = self._split_system_rest(messages) | ||
|
|
||
| if len(non_system_messages) // 2 <= drop_turns: | ||
| truncated_non_system = [] | ||
| else: | ||
| truncated_non_system = non_system_messages[drop_turns * 2 :] | ||
|
|
||
| # 对齐到第一条 user | ||
| index = next( | ||
| (i for i, item in enumerate(truncated_non_system) if item.role == "user"), | ||
| None, | ||
| ) | ||
| if index is not None: | ||
| truncated_non_system = truncated_non_system[index:] | ||
| elif truncated_non_system: | ||
| truncated_non_system = [] | ||
|
|
||
| result = system_messages + truncated_non_system | ||
|
|
||
| result = self._ensure_user_message( | ||
| system_messages, truncated_non_system, messages | ||
| ) | ||
| return self.fix_messages(result) | ||
|
|
||
| def truncate_by_halving( | ||
|
|
@@ -155,28 +173,23 @@ def truncate_by_halving( | |
| if len(messages) <= 2: | ||
| return messages | ||
|
|
||
| first_non_system = 0 | ||
| for i, msg in enumerate(messages): | ||
| if msg.role != "system": | ||
| first_non_system = i | ||
| break | ||
|
|
||
| system_messages = messages[:first_non_system] | ||
| non_system_messages = messages[first_non_system:] | ||
| system_messages, non_system_messages = self._split_system_rest(messages) | ||
|
|
||
| messages_to_delete = len(non_system_messages) // 2 | ||
| if messages_to_delete == 0: | ||
| return messages | ||
|
|
||
| truncated_non_system = non_system_messages[messages_to_delete:] | ||
|
|
||
| # 对齐到第一条 user | ||
| index = next( | ||
| (i for i, item in enumerate(truncated_non_system) if item.role == "user"), | ||
| None, | ||
| ) | ||
| if index is not None: | ||
| truncated_non_system = truncated_non_system[index:] | ||
|
|
||
| result = system_messages + truncated_non_system | ||
|
|
||
| result = self._ensure_user_message( | ||
| system_messages, truncated_non_system, messages | ||
| ) | ||
| return self.fix_messages(result) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -104,8 +104,9 @@ def test_truncate_by_turns_zero_keep(self): | |||||||||||||||||||||||||||||
| messages, keep_most_recent_turns=0, drop_turns=1 | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Should result in empty or minimal list | ||||||||||||||||||||||||||||||
| assert len(result) == 0 | ||||||||||||||||||||||||||||||
| # 截断后至少保留一条 user 消息 (#6196) | ||||||||||||||||||||||||||||||
| assert len(result) >= 1 | ||||||||||||||||||||||||||||||
| assert result[0].role == "user" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def test_truncate_by_turns_below_threshold(self): | ||||||||||||||||||||||||||||||
| """Test truncate_by_turns when messages are below threshold.""" | ||||||||||||||||||||||||||||||
|
|
@@ -201,8 +202,9 @@ def test_truncate_by_dropping_oldest_turns_drop_all(self): | |||||||||||||||||||||||||||||
| messages = self.create_messages(4) | ||||||||||||||||||||||||||||||
| result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Should drop all turns | ||||||||||||||||||||||||||||||
| assert len(result) == 0 | ||||||||||||||||||||||||||||||
| # 即使 drop 掉所有 turn,也会把 user 消息补回来 (#6196) | ||||||||||||||||||||||||||||||
| assert len(result) >= 1 | ||||||||||||||||||||||||||||||
| assert result[0].role == "user" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self): | ||||||||||||||||||||||||||||||
| """Test truncate_by_dropping_oldest_turns with drop_turns > available turns.""" | ||||||||||||||||||||||||||||||
|
|
@@ -211,8 +213,9 @@ def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self): | |||||||||||||||||||||||||||||
| messages = self.create_messages(4) | ||||||||||||||||||||||||||||||
| result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Should result in empty list | ||||||||||||||||||||||||||||||
| assert len(result) == 0 | ||||||||||||||||||||||||||||||
| # 同理,user 消息会被保留 (#6196) | ||||||||||||||||||||||||||||||
| assert len(result) >= 1 | ||||||||||||||||||||||||||||||
| assert result[0].role == "user" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def test_truncate_by_dropping_oldest_turns_ensures_user_first(self): | ||||||||||||||||||||||||||||||
| """Test that result starts with user message after dropping.""" | ||||||||||||||||||||||||||||||
|
|
@@ -372,3 +375,60 @@ def test_all_system_messages(self): | |||||||||||||||||||||||||||||
| assert len(result) >= 0 # May keep system messages or clear all | ||||||||||||||||||||||||||||||
| if len(result) > 0: | ||||||||||||||||||||||||||||||
| assert all(msg.role == "system" for msg in result) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # ==================== #6196: 长 tool chain 只有一条 user 消息 ==================== | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _build_tool_chain(self, tool_rounds: int = 20) -> list[Message]: | ||||||||||||||||||||||||||||||
| """构造 system -> user -> (assistant -> tool) * N 的长链,只有一条 user。""" | ||||||||||||||||||||||||||||||
| msgs = [ | ||||||||||||||||||||||||||||||
| self.create_message("system", "You are a helpful assistant."), | ||||||||||||||||||||||||||||||
| self.create_message("user", "帮我查一下天气"), | ||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||
| for i in range(tool_rounds): | ||||||||||||||||||||||||||||||
| msgs.append(self.create_message("assistant", f"调用工具 {i}")) | ||||||||||||||||||||||||||||||
| msgs.append(self.create_message("tool", f"工具结果 {i}")) | ||||||||||||||||||||||||||||||
| return msgs | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def test_drop_oldest_preserves_sole_user(self): | ||||||||||||||||||||||||||||||
| """#6196: drop 1 turn 不应丢掉唯一的 user 消息。""" | ||||||||||||||||||||||||||||||
| truncator = ContextTruncator() | ||||||||||||||||||||||||||||||
| msgs = self._build_tool_chain(20) # 1 system + 1 user + 40 asst/tool = 42 | ||||||||||||||||||||||||||||||
| result = truncator.truncate_by_dropping_oldest_turns(msgs, drop_turns=1) | ||||||||||||||||||||||||||||||
| roles = [m.role for m in result] | ||||||||||||||||||||||||||||||
| assert "user" in roles, "唯一的 user 消息被丢掉了" | ||||||||||||||||||||||||||||||
| assert roles[0] == "system" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def test_halving_preserves_sole_user(self): | ||||||||||||||||||||||||||||||
| """#6196: 对半砍不应丢掉唯一的 user 消息。""" | ||||||||||||||||||||||||||||||
| truncator = ContextTruncator() | ||||||||||||||||||||||||||||||
| msgs = self._build_tool_chain(20) | ||||||||||||||||||||||||||||||
| result = truncator.truncate_by_halving(msgs) | ||||||||||||||||||||||||||||||
| roles = [m.role for m in result] | ||||||||||||||||||||||||||||||
| assert "user" in roles, "唯一的 user 消息被丢掉了" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def test_truncate_by_turns_preserves_sole_user(self): | ||||||||||||||||||||||||||||||
| """#6196: keep_most_recent_turns 也不应丢掉唯一的 user 消息。""" | ||||||||||||||||||||||||||||||
| truncator = ContextTruncator() | ||||||||||||||||||||||||||||||
| msgs = self._build_tool_chain(20) | ||||||||||||||||||||||||||||||
| result = truncator.truncate_by_turns( | ||||||||||||||||||||||||||||||
| msgs, keep_most_recent_turns=3, drop_turns=1 | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| roles = [m.role for m in result] | ||||||||||||||||||||||||||||||
| assert "user" in roles, "唯一的 user 消息被丢掉了" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def test_drop_oldest_heavy_drops_still_has_user(self): | ||||||||||||||||||||||||||||||
| """#6196: 大量 drop 也不会丢 user。""" | ||||||||||||||||||||||||||||||
| truncator = ContextTruncator() | ||||||||||||||||||||||||||||||
| msgs = self._build_tool_chain(30) | ||||||||||||||||||||||||||||||
| result = truncator.truncate_by_dropping_oldest_turns(msgs, drop_turns=10) | ||||||||||||||||||||||||||||||
| roles = [m.role for m in result] | ||||||||||||||||||||||||||||||
| assert "user" in roles | ||||||||||||||||||||||||||||||
|
Comment on lines
+424
to
+425
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Strengthen assertion: also verify This test only checks that a
Suggested change
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def test_normal_multi_user_not_affected(self): | ||||||||||||||||||||||||||||||
| """正常多 user 对话不受影响。""" | ||||||||||||||||||||||||||||||
| truncator = ContextTruncator() | ||||||||||||||||||||||||||||||
| msgs = self.create_messages(20, include_system=True) | ||||||||||||||||||||||||||||||
| result_before = truncator.truncate_by_dropping_oldest_turns(msgs, drop_turns=2) | ||||||||||||||||||||||||||||||
| # 多 user 场景下截断后仍有 user | ||||||||||||||||||||||||||||||
| roles = [m.role for m in result_before] | ||||||||||||||||||||||||||||||
| assert "user" in roles | ||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个辅助函数的实现存在一个边界情况的 bug:当
messages列表里全部是system消息时,它会错误地返回([], messages)而不是预期的(messages, [])。这是因为first_non_system初始值为 0,并且在循环结束后若未找到非系统消息,其值保持不变。这个错误会导致后续的截断逻辑错误地作用于
system消息上。建议使用更简洁和健壮的方式来实现,可以优雅地处理这个边界情况。