diff --git a/flo_ai/flo_ai/arium/arium.py b/flo_ai/flo_ai/arium/arium.py index a86243a6..ebd21bdb 100644 --- a/flo_ai/flo_ai/arium/arium.py +++ b/flo_ai/flo_ai/arium/arium.py @@ -7,6 +7,7 @@ from flo_ai.tool.base_tool import Tool from flo_ai.arium.models import StartNode, EndNode from flo_ai.arium.events import AriumEventType, AriumEvent +from flo_ai.arium.nodes import AriumNode, ForEachNode from flo_ai.utils.logger import logger from flo_ai.utils.variable_extractor import ( extract_variables_from_inputs, @@ -73,7 +74,7 @@ async def run( # Execute the workflow with event support result = await self._execute_graph( - resolved_inputs, event_callback, events_filter + resolved_inputs, event_callback, events_filter, variables ) # Emit workflow completed event @@ -81,6 +82,8 @@ async def run( AriumEventType.WORKFLOW_COMPLETED, event_callback, events_filter ) + self.memory = MessageMemory() # cleanup the graph (if used as AriumNode multiple times in graph, then the same instance is used for now hence we need to cleanup memory) + return result except Exception as e: @@ -118,6 +121,7 @@ async def _execute_graph( inputs: List[str | ImageMessage | DocumentMessage], event_callback: Optional[Callable[[AriumEvent], None]] = None, events_filter: Optional[List[AriumEventType]] = None, + variables: Optional[Dict[str, Any]] = None, ): [self.memory.add(msg) for msg in inputs] @@ -162,11 +166,16 @@ async def _execute_graph( ) # execute current node result = await self._execute_node( - current_node, event_callback, events_filter + current_node, event_callback, events_filter, variables ) - # update results to memory - self._add_to_memory(result) + if isinstance(result, List): # for each node will give results array + for item in result: + # update each item in result to memory + self._add_to_memory(item) + else: + # update results to memory + self._add_to_memory(result) # find next node post current node # Prepare execution context for router functions @@ -301,6 +310,7 @@ async def _execute_node( node: Agent | Tool | StartNode | EndNode, event_callback: Optional[Callable[[AriumEvent], None]] = None, events_filter: Optional[List[AriumEventType]] = None, + variables: Optional[Dict[str, Any]] = None, ): """ Execute a single node with optional event emission. @@ -318,6 +328,10 @@ async def _execute_node( node_type = 'agent' elif isinstance(node, Tool): node_type = 'tool' + elif isinstance(node, ForEachNode): + node_type = 'foreach' + elif isinstance(node, AriumNode): + node_type = 'arium' elif isinstance(node, StartNode): node_type = 'start' elif isinstance(node, EndNode): @@ -342,7 +356,16 @@ async def _execute_node( # Variables are already resolved, pass empty dict to avoid re-processing result = await node.run(self.memory.get(), variables={}) elif isinstance(node, Tool): - result = await node.execute() + # result = await node.execute() # as Tool is also an ExecutableNode now + result = await node.run(inputs=[], variables=None) + elif isinstance(node, ForEachNode): + result = await node.run( + inputs=self.memory.get(), + variables=variables, + ) + elif isinstance(node, AriumNode): + # AriumNode execution + result = await node.run(inputs=self.memory.get(), variables=variables) elif isinstance(node, StartNode): result = None elif isinstance(node, EndNode): diff --git a/flo_ai/flo_ai/arium/base.py b/flo_ai/flo_ai/arium/base.py index c3458354..258dd850 100644 --- a/flo_ai/flo_ai/arium/base.py +++ b/flo_ai/flo_ai/arium/base.py @@ -1,5 +1,7 @@ import inspect from functools import partial +from flo_ai.arium.nodes import AriumNode, ForEachNode +from flo_ai.arium.protocols import ExecutableNode from flo_ai.models.agent import Agent from flo_ai.tool.base_tool import Tool from flo_ai.utils.logger import logger @@ -12,13 +14,13 @@ class BaseArium: def __init__(self): self.start_node_name = '__start__' self.end_node_names: set = set() # Support multiple end nodes - self.nodes: Dict[str, Agent | Tool | StartNode | EndNode] = dict() + self.nodes: Dict[str, ExecutableNode | StartNode | EndNode] = dict() self.edges: Dict[str, Edge] = dict() - def add_nodes(self, agents: List[Agent | Tool | StartNode | EndNode]): + def add_nodes(self, agents: List[ExecutableNode | StartNode | EndNode]): self.nodes.update({agent.name: agent for agent in agents}) - def start_at(self, node: Agent | Tool | StartNode | EndNode): + def start_at(self, node: ExecutableNode): start_node = StartNode() if start_node.name in self.nodes: raise ValueError(f'Start node {start_node.name} already exists') @@ -27,7 +29,7 @@ def start_at(self, node: Agent | Tool | StartNode | EndNode): router_fn=partial(default_router, to_node=node.name), to_nodes=[node.name] ) - def add_end_to(self, node: Agent | Tool | StartNode | EndNode): + def add_end_to(self, node: ExecutableNode): # Create a unique end node name for this specific node end_node_name = f'__end__{node.name}__' end_node = EndNode() @@ -378,5 +380,9 @@ def _get_node_type(self, node) -> str: return 'agent' elif isinstance(node, Tool): return 'tool' + elif isinstance(node, ForEachNode): + return 'foreach' + elif isinstance(node, AriumNode): + return 'arium' else: return 'unknown' diff --git a/flo_ai/flo_ai/arium/builder.py b/flo_ai/flo_ai/arium/builder.py index 03ca6c80..cc66bdb8 100644 --- a/flo_ai/flo_ai/arium/builder.py +++ b/flo_ai/flo_ai/arium/builder.py @@ -1,6 +1,8 @@ from typing import List, Optional, Callable, Union, Dict, Any from flo_ai.arium.arium import Arium from flo_ai.arium.memory import MessageMemory, BaseMemory +from flo_ai.arium.protocols import ExecutableNode +from flo_ai.arium.nodes import AriumNode, ForEachNode from flo_ai.models.agent import Agent from flo_ai.tool.base_tool import Tool from flo_ai.llm.base_llm import ImageMessage @@ -30,10 +32,17 @@ def __init__(self): self._memory: Optional[BaseMemory] = None self._agents: List[Agent] = [] self._tools: List[Tool] = [] - self._start_node: Optional[Union[Agent, Tool]] = None - self._end_nodes: List[Union[Agent, Tool]] = [] + self._ariums: List[ + AriumNode + ] = [] # only those ariums which are part of main workflow + self._foreach_nodes: List[ForEachNode] = [] + self._start_node: Optional[ExecutableNode] = None + self._end_nodes: List[ExecutableNode] = [] self._edges: List[tuple] = [] # (from_node, to_nodes, router) self._arium: Optional[Arium] = None + self._all_ariums: List[ + AriumNode + ] = [] # all the ariums either of main workflow or when used as a node in foreachnode or any sub workflow def with_memory(self, memory: BaseMemory) -> 'AriumBuilder': """Set the memory for the Arium.""" @@ -60,12 +69,76 @@ def add_tools(self, tools: List[Tool]) -> 'AriumBuilder': self._tools.extend(tools) return self - def start_with(self, node: Union[Agent, Tool]) -> 'AriumBuilder': + def add_arium( + self, arium: Arium, name: Optional[str] = None, inherit_variables: bool = True + ) -> 'AriumBuilder': + """ + Add an Arium workflow as a node. + + Args: + arium: The Arium to add as a node + name: Name for this node (defaults to arium's name or auto-generated) + inherit_variables: Whether to inherit parent variables + + Returns: + AriumBuilder: Self for method chaining + """ + # Generate name if not provided + node_name = name or getattr(arium, 'name', f'arium_node_{len(self._ariums)}') + + arium_node = AriumNode( + name=node_name, arium=arium, inherit_variables=inherit_variables + ) + self._ariums.append(arium_node) + self._all_ariums.append(arium_node) + return self + + def add_foreach( + self, name: str, execute_node: Union[ExecutableNode, str] + ) -> 'AriumBuilder': + """ + Add a ForEach node for batch processing. + + The ForEach node will iterate over all items in memory when executed. + Execution is sequential only for now (future - parallel would also be supported). + + Args: + name: Name for the ForEach node + execute_node: Node to execute on each item (node object or name string) + + Returns: + AriumBuilder: Self for method chaining + """ + # Resolve node reference if string name provided + if isinstance(execute_node, str): + # Search across all node types + all_nodes = self._agents + self._tools + self._ariums + self._foreach_nodes + resolved_node = next((n for n in all_nodes if n.name == execute_node), None) + if not resolved_node: + raise ValueError(f"Node '{execute_node}' not found") + execute_node = resolved_node + + foreach = ForEachNode(name=name, execute_node=execute_node) + + self._foreach_nodes.append(foreach) + if isinstance(execute_node, AriumNode): + if execute_node not in self._all_ariums: + self._all_ariums.append(execute_node) + return self + + def start_with(self, node: ExecutableNode | str) -> 'AriumBuilder': """Set the starting node for the Arium.""" + if isinstance(node, str): + # Search across all node types + all_nodes = self._agents + self._tools + self._ariums + self._foreach_nodes + resolved_node = next((n for n in all_nodes if n.name == node), None) + if not resolved_node: + raise ValueError(f"Node '{node}' not found") + node = resolved_node self._start_node = node return self - def end_with(self, node: Union[Agent, Tool]) -> 'AriumBuilder': + def end_with(self, node: ExecutableNode) -> 'AriumBuilder': """Add an ending node to the Arium.""" if node not in self._end_nodes: self._end_nodes.append(node) @@ -73,8 +146,8 @@ def end_with(self, node: Union[Agent, Tool]) -> 'AriumBuilder': def add_edge( self, - from_node: Union[Agent, Tool], - to_nodes: List[Union[Agent, Tool]], + from_node: ExecutableNode, + to_nodes: List[ExecutableNode], router: Optional[Callable] = None, ) -> 'AriumBuilder': """Add an edge between nodes with an optional router function.""" @@ -83,10 +156,29 @@ def add_edge( def connect( self, - from_node: Union[Agent, Tool], - to_node: Union[Agent, Tool], + from_node: ExecutableNode | str, + to_node: ExecutableNode | str, ) -> 'AriumBuilder': """Simple connection between two nodes without a router.""" + + if isinstance(from_node, str): + # Search across all node types + all_nodes = self._agents + self._tools + self._ariums + self._foreach_nodes + resolved_from_node = next( + (n for n in all_nodes if n.name == from_node), None + ) + if not resolved_from_node: + raise ValueError(f"Node '{from_node}' not found") + from_node = resolved_from_node + + if isinstance(to_node, str): + # Search across all node types + all_nodes = self._agents + self._tools + self._ariums + self._foreach_nodes + resolved_to_node = next((n for n in all_nodes if n.name == to_node), None) + if not resolved_to_node: + raise ValueError(f"Node '{to_node}' not found") + to_node = resolved_to_node + return self.add_edge(from_node, [to_node]) def build(self) -> Arium: @@ -102,6 +194,8 @@ def build(self) -> Arium: all_nodes = [] all_nodes.extend(self._agents) all_nodes.extend(self._tools) + all_nodes.extend(self._ariums) + all_nodes.extend(self._foreach_nodes) if not all_nodes: raise ValueError('No agents or tools added to the Arium') @@ -127,6 +221,11 @@ def build(self) -> Arium: for end_node in self._end_nodes: arium.add_end_to(end_node) + # Compile all Arium Nodes before compiling parent + for arium_node in self._all_ariums: + if not arium_node.arium.is_compiled: + arium_node.arium.compile() + # Compile the Arium arium.compile() @@ -157,6 +256,8 @@ def reset(self) -> 'AriumBuilder': self._memory = None self._agents = [] self._tools = [] + self._ariums = [] + self._foreach_nodes = [] self._start_node = None self._end_nodes = [] self._edges = [] @@ -263,8 +364,41 @@ def from_yaml( executor_agent: developer reviewer_agent: reviewer + # AriumNode definitions (nested Arium workflows) + arium_nodes: + # Method 1: Inline nested Arium definition + - name: document_processor + inherit_variables: true # optional, default: true + agents: + - name: classifier + job: "Classify documents" + model: + provider: openai + name: gpt-4o-mini + - name: specialist + job: "Process classified documents" + model: + provider: openai + name: gpt-4o-mini + workflow: + start: classifier + edges: + - from: classifier + to: [specialist] + end: [specialist] + + # Method 2: External YAML file reference + - name: complex_processor + yaml_file: "workflows/document_classifier.yaml" + inherit_variables: false + + # ForEachNode definitions + foreach_nodes: + - name: batch_processor + execute_node: document_processor # Can reference any node type + workflow: - start: content_analyst + start: batch_processor # Can reference any node type including foreach/arium nodes edges: - from: content_analyst to: [validator, summarizer] @@ -480,20 +614,139 @@ def from_yaml( all_routers.update(routers) all_routers.update(yaml_routers) + # Process AriumNodes (nested Arium workflows) + arium_nodes_config = arium_config.get('ariums', []) + arium_nodes_dict = {} + + for arium_node_config in arium_nodes_config: + node_name = arium_node_config['name'] + inherit_vars = arium_node_config.get('inherit_variables', True) + + # Method 1: External YAML file reference + if 'yaml_file' in arium_node_config: + yaml_file_path = arium_node_config['yaml_file'] + + nested_builder = cls.from_yaml( + yaml_file=yaml_file_path, + memory=None, + agents=None, + tools=tools, # Nested can use parent's tools + routers=None, + base_llm=base_llm, + ) + nested_arium = nested_builder.build() + + # Method 2: Inline definition + else: + # Build sub-config from inline definition + sub_config = { + 'arium': { + 'agents': arium_node_config.get('agents', []), + 'tools': arium_node_config.get('tools', []), + 'routers': arium_node_config.get('routers', []), + 'ariums': arium_node_config.get( + 'ariums', [] + ), # Support nesting! + 'iterators': arium_node_config.get('iterators', []), + 'workflow': arium_node_config['workflow'], + } + } + + nested_builder = cls.from_yaml( + yaml_str=yaml.dump(sub_config), + memory=None, + agents=None, + tools=tools, + routers=None, + base_llm=base_llm, + ) + nested_arium = nested_builder.build() + + # Wrap in AriumNode + arium_node = AriumNode( + name=node_name, arium=nested_arium, inherit_variables=inherit_vars + ) + + arium_nodes_dict[node_name] = arium_node + builder._all_ariums.append(arium_node) + # Don't add to builder yet - will add during workflow processing if actually used + + # Process ForEachNodes (store configs, resolve later) + foreach_nodes_config = arium_config.get('iterators', []) + foreach_nodes_dict = {} + + for foreach_config in foreach_nodes_config: + foreach_name = foreach_config['name'] + execute_node_name = foreach_config['execute_node'] + + foreach_nodes_dict[foreach_name] = { + 'name': foreach_name, + 'execute_node_name': execute_node_name, + } + + # Resolve ForEachNode references now that all nodes exist + for foreach_name, foreach_config in foreach_nodes_dict.items(): + execute_node_name = foreach_config['execute_node_name'] + + # Find execute_node from ALL node types + execute_node = ( + agents_dict.get(execute_node_name) + or tools_dict.get(execute_node_name) + or arium_nodes_dict.get(execute_node_name) + or foreach_nodes_dict.get(execute_node_name) + ) + + if not execute_node: + all_nodes = ( + list(agents_dict.keys()) + + list(tools_dict.keys()) + + list(arium_nodes_dict.keys()) + + list(foreach_nodes_dict.keys()) + ) + raise ValueError( + f"ForEachNode '{foreach_name}': execute_node '{execute_node_name}' not found. " + f'Available nodes: {all_nodes}' + ) + + # Create ForEachNode + foreach_node = ForEachNode(name=foreach_name, execute_node=execute_node) + + foreach_nodes_dict[foreach_name] = foreach_node + builder._foreach_nodes.append(foreach_node) + # Process workflow workflow_config = arium_config.get('workflow', {}) + # Helper function to find node from all sources + def _find_node(node_name: str): + return ( + agents_dict.get(node_name) + or tools_dict.get(node_name) + or arium_nodes_dict.get(node_name) + or foreach_nodes_dict.get(node_name) + ) + # Set start node start_node_name = workflow_config.get('start') if not start_node_name: raise ValueError('Workflow must specify a start node') - start_node = agents_dict.get(start_node_name) or tools_dict.get(start_node_name) + start_node = _find_node(start_node_name) if not start_node: + all_available = ( + list(agents_dict.keys()) + + list(tools_dict.keys()) + + list(arium_nodes_dict.keys()) + + list(foreach_nodes_dict.keys()) + ) raise ValueError( - f'Start node {start_node_name} not found in agents or tools' + f'Start node {start_node_name} not found. Available nodes: {all_available}' ) + # Add AriumNode to builder if it's used in main workflow + if isinstance(start_node, AriumNode): + builder._ariums.append(start_node) + builder.start_with(start_node) # Process edges @@ -505,12 +758,14 @@ def from_yaml( router_name = edge_config.get('router') # Find from node - from_node = agents_dict.get(from_node_name) or tools_dict.get( - from_node_name - ) + from_node = _find_node(from_node_name) if not from_node: raise ValueError(f'From node {from_node_name} not found') + # Add AriumNode to builder if it's used in main workflow and not already added + if isinstance(from_node, AriumNode) and from_node not in builder._ariums: + builder._ariums.append(from_node) + # Find to nodes (handle special 'end' case) to_nodes = [] for to_node_name in to_nodes_names: @@ -518,9 +773,14 @@ def from_yaml( # 'end' will be handled in end nodes processing continue - to_node = agents_dict.get(to_node_name) or tools_dict.get(to_node_name) + to_node = _find_node(to_node_name) if not to_node: raise ValueError(f'To node {to_node_name} not found') + + # Add AriumNode to builder if it's used in main workflow and not already added + if isinstance(to_node, AriumNode) and to_node not in builder._ariums: + builder._ariums.append(to_node) + to_nodes.append(to_node) # Find router function @@ -544,9 +804,14 @@ def from_yaml( raise ValueError('Workflow must specify end nodes') for end_node_name in end_nodes_names: - end_node = agents_dict.get(end_node_name) or tools_dict.get(end_node_name) + end_node = _find_node(end_node_name) if not end_node: raise ValueError(f'End node {end_node_name} not found') + + # Add AriumNode to builder if it's used in main workflow and not already added + if isinstance(end_node, AriumNode) and end_node not in builder._ariums: + builder._ariums.append(end_node) + builder.end_with(end_node) return builder diff --git a/flo_ai/flo_ai/arium/nodes.py b/flo_ai/flo_ai/arium/nodes.py new file mode 100644 index 00000000..fb6a2ffb --- /dev/null +++ b/flo_ai/flo_ai/arium/nodes.py @@ -0,0 +1,95 @@ +from flo_ai.arium.protocols import ExecutableNode +from typing import List, Any, Dict, Optional, TYPE_CHECKING +from flo_ai.utils.logger import logger + +if TYPE_CHECKING: # need to have an optional import else will get circular dependency error as arium also has AriumNode reference + from flo_ai.arium.arium import Arium + + +class AriumNode: + """ + Wrapper to use an Arium as a node in another Arium workflow. + """ + + def __init__(self, name: str, arium: 'Arium', inherit_variables: bool = True): + """ + Args: + name: Name for this node in the parent workflow + arium: The Arium workflow to execute + inherit_variables: Whether to pass parent variables to sub-workflow + """ + self.name = name + self.arium = arium + self.inherit_variables = inherit_variables + + async def run( + self, inputs: List[Any], variables: Optional[Dict[str, Any]] = None, **kwargs + ) -> Any: + """Execute the nested Arium workflow with isolated memory""" + + # Handle variable inheritance + execution_variables = ( + variables.copy() if (self.inherit_variables and variables) else None + ) + + # Execute the nested Arium with isolated memory + result = await self.arium.run( + inputs=inputs, + variables=execution_variables, + ) + return result + + +class ForEachNode: + """ + Execute a node on each item in a collection. + + Supports only sequential execution for now. (parallel execution would be supported in future) + """ + + def __init__(self, name: str, execute_node: ExecutableNode): + """ + Args: + name: Node name + execute_node: Node to execute on each item + """ + self.name = name + self.execute_node = execute_node + + async def _execute_item( + self, + item: Any, + index: int, + variables: Optional[Dict[str, Any]] = None, + ) -> Any: + """Execute the node on a single item""" + logger.info(f"ForEach '{self.name}': Processing item {index + 1}") + + # Create execution variables with item context + item_variables = (variables or {}).copy() + + # Execute the node + result = await self.execute_node.run( + inputs=[item], + variables=item_variables, + ) + + # Return last item if result is a list, otherwise return as-is + if isinstance(result, list) and result: + return result[-1] + return result + + async def run( + self, inputs: List[Any], variables: Optional[Dict[str, Any]] = None, **kwargs + ) -> List[Any]: + """Execute the node on all items""" + + # Sequential execution + results = [] + for i, item in enumerate(inputs): + result = await self._execute_item(item, i, variables) + results.append(result) + + logger.info(f"ForEach '{self.name}': Completed processing {len(results)} items") + + return results diff --git a/flo_ai/flo_ai/arium/protocols.py b/flo_ai/flo_ai/arium/protocols.py new file mode 100644 index 00000000..9b4602c1 --- /dev/null +++ b/flo_ai/flo_ai/arium/protocols.py @@ -0,0 +1,37 @@ +from typing import Protocol, runtime_checkable, List, Any, Dict, Optional + + +@runtime_checkable +class ExecutableNode(Protocol): + """ + Protocol defining the interface for any node that can be executed + within an Arium workflow. + + Any class implementing this protocol can be used as a node: + - Agent (already implements) + - Tool (already implements) + - Arium (already implements!) + - Custom node types + """ + + name: str + """Unique identifier for the node""" + + async def run( + self, + inputs: List[str | Any], + variables: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + """ + Execute the node and return results. + + Args: + inputs: Input data for execution + variables: Optional variable substitutions + **kwargs: Additional execution parameters + + Returns: + Execution result (type depends on node implementation) + """ + ... diff --git a/flo_ai/flo_ai/models/agent.py b/flo_ai/flo_ai/models/agent.py index 0cd9243f..e970b504 100644 --- a/flo_ai/flo_ai/models/agent.py +++ b/flo_ai/flo_ai/models/agent.py @@ -223,7 +223,10 @@ async def _run_with_tools( function_args = function_call['arguments'] tool = self.tools_dict[function_name] - function_response = await tool.execute(**function_args) + # function_response = await tool.execute(**function_args) + function_response = await tool.run( + inputs=[], variables=None, **function_args + ) tool_call_count += 1 # Add function call to history diff --git a/flo_ai/flo_ai/tool/base_tool.py b/flo_ai/flo_ai/tool/base_tool.py index a2b338d7..bdcb64c2 100644 --- a/flo_ai/flo_ai/tool/base_tool.py +++ b/flo_ai/flo_ai/tool/base_tool.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Callable +from typing import Dict, Any, Callable, List, Optional from flo_ai.models.agent_error import AgentError from flo_ai.utils.logger import logger @@ -40,3 +40,10 @@ async def execute(self, **kwargs) -> Any: raise ToolExecutionError( f'Error executing tool {self.name}: {str(e)}', original_error=e ) + + async def run( + self, inputs: List[Any], variables: Optional[Dict[str, Any]] = None, **kwargs + ) -> Any: + # Ignore inputs and variables + # Just pass kwargs to execute() + return await self.execute(**kwargs)