Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 115 additions & 20 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Directed Graph Multi-Agent Pattern Implementation.

This module provides a deterministic graph-based agent orchestration system where
Expand Down Expand Up @@ -34,11 +34,13 @@
MultiAgentInitializedEvent,
)
from ..hooks import HookProvider, HookRegistry
from ..interrupt import Interrupt, _InterruptState
from ..session import SessionManager
from ..telemetry import get_tracer
from ..types._events import (
MultiAgentHandoffEvent,
MultiAgentNodeCancelEvent,
MultiAgentNodeInterruptEvent,
MultiAgentNodeStartEvent,
MultiAgentNodeStopEvent,
MultiAgentNodeStreamEvent,
Expand All @@ -63,10 +65,15 @@
status: Current execution status of the graph.
completed_nodes: Set of nodes that have completed execution.
failed_nodes: Set of nodes that failed during execution.
interrupted_nodes: Set of nodes that user interrupted during execution.
execution_order: List of nodes in the order they were executed.
task: The original input prompt/query provided to the graph execution.
This represents the actual work to be performed by the graph as a whole.
Entry point nodes receive this task as their input if they have no dependencies.
start_time: Timestamp when the current invocation started.
Resets on each invocation, even when resuming from interrupt.
execution_time: Execution time of current invocation in milliseconds.
Excludes time spent waiting for interrupt responses.
"""

# Task (with default empty string)
Expand All @@ -76,6 +83,7 @@
status: Status = Status.PENDING
completed_nodes: set["GraphNode"] = field(default_factory=set)
failed_nodes: set["GraphNode"] = field(default_factory=set)
interrupted_nodes: set["GraphNode"] = field(default_factory=set)
execution_order: list["GraphNode"] = field(default_factory=list)
start_time: float = field(default_factory=time.time)

Expand Down Expand Up @@ -108,7 +116,7 @@

# Check timeout (only if set)
if execution_timeout is not None:
elapsed = time.time() - self.start_time
elapsed = self.execution_time / 1000 + time.time() - self.start_time
if elapsed > execution_timeout:
return False, f"Execution timed out: {execution_timeout}s"

Expand All @@ -122,6 +130,7 @@
total_nodes: int = 0
completed_nodes: int = 0
failed_nodes: int = 0
interrupted_nodes: int = 0
execution_order: list["GraphNode"] = field(default_factory=list)
edges: list[Tuple["GraphNode", "GraphNode"]] = field(default_factory=list)
entry_points: list["GraphNode"] = field(default_factory=list)
Expand All @@ -148,13 +157,7 @@

@dataclass
class GraphNode:
"""Represents a node in the graph.

The execution_status tracks the node's lifecycle within graph orchestration:
- PENDING: Node hasn't started executing yet
- EXECUTING: Node is currently running
- COMPLETED/FAILED: Node finished executing (regardless of result quality)
"""
"""Represents a node in the graph."""

node_id: str
executor: Agent | MultiAgentBase
Expand Down Expand Up @@ -445,6 +448,7 @@
self.node_timeout = node_timeout
self.reset_on_revisit = reset_on_revisit
self.state = GraphState()
self._interrupt_state = _InterruptState()
self.tracer = get_tracer()
self.trace_attributes: dict[str, AttributeValue] = self._parse_trace_attributes(trace_attributes)
self.session_manager = session_manager
Expand Down Expand Up @@ -519,6 +523,8 @@
- multi_agent_node_stop: When a node stops execution
- result: Final graph result
"""
self._interrupt_state.resume(task)

if invocation_state is None:
invocation_state = {}

Expand All @@ -528,7 +534,7 @@

# Initialize state
start_time = time.time()
if not self._resume_from_session:
if not self._resume_from_session and not self._interrupt_state.activated:
# Initialize state
self.state = GraphState(
status=Status.EXECUTING,
Expand All @@ -544,6 +550,8 @@

span = self.tracer.start_multiagent_span(task, "graph", custom_trace_attributes=self.trace_attributes)
with trace_api.use_span(span, end_on_exit=True):
interrupts = []

try:
logger.debug(
"max_node_executions=<%s>, execution_timeout=<%s>s, node_timeout=<%s>s | graph execution config",
Expand All @@ -553,6 +561,9 @@
)

async for event in self._execute_graph(invocation_state):
if isinstance(event, MultiAgentNodeInterruptEvent):
interrupts.extend(event.interrupts)

yield event.as_dict()

# Set final status based on execution results
Expand All @@ -564,7 +575,7 @@
logger.debug("status=<%s> | graph execution completed", self.state.status)

# Yield final result (consistent with Agent's AgentResultEvent format)
result = self._build_result()
result = self._build_result(interrupts)

# Use the same event format as Agent for consistency
yield MultiAgentResultEvent(result=result).as_dict()
Expand All @@ -574,7 +585,7 @@
self.state.status = Status.FAILED
raise
finally:
self.state.execution_time = round((time.time() - start_time) * 1000)
self.state.execution_time += round((time.time() - start_time) * 1000)
await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self))
self._resume_from_session = False
self._resume_next_nodes.clear()
Expand All @@ -591,9 +602,41 @@
# Validate Agent-specific constraints for each node
_validate_node_executor(node.executor)

