diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 6156d332c..d50b67e48 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -34,11 +34,13 @@ MultiAgentInitializedEvent, ) from ..hooks import HookProvider, HookRegistry +from ..interrupt import Interrupt, _InterruptState from ..session import SessionManager from ..telemetry import get_tracer from ..types._events import ( MultiAgentHandoffEvent, MultiAgentNodeCancelEvent, + MultiAgentNodeInterruptEvent, MultiAgentNodeStartEvent, MultiAgentNodeStopEvent, MultiAgentNodeStreamEvent, @@ -63,10 +65,15 @@ class GraphState: status: Current execution status of the graph. completed_nodes: Set of nodes that have completed execution. failed_nodes: Set of nodes that failed during execution. + interrupted_nodes: Set of nodes that user interrupted during execution. execution_order: List of nodes in the order they were executed. task: The original input prompt/query provided to the graph execution. This represents the actual work to be performed by the graph as a whole. Entry point nodes receive this task as their input if they have no dependencies. + start_time: Timestamp when the current invocation started. + Resets on each invocation, even when resuming from interrupt. + execution_time: Execution time of current invocation in milliseconds. + Excludes time spent waiting for interrupt responses. """ # Task (with default empty string) @@ -76,6 +83,7 @@ class GraphState: status: Status = Status.PENDING completed_nodes: set["GraphNode"] = field(default_factory=set) failed_nodes: set["GraphNode"] = field(default_factory=set) + interrupted_nodes: set["GraphNode"] = field(default_factory=set) execution_order: list["GraphNode"] = field(default_factory=list) start_time: float = field(default_factory=time.time) @@ -108,7 +116,7 @@ def should_continue( # Check timeout (only if set) if execution_timeout is not None: - elapsed = time.time() - self.start_time + elapsed = self.execution_time / 1000 + time.time() - self.start_time if elapsed > execution_timeout: return False, f"Execution timed out: {execution_timeout}s" @@ -122,6 +130,7 @@ class GraphResult(MultiAgentResult): total_nodes: int = 0 completed_nodes: int = 0 failed_nodes: int = 0 + interrupted_nodes: int = 0 execution_order: list["GraphNode"] = field(default_factory=list) edges: list[Tuple["GraphNode", "GraphNode"]] = field(default_factory=list) entry_points: list["GraphNode"] = field(default_factory=list) @@ -148,13 +157,7 @@ def should_traverse(self, state: GraphState) -> bool: @dataclass class GraphNode: - """Represents a node in the graph. - - The execution_status tracks the node's lifecycle within graph orchestration: - - PENDING: Node hasn't started executing yet - - EXECUTING: Node is currently running - - COMPLETED/FAILED: Node finished executing (regardless of result quality) - """ + """Represents a node in the graph.""" node_id: str executor: Agent | MultiAgentBase @@ -445,6 +448,7 @@ def __init__( self.node_timeout = node_timeout self.reset_on_revisit = reset_on_revisit self.state = GraphState() + self._interrupt_state = _InterruptState() self.tracer = get_tracer() self.trace_attributes: dict[str, AttributeValue] = self._parse_trace_attributes(trace_attributes) self.session_manager = session_manager @@ -519,6 +523,8 @@ async def stream_async( - multi_agent_node_stop: When a node stops execution - result: Final graph result """ + self._interrupt_state.resume(task) + if invocation_state is None: invocation_state = {} @@ -528,7 +534,7 @@ async def stream_async( # Initialize state start_time = time.time() - if not self._resume_from_session: + if not self._resume_from_session and not self._interrupt_state.activated: # Initialize state self.state = GraphState( status=Status.EXECUTING, @@ -544,6 +550,8 @@ async def stream_async( span = self.tracer.start_multiagent_span(task, "graph", custom_trace_attributes=self.trace_attributes) with trace_api.use_span(span, end_on_exit=True): + interrupts = [] + try: logger.debug( "max_node_executions=<%s>, execution_timeout=<%s>s, node_timeout=<%s>s | graph execution config", @@ -553,6 +561,9 @@ async def stream_async( ) async for event in self._execute_graph(invocation_state): + if isinstance(event, MultiAgentNodeInterruptEvent): + interrupts.extend(event.interrupts) + yield event.as_dict() # Set final status based on execution results @@ -564,7 +575,7 @@ async def stream_async( logger.debug("status=<%s> | graph execution completed", self.state.status) # Yield final result (consistent with Agent's AgentResultEvent format) - result = self._build_result() + result = self._build_result(interrupts) # Use the same event format as Agent for consistency yield MultiAgentResultEvent(result=result).as_dict() @@ -574,7 +585,7 @@ async def stream_async( self.state.status = Status.FAILED raise finally: - self.state.execution_time = round((time.time() - start_time) * 1000) + self.state.execution_time += round((time.time() - start_time) * 1000) await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self)) self._resume_from_session = False self._resume_next_nodes.clear() @@ -591,9 +602,41 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: # Validate Agent-specific constraints for each node _validate_node_executor(node.executor) + def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> MultiAgentNodeInterruptEvent: + """Activate the interrupt state. + + Args: + node: The interrupted node. + interrupts: The interrupts raised by the user. + + Returns: + MultiAgentNodeInterruptEvent + """ + logger.debug("node=<%s> | node interrupted", node.node_id) + + node.execution_status = Status.INTERRUPTED + + self.state.status = Status.INTERRUPTED + self.state.interrupted_nodes.add(node) + + self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts}) + self._interrupt_state.activate() + + return MultiAgentNodeInterruptEvent(node.node_id, interrupts) + async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: """Execute graph and yield TypedEvent objects.""" - ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points) + if self._interrupt_state.activated: + ready_nodes = [self.nodes[node_id] for node_id in self._interrupt_state.context["completed_nodes"]] + ready_nodes.extend(self.state.interrupted_nodes) + + self.state.interrupted_nodes.clear() + + elif self._resume_from_session: + ready_nodes = self._resume_next_nodes + + else: + ready_nodes = list(self.entry_points) while ready_nodes: # Check execution limits before continuing @@ -613,6 +656,14 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato async for event in self._execute_nodes_parallel(current_batch, invocation_state): yield event + if self.state.status == Status.INTERRUPTED: + self._interrupt_state.context["completed_nodes"] = [ + node.node_id for node in current_batch if node.execution_status == Status.COMPLETED + ] + return + + self._interrupt_state.deactivate() + # Find newly ready nodes after batch execution # We add all nodes in current batch as completed batch, # because a failure would throw exception and code would not make it here @@ -641,6 +692,9 @@ async def _execute_nodes_parallel( Uses a shared queue where each node's stream runs independently and pushes events as they occur, enabling true real-time event propagation without round-robin delays. """ + if self._interrupt_state.activated: + nodes = [node for node in nodes if node.execution_status == Status.INTERRUPTED] + event_queue: asyncio.Queue[Any | None | Exception] = asyncio.Queue() # Start all node streams as independent tasks @@ -797,12 +851,16 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) ) yield start_event - before_event, _ = await self.hooks.invoke_callbacks_async( + before_event, interrupts = await self.hooks.invoke_callbacks_async( BeforeNodeCallEvent(self, node.node_id, invocation_state) ) start_time = time.time() try: + if interrupts: + yield self._activate_interrupt(node, interrupts) + return + if before_event.cancel_node: cancel_message = ( before_event.cancel_node if isinstance(before_event.cancel_node, str) else "node cancelled by user" @@ -830,6 +888,13 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if multi_agent_result is None: raise ValueError(f"Node '{node.node_id}' did not produce a result event") + if multi_agent_result.status == Status.INTERRUPTED: + raise NotImplementedError( + f"node_id=<{node.node_id}>, " + "issue= " + "| user raised interrupt from a multi agent node" + ) + node_result = NodeResult( result=multi_agent_result, execution_time=multi_agent_result.execution_time, @@ -854,12 +919,15 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) if agent_response is None: raise ValueError(f"Node '{node.node_id}' did not produce a result event") - # Check for interrupt (from main branch) if agent_response.stop_reason == "interrupt": node.executor.messages.pop() # remove interrupted tool use message node.executor._interrupt_state.deactivate() - raise RuntimeError("user raised interrupt from agent | interrupts are not yet supported in graphs") + raise NotImplementedError( + f"node_id=<{node.node_id}>, " + "issue= " + "| user raised interrupt from an agent node" + ) # Extract metrics with defaults response_metrics = getattr(agent_response, "metrics", None) @@ -1006,8 +1074,15 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: return node_input - def _build_result(self) -> GraphResult: - """Build graph result from current state.""" + def _build_result(self, interrupts: list[Interrupt]) -> GraphResult: + """Build graph result from current state. + + Args: + interrupts: List of interrupts collected during execution. + + Returns: + GraphResult with current state. + """ return GraphResult( status=self.state.status, results=self.state.results, @@ -1018,9 +1093,11 @@ def _build_result(self) -> GraphResult: total_nodes=self.state.total_nodes, completed_nodes=len(self.state.completed_nodes), failed_nodes=len(self.state.failed_nodes), + interrupted_nodes=len(self.state.interrupted_nodes), execution_order=self.state.execution_order, edges=self.state.edges, entry_points=self.state.entry_points, + interrupts=interrupts, ) def serialize_state(self) -> dict[str, Any]: @@ -1033,10 +1110,14 @@ def serialize_state(self) -> dict[str, Any]: "status": self.state.status.value, "completed_nodes": [n.node_id for n in self.state.completed_nodes], "failed_nodes": [n.node_id for n in self.state.failed_nodes], + "interrupted_nodes": [n.node_id for n in self.state.interrupted_nodes], "node_results": {k: v.to_dict() for k, v in (self.state.results or {}).items()}, "next_nodes_to_execute": next_nodes, "current_task": self.state.task, "execution_order": [n.node_id for n in self.state.execution_order], + "_internal_state": { + "interrupt_state": self._interrupt_state.to_dict(), + }, } def deserialize_state(self, payload: dict[str, Any]) -> None: @@ -1052,6 +1133,10 @@ def deserialize_state(self, payload: dict[str, Any]) -> None: payload: Dictionary containing persisted state data including status, completed nodes, results, and next nodes to execute. """ + if "_internal_state" in payload: + internal_state = payload["_internal_state"] + self._interrupt_state = _InterruptState.from_dict(internal_state["interrupt_state"]) + if not payload.get("next_nodes_to_execute"): # Reset all nodes for node in self.nodes.values(): @@ -1098,10 +1183,20 @@ def _from_dict(self, payload: dict[str, Any]) -> None: self.state.failed_nodes = set( self.nodes[node_id] for node_id in (payload.get("failed_nodes") or []) if node_id in self.nodes ) + for node in self.state.failed_nodes: + node.execution_status = Status.FAILED - # Restore completed nodes from persisted data - completed_node_ids = payload.get("completed_nodes") or [] - self.state.completed_nodes = {self.nodes[node_id] for node_id in completed_node_ids if node_id in self.nodes} + self.state.interrupted_nodes = set( + self.nodes[node_id] for node_id in (payload.get("interrupted_nodes") or []) if node_id in self.nodes + ) + for node in self.state.interrupted_nodes: + node.execution_status = Status.INTERRUPTED + + self.state.completed_nodes = set( + self.nodes[node_id] for node_id in (payload.get("completed_nodes") or []) if node_id in self.nodes + ) + for node in self.state.completed_nodes: + node.execution_status = Status.COMPLETED # Execution order (only nodes that still exist) order_node_ids = payload.get("execution_order") or [] diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 4875d1bec..ab2d86e70 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1,6 +1,6 @@ import asyncio import time -from unittest.mock import AsyncMock, MagicMock, Mock, call, patch +from unittest.mock import ANY, AsyncMock, MagicMock, Mock, call, patch import pytest @@ -9,6 +9,7 @@ from strands.experimental.hooks.multiagent import BeforeNodeCallEvent from strands.hooks import AgentInitializedEvent from strands.hooks.registry import HookProvider, HookRegistry +from strands.interrupt import Interrupt, _InterruptState from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult from strands.multiagent.graph import Graph, GraphBuilder, GraphEdge, GraphNode, GraphResult, GraphState, Status from strands.session.file_session_manager import FileSessionManager @@ -2004,6 +2005,9 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span): state = graph.serialize_state() assert state["type"] == "graph" assert state["id"] == "default_graph" + assert state["_internal_state"] == { + "interrupt_state": {"activated": False, "context": {}, "interrupts": {}}, + } assert "status" in state assert "completed_nodes" in state assert "node_results" in state @@ -2013,14 +2017,33 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span): "status": "executing", "completed_nodes": [], "failed_nodes": [], + "interrupted_nodes": [], "node_results": {}, "current_task": "persisted task", "execution_order": [], "next_nodes_to_execute": ["test_node"], + "_internal_state": { + "interrupt_state": { + "activated": False, + "context": {"a": 1}, + "interrupts": { + "i1": { + "id": "i1", + "name": "test_name", + "reason": "test_reason", + }, + }, + }, + }, } graph.deserialize_state(persisted_state) assert graph.state.task == "persisted task" + assert graph._interrupt_state == _InterruptState( + activated=False, + context={"a": 1}, + interrupts={"i1": Interrupt(id="i1", name="test_name", reason="test_reason")}, + ) # Execute graph to test persistence integration result = await graph.invoke_async("Test persistence") @@ -2068,3 +2091,66 @@ def cancel_callback(event): tru_status = graph.state.status exp_status = Status.FAILED assert tru_status == exp_status + + +def test_graph_interrupt_on_before_node_call_event(interrupt_hook): + agent = create_mock_agent("test_agent", "Task completed") + + builder = GraphBuilder() + builder.add_node(agent, "test_agent") + builder.set_hook_providers([interrupt_hook]) + graph = builder.build() + + multiagent_result = graph("Test task") + + first_execution_time = multiagent_result.execution_time + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = [node.node_id for node in graph.state.interrupted_nodes] + exp_node_ids = ["test_agent"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_name", + reason="test_reason", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "test_response", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 1 + agent_result = multiagent_result.results["test_agent"] + + tru_message = agent_result.result.message["content"][0]["text"] + exp_message = "Task completed" + assert tru_message == exp_message + + assert multiagent_result.execution_time >= first_execution_time diff --git a/tests_integ/interrupts/multiagent/test_hook.py b/tests_integ/interrupts/multiagent/test_hook.py index be7682082..9350b3535 100644 --- a/tests_integ/interrupts/multiagent/test_hook.py +++ b/tests_integ/interrupts/multiagent/test_hook.py @@ -7,7 +7,7 @@ from strands.experimental.hooks.multiagent import BeforeNodeCallEvent from strands.hooks import HookProvider from strands.interrupt import Interrupt -from strands.multiagent import Swarm +from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status @@ -18,16 +18,34 @@ def register_hooks(self, registry): registry.add_callback(BeforeNodeCallEvent, self.interrupt) def interrupt(self, event): - if event.node_id == "info": + if event.node_id == "info" or event.node_id == "time": return - response = event.interrupt("test_interrupt", reason="need approval") + response = event.interrupt(f"{event.node_id}_interrupt", reason="need approval") if response != "APPROVE": event.cancel_node = "node rejected" return Hook() +@pytest.fixture +def day_tool(): + @tool(name="day_tool") + def func(): + return "monday" + + return func + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:01" + + return func + + @pytest.fixture def weather_tool(): @tool(name="weather_tool") @@ -38,13 +56,49 @@ def func(): @pytest.fixture -def swarm(interrupt_hook, weather_tool): - info_agent = Agent(name="info") - weather_agent = Agent(name="weather", tools=[weather_tool]) +def info_agent(): + return Agent(name="info") + +@pytest.fixture +def day_agent(day_tool): + return Agent(name="day", tools=[day_tool]) + + +@pytest.fixture +def time_agent(time_tool): + return Agent(name="time", tools=[time_tool]) + + +@pytest.fixture +def weather_agent(weather_tool): + return Agent(name="weather", tools=[weather_tool]) + + +@pytest.fixture +def swarm(interrupt_hook, info_agent, weather_agent): return Swarm([info_agent, weather_agent], hooks=[interrupt_hook]) +@pytest.fixture +def graph(interrupt_hook, info_agent, day_agent, time_agent, weather_agent): + builder = GraphBuilder() + + builder.add_node(info_agent, "info") + builder.add_node(day_agent, "day") + builder.add_node(time_agent, "time") + builder.add_node(weather_agent, "weather") + + builder.add_edge("info", "day") + builder.add_edge("info", "time") + builder.add_edge("info", "weather") + + builder.set_entry_point("info") + builder.set_hook_providers([interrupt_hook]) + + return builder.build() + + def test_swarm_interrupt(swarm): multiagent_result = swarm("What is the weather?") @@ -56,7 +110,7 @@ def test_swarm_interrupt(swarm): exp_interrupts = [ Interrupt( id=ANY, - name="test_interrupt", + name="weather_interrupt", reason="need approval", ), ] @@ -97,7 +151,7 @@ async def test_swarm_interrupt_reject(swarm): exp_interrupts = [ Interrupt( id=ANY, - name="test_interrupt", + name="weather_interrupt", reason="need approval", ), ] @@ -131,3 +185,120 @@ async def test_swarm_interrupt_reject(swarm): tru_node_id = multiagent_result.node_history[0].node_id exp_node_id = "info" assert tru_node_id == exp_node_id + + +def test_graph_interrupt(graph): + multiagent_result = graph("What is the day, time, and weather?") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_node_ids = sorted([node.node_id for node in graph.state.interrupted_nodes]) + exp_node_ids = ["day", "weather"] + assert tru_node_ids == exp_node_ids + + tru_interrupts = sorted(multiagent_result.interrupts, key=lambda interrupt: interrupt.name) + exp_interrupts = [ + Interrupt( + id=ANY, + name="day_interrupt", + reason="need approval", + ), + Interrupt( + id=ANY, + name="weather_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + } + for interrupt in multiagent_result.interrupts + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 4 + + day_message = json.dumps(multiagent_result.results["day"].result.message).lower() + time_message = json.dumps(multiagent_result.results["time"].result.message).lower() + weather_message = json.dumps(multiagent_result.results["weather"].result.message).lower() + assert "monday" in day_message + assert "12:01" in time_message + assert "sunny" in weather_message + + +@pytest.mark.asyncio +async def test_graph_interrupt_reject(graph): + multiagent_result = graph("What is the day, time, and weather?") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_interrupts = sorted(multiagent_result.interrupts, key=lambda interrupt: interrupt.name) + exp_interrupts = [ + Interrupt( + id=ANY, + name="day_interrupt", + reason="need approval", + ), + Interrupt( + id=ANY, + name="weather_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + responses = [ + { + "interruptResponse": { + "interruptId": tru_interrupts[0].id, + "response": "APPROVE", + }, + }, + { + "interruptResponse": { + "interruptId": tru_interrupts[1].id, + "response": "REJECT", + }, + }, + ] + + try: + async for event in graph.stream_async(responses): + if event.get("type") == "multiagent_node_cancel": + tru_cancel_id = event["node_id"] + + except RuntimeError as e: + assert "node rejected" in str(e) + + exp_cancel_id = "weather" + assert tru_cancel_id == exp_cancel_id + + tru_state_status = graph.state.status + exp_state_status = Status.FAILED + assert tru_state_status == exp_state_status diff --git a/tests_integ/interrupts/multiagent/test_session.py b/tests_integ/interrupts/multiagent/test_session.py index d6e8cdbf8..bab4b428f 100644 --- a/tests_integ/interrupts/multiagent/test_session.py +++ b/tests_integ/interrupts/multiagent/test_session.py @@ -4,13 +4,30 @@ import pytest from strands import Agent, tool +from strands.experimental.hooks.multiagent import BeforeNodeCallEvent +from strands.hooks import HookProvider from strands.interrupt import Interrupt -from strands.multiagent import Swarm +from strands.multiagent import GraphBuilder, Swarm from strands.multiagent.base import Status from strands.session import FileSessionManager from strands.types.tools import ToolContext +@pytest.fixture +def interrupt_hook(): + class Hook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BeforeNodeCallEvent, self.interrupt) + + def interrupt(self, event): + if event.node_id == "time": + response = event.interrupt("test_interrupt", reason="need approval") + if response != "APPROVE": + event.cancel_node = "node rejected" + + return Hook() + + @pytest.fixture def weather_tool(): @tool(name="weather_tool", context=True) @@ -22,9 +39,12 @@ def func(tool_context: ToolContext) -> str: @pytest.fixture -def swarm(weather_tool): - weather_agent = Agent(name="weather", tools=[weather_tool]) - return Swarm([weather_agent]) +def time_tool(): + @tool(name="time_tool") + def func(): + return "12:01" + + return func def test_swarm_interrupt_session(weather_tool, tmpdir): @@ -75,3 +95,73 @@ def test_swarm_interrupt_session(weather_tool, tmpdir): summarizer_message = json.dumps(summarizer_result.result.message).lower() assert "sunny" in summarizer_message + + +def test_graph_interrupt_session(interrupt_hook, time_tool, tmpdir): + time_agent = Agent(name="time", tools=[time_tool]) + summarizer_agent = Agent(name="summarizer") + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + + builder = GraphBuilder() + builder.add_node(time_agent, "time") + builder.add_node(summarizer_agent, "summarizer") + builder.add_edge("time", "summarizer") + builder.set_hook_providers([interrupt_hook]) + builder.set_session_manager(session_manager) + graph = builder.build() + + multiagent_result = graph("Can you check the time and then summarize the results?") + + tru_result_status = multiagent_result.status + exp_result_status = Status.INTERRUPTED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.INTERRUPTED + assert tru_state_status == exp_state_status + + tru_interrupts = multiagent_result.interrupts + exp_interrupts = [ + Interrupt( + id=ANY, + name="test_interrupt", + reason="need approval", + ), + ] + assert tru_interrupts == exp_interrupts + + interrupt = multiagent_result.interrupts[0] + + time_agent = Agent(name="time", tools=[time_tool]) + summarizer_agent = Agent(name="summarizer") + session_manager = FileSessionManager(session_id="strands-interrupt-test", storage_dir=tmpdir) + + builder = GraphBuilder() + builder.add_node(time_agent, "time") + builder.add_node(summarizer_agent, "summarizer") + builder.add_edge("time", "summarizer") + builder.set_hook_providers([interrupt_hook]) + builder.set_session_manager(session_manager) + graph = builder.build() + + responses = [ + { + "interruptResponse": { + "interruptId": interrupt.id, + "response": "APPROVE", + }, + }, + ] + multiagent_result = graph(responses) + + tru_result_status = multiagent_result.status + exp_result_status = Status.COMPLETED + assert tru_result_status == exp_result_status + + tru_state_status = graph.state.status + exp_state_status = Status.COMPLETED + assert tru_state_status == exp_state_status + + assert len(multiagent_result.results) == 2 + summarizer_message = json.dumps(multiagent_result.results["summarizer"].result.message).lower() + assert "12:01" in summarizer_message