Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,13 @@ def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> M

self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts})
self._interrupt_state.activate()
if isinstance(node.executor, Agent):
self._interrupt_state.context[node.node_id] = {
"activated": node.executor._interrupt_state.activated,
"interrupt_state": node.executor._interrupt_state.to_dict(),
"state": node.executor.state.get(),
"messages": node.executor.messages,
}

return MultiAgentNodeInterruptEvent(node.node_id, interrupts)

Expand Down Expand Up @@ -920,16 +927,6 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
if agent_response is None:
raise ValueError(f"Node '{node.node_id}' did not produce a result event")

if agent_response.stop_reason == "interrupt":
node.executor.messages.pop() # remove interrupted tool use message
node.executor._interrupt_state.deactivate()

raise NotImplementedError(
f"node_id=<{node.node_id}>, "
"issue=<https://github.com/strands-agents/sdk-python/issues/204> "
"| user raised interrupt from an agent node"
)

# Extract metrics with defaults
response_metrics = getattr(agent_response, "metrics", None)
usage = getattr(
Expand All @@ -940,18 +937,24 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
node_result = NodeResult(
result=agent_response,
execution_time=round((time.time() - start_time) * 1000),
status=Status.COMPLETED,
status=Status.INTERRUPTED if agent_response.stop_reason == "interrupt" else Status.COMPLETED,
accumulated_usage=usage,
accumulated_metrics=metrics,
execution_count=1,
interrupts=agent_response.interrupts or [],
)
else:
raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported")

# Mark as completed
node.execution_status = Status.COMPLETED
node.result = node_result
node.execution_time = node_result.execution_time

if node_result.status == Status.INTERRUPTED:
yield self._activate_interrupt(node, node_result.interrupts)
return

# Mark as completed
node.execution_status = Status.COMPLETED
self.state.completed_nodes.add(node)
self.state.results[node.node_id] = node_result
self.state.execution_order.append(node)
Expand Down Expand Up @@ -1018,6 +1021,8 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None:
def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
"""Build input text for a node based on dependency outputs.

If resuming from an interrupt, return user responses.

Example formatted output:
```
Original Task: Analyze the quarterly sales data and create a summary report
Expand All @@ -1032,6 +1037,21 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
- Agent: Data validation complete. All records verified, no anomalies detected.
```
"""
if self._interrupt_state.activated:
context = self._interrupt_state.context
if node.node_id in context and context[node.node_id]["activated"]:
agent_context = context[node.node_id]
agent = cast(Agent, node.executor)
agent.messages = agent_context["messages"]
agent.state = AgentState(agent_context["state"])
agent._interrupt_state = _InterruptState.from_dict(agent_context["interrupt_state"])

responses = context["responses"]
interrupts = agent._interrupt_state.interrupts
return [
response for response in responses if response["interruptResponse"]["interruptId"] in interrupts
]

# Get satisfied dependencies
dependency_results = {}
for edge in self.edges:
Expand Down
98 changes: 98 additions & 0 deletions tests/strands/multiagent/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen
agent.id = agent_id or f"{name}_id"
agent._session_manager = None
agent.hooks = HookRegistry()
agent.state = AgentState()
agent.messages = []
agent._interrupt_state = _InterruptState()

if metrics is None:
metrics = Mock(
Expand Down Expand Up @@ -2153,3 +2156,98 @@ def test_graph_interrupt_on_before_node_call_event(interrupt_hook):
assert tru_message == exp_message

assert multiagent_result.execution_time >= first_execution_time


def test_graph_interrupt_on_agent(agenerator):
exp_interrupts = [
Interrupt(
id="test_id",
name="test_name",
reason="test_reason",
)
]

agent = create_mock_agent("test_agent", "Task completed")
agent.stream_async = Mock()
agent.stream_async.return_value = agenerator(
[
{
"result": AgentResult(
message={},
stop_reason="interrupt",
state={},
metrics=None,
interrupts=exp_interrupts,
),
},
],
)

builder = GraphBuilder()
builder.add_node(agent, "test_agent")
graph = builder.build()

multiagent_result = graph("Test task")

tru_result_status = multiagent_result.status
exp_result_status = Status.INTERRUPTED
assert tru_result_status == exp_result_status

tru_state_status = graph.state.status
exp_state_status = Status.INTERRUPTED
assert tru_state_status == exp_state_status

tru_node_ids = [node.node_id for node in graph.state.interrupted_nodes]
exp_node_ids = ["test_agent"]
assert tru_node_ids == exp_node_ids

tru_interrupts = multiagent_result.interrupts
assert tru_interrupts == exp_interrupts

interrupt = multiagent_result.interrupts[0]

agent.stream_async = Mock()
agent.stream_async.return_value = agenerator(
[
{
"result": AgentResult(
message={},
stop_reason="end_turn",
state={},
metrics=None,
),
},
],
)
graph._interrupt_state.context["test_agent"] = {
"activated": True,
"interrupt_state": {
"activated": True,
"context": {},
"interrupts": {interrupt.id: interrupt.to_dict()},
},
"messages": [],
"state": {},
}

responses = [
{
"interruptResponse": {
"interruptId": interrupt.id,
"response": "test_response",
},
},
]
multiagent_result = graph(responses)

tru_result_status = multiagent_result.status
exp_result_status = Status.COMPLETED
assert tru_result_status == exp_result_status

tru_state_status = graph.state.status
exp_state_status = Status.COMPLETED
assert tru_state_status == exp_state_status

assert len(multiagent_result.results) == 1

agent.stream_async.assert_called_once_with(responses, invocation_state={})
129 changes: 124 additions & 5 deletions tests_integ/interrupts/multiagent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,83 @@

from strands import Agent, tool
from strands.interrupt import Interrupt
from strands.multiagent import Swarm
from strands.multiagent import GraphBuilder, Swarm
from strands.multiagent.base import Status
from strands.types.tools import ToolContext


@pytest.fixture
def day_tool():
@tool(name="day_tool", context=True)
def func(tool_context: ToolContext) -> str:
response = tool_context.interrupt("day_interrupt", reason="need day")
return response

return func


@pytest.fixture
def time_tool():
@tool(name="time_tool")
def func():
return "12:01"

return func


@pytest.fixture
def weather_tool():
@tool(name="weather_tool", context=True)
def func(tool_context: ToolContext) -> str:
response = tool_context.interrupt("test_interrupt", reason="need weather")
response = tool_context.interrupt("weather_interrupt", reason="need weather")
return response

return func


@pytest.fixture
def swarm(weather_tool):
weather_agent = Agent(name="weather", tools=[weather_tool])
def info_agent():
return Agent(name="info")


@pytest.fixture
def day_agent(day_tool):
return Agent(name="day", tools=[day_tool])


@pytest.fixture
def time_agent(time_tool):
return Agent(name="time", tools=[time_tool])


@pytest.fixture
def weather_agent(weather_tool):
return Agent(name="weather", tools=[weather_tool])


@pytest.fixture
def swarm(weather_agent):
return Swarm([weather_agent])


@pytest.fixture
def graph(info_agent, day_agent, time_agent, weather_agent):
builder = GraphBuilder()

builder.add_node(info_agent, "info")
builder.add_node(day_agent, "day")
builder.add_node(time_agent, "time")
builder.add_node(weather_agent, "weather")

builder.add_edge("info", "day")
builder.add_edge("info", "time")
builder.add_edge("info", "weather")

builder.set_entry_point("info")

return builder.build()


def test_swarm_interrupt_agent(swarm):
multiagent_result = swarm("What is the weather?")

Expand All @@ -38,7 +93,7 @@ def test_swarm_interrupt_agent(swarm):
exp_interrupts = [
Interrupt(
id=ANY,
name="test_interrupt",
name="weather_interrupt",
reason="need weather",
),
]
Expand All @@ -65,3 +120,67 @@ def test_swarm_interrupt_agent(swarm):

weather_message = json.dumps(weather_result.result.message).lower()
assert "sunny" in weather_message


def test_graph_interrupt_agent(graph):
multiagent_result = graph("What is the day, time, and weather?")

tru_result_status = multiagent_result.status
exp_result_status = Status.INTERRUPTED
assert tru_result_status == exp_result_status

tru_state_status = graph.state.status
exp_state_status = Status.INTERRUPTED
assert tru_state_status == exp_state_status

tru_node_ids = sorted([node.node_id for node in graph.state.interrupted_nodes])
exp_node_ids = ["day", "weather"]
assert tru_node_ids == exp_node_ids

tru_interrupts = sorted(multiagent_result.interrupts, key=lambda interrupt: interrupt.name)
exp_interrupts = [
Interrupt(
id=ANY,
name="day_interrupt",
reason="need day",
),
Interrupt(
id=ANY,
name="weather_interrupt",
reason="need weather",
),
]
assert tru_interrupts == exp_interrupts

responses = [
{
"interruptResponse": {
"interruptId": tru_interrupts[0].id,
"response": "monday",
},
},
{
"interruptResponse": {
"interruptId": tru_interrupts[1].id,
"response": "sunny",
},
},
]
multiagent_result = graph(responses)

tru_result_status = multiagent_result.status
exp_result_status = Status.COMPLETED
assert tru_result_status == exp_result_status

tru_state_status = graph.state.status
exp_state_status = Status.COMPLETED
assert tru_state_status == exp_state_status

assert len(multiagent_result.results) == 4

day_message = json.dumps(multiagent_result.results["day"].result.message).lower()
time_message = json.dumps(multiagent_result.results["time"].result.message).lower()
weather_message = json.dumps(multiagent_result.results["weather"].result.message).lower()
assert "monday" in day_message
assert "12:01" in time_message
assert "sunny" in weather_message
Loading
Loading