diff --git a/python/packages/main/agent_framework/workflow/__init__.py b/python/packages/main/agent_framework/workflow/__init__.py index bbfd05c2a9..971b7fb004 100644 --- a/python/packages/main/agent_framework/workflow/__init__.py +++ b/python/packages/main/agent_framework/workflow/__init__.py @@ -32,6 +32,8 @@ "InMemoryCheckpointStorage", "CheckpointStorage", "WorkflowCheckpoint", + "Case", + "Default", ] diff --git a/python/packages/main/agent_framework/workflow/__init__.pyi b/python/packages/main/agent_framework/workflow/__init__.pyi index 6506d4a936..e36c6ee461 100644 --- a/python/packages/main/agent_framework/workflow/__init__.pyi +++ b/python/packages/main/agent_framework/workflow/__init__.pyi @@ -6,7 +6,9 @@ from agent_framework_workflow import ( AgentExecutorResponse, AgentRunEvent, AgentRunStreamingEvent, + Case, CheckpointStorage, + Default, Executor, ExecutorCompletedEvent, ExecutorEvent, @@ -34,7 +36,9 @@ __all__ = [ "AgentExecutorResponse", "AgentRunEvent", "AgentRunStreamingEvent", + "Case", "CheckpointStorage", + "Default", "Executor", "ExecutorCompletedEvent", "ExecutorEvent", diff --git a/python/packages/workflow/agent_framework_workflow/__init__.py b/python/packages/workflow/agent_framework_workflow/__init__.py index 1d59918525..8196d5f767 100644 --- a/python/packages/workflow/agent_framework_workflow/__init__.py +++ b/python/packages/workflow/agent_framework_workflow/__init__.py @@ -11,6 +11,7 @@ from ._const import ( DEFAULT_MAX_ITERATIONS, ) +from ._edge import Case, Default from ._events import ( AgentRunEvent, AgentRunStreamingEvent, @@ -60,7 +61,9 @@ "AgentExecutorResponse", "AgentRunEvent", "AgentRunStreamingEvent", + "Case", "CheckpointStorage", + "Default", "EdgeDuplicationError", "Executor", "ExecutorCompletedEvent", diff --git a/python/packages/workflow/agent_framework_workflow/_edge.py b/python/packages/workflow/agent_framework_workflow/_edge.py index 0d7ca8bb18..fef5e3376d 100644 --- a/python/packages/workflow/agent_framework_workflow/_edge.py +++ b/python/packages/workflow/agent_framework_workflow/_edge.py @@ -1,7 +1,11 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from collections.abc import Callable +import logging +import uuid +from collections import defaultdict +from collections.abc import Callable, Sequence +from dataclasses import dataclass from typing import Any, ClassVar from ._executor import Executor @@ -9,6 +13,8 @@ from ._shared_state import SharedState from ._workflow_context import WorkflowContext +logger = logging.getLogger(__name__) + class Edge: """Represents a directed edge in a graph.""" @@ -34,10 +40,6 @@ def __init__( self.target = target self._condition = condition - # Edge group is used to group edges that share the same target executor. - # It allows for sending messages to the target executor only when all edges in the group have data. - self._edge_group_ids: list[str] = [] - @property def source_id(self) -> str: """Get the source executor ID.""" @@ -53,27 +55,6 @@ def id(self) -> str: """Get the unique ID of the edge.""" return f"{self.source_id}{self.ID_SEPARATOR}{self.target_id}" - def has_edge_group(self) -> bool: - """Check if the edge is part of an edge group.""" - return bool(self._edge_group_ids) - - @classmethod - def source_and_target_from_id(cls, edge_id: str) -> tuple[str, str]: - """Extract the source and target IDs from the edge ID. - - Args: - edge_id (str): The edge ID in the format "source_id->target_id". - - Returns: - tuple[str, str]: A tuple containing the source ID and target ID. - """ - if cls.ID_SEPARATOR not in edge_id: - raise ValueError(f"Invalid edge ID format: {edge_id}") - ids = edge_id.split(cls.ID_SEPARATOR) - if len(ids) != 2: - raise ValueError(f"Invalid edge ID format: {edge_id}") - return ids[0], ids[1] - def can_handle(self, message_data: Any) -> bool: """Check if the edge can handle the given data. @@ -83,11 +64,14 @@ def can_handle(self, message_data: Any) -> bool: Returns: bool: True if the edge can handle the data, False otherwise. """ - if not self._edge_group_ids: - return self.target.can_handle(message_data) + return self.target.can_handle(message_data) - # If the edge is part of an edge group, the target should expect a list of the data type. - return self.target.can_handle([message_data]) + def should_route(self, data: Any) -> bool: + """Determine if message should be routed through this edge based on the condition.""" + if self._condition is None: + return True + + return self._condition(data) async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> None: """Send a message along this edge. @@ -98,57 +82,338 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R ctx (RunnerContext): The context for the runner. """ if not self.can_handle(message.data): + # Caller of this method should ensure that the edge can handle the data. raise RuntimeError(f"Edge {self.id} cannot handle data of type {type(message.data)}.") - if not self._edge_group_ids and self._should_route(message.data): + if self.should_route(message.data): await self.target.execute( message.data, WorkflowContext(self.target.id, [self.source.id], shared_state, ctx) ) - elif self._edge_group_ids: - # Logic: - # 1. If not all edges in the edge group have data in the shared state, - # add the data to the shared state. - # 2. If all edges in the edge group have data in the shared state, - # copy the data to a list and send it to the target executor. - message_list: list[Message] = [] - async with shared_state.hold() as held_shared_state: - has_data = await asyncio.gather( - *(held_shared_state.has_within_hold(edge_id) for edge_id in self._edge_group_ids) - ) - if not all(has_data): - await held_shared_state.set_within_hold(self.id, message) - else: - message_list = [ - await held_shared_state.get_within_hold(edge_id) for edge_id in self._edge_group_ids - ] + [message] - # Remove the data from the shared state after retrieving it - await asyncio.gather( - *(held_shared_state.delete_within_hold(edge_id) for edge_id in self._edge_group_ids) - ) - - if message_list: - data_list = [msg.data for msg in message_list] - source_ids = [msg.source_id for msg in message_list] - await self.target.execute(data_list, WorkflowContext(self.target.id, source_ids, shared_state, ctx)) - - def _should_route(self, data: Any) -> bool: - """Determine if message should be routed through this edge.""" - if self._condition is None: + + +class EdgeGroup: + """Represents a group of edges that share some common properties and can be triggered together.""" + + def __init__(self) -> None: + """Initialize the edge group.""" + self._id = f"{self.__class__.__name__}/{uuid.uuid4()}" + + async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + """Send a message through the edge group. + + Args: + message (Message): The message to send. + shared_state (SharedState): The shared state to use for holding data. + ctx (RunnerContext): The context for the runner. + + Returns: + bool: True if the message was sent successfully, False if the target executor cannot handle the message. + If a message can be delivered but rejected due to a condition, it will still return True. + + Note: + Exception will not be raised if the target executor cannot handle the message. This is because + a source executor can be connected to multiple target executors, and not every target executor may + be able to handle all the messages sent by the source executor. + """ + raise NotImplementedError + + @property + def id(self) -> str: + """Get the unique ID of the edge group.""" + return self._id + + @property + def source_executors(self) -> list[Executor]: + """Get the source executor IDs of the edges in the group.""" + raise NotImplementedError + + @property + def target_executors(self) -> list[Executor]: + """Get the target executor IDs of the edges in the group.""" + raise NotImplementedError + + @property + def edges(self) -> list[Edge]: + """Get the edges in the group.""" + raise NotImplementedError + + +class SingleEdgeGroup(EdgeGroup): + """Represents a single edge group that contains only one edge. + + A concrete implementation of EdgeGroup that represent a group containing exactly one edge. + """ + + def __init__(self, source: Executor, target: Executor, condition: Callable[[Any], bool] | None = None) -> None: + """Initialize the single edge group with an edge. + + Args: + source (Executor): The source executor. + target (Executor): The target executor that the source executor can send messages to. + condition (Callable[[Any], bool], optional): A condition function that determines + if the edge will pass the data to the target executor. If None, the edge can + will always pass the data to the target executor. + """ + self._edge = Edge(source=source, target=target, condition=condition) + + async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + """Send a message through the single edge.""" + if message.target_id and message.target_id != self._edge.target_id: + return False + + if self._edge.can_handle(message.data): + await self._edge.send_message(message, shared_state, ctx) return True - return self._condition(data) + return False + + @property + def source_executors(self) -> list[Executor]: + """Get the source executor of the edge.""" + return [self._edge.source] + + @property + def target_executors(self) -> list[Executor]: + """Get the target executor of the edge.""" + return [self._edge.target] + + @property + def edges(self) -> list[Edge]: + """Get the edges in the group.""" + return [self._edge] + + +class FanOutEdgeGroup(EdgeGroup): + """Represents a group of edges that share the same source executor. + + Assembles a Fan-out pattern where multiple edges share the same source executor + and send messages to their respective target executors. + """ + + def __init__( + self, + source: Executor, + targets: Sequence[Executor], + selection_func: Callable[[Any, list[str]], list[str]] | None = None, + ) -> None: + """Initialize the fan-out edge group with a list of edges. + + Args: + source (Executor): The source executor. + targets (Sequence[Executor]): A list of target executors that the source executor can send messages to. + selection_func (Callable[[Any, list[str]], list[str]], optional): A function that selects which target + executors to send messages to. The function takes in the message data and a list of target executor + IDs, and returns a list of selected target executor IDs. + """ + if len(targets) <= 1: + raise ValueError("FanOutEdgeGroup must contain at least two targets.") + self._edges = [Edge(source=source, target=target) for target in targets] + self._target_ids = [edge.target_id for edge in self._edges] + self._target_map = {edge.target_id: edge for edge in self._edges} + self._selection_func = selection_func + + async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + """Send a message through all edges in the fan-out edge group.""" + selection_results = ( + self._selection_func(message.data, self._target_ids) if self._selection_func else self._target_ids + ) + if not self._validate_selection_result(selection_results): + raise RuntimeError( + f"Invalid selection result: {selection_results}. " + f"Expected selections to be a subset of valid target executor IDs: {self._target_ids}." + ) + + if message.target_id: + # If the target ID is specified and the selection result contains it, send the message to that edge + if message.target_id in selection_results: + edge = next((edge for edge in self._edges if edge.target_id == message.target_id), None) + if edge and edge.can_handle(message.data): + await edge.send_message(message, shared_state, ctx) + return True + return False + + # If no target ID, send the message to the selected targets + async def send_to_edge(edge: Edge) -> bool: + """Send the message to the edge at the specified index.""" + if edge.can_handle(message.data): + await edge.send_message(message, shared_state, ctx) + return True + return False + + tasks = [send_to_edge(self._target_map[target_id]) for target_id in selection_results] + results = await asyncio.gather(*tasks) + return any(results) + + @property + def source_executors(self) -> list[Executor]: + """Get the source executor of the edges in the group.""" + return [self._edges[0].source] + + @property + def target_executors(self) -> list[Executor]: + """Get the target executors of the edges in the group.""" + return [edge.target for edge in self._edges] + + @property + def edges(self) -> list[Edge]: + """Get the edges in the group.""" + return self._edges + + def _validate_selection_result(self, selection_results: list[str]) -> bool: + """Validate the selection results to ensure all IDs are valid target executor IDs.""" + return all(result in self._target_ids for result in selection_results) + + +class FanInEdgeGroup(EdgeGroup): + """Represents a group of edges that share the same target executor. + + Assembles a Fan-in pattern where multiple edges send messages to a single target executor. + Messages are buffered until all edges in the group have data to send. + """ + + def __init__(self, sources: Sequence[Executor], target: Executor) -> None: + """Initialize the fan-in edge group with a list of edges. + + Args: + sources (Sequence[Executor]): A list of source executors that can send messages to the target executor. + target (Executor): The target executor that receives a list of messages aggregated from all sources. + """ + if len(sources) <= 1: + raise ValueError("FanInEdgeGroup must contain at least two sources.") + self._edges = [Edge(source=source, target=target) for source in sources] + # Buffer to hold messages before sending them to the target executor + # Key is the source executor ID, value is a list of messages + self._buffer: dict[str, list[Message]] = defaultdict(list) + + async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + """Send a message through all edges in the fan-in edge group.""" + if message.target_id and message.target_id != self._edges[0].target_id: + return False + + if self._edges[0].can_handle([message.data]): + # If the edge can handle the data, buffer the message + self._buffer[message.source_id].append(message) + else: + # If the edge cannot handle the data, return False + return False + + if self._is_ready_to_send(): + # If all edges in the group have data, send the buffered messages to the target executor + messages_to_send = [msg for edge in self._edges for msg in self._buffer[edge.source_id]] + self._buffer.clear() + # Only trigger one edge to send the messages to avoid duplicate sends + await self._edges[0].send_message( + Message([msg.data for msg in messages_to_send], self.__class__.__name__), + shared_state, + ctx, + ) - def set_edge_group(self, edge_group_ids: list[str]) -> None: - """Set the edge group IDs for this edge. + return True + + def _is_ready_to_send(self) -> bool: + """Check if all edges in the group have data to send.""" + return all(self._buffer[edge.source_id] for edge in self._edges) + + @property + def source_executors(self) -> list[Executor]: + """Get the source executors of the edges in the group.""" + return [edge.source for edge in self._edges] + + @property + def target_executors(self) -> list[Executor]: + """Get the target executor of the edges in the group.""" + return [self._edges[0].target] + + @property + def edges(self) -> list[Edge]: + """Get the edges in the group.""" + return self._edges + + +@dataclass +class Case: + """Represents a single case in the conditional edge group. + + Args: + condition (Callable[[Any], bool]): The condition function for the case. + target (Executor): The target executor for the case. + """ + + condition: Callable[[Any], bool] + target: Executor + + +@dataclass +class Default: + """Represents the default case in the conditional edge group. + + Args: + target (Executor): The target executor for the default case. + """ + + target: Executor + + +class SwitchCaseEdgeGroup(FanOutEdgeGroup): + """Represents a group of edges that assemble a conditional routing pattern. + + This is similar to a switch-case construct: + switch(data): + case condition_1: + edge_1 + break + case condition_2: + edge_2 + break + default: + edge_3 + break + Or equivalently an if-elif-else construct: + if condition_1: + edge_1 + elif condition_2: + edge_2 + else: + edge_4 + """ + + def __init__( + self, + source: Executor, + cases: Sequence[Case | Default], + ) -> None: + """Initialize the conditional edge group with a list of edges. Args: - edge_group_ids (list[str]): A list of edge IDs that belong to the same edge group. + source (Executor): The source executor. + cases (Sequence[Case | Default]): A list of cases for the conditional edge group. + There should be exactly one default case. """ - # Validate that the edges in the edge group contain the same target executor as this edge - # TODO(@taochen): An edge cannot be part of multiple edge groups. - # TODO(@taochen): Can an edge have both a condition and an edge group? - if edge_group_ids: - for edge_id in edge_group_ids: - if Edge.source_and_target_from_id(edge_id)[1] != self.target.id: - raise ValueError("All edges in the group must have the same target executor.") - self._edge_group_ids = edge_group_ids + if len(cases) < 2: + raise ValueError("SwitchCaseEdgeGroup must contain at least two cases (including the default case).") + + default_case = [isinstance(case, Default) for case in cases] + if sum(default_case) != 1: + raise ValueError("SwitchCaseEdgeGroup must contain exactly one default case.") + + if isinstance(cases[-1], Default): + logger.warning( + "Default case in the conditional edge group is not the last case. " + "This will result in unexpected behavior." + ) + + def selection_func(data: Any, targets: list[str]) -> list[str]: + """Select the target executor based on the conditions.""" + for index, case in enumerate(cases): + if isinstance(case, Default): + return [case.target.id] + if isinstance(case, Case): + try: + if case.condition(data): + return [case.target.id] + except Exception as e: + logger.warning(f"Error occurred while evaluating condition for case {index}: {e}") + + raise RuntimeError("No matching case found in SwitchCaseEdgeGroup.") + + super().__init__(source, [case.target for case in cases], selection_func=selection_func) diff --git a/python/packages/workflow/agent_framework_workflow/_executor.py b/python/packages/workflow/agent_framework_workflow/_executor.py index ea1630eab4..b1ddad61b6 100644 --- a/python/packages/workflow/agent_framework_workflow/_executor.py +++ b/python/packages/workflow/agent_framework_workflow/_executor.py @@ -31,7 +31,7 @@ def __init__(self, id: str | None = None) -> None: Args: id: A unique identifier for the executor. If None, a new UUID will be generated. """ - self._id = id or str(uuid.uuid4()) + self._id = id or f"{self.__class__.__name__}/{uuid.uuid4()}" self._handlers: dict[type, Callable[[Any, WorkflowContext], Any]] = {} self._discover_handlers() diff --git a/python/packages/workflow/agent_framework_workflow/_runner.py b/python/packages/workflow/agent_framework_workflow/_runner.py index 1688a2f5fa..ac6ca41fb1 100644 --- a/python/packages/workflow/agent_framework_workflow/_runner.py +++ b/python/packages/workflow/agent_framework_workflow/_runner.py @@ -3,10 +3,10 @@ import asyncio import logging from collections import defaultdict -from collections.abc import AsyncIterable +from collections.abc import AsyncIterable, Sequence from typing import Any -from ._edge import Edge +from ._edge import EdgeGroup from ._events import WorkflowEvent from ._executor import Executor from ._runner_context import Message, RunnerContext @@ -20,7 +20,7 @@ class Runner: def __init__( self, - edges: list[Edge], + edge_groups: Sequence[EdgeGroup], shared_state: SharedState, ctx: RunnerContext, max_iterations: int = 100, @@ -29,13 +29,13 @@ def __init__( """Initialize the runner with edges, shared state, and context. Args: - edges: The edges of the workflow. + edge_groups: The edge groups of the workflow. shared_state: The shared state for the workflow. ctx: The runner context for the workflow. max_iterations: The maximum number of iterations to run. workflow_id: The workflow ID for checkpointing. """ - self._edge_map = self._parse_edges(edges) + self._edge_group_map = self._parse_edge_groups(edge_groups) self._ctx = ctx self._iteration = 0 self._max_iterations = max_iterations @@ -125,26 +125,25 @@ async def run_until_convergence(self) -> AsyncIterable[WorkflowEvent]: async def _run_iteration(self): async def _deliver_messages(source_executor_id: str, messages: list[Message]) -> None: - async def _deliver_messages_inner( - edge: Edge, - messages: list[Message], - ) -> None: - for message in messages: - if message.target_id is not None and message.target_id != edge.target_id: - continue - if not edge.can_handle(message.data): - continue - await edge.send_message(message, self._shared_state, self._ctx) - - associated_edges = self._edge_map.get(source_executor_id, []) - tasks = [asyncio.create_task(_deliver_messages_inner(edge, messages)) for edge in associated_edges] - await asyncio.gather(*tasks) + """Outer loop to concurrently deliver messages from all sources to their targets.""" + + async def _deliver_message_inner(edge_group: EdgeGroup, message: Message) -> bool: + """Inner loop to deliver a single message through an edge group.""" + return await edge_group.send_message(message, self._shared_state, self._ctx) + + associated_edge_groups = self._edge_group_map.get(source_executor_id, []) + for message in messages: + # Deliver a message through all edge groups associated with the source executor concurrently. + tasks = [_deliver_message_inner(edge_group, message) for edge_group in associated_edge_groups] + results = await asyncio.gather(*tasks) + if not any(results): + logger.warning( + f"Message {message} could not be delivered. " + "This may be due to type incompatibility or no matching targets." + ) messages = await self._ctx.drain_messages() - tasks = [ - asyncio.create_task(_deliver_messages(source_executor_id, messages)) - for source_executor_id, messages in messages.items() - ] + tasks = [_deliver_messages(source_executor_id, messages) for source_executor_id, messages in messages.items()] await asyncio.gather(*tasks) async def _create_checkpoint_if_enabled(self, checkpoint_type: str) -> str | None: @@ -177,10 +176,11 @@ async def _auto_snapshot_executor_states(self) -> None: Only JSON-serializable dicts should be provided by executors. """ executors: dict[str, Executor] = {} - for edge_list in self._edge_map.values(): - for edge in edge_list: - executors[edge.source.id] = edge.source - executors[edge.target.id] = edge.target + for edge_groups in self._edge_group_map.values(): + for edge_group in edge_groups: + for edge in edge_group.edges: + executors[edge.source.id] = edge.source + executors[edge.target.id] = edge.target for exec_id, executor in executors.items(): state_dict: dict[str, Any] | None = None snapshot = getattr(executor, "snapshot_state", None) @@ -268,16 +268,18 @@ async def _restore_shared_state_from_context(self) -> None: except Exception as e: logger.warning(f"Failed to restore shared state from context: {e}") - def _parse_edges(self, edges: list[Edge]) -> dict[str, list[Edge]]: - """Parse the edges of the workflow into a more convenient format. + def _parse_edge_groups(self, edge_groups: Sequence[EdgeGroup]) -> dict[str, list[EdgeGroup]]: + """Parse the edge groups of the workflow into a mapping where each source executor ID maps to its edge groups. Args: - edges: A list of edges in the workflow. + edge_groups: A list of edge groups in the workflow. Returns: - A dictionary mapping each source executor ID to a list of target executor IDs. + A dictionary mapping each source executor ID to a list of edge groups. """ - parsed: defaultdict[str, list[Edge]] = defaultdict(list) - for edge in edges: - parsed[edge.source_id].append(edge) + parsed: defaultdict[str, list[EdgeGroup]] = defaultdict(list) + for group in edge_groups: + for source_executor in group.source_executors: + parsed[source_executor.id].append(group) + return parsed diff --git a/python/packages/workflow/agent_framework_workflow/_typing_utils.py b/python/packages/workflow/agent_framework_workflow/_typing_utils.py index f8547d886e..46500bdf52 100644 --- a/python/packages/workflow/agent_framework_workflow/_typing_utils.py +++ b/python/packages/workflow/agent_framework_workflow/_typing_utils.py @@ -13,6 +13,10 @@ def is_instance_of(data: Any, target_type: type) -> bool: Returns: bool: True if data is an instance of target_type, False otherwise. """ + # Case 0: target_type is Any - always return True + if target_type is Any: + return True + origin = get_origin(target_type) args = get_args(target_type) diff --git a/python/packages/workflow/agent_framework_workflow/_validation.py b/python/packages/workflow/agent_framework_workflow/_validation.py index 0ce9cd2e76..3c4d8c12fd 100644 --- a/python/packages/workflow/agent_framework_workflow/_validation.py +++ b/python/packages/workflow/agent_framework_workflow/_validation.py @@ -3,10 +3,11 @@ import inspect import logging from collections import defaultdict +from collections.abc import Sequence from enum import Enum from typing import Any, Union, get_args, get_origin -from ._edge import Edge +from ._edge import Edge, EdgeGroup, FanInEdgeGroup from ._executor import Executor logger = logging.getLogger(__name__) @@ -92,18 +93,19 @@ def __init__(self): self._executors: dict[str, Executor] = {} # region Core Validation Methods - def validate_workflow(self, edges: list[Edge], start_executor: Executor | str) -> None: + def validate_workflow(self, edge_groups: Sequence[EdgeGroup], start_executor: Executor | str) -> None: """Validate the entire workflow graph. Args: - edges: list of edges in the workflow + edge_groups: list of edge groups in the workflow start_executor: The starting executor (can be instance or ID) Raises: WorkflowValidationError: If any validation fails """ - self._edges = edges - self._executors = self._build_executor_map(edges) + self._executors = self._build_executor_map(edge_groups) + self._edges = [edge for group in edge_groups for edge in group.edges] + self._edge_groups = edge_groups # Validate that start_executor exists in the graph # It should because we check for it in the WorkflowBuilder @@ -121,12 +123,13 @@ def validate_workflow(self, edges: list[Edge], start_executor: Executor | str) - self._validate_dead_ends() self._validate_cycles() - def _build_executor_map(self, edges: list[Edge]) -> dict[str, Executor]: + def _build_executor_map(self, edge_groups: Sequence[EdgeGroup]) -> dict[str, Executor]: """Build a map of executor IDs to executor instances.""" executors: dict[str, Executor] = {} - for edge in edges: - executors[edge.source_id] = edge.source - executors[edge.target_id] = edge.target + for group in edge_groups: + for executor in group.source_executors + group.target_executors: + executors[executor.id] = executor + return executors # endregion @@ -155,64 +158,80 @@ def _validate_type_compatibility(self) -> None: Raises: TypeCompatibilityError: If type incompatibility is detected """ - for edge in self._edges: - source_executor = edge.source - target_executor = edge.target - - # Get output types from source executor - source_output_types = self._get_executor_output_types(source_executor) - - # Get input types from target executor - target_input_types = self._get_executor_input_types(target_executor) - - # If either executor has no type information, log warning and skip validation - # This allows for dynamic typing scenarios but warns about reduced validation coverage - if not source_output_types or not target_input_types: - if not source_output_types: - logger.warning( - f"Executor '{source_executor.id}' has no output type annotations. " - f"Type compatibility validation will be skipped for edges from this executor. " - f"Consider adding output_types to @handler decorators for better validation." - ) - if not target_input_types: - logger.warning( - f"Executor '{target_executor.id}' has no input type annotations. " - f"Type compatibility validation will be skipped for edges to this executor. " - f"Consider adding type annotations to message handler parameters for better validation." - ) - continue - - # Check if any source output type is compatible with any target input type - compatible = False - compatible_pairs: list[tuple[type[Any], type[Any]]] = [] - - for source_type in source_output_types: - for target_type in target_input_types: - if edge.has_edge_group(): - # If the edge is part of an edge group, the target expects a list of data types - if self._is_type_compatible(list[source_type], target_type): - compatible = True - compatible_pairs.append((list[source_type], target_type)) - else: - if self._is_type_compatible(source_type, target_type): - compatible = True - compatible_pairs.append((source_type, target_type)) - - # Log successful type compatibility for debugging - if compatible: - logger.debug( - f"Type compatibility validated for edge '{source_executor.id}' -> '{target_executor.id}'. " - f"Compatible type pairs: {[(str(s), str(t)) for s, t in compatible_pairs]}" - ) + for edge_group in self._edge_groups: + for edge in edge_group.edges: + self._validate_edge_type_compatibility(edge, edge_group) + + def _validate_edge_type_compatibility(self, edge: Edge, edge_group: EdgeGroup) -> None: + """Validate type compatibility for a specific edge. + + This checks that the output types of the source executor are compatible + with the input types expected by the target executor. - if not compatible: - # Enhanced error with more detailed information - raise TypeCompatibilityError( - source_executor.id, - target_executor.id, - source_output_types, - target_input_types, + Args: + edge: The edge to validate + edge_group: The edge group containing this edge + + Raises: + TypeCompatibilityError: If type incompatibility is detected + """ + source_executor = edge.source + target_executor = edge.target + + # Get output types from source executor + source_output_types = self._get_executor_output_types(source_executor) + + # Get input types from target executor + target_input_types = self._get_executor_input_types(target_executor) + + # If either executor has no type information, log warning and skip validation + # This allows for dynamic typing scenarios but warns about reduced validation coverage + if not source_output_types or not target_input_types: + if not source_output_types: + logger.warning( + f"Executor '{source_executor.id}' has no output type annotations. " + f"Type compatibility validation will be skipped for edges from this executor. " + f"Consider adding output_types to @handler decorators for better validation." ) + if not target_input_types: + logger.warning( + f"Executor '{target_executor.id}' has no input type annotations. " + f"Type compatibility validation will be skipped for edges to this executor. " + f"Consider adding type annotations to message handler parameters for better validation." + ) + return + + # Check if any source output type is compatible with any target input type + compatible = False + compatible_pairs: list[tuple[type[Any], type[Any]]] = [] + + for source_type in source_output_types: + for target_type in target_input_types: + if isinstance(edge_group, FanInEdgeGroup): + # If the edge is part of an edge group, the target expects a list of data types + if self._is_type_compatible(list[source_type], target_type): + compatible = True + compatible_pairs.append((list[source_type], target_type)) + else: + if self._is_type_compatible(source_type, target_type): + compatible = True + compatible_pairs.append((source_type, target_type)) + + # Log successful type compatibility for debugging + if compatible: + logger.debug( + f"Type compatibility validated for edge '{source_executor.id}' -> '{target_executor.id}'. " + f"Compatible type pairs: {[(str(s), str(t)) for s, t in compatible_pairs]}" + ) + + if not compatible: + # Enhanced error with more detailed information + raise TypeCompatibilityError( + source_executor.id, + target_executor.id, + source_output_types, + target_input_types, + ) def _get_executor_output_types(self, executor: Executor) -> list[type[Any]]: """Extract output types from an executor's message handlers. @@ -479,15 +498,15 @@ def _is_type_compatible(source_type: type[Any], target_type: type[Any]) -> bool: # endregion -def validate_workflow_graph(edges: list[Edge], start_executor: Executor | str) -> None: +def validate_workflow_graph(edge_groups: Sequence[EdgeGroup], start_executor: Executor | str) -> None: """Convenience function to validate a workflow graph. Args: - edges: list of edges in the workflow + edge_groups: list of edge groups in the workflow start_executor: The starting executor (can be instance or ID) Raises: WorkflowValidationError: If any validation fails """ validator = WorkflowGraphValidator() - validator.validate_workflow(edges, start_executor) + validator.validate_workflow(edge_groups, start_executor) diff --git a/python/packages/workflow/agent_framework_workflow/_workflow.py b/python/packages/workflow/agent_framework_workflow/_workflow.py index bc3ae00f19..f2b80ce953 100644 --- a/python/packages/workflow/agent_framework_workflow/_workflow.py +++ b/python/packages/workflow/agent_framework_workflow/_workflow.py @@ -9,7 +9,15 @@ from ._checkpoint import CheckpointStorage from ._const import DEFAULT_MAX_ITERATIONS -from ._edge import Edge +from ._edge import ( + Case, + Default, + EdgeGroup, + FanInEdgeGroup, + FanOutEdgeGroup, + SingleEdgeGroup, + SwitchCaseEdgeGroup, +) from ._events import RequestInfoEvent, WorkflowCompletedEvent, WorkflowEvent from ._executor import Executor, RequestInfoExecutor from ._runner import Runner @@ -67,7 +75,7 @@ class Workflow: def __init__( self, - edges: list[Edge], + edge_groups: list[EdgeGroup], start_executor: Executor | str, runner_context: RunnerContext, max_iterations: int, @@ -75,28 +83,29 @@ def __init__( """Initialize the workflow with a list of edges. Args: - edges: A list of directed edges representing the connections between nodes in the workflow. + edge_groups: A list of EdgeGroup instances that define the workflow edges. start_executor: The starting executor for the workflow, which can be an Executor instance or its ID. runner_context: The RunnerContext instance to be used during workflow execution. max_iterations: The maximum number of iterations the workflow will run for convergence. """ - self._edges = edges + self._edge_groups = edge_groups + self._executors = self._build_executor_map(edge_groups) self._start_executor = start_executor - self._executors = {edge.source_id: edge.source for edge in edges} | { - edge.target_id: edge.target for edge in edges - } self._shared_state = SharedState() - workflow_id = str(uuid.uuid4()) self._runner = Runner( - self._edges, self._shared_state, runner_context, max_iterations=max_iterations, workflow_id=workflow_id + self._edge_groups, + self._shared_state, + runner_context, + max_iterations=max_iterations, + workflow_id=workflow_id, ) @property - def edges(self) -> list[Edge]: - """Get the list of edges in the workflow.""" - return self._edges + def edge_groups(self) -> list[EdgeGroup]: + """Get the list of edge groups in the workflow.""" + return self._edge_groups @property def start_executor(self) -> Executor: @@ -298,6 +307,22 @@ def _get_executor_by_id(self, executor_id: str) -> Executor: raise ValueError(f"Executor with ID {executor_id} not found.") return self._executors[executor_id] + def _build_executor_map(self, edge_groups: list[EdgeGroup]) -> dict[str, Executor]: + """Build the executor map from edge groups. + + Args: + edge_groups: A list of EdgeGroup instances. + + Returns: + A dictionary mapping executor IDs to Executor instances. + """ + executors: dict[str, Executor] = {} + for group in edge_groups: + for executor in group.source_executors + group.target_executors: + executors[executor.id] = executor + + return executors + async def _restore_from_external_checkpoint( self, checkpoint_id: str, checkpoint_storage: CheckpointStorage ) -> bool: @@ -405,7 +430,7 @@ class WorkflowBuilder: def __init__(self, max_iterations: int = DEFAULT_MAX_ITERATIONS): """Initialize the WorkflowBuilder with an empty list of edges and no starting executor.""" - self._edges: list[Edge] = [] + self._edge_groups: list[EdgeGroup] = [] self._start_executor: Executor | str | None = None self._checkpoint_storage: CheckpointStorage | None = None self._max_iterations: int = max_iterations @@ -427,21 +452,66 @@ def add_edge( should be traversed based on the message type. """ # TODO(@taochen): Support executor factories for lazy initialization - self._edges.append(Edge(source, target, condition)) + self._edge_groups.append(SingleEdgeGroup(source, target, condition)) return self def add_fan_out_edges(self, source: Executor, targets: Sequence[Executor]) -> "Self": - """Add multiple edges to the workflow. + """Add multiple edges to the workflow where messages from the source will be sent to all target. + + The output types of the source and the input types of the targets must be compatible. + + Args: + source: The source executor of the edges. + targets: A list of target executors for the edges. + """ + self._edge_groups.append(FanOutEdgeGroup(source, targets)) + + return self + + def add_switch_case_edge_group(self, source: Executor, cases: Sequence[Case | Default]) -> "Self": + """Add an edge group that represents a switch-case statement. The output types of the source and the input types of the targets must be compatible. - Messages from the source executor will be sent to all target executors. + Messages from the source executor will be sent to one of the target executors based on + the provided conditions. + + Think of this as a switch statement where each target executor corresponds to a case. + Each condition function will be evaluated in order, and the first one that returns True + will determine which target executor receives the message. + + The last case (the default case) will receive messages that fall through all conditions + (i.e., no condition matched). + + Args: + source: The source executor of the edges. + cases: A list of case objects that determine the target executor for each message. + """ + self._edge_groups.append(SwitchCaseEdgeGroup(source, cases)) + + return self + + def add_multi_selection_edge_group( + self, + source: Executor, + targets: Sequence[Executor], + selection_func: Callable[[Any, list[str]], list[str]], + ) -> "Self": + """Add an edge group that represents a multi-selection execution model. + + The output types of the source and the input types of the targets must be compatible. + Messages from the source executor will be sent to multiple target executors based on + the provided selection function. + + The selection function should take a message and the name of the target executors, + and return a list of indices indicating which target executors should receive the message. Args: source: The source executor of the edges. targets: A list of target executors for the edges. + selection_func: A function that selects target executors for messages. """ - for target in targets: - self._edges.append(Edge(source, target)) + self._edge_groups.append(FanOutEdgeGroup(source, targets, selection_func)) + return self def add_fan_in_edges(self, sources: Sequence[Executor], target: Executor) -> "Self": @@ -478,16 +548,7 @@ def handle_message(self, message: Message) -> None: sources: A list of source executors for the edges. target: The target executor for the edges. """ - edges = [Edge(source, target) for source in sources] - - # Set the edge groups for the edges to ensure they are processed together. - for i, edge in enumerate(edges): - group_ids: list[str] = [] - group_ids.extend([e.id for e in edges[0:i]]) - group_ids.extend([e.id for e in edges[i + 1 :]]) - edge.set_edge_group(group_ids) - - self._edges.extend(edges) + self._edge_groups.append(FanInEdgeGroup(sources, target)) return self @@ -549,11 +610,11 @@ def build(self) -> Workflow: if not self._start_executor: raise ValueError("Starting executor must be set before building the workflow.") - validate_workflow_graph(self._edges, self._start_executor) + validate_workflow_graph(self._edge_groups, self._start_executor) context = InProcRunnerContext(self._checkpoint_storage) - return Workflow(self._edges, self._start_executor, context, self._max_iterations) + return Workflow(self._edge_groups, self._start_executor, context, self._max_iterations) # endregion diff --git a/python/packages/workflow/tests/test_edge.py b/python/packages/workflow/tests/test_edge.py index b1c41c4470..ddef2cd649 100644 --- a/python/packages/workflow/tests/test_edge.py +++ b/python/packages/workflow/tests/test_edge.py @@ -2,10 +2,20 @@ from dataclasses import dataclass from typing import Any +from unittest.mock import patch +import pytest from agent_framework.workflow import Executor, WorkflowContext, handler -from agent_framework_workflow._edge import Edge +from agent_framework_workflow._edge import ( + Case, + Default, + Edge, + FanInEdgeGroup, + FanOutEdgeGroup, + SingleEdgeGroup, + SwitchCaseEdgeGroup, +) @dataclass @@ -15,6 +25,13 @@ class MockMessage: data: Any +@dataclass +class MockMessageSecondary: + """A secondary mock message for testing purposes.""" + + data: Any + + class MockExecutor(Executor): """A mock executor for testing purposes.""" @@ -24,6 +41,27 @@ async def mock_handler(self, message: MockMessage, ctx: WorkflowContext) -> None pass +class MockExecutorSecondary(Executor): + """A secondary mock executor for testing purposes.""" + + @handler + async def mock_handler_secondary(self, message: MockMessageSecondary, ctx: WorkflowContext) -> None: + """A secondary mock handler that does nothing.""" + pass + + +class MockAggregator(Executor): + """A mock aggregator for testing purposes.""" + + @handler + async def mock_aggregator_handler(self, message: list[MockMessage], ctx: WorkflowContext) -> None: + """A mock aggregator handler that does nothing.""" + pass + + +# region Edge + + def test_create_edge(): """Test creating an edge with a source and target executor.""" source = MockExecutor(id="source_executor") @@ -34,7 +72,6 @@ def test_create_edge(): assert edge.source_id == "source_executor" assert edge.target_id == "target_executor" assert edge.id == f"{edge.source_id}{Edge.ID_SEPARATOR}{edge.target_id}" - assert (edge.source_id, edge.target_id) == Edge.source_and_target_from_id(edge.id) def test_edge_can_handle(): @@ -45,3 +82,750 @@ def test_edge_can_handle(): edge = Edge(source=source, target=target) assert edge.can_handle(MockMessage(data="test")) + + +# endregion Edge + +# region SingleEdgeGroup + + +def test_single_edge_group(): + """Test creating a single edge group.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target) + + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target] + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor" + + +def test_single_edge_group_with_condition(): + """Test creating a single edge group with a condition.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target, condition=lambda x: x.data == "test") + + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target] + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor" + assert edge_group.edges[0]._condition is not None # type: ignore + + +async def test_single_edge_group_send_message(): + """Test sending a message through a single edge group.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is True + + +async def test_single_edge_group_send_message_with_target(): + """Test sending a message through a single edge group.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id=target.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is True + + +async def test_single_edge_group_send_message_with_invalid_target(): + """Test sending a message through a single edge group.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id="invalid_target") + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_single_edge_group_send_message_with_invalid_data(): + """Test sending a message through a single edge group.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +# endregion SingleEdgeGroup + + +# region FanOutEdgeGroup + + +def test_source_edge_group(): + """Test creating a fan-out group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) + + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target1, target2] + assert len(edge_group.edges) == 2 + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor_1" + assert edge_group.edges[1].source_id == "source_executor" + assert edge_group.edges[1].target_id == "target_executor_2" + + +def test_source_edge_group_invalid_number_of_targets(): + """Test creating a fan-out group with an invalid number of targets.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + with pytest.raises(ValueError, match="FanOutEdgeGroup must contain at least two targets"): + FanOutEdgeGroup(source=source, targets=[target]) + + +async def test_source_edge_group_send_message(): + """Test sending a message through a fan-out group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 2 + + +async def test_source_edge_group_send_message_with_target(): + """Test sending a message through a fan-out group with a target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id=target1.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + assert mock_send.call_args[0][0].target_id == target1.id + + +async def test_source_edge_group_send_message_with_invalid_target(): + """Test sending a message through a fan-out group with an invalid target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id="invalid_target") + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_source_edge_group_send_message_with_invalid_data(): + """Test sending a message through a fan-out group with invalid data.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_source_edge_group_send_message_only_one_successful_send(): + """Test sending a message through a fan-out group where only one edge can handle the message.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutorSecondary(id="target_executor_2") + + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + + +def test_source_edge_group_with_selection_func(): + """Test creating a partitioning edge group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup( + source=source, + targets=[target1, target2], + selection_func=lambda data, target_ids: [target1.id], + ) + + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target1, target2] + assert len(edge_group.edges) == 2 + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor_1" + assert edge_group.edges[1].source_id == "source_executor" + assert edge_group.edges[1].target_id == "target_executor_2" + + +async def test_source_edge_group_with_selection_func_send_message(): + """Test sending a message through a fan-out group with a selection function.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup( + source=source, + targets=[target1, target2], + selection_func=lambda data, target_ids: [target1.id, target2.id], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 2 + + +async def test_source_edge_group_with_selection_func_send_message_with_invalid_selection_result(): + """Test sending a message through a fan-out group with a selection func with an invalid selection result.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup( + source=source, + targets=[target1, target2], + selection_func=lambda data, target_ids: [target1.id, "invalid_target"], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + with pytest.raises(RuntimeError): + await edge_group.send_message(message, shared_state, ctx) + + +async def test_source_edge_group_with_selection_func_send_message_with_target(): + """Test sending a message through a fan-out group with a selection func with a target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup( + source=source, + targets=[target1, target2], + selection_func=lambda data, target_ids: [target1.id, target2.id], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id=target1.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + assert mock_send.call_args[0][0].target_id == target1.id + + +async def test_source_edge_group_with_selection_func_send_message_with_target_not_in_selection(): + """Test sending a message through a fan-out group with a selection func with a target not in the selection.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup( + source=source, + targets=[target1, target2], + selection_func=lambda data, target_ids: [target1.id], # Only target1 will receive the message + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id=target2.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_source_edge_group_with_selection_func_send_message_with_invalid_data(): + """Test sending a message through a fan-out group with a selection func with invalid data.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup( + source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id] + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_source_edge_group_with_selection_func_send_message_with_target_invalid_data(): + """Test sending a message through a fan-out group with a selection func with a target and invalid data.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = FanOutEdgeGroup( + source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id] + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id, target_id=target1.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +# endregion FanOutEdgeGroup + +# region FanInEdgeGroup + + +def test_target_edge_group(): + """Test creating a fan-in edge group.""" + source1 = MockExecutor(id="source_executor_1") + source2 = MockExecutor(id="source_executor_2") + target = MockAggregator(id="target_executor") + + edge_group = FanInEdgeGroup(sources=[source1, source2], target=target) + + assert edge_group.source_executors == [source1, source2] + assert edge_group.target_executors == [target] + assert len(edge_group.edges) == 2 + assert edge_group.edges[0].source_id == "source_executor_1" + assert edge_group.edges[0].target_id == "target_executor" + assert edge_group.edges[1].source_id == "source_executor_2" + assert edge_group.edges[1].target_id == "target_executor" + + +def test_target_edge_group_invalid_number_of_sources(): + """Test creating a fan-in edge group with an invalid number of sources.""" + source = MockExecutor(id="source_executor") + target = MockAggregator(id="target_executor") + + with pytest.raises(ValueError, match="FanInEdgeGroup must contain at least two sources"): + FanInEdgeGroup(sources=[source], target=target) + + +async def test_target_edge_group_send_message_buffer(): + """Test sending a message through a fan-in edge group with buffering.""" + source1 = MockExecutor(id="source_executor_1") + source2 = MockExecutor(id="source_executor_2") + target = MockAggregator(id="target_executor") + + edge_group = FanInEdgeGroup(sources=[source1, source2], target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message( + Message(data=data, source_id=source1.id), + shared_state, + ctx, + ) + + assert success is True + assert mock_send.call_count == 0 # The message should be buffered and wait for the second source + assert len(edge_group._buffer[source1.id]) == 1 # type: ignore + + success = await edge_group.send_message( + Message(data=data, source_id=source2.id), + shared_state, + ctx, + ) + assert success is True + assert mock_send.call_count == 1 # The message should be sent now that both sources have sent their messages + + # Buffer should be cleared after sending + assert not edge_group._buffer # type: ignore + + +async def test_target_edge_group_send_message_with_invalid_target(): + """Test sending a message through a fan-in edge group with an invalid target.""" + source1 = MockExecutor(id="source_executor_1") + source2 = MockExecutor(id="source_executor_2") + target = MockAggregator(id="target_executor") + + edge_group = FanInEdgeGroup(sources=[source1, source2], target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source1.id, target_id="invalid_target") + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_target_edge_group_send_message_with_invalid_data(): + """Test sending a message through a fan-in edge group with invalid data.""" + source1 = MockExecutor(id="source_executor_1") + source2 = MockExecutor(id="source_executor_2") + target = MockAggregator(id="target_executor") + + edge_group = FanInEdgeGroup(sources=[source1, source2], target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source1.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +# endregion FanInEdgeGroup + +# region SwitchCaseEdgeGroup + + +def test_switch_case_edge_group(): + """Test creating a switch case edge group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + ], + ) + + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target1, target2] + assert len(edge_group.edges) == 2 + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor_1" + assert edge_group.edges[1].source_id == "source_executor" + assert edge_group.edges[1].target_id == "target_executor_2" + + assert edge_group._selection_func is not None # type: ignore + assert edge_group._selection_func(MockMessage(data=-1), [target1.id, target2.id]) == [target1.id] # type: ignore + assert edge_group._selection_func(MockMessage(data=1), [target1.id, target2.id]) == [target2.id] # type: ignore + + +def test_switch_case_edge_group_invalid_number_of_cases(): + """Test creating a switch case edge group with an invalid number of cases.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + with pytest.raises( + ValueError, match=r"SwitchCaseEdgeGroup must contain at least two cases \(including the default case\)." + ): + SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target), + ], + ) + + with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."): + SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target), + Case(condition=lambda x: x.data >= 0, target=target), + ], + ) + + +def test_switch_case_edge_group_invalid_number_of_default_cases(): + """Test creating a switch case edge group with an invalid number of conditions.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."): + SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + Default(target=target2), + ], + ) + + +async def test_switch_case_edge_group_send_message(): + """Test sending a message through a switch case edge group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + ], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data=-1) + message = Message(data=data, source_id=source.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + + # Default condition should + data = MockMessage(data=1) + message = Message(data=data, source_id=source.id) + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + + +async def test_switch_case_edge_group_send_message_with_invalid_target(): + """Test sending a message through a switch case edge group with an invalid target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + ], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data=-1) + message = Message(data=data, source_id=source.id, target_id="invalid_target") + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_switch_case_edge_group_send_message_with_valid_target(): + """Test sending a message through a switch case edge group with a target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + ], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data=1) # Condition will fail + message = Message(data=data, source_id=source.id, target_id=target1.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + data = MockMessage(data=-1) # Condition will pass + message = Message(data=data, source_id=source.id, target_id=target1.id) + success = await edge_group.send_message(message, shared_state, ctx) + assert success is True + + +async def test_switch_case_edge_group_send_message_with_invalid_data(): + """Test sending a message through a switch case edge group with invalid data.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SwitchCaseEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + ], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +# endregion SwitchCaseEdgeGroup diff --git a/python/packages/workflow/tests/test_runner.py b/python/packages/workflow/tests/test_runner.py index a4a1abb43c..d4e4ef79ad 100644 --- a/python/packages/workflow/tests/test_runner.py +++ b/python/packages/workflow/tests/test_runner.py @@ -6,7 +6,7 @@ import pytest from agent_framework.workflow import Executor, WorkflowCompletedEvent, WorkflowContext, WorkflowEvent, handler -from agent_framework_workflow._edge import Edge +from agent_framework_workflow._edge import SingleEdgeGroup from agent_framework_workflow._runner import Runner from agent_framework_workflow._runner_context import InProcRunnerContext, RunnerContext from agent_framework_workflow._shared_state import SharedState @@ -36,12 +36,12 @@ def test_create_runner(): executor_b = MockExecutor(id="executor_b") # Create a loop - edges = [ - Edge(source=executor_a, target=executor_b), - Edge(source=executor_b, target=executor_a), + edge_groups = [ + SingleEdgeGroup(executor_a, executor_b), + SingleEdgeGroup(executor_b, executor_a), ] - runner = Runner(edges, shared_state=SharedState(), ctx=InProcRunnerContext()) + runner = Runner(edge_groups, shared_state=SharedState(), ctx=InProcRunnerContext()) assert runner.context is not None and isinstance(runner.context, RunnerContext) @@ -53,8 +53,8 @@ async def test_runner_run_until_convergence(): # Create a loop edges = [ - Edge(source=executor_a, target=executor_b), - Edge(source=executor_b, target=executor_a), + SingleEdgeGroup(executor_a, executor_b), + SingleEdgeGroup(executor_b, executor_a), ] shared_state = SharedState() @@ -87,8 +87,8 @@ async def test_runner_run_until_convergence_not_completed(): # Create a loop edges = [ - Edge(source=executor_a, target=executor_b), - Edge(source=executor_b, target=executor_a), + SingleEdgeGroup(executor_a, executor_b), + SingleEdgeGroup(executor_b, executor_a), ] shared_state = SharedState() @@ -117,8 +117,8 @@ async def test_runner_already_running(): # Create a loop edges = [ - Edge(source=executor_a, target=executor_b), - Edge(source=executor_b, target=executor_a), + SingleEdgeGroup(executor_a, executor_b), + SingleEdgeGroup(executor_b, executor_a), ] shared_state = SharedState() diff --git a/python/packages/workflow/tests/test_validation.py b/python/packages/workflow/tests/test_validation.py index 23a14a4490..8cd6e78c2f 100644 --- a/python/packages/workflow/tests/test_validation.py +++ b/python/packages/workflow/tests/test_validation.py @@ -17,7 +17,7 @@ handler, validate_workflow_graph, ) -from agent_framework_workflow._edge import Edge +from agent_framework_workflow._edge import SingleEdgeGroup class StringExecutor(Executor): @@ -159,10 +159,13 @@ def test_graph_connectivity_isolated_executors(): executor3 = StringExecutor(id="executor3") # This will be isolated # Create edges that include an isolated executor (self-loop that's not connected to main graph) - edges = [Edge(executor1, executor2), Edge(executor3, executor3)] # Self-loop to include in graph + edge_groups = [ + SingleEdgeGroup(executor1, executor2), + SingleEdgeGroup(executor3, executor3), + ] # Self-loop to include in graph with pytest.raises(GraphConnectivityError) as exc_info: - validate_workflow_graph(edges, executor1) + validate_workflow_graph(edge_groups, executor1) assert "unreachable" in str(exc_info.value).lower() assert "executor3" in str(exc_info.value) @@ -239,15 +242,15 @@ async def handle_derived(self, message: str, ctx: WorkflowContext) -> None: def test_direct_validation_function(): executor1 = StringExecutor(id="executor1") executor2 = StringExecutor(id="executor2") - edges = [Edge(executor1, executor2)] + edge_groups = [SingleEdgeGroup(executor1, executor2)] # This should not raise any exceptions - validate_workflow_graph(edges, executor1) + validate_workflow_graph(edge_groups, executor1) # Test with invalid start executor executor3 = StringExecutor(id="executor3") with pytest.raises(GraphConnectivityError): - validate_workflow_graph(edges, executor3) + validate_workflow_graph(edge_groups, executor3) def test_fan_out_validation(): diff --git a/python/packages/workflow/tests/test_workflow_builder.py b/python/packages/workflow/tests/test_workflow_builder.py index 5135314485..55b3d32ec1 100644 --- a/python/packages/workflow/tests/test_workflow_builder.py +++ b/python/packages/workflow/tests/test_workflow_builder.py @@ -60,6 +60,6 @@ def test_workflow_builder_fluent_api(): .build() ) - assert len(workflow.edges) == 6 + assert len(workflow.edge_groups) == 4 assert workflow.start_executor.id == executor_a.id assert len(workflow.executors) == 6 diff --git a/python/samples/getting_started/workflow/step_00_foundation_patterns.py b/python/samples/getting_started/workflow/step_00_foundation_patterns.py new file mode 100644 index 0000000000..2827d84fab --- /dev/null +++ b/python/samples/getting_started/workflow/step_00_foundation_patterns.py @@ -0,0 +1,291 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from typing import Any + +from agent_framework.workflow import Case, Default, Executor, WorkflowBuilder, WorkflowContext, handler + +""" +The following sample demonstrates the foundation patterns that the workflow framework supports. +These patterns include: +- Single connection +- Single connection with condition +- Fan-out and fan-in connections +- Conditional fan-out connections +- Partitioning fan-out connections + +The samples here use numbers and simple arithmetic operations to demonstrate the patterns. +""" + + +class AddOneExecutor(Executor): + """An executor that processes a number by adding one.""" + + @handler(output_types=[int]) + async def add_one(self, number: int, ctx: WorkflowContext) -> None: + """Execute the task by adding one to the input number.""" + result = number + 1 + + # Send the result to the next executor in the workflow. + await ctx.send_message(result) + + print("Adding one to the number:", number, "Result:", result) + + +class MultiplyByTwoExecutor(Executor): + """An executor that processes a number by multiplying it by two.""" + + @handler(output_types=[int]) + async def multiply_by_two(self, number: int, ctx: WorkflowContext) -> None: + """Execute the task by multiplying the input number by two.""" + result = number * 2 + + # Send the result to the next executor in the workflow. + await ctx.send_message(result) + + print("Multiplying the number by two:", number, "Result:", result) + + +class DivideByTwoExecutor(Executor): + """An executor that processes a number by dividing it by two.""" + + @handler(output_types=[float]) + async def divide_by_two(self, number: int, ctx: WorkflowContext) -> None: + """Execute the task by dividing the input number by two.""" + result = number / 2 + + # Send the result with a workflow completion event. + await ctx.send_message(result) + + print("Dividing the number by two:", number, "Result:", result) + + +class AggregateResultExecutor(Executor): + """An executor that receives results and prints them.""" + + @handler + async def aggregate_results(self, results: Any, ctx: WorkflowContext) -> None: + """Print whatever results are received.""" + print("Aggregating results:", results) + + +async def single_edge(): + """A sample to demonstrate a single directed connection between two executors. + + Three executors are connected in a sequence: AddOneExecutor -> AddOneExecutor -> AggregateResultExecutor. + + Expected output: + Adding one to the number: 1 Result: 2 + Adding one to the number: 2 Result: 3 + Aggregating results: 3 + """ + add_one_executor_a = AddOneExecutor() + add_one_executor_b = AddOneExecutor() + aggregate_result_executor = AggregateResultExecutor() + + workflow = ( + WorkflowBuilder() + .add_edge(add_one_executor_a, add_one_executor_b) + .add_edge(add_one_executor_b, aggregate_result_executor) + .set_start_executor(add_one_executor_a) + .build() + ) + + await workflow.run(1) + + +async def single_edge_with_condition(): + """A sample to demonstrate a single directed connection with a condition. + + Three executors are connected: AddOneExecutor -> AddOneExecutor, AggregateResultExecutor. + The AddOneExecutor will loop back to itself until the number reaches 10, then it will start + sending the result to AggregateResultExecutor when the number is greater than 8. The workflow + stops when the number reaches 11. + + Expected output: + Adding one to the number: 1 Result: 2 + Adding one to the number: 2 Result: 3 + Adding one to the number: 3 Result: 4 + Adding one to the number: 4 Result: 5 + Adding one to the number: 5 Result: 6 + Adding one to the number: 6 Result: 7 + Adding one to the number: 7 Result: 8 + Adding one to the number: 8 Result: 9 + Adding one to the number: 9 Result: 10 + Aggregating results: 9 + Adding one to the number: 10 Result: 11 + Aggregating results: 10 + Aggregating results: 11 + """ + add_one_executor_a = AddOneExecutor() + aggregate_result_executor = AggregateResultExecutor() + + workflow = ( + WorkflowBuilder() + .add_edge(add_one_executor_a, add_one_executor_a, condition=lambda x: x < 11) + .add_edge(add_one_executor_a, aggregate_result_executor, condition=lambda x: x > 8) + .set_start_executor(add_one_executor_a) + .build() + ) + + await workflow.run(1) + + +async def fan_out_fan_in_edge_group(): + """A sample to demonstrate a fan-out and fan-in connection between executors. + + Four executors are connected in a fan-out and fan-in pattern: + AddOneExecutor -> MultiplyByTwoExecutor, DivideByTwoExecutor -> AggregateResultExecutor. + The AddOneExecutor sends its output to both MultiplyByTwoExecutor and DivideByTwoExecutor, + and both of these executors send their results to AggregateResultExecutor. + + The target of the fan-in connection will wait for all the results from the sources before proceeding. + + Expected output: + Adding one to the number: 1 Result: 2 + Multiplying the number by two: 2 Result: 4 + Dividing the number by two: 2 Result: 1.0 + Aggregating results: [4, 1.0] + """ + add_one_executor = AddOneExecutor() + multiply_by_two_executor = MultiplyByTwoExecutor() + divide_by_two_executor = DivideByTwoExecutor() + aggregate_result_executor = AggregateResultExecutor() + + workflow = ( + WorkflowBuilder() + .add_fan_out_edges(add_one_executor, [multiply_by_two_executor, divide_by_two_executor]) + .add_fan_in_edges([multiply_by_two_executor, divide_by_two_executor], aggregate_result_executor) + .set_start_executor(add_one_executor) + .build() + ) + + await workflow.run(1) + + +async def switch_case_edge_group(): + """A sample to demonstrate a switch-case connection. + + Four executors are connected in a switch-case pattern: + AddOneExecutor -> AddOneExecutor, MultiplyByTwoExecutor, DivideByTwoExecutor -> AggregateResultExecutor. + + The message from AddOneExecutor will be evaluated against the conditions one by one, and the first condition + that evaluates to True will determine the target executors. If no conditions match, the message will be sent + to the last targets. + + This pattern resembles a switch-case statement with a default case where the first matching case is executed. + + Expected output: + Adding one to the number: 1 Result: 2 + Adding one to the number: 2 Result: 3 + Adding one to the number: 3 Result: 4 + Adding one to the number: 4 Result: 5 + Adding one to the number: 5 Result: 6 + Adding one to the number: 6 Result: 7 + Adding one to the number: 7 Result: 8 + Adding one to the number: 8 Result: 9 + Adding one to the number: 9 Result: 10 + Adding one to the number: 10 Result: 11 + Multiplying the number by two: 11 Result: 22 + """ + add_one_executor = AddOneExecutor() + multiply_by_two_executor = MultiplyByTwoExecutor() + divide_by_two_executor = DivideByTwoExecutor() + aggregate_result_executor = AggregateResultExecutor() + + workflow = ( + WorkflowBuilder() + .set_start_executor(add_one_executor) + .add_switch_case_edge_group( + source=add_one_executor, + cases=[ + # Loop back to the add_one_executor if the number is less than 11 + Case(condition=lambda x: x < 11, target=add_one_executor), + # multiply_by_two_executor when the number is larger than or equal to 11 and even. + Case(condition=lambda x: x % 2 == 0, target=multiply_by_two_executor), + # Otherwise, send to the divide_by_two_executor. + Default(target=divide_by_two_executor), + ], + ) + .add_fan_in_edges([multiply_by_two_executor, divide_by_two_executor], aggregate_result_executor) + .build() + ) + + await workflow.run(1) + + +async def multi_selection_edge_group(): + """A sample to demonstrate a multi-selection edge connection. + + Four executors are connected in a multi-selection edge pattern: + AddOneExecutor -> AddOneExecutor, MultiplyByTwoExecutor, DivideByTwoExecutor -> AggregateResultExecutor. + + The AddOneExecutor sends its output to one or more executors based on the partitioning function. + + Expected output: + Adding one to the number: 1 Result: 2 + Adding one to the number: 2 Result: 3 + Adding one to the number: 3 Result: 4 + Adding one to the number: 4 Result: 5 + Adding one to the number: 5 Result: 6 + Adding one to the number: 6 Result: 7 + Adding one to the number: 7 Result: 8 + Adding one to the number: 8 Result: 9 + Adding one to the number: 9 Result: 10 + Adding one to the number: 10 Result: 11 + Adding one to the number: 11 Result: 12 + Adding one to the number: 12 Result: 13 + Dividing the number by two: 12 Result: 6.0 + Multiplying the number by two: 13 Result: 26 + Aggregating results: [26, 6.0] + """ + add_one_executor = AddOneExecutor() + multiply_by_two_executor = MultiplyByTwoExecutor() + divide_by_two_executor = DivideByTwoExecutor() + aggregate_result_executor = AggregateResultExecutor() + + def selection_func(number: int, target_ids: list[str]) -> list[str]: + """Selection function to determine which executor to send the number to.""" + if number < 12: + # Loop back to the add_one_executor if the number is less than 12 + return [add_one_executor.id] + + if number % 2 == 0: + # Send it to the add_one_executor to add one more time and the + # divide_by_two_executor to divide the result by two. + return [add_one_executor.id, divide_by_two_executor.id] + + # Otherwise, send it to the multiply_by_two_executor to multiply the result by two. + return [multiply_by_two_executor.id] + + workflow = ( + WorkflowBuilder() + .set_start_executor(add_one_executor) + .add_multi_selection_edge_group( + add_one_executor, + [add_one_executor, multiply_by_two_executor, divide_by_two_executor], + selection_func=selection_func, + ) + .add_fan_in_edges([multiply_by_two_executor, divide_by_two_executor], aggregate_result_executor) + .build() + ) + + await workflow.run(1) + + +async def main(): + """Main function to run the workflows.""" + print("**Running single connection workflow**") + await single_edge() + print("**Running single connection with condition workflow**") + await single_edge_with_condition() + print("**Running fan-out and fan-in connection workflow**") + await fan_out_fan_in_edge_group() + print("**Running conditional fan-out connection workflow**") + await switch_case_edge_group() + print("**Running multi-selection edge group workflow**") + await multi_selection_edge_group() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py b/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py index c6f4f21f60..3a526ce7b7 100644 --- a/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py +++ b/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py @@ -3,7 +3,15 @@ import asyncio from dataclasses import dataclass -from agent_framework.workflow import Executor, WorkflowBuilder, WorkflowCompletedEvent, WorkflowContext, handler +from agent_framework.workflow import ( + Case, + Default, + Executor, + WorkflowBuilder, + WorkflowCompletedEvent, + WorkflowContext, + handler, +) """ The following sample demonstrates a basic workflow with two executors @@ -91,15 +99,12 @@ async def main(): workflow = ( WorkflowBuilder() .set_start_executor(spam_detector) - .add_edge( + .add_switch_case_edge_group( spam_detector, - send_response, - condition=lambda x: x.is_spam is False, - ) - .add_edge( - spam_detector, - remove_spam, - condition=lambda x: x.is_spam is True, + [ + Case(condition=lambda x: x.is_spam, target=remove_spam), + Default(target=send_response), + ], ) .build() ) diff --git a/python/samples/getting_started/workflow/step_04_simple_group_chat.py b/python/samples/getting_started/workflow/step_04_simple_group_chat.py index 0e0d6eda8d..496f506b6b 100644 --- a/python/samples/getting_started/workflow/step_04_simple_group_chat.py +++ b/python/samples/getting_started/workflow/step_04_simple_group_chat.py @@ -120,8 +120,7 @@ async def main(): workflow = ( WorkflowBuilder() .set_start_executor(group_chat_manager) - .add_edge(group_chat_manager, writer) - .add_edge(group_chat_manager, reviewer) + .add_fan_out_edges(group_chat_manager, [writer, reviewer]) .add_edge(writer, group_chat_manager) .add_edge(reviewer, group_chat_manager) .build() diff --git a/python/samples/getting_started/workflow/step_05_simple_group_chat_with_hil.py b/python/samples/getting_started/workflow/step_05_simple_group_chat_with_hil.py index 1c30404426..f10eb41d47 100644 --- a/python/samples/getting_started/workflow/step_05_simple_group_chat_with_hil.py +++ b/python/samples/getting_started/workflow/step_05_simple_group_chat_with_hil.py @@ -167,8 +167,7 @@ async def main(): .set_start_executor(group_chat_manager) .add_edge(group_chat_manager, request_info_executor) .add_edge(request_info_executor, group_chat_manager) - .add_edge(group_chat_manager, writer) - .add_edge(group_chat_manager, reviewer) + .add_fan_out_edges(group_chat_manager, [writer, reviewer]) .add_edge(writer, group_chat_manager) .add_edge(reviewer, group_chat_manager) .build() diff --git a/python/samples/getting_started/workflow/step_06_map_reduce.py b/python/samples/getting_started/workflow/step_06_map_reduce.py index e7665f67ed..929cb5f4c7 100644 --- a/python/samples/getting_started/workflow/step_06_map_reduce.py +++ b/python/samples/getting_started/workflow/step_06_map_reduce.py @@ -113,7 +113,6 @@ async def map(self, _: SplitCompleted, ctx: WorkflowContext) -> None: ctx: The execution context containing the shared state and other information. """ # Retrieve the data to be processed from the shared state. - # Define a key for the shared state to store the data to be processed data_to_be_processed: list[str] = await ctx.get_shared_state(SHARED_STATE_DATA_KEY) chunk_start, chunk_end = await ctx.get_shared_state(self.id)