Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 46 additions & 33 deletions astrbot/core/agent/context/truncator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,37 @@ def _has_tool_calls(self, message: Message) -> bool:
and len(message.tool_calls) > 0
)

@staticmethod
def _split_system_rest(
messages: list[Message],
) -> tuple[list[Message], list[Message]]:
"""把 system 消息和后面的对话消息分开。"""
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
return messages[:first_non_system], messages[first_non_system:]
Comment on lines +20 to +25
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

这个辅助函数的实现存在一个边界情况的 bug:当 messages 列表里全部是 system 消息时,它会错误地返回 ([], messages) 而不是预期的 (messages, [])。这是因为 first_non_system 初始值为 0,并且在循环结束后若未找到非系统消息,其值保持不变。

这个错误会导致后续的截断逻辑错误地作用于 system 消息上。

建议使用更简洁和健壮的方式来实现,可以优雅地处理这个边界情况。

Suggested change
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
return messages[:first_non_system], messages[first_non_system:]
try:
first_non_system = next(i for i, msg in enumerate(messages) if msg.role != "system")
return messages[:first_non_system], messages[first_non_system:]
except StopIteration:
return messages, []

Comment on lines +20 to +25
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): _split_system_rest misclassifies all-system message lists as having no system messages.

If all messages have role == "system", first_non_system is never updated and stays 0, so the function returns ([], messages) and misclassifies the entire list as non-system. Since downstream logic expects system_messages to contain the system-only prefix, this can skew truncation. You can fix this by defaulting first_non_system to len(messages) when no non-system message is found (e.g., via a for/else or a post-loop check), so all-system input yields (messages, []).


@staticmethod
def _ensure_user_message(
system_messages: list[Message],
truncated: list[Message],
original_messages: list[Message],
) -> list[Message]:
"""截断后如果没有 user 消息了,从原始列表里把第一条 user 补回来。
很多 provider (智谱、Gemini 等) 要求 system 之后必须紧跟 user,否则直接 400。
"""
if any(m.role == "user" for m in truncated):
return system_messages + truncated

# 从原始消息里找第一条 user
first_user = next((m for m in original_messages if m.role == "user"), None)
if first_user is None:
return system_messages + truncated

return system_messages + [first_user] + truncated

def fix_messages(self, messages: list[Message]) -> list[Message]:
"""修复消息列表,确保 tool call 和 tool response 的配对关系有效。

Expand Down Expand Up @@ -81,14 +112,7 @@ def truncate_by_turns(
if keep_most_recent_turns == -1:
return messages

first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break

system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
system_messages, non_system_messages = self._split_system_rest(messages)

if len(non_system_messages) // 2 <= keep_most_recent_turns:
return messages
Expand All @@ -99,16 +123,17 @@ def truncate_by_turns(
else:
truncated_contexts = non_system_messages[-num_to_keep * 2 :]

# 找到第一个 role 为 user 的索引,确保上下文格式正确
# 对齐到第一条 user 消息
index = next(
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
None,
)
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]

result = system_messages + truncated_contexts

result = self._ensure_user_message(
system_messages, truncated_contexts, messages
)
return self.fix_messages(result)

def truncate_by_dropping_oldest_turns(
Expand All @@ -120,31 +145,24 @@ def truncate_by_dropping_oldest_turns(
if drop_turns <= 0:
return messages

first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break

system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
system_messages, non_system_messages = self._split_system_rest(messages)

if len(non_system_messages) // 2 <= drop_turns:
truncated_non_system = []
else:
truncated_non_system = non_system_messages[drop_turns * 2 :]

# 对齐到第一条 user
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
truncated_non_system = truncated_non_system[index:]
elif truncated_non_system:
truncated_non_system = []

result = system_messages + truncated_non_system

result = self._ensure_user_message(
system_messages, truncated_non_system, messages
)
return self.fix_messages(result)

def truncate_by_halving(
Expand All @@ -155,28 +173,23 @@ def truncate_by_halving(
if len(messages) <= 2:
return messages

first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break

system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
system_messages, non_system_messages = self._split_system_rest(messages)

messages_to_delete = len(non_system_messages) // 2
if messages_to_delete == 0:
return messages

truncated_non_system = non_system_messages[messages_to_delete:]

# 对齐到第一条 user
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
truncated_non_system = truncated_non_system[index:]

result = system_messages + truncated_non_system

result = self._ensure_user_message(
system_messages, truncated_non_system, messages
)
return self.fix_messages(result)
72 changes: 66 additions & 6 deletions tests/agent/test_truncator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ def test_truncate_by_turns_zero_keep(self):
messages, keep_most_recent_turns=0, drop_turns=1
)

# Should result in empty or minimal list
assert len(result) == 0
# 截断后至少保留一条 user 消息 (#6196)
assert len(result) >= 1
assert result[0].role == "user"

def test_truncate_by_turns_below_threshold(self):
"""Test truncate_by_turns when messages are below threshold."""
Expand Down Expand Up @@ -201,8 +202,9 @@ def test_truncate_by_dropping_oldest_turns_drop_all(self):
messages = self.create_messages(4)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)

# Should drop all turns
assert len(result) == 0
# 即使 drop 掉所有 turn,也会把 user 消息补回来 (#6196)
assert len(result) >= 1
assert result[0].role == "user"

def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
"""Test truncate_by_dropping_oldest_turns with drop_turns > available turns."""
Expand All @@ -211,8 +213,9 @@ def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
messages = self.create_messages(4)
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5)

# Should result in empty list
assert len(result) == 0
# 同理,user 消息会被保留 (#6196)
assert len(result) >= 1
assert result[0].role == "user"

def test_truncate_by_dropping_oldest_turns_ensures_user_first(self):
"""Test that result starts with user message after dropping."""
Expand Down Expand Up @@ -372,3 +375,60 @@ def test_all_system_messages(self):
assert len(result) >= 0 # May keep system messages or clear all
if len(result) > 0:
assert all(msg.role == "system" for msg in result)

# ==================== #6196: 长 tool chain 只有一条 user 消息 ====================

def _build_tool_chain(self, tool_rounds: int = 20) -> list[Message]:
"""构造 system -> user -> (assistant -> tool) * N 的长链,只有一条 user。"""
msgs = [
self.create_message("system", "You are a helpful assistant."),
self.create_message("user", "帮我查一下天气"),
]
for i in range(tool_rounds):
msgs.append(self.create_message("assistant", f"调用工具 {i}"))
msgs.append(self.create_message("tool", f"工具结果 {i}"))
return msgs

def test_drop_oldest_preserves_sole_user(self):
"""#6196: drop 1 turn 不应丢掉唯一的 user 消息。"""
truncator = ContextTruncator()
msgs = self._build_tool_chain(20) # 1 system + 1 user + 40 asst/tool = 42
result = truncator.truncate_by_dropping_oldest_turns(msgs, drop_turns=1)
roles = [m.role for m in result]
assert "user" in roles, "唯一的 user 消息被丢掉了"
assert roles[0] == "system"

