Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,10 @@ async def _run_loop(
Yields:
Events from the event loop cycle.
"""
await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self))
before_invocation_event, _interrupts = await self.hooks.invoke_callbacks_async(
BeforeInvocationEvent(agent=self, messages=messages)
)
messages = before_invocation_event.messages if before_invocation_event.messages is not None else messages

agent_result: AgentResult | None = None
try:
Expand Down
11 changes: 9 additions & 2 deletions src/strands/hooks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
if TYPE_CHECKING:
from ..agent.agent_result import AgentResult

from ..types.content import Message
from ..types.content import Message, Messages
from ..types.interrupt import _Interruptible
from ..types.streaming import StopReason
from ..types.tools import AgentTool, ToolResult, ToolUse
Expand Down Expand Up @@ -43,9 +43,16 @@ class BeforeInvocationEvent(HookEvent):
- Agent.__call__
- Agent.stream_async
- Agent.structured_output

Attributes:
messages: The input messages for this invocation. Can be modified by hooks
to redact or transform content before processing.
"""

pass
messages: Messages | None = None

def _can_write(self, name: str) -> bool:
return name == "messages"


@dataclass
Expand Down
38 changes: 37 additions & 1 deletion tests/strands/agent/hooks/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
BeforeToolCallEvent,
MessageAddedEvent,
)
from strands.types.content import Message
from strands.types.content import Message, Messages
from strands.types.tools import ToolResult, ToolUse


Expand All @@ -20,6 +20,11 @@ def agent():
return Mock()


@pytest.fixture
def sample_messages() -> Messages:
return [{"role": "user", "content": [{"text": "Hello, agent!"}]}]


@pytest.fixture
def tool():
tool = Mock()
Expand Down Expand Up @@ -52,6 +57,11 @@ def start_request_event(agent):
return BeforeInvocationEvent(agent=agent)


@pytest.fixture
def start_request_event_with_messages(agent, sample_messages):
return BeforeInvocationEvent(agent=agent, messages=sample_messages)


@pytest.fixture
def messaged_added_event(agent):
return MessageAddedEvent(agent=agent, message=Mock())
Expand Down Expand Up @@ -159,3 +169,29 @@ def test_after_invocation_event_properties_not_writable(agent):

with pytest.raises(AttributeError, match="Property agent is not writable"):
event.agent = Mock()


def test_before_invocation_event_messages_default_none(agent):
"""Test that BeforeInvocationEvent.messages defaults to None for backward compatibility."""
event = BeforeInvocationEvent(agent=agent)
assert event.messages is None


def test_before_invocation_event_messages_writable(agent, sample_messages):
"""Test that BeforeInvocationEvent.messages can be modified in-place for guardrail redaction."""
event = BeforeInvocationEvent(agent=agent, messages=sample_messages)

# Should be able to modify the messages list in-place
event.messages[0]["content"] = [{"text": "[REDACTED]"}]
assert event.messages[0]["content"] == [{"text": "[REDACTED]"}]

# Should be able to reassign messages entirely
new_messages: Messages = [{"role": "user", "content": [{"text": "Different message"}]}]
event.messages = new_messages
assert event.messages == new_messages


def test_before_invocation_event_agent_not_writable(start_request_event_with_messages):
"""Test that BeforeInvocationEvent.agent is not writable."""
with pytest.raises(AttributeError, match="Property agent is not writable"):
start_request_event_with_messages.agent = Mock()
96 changes: 93 additions & 3 deletions tests/strands/agent/test_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, mock_model, tool_u

assert length == 12

assert next(events) == BeforeInvocationEvent(agent=agent)
assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1])
assert next(events) == MessageAddedEvent(
agent=agent,
message=agent.messages[0],
Expand Down Expand Up @@ -214,7 +214,11 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m
"""Verify that the correct hook events are emitted as part of stream_async."""
iterator = agent.stream_async("test message")
await anext(iterator)
assert hook_provider.events_received == [BeforeInvocationEvent(agent=agent)]

# Verify first event is BeforeInvocationEvent with messages
assert len(hook_provider.events_received) == 1
assert hook_provider.events_received[0].messages is not None
assert hook_provider.events_received[0].messages[0]["role"] == "user"

# iterate the rest
result = None
Expand All @@ -226,7 +230,7 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, mock_m

assert length == 12

assert next(events) == BeforeInvocationEvent(agent=agent)
assert next(events) == BeforeInvocationEvent(agent=agent, messages=agent.messages[0:1])
assert next(events) == MessageAddedEvent(
agent=agent,
message=agent.messages[0],
Expand Down Expand Up @@ -596,3 +600,89 @@ async def handle_after_model_call(event: AfterModelCallEvent):
# Should succeed after: custom retry + 2 throttle retries
assert result.stop_reason == "end_turn"
assert result.message["content"][0]["text"] == "Success after mixed retries"


def test_before_invocation_event_message_modification():
"""Test that hooks can modify messages in BeforeInvocationEvent for input guardrails."""
mock_provider = MockedModelProvider(
[
{
"role": "assistant",
"content": [{"text": "I received your redacted message"}],
},
]
)

modified_content = None

async def input_guardrail_hook(event: BeforeInvocationEvent):
"""Simulates a guardrail that redacts sensitive content."""
nonlocal modified_content
if event.messages is not None:
for message in event.messages:
if message.get("role") == "user":
content = message.get("content", [])
for block in content:
if "text" in block and "SECRET" in block["text"]:
# Redact sensitive content in-place
block["text"] = block["text"].replace("SECRET", "[REDACTED]")
modified_content = event.messages[0]["content"][0]["text"]

agent = Agent(model=mock_provider)
agent.hooks.add_callback(BeforeInvocationEvent, input_guardrail_hook)

agent("My password is SECRET123")

# Verify the message was modified before being processed
assert modified_content == "My password is [REDACTED]123"
# Verify the modified message was added to agent's conversation history
assert agent.messages[0]["content"][0]["text"] == "My password is [REDACTED]123"


def test_before_invocation_event_message_overwrite():
"""Test that hooks can overwrite messages in BeforeInvocationEvent."""
mock_provider = MockedModelProvider(
[
{
"role": "assistant",
"content": [{"text": "I received your message message"}],
},
]
)

async def overwrite_input_hook(event: BeforeInvocationEvent):
event.messages = [{"role": "user", "content": [{"text": "GOODBYE"}]}]

agent = Agent(model=mock_provider)
agent.hooks.add_callback(BeforeInvocationEvent, overwrite_input_hook)

agent("HELLO")

# Verify the message was overwritten to agent's conversation history
assert agent.messages[0]["content"][0]["text"] == "GOODBYE"


@pytest.mark.asyncio
async def test_before_invocation_event_messages_none_in_structured_output(agenerator):
"""Test that BeforeInvocationEvent.messages is None when called from deprecated structured_output."""

class Person(BaseModel):
name: str
age: int

mock_provider = MockedModelProvider([])
mock_provider.structured_output = Mock(return_value=agenerator([{"output": Person(name="Test", age=30)}]))

received_messages = "not_set"

async def capture_messages_hook(event: BeforeInvocationEvent):
nonlocal received_messages
received_messages = event.messages

agent = Agent(model=mock_provider)
agent.hooks.add_callback(BeforeInvocationEvent, capture_messages_hook)

await agent.structured_output_async(Person, "Test prompt")

# structured_output_async uses deprecated path that doesn't pass messages
assert received_messages is None
Loading