Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 82 additions & 15 deletions src/stirrup/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
TURNS_REMAINING_WARNING_THRESHOLD,
)
from stirrup.core.cache import CacheManager, CacheState, compute_task_hash
from stirrup.core.exceptions import ContextOverflowError
from stirrup.core.models import (
AssistantMessage,
ChatMessage,
Expand Down Expand Up @@ -178,6 +179,27 @@ def _get_model_speed_stats(messages: list[list[ChatMessage]], model_slug: str) -
}


def _split_latest_assistant_suffix(
messages: list[ChatMessage],
) -> tuple[list[ChatMessage], list[ChatMessage]] | None:
"""Split messages into older history and the newest assistant-led suffix."""
for index in range(len(messages) - 1, -1, -1):
if isinstance(messages[index], AssistantMessage):
return messages[:index], messages[index:]
return None


def _copy_messages_for_summarized_trajectory(messages: list[ChatMessage]) -> list[ChatMessage]:
"""Copy preserved messages, zeroing assistant token usage in the new trajectory."""
copied_messages: list[ChatMessage] = []
for message in messages:
if isinstance(message, AssistantMessage):
copied_messages.append(message.model_copy(update={"token_usage": TokenUsage()}))
continue
copied_messages.append(message.model_copy())
return copied_messages


type JsonSchema = dict[str, object]


