diff --git a/src/stirrup/core/agent.py b/src/stirrup/core/agent.py index 52d3c2c..32d73a0 100644 --- a/src/stirrup/core/agent.py +++ b/src/stirrup/core/agent.py @@ -24,6 +24,7 @@ TURNS_REMAINING_WARNING_THRESHOLD, ) from stirrup.core.cache import CacheManager, CacheState, compute_task_hash +from stirrup.core.exceptions import ContextOverflowError from stirrup.core.models import ( AssistantMessage, ChatMessage, @@ -178,6 +179,27 @@ def _get_model_speed_stats(messages: list[list[ChatMessage]], model_slug: str) - } +def _split_latest_assistant_suffix( + messages: list[ChatMessage], +) -> tuple[list[ChatMessage], list[ChatMessage]] | None: + """Split messages into older history and the newest assistant-led suffix.""" + for index in range(len(messages) - 1, -1, -1): + if isinstance(messages[index], AssistantMessage): + return messages[:index], messages[index:] + return None + + +def _copy_messages_for_summarized_trajectory(messages: list[ChatMessage]) -> list[ChatMessage]: + """Copy preserved messages, zeroing assistant token usage in the new trajectory.""" + copied_messages: list[ChatMessage] = [] + for message in messages: + if isinstance(message, AssistantMessage): + copied_messages.append(message.model_copy(update={"token_usage": TokenUsage()})) + continue + copied_messages.append(message.model_copy()) + return copied_messages + + type JsonSchema = dict[str, object] @@ -1098,27 +1120,71 @@ async def step( return assistant_message, tool_messages, finish_params + async def _step_with_overflow_recovery( + self, + messages: list[ChatMessage], + full_msg_history: list[list[ChatMessage]], + run_metadata: dict[str, list[Any]], + *, + turn: int = 0, + max_turns: int = 0, + ) -> tuple[AssistantMessage, list[ToolMessage], FinishParams | None, list[ChatMessage]]: + """Run one step, summarizing once if the client reports context overflow.""" + try: + assistant_message, tool_messages, finish_params = await self.step( + messages, + run_metadata, + turn=turn, + max_turns=max_turns, + ) + return assistant_message, tool_messages, finish_params, messages + except ContextOverflowError: + self._logger.context_summarization_start(1.0, self._context_summarization_cutoff) + full_msg_history.append(list(messages)) + summarized_messages = await self.summarize_messages(messages) + assistant_message, tool_messages, finish_params = await self.step( + summarized_messages, + run_metadata, + turn=turn, + max_turns=max_turns, + ) + return assistant_message, tool_messages, finish_params, summarized_messages + async def summarize_messages(self, messages: list[ChatMessage]) -> list[ChatMessage]: """Condense message history using LLM to stay within context window.""" - task_context: list[ChatMessage] = list( - takewhile(lambda m: not isinstance(m, (AssistantMessage, SummaryMessage)), messages) - ) + messages_to_summarize = list(messages) + preserved_tail: list[ChatMessage] = [] + + while True: + task_context: list[ChatMessage] = list( + takewhile(lambda m: not isinstance(m, (AssistantMessage, SummaryMessage)), messages_to_summarize) + ) - summary_prompt = [*messages, UserMessage(content=MESSAGE_SUMMARIZER)] + summary_prompt = [*messages_to_summarize, UserMessage(content=MESSAGE_SUMMARIZER)] - # We need to pass the tools to the client so that it has context of tools used in the conversation - summary = await self._client.generate(summary_prompt, self._active_tools) + try: + # We need to pass the tools to the client so that it has context of tools used in the conversation + summary = await self._client.generate(summary_prompt, self._active_tools) + except ContextOverflowError: + split_messages = _split_latest_assistant_suffix(messages_to_summarize) + if split_messages is None: + raise ContextOverflowError("Message summarization overflowed with no AssistantMessage left to peel.") + + messages_to_summarize, latest_suffix = split_messages + copied_suffix = _copy_messages_for_summarized_trajectory(latest_suffix) + preserved_tail = [*copied_suffix, *preserved_tail] + continue - summary_bridge_prompt = MESSAGE_SUMMARIZER_BRIDGE_TEMPLATE.format(summary=summary.content) - summary_bridge = SummaryMessage(content=summary_bridge_prompt) - # UserMessage (not AssistantMessage) to avoid consecutive assistant messages which some providers reject - acknowledgement_msg = UserMessage(content="Got it, thanks!") + summary_bridge_prompt = MESSAGE_SUMMARIZER_BRIDGE_TEMPLATE.format(summary=summary.content) + summary_bridge = SummaryMessage(content=summary_bridge_prompt) + # UserMessage (not AssistantMessage) to avoid consecutive assistant messages which some providers reject + acknowledgement_msg = UserMessage(content="Got it, thanks!") - # Log the completed summary - summary_content = summary.content if isinstance(summary.content, str) else str(summary.content) - self._logger.context_summarization_complete(summary_content, summary_bridge_prompt) + # Log the completed summary + summary_content = summary.content if isinstance(summary.content, str) else str(summary.content) + self._logger.context_summarization_complete(summary_content, summary_bridge_prompt) - return [*task_context, summary_bridge, acknowledgement_msg] + return [*task_context, summary_bridge, acknowledgement_msg, *preserved_tail] async def run( self, @@ -1250,8 +1316,9 @@ async def run( self._logger.user_message(num_turns_remaining_msg) # Pass turn info to step() for real-time logging - assistant_message, tool_messages, finish_params = await self.step( + assistant_message, tool_messages, finish_params, msgs = await self._step_with_overflow_recovery( msgs, + full_msg_history, run_metadata, turn=i + 1, max_turns=self._max_turns, diff --git a/tests/test_agent.py b/tests/test_agent.py index e042898..b3eac99 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -4,9 +4,11 @@ from PIL import Image from pydantic import BaseModel +import pytest from stirrup.constants import FINISH_TOOL_NAME from stirrup.core.agent import Agent +from stirrup.core.exceptions import ContextOverflowError from stirrup.core.models import ( AssistantMessage, ChatMessage, @@ -27,9 +29,10 @@ class MockLLMClient(LLMClient): """Mock LLM client for testing.""" - def __init__(self, responses: list[AssistantMessage], max_tokens: int = 100_000) -> None: + def __init__(self, responses: list[AssistantMessage | Exception], max_tokens: int = 100_000) -> None: self.responses = responses self.call_count = 0 + self.calls: list[list[ChatMessage]] = [] self._max_tokens = max_tokens @property @@ -41,8 +44,11 @@ def max_tokens(self) -> int: return self._max_tokens async def generate(self, messages: list[ChatMessage], tools: dict[str, Tool]) -> AssistantMessage: # noqa: ARG002 + self.calls.append(list(messages)) response = self.responses[self.call_count] self.call_count += 1 + if isinstance(response, Exception): + raise response return response @@ -686,3 +692,104 @@ async def test_summarize_history_has_one_summary_per_trajectory() -> None: # The summary content should be different between history[1] and history[2] assert summaries_1[0].content != summaries_2[0].content + + +async def test_summarize_messages_peels_latest_turn_on_overflow() -> None: + """Retry summarization without the newest assistant-led suffix when needed.""" + + class EchoParams(BaseModel): + message: str + + def echo_executor(params: EchoParams) -> ToolResult: + return ToolResult(content=f"Echo: {params.message}") + + echo_tool = Tool[EchoParams, None]( + name="echo", + description="Echo a message", + parameters=EchoParams, + executor=echo_executor, # ty: ignore[invalid-argument-type] + ) + + client = MockLLMClient( + [ + ContextOverflowError("summary overflow"), + AssistantMessage( + content="Condensed earlier progress.", + tool_calls=[], + token_usage=TokenUsage(input=50, answer=20), + ), + ] + ) + agent = Agent( + client=client, + name="test-agent", + tools=[echo_tool], + finish_tool=SIMPLE_FINISH_TOOL, + ) + + messages = [ + SystemMessage(content="System prompt"), + UserMessage(content="Do the task"), + AssistantMessage( + content="Earlier progress", + tool_calls=[], + token_usage=TokenUsage(input=100, answer=40), + ), + AssistantMessage( + content="Using a tool now", + tool_calls=[ToolCall(name="echo", arguments='{"message": "hello"}', tool_call_id="call_1")], + token_usage=TokenUsage(input=120, answer=60), + ), + ToolMessage( + content="Echo: hello", + tool_call_id="call_1", + name="echo", + success=True, + ), + ] + + async with agent.session() as session: + summarized = await session.summarize_messages(messages) + + assert client.call_count == 2 + + retry_prompt = client.calls[1] + assert len(retry_prompt) == 4 + assert isinstance(retry_prompt[0], SystemMessage) + assert isinstance(retry_prompt[1], UserMessage) + assert isinstance(retry_prompt[2], AssistantMessage) + assert retry_prompt[2].content == "Earlier progress" + assert isinstance(retry_prompt[3], UserMessage) + + assert isinstance(summarized[0], SystemMessage) + assert isinstance(summarized[1], UserMessage) + assert isinstance(summarized[2], SummaryMessage) + assert isinstance(summarized[3], UserMessage) + assert summarized[3].content == "Got it, thanks!" + assert isinstance(summarized[4], AssistantMessage) + assert summarized[4].content == "Using a tool now" + assert summarized[4].token_usage.total == 0 + assert messages[3].token_usage.total == 180 + assert isinstance(summarized[5], ToolMessage) + assert summarized[5].content == "Echo: hello" + + +async def test_summarize_messages_raises_when_nothing_can_be_peeled() -> None: + """Raise ContextOverflowError when summarization cannot shrink the trajectory.""" + + client = MockLLMClient([ContextOverflowError("summary overflow")]) + agent = Agent( + client=client, + name="test-agent", + tools=[], + finish_tool=SIMPLE_FINISH_TOOL, + ) + + messages = [ + SystemMessage(content="System prompt"), + UserMessage(content="Do the task"), + ] + + async with agent.session() as session: + with pytest.raises(ContextOverflowError, match="no AssistantMessage left to peel"): + await session.summarize_messages(messages)