diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 7126644e6..6df775d20 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -644,7 +644,10 @@ async def _run_loop( Yields: Events from the event loop cycle. """ - await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self)) + before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async( + BeforeInvocationEvent(agent=self, messages=messages) + ) + messages = before_invocation_event.messages if before_invocation_event.messages is not None else messages agent_result: AgentResult | None = None try: diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 5e11524d1..70764e342 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from ..agent.agent_result import AgentResult -from ..types.content import Message +from ..types.content import Message, Messages from ..types.interrupt import _Interruptible from ..types.streaming import StopReason from ..types.tools import AgentTool, ToolResult, ToolUse @@ -43,9 +43,16 @@ class BeforeInvocationEvent(HookEvent): - Agent.__call__ - Agent.stream_async - Agent.structured_output + + Attributes: + messages: The input messages for this invocation. Can be modified by hooks + to redact or transform content before processing. """ - pass + messages: Messages | None = None + + def _can_write(self, name: str) -> bool: + return name == "messages" @dataclass diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index 9203478b2..83cb1af24 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -11,7 +11,7 @@ BeforeToolCallEvent, MessageAddedEvent, ) -from strands.types.content import Message +from strands.types.content import Message, Messages from strands.types.tools import ToolResult, ToolUse @@ -20,6 +20,11 @@ def agent(): return Mock() +@pytest.fixture +def sample_messages() -> Messages: + return [{"role": "user", "content": [{"text": "Hello, agent!"}]}] + + @pytest.fixture def tool(): tool = Mock() @@ -52,6 +57,11 @@ def start_request_event(agent): return BeforeInvocationEvent(agent=agent) +@pytest.fixture +def start_request_event_with_messages(agent, sample_messages): + return BeforeInvocationEvent(agent=agent, messages=sample_messages) + + @pytest.fixture def messaged_added_event(agent): return MessageAddedEvent(agent=agent, message=Mock()) @@ -159,3 +169,29 @@ def test_after_invocation_event_properties_not_writable(agent): with pytest.raises(AttributeError, match="Property agent is not writable"): event.agent = Mock() + + +def test_before_invocation_event_messages_default_none(agent): + """Test that BeforeInvocationEvent.messages defaults to None for backward compatibility.""" + event = BeforeInvocationEvent(agent=agent) + assert event.messages is None + + +def test_before_invocation_event_messages_writable(agent, sample_messages): + """Test that BeforeInvocationEvent.messages can be modified in-place for guardrail redaction.""" + event = BeforeInvocationEvent(agent=agent, messages=sample_messages) + + # Should be able to modify the messages list in-place + event.messages[0]["content"] = [{"text": "[REDACTED]"}] + assert event.messages[0]["content"] == [{"text": "[REDACTED]"}] + + # Should be able to reassign messages entirely + new_messages: Messages = [{"role": "user", "content": [{"text": "Different message"}]}] + event.messages = new_messages + assert event.messages == new_messages + + +def test_before_invocation_event_agent_not_writable(start_request_event_with_messages): + """Test that BeforeInvocationEvent.agent is not writable.""" + with pytest.raises(AttributeError, match="Property agent is not writable"): + start_request_event_with_messages.agent = Mock() diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 00b9d368a..be71b5fcf 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -160,7 +160,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u assert length == 12 - assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1]) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], @@ -214,7 +214,11 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m """Verify that the correct hook events are emitted as part of stream_async.""" iterator = agent.stream_async("test message") await anext(iterator) - assert hook_provider.events_received == [BeforeInvocationEvent(agent=agent)] + + # Verify first event is BeforeInvocationEvent with messages + assert len(hook_provider.events_received) == 1 + assert hook_provider.events_received[0].messages is not None + assert hook_provider.events_received[0].messages[0]["role"] == "user" # iterate the rest result = None @@ -226,7 +230,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m assert length == 12 - assert next(events) == BeforeInvocationEvent(agent=agent) + assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1]) assert next(events) == MessageAddedEvent( agent=agent, message=agent.messages[0], @@ -596,3 +600,89 @@ async def handle_after_model_call(event: AfterModelCallEvent): # Should succeed after: custom retry + 2 throttle retries assert result.stop_reason == "end_turn" assert result.message["content"][0]["text"] == "Success after mixed retries" + + +def test_before_invocation_event_message_modification(): + """Test that hooks can modify messages in BeforeInvocationEvent for input guardrails.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "I received your redacted message"}], + }, + ] + ) + + modified_content = None + + async def input_guardrail_hook(event: BeforeInvocationEvent): + """Simulates a guardrail that redacts sensitive content.""" + nonlocal modified_content + if event.messages is not None: + for message in event.messages: + if message.get("role") == "user": + content = message.get("content", []) + for block in content: + if "text" in block and "SECRET" in block["text"]: + # Redact sensitive content in-place + block["text"] = block["text"].replace("SECRET", "[REDACTED]") + modified_content = event.messages[0]["content"][0]["text"] + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, input_guardrail_hook) + + agent("My password is SECRET123") + + # Verify the message was modified before being processed + assert modified_content == "My password is [REDACTED]123" + # Verify the modified message was added to agent's conversation history + assert agent.messages[0]["content"][0]["text"] == "My password is [REDACTED]123" + + +def test_before_invocation_event_message_overwrite(): + """Test that hooks can overwrite messages in BeforeInvocationEvent.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "I received your message message"}], + }, + ] + ) + + async def overwrite_input_hook(event: BeforeInvocationEvent): + event.messages = [{"role": "user", "content": [{"text": "GOODBYE"}]}] + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, overwrite_input_hook) + + agent("HELLO") + + # Verify the message was overwritten to agent's conversation history + assert agent.messages[0]["content"][0]["text"] == "GOODBYE" + + +@pytest.mark.asyncio +async def test_before_invocation_event_messages_none_in_structured_output(agenerator): + """Test that BeforeInvocationEvent.messages is None when called from deprecated structured_output.""" + + class Person(BaseModel): + name: str + age: int + + mock_provider = MockedModelProvider([]) + mock_provider.structured_output = Mock(return_value=agenerator([{"output": Person(name="Test", age=30)}])) + + received_messages = "not_set" + + async def capture_messages_hook(event: BeforeInvocationEvent): + nonlocal received_messages + received_messages = event.messages + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeInvocationEvent, capture_messages_hook) + + await agent.structured_output_async(Person, "Test prompt") + + # structured_output_async uses deprecated path that doesn't pass messages + assert received_messages is None