Expand Down Expand Up @@ -1098,27 +1120,71 @@ async def step(

return assistant_message, tool_messages, finish_params

async def _step_with_overflow_recovery(
self,
messages: list[ChatMessage],
full_msg_history: list[list[ChatMessage]],
run_metadata: dict[str, list[Any]],
*,
turn: int = 0,
max_turns: int = 0,
) -> tuple[AssistantMessage, list[ToolMessage], FinishParams | None, list[ChatMessage]]:
"""Run one step, summarizing once if the client reports context overflow."""
try:
assistant_message, tool_messages, finish_params = await self.step(
messages,
run_metadata,
turn=turn,
max_turns=max_turns,
)
return assistant_message, tool_messages, finish_params, messages
except ContextOverflowError:
self._logger.context_summarization_start(1.0, self._context_summarization_cutoff)
full_msg_history.append(list(messages))
summarized_messages = await self.summarize_messages(messages)
assistant_message, tool_messages, finish_params = await self.step(
summarized_messages,
run_metadata,
turn=turn,
max_turns=max_turns,
)
return assistant_message, tool_messages, finish_params, summarized_messages

async def summarize_messages(self, messages: list[ChatMessage]) -> list[ChatMessage]:
"""Condense message history using LLM to stay within context window."""
task_context: list[ChatMessage] = list(
takewhile(lambda m: not isinstance(m, (AssistantMessage, SummaryMessage)), messages)
)
messages_to_summarize = list(messages)
preserved_tail: list[ChatMessage] = []

while True:
task_context: list[ChatMessage] = list(
takewhile(lambda m: not isinstance(m, (AssistantMessage, SummaryMessage)), messages_to_summarize)
)
Comment on lines +1159 to +1161
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be out of the loop right?


summary_prompt = [*messages, UserMessage(content=MESSAGE_SUMMARIZER)]
summary_prompt = [*messages_to_summarize, UserMessage(content=MESSAGE_SUMMARIZER)]

# We need to pass the tools to the client so that it has context of tools used in the conversation
summary = await self._client.generate(summary_prompt, self._active_tools)
try:
# We need to pass the tools to the client so that it has context of tools used in the conversation
summary = await self._client.generate(summary_prompt, self._active_tools)
except ContextOverflowError:
split_messages = _split_latest_assistant_suffix(messages_to_summarize)
if split_messages is None:
raise ContextOverflowError("Message summarization overflowed with no AssistantMessage left to peel.")

messages_to_summarize, latest_suffix = split_messages
copied_suffix = _copy_messages_for_summarized_trajectory(latest_suffix)
preserved_tail = [*copied_suffix, *preserved_tail]
continue

summary_bridge_prompt = MESSAGE_SUMMARIZER_BRIDGE_TEMPLATE.format(summary=summary.content)
summary_bridge = SummaryMessage(content=summary_bridge_prompt)
# UserMessage (not AssistantMessage) to avoid consecutive assistant messages which some providers reject
acknowledgement_msg = UserMessage(content="Got it, thanks!")
summary_bridge_prompt = MESSAGE_SUMMARIZER_BRIDGE_TEMPLATE.format(summary=summary.content)
summary_bridge = SummaryMessage(content=summary_bridge_prompt)
# UserMessage (not AssistantMessage) to avoid consecutive assistant messages which some providers reject
acknowledgement_msg = UserMessage(content="Got it, thanks!")

# Log the completed summary
summary_content = summary.content if isinstance(summary.content, str) else str(summary.content)
self._logger.context_summarization_complete(summary_content, summary_bridge_prompt)
# Log the completed summary
summary_content = summary.content if isinstance(summary.content, str) else str(summary.content)
self._logger.context_summarization_complete(summary_content, summary_bridge_prompt)

return [*task_context, summary_bridge, acknowledgement_msg]
return [*task_context, summary_bridge, acknowledgement_msg, *preserved_tail]

async def run(
self,
Expand Down Expand Up @@ -1250,8 +1316,9 @@ async def run(
self._logger.user_message(num_turns_remaining_msg)

# Pass turn info to step() for real-time logging
assistant_message, tool_messages, finish_params = await self.step(
assistant_message, tool_messages, finish_params, msgs = await self._step_with_overflow_recovery(
msgs,
full_msg_history,
run_metadata,
turn=i + 1,
max_turns=self._max_turns,
Expand Down
109 changes: 108 additions & 1 deletion tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

from PIL import Image
from pydantic import BaseModel
import pytest

from stirrup.constants import FINISH_TOOL_NAME
from stirrup.core.agent import Agent
from stirrup.core.exceptions import ContextOverflowError
from stirrup.core.models import (
AssistantMessage,
ChatMessage,
Expand All @@ -27,9 +29,10 @@
class MockLLMClient(LLMClient):
"""Mock LLM client for testing."""

def __init__(self, responses: list[AssistantMessage], max_tokens: int = 100_000) -> None:
def __init__(self, responses: list[AssistantMessage | Exception], max_tokens: int = 100_000) -> None:
self.responses = responses
self.call_count = 0
self.calls: list[list[ChatMessage]] = []
self._max_tokens = max_tokens

@property
Expand All @@ -41,8 +44,11 @@ def max_tokens(self) -> int:
return self._max_tokens

async def generate(self, messages: list[ChatMessage], tools: dict[str, Tool]) -> AssistantMessage: # noqa: ARG002
self.calls.append(list(messages))
response = self.responses[self.call_count]
self.call_count += 1
if isinstance(response, Exception):
raise response
return response


Expand Down Expand Up @@ -686,3 +692,104 @@ async def test_summarize_history_has_one_summary_per_trajectory() -> None:

# The summary content should be different between history[1] and history[2]
assert summaries_1[0].content != summaries_2[0].content


async def test_summarize_messages_peels_latest_turn_on_overflow() -> None:
"""Retry summarization without the newest assistant-led suffix when needed."""

class EchoParams(BaseModel):
message: str

def echo_executor(params: EchoParams) -> ToolResult:
return ToolResult(content=f"Echo: {params.message}")

echo_tool = Tool[EchoParams, None](
name="echo",
description="Echo a message",
parameters=EchoParams,
executor=echo_executor, # ty: ignore[invalid-argument-type]
)

client = MockLLMClient(
[
ContextOverflowError("summary overflow"),
AssistantMessage(
content="Condensed earlier progress.",
tool_calls=[],
token_usage=TokenUsage(input=50, answer=20),
),
]
)
agent = Agent(
client=client,
name="test-agent",
tools=[echo_tool],
finish_tool=SIMPLE_FINISH_TOOL,
)

messages = [
SystemMessage(content="System prompt"),
UserMessage(content="Do the task"),
AssistantMessage(
content="Earlier progress",
tool_calls=[],
token_usage=TokenUsage(input=100, answer=40),
),
AssistantMessage(
content="Using a tool now",
tool_calls=[ToolCall(name="echo", arguments='{"message": "hello"}', tool_call_id="call_1")],
token_usage=TokenUsage(input=120, answer=60),
),
ToolMessage(
content="Echo: hello",
tool_call_id="call_1",
name="echo",
success=True,
),
]

async with agent.session() as session:
summarized = await session.summarize_messages(messages)

assert client.call_count == 2

retry_prompt = client.calls[1]
assert len(retry_prompt) == 4
assert isinstance(retry_prompt[0], SystemMessage)
assert isinstance(retry_prompt[1], UserMessage)
assert isinstance(retry_prompt[2], AssistantMessage)
assert retry_prompt[2].content == "Earlier progress"
assert isinstance(retry_prompt[3], UserMessage)

assert isinstance(summarized[0], SystemMessage)
assert isinstance(summarized[1], UserMessage)
assert isinstance(summarized[2], SummaryMessage)
assert isinstance(summarized[3], UserMessage)
assert summarized[3].content == "Got it, thanks!"
assert isinstance(summarized[4], AssistantMessage)
assert summarized[4].content == "Using a tool now"
assert summarized[4].token_usage.total == 0
assert messages[3].token_usage.total == 180
assert isinstance(summarized[5], ToolMessage)
assert summarized[5].content == "Echo: hello"


async def test_summarize_messages_raises_when_nothing_can_be_peeled() -> None:
"""Raise ContextOverflowError when summarization cannot shrink the trajectory."""

client = MockLLMClient([ContextOverflowError("summary overflow")])
agent = Agent(
client=client,
name="test-agent",
tools=[],
finish_tool=SIMPLE_FINISH_TOOL,
)

messages = [
SystemMessage(content="System prompt"),
UserMessage(content="Do the task"),
]

async with agent.session() as session:
with pytest.raises(ContextOverflowError, match="no AssistantMessage left to peel"):
await session.summarize_messages(messages)