def test_halving_preserves_sole_user(self):
"""#6196: 对半砍不应丢掉唯一的 user 消息。"""
truncator = ContextTruncator()
msgs = self._build_tool_chain(20)
result = truncator.truncate_by_halving(msgs)
roles = [m.role for m in result]
assert "user" in roles, "唯一的 user 消息被丢掉了"

def test_truncate_by_turns_preserves_sole_user(self):
"""#6196: keep_most_recent_turns 也不应丢掉唯一的 user 消息。"""
truncator = ContextTruncator()
msgs = self._build_tool_chain(20)
result = truncator.truncate_by_turns(
msgs, keep_most_recent_turns=3, drop_turns=1
)
roles = [m.role for m in result]
assert "user" in roles, "唯一的 user 消息被丢掉了"

def test_drop_oldest_heavy_drops_still_has_user(self):
"""#6196: 大量 drop 也不会丢 user。"""
truncator = ContextTruncator()
msgs = self._build_tool_chain(30)
result = truncator.truncate_by_dropping_oldest_turns(msgs, drop_turns=10)
roles = [m.role for m in result]
assert "user" in roles
Comment on lines +424 to +425
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Strengthen assertion: also verify system -> user ordering for heavy-drop case

This test only checks that a user role is present after heavy dropping. Since the provider requires that system is immediately followed by user, please also assert the ordering here (e.g. assert roles[0] == "system" and assert roles[1] == "user", or that the first non-system message is user) to better guard against regressions when many turns are dropped.

Suggested change
roles = [m.role for m in result]
assert "user" in roles
roles = [m.role for m in result]
# 仍然确保至少有一条 user 消息
assert "user" in roles
# provider 要求 system 后面紧跟 user:
# 1) 如果存在 system,则第一条必须是 system
if roles:
assert roles[0] == "system"
# 2) 第一条非 system 消息必须是 user
first_non_system = next((r for r in roles if r != "system"), None)
assert first_non_system == "user"


def test_normal_multi_user_not_affected(self):
"""正常多 user 对话不受影响。"""
truncator = ContextTruncator()
msgs = self.create_messages(20, include_system=True)
result_before = truncator.truncate_by_dropping_oldest_turns(msgs, drop_turns=2)
# 多 user 场景下截断后仍有 user
roles = [m.role for m in result_before]
assert "user" in roles
Loading