From a7bd4848f2b658cb28dff4eebd885497f8390c09 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Mon, 2 Jun 2025 12:20:30 +0000 Subject: [PATCH] refactor: Update conversation manager interface --- src/strands/agent/agent.py | 6 +++--- .../conversation_manager.py | 15 ++++++------- .../null_conversation_manager.py | 14 +++++++------ .../sliding_window_conversation_manager.py | 21 ++++++++++++------- tests/strands/agent/test_agent.py | 4 ++-- .../agent/test_conversation_manager.py | 13 ++++++++---- 6 files changed, 43 insertions(+), 30 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 0f912b54b..bfa83fe20 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -165,7 +165,7 @@ def caller(**kwargs: Any) -> Any: self._agent._record_tool_execution(tool_use, tool_result, user_message_override, messages) # Apply window management - self._agent.conversation_manager.apply_management(self._agent.messages) + self._agent.conversation_manager.apply_management(self._agent) return tool_result @@ -439,7 +439,7 @@ def _run_loop( return self._execute_event_loop_cycle(invocation_callback_handler, kwargs) finally: - self.conversation_manager.apply_management(self.messages) + self.conversation_manager.apply_management(self) def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str, Any]) -> AgentResult: """Execute the event loop cycle with retry logic for context window limits. @@ -483,7 +483,7 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str except ContextWindowOverflowException as e: # Try reducing the context size and retrying - self.conversation_manager.reduce_context(messages, e=e) + self.conversation_manager.reduce_context(self, e=e) return self._execute_event_loop_cycle(callback_handler_override, kwargs) def _record_tool_execution( diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index d18ae69a3..dbccf9410 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -1,9 +1,10 @@ """Abstract interface for conversation history management.""" from abc import ABC, abstractmethod -from typing import Optional +from typing import TYPE_CHECKING, Optional -from ...types.content import Messages +if TYPE_CHECKING: + from ...agent.agent import Agent class ConversationManager(ABC): @@ -19,22 +20,22 @@ class ConversationManager(ABC): @abstractmethod # pragma: no cover - def apply_management(self, messages: Messages) -> None: - """Applies management strategy to the provided list of messages. + def apply_management(self, agent: "Agent") -> None: + """Applies management strategy to the provided agent. Processes the conversation history to maintain appropriate size by modifying the messages list in-place. Implementations should handle message pruning, summarization, or other size management techniques to keep the conversation context within desired bounds. Args: - messages: The conversation history to manage. + agent: The agent whose conversation history will be manage. This list is modified in-place. """ pass @abstractmethod # pragma: no cover - def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: + def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: """Called when the model's context window is exceeded. This method should implement the specific strategy for reducing the window size when a context overflow occurs. @@ -48,7 +49,7 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N - Maintaining critical conversation markers Args: - messages: The conversation history to reduce. + agent: The agent whose conversation history will be reduced. This list is modified in-place. e: The exception that triggered the context reduction, if any. """ diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index 2066c08bb..4af4eb788 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -1,8 +1,10 @@ """Null implementation of conversation management.""" -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from ...agent.agent import Agent -from ...types.content import Messages from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager @@ -17,19 +19,19 @@ class NullConversationManager(ConversationManager): - Situations where the full conversation history should be preserved """ - def apply_management(self, messages: Messages) -> None: + def apply_management(self, _agent: "Agent") -> None: """Does nothing to the conversation history. Args: - messages: The conversation history that will remain unmodified. + agent: The agent whose conversation history will remain unmodified. """ pass - def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: + def reduce_context(self, _agent: "Agent", e: Optional[Exception] = None) -> None: """Does not reduce context and raises an exception. Args: - messages: The conversation history that will remain unmodified. + agent: The agent whose conversation history will remain unmodified. e: The exception that triggered the context reduction, if any. Raises: diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index f367b2721..3381247cb 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -1,7 +1,10 @@ """Sliding window conversation history management.""" import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from ...agent.agent import Agent from ...types.content import Message, Messages from ...types.exceptions import ContextWindowOverflowException @@ -45,13 +48,13 @@ def __init__(self, window_size: int = 40): """Initialize the sliding window conversation manager. Args: - window_size: Maximum number of messages to keep in history. + window_size: Maximum number of messages to keep in the agent's history. Defaults to 40 messages. """ self.window_size = window_size - def apply_management(self, messages: Messages) -> None: - """Apply the sliding window to the messages array to maintain a manageable history size. + def apply_management(self, agent: "Agent") -> None: + """Apply the sliding window to the agent's messages array to maintain a manageable history size. This method is called after every event loop cycle, as the messages array may have been modified with tool results and assistant responses. It first removes any dangling messages that might create an invalid @@ -62,9 +65,10 @@ def apply_management(self, messages: Messages) -> None: blocks to maintain conversation coherence. Args: - messages: The messages to manage. + agent: The agent whose messages will be managed. This list is modified in-place. """ + messages = agent.messages self._remove_dangling_messages(messages) if len(messages) <= self.window_size: @@ -72,7 +76,7 @@ def apply_management(self, messages: Messages) -> None: "window_size=<%s>, message_count=<%s> | skipping context reduction", len(messages), self.window_size ) return - self.reduce_context(messages) + self.reduce_context(agent) def _remove_dangling_messages(self, messages: Messages) -> None: """Remove dangling messages that would create an invalid conversation state. @@ -105,7 +109,7 @@ def _remove_dangling_messages(self, messages: Messages) -> None: if not any("toolResult" in content for content in messages[-1]["content"]): messages.pop() - def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: + def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: """Trim the oldest messages to reduce the conversation context size. The method handles special cases where trimming the messages leads to: @@ -113,7 +117,7 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N - toolUse with no corresponding toolResult Args: - messages: The messages to reduce. + agent: The agent whose messages will be reduce. This list is modified in-place. e: The exception that triggered the context reduction, if any. @@ -122,6 +126,7 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N Such as when the conversation is already minimal or when tool result messages cannot be properly converted. """ + messages = agent.messages # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index ea06fb4ee..4a63fa31f 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -318,7 +318,7 @@ def test_agent__call__( ) callback_handler.assert_called() - conversation_manager_spy.apply_management.assert_called_with(agent.messages) + conversation_manager_spy.apply_management.assert_called_with(agent) def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler, agent, tool, mock_event_loop_cycle): @@ -583,7 +583,7 @@ def test_agent_tool(mock_randint, agent): } assert tru_result == exp_result - conversation_manager_spy.apply_management.assert_called_with(agent.messages) + conversation_manager_spy.apply_management.assert_called_with(agent) def test_agent_tool_user_message_override(agent): diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index b6132f1db..bbec3cd11 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -1,6 +1,7 @@ import pytest import strands +from strands.agent.agent import Agent from strands.types.exceptions import ContextWindowOverflowException @@ -160,7 +161,8 @@ def conversation_manager(request): indirect=["conversation_manager"], ) def test_apply_management(conversation_manager, messages, expected_messages): - conversation_manager.apply_management(messages) + test_agent = Agent(messages=messages) + conversation_manager.apply_management(test_agent) assert messages == expected_messages @@ -172,9 +174,10 @@ def test_sliding_window_conversation_manager_with_untrimmable_history_raises_con {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, ] original_messages = messages.copy() + test_agent = Agent(messages=messages) with pytest.raises(ContextWindowOverflowException): - manager.apply_management(messages) + manager.apply_management(test_agent) assert messages == original_messages @@ -187,8 +190,9 @@ def test_null_conversation_manager_reduce_context_raises_context_window_overflow {"role": "assistant", "content": [{"text": "Hi there"}]}, ] original_messages = messages.copy() + test_agent = Agent(messages=messages) - manager.apply_management(messages) + manager.apply_management(test_agent) with pytest.raises(ContextWindowOverflowException): manager.reduce_context(messages) @@ -204,8 +208,9 @@ def test_null_conversation_manager_reduce_context_with_exception_raises_same_exc {"role": "assistant", "content": [{"text": "Hi there"}]}, ] original_messages = messages.copy() + test_agent = Agent(messages=messages) - manager.apply_management(messages) + manager.apply_management(test_agent) with pytest.raises(RuntimeError): manager.reduce_context(messages, RuntimeError("test"))