def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> MultiAgentNodeInterruptEvent:
"""Activate the interrupt state.

Args:
node: The interrupted node.
interrupts: The interrupts raised by the user.

Returns:
MultiAgentNodeInterruptEvent
"""
logger.debug("node=<%s> | node interrupted", node.node_id)

node.execution_status = Status.INTERRUPTED

self.state.status = Status.INTERRUPTED
self.state.interrupted_nodes.add(node)

self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts})
self._interrupt_state.activate()

return MultiAgentNodeInterruptEvent(node.node_id, interrupts)

async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
"""Execute graph and yield TypedEvent objects."""
ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points)
if self._interrupt_state.activated:
ready_nodes = [self.nodes[node_id] for node_id in self._interrupt_state.context["completed_nodes"]]
ready_nodes.extend(self.state.interrupted_nodes)

self.state.interrupted_nodes.clear()

elif self._resume_from_session:
ready_nodes = self._resume_next_nodes

else:
ready_nodes = list(self.entry_points)

while ready_nodes:
# Check execution limits before continuing
Expand All @@ -613,6 +656,14 @@
async for event in self._execute_nodes_parallel(current_batch, invocation_state):
yield event

if self.state.status == Status.INTERRUPTED:
self._interrupt_state.context["completed_nodes"] = [
node.node_id for node in current_batch if node.execution_status == Status.COMPLETED
]
return

self._interrupt_state.deactivate()

# Find newly ready nodes after batch execution
# We add all nodes in current batch as completed batch,
# because a failure would throw exception and code would not make it here
Expand Down Expand Up @@ -641,6 +692,9 @@
Uses a shared queue where each node's stream runs independently and pushes events
as they occur, enabling true real-time event propagation without round-robin delays.
"""
if self._interrupt_state.activated:
nodes = [node for node in nodes if node.execution_status == Status.INTERRUPTED]

event_queue: asyncio.Queue[Any | None | Exception] = asyncio.Queue()

# Start all node streams as independent tasks
Expand Down Expand Up @@ -797,12 +851,16 @@
)
yield start_event

before_event, _ = await self.hooks.invoke_callbacks_async(
before_event, interrupts = await self.hooks.invoke_callbacks_async(
BeforeNodeCallEvent(self, node.node_id, invocation_state)
)

start_time = time.time()
try:
if interrupts:
yield self._activate_interrupt(node, interrupts)
return

if before_event.cancel_node:
cancel_message = (
before_event.cancel_node if isinstance(before_event.cancel_node, str) else "node cancelled by user"
Expand Down Expand Up @@ -830,6 +888,13 @@
if multi_agent_result is None:
raise ValueError(f"Node '{node.node_id}' did not produce a result event")

if multi_agent_result.status == Status.INTERRUPTED:
raise NotImplementedError(
f"node_id=<{node.node_id}>, "
"issue=<https://github.com/strands-agents/sdk-python/issues/204> "
"| user raised interrupt from a multi agent node"
)

node_result = NodeResult(
result=multi_agent_result,
execution_time=multi_agent_result.execution_time,
Expand All @@ -854,12 +919,15 @@
if agent_response is None:
raise ValueError(f"Node '{node.node_id}' did not produce a result event")

# Check for interrupt (from main branch)
if agent_response.stop_reason == "interrupt":
node.executor.messages.pop() # remove interrupted tool use message
node.executor._interrupt_state.deactivate()

raise RuntimeError("user raised interrupt from agent | interrupts are not yet supported in graphs")
raise NotImplementedError(
f"node_id=<{node.node_id}>, "
"issue=<https://github.com/strands-agents/sdk-python/issues/204> "
"| user raised interrupt from an agent node"
)

# Extract metrics with defaults
response_metrics = getattr(agent_response, "metrics", None)
Expand Down Expand Up @@ -1006,8 +1074,15 @@

return node_input

def _build_result(self) -> GraphResult:
"""Build graph result from current state."""
def _build_result(self, interrupts: list[Interrupt]) -> GraphResult:
"""Build graph result from current state.

