From f89fcbd68c2241bcbbc68dae82ba08681621e6e5 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 8 Aug 2025 14:23:03 -0700 Subject: [PATCH 01/11] Introducing edge groups --- .../agent_framework_workflow/_edge.py | 267 +++++++++++++----- .../agent_framework_workflow/_executor.py | 2 +- .../agent_framework_workflow/_runner.py | 72 ++--- .../agent_framework_workflow/_validation.py | 155 +++++----- .../agent_framework_workflow/_workflow.py | 59 ++-- python/packages/workflow/tests/test_edge.py | 1 - python/packages/workflow/tests/test_runner.py | 22 +- .../workflow/tests/test_validation.py | 15 +- .../workflow/tests/test_workflow_builder.py | 2 +- .../workflow/step_06_map_reduce.py | 9 +- 10 files changed, 370 insertions(+), 234 deletions(-) diff --git a/python/packages/workflow/agent_framework_workflow/_edge.py b/python/packages/workflow/agent_framework_workflow/_edge.py index 0d7ca8bb18..1893e91978 100644 --- a/python/packages/workflow/agent_framework_workflow/_edge.py +++ b/python/packages/workflow/agent_framework_workflow/_edge.py @@ -1,7 +1,12 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from collections.abc import Callable +import logging +import sys +import uuid +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Callable, Sequence from typing import Any, ClassVar from ._executor import Executor @@ -9,6 +14,13 @@ from ._shared_state import SharedState from ._workflow_context import WorkflowContext +if sys.version_info >= (3, 12): + from typing import override # pragma: no cover +else: + from typing_extensions import override # pragma: no cover + +logger = logging.getLogger(__name__) + class Edge: """Represents a directed edge in a graph.""" @@ -34,10 +46,6 @@ def __init__( self.target = target self._condition = condition - # Edge group is used to group edges that share the same target executor. - # It allows for sending messages to the target executor only when all edges in the group have data. - self._edge_group_ids: list[str] = [] - @property def source_id(self) -> str: """Get the source executor ID.""" @@ -53,27 +61,6 @@ def id(self) -> str: """Get the unique ID of the edge.""" return f"{self.source_id}{self.ID_SEPARATOR}{self.target_id}" - def has_edge_group(self) -> bool: - """Check if the edge is part of an edge group.""" - return bool(self._edge_group_ids) - - @classmethod - def source_and_target_from_id(cls, edge_id: str) -> tuple[str, str]: - """Extract the source and target IDs from the edge ID. - - Args: - edge_id (str): The edge ID in the format "source_id->target_id". - - Returns: - tuple[str, str]: A tuple containing the source ID and target ID. - """ - if cls.ID_SEPARATOR not in edge_id: - raise ValueError(f"Invalid edge ID format: {edge_id}") - ids = edge_id.split(cls.ID_SEPARATOR) - if len(ids) != 2: - raise ValueError(f"Invalid edge ID format: {edge_id}") - return ids[0], ids[1] - def can_handle(self, message_data: Any) -> bool: """Check if the edge can handle the given data. @@ -83,11 +70,7 @@ def can_handle(self, message_data: Any) -> bool: Returns: bool: True if the edge can handle the data, False otherwise. """ - if not self._edge_group_ids: - return self.target.can_handle(message_data) - - # If the edge is part of an edge group, the target should expect a list of the data type. - return self.target.can_handle([message_data]) + return self.target.can_handle(message_data) async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> None: """Send a message along this edge. @@ -98,57 +81,201 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R ctx (RunnerContext): The context for the runner. """ if not self.can_handle(message.data): + # Caller of this method should ensure that the edge can handle the data. raise RuntimeError(f"Edge {self.id} cannot handle data of type {type(message.data)}.") - if not self._edge_group_ids and self._should_route(message.data): + if self._should_route(message.data): await self.target.execute( message.data, WorkflowContext(self.target.id, [self.source.id], shared_state, ctx) ) - elif self._edge_group_ids: - # Logic: - # 1. If not all edges in the edge group have data in the shared state, - # add the data to the shared state. - # 2. If all edges in the edge group have data in the shared state, - # copy the data to a list and send it to the target executor. - message_list: list[Message] = [] - async with shared_state.hold() as held_shared_state: - has_data = await asyncio.gather( - *(held_shared_state.has_within_hold(edge_id) for edge_id in self._edge_group_ids) - ) - if not all(has_data): - await held_shared_state.set_within_hold(self.id, message) - else: - message_list = [ - await held_shared_state.get_within_hold(edge_id) for edge_id in self._edge_group_ids - ] + [message] - # Remove the data from the shared state after retrieving it - await asyncio.gather( - *(held_shared_state.delete_within_hold(edge_id) for edge_id in self._edge_group_ids) - ) - - if message_list: - data_list = [msg.data for msg in message_list] - source_ids = [msg.source_id for msg in message_list] - await self.target.execute(data_list, WorkflowContext(self.target.id, source_ids, shared_state, ctx)) def _should_route(self, data: Any) -> bool: - """Determine if message should be routed through this edge.""" + """Determine if message should be routed through this edge based on the condition.""" if self._condition is None: return True return self._condition(data) - def set_edge_group(self, edge_group_ids: list[str]) -> None: - """Set the edge group IDs for this edge. + +class EdgeGroup(ABC): + """Represents a group of edges that share some common properties and can be triggered together.""" + + def __init__(self) -> None: + """Initialize the edge group.""" + self._id = f"{self.__class__.__name__}/{uuid.uuid4()}" + + @abstractmethod + async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + """Send a message through the edge group. Args: - edge_group_ids (list[str]): A list of edge IDs that belong to the same edge group. + message (Message): The message to send. + shared_state (SharedState): The shared state to use for holding data. + ctx (RunnerContext): The context for the runner. + + Returns: + bool: True if the message was sent successfully, False otherwise. If a message can be delivered + but rejected due to a condition, it will still return True. """ - # Validate that the edges in the edge group contain the same target executor as this edge - # TODO(@taochen): An edge cannot be part of multiple edge groups. - # TODO(@taochen): Can an edge have both a condition and an edge group? - if edge_group_ids: - for edge_id in edge_group_ids: - if Edge.source_and_target_from_id(edge_id)[1] != self.target.id: - raise ValueError("All edges in the group must have the same target executor.") - self._edge_group_ids = edge_group_ids + ... + + @property + def id(self) -> str: + """Get the unique ID of the edge group.""" + return self._id + + @abstractmethod + def source_executors(self) -> list[Executor]: + """Get the source executor IDs of the edges in the group.""" + ... + + @abstractmethod + def target_executors(self) -> list[Executor]: + """Get the target executor IDs of the edges in the group.""" + ... + + @abstractmethod + def edges(self) -> list[Edge]: + """Get the edges in the group.""" + ... + + +class SingleEdgeGroup(EdgeGroup): + """Represents a single edge group that contains only one edge.""" + + def __init__(self, source: Executor, target: Executor, condition: Callable[[Any], bool] | None = None) -> None: + """Initialize the single edge group with an edge.""" + self._edge = Edge(source=source, target=target, condition=condition) + + @override + async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + """Send a message through the single edge.""" + if message.target_id and message.target_id != self._edge.target_id: + return False + + if self._edge.can_handle(message.data): + await self._edge.send_message(message, shared_state, ctx) + return True + + return False + + @override + def source_executors(self) -> list[Executor]: + """Get the source executor of the edge.""" + return [self._edge.source] + + @override + def target_executors(self) -> list[Executor]: + """Get the target executor of the edge.""" + return [self._edge.target] + + @override + def edges(self) -> list[Edge]: + """Get the edges in the group.""" + return [self._edge] + + +class SourceEdgeGroup(EdgeGroup): + """Represents a group of edges that share the same source executor. + + Assembles a Fan-out pattern where multiple edges share the same source executor + and send messages to their respective target executors. + """ + + def __init__(self, source: Executor, targets: Sequence[Executor]) -> None: + """Initialize the source edge group with a list of edges.""" + if len(targets) <= 1: + raise ValueError("SourceEdgeGroup must contain at least two targets.") + self._edges = [Edge(source=source, target=target) for target in targets] + + @override + async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + """Send a message through all edges in the source edge group.""" + if message.target_id: + # If the message has a target ID, send it to the specific target executor + target_edge = next((edge for edge in self._edges if edge.target_id == message.target_id), None) + if target_edge and target_edge.can_handle(message.data): + await target_edge.send_message(message, shared_state, ctx) + return True + return False + + # If no target ID, send the message to all edges in the group + await asyncio.gather(*(edge.send_message(message, shared_state, ctx) for edge in self._edges)) + return True + + @override + def source_executors(self) -> list[Executor]: + """Get the source executor of the edges in the group.""" + return [self._edges[0].source] + + @override + def target_executors(self) -> list[Executor]: + """Get the target executors of the edges in the group.""" + return [edge.target for edge in self._edges] + + @override + def edges(self) -> list[Edge]: + """Get the edges in the group.""" + return self._edges + + +class TargetEdgeGroup(EdgeGroup): + """Represents a group of edges that share the same target executor. + + Assembles a Fan-in pattern where multiple edges send messages to a single target executor. + Messages are buffered until all edges in the group have data to send. + """ + + def __init__(self, sources: Sequence[Executor], target: Executor) -> None: + """Initialize the target edge group with a list of edges.""" + if len(sources) <= 1: + raise ValueError("TargetEdgeGroup must contain at least two sources.") + self._edges = [Edge(source=source, target=target) for source in sources] + # Buffer to hold messages before sending them to the target executor + # Key is the source executor ID, value is a list of messages + self._buffer: dict[str, list[Message]] = defaultdict(list) + + @override + async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + """Send a message through all edges in the target edge group.""" + if message.target_id and message.target_id != self._edges[0].target_id: + return False + + if self._edges[0].can_handle([message.data]): + # If the edge can handle the data, buffer the message + self._buffer[message.source_id].append(message) + else: + # If the edge cannot handle the data, return False + return False + + if self._is_ready_to_send(): + # If all edges in the group have data, send the buffered messages to the target executor + messages_to_send = [msg for edge in self._edges for msg in self._buffer[edge.source_id]] + self._buffer.clear() + # Only trigger one edge to send the messages to avoid duplicate sends + await self._edges[0].send_message( + Message([msg.data for msg in messages_to_send], self._edges[0].source_id), + shared_state, + ctx, + ) + + return True + + def _is_ready_to_send(self) -> bool: + """Check if all edges in the group have data to send.""" + return all(self._buffer[edge.source_id] for edge in self._edges) + + @override + def source_executors(self) -> list[Executor]: + """Get the source executors of the edges in the group.""" + return [edge.source for edge in self._edges] + + @override + def target_executors(self) -> list[Executor]: + """Get the target executor of the edges in the group.""" + return [self._edges[0].target] + + @override + def edges(self) -> list[Edge]: + """Get the edges in the group.""" + return self._edges diff --git a/python/packages/workflow/agent_framework_workflow/_executor.py b/python/packages/workflow/agent_framework_workflow/_executor.py index aa43859426..5e51dfa3a4 100644 --- a/python/packages/workflow/agent_framework_workflow/_executor.py +++ b/python/packages/workflow/agent_framework_workflow/_executor.py @@ -31,7 +31,7 @@ def __init__(self, id: str | None = None) -> None: Args: id: A unique identifier for the executor. If None, a new UUID will be generated. """ - self._id = id or str(uuid.uuid4()) + self._id = id or f"{self.__class__.__name__}/{uuid.uuid4()}" self._handlers: dict[type, Callable[[Any, WorkflowContext], Any]] = {} self._discover_handlers() diff --git a/python/packages/workflow/agent_framework_workflow/_runner.py b/python/packages/workflow/agent_framework_workflow/_runner.py index dd53d8c4da..df4e39a564 100644 --- a/python/packages/workflow/agent_framework_workflow/_runner.py +++ b/python/packages/workflow/agent_framework_workflow/_runner.py @@ -3,9 +3,9 @@ import asyncio import logging from collections import defaultdict -from collections.abc import AsyncIterable +from collections.abc import AsyncIterable, Sequence -from ._edge import Edge +from ._edge import EdgeGroup from ._events import WorkflowEvent from ._runner_context import Message, RunnerContext from ._shared_state import SharedState @@ -20,7 +20,7 @@ class Runner: def __init__( self, - edges: list[Edge], + edge_groups: Sequence[EdgeGroup], shared_state: SharedState, ctx: RunnerContext, max_iterations: int = DEFAULT_MAX_ITERATIONS, @@ -28,12 +28,12 @@ def __init__( """Initialize the runner with edges, shared state, and context. Args: - edges: The edges of the workflow. + edge_groups: The edge groups of the workflow. shared_state: The shared state for the workflow. ctx: The runner context for the workflow. max_iterations: The maximum number of iterations to run. """ - self._edge_map = self._parse_edges(edges) + self._edge_group_map = self._parse_edge_groups(edge_groups) self._ctx = ctx self._iteration = 0 self._max_iterations = max_iterations @@ -72,49 +72,39 @@ async def _run_iteration(self): """Run a superstep of the workflow execution.""" async def _deliver_messages(source_executor_id: str, messages: list[Message]) -> None: - """Deliver messages to the executors. - - Outer loop to concurrently deliver messages from all sources to their targets. - """ - - async def _deliver_messages_inner( - edge: Edge, - messages: list[Message], - ) -> None: - """Deliver messages to a specific target executor. - - Inner loop to deliver messages to a specific target executor. - """ - for message in messages: - if message.target_id is not None and message.target_id != edge.target_id: - continue - - if not edge.can_handle(message.data): - continue - - await edge.send_message(message, self._shared_state, self._ctx) - - associated_edges = self._edge_map.get(source_executor_id, []) - tasks = [asyncio.create_task(_deliver_messages_inner(edge, messages)) for edge in associated_edges] - await asyncio.gather(*tasks) + """Outer loop to concurrently deliver messages from all sources to their targets.""" + + async def _deliver_message_inner(edge_group: EdgeGroup, message: Message) -> bool: + """Inner loop to deliver a single message through an edge group.""" + return await edge_group.send_message(message, self._shared_state, self._ctx) + + associated_edge_groups = self._edge_group_map.get(source_executor_id, []) + for message in messages: + # Deliver a message through all edge groups associated with the source executor concurrently. + tasks = [_deliver_message_inner(edge_group, message) for edge_group in associated_edge_groups] + results = await asyncio.gather(*tasks) + if not any(results): + logger.warning( + f"Message {message} could not be delivered. " + "This may be due to type incompatibility or no matching targets." + ) messages = await self._ctx.drain_messages() - tasks = [ - asyncio.create_task(_deliver_messages(source_executor_id, messages)) - for source_executor_id, messages in messages.items() - ] + tasks = [_deliver_messages(source_executor_id, messages) for source_executor_id, messages in messages.items()] await asyncio.gather(*tasks) - def _parse_edges(self, edges: list[Edge]) -> dict[str, list[Edge]]: - """Parse the edges of the workflow into a more convenient format. + def _parse_edge_groups(self, edge_groups: Sequence[EdgeGroup]) -> dict[str, list[EdgeGroup]]: + """Parse the edge groups of the workflow into a mapping where each source executor ID maps to its edge groups. Args: - edges: A list of edges in the workflow. + edge_groups: A list of edge groups in the workflow. Returns: - A dictionary mapping each source executor ID to a list of target executor IDs. + A dictionary mapping each source executor ID to a list of edge groups. """ - parsed: defaultdict[str, list[Edge]] = defaultdict(list) - for edge in edges: - parsed[edge.source_id].append(edge) + parsed: defaultdict[str, list[EdgeGroup]] = defaultdict(list) + for group in edge_groups: + for source_executor in group.source_executors(): + parsed[source_executor.id].append(group) + return parsed diff --git a/python/packages/workflow/agent_framework_workflow/_validation.py b/python/packages/workflow/agent_framework_workflow/_validation.py index 0ce9cd2e76..ad7b731d0d 100644 --- a/python/packages/workflow/agent_framework_workflow/_validation.py +++ b/python/packages/workflow/agent_framework_workflow/_validation.py @@ -3,10 +3,11 @@ import inspect import logging from collections import defaultdict +from collections.abc import Sequence from enum import Enum from typing import Any, Union, get_args, get_origin -from ._edge import Edge +from ._edge import Edge, EdgeGroup, TargetEdgeGroup from ._executor import Executor logger = logging.getLogger(__name__) @@ -92,18 +93,19 @@ def __init__(self): self._executors: dict[str, Executor] = {} # region Core Validation Methods - def validate_workflow(self, edges: list[Edge], start_executor: Executor | str) -> None: + def validate_workflow(self, edge_groups: Sequence[EdgeGroup], start_executor: Executor | str) -> None: """Validate the entire workflow graph. Args: - edges: list of edges in the workflow + edge_groups: list of edge groups in the workflow start_executor: The starting executor (can be instance or ID) Raises: WorkflowValidationError: If any validation fails """ - self._edges = edges - self._executors = self._build_executor_map(edges) + self._executors = self._build_executor_map(edge_groups) + self._edges = [edge for group in edge_groups for edge in group.edges()] + self._edge_groups = edge_groups # Validate that start_executor exists in the graph # It should because we check for it in the WorkflowBuilder @@ -121,12 +123,13 @@ def validate_workflow(self, edges: list[Edge], start_executor: Executor | str) - self._validate_dead_ends() self._validate_cycles() - def _build_executor_map(self, edges: list[Edge]) -> dict[str, Executor]: + def _build_executor_map(self, edge_groups: Sequence[EdgeGroup]) -> dict[str, Executor]: """Build a map of executor IDs to executor instances.""" executors: dict[str, Executor] = {} - for edge in edges: - executors[edge.source_id] = edge.source - executors[edge.target_id] = edge.target + for group in edge_groups: + for executor in group.source_executors() + group.target_executors(): + executors[executor.id] = executor + return executors # endregion @@ -155,64 +158,80 @@ def _validate_type_compatibility(self) -> None: Raises: TypeCompatibilityError: If type incompatibility is detected """ - for edge in self._edges: - source_executor = edge.source - target_executor = edge.target - - # Get output types from source executor - source_output_types = self._get_executor_output_types(source_executor) - - # Get input types from target executor - target_input_types = self._get_executor_input_types(target_executor) - - # If either executor has no type information, log warning and skip validation - # This allows for dynamic typing scenarios but warns about reduced validation coverage - if not source_output_types or not target_input_types: - if not source_output_types: - logger.warning( - f"Executor '{source_executor.id}' has no output type annotations. " - f"Type compatibility validation will be skipped for edges from this executor. " - f"Consider adding output_types to @handler decorators for better validation." - ) - if not target_input_types: - logger.warning( - f"Executor '{target_executor.id}' has no input type annotations. " - f"Type compatibility validation will be skipped for edges to this executor. " - f"Consider adding type annotations to message handler parameters for better validation." - ) - continue - - # Check if any source output type is compatible with any target input type - compatible = False - compatible_pairs: list[tuple[type[Any], type[Any]]] = [] - - for source_type in source_output_types: - for target_type in target_input_types: - if edge.has_edge_group(): - # If the edge is part of an edge group, the target expects a list of data types - if self._is_type_compatible(list[source_type], target_type): - compatible = True - compatible_pairs.append((list[source_type], target_type)) - else: - if self._is_type_compatible(source_type, target_type): - compatible = True - compatible_pairs.append((source_type, target_type)) - - # Log successful type compatibility for debugging - if compatible: - logger.debug( - f"Type compatibility validated for edge '{source_executor.id}' -> '{target_executor.id}'. " - f"Compatible type pairs: {[(str(s), str(t)) for s, t in compatible_pairs]}" - ) + for edge_group in self._edge_groups: + for edge in edge_group.edges(): + self._validate_edge_type_compatibility(edge, edge_group) + + def _validate_edge_type_compatibility(self, edge: Edge, edge_group: EdgeGroup) -> None: + """Validate type compatibility for a specific edge. + + This checks that the output types of the source executor are compatible + with the input types expected by the target executor. - if not compatible: - # Enhanced error with more detailed information - raise TypeCompatibilityError( - source_executor.id, - target_executor.id, - source_output_types, - target_input_types, + Args: + edge: The edge to validate + edge_group: The edge group containing this edge + + Raises: + TypeCompatibilityError: If type incompatibility is detected + """ + source_executor = edge.source + target_executor = edge.target + + # Get output types from source executor + source_output_types = self._get_executor_output_types(source_executor) + + # Get input types from target executor + target_input_types = self._get_executor_input_types(target_executor) + + # If either executor has no type information, log warning and skip validation + # This allows for dynamic typing scenarios but warns about reduced validation coverage + if not source_output_types or not target_input_types: + if not source_output_types: + logger.warning( + f"Executor '{source_executor.id}' has no output type annotations. " + f"Type compatibility validation will be skipped for edges from this executor. " + f"Consider adding output_types to @handler decorators for better validation." ) + if not target_input_types: + logger.warning( + f"Executor '{target_executor.id}' has no input type annotations. " + f"Type compatibility validation will be skipped for edges to this executor. " + f"Consider adding type annotations to message handler parameters for better validation." + ) + return + + # Check if any source output type is compatible with any target input type + compatible = False + compatible_pairs: list[tuple[type[Any], type[Any]]] = [] + + for source_type in source_output_types: + for target_type in target_input_types: + if isinstance(edge_group, TargetEdgeGroup): + # If the edge is part of an edge group, the target expects a list of data types + if self._is_type_compatible(list[source_type], target_type): + compatible = True + compatible_pairs.append((list[source_type], target_type)) + else: + if self._is_type_compatible(source_type, target_type): + compatible = True + compatible_pairs.append((source_type, target_type)) + + # Log successful type compatibility for debugging + if compatible: + logger.debug( + f"Type compatibility validated for edge '{source_executor.id}' -> '{target_executor.id}'. " + f"Compatible type pairs: {[(str(s), str(t)) for s, t in compatible_pairs]}" + ) + + if not compatible: + # Enhanced error with more detailed information + raise TypeCompatibilityError( + source_executor.id, + target_executor.id, + source_output_types, + target_input_types, + ) def _get_executor_output_types(self, executor: Executor) -> list[type[Any]]: """Extract output types from an executor's message handlers. @@ -479,15 +498,15 @@ def _is_type_compatible(source_type: type[Any], target_type: type[Any]) -> bool: # endregion -def validate_workflow_graph(edges: list[Edge], start_executor: Executor | str) -> None: +def validate_workflow_graph(edge_groups: Sequence[EdgeGroup], start_executor: Executor | str) -> None: """Convenience function to validate a workflow graph. Args: - edges: list of edges in the workflow + edge_groups: list of edge groups in the workflow start_executor: The starting executor (can be instance or ID) Raises: WorkflowValidationError: If any validation fails """ validator = WorkflowGraphValidator() - validator.validate_workflow(edges, start_executor) + validator.validate_workflow(edge_groups, start_executor) diff --git a/python/packages/workflow/agent_framework_workflow/_workflow.py b/python/packages/workflow/agent_framework_workflow/_workflow.py index 84cc8178b5..63fc2b0c5e 100644 --- a/python/packages/workflow/agent_framework_workflow/_workflow.py +++ b/python/packages/workflow/agent_framework_workflow/_workflow.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, Callable, Sequence from typing import Any -from ._edge import Edge +from ._edge import EdgeGroup, SingleEdgeGroup, SourceEdgeGroup, TargetEdgeGroup from ._events import RequestInfoEvent, WorkflowCompletedEvent, WorkflowEvent from ._executor import Executor, RequestInfoExecutor from ._runner import DEFAULT_MAX_ITERATIONS, Runner @@ -57,7 +57,7 @@ class Workflow: def __init__( self, - edges: list[Edge], + edge_groups: list[EdgeGroup], start_executor: Executor | str, runner_context: RunnerContext, max_iterations: int, @@ -65,24 +65,22 @@ def __init__( """Initialize the workflow with a list of edges. Args: - edges: A list of directed edges representing the connections between nodes in the workflow. + edge_groups: A list of EdgeGroup instances that define the workflow edges. start_executor: The starting executor for the workflow, which can be an Executor instance or its ID. runner_context: The RunnerContext instance to be used during workflow execution. max_iterations: The maximum number of iterations the workflow will run for convergence. """ - self._edges = edges + self._edge_groups = edge_groups + self._executors = self._build_executor_map(edge_groups) self._start_executor = start_executor - self._executors = {edge.source_id: edge.source for edge in edges} | { - edge.target_id: edge.target for edge in edges - } self._shared_state = SharedState() - self._runner = Runner(self._edges, self._shared_state, runner_context, max_iterations=max_iterations) + self._runner = Runner(self._edge_groups, self._shared_state, runner_context, max_iterations=max_iterations) @property - def edges(self) -> list[Edge]: - """Get the list of edges in the workflow.""" - return self._edges + def edge_groups(self) -> list[EdgeGroup]: + """Get the list of edge groups in the workflow.""" + return self._edge_groups @property def start_executor(self) -> Executor: @@ -202,6 +200,22 @@ def _get_executor_by_id(self, executor_id: str) -> Executor: raise ValueError(f"Executor with ID {executor_id} not found.") return self._executors[executor_id] + def _build_executor_map(self, edge_groups: list[EdgeGroup]) -> dict[str, Executor]: + """Build the executor map from edge groups. + + Args: + edge_groups: A list of EdgeGroup instances. + + Returns: + A dictionary mapping executor IDs to Executor instances. + """ + executors: dict[str, Executor] = {} + for group in edge_groups: + for executor in group.source_executors() + group.target_executors(): + executors[executor.id] = executor + + return executors + class WorkflowBuilder: """A builder class for constructing workflows. @@ -211,7 +225,7 @@ class WorkflowBuilder: def __init__(self): """Initialize the WorkflowBuilder with an empty list of edges and no starting executor.""" - self._edges: list[Edge] = [] + self._edge_groups: list[EdgeGroup] = [] self._start_executor: Executor | str | None = None self._max_iterations: int = DEFAULT_MAX_ITERATIONS @@ -232,7 +246,7 @@ def add_edge( should be traversed based on the message type. """ # TODO(@taochen): Support executor factories for lazy initialization - self._edges.append(Edge(source, target, condition)) + self._edge_groups.append(SingleEdgeGroup(source, target, condition)) return self def add_fan_out_edges(self, source: Executor, targets: Sequence[Executor]) -> "Self": @@ -245,8 +259,8 @@ def add_fan_out_edges(self, source: Executor, targets: Sequence[Executor]) -> "S source: The source executor of the edges. targets: A list of target executors for the edges. """ - for target in targets: - self._edges.append(Edge(source, target)) + self._edge_groups.append(SourceEdgeGroup(source, targets)) + return self def add_fan_in_edges(self, sources: Sequence[Executor], target: Executor) -> "Self": @@ -283,16 +297,7 @@ def handle_message(self, message: Message) -> None: sources: A list of source executors for the edges. target: The target executor for the edges. """ - edges = [Edge(source, target) for source in sources] - - # Set the edge groups for the edges to ensure they are processed together. - for i, edge in enumerate(edges): - group_ids: list[str] = [] - group_ids.extend([e.id for e in edges[0:i]]) - group_ids.extend([e.id for e in edges[i + 1 :]]) - edge.set_edge_group(group_ids) - - self._edges.extend(edges) + self._edge_groups.append(TargetEdgeGroup(sources, target)) return self @@ -345,6 +350,6 @@ def build(self) -> Workflow: if not self._start_executor: raise ValueError("Starting executor must be set before building the workflow.") - validate_workflow_graph(self._edges, self._start_executor) + validate_workflow_graph(self._edge_groups, self._start_executor) - return Workflow(self._edges, self._start_executor, InProcRunnerContext(), self._max_iterations) + return Workflow(self._edge_groups, self._start_executor, InProcRunnerContext(), self._max_iterations) diff --git a/python/packages/workflow/tests/test_edge.py b/python/packages/workflow/tests/test_edge.py index b1c41c4470..0a9f974602 100644 --- a/python/packages/workflow/tests/test_edge.py +++ b/python/packages/workflow/tests/test_edge.py @@ -34,7 +34,6 @@ def test_create_edge(): assert edge.source_id == "source_executor" assert edge.target_id == "target_executor" assert edge.id == f"{edge.source_id}{Edge.ID_SEPARATOR}{edge.target_id}" - assert (edge.source_id, edge.target_id) == Edge.source_and_target_from_id(edge.id) def test_edge_can_handle(): diff --git a/python/packages/workflow/tests/test_runner.py b/python/packages/workflow/tests/test_runner.py index a4a1abb43c..d4e4ef79ad 100644 --- a/python/packages/workflow/tests/test_runner.py +++ b/python/packages/workflow/tests/test_runner.py @@ -6,7 +6,7 @@ import pytest from agent_framework.workflow import Executor, WorkflowCompletedEvent, WorkflowContext, WorkflowEvent, handler -from agent_framework_workflow._edge import Edge +from agent_framework_workflow._edge import SingleEdgeGroup from agent_framework_workflow._runner import Runner from agent_framework_workflow._runner_context import InProcRunnerContext, RunnerContext from agent_framework_workflow._shared_state import SharedState @@ -36,12 +36,12 @@ def test_create_runner(): executor_b = MockExecutor(id="executor_b") # Create a loop - edges = [ - Edge(source=executor_a, target=executor_b), - Edge(source=executor_b, target=executor_a), + edge_groups = [ + SingleEdgeGroup(executor_a, executor_b), + SingleEdgeGroup(executor_b, executor_a), ] - runner = Runner(edges, shared_state=SharedState(), ctx=InProcRunnerContext()) + runner = Runner(edge_groups, shared_state=SharedState(), ctx=InProcRunnerContext()) assert runner.context is not None and isinstance(runner.context, RunnerContext) @@ -53,8 +53,8 @@ async def test_runner_run_until_convergence(): # Create a loop edges = [ - Edge(source=executor_a, target=executor_b), - Edge(source=executor_b, target=executor_a), + SingleEdgeGroup(executor_a, executor_b), + SingleEdgeGroup(executor_b, executor_a), ] shared_state = SharedState() @@ -87,8 +87,8 @@ async def test_runner_run_until_convergence_not_completed(): # Create a loop edges = [ - Edge(source=executor_a, target=executor_b), - Edge(source=executor_b, target=executor_a), + SingleEdgeGroup(executor_a, executor_b), + SingleEdgeGroup(executor_b, executor_a), ] shared_state = SharedState() @@ -117,8 +117,8 @@ async def test_runner_already_running(): # Create a loop edges = [ - Edge(source=executor_a, target=executor_b), - Edge(source=executor_b, target=executor_a), + SingleEdgeGroup(executor_a, executor_b), + SingleEdgeGroup(executor_b, executor_a), ] shared_state = SharedState() diff --git a/python/packages/workflow/tests/test_validation.py b/python/packages/workflow/tests/test_validation.py index 23a14a4490..8cd6e78c2f 100644 --- a/python/packages/workflow/tests/test_validation.py +++ b/python/packages/workflow/tests/test_validation.py @@ -17,7 +17,7 @@ handler, validate_workflow_graph, ) -from agent_framework_workflow._edge import Edge +from agent_framework_workflow._edge import SingleEdgeGroup class StringExecutor(Executor): @@ -159,10 +159,13 @@ def test_graph_connectivity_isolated_executors(): executor3 = StringExecutor(id="executor3") # This will be isolated # Create edges that include an isolated executor (self-loop that's not connected to main graph) - edges = [Edge(executor1, executor2), Edge(executor3, executor3)] # Self-loop to include in graph + edge_groups = [ + SingleEdgeGroup(executor1, executor2), + SingleEdgeGroup(executor3, executor3), + ] # Self-loop to include in graph with pytest.raises(GraphConnectivityError) as exc_info: - validate_workflow_graph(edges, executor1) + validate_workflow_graph(edge_groups, executor1) assert "unreachable" in str(exc_info.value).lower() assert "executor3" in str(exc_info.value) @@ -239,15 +242,15 @@ async def handle_derived(self, message: str, ctx: WorkflowContext) -> None: def test_direct_validation_function(): executor1 = StringExecutor(id="executor1") executor2 = StringExecutor(id="executor2") - edges = [Edge(executor1, executor2)] + edge_groups = [SingleEdgeGroup(executor1, executor2)] # This should not raise any exceptions - validate_workflow_graph(edges, executor1) + validate_workflow_graph(edge_groups, executor1) # Test with invalid start executor executor3 = StringExecutor(id="executor3") with pytest.raises(GraphConnectivityError): - validate_workflow_graph(edges, executor3) + validate_workflow_graph(edge_groups, executor3) def test_fan_out_validation(): diff --git a/python/packages/workflow/tests/test_workflow_builder.py b/python/packages/workflow/tests/test_workflow_builder.py index 5135314485..55b3d32ec1 100644 --- a/python/packages/workflow/tests/test_workflow_builder.py +++ b/python/packages/workflow/tests/test_workflow_builder.py @@ -60,6 +60,6 @@ def test_workflow_builder_fluent_api(): .build() ) - assert len(workflow.edges) == 6 + assert len(workflow.edge_groups) == 4 assert workflow.start_executor.id == executor_a.id assert len(workflow.executors) == 6 diff --git a/python/samples/getting_started/workflow/step_06_map_reduce.py b/python/samples/getting_started/workflow/step_06_map_reduce.py index 8879bce982..929cb5f4c7 100644 --- a/python/samples/getting_started/workflow/step_06_map_reduce.py +++ b/python/samples/getting_started/workflow/step_06_map_reduce.py @@ -3,7 +3,6 @@ import ast import asyncio import os -import sys from collections import defaultdict from dataclasses import dataclass @@ -16,12 +15,6 @@ handler, ) -if sys.version_info >= (3, 12): - pass # pragma: no cover -else: - pass # pragma: no cover - - """ The following sample demonstrates a basic map reduce workflow that processes a large text file by splitting it into smaller chunks, @@ -119,7 +112,7 @@ async def map(self, _: SplitCompleted, ctx: WorkflowContext) -> None: data: An instance of SplitCompleted signaling the map step can be started. ctx: The execution context containing the shared state and other information. """ - # Retrieve the data to be processed from the shared state.# Define a key for the shared state to store the data to be processed + # Retrieve the data to be processed from the shared state. data_to_be_processed: list[str] = await ctx.get_shared_state(SHARED_STATE_DATA_KEY) chunk_start, chunk_end = await ctx.get_shared_state(self.id) From f10e6d8aed31bc5778cbb712d6f3e0dac39f98ce Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 8 Aug 2025 17:23:06 -0700 Subject: [PATCH 02/11] Add conditional and partitioning edge groups; next add samples and tests --- .../agent_framework_workflow/_edge.py | 166 ++++++++++++++++-- .../agent_framework_workflow/_workflow.py | 58 +++++- 2 files changed, 212 insertions(+), 12 deletions(-) diff --git a/python/packages/workflow/agent_framework_workflow/_edge.py b/python/packages/workflow/agent_framework_workflow/_edge.py index 1893e91978..3fdb2c3a74 100644 --- a/python/packages/workflow/agent_framework_workflow/_edge.py +++ b/python/packages/workflow/agent_framework_workflow/_edge.py @@ -72,6 +72,13 @@ def can_handle(self, message_data: Any) -> bool: """ return self.target.can_handle(message_data) + def should_route(self, data: Any) -> bool: + """Determine if message should be routed through this edge based on the condition.""" + if self._condition is None: + return True + + return self._condition(data) + async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> None: """Send a message along this edge. @@ -84,18 +91,11 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R # Caller of this method should ensure that the edge can handle the data. raise RuntimeError(f"Edge {self.id} cannot handle data of type {type(message.data)}.") - if self._should_route(message.data): + if self.should_route(message.data): await self.target.execute( message.data, WorkflowContext(self.target.id, [self.source.id], shared_state, ctx) ) - def _should_route(self, data: Any) -> bool: - """Determine if message should be routed through this edge based on the condition.""" - if self._condition is None: - return True - - return self._condition(data) - class EdgeGroup(ABC): """Represents a group of edges that share some common properties and can be triggered together.""" @@ -144,7 +144,15 @@ class SingleEdgeGroup(EdgeGroup): """Represents a single edge group that contains only one edge.""" def __init__(self, source: Executor, target: Executor, condition: Callable[[Any], bool] | None = None) -> None: - """Initialize the single edge group with an edge.""" + """Initialize the single edge group with an edge. + + Args: + source (Executor): The source executor. + target (Executor): The target executor that the source executor can send messages to. + condition (Callable[[Any], bool], optional): A condition function that determines + if the edge will pass the data to the target executor. If None, the edge can + will always pass the data to the target executor. + """ self._edge = Edge(source=source, target=target, condition=condition) @override @@ -183,7 +191,12 @@ class SourceEdgeGroup(EdgeGroup): """ def __init__(self, source: Executor, targets: Sequence[Executor]) -> None: - """Initialize the source edge group with a list of edges.""" + """Initialize the source edge group with a list of edges. + + Args: + source (Executor): The source executor. + targets (Sequence[Executor]): A list of target executors that the source executor can send messages to. + """ if len(targets) <= 1: raise ValueError("SourceEdgeGroup must contain at least two targets.") self._edges = [Edge(source=source, target=target) for target in targets] @@ -227,7 +240,12 @@ class TargetEdgeGroup(EdgeGroup): """ def __init__(self, sources: Sequence[Executor], target: Executor) -> None: - """Initialize the target edge group with a list of edges.""" + """Initialize the target edge group with a list of edges. + + Args: + sources (Sequence[Executor]): A list of source executors that can send messages to the target executor. + target (Executor): The target executor that receives a list of messages aggregated from all sources. + """ if len(sources) <= 1: raise ValueError("TargetEdgeGroup must contain at least two sources.") self._edges = [Edge(source=source, target=target) for source in sources] @@ -279,3 +297,129 @@ def target_executors(self) -> list[Executor]: def edges(self) -> list[Edge]: """Get the edges in the group.""" return self._edges + + +class ConditionalEdgeGroup(SourceEdgeGroup): + """Represents a group of edges that assemble a conditional routing pattern. + + This is similar to a switch-case construct: + switch(data): + case condition_1: + edge_1 + break + case condition_2: + edge_2 + break + default: + edge_3 + break + Or equivalently an if-elif-else construct: + if condition_1: + edge_1 + elif condition_2: + edge_2 + else: + edge_4 + """ + + def __init__( + self, + source: Executor, + targets: Sequence[Executor], + conditions: Sequence[Callable[[Any], bool]], + ) -> None: + """Initialize the conditional edge group with a list of edges. + + Args: + source (Executor): The source executor. + targets (Sequence[Executor]): A list of target executors that the source executor can send messages to. + conditions (Sequence[Callable[[Any], bool]]): A list of condition functions that determine + which target executor to route the message to based on the data. The number of conditions + must be one less than the number of targets, as the last target is the default case. The + index of the condition corresponds to the index of the target executor. + """ + if len(targets) <= 1: + raise ValueError("ConditionalEdgeGroup must contain at least two targets.") + + if len(targets) != len(conditions) + 1: + raise ValueError("Number of targets must be one more than the number of conditions.") + + self._edges = [ + Edge(source, target, condition) for target, condition in zip(targets, [*conditions, None], strict=False) + ] + + @override + async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + """Send a message through the conditional edge group.""" + if message.target_id: + # Find the index of the target edge in the edges list if target_id is specified + index = next((i for i, edge in enumerate(self._edges) if edge.target_id == message.target_id), None) + if index is None: + return False + if self._edges[index].can_handle(message.data) and self._edges[index].should_route(message.data): + await self._edges[index].send_message(message, shared_state, ctx) + return True + return False + + for edge in self._edges: + if edge.can_handle(message.data) and edge.should_route(message.data): + await edge.send_message(message, shared_state, ctx) + return True + + return False + + +class PartitioningEdgeGroup(SourceEdgeGroup): + """Represents a group of edges that can route messages based on a partitioning strategy. + + Messages from the source executor are routed to multiple target executors based on a partitioning function. + """ + + def __init__( + self, source: Executor, targets: Sequence[Executor], partition_func: Callable[[Any, int], list[int]] + ) -> None: + """Initialize the partitioning edge group with a list of edges. + + Args: + source (Executor): The source executor. + targets (Sequence[Executor]): A list of target executors that the source executor can send messages to. + partition_func (Callable[[Any, int], list[int]]): A partitioning function that determines which target + executors to route the message to based on the data. The function should take the message data and + the number of targets, and return a list of indices of the target executors to route the message to. + """ + self._edges = [Edge(source=source, target=target) for target in targets] + self._partition_func = partition_func + + @override + async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: + """Send a message through the partitioning edge group.""" + partition_result = self._partition_func(message.data, len(self._edges)) + if not self._validate_partition_result(partition_result): + raise RuntimeError( + f"Invalid partition result: {partition_result}. Expected indices in range [0, {len(self._edges) - 1}]." + ) + + if message.target_id: + # If the target ID is specified and the partition result contains it, send the message to that edge + has_target = message.target_id in [self._edges[index].target_id for index in partition_result] + if has_target: + edge = next((edge for edge in self._edges if edge.target_id == message.target_id), None) + if edge and edge.can_handle(message.data): + await edge.send_message(message, shared_state, ctx) + return True + return False + + async def send_to_edge(edge: Edge) -> bool: + """Send the message to the edge at the specified index.""" + if edge.can_handle(message.data): + await edge.send_message(message, shared_state, ctx) + return True + return False + + tasks = [send_to_edge(self._edges[index]) for index in partition_result] + results = await asyncio.gather(*tasks) + return any(results) + + def _validate_partition_result(self, partition_result: list[int]) -> bool: + """Validate the partition result to ensure all indices are within bounds.""" + return all(0 <= index < len(self._edges) for index in partition_result) diff --git a/python/packages/workflow/agent_framework_workflow/_workflow.py b/python/packages/workflow/agent_framework_workflow/_workflow.py index 63fc2b0c5e..cc85607cd6 100644 --- a/python/packages/workflow/agent_framework_workflow/_workflow.py +++ b/python/packages/workflow/agent_framework_workflow/_workflow.py @@ -5,7 +5,14 @@ from collections.abc import AsyncIterable, Callable, Sequence from typing import Any -from ._edge import EdgeGroup, SingleEdgeGroup, SourceEdgeGroup, TargetEdgeGroup +from ._edge import ( + ConditionalEdgeGroup, + EdgeGroup, + PartitioningEdgeGroup, + SingleEdgeGroup, + SourceEdgeGroup, + TargetEdgeGroup, +) from ._events import RequestInfoEvent, WorkflowCompletedEvent, WorkflowEvent from ._executor import Executor, RequestInfoExecutor from ._runner import DEFAULT_MAX_ITERATIONS, Runner @@ -263,6 +270,55 @@ def add_fan_out_edges(self, source: Executor, targets: Sequence[Executor]) -> "S return self + def add_conditional_fan_out_edges( + self, source: Executor, targets: Sequence[Executor], conditions: Sequence[Callable[[Any], bool]] + ) -> "Self": + """Add a conditional fan out group of edges to the workflow. + + The output types of the source and the input types of the targets must be compatible. + Messages from the source executor will be sent to one of the target executors based on + the provided conditions. + + Think of this as a switch statement where each target executor corresponds to a case. + Each condition function will be evaluated in order, and the first one that returns True + will determine which target executor receives the message. + + The number of targets must be one greater than the number of conditions. The last target + executor will receive messages that fall through all conditions (i.e., no condition matched). + + Args: + source: The source executor of the edges. + targets: A list of target executors for the edges. + conditions: A list of condition functions that determine whether each edge should be traversed. + """ + self._edge_groups.append(ConditionalEdgeGroup(source, targets, conditions)) + + return self + + def add_partitioning_fan_out_edges( + self, + source: Executor, + targets: Sequence[Executor], + partition_func: Callable[[Any, int], list[int]], + ) -> "Self": + """Add a partitioning fan out group of edges to the workflow. + + The output types of the source and the input types of the targets must be compatible. + Messages from the source executor will be sent to multiple target executors based on + the provided partition function. + + The partition function should take a message and the number of target executors, + and return a list of indices indicating which target executors should receive the message. + + Args: + source: The source executor of the edges. + targets: A list of target executors for the edges. + partition_func: A function that partitions messages to target executors. + """ + self._edge_groups.append(PartitioningEdgeGroup(source, targets, partition_func)) + + return self + def add_fan_in_edges(self, sources: Sequence[Executor], target: Executor) -> "Self": """Add multiple edges from sources to a single target executor. From edc3e96e695ff3e56acbe8040c7b4a71dd5a47bd Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 11 Aug 2025 10:22:41 -0700 Subject: [PATCH 03/11] Add unit tests --- .../agent_framework_workflow/_edge.py | 19 +- python/packages/workflow/tests/test_edge.py | 782 +++++++++++++++++- 2 files changed, 798 insertions(+), 3 deletions(-) diff --git a/python/packages/workflow/agent_framework_workflow/_edge.py b/python/packages/workflow/agent_framework_workflow/_edge.py index 3fdb2c3a74..519b884e9d 100644 --- a/python/packages/workflow/agent_framework_workflow/_edge.py +++ b/python/packages/workflow/agent_framework_workflow/_edge.py @@ -213,7 +213,20 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R return False # If no target ID, send the message to all edges in the group - await asyncio.gather(*(edge.send_message(message, shared_state, ctx) for edge in self._edges)) + if all(not edge.can_handle(message.data) for edge in self._edges): + return False + + await asyncio.gather( + *( + edge.send_message( + message, + shared_state, + ctx, + ) + for edge in self._edges + if edge.can_handle(message.data) + ), + ) return True @override @@ -387,6 +400,8 @@ def __init__( executors to route the message to based on the data. The function should take the message data and the number of targets, and return a list of indices of the target executors to route the message to. """ + if len(targets) <= 1: + raise ValueError("PartitioningEdgeGroup must contain at least two targets.") self._edges = [Edge(source=source, target=target) for target in targets] self._partition_func = partition_func @@ -396,7 +411,7 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R partition_result = self._partition_func(message.data, len(self._edges)) if not self._validate_partition_result(partition_result): raise RuntimeError( - f"Invalid partition result: {partition_result}. Expected indices in range [0, {len(self._edges) - 1}]." + f"Invalid partition result: {partition_result}. Expected indices in range: [0, {len(self._edges) - 1}]." ) if message.target_id: diff --git a/python/packages/workflow/tests/test_edge.py b/python/packages/workflow/tests/test_edge.py index 0a9f974602..43db3e7265 100644 --- a/python/packages/workflow/tests/test_edge.py +++ b/python/packages/workflow/tests/test_edge.py @@ -2,10 +2,19 @@ from dataclasses import dataclass from typing import Any +from unittest.mock import patch +import pytest from agent_framework.workflow import Executor, WorkflowContext, handler -from agent_framework_workflow._edge import Edge +from agent_framework_workflow._edge import ( + ConditionalEdgeGroup, + Edge, + PartitioningEdgeGroup, + SingleEdgeGroup, + SourceEdgeGroup, + TargetEdgeGroup, +) @dataclass @@ -15,6 +24,13 @@ class MockMessage: data: Any +@dataclass +class MockMessageSecondary: + """A secondary mock message for testing purposes.""" + + data: Any + + class MockExecutor(Executor): """A mock executor for testing purposes.""" @@ -24,6 +40,27 @@ async def mock_handler(self, message: MockMessage, ctx: WorkflowContext) -> None pass +class MockExecutorSecondary(Executor): + """A secondary mock executor for testing purposes.""" + + @handler + async def mock_handler_secondary(self, message: MockMessageSecondary, ctx: WorkflowContext) -> None: + """A secondary mock handler that does nothing.""" + pass + + +class MockAggregator(Executor): + """A mock aggregator for testing purposes.""" + + @handler + async def mock_aggregator_handler(self, message: list[MockMessage], ctx: WorkflowContext) -> None: + """A mock aggregator handler that does nothing.""" + pass + + +# region Edge + + def test_create_edge(): """Test creating an edge with a source and target executor.""" source = MockExecutor(id="source_executor") @@ -44,3 +81,746 @@ def test_edge_can_handle(): edge = Edge(source=source, target=target) assert edge.can_handle(MockMessage(data="test")) + + +# endregion Edge + +# region SingleEdgeGroup + + +def test_single_edge_group(): + """Test creating a single edge group.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target) + + assert edge_group.source_executors() == [source] + assert edge_group.target_executors() == [target] + assert edge_group.edges()[0].source_id == "source_executor" + assert edge_group.edges()[0].target_id == "target_executor" + + +def test_single_edge_group_with_condition(): + """Test creating a single edge group with a condition.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target, condition=lambda x: x.data == "test") + + assert edge_group.source_executors() == [source] + assert edge_group.target_executors() == [target] + assert edge_group.edges()[0].source_id == "source_executor" + assert edge_group.edges()[0].target_id == "target_executor" + assert edge_group.edges()[0]._condition is not None # type: ignore + + +async def test_single_edge_group_send_message(): + """Test sending a message through a single edge group.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is True + + +async def test_single_edge_group_send_message_with_target(): + """Test sending a message through a single edge group.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id=target.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is True + + +async def test_single_edge_group_send_message_with_invalid_target(): + """Test sending a message through a single edge group.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id="invalid_target") + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_single_edge_group_send_message_with_invalid_data(): + """Test sending a message through a single edge group.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + edge_group = SingleEdgeGroup(source=source, target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +# endregion SingleEdgeGroup + + +# region SourceEdgeGroup + + +def test_source_edge_group(): + """Test creating a source edge group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SourceEdgeGroup(source=source, targets=[target1, target2]) + + assert edge_group.source_executors() == [source] + assert edge_group.target_executors() == [target1, target2] + assert len(edge_group.edges()) == 2 + assert edge_group.edges()[0].source_id == "source_executor" + assert edge_group.edges()[0].target_id == "target_executor_1" + assert edge_group.edges()[1].source_id == "source_executor" + assert edge_group.edges()[1].target_id == "target_executor_2" + + +def test_source_edge_group_invalid_number_of_targets(): + """Test creating a source edge group with an invalid number of targets.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + with pytest.raises(ValueError, match="SourceEdgeGroup must contain at least two targets"): + SourceEdgeGroup(source=source, targets=[target]) + + +async def test_source_edge_group_send_message(): + """Test sending a message through a source edge group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SourceEdgeGroup(source=source, targets=[target1, target2]) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 2 + + +async def test_source_edge_group_send_message_with_target(): + """Test sending a message through a source edge group with a target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SourceEdgeGroup(source=source, targets=[target1, target2]) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id=target1.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + assert mock_send.call_args[0][0].target_id == target1.id + + +async def test_source_edge_group_send_message_with_invalid_target(): + """Test sending a message through a source edge group with an invalid target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SourceEdgeGroup(source=source, targets=[target1, target2]) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id="invalid_target") + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_source_edge_group_send_message_with_invalid_data(): + """Test sending a message through a source edge group with invalid data.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SourceEdgeGroup(source=source, targets=[target1, target2]) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_source_edge_group_send_message_only_one_successful_send(): + """Test sending a message through a source edge group where only one edge can handle the message.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutorSecondary(id="target_executor_2") + + edge_group = SourceEdgeGroup(source=source, targets=[target1, target2]) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + + +# endregion SourceEdgeGroup + +# region TargetEdgeGroup + + +def test_target_edge_group(): + """Test creating a target edge group.""" + source1 = MockExecutor(id="source_executor_1") + source2 = MockExecutor(id="source_executor_2") + target = MockAggregator(id="target_executor") + + edge_group = TargetEdgeGroup(sources=[source1, source2], target=target) + + assert edge_group.source_executors() == [source1, source2] + assert edge_group.target_executors() == [target] + assert len(edge_group.edges()) == 2 + assert edge_group.edges()[0].source_id == "source_executor_1" + assert edge_group.edges()[0].target_id == "target_executor" + assert edge_group.edges()[1].source_id == "source_executor_2" + assert edge_group.edges()[1].target_id == "target_executor" + + +def test_target_edge_group_invalid_number_of_sources(): + """Test creating a target edge group with an invalid number of sources.""" + source = MockExecutor(id="source_executor") + target = MockAggregator(id="target_executor") + + with pytest.raises(ValueError, match="TargetEdgeGroup must contain at least two sources"): + TargetEdgeGroup(sources=[source], target=target) + + +async def test_target_edge_group_send_message_buffer(): + """Test sending a message through a target edge group with buffering.""" + source1 = MockExecutor(id="source_executor_1") + source2 = MockExecutor(id="source_executor_2") + target = MockAggregator(id="target_executor") + + edge_group = TargetEdgeGroup(sources=[source1, source2], target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message( + Message(data=data, source_id=source1.id), + shared_state, + ctx, + ) + + assert success is True + assert mock_send.call_count == 0 # The message should be buffered and wait for the second source + assert len(edge_group._buffer[source1.id]) == 1 # type: ignore + + success = await edge_group.send_message( + Message(data=data, source_id=source2.id), + shared_state, + ctx, + ) + assert success is True + assert mock_send.call_count == 1 # The message should be sent now that both sources have sent their messages + + # Buffer should be cleared after sending + assert not edge_group._buffer # type: ignore + + +async def test_target_edge_group_send_message_with_invalid_target(): + """Test sending a message through a target edge group with an invalid target.""" + source1 = MockExecutor(id="source_executor_1") + source2 = MockExecutor(id="source_executor_2") + target = MockAggregator(id="target_executor") + + edge_group = TargetEdgeGroup(sources=[source1, source2], target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source1.id, target_id="invalid_target") + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_target_edge_group_send_message_with_invalid_data(): + """Test sending a message through a target edge group with invalid data.""" + source1 = MockExecutor(id="source_executor_1") + source2 = MockExecutor(id="source_executor_2") + target = MockAggregator(id="target_executor") + + edge_group = TargetEdgeGroup(sources=[source1, source2], target=target) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source1.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +# endregion TargetEdgeGroup + +# region ConditionalEdgeGroup + + +def test_conditional_edge_group(): + """Test creating a conditional edge group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = ConditionalEdgeGroup( + source=source, + targets=[target1, target2], + conditions=[lambda x: x.data < 0], + ) + + assert edge_group.source_executors() == [source] + assert edge_group.target_executors() == [target1, target2] + assert len(edge_group.edges()) == 2 + assert edge_group.edges()[0].source_id == "source_executor" + assert edge_group.edges()[0].target_id == "target_executor_1" + assert edge_group.edges()[0]._condition is not None # type: ignore + assert edge_group.edges()[1].source_id == "source_executor" + assert edge_group.edges()[1].target_id == "target_executor_2" + assert edge_group.edges()[1]._condition is None # type: ignore + + +def test_conditional_edge_group_invalid_number_of_targets(): + """Test creating a conditional edge group with an invalid number of targets.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + with pytest.raises(ValueError, match="ConditionalEdgeGroup must contain at least two targets"): + ConditionalEdgeGroup( + source=source, + targets=[target], + conditions=[lambda x: x.data < 0], + ) + + +def test_conditional_edge_group_invalid_number_of_conditions(): + """Test creating a conditional edge group with an invalid number of conditions.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + with pytest.raises(ValueError, match="Number of targets must be one more than the number of conditions."): + ConditionalEdgeGroup( + source=source, + targets=[target1, target2], + conditions=[lambda x: x.data < 0, lambda x: x.data > 0], + ) + + +async def test_conditional_edge_group_send_message(): + """Test sending a message through a conditional edge group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = ConditionalEdgeGroup( + source=source, + targets=[target1, target2], + conditions=[lambda x: x.data < 0], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data=-1) + message = Message(data=data, source_id=source.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + + # Default condition should + data = MockMessage(data=1) + message = Message(data=data, source_id=source.id) + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + + +async def test_conditional_edge_group_send_message_with_invalid_target(): + """Test sending a message through a conditional edge group with an invalid target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = ConditionalEdgeGroup( + source=source, + targets=[target1, target2], + conditions=[lambda x: x.data < 0], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data=-1) + message = Message(data=data, source_id=source.id, target_id="invalid_target") + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_conditional_edge_group_send_message_with_valid_target(): + """Test sending a message through a conditional edge group with a target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = ConditionalEdgeGroup( + source=source, + targets=[target1, target2], + conditions=[lambda x: x.data < 0], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data=1) # Condition will fail + message = Message(data=data, source_id=source.id, target_id=target1.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + data = MockMessage(data=-1) # Condition will pass + message = Message(data=data, source_id=source.id, target_id=target1.id) + success = await edge_group.send_message(message, shared_state, ctx) + assert success is True + + +async def test_conditional_edge_group_send_message_with_invalid_data(): + """Test sending a message through a conditional edge group with invalid data.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = ConditionalEdgeGroup( + source=source, + targets=[target1, target2], + conditions=[lambda x: x.data < 0], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +# endregion ConditionalEdgeGroup + + +# region PartitioningEdgeGroup + + +def test_partitioning_edge_group(): + """Test creating a partitioning edge group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = PartitioningEdgeGroup( + source=source, + targets=[target1, target2], + partition_func=lambda data, num_edges: [0], + ) + + assert edge_group.source_executors() == [source] + assert edge_group.target_executors() == [target1, target2] + assert len(edge_group.edges()) == 2 + assert edge_group.edges()[0].source_id == "source_executor" + assert edge_group.edges()[0].target_id == "target_executor_1" + assert edge_group.edges()[1].source_id == "source_executor" + assert edge_group.edges()[1].target_id == "target_executor_2" + + +def test_partitioning_edge_group_invalid_number_of_targets(): + """Test creating a partitioning edge group with an invalid number of targets.""" + source = MockExecutor(id="source_executor") + target = MockExecutor(id="target_executor") + + with pytest.raises(ValueError, match="PartitioningEdgeGroup must contain at least two targets."): + PartitioningEdgeGroup( + source=source, + targets=[target], + partition_func=lambda data, num_edges: [0], + ) + + +async def test_partitioning_edge_group_send_message(): + """Test sending a message through a partitioning edge group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = PartitioningEdgeGroup( + source=source, + targets=[target1, target2], + partition_func=lambda data, num_edges: [0, 1], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 2 + + +async def test_partitioning_edge_group_send_message_with_invalid_partition_result(): + """Test sending a message through a partitioning edge group with an invalid partition result.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = PartitioningEdgeGroup( + source=source, + targets=[target1, target2], + partition_func=lambda data, num_edges: [0, 2], # Invalid index + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + with pytest.raises(RuntimeError): + await edge_group.send_message(message, shared_state, ctx) + + +async def test_partitioning_edge_group_send_message_with_target(): + """Test sending a message through a partitioning edge group with a target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = PartitioningEdgeGroup( + source=source, + targets=[target1, target2], + partition_func=lambda data, num_edges: [0, 1], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id=target1.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + assert mock_send.call_args[0][0].target_id == target1.id + + +async def test_partitioning_edge_group_send_message_with_target_not_in_partition(): + """Test sending a message through a partitioning edge group with a target not in the partition.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = PartitioningEdgeGroup( + source=source, + targets=[target1, target2], + partition_func=lambda data, num_edges: [0], # Only target1 will receive the message + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id=target2.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_partitioning_edge_group_send_message_with_invalid_data(): + """Test sending a message through a partitioning edge group with invalid data.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = PartitioningEdgeGroup( + source=source, + targets=[target1, target2], + partition_func=lambda data, num_edges: [0, 1], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_partitioning_edge_group_send_message_with_target_invalid_data(): + """Test sending a message through a partitioning edge group with a target and invalid data.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = PartitioningEdgeGroup( + source=source, + targets=[target1, target2], + partition_func=lambda data, num_edges: [0, 1], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id, target_id=target1.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +# endregion PartitioningEdgeGroup From b07296e266916d189a1cf7851d95b452e04231f2 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 11 Aug 2025 11:55:38 -0700 Subject: [PATCH 04/11] Add samples --- .../agent_framework_workflow/_typing_utils.py | 4 + .../workflow/step_00_foundation_patterns.py | 287 ++++++++++++++++++ 2 files changed, 291 insertions(+) create mode 100644 python/samples/getting_started/workflow/step_00_foundation_patterns.py diff --git a/python/packages/workflow/agent_framework_workflow/_typing_utils.py b/python/packages/workflow/agent_framework_workflow/_typing_utils.py index f8547d886e..46500bdf52 100644 --- a/python/packages/workflow/agent_framework_workflow/_typing_utils.py +++ b/python/packages/workflow/agent_framework_workflow/_typing_utils.py @@ -13,6 +13,10 @@ def is_instance_of(data: Any, target_type: type) -> bool: Returns: bool: True if data is an instance of target_type, False otherwise. """ + # Case 0: target_type is Any - always return True + if target_type is Any: + return True + origin = get_origin(target_type) args = get_args(target_type) diff --git a/python/samples/getting_started/workflow/step_00_foundation_patterns.py b/python/samples/getting_started/workflow/step_00_foundation_patterns.py new file mode 100644 index 0000000000..30f2b38b04 --- /dev/null +++ b/python/samples/getting_started/workflow/step_00_foundation_patterns.py @@ -0,0 +1,287 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from typing import Any + +from agent_framework.workflow import Executor, WorkflowBuilder, WorkflowContext, handler + +""" +The following sample demonstrates the foundation patterns that the workflow framework supports. +These patterns include: +- Single connection +- Single connection with condition +- Fan-out and fan-in connections +- Conditional fan-out connections +- Partitioning fan-out connections + +The samples here use numbers and simple arithmetic operations to demonstrate the patterns. +""" + + +class AddOneExecutor(Executor): + """An executor that processes a number by adding one.""" + + @handler(output_types=[int]) + async def add_one(self, number: int, ctx: WorkflowContext) -> None: + """Execute the task by adding one to the input number.""" + result = number + 1 + + # Send the result to the next executor in the workflow. + await ctx.send_message(result) + + print("Adding one to the number:", number, "Result:", result) + + +class MultiplyByTwoExecutor(Executor): + """An executor that processes a number by multiplying it by two.""" + + @handler(output_types=[int]) + async def multiply_by_two(self, number: int, ctx: WorkflowContext) -> None: + """Execute the task by multiplying the input number by two.""" + result = number * 2 + + # Send the result to the next executor in the workflow. + await ctx.send_message(result) + + print("Multiplying the number by two:", number, "Result:", result) + + +class DivideByTwoExecutor(Executor): + """An executor that processes a number by dividing it by two.""" + + @handler(output_types=[float]) + async def divide_by_two(self, number: int, ctx: WorkflowContext) -> None: + """Execute the task by dividing the input number by two.""" + result = number / 2 + + # Send the result with a workflow completion event. + await ctx.send_message(result) + + print("Dividing the number by two:", number, "Result:", result) + + +class AggregateResultExecutor(Executor): + """An executor that receives results and prints them.""" + + @handler + async def aggregate_results(self, results: Any, ctx: WorkflowContext) -> None: + """Print whatever results are received.""" + print("Aggregating results:", results) + + +async def single_edge(): + """A sample to demonstrate a single directed connection between two executors. + + Three executors are connected in a sequence: AddOneExecutor -> AddOneExecutor -> AggregateResultExecutor. + + Expected output: + Adding one to the number: 1 Result: 2 + Adding one to the number: 2 Result: 3 + Aggregating results: 3 + """ + add_one_executor_a = AddOneExecutor() + add_one_executor_b = AddOneExecutor() + aggregate_result_executor = AggregateResultExecutor() + + workflow = ( + WorkflowBuilder() + .add_edge(add_one_executor_a, add_one_executor_b) + .add_edge(add_one_executor_b, aggregate_result_executor) + .set_start_executor(add_one_executor_a) + .build() + ) + + await workflow.run(1) + + +async def single_edge_with_condition(): + """A sample to demonstrate a single directed connection with a condition. + + Three executors are connected: AddOneExecutor -> AddOneExecutor, AggregateResultExecutor. + The AddOneExecutor will loop back to itself until the number reaches 10, then it will start + sending the result to AggregateResultExecutor when the number is greater than 8. + + Expected output: + Adding one to the number: 1 Result: 2 + Adding one to the number: 2 Result: 3 + Adding one to the number: 3 Result: 4 + Adding one to the number: 4 Result: 5 + Adding one to the number: 5 Result: 6 + Adding one to the number: 6 Result: 7 + Adding one to the number: 7 Result: 8 + Adding one to the number: 8 Result: 9 + Adding one to the number: 9 Result: 10 + Aggregating results: 9 + Adding one to the number: 10 Result: 11 + Aggregating results: 10 + Aggregating results: 11 + """ + add_one_executor_a = AddOneExecutor() + aggregate_result_executor = AggregateResultExecutor() + + workflow = ( + WorkflowBuilder() + .add_edge(add_one_executor_a, add_one_executor_a, condition=lambda x: x < 11) + .add_edge(add_one_executor_a, aggregate_result_executor, condition=lambda x: x > 8) + .set_start_executor(add_one_executor_a) + .build() + ) + + await workflow.run(1) + + +async def fan_out_fan_in_edges(): + """A sample to demonstrate a fan-out and fan-in connection between executors. + + Four executors are connected in a fan-out and fan-in pattern: + AddOneExecutor -> MultiplyByTwoExecutor, DivideByTwoExecutor -> AggregateResultExecutor. + The AddOneExecutor sends its output to both MultiplyByTwoExecutor and DivideByTwoExecutor, + and both of these executors send their results to AggregateResultExecutor. + + The target of the fan-in connection will wait for all the results from the sources before proceeding. + + Expected output: + Adding one to the number: 1 Result: 2 + Multiplying the number by two: 2 Result: 4 + Dividing the number by two: 2 Result: 1.0 + Aggregating results: [4, 1.0] + """ + add_one_executor = AddOneExecutor() + multiply_by_two_executor = MultiplyByTwoExecutor() + divide_by_two_executor = DivideByTwoExecutor() + aggregate_result_executor = AggregateResultExecutor() + + workflow = ( + WorkflowBuilder() + .add_fan_out_edges(add_one_executor, [multiply_by_two_executor, divide_by_two_executor]) + .add_fan_in_edges([multiply_by_two_executor, divide_by_two_executor], aggregate_result_executor) + .set_start_executor(add_one_executor) + .build() + ) + + await workflow.run(1) + + +async def conditional_fan_out_edges(): + """A sample to demonstrate a conditional fan-out connection. + + Four executors are connected in a conditional fan-out pattern: + AddOneExecutor -> AddOneExecutor, MultiplyByTwoExecutor, DivideByTwoExecutor -> AggregateResultExecutor. + + The message from AddOneExecutor will be evaluated against the conditions one by one, and the first condition + that evaluates to True will determine the target executors. If no conditions match, the message will be sent + to the last targets. + + This pattern resembles a switch-case statement with a default case where the first matching case is executed. + + Expected output: + Adding one to the number: 1 Result: 2 + Adding one to the number: 2 Result: 3 + Adding one to the number: 3 Result: 4 + Adding one to the number: 4 Result: 5 + Adding one to the number: 5 Result: 6 + Adding one to the number: 6 Result: 7 + Adding one to the number: 7 Result: 8 + Adding one to the number: 8 Result: 9 + Adding one to the number: 9 Result: 10 + Adding one to the number: 10 Result: 11 + Multiplying the number by two: 11 Result: 22 + """ + add_one_executor = AddOneExecutor() + multiply_by_two_executor = MultiplyByTwoExecutor() + divide_by_two_executor = DivideByTwoExecutor() + aggregate_result_executor = AggregateResultExecutor() + + workflow = ( + WorkflowBuilder() + .set_start_executor(add_one_executor) + .add_conditional_fan_out_edges( + add_one_executor, + [add_one_executor, divide_by_two_executor, multiply_by_two_executor], + # Loop back to the add_one_executor if the number is less than 11 + # and to the multiply_by_two_executor when the number is larger than or equal to 11 and even. + # Otherwise, send to the divide_by_two_executor. + conditions=[lambda x: x < 11, lambda x: x % 2 == 0], + ) + .add_fan_in_edges([multiply_by_two_executor, divide_by_two_executor], aggregate_result_executor) + .build() + ) + + await workflow.run(1) + + +async def partitioning_fan_out_edges(): + """A sample to demonstrate a partitioning fan-out connection. + + Four executors are connected in a partitioning fan-out pattern: + AddOneExecutor -> AddOneExecutor, MultiplyByTwoExecutor, DivideByTwoExecutor -> AggregateResultExecutor. + + The AddOneExecutor sends its output to one or more executors based on the partitioning function. + + Expected output: + Adding one to the number: 1 Result: 2 + Adding one to the number: 2 Result: 3 + Adding one to the number: 3 Result: 4 + Adding one to the number: 4 Result: 5 + Adding one to the number: 5 Result: 6 + Adding one to the number: 6 Result: 7 + Adding one to the number: 7 Result: 8 + Adding one to the number: 8 Result: 9 + Adding one to the number: 9 Result: 10 + Adding one to the number: 10 Result: 11 + Adding one to the number: 11 Result: 12 + Adding one to the number: 12 Result: 13 + Dividing the number by two: 12 Result: 6.0 + Multiplying the number by two: 13 Result: 26 + Aggregating results: [26, 6.0] + """ + add_one_executor = AddOneExecutor() + multiply_by_two_executor = MultiplyByTwoExecutor() + divide_by_two_executor = DivideByTwoExecutor() + aggregate_result_executor = AggregateResultExecutor() + + def partition_func(number: int, total_targets: int) -> list[int]: + """Partition function to determine which executor to send the number to.""" + if number < 12: + # Loop back to the add_one_executor if the number is less than 12 + return [0] + + if number % 2 == 0: + # Send it to the add_one_executor to add one more time and the + # divide_by_two_executor to divide the result by two. + return [0, 2] + + # Otherwise, send it to the multiply_by_two_executor to multiply the result by two. + return [1] + + workflow = ( + WorkflowBuilder() + .set_start_executor(add_one_executor) + .add_partitioning_fan_out_edges( + add_one_executor, + [add_one_executor, multiply_by_two_executor, divide_by_two_executor], + partition_func=partition_func, + ) + .add_fan_in_edges([multiply_by_two_executor, divide_by_two_executor], aggregate_result_executor) + .build() + ) + + await workflow.run(1) + + +async def main(): + """Main function to run the workflows.""" + print("**Running single connection workflow**") + await single_edge() + print("**Running single connection with condition workflow**") + await single_edge_with_condition() + print("**Running fan-out and fan-in connection workflow**") + await fan_out_fan_in_edges() + print("**Running conditional fan-out connection workflow**") + await conditional_fan_out_edges() + print("**Running partitioning fan-out connection workflow**") + await partitioning_fan_out_edges() + + +if __name__ == "__main__": + asyncio.run(main()) From 720fd5bba731d0b53321f5cfb98abc5a376730ae Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 11 Aug 2025 13:50:17 -0700 Subject: [PATCH 05/11] Address comments 1 --- .../agent_framework_workflow/_edge.py | 19 ++++- .../agent_framework_workflow/_runner.py | 2 +- .../agent_framework_workflow/_validation.py | 6 +- .../agent_framework_workflow/_workflow.py | 2 +- python/packages/workflow/tests/test_edge.py | 78 +++++++++---------- 5 files changed, 61 insertions(+), 46 deletions(-) diff --git a/python/packages/workflow/agent_framework_workflow/_edge.py b/python/packages/workflow/agent_framework_workflow/_edge.py index 519b884e9d..f8b9ae089d 100644 --- a/python/packages/workflow/agent_framework_workflow/_edge.py +++ b/python/packages/workflow/agent_framework_workflow/_edge.py @@ -124,16 +124,19 @@ def id(self) -> str: """Get the unique ID of the edge group.""" return self._id + @property @abstractmethod def source_executors(self) -> list[Executor]: """Get the source executor IDs of the edges in the group.""" ... + @property @abstractmethod def target_executors(self) -> list[Executor]: """Get the target executor IDs of the edges in the group.""" ... + @property @abstractmethod def edges(self) -> list[Edge]: """Get the edges in the group.""" @@ -141,7 +144,10 @@ def edges(self) -> list[Edge]: class SingleEdgeGroup(EdgeGroup): - """Represents a single edge group that contains only one edge.""" + """Represents a single edge group that contains only one edge. + + A concrete implementation of EdgeGroup that represent a group containing exactly one edge. + """ def __init__(self, source: Executor, target: Executor, condition: Callable[[Any], bool] | None = None) -> None: """Initialize the single edge group with an edge. @@ -167,16 +173,19 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R return False + @property @override def source_executors(self) -> list[Executor]: """Get the source executor of the edge.""" return [self._edge.source] + @property @override def target_executors(self) -> list[Executor]: """Get the target executor of the edge.""" return [self._edge.target] + @property @override def edges(self) -> list[Edge]: """Get the edges in the group.""" @@ -229,16 +238,19 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R ) return True + @property @override def source_executors(self) -> list[Executor]: """Get the source executor of the edges in the group.""" return [self._edges[0].source] + @property @override def target_executors(self) -> list[Executor]: """Get the target executors of the edges in the group.""" return [edge.target for edge in self._edges] + @property @override def edges(self) -> list[Edge]: """Get the edges in the group.""" @@ -296,16 +308,19 @@ def _is_ready_to_send(self) -> bool: """Check if all edges in the group have data to send.""" return all(self._buffer[edge.source_id] for edge in self._edges) + @property @override def source_executors(self) -> list[Executor]: """Get the source executors of the edges in the group.""" return [edge.source for edge in self._edges] + @property @override def target_executors(self) -> list[Executor]: """Get the target executor of the edges in the group.""" return [self._edges[0].target] + @property @override def edges(self) -> list[Edge]: """Get the edges in the group.""" @@ -358,7 +373,7 @@ def __init__( raise ValueError("Number of targets must be one more than the number of conditions.") self._edges = [ - Edge(source, target, condition) for target, condition in zip(targets, [*conditions, None], strict=False) + Edge(source, target, condition) for target, condition in zip(targets, [*conditions, None], strict=True) ] @override diff --git a/python/packages/workflow/agent_framework_workflow/_runner.py b/python/packages/workflow/agent_framework_workflow/_runner.py index df4e39a564..5eb31586b1 100644 --- a/python/packages/workflow/agent_framework_workflow/_runner.py +++ b/python/packages/workflow/agent_framework_workflow/_runner.py @@ -104,7 +104,7 @@ def _parse_edge_groups(self, edge_groups: Sequence[EdgeGroup]) -> dict[str, list """ parsed: defaultdict[str, list[EdgeGroup]] = defaultdict(list) for group in edge_groups: - for source_executor in group.source_executors(): + for source_executor in group.source_executors: parsed[source_executor.id].append(group) return parsed diff --git a/python/packages/workflow/agent_framework_workflow/_validation.py b/python/packages/workflow/agent_framework_workflow/_validation.py index ad7b731d0d..5c5a352ac4 100644 --- a/python/packages/workflow/agent_framework_workflow/_validation.py +++ b/python/packages/workflow/agent_framework_workflow/_validation.py @@ -104,7 +104,7 @@ def validate_workflow(self, edge_groups: Sequence[EdgeGroup], start_executor: Ex WorkflowValidationError: If any validation fails """ self._executors = self._build_executor_map(edge_groups) - self._edges = [edge for group in edge_groups for edge in group.edges()] + self._edges = [edge for group in edge_groups for edge in group.edges] self._edge_groups = edge_groups # Validate that start_executor exists in the graph @@ -127,7 +127,7 @@ def _build_executor_map(self, edge_groups: Sequence[EdgeGroup]) -> dict[str, Exe """Build a map of executor IDs to executor instances.""" executors: dict[str, Executor] = {} for group in edge_groups: - for executor in group.source_executors() + group.target_executors(): + for executor in group.source_executors + group.target_executors: executors[executor.id] = executor return executors @@ -159,7 +159,7 @@ def _validate_type_compatibility(self) -> None: TypeCompatibilityError: If type incompatibility is detected """ for edge_group in self._edge_groups: - for edge in edge_group.edges(): + for edge in edge_group.edges: self._validate_edge_type_compatibility(edge, edge_group) def _validate_edge_type_compatibility(self, edge: Edge, edge_group: EdgeGroup) -> None: diff --git a/python/packages/workflow/agent_framework_workflow/_workflow.py b/python/packages/workflow/agent_framework_workflow/_workflow.py index cc85607cd6..b1bf708737 100644 --- a/python/packages/workflow/agent_framework_workflow/_workflow.py +++ b/python/packages/workflow/agent_framework_workflow/_workflow.py @@ -218,7 +218,7 @@ def _build_executor_map(self, edge_groups: list[EdgeGroup]) -> dict[str, Executo """ executors: dict[str, Executor] = {} for group in edge_groups: - for executor in group.source_executors() + group.target_executors(): + for executor in group.source_executors + group.target_executors: executors[executor.id] = executor return executors diff --git a/python/packages/workflow/tests/test_edge.py b/python/packages/workflow/tests/test_edge.py index 43db3e7265..5d6a436e7a 100644 --- a/python/packages/workflow/tests/test_edge.py +++ b/python/packages/workflow/tests/test_edge.py @@ -95,10 +95,10 @@ def test_single_edge_group(): edge_group = SingleEdgeGroup(source=source, target=target) - assert edge_group.source_executors() == [source] - assert edge_group.target_executors() == [target] - assert edge_group.edges()[0].source_id == "source_executor" - assert edge_group.edges()[0].target_id == "target_executor" + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target] + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor" def test_single_edge_group_with_condition(): @@ -108,11 +108,11 @@ def test_single_edge_group_with_condition(): edge_group = SingleEdgeGroup(source=source, target=target, condition=lambda x: x.data == "test") - assert edge_group.source_executors() == [source] - assert edge_group.target_executors() == [target] - assert edge_group.edges()[0].source_id == "source_executor" - assert edge_group.edges()[0].target_id == "target_executor" - assert edge_group.edges()[0]._condition is not None # type: ignore + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target] + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor" + assert edge_group.edges[0]._condition is not None # type: ignore async def test_single_edge_group_send_message(): @@ -209,13 +209,13 @@ def test_source_edge_group(): edge_group = SourceEdgeGroup(source=source, targets=[target1, target2]) - assert edge_group.source_executors() == [source] - assert edge_group.target_executors() == [target1, target2] - assert len(edge_group.edges()) == 2 - assert edge_group.edges()[0].source_id == "source_executor" - assert edge_group.edges()[0].target_id == "target_executor_1" - assert edge_group.edges()[1].source_id == "source_executor" - assert edge_group.edges()[1].target_id == "target_executor_2" + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target1, target2] + assert len(edge_group.edges) == 2 + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor_1" + assert edge_group.edges[1].source_id == "source_executor" + assert edge_group.edges[1].target_id == "target_executor_2" def test_source_edge_group_invalid_number_of_targets(): @@ -355,13 +355,13 @@ def test_target_edge_group(): edge_group = TargetEdgeGroup(sources=[source1, source2], target=target) - assert edge_group.source_executors() == [source1, source2] - assert edge_group.target_executors() == [target] - assert len(edge_group.edges()) == 2 - assert edge_group.edges()[0].source_id == "source_executor_1" - assert edge_group.edges()[0].target_id == "target_executor" - assert edge_group.edges()[1].source_id == "source_executor_2" - assert edge_group.edges()[1].target_id == "target_executor" + assert edge_group.source_executors == [source1, source2] + assert edge_group.target_executors == [target] + assert len(edge_group.edges) == 2 + assert edge_group.edges[0].source_id == "source_executor_1" + assert edge_group.edges[0].target_id == "target_executor" + assert edge_group.edges[1].source_id == "source_executor_2" + assert edge_group.edges[1].target_id == "target_executor" def test_target_edge_group_invalid_number_of_sources(): @@ -471,15 +471,15 @@ def test_conditional_edge_group(): conditions=[lambda x: x.data < 0], ) - assert edge_group.source_executors() == [source] - assert edge_group.target_executors() == [target1, target2] - assert len(edge_group.edges()) == 2 - assert edge_group.edges()[0].source_id == "source_executor" - assert edge_group.edges()[0].target_id == "target_executor_1" - assert edge_group.edges()[0]._condition is not None # type: ignore - assert edge_group.edges()[1].source_id == "source_executor" - assert edge_group.edges()[1].target_id == "target_executor_2" - assert edge_group.edges()[1]._condition is None # type: ignore + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target1, target2] + assert len(edge_group.edges) == 2 + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor_1" + assert edge_group.edges[0]._condition is not None # type: ignore + assert edge_group.edges[1].source_id == "source_executor" + assert edge_group.edges[1].target_id == "target_executor_2" + assert edge_group.edges[1]._condition is None # type: ignore def test_conditional_edge_group_invalid_number_of_targets(): @@ -644,13 +644,13 @@ def test_partitioning_edge_group(): partition_func=lambda data, num_edges: [0], ) - assert edge_group.source_executors() == [source] - assert edge_group.target_executors() == [target1, target2] - assert len(edge_group.edges()) == 2 - assert edge_group.edges()[0].source_id == "source_executor" - assert edge_group.edges()[0].target_id == "target_executor_1" - assert edge_group.edges()[1].source_id == "source_executor" - assert edge_group.edges()[1].target_id == "target_executor_2" + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target1, target2] + assert len(edge_group.edges) == 2 + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor_1" + assert edge_group.edges[1].source_id == "source_executor" + assert edge_group.edges[1].target_id == "target_executor_2" def test_partitioning_edge_group_invalid_number_of_targets(): From 98d8b4d8d5c80bdcd731b7bb51ffdbcf1ff76513 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 11 Aug 2025 14:08:07 -0700 Subject: [PATCH 06/11] Address comments 2 --- .../workflow/agent_framework_workflow/_edge.py | 2 +- .../workflow/step_02_simple_workflow_condition.py | 11 +++-------- .../workflow/step_04_simple_group_chat.py | 3 +-- .../workflow/step_05_simple_group_chat_with_hil.py | 3 +-- 4 files changed, 6 insertions(+), 13 deletions(-) diff --git a/python/packages/workflow/agent_framework_workflow/_edge.py b/python/packages/workflow/agent_framework_workflow/_edge.py index f8b9ae089d..1568138e68 100644 --- a/python/packages/workflow/agent_framework_workflow/_edge.py +++ b/python/packages/workflow/agent_framework_workflow/_edge.py @@ -297,7 +297,7 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R self._buffer.clear() # Only trigger one edge to send the messages to avoid duplicate sends await self._edges[0].send_message( - Message([msg.data for msg in messages_to_send], self._edges[0].source_id), + Message([msg.data for msg in messages_to_send], self.__class__.__name__), shared_state, ctx, ) diff --git a/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py b/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py index 1d8625c02c..637e3cefa3 100644 --- a/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py +++ b/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py @@ -97,15 +97,10 @@ async def main(): workflow = ( WorkflowBuilder() .set_start_executor(spam_detector) - .add_edge( + .add_conditional_fan_out_edges( spam_detector, - send_response, - condition=lambda x: x.is_spam is False, - ) - .add_edge( - spam_detector, - remove_spam, - condition=lambda x: x.is_spam is True, + [send_response, remove_spam], + conditions=[lambda x: not x.is_spam], ) .build() ) diff --git a/python/samples/getting_started/workflow/step_04_simple_group_chat.py b/python/samples/getting_started/workflow/step_04_simple_group_chat.py index dd5ae88e3f..4cc043638d 100644 --- a/python/samples/getting_started/workflow/step_04_simple_group_chat.py +++ b/python/samples/getting_started/workflow/step_04_simple_group_chat.py @@ -121,8 +121,7 @@ async def main(): workflow = ( WorkflowBuilder() .set_start_executor(group_chat_manager) - .add_edge(group_chat_manager, writer) - .add_edge(group_chat_manager, reviewer) + .add_fan_out_edges(group_chat_manager, [writer, reviewer]) .add_edge(writer, group_chat_manager) .add_edge(reviewer, group_chat_manager) .build() diff --git a/python/samples/getting_started/workflow/step_05_simple_group_chat_with_hil.py b/python/samples/getting_started/workflow/step_05_simple_group_chat_with_hil.py index 4f0040a249..b9da1d65b1 100644 --- a/python/samples/getting_started/workflow/step_05_simple_group_chat_with_hil.py +++ b/python/samples/getting_started/workflow/step_05_simple_group_chat_with_hil.py @@ -167,8 +167,7 @@ async def main(): .set_start_executor(group_chat_manager) .add_edge(group_chat_manager, request_info_executor) .add_edge(request_info_executor, group_chat_manager) - .add_edge(group_chat_manager, writer) - .add_edge(group_chat_manager, reviewer) + .add_fan_out_edges(group_chat_manager, [writer, reviewer]) .add_edge(writer, group_chat_manager) .add_edge(reviewer, group_chat_manager) .build() From 6ec2f1c2c88e584e5bf1799135063adad5e7d344 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 12 Aug 2025 11:57:36 -0700 Subject: [PATCH 07/11] Update conditional edge group to take in cases and default --- .../main/agent_framework/workflow/__init__.py | 2 + .../agent_framework/workflow/__init__.pyi | 4 + .../agent_framework_workflow/__init__.py | 3 + .../agent_framework_workflow/_edge.py | 97 ++++++++++--------- .../agent_framework_workflow/_workflow.py | 11 +-- python/packages/workflow/tests/test_edge.py | 65 +++++++++---- .../workflow/step_00_foundation_patterns.py | 17 ++-- .../step_02_simple_workflow_condition.py | 18 +++- 8 files changed, 137 insertions(+), 80 deletions(-) diff --git a/python/packages/main/agent_framework/workflow/__init__.py b/python/packages/main/agent_framework/workflow/__init__.py index bbfd05c2a9..971b7fb004 100644 --- a/python/packages/main/agent_framework/workflow/__init__.py +++ b/python/packages/main/agent_framework/workflow/__init__.py @@ -32,6 +32,8 @@ "InMemoryCheckpointStorage", "CheckpointStorage", "WorkflowCheckpoint", + "Case", + "Default", ] diff --git a/python/packages/main/agent_framework/workflow/__init__.pyi b/python/packages/main/agent_framework/workflow/__init__.pyi index 6506d4a936..e36c6ee461 100644 --- a/python/packages/main/agent_framework/workflow/__init__.pyi +++ b/python/packages/main/agent_framework/workflow/__init__.pyi @@ -6,7 +6,9 @@ from agent_framework_workflow import ( AgentExecutorResponse, AgentRunEvent, AgentRunStreamingEvent, + Case, CheckpointStorage, + Default, Executor, ExecutorCompletedEvent, ExecutorEvent, @@ -34,7 +36,9 @@ __all__ = [ "AgentExecutorResponse", "AgentRunEvent", "AgentRunStreamingEvent", + "Case", "CheckpointStorage", + "Default", "Executor", "ExecutorCompletedEvent", "ExecutorEvent", diff --git a/python/packages/workflow/agent_framework_workflow/__init__.py b/python/packages/workflow/agent_framework_workflow/__init__.py index 1d59918525..8196d5f767 100644 --- a/python/packages/workflow/agent_framework_workflow/__init__.py +++ b/python/packages/workflow/agent_framework_workflow/__init__.py @@ -11,6 +11,7 @@ from ._const import ( DEFAULT_MAX_ITERATIONS, ) +from ._edge import Case, Default from ._events import ( AgentRunEvent, AgentRunStreamingEvent, @@ -60,7 +61,9 @@ "AgentExecutorResponse", "AgentRunEvent", "AgentRunStreamingEvent", + "Case", "CheckpointStorage", + "Default", "EdgeDuplicationError", "Executor", "ExecutorCompletedEvent", diff --git a/python/packages/workflow/agent_framework_workflow/_edge.py b/python/packages/workflow/agent_framework_workflow/_edge.py index 1568138e68..7e82dc2176 100644 --- a/python/packages/workflow/agent_framework_workflow/_edge.py +++ b/python/packages/workflow/agent_framework_workflow/_edge.py @@ -2,11 +2,10 @@ import asyncio import logging -import sys import uuid -from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Callable, Sequence +from dataclasses import dataclass from typing import Any, ClassVar from ._executor import Executor @@ -14,11 +13,6 @@ from ._shared_state import SharedState from ._workflow_context import WorkflowContext -if sys.version_info >= (3, 12): - from typing import override # pragma: no cover -else: - from typing_extensions import override # pragma: no cover - logger = logging.getLogger(__name__) @@ -97,14 +91,13 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R ) -class EdgeGroup(ABC): +class EdgeGroup: """Represents a group of edges that share some common properties and can be triggered together.""" def __init__(self) -> None: """Initialize the edge group.""" self._id = f"{self.__class__.__name__}/{uuid.uuid4()}" - @abstractmethod async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: """Send a message through the edge group. @@ -114,10 +107,15 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R ctx (RunnerContext): The context for the runner. Returns: - bool: True if the message was sent successfully, False otherwise. If a message can be delivered - but rejected due to a condition, it will still return True. + bool: True if the message was sent successfully, False if the target executor cannot handle the message. + If a message can be delivered but rejected due to a condition, it will still return True. + + Note: + Exception will not be raised if the target executor cannot handle the message. This is because + a source executor can be connected to multiple target executors, and not every target executor may + be able to handle all the messages sent by the source executor. """ - ... + raise NotImplementedError @property def id(self) -> str: @@ -125,22 +123,19 @@ def id(self) -> str: return self._id @property - @abstractmethod def source_executors(self) -> list[Executor]: """Get the source executor IDs of the edges in the group.""" - ... + raise NotImplementedError @property - @abstractmethod def target_executors(self) -> list[Executor]: """Get the target executor IDs of the edges in the group.""" - ... + raise NotImplementedError @property - @abstractmethod def edges(self) -> list[Edge]: """Get the edges in the group.""" - ... + raise NotImplementedError class SingleEdgeGroup(EdgeGroup): @@ -161,7 +156,6 @@ def __init__(self, source: Executor, target: Executor, condition: Callable[[Any] """ self._edge = Edge(source=source, target=target, condition=condition) - @override async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: """Send a message through the single edge.""" if message.target_id and message.target_id != self._edge.target_id: @@ -174,19 +168,16 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R return False @property - @override def source_executors(self) -> list[Executor]: """Get the source executor of the edge.""" return [self._edge.source] @property - @override def target_executors(self) -> list[Executor]: """Get the target executor of the edge.""" return [self._edge.target] @property - @override def edges(self) -> list[Edge]: """Get the edges in the group.""" return [self._edge] @@ -210,7 +201,6 @@ def __init__(self, source: Executor, targets: Sequence[Executor]) -> None: raise ValueError("SourceEdgeGroup must contain at least two targets.") self._edges = [Edge(source=source, target=target) for target in targets] - @override async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: """Send a message through all edges in the source edge group.""" if message.target_id: @@ -239,19 +229,16 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R return True @property - @override def source_executors(self) -> list[Executor]: """Get the source executor of the edges in the group.""" return [self._edges[0].source] @property - @override def target_executors(self) -> list[Executor]: """Get the target executors of the edges in the group.""" return [edge.target for edge in self._edges] @property - @override def edges(self) -> list[Edge]: """Get the edges in the group.""" return self._edges @@ -278,7 +265,6 @@ def __init__(self, sources: Sequence[Executor], target: Executor) -> None: # Key is the source executor ID, value is a list of messages self._buffer: dict[str, list[Message]] = defaultdict(list) - @override async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: """Send a message through all edges in the target edge group.""" if message.target_id and message.target_id != self._edges[0].target_id: @@ -309,24 +295,45 @@ def _is_ready_to_send(self) -> bool: return all(self._buffer[edge.source_id] for edge in self._edges) @property - @override def source_executors(self) -> list[Executor]: """Get the source executors of the edges in the group.""" return [edge.source for edge in self._edges] @property - @override def target_executors(self) -> list[Executor]: """Get the target executor of the edges in the group.""" return [self._edges[0].target] @property - @override def edges(self) -> list[Edge]: """Get the edges in the group.""" return self._edges +@dataclass +class Case: + """Represents a single case in the conditional edge group. + + Args: + condition (Callable[[Any], bool]): The condition function for the case. + target (Executor): The target executor for the case. + """ + + condition: Callable[[Any], bool] + target: Executor + + +@dataclass +class Default: + """Represents the default case in the conditional edge group. + + Args: + target (Executor): The target executor for the default case. + """ + + target: Executor + + class ConditionalEdgeGroup(SourceEdgeGroup): """Represents a group of edges that assemble a conditional routing pattern. @@ -353,30 +360,33 @@ class ConditionalEdgeGroup(SourceEdgeGroup): def __init__( self, source: Executor, - targets: Sequence[Executor], - conditions: Sequence[Callable[[Any], bool]], + cases: Sequence[Case | Default], ) -> None: """Initialize the conditional edge group with a list of edges. Args: source (Executor): The source executor. - targets (Sequence[Executor]): A list of target executors that the source executor can send messages to. - conditions (Sequence[Callable[[Any], bool]]): A list of condition functions that determine - which target executor to route the message to based on the data. The number of conditions - must be one less than the number of targets, as the last target is the default case. The - index of the condition corresponds to the index of the target executor. + cases (Sequence[Case | Default]): A list of cases for the conditional edge group. + There should be exactly one default case. """ - if len(targets) <= 1: - raise ValueError("ConditionalEdgeGroup must contain at least two targets.") + if len(cases) < 2: + raise ValueError("ConditionalEdgeGroup must contain at least two cases (including the default case).") - if len(targets) != len(conditions) + 1: - raise ValueError("Number of targets must be one more than the number of conditions.") + default_case = [isinstance(case, Default) for case in cases] + if sum(default_case) != 1: + raise ValueError("ConditionalEdgeGroup must contain exactly one default case.") + + if isinstance(cases[-1], Default): + logger.warning( + "Default case in the conditional edge group is not the last case. " + "This will result in unexpected behavior." + ) self._edges = [ - Edge(source, target, condition) for target, condition in zip(targets, [*conditions, None], strict=True) + Edge(source, case.target, case.condition) if isinstance(case, Case) else Edge(source, case.target, None) + for case in cases ] - @override async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: """Send a message through the conditional edge group.""" if message.target_id: @@ -420,7 +430,6 @@ def __init__( self._edges = [Edge(source=source, target=target) for target in targets] self._partition_func = partition_func - @override async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: """Send a message through the partitioning edge group.""" partition_result = self._partition_func(message.data, len(self._edges)) diff --git a/python/packages/workflow/agent_framework_workflow/_workflow.py b/python/packages/workflow/agent_framework_workflow/_workflow.py index 1cc5390ece..ade9c85ee4 100644 --- a/python/packages/workflow/agent_framework_workflow/_workflow.py +++ b/python/packages/workflow/agent_framework_workflow/_workflow.py @@ -10,7 +10,9 @@ from ._checkpoint import CheckpointStorage from ._const import DEFAULT_MAX_ITERATIONS from ._edge import ( + Case, ConditionalEdgeGroup, + Default, EdgeGroup, PartitioningEdgeGroup, SingleEdgeGroup, @@ -468,9 +470,7 @@ def add_fan_out_edges(self, source: Executor, targets: Sequence[Executor]) -> "S return self - def add_conditional_fan_out_edges( - self, source: Executor, targets: Sequence[Executor], conditions: Sequence[Callable[[Any], bool]] - ) -> "Self": + def add_conditional_edge_group(self, source: Executor, cases: Sequence[Case | Default]) -> "Self": """Add a conditional fan out group of edges to the workflow. The output types of the source and the input types of the targets must be compatible. @@ -486,10 +486,9 @@ def add_conditional_fan_out_edges( Args: source: The source executor of the edges. - targets: A list of target executors for the edges. - conditions: A list of condition functions that determine whether each edge should be traversed. + cases: A list of case objects that determine the target executor for each message. """ - self._edge_groups.append(ConditionalEdgeGroup(source, targets, conditions)) + self._edge_groups.append(ConditionalEdgeGroup(source, cases)) return self diff --git a/python/packages/workflow/tests/test_edge.py b/python/packages/workflow/tests/test_edge.py index 5d6a436e7a..527e5ee34c 100644 --- a/python/packages/workflow/tests/test_edge.py +++ b/python/packages/workflow/tests/test_edge.py @@ -8,7 +8,9 @@ from agent_framework.workflow import Executor, WorkflowContext, handler from agent_framework_workflow._edge import ( + Case, ConditionalEdgeGroup, + Default, Edge, PartitioningEdgeGroup, SingleEdgeGroup, @@ -467,8 +469,10 @@ def test_conditional_edge_group(): edge_group = ConditionalEdgeGroup( source=source, - targets=[target1, target2], - conditions=[lambda x: x.data < 0], + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + ], ) assert edge_group.source_executors == [source] @@ -482,30 +486,45 @@ def test_conditional_edge_group(): assert edge_group.edges[1]._condition is None # type: ignore -def test_conditional_edge_group_invalid_number_of_targets(): - """Test creating a conditional edge group with an invalid number of targets.""" +def test_conditional_edge_group_invalid_number_of_cases(): + """Test creating a conditional edge group with an invalid number of cases.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") - with pytest.raises(ValueError, match="ConditionalEdgeGroup must contain at least two targets"): + with pytest.raises( + ValueError, match=r"ConditionalEdgeGroup must contain at least two cases \(including the default case\)." + ): ConditionalEdgeGroup( source=source, - targets=[target], - conditions=[lambda x: x.data < 0], + cases=[ + Case(condition=lambda x: x.data < 0, target=target), + ], + ) + + with pytest.raises(ValueError, match="ConditionalEdgeGroup must contain exactly one default case."): + ConditionalEdgeGroup( + source=source, + cases=[ + Case(condition=lambda x: x.data < 0, target=target), + Case(condition=lambda x: x.data >= 0, target=target), + ], ) -def test_conditional_edge_group_invalid_number_of_conditions(): +def test_conditional_edge_group_invalid_number_of_default_cases(): """Test creating a conditional edge group with an invalid number of conditions.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - with pytest.raises(ValueError, match="Number of targets must be one more than the number of conditions."): + with pytest.raises(ValueError, match="ConditionalEdgeGroup must contain exactly one default case."): ConditionalEdgeGroup( source=source, - targets=[target1, target2], - conditions=[lambda x: x.data < 0, lambda x: x.data > 0], + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + Default(target=target2), + ], ) @@ -517,8 +536,10 @@ async def test_conditional_edge_group_send_message(): edge_group = ConditionalEdgeGroup( source=source, - targets=[target1, target2], - conditions=[lambda x: x.data < 0], + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + ], ) from agent_framework_workflow._runner_context import InProcRunnerContext, Message @@ -554,8 +575,10 @@ async def test_conditional_edge_group_send_message_with_invalid_target(): edge_group = ConditionalEdgeGroup( source=source, - targets=[target1, target2], - conditions=[lambda x: x.data < 0], + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + ], ) from agent_framework_workflow._runner_context import InProcRunnerContext, Message @@ -579,8 +602,10 @@ async def test_conditional_edge_group_send_message_with_valid_target(): edge_group = ConditionalEdgeGroup( source=source, - targets=[target1, target2], - conditions=[lambda x: x.data < 0], + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + ], ) from agent_framework_workflow._runner_context import InProcRunnerContext, Message @@ -609,8 +634,10 @@ async def test_conditional_edge_group_send_message_with_invalid_data(): edge_group = ConditionalEdgeGroup( source=source, - targets=[target1, target2], - conditions=[lambda x: x.data < 0], + cases=[ + Case(condition=lambda x: x.data < 0, target=target1), + Default(target=target2), + ], ) from agent_framework_workflow._runner_context import InProcRunnerContext, Message diff --git a/python/samples/getting_started/workflow/step_00_foundation_patterns.py b/python/samples/getting_started/workflow/step_00_foundation_patterns.py index 30f2b38b04..ae7f38af22 100644 --- a/python/samples/getting_started/workflow/step_00_foundation_patterns.py +++ b/python/samples/getting_started/workflow/step_00_foundation_patterns.py @@ -3,7 +3,7 @@ import asyncio from typing import Any -from agent_framework.workflow import Executor, WorkflowBuilder, WorkflowContext, handler +from agent_framework.workflow import Case, Default, Executor, WorkflowBuilder, WorkflowContext, handler """ The following sample demonstrates the foundation patterns that the workflow framework supports. @@ -195,13 +195,16 @@ async def conditional_fan_out_edges(): workflow = ( WorkflowBuilder() .set_start_executor(add_one_executor) - .add_conditional_fan_out_edges( + .add_conditional_edge_group( add_one_executor, - [add_one_executor, divide_by_two_executor, multiply_by_two_executor], - # Loop back to the add_one_executor if the number is less than 11 - # and to the multiply_by_two_executor when the number is larger than or equal to 11 and even. - # Otherwise, send to the divide_by_two_executor. - conditions=[lambda x: x < 11, lambda x: x % 2 == 0], + [ + # Loop back to the add_one_executor if the number is less than 11 + Case(condition=lambda x: x < 11, target=add_one_executor), + # multiply_by_two_executor when the number is larger than or equal to 11 and even. + Case(condition=lambda x: x % 2 == 0, target=multiply_by_two_executor), + # Otherwise, send to the divide_by_two_executor. + Default(target=divide_by_two_executor), + ], ) .add_fan_in_edges([multiply_by_two_executor, divide_by_two_executor], aggregate_result_executor) .build() diff --git a/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py b/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py index 4f00c3c5a4..3da2495a9d 100644 --- a/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py +++ b/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py @@ -3,7 +3,15 @@ import asyncio from dataclasses import dataclass -from agent_framework.workflow import Executor, WorkflowBuilder, WorkflowCompletedEvent, WorkflowContext, handler +from agent_framework.workflow import ( + Case, + Default, + Executor, + WorkflowBuilder, + WorkflowCompletedEvent, + WorkflowContext, + handler, +) """ The following sample demonstrates a basic workflow with two executors @@ -91,10 +99,12 @@ async def main(): workflow = ( WorkflowBuilder() .set_start_executor(spam_detector) - .add_conditional_fan_out_edges( + .add_conditional_edge_group( spam_detector, - [send_response, remove_spam], - conditions=[lambda x: not x.is_spam], + [ + Case(condition=lambda x: x.is_spam, target=remove_spam), + Default(target=send_response), + ], ) .build() ) From 6fd81980aefcda7e4067f17cdd81dba650710550 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 12 Aug 2025 12:52:03 -0700 Subject: [PATCH 08/11] Minor updates to sample --- .../getting_started/workflow/step_00_foundation_patterns.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/samples/getting_started/workflow/step_00_foundation_patterns.py b/python/samples/getting_started/workflow/step_00_foundation_patterns.py index ae7f38af22..9156d0ef28 100644 --- a/python/samples/getting_started/workflow/step_00_foundation_patterns.py +++ b/python/samples/getting_started/workflow/step_00_foundation_patterns.py @@ -99,7 +99,8 @@ async def single_edge_with_condition(): Three executors are connected: AddOneExecutor -> AddOneExecutor, AggregateResultExecutor. The AddOneExecutor will loop back to itself until the number reaches 10, then it will start - sending the result to AggregateResultExecutor when the number is greater than 8. + sending the result to AggregateResultExecutor when the number is greater than 8. The workflow + stops when the number reaches 11. Expected output: Adding one to the number: 1 Result: 2 From 0eadc63b4f5942ce3d28bda1090da7c01b0f44ed Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 13 Aug 2025 15:03:17 -0700 Subject: [PATCH 09/11] Collapsing Paritioning Edge group and Conditional Edge group to source edge group --- .../agent_framework_workflow/_edge.py | 161 +++---- .../agent_framework_workflow/_workflow.py | 30 +- python/packages/workflow/tests/test_edge.py | 434 +++++++++--------- .../workflow/step_00_foundation_patterns.py | 38 +- .../step_02_simple_workflow_condition.py | 2 +- 5 files changed, 298 insertions(+), 367 deletions(-) diff --git a/python/packages/workflow/agent_framework_workflow/_edge.py b/python/packages/workflow/agent_framework_workflow/_edge.py index 7e82dc2176..2f7bc6939a 100644 --- a/python/packages/workflow/agent_framework_workflow/_edge.py +++ b/python/packages/workflow/agent_framework_workflow/_edge.py @@ -190,43 +190,59 @@ class SourceEdgeGroup(EdgeGroup): and send messages to their respective target executors. """ - def __init__(self, source: Executor, targets: Sequence[Executor]) -> None: + def __init__( + self, + source: Executor, + targets: Sequence[Executor], + selection_func: Callable[[Any, list[str]], list[str]] | None = None, + ) -> None: """Initialize the source edge group with a list of edges. Args: source (Executor): The source executor. targets (Sequence[Executor]): A list of target executors that the source executor can send messages to. + selection_func (Callable[[Any, list[str]], list[str]], optional): A function that selects which target + executors to send messages to. The function takes in the message data and a list of target executor + IDs, and returns a list of selected target executor IDs. """ if len(targets) <= 1: raise ValueError("SourceEdgeGroup must contain at least two targets.") self._edges = [Edge(source=source, target=target) for target in targets] + self._target_ids = [edge.target_id for edge in self._edges] + self._target_map = {edge.target_id: edge for edge in self._edges} + self._selection_func = selection_func async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: """Send a message through all edges in the source edge group.""" + selection_results = ( + self._selection_func(message.data, self._target_ids) if self._selection_func else self._target_ids + ) + if not self._validate_selection_result(selection_results): + raise RuntimeError( + f"Invalid selection result: {selection_results}. " + f"Expected selections to be a subset of valid target executor IDs: {self._target_ids}." + ) + if message.target_id: - # If the message has a target ID, send it to the specific target executor - target_edge = next((edge for edge in self._edges if edge.target_id == message.target_id), None) - if target_edge and target_edge.can_handle(message.data): - await target_edge.send_message(message, shared_state, ctx) - return True + # If the target ID is specified and the selection result contains it, send the message to that edge + if message.target_id in selection_results: + edge = next((edge for edge in self._edges if edge.target_id == message.target_id), None) + if edge and edge.can_handle(message.data): + await edge.send_message(message, shared_state, ctx) + return True return False - # If no target ID, send the message to all edges in the group - if all(not edge.can_handle(message.data) for edge in self._edges): + # If no target ID, send the message to the selected targets + async def send_to_edge(edge: Edge) -> bool: + """Send the message to the edge at the specified index.""" + if edge.can_handle(message.data): + await edge.send_message(message, shared_state, ctx) + return True return False - await asyncio.gather( - *( - edge.send_message( - message, - shared_state, - ctx, - ) - for edge in self._edges - if edge.can_handle(message.data) - ), - ) - return True + tasks = [send_to_edge(self._target_map[target_id]) for target_id in selection_results] + results = await asyncio.gather(*tasks) + return any(results) @property def source_executors(self) -> list[Executor]: @@ -243,6 +259,10 @@ def edges(self) -> list[Edge]: """Get the edges in the group.""" return self._edges + def _validate_selection_result(self, selection_results: list[str]) -> bool: + """Validate the selection results to ensure all IDs are valid target executor IDs.""" + return all(result in self._target_ids for result in selection_results) + class TargetEdgeGroup(EdgeGroup): """Represents a group of edges that share the same target executor. @@ -334,7 +354,7 @@ class Default: target: Executor -class ConditionalEdgeGroup(SourceEdgeGroup): +class SwitchCaseEdgeGroup(SourceEdgeGroup): """Represents a group of edges that assemble a conditional routing pattern. This is similar to a switch-case construct: @@ -370,11 +390,11 @@ def __init__( There should be exactly one default case. """ if len(cases) < 2: - raise ValueError("ConditionalEdgeGroup must contain at least two cases (including the default case).") + raise ValueError("SwitchCaseEdgeGroup must contain at least two cases (including the default case).") default_case = [isinstance(case, Default) for case in cases] if sum(default_case) != 1: - raise ValueError("ConditionalEdgeGroup must contain exactly one default case.") + raise ValueError("SwitchCaseEdgeGroup must contain exactly one default case.") if isinstance(cases[-1], Default): logger.warning( @@ -382,83 +402,18 @@ def __init__( "This will result in unexpected behavior." ) - self._edges = [ - Edge(source, case.target, case.condition) if isinstance(case, Case) else Edge(source, case.target, None) - for case in cases - ] - - async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: - """Send a message through the conditional edge group.""" - if message.target_id: - # Find the index of the target edge in the edges list if target_id is specified - index = next((i for i, edge in enumerate(self._edges) if edge.target_id == message.target_id), None) - if index is None: - return False - if self._edges[index].can_handle(message.data) and self._edges[index].should_route(message.data): - await self._edges[index].send_message(message, shared_state, ctx) - return True - return False - - for edge in self._edges: - if edge.can_handle(message.data) and edge.should_route(message.data): - await edge.send_message(message, shared_state, ctx) - return True - - return False - - -class PartitioningEdgeGroup(SourceEdgeGroup): - """Represents a group of edges that can route messages based on a partitioning strategy. - - Messages from the source executor are routed to multiple target executors based on a partitioning function. - """ - - def __init__( - self, source: Executor, targets: Sequence[Executor], partition_func: Callable[[Any, int], list[int]] - ) -> None: - """Initialize the partitioning edge group with a list of edges. - - Args: - source (Executor): The source executor. - targets (Sequence[Executor]): A list of target executors that the source executor can send messages to. - partition_func (Callable[[Any, int], list[int]]): A partitioning function that determines which target - executors to route the message to based on the data. The function should take the message data and - the number of targets, and return a list of indices of the target executors to route the message to. - """ - if len(targets) <= 1: - raise ValueError("PartitioningEdgeGroup must contain at least two targets.") - self._edges = [Edge(source=source, target=target) for target in targets] - self._partition_func = partition_func - - async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: - """Send a message through the partitioning edge group.""" - partition_result = self._partition_func(message.data, len(self._edges)) - if not self._validate_partition_result(partition_result): - raise RuntimeError( - f"Invalid partition result: {partition_result}. Expected indices in range: [0, {len(self._edges) - 1}]." - ) - - if message.target_id: - # If the target ID is specified and the partition result contains it, send the message to that edge - has_target = message.target_id in [self._edges[index].target_id for index in partition_result] - if has_target: - edge = next((edge for edge in self._edges if edge.target_id == message.target_id), None) - if edge and edge.can_handle(message.data): - await edge.send_message(message, shared_state, ctx) - return True - return False - - async def send_to_edge(edge: Edge) -> bool: - """Send the message to the edge at the specified index.""" - if edge.can_handle(message.data): - await edge.send_message(message, shared_state, ctx) - return True - return False - - tasks = [send_to_edge(self._edges[index]) for index in partition_result] - results = await asyncio.gather(*tasks) - return any(results) - - def _validate_partition_result(self, partition_result: list[int]) -> bool: - """Validate the partition result to ensure all indices are within bounds.""" - return all(0 <= index < len(self._edges) for index in partition_result) + def selection_func(data: Any, targets: list[str]) -> list[str]: + """Select the target executor based on the conditions.""" + for index, case in enumerate(cases): + if isinstance(case, Default): + return [case.target.id] + if isinstance(case, Case): + try: + if case.condition(data): + return [case.target.id] + except Exception as e: + logger.warning(f"Error occurred while evaluating condition for case {index}: {e}") + + raise RuntimeError("No matching case found in SwitchCaseEdgeGroup.") + + super().__init__(source, [case.target for case in cases], selection_func=selection_func) diff --git a/python/packages/workflow/agent_framework_workflow/_workflow.py b/python/packages/workflow/agent_framework_workflow/_workflow.py index ade9c85ee4..7060e2ed84 100644 --- a/python/packages/workflow/agent_framework_workflow/_workflow.py +++ b/python/packages/workflow/agent_framework_workflow/_workflow.py @@ -11,12 +11,11 @@ from ._const import DEFAULT_MAX_ITERATIONS from ._edge import ( Case, - ConditionalEdgeGroup, Default, EdgeGroup, - PartitioningEdgeGroup, SingleEdgeGroup, SourceEdgeGroup, + SwitchCaseEdgeGroup, TargetEdgeGroup, ) from ._events import RequestInfoEvent, WorkflowCompletedEvent, WorkflowEvent @@ -457,10 +456,9 @@ def add_edge( return self def add_fan_out_edges(self, source: Executor, targets: Sequence[Executor]) -> "Self": - """Add multiple edges to the workflow. + """Add multiple edges to the workflow where messages from the source will be sent to all target. The output types of the source and the input types of the targets must be compatible. - Messages from the source executor will be sent to all target executors. Args: source: The source executor of the edges. @@ -470,8 +468,8 @@ def add_fan_out_edges(self, source: Executor, targets: Sequence[Executor]) -> "S return self - def add_conditional_edge_group(self, source: Executor, cases: Sequence[Case | Default]) -> "Self": - """Add a conditional fan out group of edges to the workflow. + def add_switch_case_edge_group(self, source: Executor, cases: Sequence[Case | Default]) -> "Self": + """Add an edge group that represents a switch-case statement. The output types of the source and the input types of the targets must be compatible. Messages from the source executor will be sent to one of the target executors based on @@ -481,38 +479,38 @@ def add_conditional_edge_group(self, source: Executor, cases: Sequence[Case | De Each condition function will be evaluated in order, and the first one that returns True will determine which target executor receives the message. - The number of targets must be one greater than the number of conditions. The last target - executor will receive messages that fall through all conditions (i.e., no condition matched). + The last case (the default case) will receive messages that fall through all conditions + (i.e., no condition matched). Args: source: The source executor of the edges. cases: A list of case objects that determine the target executor for each message. """ - self._edge_groups.append(ConditionalEdgeGroup(source, cases)) + self._edge_groups.append(SwitchCaseEdgeGroup(source, cases)) return self - def add_partitioning_fan_out_edges( + def add_multi_selection_edge_group( self, source: Executor, targets: Sequence[Executor], - partition_func: Callable[[Any, int], list[int]], + selection_func: Callable[[Any, list[str]], list[str]], ) -> "Self": - """Add a partitioning fan out group of edges to the workflow. + """Add an edge group that represents a multi-selection execution model. The output types of the source and the input types of the targets must be compatible. Messages from the source executor will be sent to multiple target executors based on - the provided partition function. + the provided selection function. - The partition function should take a message and the number of target executors, + The selection function should take a message and the name of the target executors, and return a list of indices indicating which target executors should receive the message. Args: source: The source executor of the edges. targets: A list of target executors for the edges. - partition_func: A function that partitions messages to target executors. + selection_func: A function that selects target executors for messages. """ - self._edge_groups.append(PartitioningEdgeGroup(source, targets, partition_func)) + self._edge_groups.append(SourceEdgeGroup(source, targets, selection_func)) return self diff --git a/python/packages/workflow/tests/test_edge.py b/python/packages/workflow/tests/test_edge.py index 527e5ee34c..600b0624ec 100644 --- a/python/packages/workflow/tests/test_edge.py +++ b/python/packages/workflow/tests/test_edge.py @@ -9,12 +9,11 @@ from agent_framework_workflow._edge import ( Case, - ConditionalEdgeGroup, Default, Edge, - PartitioningEdgeGroup, SingleEdgeGroup, SourceEdgeGroup, + SwitchCaseEdgeGroup, TargetEdgeGroup, ) @@ -344,6 +343,180 @@ async def test_source_edge_group_send_message_only_one_successful_send(): assert mock_send.call_count == 1 +def test_source_edge_group_with_selection_func(): + """Test creating a partitioning edge group.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SourceEdgeGroup( + source=source, + targets=[target1, target2], + selection_func=lambda data, target_ids: [target1.id], + ) + + assert edge_group.source_executors == [source] + assert edge_group.target_executors == [target1, target2] + assert len(edge_group.edges) == 2 + assert edge_group.edges[0].source_id == "source_executor" + assert edge_group.edges[0].target_id == "target_executor_1" + assert edge_group.edges[1].source_id == "source_executor" + assert edge_group.edges[1].target_id == "target_executor_2" + + +async def test_source_edge_group_with_selection_func_send_message(): + """Test sending a message through a source edge group with a selection function.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SourceEdgeGroup( + source=source, + targets=[target1, target2], + selection_func=lambda data, target_ids: [target1.id, target2.id], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 2 + + +async def test_source_edge_group_with_selection_func_send_message_with_invalid_selection_result(): + """Test sending a message through a source edge group with a selection func with an invalid selection result.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SourceEdgeGroup( + source=source, + targets=[target1, target2], + selection_func=lambda data, target_ids: [target1.id, "invalid_target"], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id) + + with pytest.raises(RuntimeError): + await edge_group.send_message(message, shared_state, ctx) + + +async def test_source_edge_group_with_selection_func_send_message_with_target(): + """Test sending a message through a source edge group with a selection func with a target.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SourceEdgeGroup( + source=source, + targets=[target1, target2], + selection_func=lambda data, target_ids: [target1.id, target2.id], + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id=target1.id) + + with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: + success = await edge_group.send_message(message, shared_state, ctx) + + assert success is True + assert mock_send.call_count == 1 + assert mock_send.call_args[0][0].target_id == target1.id + + +async def test_source_edge_group_with_selection_func_send_message_with_target_not_in_selection(): + """Test sending a message through a source edge group with a selection func with a target not in the selection.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SourceEdgeGroup( + source=source, + targets=[target1, target2], + selection_func=lambda data, target_ids: [target1.id], # Only target1 will receive the message + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = MockMessage(data="test") + message = Message(data=data, source_id=source.id, target_id=target2.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_source_edge_group_with_selection_func_send_message_with_invalid_data(): + """Test sending a message through a source edge group with a selection func with invalid data.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SourceEdgeGroup( + source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id] + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + +async def test_source_edge_group_with_selection_func_send_message_with_target_invalid_data(): + """Test sending a message through a source edge group with a selection func with a target and invalid data.""" + source = MockExecutor(id="source_executor") + target1 = MockExecutor(id="target_executor_1") + target2 = MockExecutor(id="target_executor_2") + + edge_group = SourceEdgeGroup( + source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id] + ) + + from agent_framework_workflow._runner_context import InProcRunnerContext, Message + from agent_framework_workflow._shared_state import SharedState + + shared_state = SharedState() + ctx = InProcRunnerContext() + + data = "invalid_data" + message = Message(data=data, source_id=source.id, target_id=target1.id) + + success = await edge_group.send_message(message, shared_state, ctx) + assert success is False + + # endregion SourceEdgeGroup # region TargetEdgeGroup @@ -458,16 +631,16 @@ async def test_target_edge_group_send_message_with_invalid_data(): # endregion TargetEdgeGroup -# region ConditionalEdgeGroup +# region SwitchCaseEdgeGroup -def test_conditional_edge_group(): - """Test creating a conditional edge group.""" +def test_switch_case_edge_group(): + """Test creating a switch case edge group.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = ConditionalEdgeGroup( + edge_group = SwitchCaseEdgeGroup( source=source, cases=[ Case(condition=lambda x: x.data < 0, target=target1), @@ -480,29 +653,31 @@ def test_conditional_edge_group(): assert len(edge_group.edges) == 2 assert edge_group.edges[0].source_id == "source_executor" assert edge_group.edges[0].target_id == "target_executor_1" - assert edge_group.edges[0]._condition is not None # type: ignore assert edge_group.edges[1].source_id == "source_executor" assert edge_group.edges[1].target_id == "target_executor_2" - assert edge_group.edges[1]._condition is None # type: ignore + + assert edge_group._selection_func is not None # type: ignore + assert edge_group._selection_func(MockMessage(data=-1), [target1.id, target2.id]) == [target1.id] # type: ignore + assert edge_group._selection_func(MockMessage(data=1), [target1.id, target2.id]) == [target2.id] # type: ignore -def test_conditional_edge_group_invalid_number_of_cases(): - """Test creating a conditional edge group with an invalid number of cases.""" +def test_switch_case_edge_group_invalid_number_of_cases(): + """Test creating a switch case edge group with an invalid number of cases.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") with pytest.raises( - ValueError, match=r"ConditionalEdgeGroup must contain at least two cases \(including the default case\)." + ValueError, match=r"SwitchCaseEdgeGroup must contain at least two cases \(including the default case\)." ): - ConditionalEdgeGroup( + SwitchCaseEdgeGroup( source=source, cases=[ Case(condition=lambda x: x.data < 0, target=target), ], ) - with pytest.raises(ValueError, match="ConditionalEdgeGroup must contain exactly one default case."): - ConditionalEdgeGroup( + with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."): + SwitchCaseEdgeGroup( source=source, cases=[ Case(condition=lambda x: x.data < 0, target=target), @@ -511,14 +686,14 @@ def test_conditional_edge_group_invalid_number_of_cases(): ) -def test_conditional_edge_group_invalid_number_of_default_cases(): - """Test creating a conditional edge group with an invalid number of conditions.""" +def test_switch_case_edge_group_invalid_number_of_default_cases(): + """Test creating a switch case edge group with an invalid number of conditions.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - with pytest.raises(ValueError, match="ConditionalEdgeGroup must contain exactly one default case."): - ConditionalEdgeGroup( + with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."): + SwitchCaseEdgeGroup( source=source, cases=[ Case(condition=lambda x: x.data < 0, target=target1), @@ -528,13 +703,13 @@ def test_conditional_edge_group_invalid_number_of_default_cases(): ) -async def test_conditional_edge_group_send_message(): - """Test sending a message through a conditional edge group.""" +async def test_switch_case_edge_group_send_message(): + """Test sending a message through a switch case edge group.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = ConditionalEdgeGroup( + edge_group = SwitchCaseEdgeGroup( source=source, cases=[ Case(condition=lambda x: x.data < 0, target=target1), @@ -567,13 +742,13 @@ async def test_conditional_edge_group_send_message(): assert mock_send.call_count == 1 -async def test_conditional_edge_group_send_message_with_invalid_target(): - """Test sending a message through a conditional edge group with an invalid target.""" +async def test_switch_case_edge_group_send_message_with_invalid_target(): + """Test sending a message through a switch case edge group with an invalid target.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = ConditionalEdgeGroup( + edge_group = SwitchCaseEdgeGroup( source=source, cases=[ Case(condition=lambda x: x.data < 0, target=target1), @@ -594,13 +769,13 @@ async def test_conditional_edge_group_send_message_with_invalid_target(): assert success is False -async def test_conditional_edge_group_send_message_with_valid_target(): - """Test sending a message through a conditional edge group with a target.""" +async def test_switch_case_edge_group_send_message_with_valid_target(): + """Test sending a message through a switch case edge group with a target.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = ConditionalEdgeGroup( + edge_group = SwitchCaseEdgeGroup( source=source, cases=[ Case(condition=lambda x: x.data < 0, target=target1), @@ -626,13 +801,13 @@ async def test_conditional_edge_group_send_message_with_valid_target(): assert success is True -async def test_conditional_edge_group_send_message_with_invalid_data(): - """Test sending a message through a conditional edge group with invalid data.""" +async def test_switch_case_edge_group_send_message_with_invalid_data(): + """Test sending a message through a switch case edge group with invalid data.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = ConditionalEdgeGroup( + edge_group = SwitchCaseEdgeGroup( source=source, cases=[ Case(condition=lambda x: x.data < 0, target=target1), @@ -653,201 +828,4 @@ async def test_conditional_edge_group_send_message_with_invalid_data(): assert success is False -# endregion ConditionalEdgeGroup - - -# region PartitioningEdgeGroup - - -def test_partitioning_edge_group(): - """Test creating a partitioning edge group.""" - source = MockExecutor(id="source_executor") - target1 = MockExecutor(id="target_executor_1") - target2 = MockExecutor(id="target_executor_2") - - edge_group = PartitioningEdgeGroup( - source=source, - targets=[target1, target2], - partition_func=lambda data, num_edges: [0], - ) - - assert edge_group.source_executors == [source] - assert edge_group.target_executors == [target1, target2] - assert len(edge_group.edges) == 2 - assert edge_group.edges[0].source_id == "source_executor" - assert edge_group.edges[0].target_id == "target_executor_1" - assert edge_group.edges[1].source_id == "source_executor" - assert edge_group.edges[1].target_id == "target_executor_2" - - -def test_partitioning_edge_group_invalid_number_of_targets(): - """Test creating a partitioning edge group with an invalid number of targets.""" - source = MockExecutor(id="source_executor") - target = MockExecutor(id="target_executor") - - with pytest.raises(ValueError, match="PartitioningEdgeGroup must contain at least two targets."): - PartitioningEdgeGroup( - source=source, - targets=[target], - partition_func=lambda data, num_edges: [0], - ) - - -async def test_partitioning_edge_group_send_message(): - """Test sending a message through a partitioning edge group.""" - source = MockExecutor(id="source_executor") - target1 = MockExecutor(id="target_executor_1") - target2 = MockExecutor(id="target_executor_2") - - edge_group = PartitioningEdgeGroup( - source=source, - targets=[target1, target2], - partition_func=lambda data, num_edges: [0, 1], - ) - - from agent_framework_workflow._runner_context import InProcRunnerContext, Message - from agent_framework_workflow._shared_state import SharedState - - shared_state = SharedState() - ctx = InProcRunnerContext() - - data = MockMessage(data="test") - message = Message(data=data, source_id=source.id) - - with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: - success = await edge_group.send_message(message, shared_state, ctx) - - assert success is True - assert mock_send.call_count == 2 - - -async def test_partitioning_edge_group_send_message_with_invalid_partition_result(): - """Test sending a message through a partitioning edge group with an invalid partition result.""" - source = MockExecutor(id="source_executor") - target1 = MockExecutor(id="target_executor_1") - target2 = MockExecutor(id="target_executor_2") - - edge_group = PartitioningEdgeGroup( - source=source, - targets=[target1, target2], - partition_func=lambda data, num_edges: [0, 2], # Invalid index - ) - - from agent_framework_workflow._runner_context import InProcRunnerContext, Message - from agent_framework_workflow._shared_state import SharedState - - shared_state = SharedState() - ctx = InProcRunnerContext() - - data = MockMessage(data="test") - message = Message(data=data, source_id=source.id) - - with pytest.raises(RuntimeError): - await edge_group.send_message(message, shared_state, ctx) - - -async def test_partitioning_edge_group_send_message_with_target(): - """Test sending a message through a partitioning edge group with a target.""" - source = MockExecutor(id="source_executor") - target1 = MockExecutor(id="target_executor_1") - target2 = MockExecutor(id="target_executor_2") - - edge_group = PartitioningEdgeGroup( - source=source, - targets=[target1, target2], - partition_func=lambda data, num_edges: [0, 1], - ) - - from agent_framework_workflow._runner_context import InProcRunnerContext, Message - from agent_framework_workflow._shared_state import SharedState - - shared_state = SharedState() - ctx = InProcRunnerContext() - - data = MockMessage(data="test") - message = Message(data=data, source_id=source.id, target_id=target1.id) - - with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send: - success = await edge_group.send_message(message, shared_state, ctx) - - assert success is True - assert mock_send.call_count == 1 - assert mock_send.call_args[0][0].target_id == target1.id - - -async def test_partitioning_edge_group_send_message_with_target_not_in_partition(): - """Test sending a message through a partitioning edge group with a target not in the partition.""" - source = MockExecutor(id="source_executor") - target1 = MockExecutor(id="target_executor_1") - target2 = MockExecutor(id="target_executor_2") - - edge_group = PartitioningEdgeGroup( - source=source, - targets=[target1, target2], - partition_func=lambda data, num_edges: [0], # Only target1 will receive the message - ) - - from agent_framework_workflow._runner_context import InProcRunnerContext, Message - from agent_framework_workflow._shared_state import SharedState - - shared_state = SharedState() - ctx = InProcRunnerContext() - - data = MockMessage(data="test") - message = Message(data=data, source_id=source.id, target_id=target2.id) - - success = await edge_group.send_message(message, shared_state, ctx) - assert success is False - - -async def test_partitioning_edge_group_send_message_with_invalid_data(): - """Test sending a message through a partitioning edge group with invalid data.""" - source = MockExecutor(id="source_executor") - target1 = MockExecutor(id="target_executor_1") - target2 = MockExecutor(id="target_executor_2") - - edge_group = PartitioningEdgeGroup( - source=source, - targets=[target1, target2], - partition_func=lambda data, num_edges: [0, 1], - ) - - from agent_framework_workflow._runner_context import InProcRunnerContext, Message - from agent_framework_workflow._shared_state import SharedState - - shared_state = SharedState() - ctx = InProcRunnerContext() - - data = "invalid_data" - message = Message(data=data, source_id=source.id) - - success = await edge_group.send_message(message, shared_state, ctx) - assert success is False - - -async def test_partitioning_edge_group_send_message_with_target_invalid_data(): - """Test sending a message through a partitioning edge group with a target and invalid data.""" - source = MockExecutor(id="source_executor") - target1 = MockExecutor(id="target_executor_1") - target2 = MockExecutor(id="target_executor_2") - - edge_group = PartitioningEdgeGroup( - source=source, - targets=[target1, target2], - partition_func=lambda data, num_edges: [0, 1], - ) - - from agent_framework_workflow._runner_context import InProcRunnerContext, Message - from agent_framework_workflow._shared_state import SharedState - - shared_state = SharedState() - ctx = InProcRunnerContext() - - data = "invalid_data" - message = Message(data=data, source_id=source.id, target_id=target1.id) - - success = await edge_group.send_message(message, shared_state, ctx) - assert success is False - - -# endregion PartitioningEdgeGroup +# endregion SwitchCaseEdgeGroup diff --git a/python/samples/getting_started/workflow/step_00_foundation_patterns.py b/python/samples/getting_started/workflow/step_00_foundation_patterns.py index 9156d0ef28..cf1dc44f5f 100644 --- a/python/samples/getting_started/workflow/step_00_foundation_patterns.py +++ b/python/samples/getting_started/workflow/step_00_foundation_patterns.py @@ -131,7 +131,7 @@ async def single_edge_with_condition(): await workflow.run(1) -async def fan_out_fan_in_edges(): +async def fan_out_fan_in_edge_group(): """A sample to demonstrate a fan-out and fan-in connection between executors. Four executors are connected in a fan-out and fan-in pattern: @@ -163,10 +163,10 @@ async def fan_out_fan_in_edges(): await workflow.run(1) -async def conditional_fan_out_edges(): - """A sample to demonstrate a conditional fan-out connection. +async def switch_case_edge_group(): + """A sample to demonstrate a switch-case connection. - Four executors are connected in a conditional fan-out pattern: + Four executors are connected in a switch-case pattern: AddOneExecutor -> AddOneExecutor, MultiplyByTwoExecutor, DivideByTwoExecutor -> AggregateResultExecutor. The message from AddOneExecutor will be evaluated against the conditions one by one, and the first condition @@ -196,7 +196,7 @@ async def conditional_fan_out_edges(): workflow = ( WorkflowBuilder() .set_start_executor(add_one_executor) - .add_conditional_edge_group( + .add_switch_case_edge_group( add_one_executor, [ # Loop back to the add_one_executor if the number is less than 11 @@ -214,10 +214,10 @@ async def conditional_fan_out_edges(): await workflow.run(1) -async def partitioning_fan_out_edges(): - """A sample to demonstrate a partitioning fan-out connection. +async def multi_selection_edge_group(): + """A sample to demonstrate a multi-selection edge connection. - Four executors are connected in a partitioning fan-out pattern: + Four executors are connected in a multi-selection edge pattern: AddOneExecutor -> AddOneExecutor, MultiplyByTwoExecutor, DivideByTwoExecutor -> AggregateResultExecutor. The AddOneExecutor sends its output to one or more executors based on the partitioning function. @@ -244,27 +244,27 @@ async def partitioning_fan_out_edges(): divide_by_two_executor = DivideByTwoExecutor() aggregate_result_executor = AggregateResultExecutor() - def partition_func(number: int, total_targets: int) -> list[int]: - """Partition function to determine which executor to send the number to.""" + def selection_func(number: int, target_ids: list[str]) -> list[str]: + """Selection function to determine which executor to send the number to.""" if number < 12: # Loop back to the add_one_executor if the number is less than 12 - return [0] + return [add_one_executor.id] if number % 2 == 0: # Send it to the add_one_executor to add one more time and the # divide_by_two_executor to divide the result by two. - return [0, 2] + return [add_one_executor.id, divide_by_two_executor.id] # Otherwise, send it to the multiply_by_two_executor to multiply the result by two. - return [1] + return [multiply_by_two_executor.id] workflow = ( WorkflowBuilder() .set_start_executor(add_one_executor) - .add_partitioning_fan_out_edges( + .add_multi_selection_edge_group( add_one_executor, [add_one_executor, multiply_by_two_executor, divide_by_two_executor], - partition_func=partition_func, + selection_func=selection_func, ) .add_fan_in_edges([multiply_by_two_executor, divide_by_two_executor], aggregate_result_executor) .build() @@ -280,11 +280,11 @@ async def main(): print("**Running single connection with condition workflow**") await single_edge_with_condition() print("**Running fan-out and fan-in connection workflow**") - await fan_out_fan_in_edges() + await fan_out_fan_in_edge_group() print("**Running conditional fan-out connection workflow**") - await conditional_fan_out_edges() - print("**Running partitioning fan-out connection workflow**") - await partitioning_fan_out_edges() + await switch_case_edge_group() + print("**Running multi-selection edge group workflow**") + await multi_selection_edge_group() if __name__ == "__main__": diff --git a/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py b/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py index 3da2495a9d..3a526ce7b7 100644 --- a/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py +++ b/python/samples/getting_started/workflow/step_02_simple_workflow_condition.py @@ -99,7 +99,7 @@ async def main(): workflow = ( WorkflowBuilder() .set_start_executor(spam_detector) - .add_conditional_edge_group( + .add_switch_case_edge_group( spam_detector, [ Case(condition=lambda x: x.is_spam, target=remove_spam), From bbab69a98516b6f11baecf762195bf83d0d21222 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 13 Aug 2025 15:12:48 -0700 Subject: [PATCH 10/11] Improve sample clarity --- .../getting_started/workflow/step_00_foundation_patterns.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/samples/getting_started/workflow/step_00_foundation_patterns.py b/python/samples/getting_started/workflow/step_00_foundation_patterns.py index cf1dc44f5f..2827d84fab 100644 --- a/python/samples/getting_started/workflow/step_00_foundation_patterns.py +++ b/python/samples/getting_started/workflow/step_00_foundation_patterns.py @@ -197,8 +197,8 @@ async def switch_case_edge_group(): WorkflowBuilder() .set_start_executor(add_one_executor) .add_switch_case_edge_group( - add_one_executor, - [ + source=add_one_executor, + cases=[ # Loop back to the add_one_executor if the number is less than 11 Case(condition=lambda x: x < 11, target=add_one_executor), # multiply_by_two_executor when the number is larger than or equal to 11 and even. From a5301ace4f7162ca3505d5aa1dc33bc887db545e Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Thu, 14 Aug 2025 14:56:55 -0700 Subject: [PATCH 11/11] Name consolidation --- .../agent_framework_workflow/_edge.py | 18 ++-- .../agent_framework_workflow/_validation.py | 4 +- .../agent_framework_workflow/_workflow.py | 10 +-- python/packages/workflow/tests/test_edge.py | 90 +++++++++---------- 4 files changed, 61 insertions(+), 61 deletions(-) diff --git a/python/packages/workflow/agent_framework_workflow/_edge.py b/python/packages/workflow/agent_framework_workflow/_edge.py index 2f7bc6939a..fef5e3376d 100644 --- a/python/packages/workflow/agent_framework_workflow/_edge.py +++ b/python/packages/workflow/agent_framework_workflow/_edge.py @@ -183,7 +183,7 @@ def edges(self) -> list[Edge]: return [self._edge] -class SourceEdgeGroup(EdgeGroup): +class FanOutEdgeGroup(EdgeGroup): """Represents a group of edges that share the same source executor. Assembles a Fan-out pattern where multiple edges share the same source executor @@ -196,7 +196,7 @@ def __init__( targets: Sequence[Executor], selection_func: Callable[[Any, list[str]], list[str]] | None = None, ) -> None: - """Initialize the source edge group with a list of edges. + """Initialize the fan-out edge group with a list of edges. Args: source (Executor): The source executor. @@ -206,14 +206,14 @@ def __init__( IDs, and returns a list of selected target executor IDs. """ if len(targets) <= 1: - raise ValueError("SourceEdgeGroup must contain at least two targets.") + raise ValueError("FanOutEdgeGroup must contain at least two targets.") self._edges = [Edge(source=source, target=target) for target in targets] self._target_ids = [edge.target_id for edge in self._edges] self._target_map = {edge.target_id: edge for edge in self._edges} self._selection_func = selection_func async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: - """Send a message through all edges in the source edge group.""" + """Send a message through all edges in the fan-out edge group.""" selection_results = ( self._selection_func(message.data, self._target_ids) if self._selection_func else self._target_ids ) @@ -264,7 +264,7 @@ def _validate_selection_result(self, selection_results: list[str]) -> bool: return all(result in self._target_ids for result in selection_results) -class TargetEdgeGroup(EdgeGroup): +class FanInEdgeGroup(EdgeGroup): """Represents a group of edges that share the same target executor. Assembles a Fan-in pattern where multiple edges send messages to a single target executor. @@ -272,21 +272,21 @@ class TargetEdgeGroup(EdgeGroup): """ def __init__(self, sources: Sequence[Executor], target: Executor) -> None: - """Initialize the target edge group with a list of edges. + """Initialize the fan-in edge group with a list of edges. Args: sources (Sequence[Executor]): A list of source executors that can send messages to the target executor. target (Executor): The target executor that receives a list of messages aggregated from all sources. """ if len(sources) <= 1: - raise ValueError("TargetEdgeGroup must contain at least two sources.") + raise ValueError("FanInEdgeGroup must contain at least two sources.") self._edges = [Edge(source=source, target=target) for source in sources] # Buffer to hold messages before sending them to the target executor # Key is the source executor ID, value is a list of messages self._buffer: dict[str, list[Message]] = defaultdict(list) async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: - """Send a message through all edges in the target edge group.""" + """Send a message through all edges in the fan-in edge group.""" if message.target_id and message.target_id != self._edges[0].target_id: return False @@ -354,7 +354,7 @@ class Default: target: Executor -class SwitchCaseEdgeGroup(SourceEdgeGroup): +class SwitchCaseEdgeGroup(FanOutEdgeGroup): """Represents a group of edges that assemble a conditional routing pattern. This is similar to a switch-case construct: diff --git a/python/packages/workflow/agent_framework_workflow/_validation.py b/python/packages/workflow/agent_framework_workflow/_validation.py index 5c5a352ac4..3c4d8c12fd 100644 --- a/python/packages/workflow/agent_framework_workflow/_validation.py +++ b/python/packages/workflow/agent_framework_workflow/_validation.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Any, Union, get_args, get_origin -from ._edge import Edge, EdgeGroup, TargetEdgeGroup +from ._edge import Edge, EdgeGroup, FanInEdgeGroup from ._executor import Executor logger = logging.getLogger(__name__) @@ -207,7 +207,7 @@ def _validate_edge_type_compatibility(self, edge: Edge, edge_group: EdgeGroup) - for source_type in source_output_types: for target_type in target_input_types: - if isinstance(edge_group, TargetEdgeGroup): + if isinstance(edge_group, FanInEdgeGroup): # If the edge is part of an edge group, the target expects a list of data types if self._is_type_compatible(list[source_type], target_type): compatible = True diff --git a/python/packages/workflow/agent_framework_workflow/_workflow.py b/python/packages/workflow/agent_framework_workflow/_workflow.py index 7060e2ed84..f2b80ce953 100644 --- a/python/packages/workflow/agent_framework_workflow/_workflow.py +++ b/python/packages/workflow/agent_framework_workflow/_workflow.py @@ -13,10 +13,10 @@ Case, Default, EdgeGroup, + FanInEdgeGroup, + FanOutEdgeGroup, SingleEdgeGroup, - SourceEdgeGroup, SwitchCaseEdgeGroup, - TargetEdgeGroup, ) from ._events import RequestInfoEvent, WorkflowCompletedEvent, WorkflowEvent from ._executor import Executor, RequestInfoExecutor @@ -464,7 +464,7 @@ def add_fan_out_edges(self, source: Executor, targets: Sequence[Executor]) -> "S source: The source executor of the edges. targets: A list of target executors for the edges. """ - self._edge_groups.append(SourceEdgeGroup(source, targets)) + self._edge_groups.append(FanOutEdgeGroup(source, targets)) return self @@ -510,7 +510,7 @@ def add_multi_selection_edge_group( targets: A list of target executors for the edges. selection_func: A function that selects target executors for messages. """ - self._edge_groups.append(SourceEdgeGroup(source, targets, selection_func)) + self._edge_groups.append(FanOutEdgeGroup(source, targets, selection_func)) return self @@ -548,7 +548,7 @@ def handle_message(self, message: Message) -> None: sources: A list of source executors for the edges. target: The target executor for the edges. """ - self._edge_groups.append(TargetEdgeGroup(sources, target)) + self._edge_groups.append(FanInEdgeGroup(sources, target)) return self diff --git a/python/packages/workflow/tests/test_edge.py b/python/packages/workflow/tests/test_edge.py index 600b0624ec..ddef2cd649 100644 --- a/python/packages/workflow/tests/test_edge.py +++ b/python/packages/workflow/tests/test_edge.py @@ -11,10 +11,10 @@ Case, Default, Edge, + FanInEdgeGroup, + FanOutEdgeGroup, SingleEdgeGroup, - SourceEdgeGroup, SwitchCaseEdgeGroup, - TargetEdgeGroup, ) @@ -199,16 +199,16 @@ async def test_single_edge_group_send_message_with_invalid_data(): # endregion SingleEdgeGroup -# region SourceEdgeGroup +# region FanOutEdgeGroup def test_source_edge_group(): - """Test creating a source edge group.""" + """Test creating a fan-out group.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = SourceEdgeGroup(source=source, targets=[target1, target2]) + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) assert edge_group.source_executors == [source] assert edge_group.target_executors == [target1, target2] @@ -220,21 +220,21 @@ def test_source_edge_group(): def test_source_edge_group_invalid_number_of_targets(): - """Test creating a source edge group with an invalid number of targets.""" + """Test creating a fan-out group with an invalid number of targets.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") - with pytest.raises(ValueError, match="SourceEdgeGroup must contain at least two targets"): - SourceEdgeGroup(source=source, targets=[target]) + with pytest.raises(ValueError, match="FanOutEdgeGroup must contain at least two targets"): + FanOutEdgeGroup(source=source, targets=[target]) async def test_source_edge_group_send_message(): - """Test sending a message through a source edge group.""" + """Test sending a message through a fan-out group.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = SourceEdgeGroup(source=source, targets=[target1, target2]) + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) from agent_framework_workflow._runner_context import InProcRunnerContext, Message from agent_framework_workflow._shared_state import SharedState @@ -253,12 +253,12 @@ async def test_source_edge_group_send_message(): async def test_source_edge_group_send_message_with_target(): - """Test sending a message through a source edge group with a target.""" + """Test sending a message through a fan-out group with a target.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = SourceEdgeGroup(source=source, targets=[target1, target2]) + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) from agent_framework_workflow._runner_context import InProcRunnerContext, Message from agent_framework_workflow._shared_state import SharedState @@ -278,12 +278,12 @@ async def test_source_edge_group_send_message_with_target(): async def test_source_edge_group_send_message_with_invalid_target(): - """Test sending a message through a source edge group with an invalid target.""" + """Test sending a message through a fan-out group with an invalid target.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = SourceEdgeGroup(source=source, targets=[target1, target2]) + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) from agent_framework_workflow._runner_context import InProcRunnerContext, Message from agent_framework_workflow._shared_state import SharedState @@ -299,12 +299,12 @@ async def test_source_edge_group_send_message_with_invalid_target(): async def test_source_edge_group_send_message_with_invalid_data(): - """Test sending a message through a source edge group with invalid data.""" + """Test sending a message through a fan-out group with invalid data.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = SourceEdgeGroup(source=source, targets=[target1, target2]) + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) from agent_framework_workflow._runner_context import InProcRunnerContext, Message from agent_framework_workflow._shared_state import SharedState @@ -320,12 +320,12 @@ async def test_source_edge_group_send_message_with_invalid_data(): async def test_source_edge_group_send_message_only_one_successful_send(): - """Test sending a message through a source edge group where only one edge can handle the message.""" + """Test sending a message through a fan-out group where only one edge can handle the message.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutorSecondary(id="target_executor_2") - edge_group = SourceEdgeGroup(source=source, targets=[target1, target2]) + edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2]) from agent_framework_workflow._runner_context import InProcRunnerContext, Message from agent_framework_workflow._shared_state import SharedState @@ -349,7 +349,7 @@ def test_source_edge_group_with_selection_func(): target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = SourceEdgeGroup( + edge_group = FanOutEdgeGroup( source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id], @@ -365,12 +365,12 @@ def test_source_edge_group_with_selection_func(): async def test_source_edge_group_with_selection_func_send_message(): - """Test sending a message through a source edge group with a selection function.""" + """Test sending a message through a fan-out group with a selection function.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = SourceEdgeGroup( + edge_group = FanOutEdgeGroup( source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id], @@ -393,12 +393,12 @@ async def test_source_edge_group_with_selection_func_send_message(): async def test_source_edge_group_with_selection_func_send_message_with_invalid_selection_result(): - """Test sending a message through a source edge group with a selection func with an invalid selection result.""" + """Test sending a message through a fan-out group with a selection func with an invalid selection result.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = SourceEdgeGroup( + edge_group = FanOutEdgeGroup( source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, "invalid_target"], @@ -418,12 +418,12 @@ async def test_source_edge_group_with_selection_func_send_message_with_invalid_s async def test_source_edge_group_with_selection_func_send_message_with_target(): - """Test sending a message through a source edge group with a selection func with a target.""" + """Test sending a message through a fan-out group with a selection func with a target.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = SourceEdgeGroup( + edge_group = FanOutEdgeGroup( source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id], @@ -447,12 +447,12 @@ async def test_source_edge_group_with_selection_func_send_message_with_target(): async def test_source_edge_group_with_selection_func_send_message_with_target_not_in_selection(): - """Test sending a message through a source edge group with a selection func with a target not in the selection.""" + """Test sending a message through a fan-out group with a selection func with a target not in the selection.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = SourceEdgeGroup( + edge_group = FanOutEdgeGroup( source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id], # Only target1 will receive the message @@ -472,12 +472,12 @@ async def test_source_edge_group_with_selection_func_send_message_with_target_no async def test_source_edge_group_with_selection_func_send_message_with_invalid_data(): - """Test sending a message through a source edge group with a selection func with invalid data.""" + """Test sending a message through a fan-out group with a selection func with invalid data.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = SourceEdgeGroup( + edge_group = FanOutEdgeGroup( source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id] ) @@ -495,12 +495,12 @@ async def test_source_edge_group_with_selection_func_send_message_with_invalid_d async def test_source_edge_group_with_selection_func_send_message_with_target_invalid_data(): - """Test sending a message through a source edge group with a selection func with a target and invalid data.""" + """Test sending a message through a fan-out group with a selection func with a target and invalid data.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") - edge_group = SourceEdgeGroup( + edge_group = FanOutEdgeGroup( source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id] ) @@ -517,18 +517,18 @@ async def test_source_edge_group_with_selection_func_send_message_with_target_in assert success is False -# endregion SourceEdgeGroup +# endregion FanOutEdgeGroup -# region TargetEdgeGroup +# region FanInEdgeGroup def test_target_edge_group(): - """Test creating a target edge group.""" + """Test creating a fan-in edge group.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") target = MockAggregator(id="target_executor") - edge_group = TargetEdgeGroup(sources=[source1, source2], target=target) + edge_group = FanInEdgeGroup(sources=[source1, source2], target=target) assert edge_group.source_executors == [source1, source2] assert edge_group.target_executors == [target] @@ -540,21 +540,21 @@ def test_target_edge_group(): def test_target_edge_group_invalid_number_of_sources(): - """Test creating a target edge group with an invalid number of sources.""" + """Test creating a fan-in edge group with an invalid number of sources.""" source = MockExecutor(id="source_executor") target = MockAggregator(id="target_executor") - with pytest.raises(ValueError, match="TargetEdgeGroup must contain at least two sources"): - TargetEdgeGroup(sources=[source], target=target) + with pytest.raises(ValueError, match="FanInEdgeGroup must contain at least two sources"): + FanInEdgeGroup(sources=[source], target=target) async def test_target_edge_group_send_message_buffer(): - """Test sending a message through a target edge group with buffering.""" + """Test sending a message through a fan-in edge group with buffering.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") target = MockAggregator(id="target_executor") - edge_group = TargetEdgeGroup(sources=[source1, source2], target=target) + edge_group = FanInEdgeGroup(sources=[source1, source2], target=target) from agent_framework_workflow._runner_context import InProcRunnerContext, Message from agent_framework_workflow._shared_state import SharedState @@ -588,12 +588,12 @@ async def test_target_edge_group_send_message_buffer(): async def test_target_edge_group_send_message_with_invalid_target(): - """Test sending a message through a target edge group with an invalid target.""" + """Test sending a message through a fan-in edge group with an invalid target.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") target = MockAggregator(id="target_executor") - edge_group = TargetEdgeGroup(sources=[source1, source2], target=target) + edge_group = FanInEdgeGroup(sources=[source1, source2], target=target) from agent_framework_workflow._runner_context import InProcRunnerContext, Message from agent_framework_workflow._shared_state import SharedState @@ -609,12 +609,12 @@ async def test_target_edge_group_send_message_with_invalid_target(): async def test_target_edge_group_send_message_with_invalid_data(): - """Test sending a message through a target edge group with invalid data.""" + """Test sending a message through a fan-in edge group with invalid data.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") target = MockAggregator(id="target_executor") - edge_group = TargetEdgeGroup(sources=[source1, source2], target=target) + edge_group = FanInEdgeGroup(sources=[source1, source2], target=target) from agent_framework_workflow._runner_context import InProcRunnerContext, Message from agent_framework_workflow._shared_state import SharedState @@ -629,7 +629,7 @@ async def test_target_edge_group_send_message_with_invalid_data(): assert success is False -# endregion TargetEdgeGroup +# endregion FanInEdgeGroup # region SwitchCaseEdgeGroup