Args:
interrupts: List of interrupts collected during execution.

Returns:
GraphResult with current state.
"""
return GraphResult(
status=self.state.status,
results=self.state.results,
Expand All @@ -1018,9 +1093,11 @@
total_nodes=self.state.total_nodes,
completed_nodes=len(self.state.completed_nodes),
failed_nodes=len(self.state.failed_nodes),
interrupted_nodes=len(self.state.interrupted_nodes),
execution_order=self.state.execution_order,
edges=self.state.edges,
entry_points=self.state.entry_points,
interrupts=interrupts,
)

def serialize_state(self) -> dict[str, Any]:
Expand All @@ -1033,10 +1110,14 @@
"status": self.state.status.value,
"completed_nodes": [n.node_id for n in self.state.completed_nodes],
"failed_nodes": [n.node_id for n in self.state.failed_nodes],
"interrupted_nodes": [n.node_id for n in self.state.interrupted_nodes],
"node_results": {k: v.to_dict() for k, v in (self.state.results or {}).items()},
"next_nodes_to_execute": next_nodes,
"current_task": self.state.task,
"execution_order": [n.node_id for n in self.state.execution_order],
"_internal_state": {
"interrupt_state": self._interrupt_state.to_dict(),
},
}

def deserialize_state(self, payload: dict[str, Any]) -> None:
Expand All @@ -1052,6 +1133,10 @@
payload: Dictionary containing persisted state data including status,
completed nodes, results, and next nodes to execute.
"""
if "_internal_state" in payload:
internal_state = payload["_internal_state"]
self._interrupt_state = _InterruptState.from_dict(internal_state["interrupt_state"])

if not payload.get("next_nodes_to_execute"):
# Reset all nodes
for node in self.nodes.values():
Expand Down Expand Up @@ -1098,10 +1183,20 @@
self.state.failed_nodes = set(
self.nodes[node_id] for node_id in (payload.get("failed_nodes") or []) if node_id in self.nodes
)
for node in self.state.failed_nodes:
node.execution_status = Status.FAILED

# Restore completed nodes from persisted data
completed_node_ids = payload.get("completed_nodes") or []
self.state.completed_nodes = {self.nodes[node_id] for node_id in completed_node_ids if node_id in self.nodes}
self.state.interrupted_nodes = set(
self.nodes[node_id] for node_id in (payload.get("interrupted_nodes") or []) if node_id in self.nodes
)
for node in self.state.interrupted_nodes:
node.execution_status = Status.INTERRUPTED

self.state.completed_nodes = set(
self.nodes[node_id] for node_id in (payload.get("completed_nodes") or []) if node_id in self.nodes
)
for node in self.state.completed_nodes:
node.execution_status = Status.COMPLETED

# Execution order (only nodes that still exist)
order_node_ids = payload.get("execution_order") or []
Expand Down
Loading
Loading