From 7eace251c408a8a8bbae39d3aa63dcf31c6af8df Mon Sep 17 00:00:00 2001 From: vishnu r kumar Date: Thu, 30 Oct 2025 21:53:58 +0530 Subject: [PATCH 1/7] feat: add tool node and upgrade message memory with input filters --- flo_ai/examples/arium_examples.py | 84 ++++++++++++++----- flo_ai/flo_ai/arium/arium.py | 40 ++++----- flo_ai/flo_ai/arium/base.py | 4 +- flo_ai/flo_ai/arium/builder.py | 100 +++++++++++++---------- flo_ai/flo_ai/arium/memory.py | 129 +++++++++++++++++++++++++----- flo_ai/flo_ai/arium/models.py | 4 +- flo_ai/flo_ai/arium/nodes.py | 46 ++++++++++- flo_ai/flo_ai/arium/protocols.py | 2 + flo_ai/flo_ai/models/agent.py | 2 + 9 files changed, 305 insertions(+), 106 deletions(-) diff --git a/flo_ai/examples/arium_examples.py b/flo_ai/examples/arium_examples.py index cbdacc44..73237dbc 100644 --- a/flo_ai/examples/arium_examples.py +++ b/flo_ai/examples/arium_examples.py @@ -5,7 +5,7 @@ from typing import Literal from flo_ai.arium import AriumBuilder, create_arium from flo_ai.models.agent import Agent -from flo_ai.tool.base_tool import Tool +from flo_ai.arium.nodes import ToolNode from flo_ai.arium.memory import MessageMemory @@ -15,18 +15,18 @@ async def example_linear_workflow(): # Create some example agents and tools (these would be your actual implementations) analyzer_agent = Agent(name='analyzer', prompt='Analyze the input') - processing_tool = Tool(name='processor') + processing_tool_node = ToolNode(name='processor', description='Process the input', function=lambda x: x) summarizer_agent = Agent(name='summarizer', prompt='Summarize the results') # Build and run the workflow result = await ( AriumBuilder() .add_agent(analyzer_agent) - .add_tool(processing_tool) + .add_tool_node(processing_tool_node) .add_agent(summarizer_agent) .start_with(analyzer_agent) - .connect(analyzer_agent, processing_tool) - .connect(processing_tool, summarizer_agent) + .connect(analyzer_agent, processing_tool_node) + .connect(processing_tool_node, summarizer_agent) .end_with(summarizer_agent) .build_and_run(['Analyze this text']) ) @@ -40,8 +40,8 @@ async def example_branching_workflow(): # Create agents and tools classifier_agent = Agent(name='classifier', prompt='Classify the input type') - text_processor = Tool(name='text_processor') - image_processor = Tool(name='image_processor') + text_processor_node = ToolNode(name='text_processor', description='Process text', function=lambda x: x) + image_processor_node = ToolNode(name='image_processor', description='Process image', function=lambda x: x) final_agent = Agent(name='final', prompt='Provide final response') # Router function for conditional branching @@ -56,20 +56,18 @@ def content_router(memory) -> Literal['text_processor', 'image_processor']: result = await ( AriumBuilder() .add_agent(classifier_agent) - .add_tool(text_processor) - .add_tool(image_processor) .add_agent(final_agent) + .add_tool_node(ToolNode(name='text_processor', description='Process text', function=lambda x: x)) + .add_tool_node(ToolNode(name='image_processor', description='Process image', function=lambda x: x)) .start_with(classifier_agent) - .add_edge(classifier_agent, [text_processor, image_processor], content_router) - .connect(text_processor, final_agent) - .connect(image_processor, final_agent) + .add_edge(classifier_agent, [text_processor_node, image_processor_node], content_router) + .connect(text_processor_node, final_agent) + .connect(image_processor_node, final_agent) .end_with(final_agent) .build_and_run(['Process this content']) ) - return result - # Example 3: Complex Multi-Agent Workflow async def example_complex_workflow(): """Example of a more complex workflow with multiple agents and tools""" @@ -80,8 +78,8 @@ async def example_complex_workflow(): analyzer_agent = Agent(name='analyzer', prompt='Analyze findings') writer_agent = Agent(name='writer', prompt='Write the final report') - search_tool = Tool(name='search_tool') - data_tool = Tool(name='data_processor') + search_tool_node = ToolNode(name='search_tool', description='Search the web', function=lambda x: x) + data_tool_node = ToolNode(name='data_processor', description='Process the data', function=lambda x: x) # Router for deciding next step after analysis def analysis_router(memory) -> Literal['writer', 'researcher']: @@ -93,13 +91,13 @@ def analysis_router(memory) -> Literal['writer', 'researcher']: arium = ( AriumBuilder() .add_agents([input_agent, researcher_agent, analyzer_agent, writer_agent]) - .add_tools([search_tool, data_tool]) + .add_tool_nodes([search_tool_node, data_tool_node]) .with_memory(MessageMemory()) .start_with(input_agent) .connect(input_agent, researcher_agent) - .connect(researcher_agent, search_tool) - .connect(search_tool, data_tool) - .connect(data_tool, analyzer_agent) + .connect(researcher_agent, search_tool_node) + .connect(search_tool_node, data_tool_node) + .connect(data_tool_node, analyzer_agent) .add_edge(analyzer_agent, [writer_agent, researcher_agent], analysis_router) .end_with(writer_agent) .build() @@ -151,6 +149,49 @@ async def example_build_and_reuse(): return result1, result2 +# Example 6: Four ToolNodes with input filtering (no agents) +async def example_tool_nodes_with_filters(): + """Workflow of only ToolNodes; each uses input_filter to read from specific nodes.""" + + # Define simple tool functions + async def pass_through(inputs=None, variables=None, **kwargs): + return inputs + + async def capitalize_last(inputs=None, variables=None, **kwargs): + if not inputs: + return 'No inputs' + last = str(inputs[-1]) + return last.capitalize() + + async def uppercase_all(inputs=None, variables=None, **kwargs): + if not inputs: + return 'No inputs' + return ' '.join([str(x).upper() for x in inputs]) + + async def summarize(inputs=None, variables=None, **kwargs): + return f"count={len(inputs or [])} last={(str(inputs[-1]) if inputs else '')}" + + # Create four ToolNodes with input filters + t1 = ToolNode(name='tool1', description='reads initial inputs', function=pass_through, input_filter=['input']) + t2 = ToolNode(name='tool2', description='reads tool1 only', function=capitalize_last, input_filter=['tool1']) + t3 = ToolNode(name='tool3', description='reads tool2 only', function=uppercase_all, input_filter=['tool2']) + t4 = ToolNode(name='tool4', description='reads tool1 & tool3', function=summarize, input_filter=['tool1', 'tool3']) + + # Build and run: tool1 -> tool2 -> tool3 -> tool4 + result = await ( + AriumBuilder() + .with_memory(MessageMemory()) + .add_tool_nodes([t1, t2, t3, t4]) + .start_with(t1) + .connect(t1, t2) + .connect(t2, t3) + .connect(t3, t4) + .end_with(t4) + .build_and_run(['hello world']) + ) + + return result + if __name__ == '__main__': import asyncio @@ -164,6 +205,9 @@ async def main(): # result3 = await example_complex_workflow() # result4 = await example_convenience_function() # result5 = await example_build_and_reuse() + result6 = await example_tool_nodes_with_filters() + + print(result6) print('Examples completed!') diff --git a/flo_ai/flo_ai/arium/arium.py b/flo_ai/flo_ai/arium/arium.py index cc969e71..9be33404 100644 --- a/flo_ai/flo_ai/arium/arium.py +++ b/flo_ai/flo_ai/arium/arium.py @@ -4,10 +4,9 @@ from flo_ai.models.document import DocumentMessage from typing import List, Dict, Any, Optional, Callable from flo_ai.models.agent import Agent -from flo_ai.tool.base_tool import Tool from flo_ai.arium.models import StartNode, EndNode from flo_ai.arium.events import AriumEventType, AriumEvent -from flo_ai.arium.nodes import AriumNode, ForEachNode +from flo_ai.arium.nodes import AriumNode, ForEachNode, ToolNode from flo_ai.utils.logger import logger from flo_ai.utils.variable_extractor import ( extract_variables_from_inputs, @@ -247,11 +246,11 @@ async def _execute_graph( if isinstance(result, List): # for each node will give results array for item in result: - # update each item in result to memory - self._add_to_memory(item) + # update each item in result to memory using new schema + self._add_to_memory({'node': current_node.name, 'output': item}) else: - # update results to memory - self._add_to_memory(result) + # update results to memory using new schema + self._add_to_memory({'node': current_node.name, 'output': result}) # find next node post current node # Prepare execution context for router functions @@ -383,7 +382,7 @@ def _resolve_agent_prompts(self, variables: Dict[str, Any]) -> None: async def _execute_node( self, - node: Agent | Tool | StartNode | EndNode, + node: Agent | ToolNode | ForEachNode | AriumNode | StartNode | EndNode, event_callback: Optional[Callable[[AriumEvent], None]] = None, events_filter: Optional[List[AriumEventType]] = None, variables: Optional[Dict[str, Any]] = None, @@ -402,7 +401,7 @@ async def _execute_node( # Determine node type for events if isinstance(node, Agent): node_type = 'agent' - elif isinstance(node, Tool): + elif isinstance(node, ToolNode): node_type = 'tool' elif isinstance(node, ForEachNode): node_type = 'foreach' @@ -429,6 +428,8 @@ async def _execute_node( # Start node telemetry tracing tracer = get_tracer() + memory_items = self.memory.get(getattr(node, 'input_filter', None)) if getattr(node, 'input_filter', None) else self.memory.get() + inputs = [item['output'] for item in memory_items] if tracer and node_type not in ['start', 'end']: with tracer.start_as_current_span( @@ -441,20 +442,21 @@ async def _execute_node( ) as node_span: try: # Execute the node based on its type + if isinstance(node, Agent): # Variables are already resolved, pass empty dict to avoid re-processing - result = await node.run(self.memory.get(), variables={}) - elif isinstance(node, Tool): - result = await node.run(inputs=[], variables=None) + result = await node.run(inputs, variables={}) + elif isinstance(node, ToolNode): + result = await node.run(inputs, variables=None) elif isinstance(node, ForEachNode): result = await node.run( - inputs=self.memory.get(), + inputs, variables=variables, ) elif isinstance(node, AriumNode): # AriumNode execution result = await node.run( - inputs=self.memory.get(), variables=variables + inputs, variables=variables ) elif isinstance(node, StartNode): result = None @@ -526,18 +528,16 @@ async def _execute_node( try: # Execute the node based on its type if isinstance(node, Agent): - result = await node.run(self.memory.get(), variables={}) - elif isinstance(node, Tool): - result = await node.run(inputs=[], variables=None) + result = await node.run(inputs, variables={}) + elif isinstance(node, ToolNode): + result = await node.run(inputs, variables=None) elif isinstance(node, ForEachNode): result = await node.run( - inputs=self.memory.get(), + inputs, variables=variables, ) elif isinstance(node, AriumNode): - result = await node.run( - inputs=self.memory.get(), variables=variables - ) + result = await node.run(inputs, variables=variables) elif isinstance(node, StartNode): result = None elif isinstance(node, EndNode): diff --git a/flo_ai/flo_ai/arium/base.py b/flo_ai/flo_ai/arium/base.py index 258dd850..e67b71af 100644 --- a/flo_ai/flo_ai/arium/base.py +++ b/flo_ai/flo_ai/arium/base.py @@ -14,8 +14,8 @@ class BaseArium: def __init__(self): self.start_node_name = '__start__' self.end_node_names: set = set() # Support multiple end nodes - self.nodes: Dict[str, ExecutableNode | StartNode | EndNode] = dict() - self.edges: Dict[str, Edge] = dict() + self.nodes: Dict[str, ExecutableNode | StartNode | EndNode] = dict[str, ExecutableNode | StartNode | EndNode]() + self.edges: Dict[str, Edge] = dict[str, Edge]() def add_nodes(self, agents: List[ExecutableNode | StartNode | EndNode]): self.nodes.update({agent.name: agent for agent in agents}) diff --git a/flo_ai/flo_ai/arium/builder.py b/flo_ai/flo_ai/arium/builder.py index cc66bdb8..4c636917 100644 --- a/flo_ai/flo_ai/arium/builder.py +++ b/flo_ai/flo_ai/arium/builder.py @@ -2,7 +2,7 @@ from flo_ai.arium.arium import Arium from flo_ai.arium.memory import MessageMemory, BaseMemory from flo_ai.arium.protocols import ExecutableNode -from flo_ai.arium.nodes import AriumNode, ForEachNode +from flo_ai.arium.nodes import AriumNode, ForEachNode, ToolNode from flo_ai.models.agent import Agent from flo_ai.tool.base_tool import Tool from flo_ai.llm.base_llm import ImageMessage @@ -21,23 +21,23 @@ class AriumBuilder: result = (AriumBuilder() .with_memory(my_memory) .add_agent(agent1) - .add_tool(tool1) + .add_tool_node(tool_node1) .start_with(agent1) - .add_edge(agent1, [tool1], router_fn) - .end_with(tool1) + .add_edge(agent1, [tool_node1], router_fn) + .end_with(tool_node1) .build_and_run(["Hello, world!"])) """ def __init__(self): self._memory: Optional[BaseMemory] = None self._agents: List[Agent] = [] - self._tools: List[Tool] = [] self._ariums: List[ AriumNode ] = [] # only those ariums which are part of main workflow self._foreach_nodes: List[ForEachNode] = [] self._start_node: Optional[ExecutableNode] = None self._end_nodes: List[ExecutableNode] = [] + self._tool_nodes: List[ToolNode] = [] self._edges: List[tuple] = [] # (from_node, to_nodes, router) self._arium: Optional[Arium] = None self._all_ariums: List[ @@ -59,14 +59,14 @@ def add_agents(self, agents: List[Agent]) -> 'AriumBuilder': self._agents.extend(agents) return self - def add_tool(self, tool: Tool) -> 'AriumBuilder': - """Add a tool to the Arium.""" - self._tools.append(tool) + def add_tool_node(self, tool_node: ToolNode) -> 'AriumBuilder': + """Add a tool node to the Arium.""" + self._tool_nodes.append(tool_node) return self - def add_tools(self, tools: List[Tool]) -> 'AriumBuilder': - """Add multiple tools to the Arium.""" - self._tools.extend(tools) + def add_tool_nodes(self, tool_nodes: List[ToolNode]) -> 'AriumBuilder': + """Add multiple tool nodes to the Arium.""" + self._tool_nodes.extend(tool_nodes) return self def add_arium( @@ -112,7 +112,7 @@ def add_foreach( # Resolve node reference if string name provided if isinstance(execute_node, str): # Search across all node types - all_nodes = self._agents + self._tools + self._ariums + self._foreach_nodes + all_nodes = self._agents + self._tool_nodes + self._ariums + self._foreach_nodes resolved_node = next((n for n in all_nodes if n.name == execute_node), None) if not resolved_node: raise ValueError(f"Node '{execute_node}' not found") @@ -130,7 +130,7 @@ def start_with(self, node: ExecutableNode | str) -> 'AriumBuilder': """Set the starting node for the Arium.""" if isinstance(node, str): # Search across all node types - all_nodes = self._agents + self._tools + self._ariums + self._foreach_nodes + all_nodes = self._agents + self._tool_nodes + self._ariums + self._foreach_nodes resolved_node = next((n for n in all_nodes if n.name == node), None) if not resolved_node: raise ValueError(f"Node '{node}' not found") @@ -163,7 +163,7 @@ def connect( if isinstance(from_node, str): # Search across all node types - all_nodes = self._agents + self._tools + self._ariums + self._foreach_nodes + all_nodes = self._agents + self._tool_nodes + self._ariums + self._foreach_nodes resolved_from_node = next( (n for n in all_nodes if n.name == from_node), None ) @@ -173,7 +173,7 @@ def connect( if isinstance(to_node, str): # Search across all node types - all_nodes = self._agents + self._tools + self._ariums + self._foreach_nodes + all_nodes = self._agents + self._tool_nodes + self._ariums + self._foreach_nodes resolved_to_node = next((n for n in all_nodes if n.name == to_node), None) if not resolved_to_node: raise ValueError(f"Node '{to_node}' not found") @@ -193,12 +193,12 @@ def build(self) -> Arium: # Add all nodes all_nodes = [] all_nodes.extend(self._agents) - all_nodes.extend(self._tools) + all_nodes.extend(self._tool_nodes) all_nodes.extend(self._ariums) all_nodes.extend(self._foreach_nodes) if not all_nodes: - raise ValueError('No agents or tools added to the Arium') + raise ValueError('No agents or tool nodes added to the Arium') arium.add_nodes(all_nodes) @@ -255,7 +255,7 @@ def reset(self) -> 'AriumBuilder': """Reset the builder to start fresh.""" self._memory = None self._agents = [] - self._tools = [] + self._tool_nodes = [] self._ariums = [] self._foreach_nodes = [] self._start_node = None @@ -271,7 +271,7 @@ def from_yaml( yaml_file: Optional[str] = None, memory: Optional[BaseMemory] = None, agents: Optional[Dict[str, Agent]] = None, - tools: Optional[Dict[str, Tool]] = None, + tool_nodes: Optional[Dict[str, ToolNode]] = None, routers: Optional[Dict[str, Callable]] = None, base_llm: Optional[BaseLLM] = None, ) -> 'AriumBuilder': @@ -282,7 +282,7 @@ def from_yaml( yaml_file: Path to YAML file containing arium configuration memory: Memory instance to use for the workflow (defaults to MessageMemory) agents: Dictionary mapping agent names to pre-built Agent instances - tools: Dictionary mapping tool names to Tool instances + tool_nodes: Dictionary mapping tool names to ToolNode instances routers: Dictionary mapping router names to router functions base_llm: Base LLM to use for all agents if not specified in individual agent configs @@ -325,7 +325,7 @@ def from_yaml( - name: reporter yaml_file: "path/to/reporter.yaml" - tools: + tool_nodes: - name: tool1 - name: tool2 @@ -464,7 +464,7 @@ def from_yaml( and 'yaml_file' not in agent_config ): agent = cls._create_agent_from_direct_config( - agent_config, base_llm, tools + agent_config, base_llm ) # Method 3: Inline YAML config @@ -493,24 +493,40 @@ def from_yaml( agents_dict[agent_name] = agent builder.add_agent(agent) - # Process tools - tools_config = arium_config.get('tools', []) - tools_dict = {} + # Process tool nodes + tool_nodes_config = arium_config.get('tool_nodes', []) + tool_nodes_dict = {} - for tool_config in tools_config: - tool_name = tool_config['name'] + for tool_node_config in tool_nodes_config: + tool_node_name = tool_node_config['name'] - # Look up tool in provided tools dictionary - if tools and tool_name in tools: - tool = tools[tool_name] - tools_dict[tool_name] = tool - builder.add_tool(tool) + #Add a tool node from a pre-built tool node + if len(tool_node_config) == 1 and 'name' in tool_node_config: + if tool_nodes and tool_node_name in tool_nodes: + tool_node = tool_nodes[tool_node_name] + else: + raise ValueError( + f"ToolNode {tool_node_name} not found in provided tool_nodes dictionary. " + f"Available tool_nodes: {list(tool_nodes.keys()) if tool_nodes else []}. " + f"Either provide the ToolNode in the tool_nodes parameter or add configuration fields." + ) else: - raise ValueError( - f'Tool {tool_name} not found in provided tools dictionary. ' - f'Available tools: {list(tools.keys()) if tools else []}' + # Add a tool node from a direct ToolNode definition (function must be provided in code, YAML cannot define callables) + function = tool_node_config.get('function') + if function is None: + # Fallback: identity passthrough if not provided + function = (lambda inputs=None, variables=None, **kwargs: inputs) + + tool_node = ToolNode( + name=tool_node_name, + description=tool_node_config.get('description', ''), + function=function, + input_filter=tool_node_config.get('input_filter', None), ) + tool_nodes_dict[tool_node_name] = tool_node + builder.add_tool_node(tool_node) + # Process LLM routers (if defined in YAML) routers_config = arium_config.get('routers', []) yaml_routers = {} # Store routers created from YAML config @@ -630,7 +646,7 @@ def from_yaml( yaml_file=yaml_file_path, memory=None, agents=None, - tools=tools, # Nested can use parent's tools + tool_nodes=None, routers=None, base_llm=base_llm, ) @@ -642,7 +658,7 @@ def from_yaml( sub_config = { 'arium': { 'agents': arium_node_config.get('agents', []), - 'tools': arium_node_config.get('tools', []), + 'tool_nodes': arium_node_config.get('tool_nodes', []), 'routers': arium_node_config.get('routers', []), 'ariums': arium_node_config.get( 'ariums', [] @@ -656,7 +672,7 @@ def from_yaml( yaml_str=yaml.dump(sub_config), memory=None, agents=None, - tools=tools, + tool_nodes=None, routers=None, base_llm=base_llm, ) @@ -691,7 +707,7 @@ def from_yaml( # Find execute_node from ALL node types execute_node = ( agents_dict.get(execute_node_name) - or tools_dict.get(execute_node_name) + or tool_nodes_dict.get(execute_node_name) or arium_nodes_dict.get(execute_node_name) or foreach_nodes_dict.get(execute_node_name) ) @@ -699,7 +715,7 @@ def from_yaml( if not execute_node: all_nodes = ( list(agents_dict.keys()) - + list(tools_dict.keys()) + + list(tool_nodes_dict.keys()) + list(arium_nodes_dict.keys()) + list(foreach_nodes_dict.keys()) ) @@ -721,7 +737,7 @@ def from_yaml( def _find_node(node_name: str): return ( agents_dict.get(node_name) - or tools_dict.get(node_name) + or tool_nodes_dict.get(node_name) or arium_nodes_dict.get(node_name) or foreach_nodes_dict.get(node_name) ) @@ -735,7 +751,7 @@ def _find_node(node_name: str): if not start_node: all_available = ( list(agents_dict.keys()) - + list(tools_dict.keys()) + + list(tool_nodes_dict.keys()) + list(arium_nodes_dict.keys()) + list(foreach_nodes_dict.keys()) ) diff --git a/flo_ai/flo_ai/arium/memory.py b/flo_ai/flo_ai/arium/memory.py index 783ca4ce..15340f27 100644 --- a/flo_ai/flo_ai/arium/memory.py +++ b/flo_ai/flo_ai/arium/memory.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import TypeVar, Generic, List, Dict, Optional, Any +import json from dataclasses import dataclass, field from enum import Enum @@ -92,7 +93,7 @@ def add(self, m: T): pass @abstractmethod - def get(self) -> List[T]: + def get(self, include_nodes: Optional[List[str]] = None) -> List[T]: pass # Plan management methods (optional - only implemented by memory classes that support plans) @@ -113,32 +114,124 @@ def get_plan(self, plan_id: str) -> Optional[ExecutionPlan]: return None -class MessageMemory(BaseMemory[Dict[str, str]]): +class MessageMemory(BaseMemory[Dict[str, Any]]): def __init__(self): - self.messages = [] - - def add(self, message: Dict[str, str]): - self.messages.append(message) - - def get(self) -> List[Dict[str, str]]: - return self.messages - - -class PlanAwareMemory(BaseMemory[Dict[str, str]]): + self.messages: List[Dict[str, Any]] = [] + self._node_occurrences: Dict[str, int] = {} + + def _next_occurrence(self, node: str) -> int: + current = self._node_occurrences.get(node, 0) + 1 + self._node_occurrences[node] = current + return current + + def _normalize(self, message: Any) -> Optional[Dict[str, Any]]: + if message is None: + return None + + # New schema without occurrence: { 'node': str, 'output': Any } + if isinstance(message, dict) and {"node", "output"}.issubset(message.keys()) and "occurrence" not in message: + node = str(message["node"]) if message.get("node") is not None else "unknown" + occurrence = self._next_occurrence(node) + return {"node": node, "occurrence": occurrence, "output": message.get("output")} + + # Legacy schema produced by Arium: { 'node_name': str, 'result': Any } + if isinstance(message, dict) and { + "node_name", + "result", + }.issubset(message.keys()): + node = str(message["node_name"]) if message.get("node_name") is not None else "unknown" + occurrence = self._next_occurrence(node) + return {"node": node, "occurrence": occurrence, "output": message.get("result")} + + # Raw string input + if isinstance(message, str): + node = "input" + occurrence = self._next_occurrence(node) + return {"node": node, "occurrence": occurrence, "output": message} + + # Generic object: try to_dict, else str + if hasattr(message, "to_dict") and callable(getattr(message, "to_dict")): + node = "input" + occurrence = self._next_occurrence(node) + try: + data = message.to_dict() + except Exception: + data = str(message) + return {"node": node, "occurrence": occurrence, "output": data} + + # Fallback + node = "input" + occurrence = self._next_occurrence(node) + return {"node": node, "occurrence": occurrence, "output": message} + + def add(self, message: Any): + normalized = self._normalize(message) + if normalized is not None: + self.messages.append(normalized) + + def get(self, include_nodes: Optional[List[str]] = None) -> List[Dict[str, Any]]: + if not include_nodes: + return self.messages + include = set(include_nodes) + return [m for m in self.messages if isinstance(m, dict) and m.get('node') in include] + + +class PlanAwareMemory(BaseMemory[Dict[str, Any]]): """Enhanced memory that supports both messages and execution plans""" def __init__(self): - self.messages = [] + self.messages: List[Dict[str, Any]] = [] self.plans: Dict[str, ExecutionPlan] = {} self.current_plan_id: Optional[str] = None - - def add(self, message: Dict[str, str]): + self._node_occurrences: Dict[str, int] = {} + + def _next_occurrence(self, node: str) -> int: + current = self._node_occurrences.get(node, 0) + 1 + self._node_occurrences[node] = current + return current + + def _normalize(self, message: Any) -> Optional[Dict[str, Any]]: + if message is None: + return None + if isinstance(message, dict) and {"node", "occurrence", "output"}.issubset(message.keys()): + return message + # New schema without occurrence: { 'node': str, 'output': Any } + if isinstance(message, dict) and {"node", "output"}.issubset(message.keys()) and "occurrence" not in message: + node = str(message["node"]) if message.get("node") is not None else "unknown" + occurrence = self._next_occurrence(node) + return {"node": node, "occurrence": occurrence, "output": message.get("output")} + if isinstance(message, dict) and {"node_name", "result"}.issubset(message.keys()): + node = str(message["node_name"]) if message.get("node_name") is not None else "unknown" + occurrence = self._next_occurrence(node) + return {"node": node, "occurrence": occurrence, "output": message.get("result")} + if isinstance(message, str): + node = "input" + occurrence = self._next_occurrence(node) + return {"node": node, "occurrence": occurrence, "output": message} + if hasattr(message, "to_dict") and callable(getattr(message, "to_dict")): + node = "input" + occurrence = self._next_occurrence(node) + try: + data = message.to_dict() + except Exception: + data = str(message) + return {"node": node, "occurrence": occurrence, "output": data} + node = "input" + occurrence = self._next_occurrence(node) + return {"node": node, "occurrence": occurrence, "output": message} + + def add(self, message: Any): """Add a message to memory""" - self.messages.append(message) + normalized = self._normalize(message) + if normalized is not None: + self.messages.append(normalized) - def get(self) -> List[Dict[str, str]]: + def get(self, include_nodes: Optional[List[str]] = None) -> List[Dict[str, Any]]: """Get all messages""" - return self.messages + if not include_nodes: + return self.messages + include = set(include_nodes) + return [m for m in self.messages if isinstance(m, dict) and m.get('node') in include] # Plan management methods def add_plan(self, plan: ExecutionPlan): diff --git a/flo_ai/flo_ai/arium/models.py b/flo_ai/flo_ai/arium/models.py index e372c3a4..36c1f27b 100644 --- a/flo_ai/flo_ai/arium/models.py +++ b/flo_ai/flo_ai/arium/models.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable, List +from typing import Callable, List, Optional from functools import partial from flo_ai.arium.memory import BaseMemory @@ -14,11 +14,13 @@ def default_router( @dataclass class StartNode: name = '__start__' + input_filter: Optional[List[str]] = None @dataclass class EndNode: name = '__end__' + input_filter: Optional[List[str]] = None @dataclass diff --git a/flo_ai/flo_ai/arium/nodes.py b/flo_ai/flo_ai/arium/nodes.py index 81e0f510..6355a8de 100644 --- a/flo_ai/flo_ai/arium/nodes.py +++ b/flo_ai/flo_ai/arium/nodes.py @@ -1,7 +1,8 @@ from flo_ai.arium.protocols import ExecutableNode -from typing import List, Any, Dict, Optional, TYPE_CHECKING +from typing import List, Any, Dict, Optional, TYPE_CHECKING, Callable from flo_ai.utils.logger import logger from flo_ai.arium.memory import MessageMemory +import asyncio if TYPE_CHECKING: # need to have an optional import else will get circular dependency error as arium also has AriumNode reference from flo_ai.arium.arium import Arium @@ -12,7 +13,7 @@ class AriumNode: Wrapper to use an Arium as a node in another Arium workflow. """ - def __init__(self, name: str, arium: 'Arium', inherit_variables: bool = True): + def __init__(self, name: str, arium: 'Arium', inherit_variables: bool = True, input_filter: Optional[List[str]] = None): """ Args: name: Name for this node in the parent workflow @@ -22,6 +23,7 @@ def __init__(self, name: str, arium: 'Arium', inherit_variables: bool = True): self.name = name self.arium = arium self.inherit_variables = inherit_variables + self.input_filter: Optional[List[str]] = input_filter async def run( self, inputs: List[Any], variables: Optional[Dict[str, Any]] = None, **kwargs @@ -48,7 +50,7 @@ class ForEachNode: Supports only sequential execution for now. (parallel execution would be supported in future) """ - def __init__(self, name: str, execute_node: ExecutableNode): + def __init__(self, name: str, execute_node: ExecutableNode, input_filter: Optional[List[str]] = None): """ Args: name: Node name @@ -56,6 +58,7 @@ def __init__(self, name: str, execute_node: ExecutableNode): """ self.name = name self.execute_node = execute_node + self.input_filter: Optional[List[str]] = input_filter async def _execute_item( self, @@ -140,3 +143,40 @@ async def _execute_item_with_isolated_memory( if isinstance(result, list) and result: return result[-1] return result + +class ToolNode: + """ + Lightweight tool-as-node wrapper that conforms to ExecutableNode. + + Forwards inputs and variables to the provided function along with any kwargs. + """ + + def __init__( + self, + name: str, + description: str, + function: Callable[..., Any], + input_filter: Optional[List[str]] = None, + ) -> None: + self.name = name + self.description = description + self.function = function + self.input_filter: Optional[List[str]] = input_filter + + async def run( + self, + inputs: List[Any] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + logger.info( + f"Executing ToolNode '{self.name}' with inputs: {inputs} variables: {variables} kwargs: {kwargs}" + ) + + if asyncio.iscoroutinefunction(self.function): + return await self.function(inputs=inputs, variables=variables, **kwargs) + + result = self.function(inputs=inputs, variables=variables, **kwargs) + if asyncio.iscoroutine(result): + return await result + return result diff --git a/flo_ai/flo_ai/arium/protocols.py b/flo_ai/flo_ai/arium/protocols.py index 9b4602c1..ef193573 100644 --- a/flo_ai/flo_ai/arium/protocols.py +++ b/flo_ai/flo_ai/arium/protocols.py @@ -16,6 +16,8 @@ class ExecutableNode(Protocol): name: str """Unique identifier for the node""" + input_filter: Optional[List[str]] = None + """List of input keys to include in the node's execution""" async def run( self, diff --git a/flo_ai/flo_ai/models/agent.py b/flo_ai/flo_ai/models/agent.py index aff81a90..517c7530 100644 --- a/flo_ai/flo_ai/models/agent.py +++ b/flo_ai/flo_ai/models/agent.py @@ -40,6 +40,7 @@ def __init__( output_schema: Optional[Dict[str, Any]] = None, role: Optional[str] = None, act_as: Optional[str] = MessageType.ASSISTANT, + input_filter: Optional[List[str]] = None, ): # Determine agent type based on tools agent_type = AgentType.TOOL_USING if tools else AgentType.CONVERSATIONAL @@ -63,6 +64,7 @@ def __init__( self.output_schema = output_schema self.role = role self.act_as = act_as + self.input_filter: Optional[List[str]] = input_filter @trace_agent_execution() async def run( From 8c51d8252f224494a8ae432a61b3fe67ca7b7871 Mon Sep 17 00:00:00 2001 From: vishnu r kumar Date: Mon, 17 Nov 2025 17:27:03 +0530 Subject: [PATCH 2/7] Store node results as message memory item for arium execution --- flo_ai/examples/arium_examples.py | 46 ++++++++------ flo_ai/flo_ai/arium/arium.py | 24 +++++--- flo_ai/flo_ai/arium/builder.py | 8 +-- flo_ai/flo_ai/arium/memory.py | 99 ++++++++++++------------------- flo_ai/flo_ai/arium/nodes.py | 27 +++++++-- 5 files changed, 103 insertions(+), 101 deletions(-) diff --git a/flo_ai/examples/arium_examples.py b/flo_ai/examples/arium_examples.py index 0eabb603..3ab75264 100644 --- a/flo_ai/examples/arium_examples.py +++ b/flo_ai/examples/arium_examples.py @@ -8,7 +8,9 @@ from flo_ai.models import TextMessageContent, UserMessage from flo_ai.models.agent import Agent from flo_ai.arium.nodes import ToolNode -from flo_ai.arium.memory import MessageMemory +from flo_ai.arium.memory import MessageMemory, MessageMemoryItem +from flo_ai.models import BaseMessage +from typing import List async def print_result(result: str) -> str: @@ -199,22 +201,19 @@ async def example_tool_nodes_with_filters(): """Workflow of only ToolNodes; each uses input_filter to read from specific nodes.""" # Define simple tool functions - async def pass_through(inputs=None, variables=None, **kwargs): - return inputs + async def pass_through(inputs: List[BaseMessage] | str, variables=None, **kwargs): + return inputs[-1].content - async def capitalize_last(inputs=None, variables=None, **kwargs): - if not inputs: - return 'No inputs' - last = str(inputs[-1]) - return last.capitalize() + async def capitalize_last( + inputs: List[BaseMessage] | str, variables=None, **kwargs + ): + return str(inputs[-1].content).capitalize() - async def uppercase_all(inputs=None, variables=None, **kwargs): - if not inputs: - return 'No inputs' - return ' '.join([str(x).upper() for x in inputs]) + async def uppercase_all(inputs: List[BaseMessage] | str, variables=None, **kwargs): + return ' '.join([str(x.content).upper() for x in inputs]) - async def summarize(inputs=None, variables=None, **kwargs): - return f"count={len(inputs or [])} last={(str(inputs[-1]) if inputs else '')}" + async def summarize(inputs: List[BaseMessage] | str, variables=None, **kwargs): + return f"count={len(inputs or [])} last={(str(inputs[-1].content) if inputs else '')}" # Create four ToolNodes with input filters t1 = ToolNode( @@ -243,14 +242,25 @@ async def summarize(inputs=None, variables=None, **kwargs): ) # Build and run: tool1 -> tool2 -> tool3 -> tool4 + state = 1 + + def router(memory: MessageMemory) -> Literal['tool2', 'tool4']: + nonlocal state + if state == 1: + state = 2 + return 'tool2' + else: + state = 1 + return 'tool4' + result = await ( AriumBuilder() .with_memory(MessageMemory()) .add_tool_nodes([t1, t2, t3, t4]) .start_with(t1) - .connect(t1, t2) + .add_edge(t1, [t2, t4], router=router) .connect(t2, t3) - .connect(t3, t4) + .connect(t3, t1) .end_with(t4) .build_and_run(['hello world']) ) @@ -271,9 +281,9 @@ async def main(): # result3 = await example_complex_workflow() # result4 = await example_convenience_function() # result5 = await example_build_and_reuse() - result6 = await example_tool_nodes_with_filters() + result6: List[MessageMemoryItem] = await example_tool_nodes_with_filters() - print(result6) + print(result6[-1].result.content) print('Examples completed!') diff --git a/flo_ai/flo_ai/arium/arium.py b/flo_ai/flo_ai/arium/arium.py index cdbfdd08..ae8871c5 100644 --- a/flo_ai/flo_ai/arium/arium.py +++ b/flo_ai/flo_ai/arium/arium.py @@ -1,5 +1,5 @@ from flo_ai.arium.base import BaseArium -from flo_ai.arium.memory import MessageMemory, BaseMemory +from flo_ai.arium.memory import MessageMemory, BaseMemory, MessageMemoryItem from flo_ai.models import BaseMessage, UserMessage, TextMessageContent from typing import List, Dict, Any, Optional, Callable from flo_ai.models.agent import Agent @@ -204,7 +204,10 @@ async def _execute_graph( events_filter: Optional[List[AriumEventType]] = None, variables: Optional[Dict[str, Any]] = None, ): - [self.memory.add(msg) for msg in inputs] + [ + self.memory.add(MessageMemoryItem(node='input', occurrence=0, result=msg)) + for msg in inputs + ] current_node = self.nodes[self.start_node_name] current_edge = self.edges[self.start_node_name] @@ -251,11 +254,15 @@ async def _execute_graph( ) if isinstance(result, List): # for each node will give results array - self._add_to_memory(result[-1]) + self._add_to_memory( + MessageMemoryItem(node=current_node.name, result=result[-1]) + ) else: # update results to memory if result: - self._add_to_memory(result) + self._add_to_memory( + MessageMemoryItem(node=current_node.name, result=result) + ) # find next node post current node # Prepare execution context for router functions @@ -448,7 +455,7 @@ async def _execute_node( if getattr(node, 'input_filter', None) else self.memory.get() ) - inputs = [item['output'] for item in memory_items] + inputs = [item.result for item in memory_items] if tracer and node_type not in ['start', 'end']: with tracer.start_as_current_span( @@ -595,9 +602,8 @@ async def _execute_node( # Re-raise the exception raise e - def _add_to_memory(self, result: BaseMessage): + def _add_to_memory(self, message: MessageMemoryItem): """ - Store result in memory, converting strings to AssistantMessage if needed. - Agent responses should be stored as AssistantMessage, not UserMessage. + Store message in memory """ - self.memory.add(result) + self.memory.add(message) diff --git a/flo_ai/flo_ai/arium/builder.py b/flo_ai/flo_ai/arium/builder.py index f3eb4e4c..3286aef8 100644 --- a/flo_ai/flo_ai/arium/builder.py +++ b/flo_ai/flo_ai/arium/builder.py @@ -3,7 +3,7 @@ from flo_ai.arium.memory import MessageMemory, BaseMemory from flo_ai.arium.protocols import ExecutableNode from flo_ai.arium.nodes import AriumNode, ForEachNode -from flo_ai.models import BaseMessage, TextMessageContent, UserMessage +from flo_ai.models import BaseMessage, UserMessage from flo_ai.models.agent import Agent, resolve_variables from flo_ai.tool.base_tool import Tool import yaml @@ -250,11 +250,7 @@ async def build_and_run( new_inputs = [] for input in inputs: if isinstance(input, str): - new_inputs.append( - UserMessage( - TextMessageContent(text=resolve_variables(input, variables)) - ) - ) + new_inputs.append(UserMessage(resolve_variables(input, variables))) elif isinstance(input, BaseMessage): new_inputs.append(input) else: diff --git a/flo_ai/flo_ai/arium/memory.py b/flo_ai/flo_ai/arium/memory.py index ca05bfbe..af0db6d2 100644 --- a/flo_ai/flo_ai/arium/memory.py +++ b/flo_ai/flo_ai/arium/memory.py @@ -19,6 +19,22 @@ class StepStatus(Enum): SKIPPED = 'skipped' +class MessageMemoryItem: + def __init__( + self, node: str, occurrence: int = 0, result: BaseMessage | str = None + ): + self.node: str = node + self.occurrence: int = occurrence + self.result: BaseMessage | str = result + + def to_dict(self) -> Dict[str, Any]: + return { + 'node': self.node, + 'occurrence': self.occurrence, + 'result': self.result, + } + + @dataclass class PlanStep: """Represents a single step in an execution plan""" @@ -115,16 +131,27 @@ def get_plan(self, plan_id: str) -> Optional[ExecutionPlan]: return None -class MessageMemory(BaseMemory[BaseMessage]): +class MessageMemory(BaseMemory[MessageMemoryItem]): def __init__(self): - self.messages: List[Dict[str, Any]] = [] + self.messages: List[MessageMemoryItem] = [] self._node_occurrences: Dict[str, int] = {} - def add(self, message: BaseMessage): + def _next_occurrence(self, node: str) -> int: + current = self._node_occurrences.get(node, 0) + 1 + self._node_occurrences[node] = current + return current + + def add(self, message: MessageMemoryItem): + # Update occurrence count for the node + occurrence = self._next_occurrence(message.node) + message.occurrence = occurrence self.messages.append(message) - def get(self) -> List[BaseMessage]: - return self.messages + def get(self, include_nodes: Optional[List[str]] = None) -> List[MessageMemoryItem]: + if not include_nodes: + return self.messages + include = set[str](include_nodes) + return [m for m in self.messages if m.node in include] class PlanAwareMemory(BaseMemory[Dict[str, Any]]): @@ -141,63 +168,11 @@ def _next_occurrence(self, node: str) -> int: self._node_occurrences[node] = current return current - def _normalize(self, message: Any) -> Optional[Dict[str, Any]]: - if message is None: - return None - if isinstance(message, dict) and {'node', 'occurrence', 'output'}.issubset( - message.keys() - ): - return message - # New schema without occurrence: { 'node': str, 'output': Any } - if ( - isinstance(message, dict) - and {'node', 'output'}.issubset(message.keys()) - and 'occurrence' not in message - ): - node = ( - str(message['node']) if message.get('node') is not None else 'unknown' - ) - occurrence = self._next_occurrence(node) - return { - 'node': node, - 'occurrence': occurrence, - 'output': message.get('output'), - } - if isinstance(message, dict) and {'node_name', 'result'}.issubset( - message.keys() - ): - node = ( - str(message['node_name']) - if message.get('node_name') is not None - else 'unknown' - ) - occurrence = self._next_occurrence(node) - return { - 'node': node, - 'occurrence': occurrence, - 'output': message.get('result'), - } - if isinstance(message, str): - node = 'input' - occurrence = self._next_occurrence(node) - return {'node': node, 'occurrence': occurrence, 'output': message} - if hasattr(message, 'to_dict') and callable(getattr(message, 'to_dict')): - node = 'input' - occurrence = self._next_occurrence(node) - try: - data = message.to_dict() - except Exception: - data = str(message) - return {'node': node, 'occurrence': occurrence, 'output': data} - node = 'input' - occurrence = self._next_occurrence(node) - return {'node': node, 'occurrence': occurrence, 'output': message} - - def add(self, message: Any): - """Add a message to memory""" - normalized = self._normalize(message) - if normalized is not None: - self.messages.append(normalized) + def add(self, message: MessageMemoryItem): + # Update occurrence count for the node + occurrence = self._next_occurrence(message.node) + message.occurrence = occurrence + self.messages.append(message) def get(self, include_nodes: Optional[List[str]] = None) -> List[Dict[str, Any]]: """Get all messages""" diff --git a/flo_ai/flo_ai/arium/nodes.py b/flo_ai/flo_ai/arium/nodes.py index 6355a8de..d29a59ea 100644 --- a/flo_ai/flo_ai/arium/nodes.py +++ b/flo_ai/flo_ai/arium/nodes.py @@ -2,6 +2,7 @@ from typing import List, Any, Dict, Optional, TYPE_CHECKING, Callable from flo_ai.utils.logger import logger from flo_ai.arium.memory import MessageMemory +from flo_ai.models import BaseMessage, UserMessage import asyncio if TYPE_CHECKING: # need to have an optional import else will get circular dependency error as arium also has AriumNode reference @@ -13,7 +14,13 @@ class AriumNode: Wrapper to use an Arium as a node in another Arium workflow. """ - def __init__(self, name: str, arium: 'Arium', inherit_variables: bool = True, input_filter: Optional[List[str]] = None): + def __init__( + self, + name: str, + arium: 'Arium', + inherit_variables: bool = True, + input_filter: Optional[List[str]] = None, + ): """ Args: name: Name for this node in the parent workflow @@ -50,7 +57,12 @@ class ForEachNode: Supports only sequential execution for now. (parallel execution would be supported in future) """ - def __init__(self, name: str, execute_node: ExecutableNode, input_filter: Optional[List[str]] = None): + def __init__( + self, + name: str, + execute_node: ExecutableNode, + input_filter: Optional[List[str]] = None, + ): """ Args: name: Node name @@ -144,6 +156,7 @@ async def _execute_item_with_isolated_memory( return result[-1] return result + class ToolNode: """ Lightweight tool-as-node wrapper that conforms to ExecutableNode. @@ -165,7 +178,7 @@ def __init__( async def run( self, - inputs: List[Any] = None, + inputs: List[BaseMessage] | str, variables: Optional[Dict[str, Any]] = None, **kwargs, ) -> Any: @@ -174,9 +187,11 @@ async def run( ) if asyncio.iscoroutinefunction(self.function): - return await self.function(inputs=inputs, variables=variables, **kwargs) + result = await self.function(inputs=inputs, variables=variables, **kwargs) + return UserMessage(content=result) result = self.function(inputs=inputs, variables=variables, **kwargs) if asyncio.iscoroutine(result): - return await result - return result + content = await result + return UserMessage(content=content) + return UserMessage(content=result) From 33b7bc625cb4e4ce4832cc58fe97012cd96e8538 Mon Sep 17 00:00:00 2001 From: vishnu r kumar Date: Mon, 17 Nov 2025 17:40:58 +0530 Subject: [PATCH 3/7] Rename ToolNode as FunctionNode --- flo_ai/examples/arium_examples.py | 98 +++++++++++++++---------------- flo_ai/flo_ai/arium/arium.py | 12 ++-- flo_ai/flo_ai/arium/builder.py | 96 +++++++++++++++--------------- flo_ai/flo_ai/arium/nodes.py | 9 +-- 4 files changed, 108 insertions(+), 107 deletions(-) diff --git a/flo_ai/examples/arium_examples.py b/flo_ai/examples/arium_examples.py index 3ab75264..51745916 100644 --- a/flo_ai/examples/arium_examples.py +++ b/flo_ai/examples/arium_examples.py @@ -7,7 +7,7 @@ from flo_ai.llm import OpenAI from flo_ai.models import TextMessageContent, UserMessage from flo_ai.models.agent import Agent -from flo_ai.arium.nodes import ToolNode +from flo_ai.arium.nodes import FunctionNode from flo_ai.arium.memory import MessageMemory, MessageMemoryItem from flo_ai.models import BaseMessage from typing import List @@ -20,9 +20,9 @@ async def print_result(result: str) -> str: # Example 1: Simple Linear Workflow async def example_linear_workflow(): - """Example of a simple linear workflow: Agent -> Tool -> Agent""" + """Example of a simple linear workflow: Agent -> FunctionNode -> Agent""" - # Create some example agents and tools (these would be your actual implementations) + # Create some example agents and function nodes (these would be your actual implementations) analyzer_agent = Agent( name='analyzer', system_prompt='Analyze the input', @@ -33,7 +33,7 @@ async def example_linear_workflow(): system_prompt='Summarize the results', llm=OpenAI(model='gpt-4o-mini'), ) - processing_tool_node = ToolNode( + processing_function_node = FunctionNode( name='processor', description='Process the input', function=print_result ) @@ -41,11 +41,11 @@ async def example_linear_workflow(): result = await ( AriumBuilder() .add_agent(analyzer_agent) - .add_tool_node(processing_tool_node) + .add_function_node(processing_function_node) .add_agent(summarizer_agent) .start_with(analyzer_agent) - .connect(analyzer_agent, processing_tool_node) - .connect(processing_tool_node, summarizer_agent) + .connect(analyzer_agent, processing_function_node) + .connect(processing_function_node, summarizer_agent) .end_with(summarizer_agent) .build_and_run([UserMessage(TextMessageContent(text='Analyze this text'))]) ) @@ -57,12 +57,12 @@ async def example_linear_workflow(): async def example_branching_workflow(): """Example of a branching workflow with conditional routing""" - # Create agents and tools + # Create agents and function nodes classifier_agent = Agent(name='classifier', prompt='Classify the input type') - text_processor_node = ToolNode( + text_processor_node = FunctionNode( name='text_processor', description='Process text', function=lambda x: x ) - image_processor_node = ToolNode( + image_processor_node = FunctionNode( name='image_processor', description='Process image', function=lambda x: x ) final_agent = Agent(name='final', prompt='Provide final response') @@ -80,13 +80,13 @@ def content_router(memory) -> Literal['text_processor', 'image_processor']: AriumBuilder() .add_agent(classifier_agent) .add_agent(final_agent) - .add_tool_node( - ToolNode( + .add_function_node( + FunctionNode( name='text_processor', description='Process text', function=lambda x: x ) ) - .add_tool_node( - ToolNode( + .add_function_node( + FunctionNode( name='image_processor', description='Process image', function=lambda x: x, @@ -108,19 +108,19 @@ def content_router(memory) -> Literal['text_processor', 'image_processor']: # Example 3: Complex Multi-Agent Workflow async def example_complex_workflow(): - """Example of a more complex workflow with multiple agents and tools""" + """Example of a more complex workflow with multiple agents and function nodes""" - # Create multiple agents and tools + # Create multiple agents and function nodes input_agent = Agent(name='input_handler', prompt='Handle initial input') researcher_agent = Agent(name='researcher', prompt='Research the topic') analyzer_agent = Agent(name='analyzer', prompt='Analyze findings') writer_agent = Agent(name='writer', prompt='Write the final report') - search_tool_node = ToolNode( - name='search_tool', description='Search the web', function=lambda x: x + search_function_node = FunctionNode( + name='search_function', description='Search the web', function=lambda x: x ) - data_tool_node = ToolNode( - name='data_processor', description='Process the data', function=lambda x: x + data_function_node = FunctionNode( + name='data_function', description='Process the data', function=lambda x: x ) # Router for deciding next step after analysis @@ -133,13 +133,13 @@ def analysis_router(memory) -> Literal['writer', 'researcher']: arium = ( AriumBuilder() .add_agents([input_agent, researcher_agent, analyzer_agent, writer_agent]) - .add_tool_nodes([search_tool_node, data_tool_node]) + .add_function_nodes([search_function_node, data_function_node]) .with_memory(MessageMemory()) .start_with(input_agent) .connect(input_agent, researcher_agent) - .connect(researcher_agent, search_tool_node) - .connect(search_tool_node, data_tool_node) - .connect(data_tool_node, analyzer_agent) + .connect(researcher_agent, search_function_node) + .connect(search_function_node, data_function_node) + .connect(data_function_node, analyzer_agent) .add_edge(analyzer_agent, [writer_agent, researcher_agent], analysis_router) .end_with(writer_agent) .build() @@ -196,11 +196,11 @@ async def example_build_and_reuse(): return result1, result2 -# Example 6: Four ToolNodes with input filtering (no agents) -async def example_tool_nodes_with_filters(): - """Workflow of only ToolNodes; each uses input_filter to read from specific nodes.""" +# Example 6: Four FunctionNodes with input filtering (no agents) +async def example_function_nodes_with_filters(): + """Workflow of only FunctionNodes; each uses input_filter to read from specific nodes.""" - # Define simple tool functions + # Define simple functions as nodes async def pass_through(inputs: List[BaseMessage] | str, variables=None, **kwargs): return inputs[-1].content @@ -215,48 +215,48 @@ async def uppercase_all(inputs: List[BaseMessage] | str, variables=None, **kwarg async def summarize(inputs: List[BaseMessage] | str, variables=None, **kwargs): return f"count={len(inputs or [])} last={(str(inputs[-1].content) if inputs else '')}" - # Create four ToolNodes with input filters - t1 = ToolNode( - name='tool1', + # Create four FunctionNodes with input filters + t1 = FunctionNode( + name='function1', description='reads initial inputs', function=pass_through, input_filter=['input'], ) - t2 = ToolNode( - name='tool2', - description='reads tool1 only', + t2 = FunctionNode( + name='function2', + description='reads function1 only', function=capitalize_last, - input_filter=['tool1'], + input_filter=['function1'], ) - t3 = ToolNode( - name='tool3', - description='reads tool2 only', + t3 = FunctionNode( + name='function3', + description='reads function2 only', function=uppercase_all, - input_filter=['tool2'], + input_filter=['function2'], ) - t4 = ToolNode( - name='tool4', - description='reads tool1 & tool3', + t4 = FunctionNode( + name='function4', + description='reads function1 & function3', function=summarize, - input_filter=['tool1', 'tool3'], + input_filter=['function1', 'function3'], ) - # Build and run: tool1 -> tool2 -> tool3 -> tool4 + # Build and run: function1 -> function2 -> function3 -> function4 state = 1 - def router(memory: MessageMemory) -> Literal['tool2', 'tool4']: + def router(memory: MessageMemory) -> Literal['function2', 'function4']: nonlocal state if state == 1: state = 2 - return 'tool2' + return 'function2' else: state = 1 - return 'tool4' + return 'function4' result = await ( AriumBuilder() .with_memory(MessageMemory()) - .add_tool_nodes([t1, t2, t3, t4]) + .add_function_nodes([t1, t2, t3, t4]) .start_with(t1) .add_edge(t1, [t2, t4], router=router) .connect(t2, t3) @@ -281,7 +281,7 @@ async def main(): # result3 = await example_complex_workflow() # result4 = await example_convenience_function() # result5 = await example_build_and_reuse() - result6: List[MessageMemoryItem] = await example_tool_nodes_with_filters() + result6: List[MessageMemoryItem] = await example_function_nodes_with_filters() print(result6[-1].result.content) diff --git a/flo_ai/flo_ai/arium/arium.py b/flo_ai/flo_ai/arium/arium.py index ae8871c5..88bc9005 100644 --- a/flo_ai/flo_ai/arium/arium.py +++ b/flo_ai/flo_ai/arium/arium.py @@ -5,7 +5,7 @@ from flo_ai.models.agent import Agent from flo_ai.arium.models import StartNode, EndNode from flo_ai.arium.events import AriumEventType, AriumEvent -from flo_ai.arium.nodes import AriumNode, ForEachNode, ToolNode +from flo_ai.arium.nodes import AriumNode, ForEachNode, FunctionNode from flo_ai.utils.logger import logger from flo_ai.utils.variable_extractor import ( extract_variables_from_inputs, @@ -404,7 +404,7 @@ def _resolve_agent_prompts(self, variables: Dict[str, Any]) -> None: async def _execute_node( self, - node: Agent | ToolNode | ForEachNode | AriumNode | StartNode | EndNode, + node: Agent | FunctionNode | ForEachNode | AriumNode | StartNode | EndNode, event_callback: Optional[Callable[[AriumEvent], None]] = None, events_filter: Optional[List[AriumEventType]] = None, variables: Optional[Dict[str, Any]] = None, @@ -423,8 +423,8 @@ async def _execute_node( # Determine node type for events if isinstance(node, Agent): node_type = 'agent' - elif isinstance(node, ToolNode): - node_type = 'tool' + elif isinstance(node, FunctionNode): + node_type = 'function' elif isinstance(node, ForEachNode): node_type = 'foreach' elif isinstance(node, AriumNode): @@ -472,7 +472,7 @@ async def _execute_node( if isinstance(node, Agent): # Variables are already resolved, pass empty dict to avoid re-processing result = await node.run(inputs, variables={}) - elif isinstance(node, ToolNode): + elif isinstance(node, FunctionNode): result = await node.run(inputs, variables=None) elif isinstance(node, ForEachNode): result = await node.run( @@ -553,7 +553,7 @@ async def _execute_node( # Execute the node based on its type if isinstance(node, Agent): result = await node.run(inputs, variables={}) - elif isinstance(node, ToolNode): + elif isinstance(node, FunctionNode): result = await node.run(inputs, variables=None) elif isinstance(node, ForEachNode): result = await node.run( diff --git a/flo_ai/flo_ai/arium/builder.py b/flo_ai/flo_ai/arium/builder.py index 3286aef8..950d5d61 100644 --- a/flo_ai/flo_ai/arium/builder.py +++ b/flo_ai/flo_ai/arium/builder.py @@ -10,7 +10,7 @@ from flo_ai.builder.agent_builder import AgentBuilder from flo_ai.llm import BaseLLM from flo_ai.arium.llm_router import create_llm_router -from flo_ai.arium.nodes import ToolNode +from flo_ai.arium.nodes import FunctionNode class AriumBuilder: @@ -21,10 +21,10 @@ class AriumBuilder: result = (AriumBuilder() .with_memory(my_memory) .add_agent(agent1) - .add_tool_node(tool_node1) + .add_function_node(function_node1) .start_with(agent1) - .add_edge(agent1, [tool_node1], router_fn) - .end_with(tool_node1) + .add_edge(agent1, [function_node1], router_fn) + .end_with(function_node1) .build_and_run(["Hello, world!"])) """ @@ -37,7 +37,7 @@ def __init__(self): self._foreach_nodes: List[ForEachNode] = [] self._start_node: Optional[ExecutableNode] = None self._end_nodes: List[ExecutableNode] = [] - self._tool_nodes: List[ToolNode] = [] + self._function_nodes: List[FunctionNode] = [] self._edges: List[tuple] = [] # (from_node, to_nodes, router) self._arium: Optional[Arium] = None self._all_ariums: List[ @@ -59,14 +59,14 @@ def add_agents(self, agents: List[Agent]) -> 'AriumBuilder': self._agents.extend(agents) return self - def add_tool_node(self, tool_node: ToolNode) -> 'AriumBuilder': + def add_function_node(self, function_node: FunctionNode) -> 'AriumBuilder': """Add a tool node to the Arium.""" - self._tool_nodes.append(tool_node) + self._function_nodes.append(function_node) return self - def add_tool_nodes(self, tool_nodes: List[ToolNode]) -> 'AriumBuilder': + def add_function_nodes(self, function_nodes: List[FunctionNode]) -> 'AriumBuilder': """Add multiple tool nodes to the Arium.""" - self._tool_nodes.extend(tool_nodes) + self._function_nodes.extend(function_nodes) return self def add_arium( @@ -113,7 +113,7 @@ def add_foreach( if isinstance(execute_node, str): # Search across all node types all_nodes = ( - self._agents + self._tool_nodes + self._ariums + self._foreach_nodes + self._agents + self._function_nodes + self._ariums + self._foreach_nodes ) resolved_node = next((n for n in all_nodes if n.name == execute_node), None) if not resolved_node: @@ -133,7 +133,7 @@ def start_with(self, node: ExecutableNode | str) -> 'AriumBuilder': if isinstance(node, str): # Search across all node types all_nodes = ( - self._agents + self._tool_nodes + self._ariums + self._foreach_nodes + self._agents + self._function_nodes + self._ariums + self._foreach_nodes ) resolved_node = next((n for n in all_nodes if n.name == node), None) if not resolved_node: @@ -168,7 +168,7 @@ def connect( if isinstance(from_node, str): # Search across all node types all_nodes = ( - self._agents + self._tool_nodes + self._ariums + self._foreach_nodes + self._agents + self._function_nodes + self._ariums + self._foreach_nodes ) resolved_from_node = next( (n for n in all_nodes if n.name == from_node), None @@ -180,7 +180,7 @@ def connect( if isinstance(to_node, str): # Search across all node types all_nodes = ( - self._agents + self._tool_nodes + self._ariums + self._foreach_nodes + self._agents + self._function_nodes + self._ariums + self._foreach_nodes ) resolved_to_node = next((n for n in all_nodes if n.name == to_node), None) if not resolved_to_node: @@ -201,12 +201,12 @@ def build(self) -> Arium: # Add all nodes all_nodes = [] all_nodes.extend(self._agents) - all_nodes.extend(self._tool_nodes) + all_nodes.extend(self._function_nodes) all_nodes.extend(self._ariums) all_nodes.extend(self._foreach_nodes) if not all_nodes: - raise ValueError('No agents or tool nodes added to the Arium') + raise ValueError('No agents or function nodes added to the Arium') arium.add_nodes(all_nodes) @@ -271,7 +271,7 @@ def reset(self) -> 'AriumBuilder': """Reset the builder to start fresh.""" self._memory = None self._agents = [] - self._tool_nodes = [] + self._function_nodes = [] self._ariums = [] self._foreach_nodes = [] self._start_node = None @@ -287,7 +287,7 @@ def from_yaml( yaml_file: Optional[str] = None, memory: Optional[BaseMemory] = None, agents: Optional[Dict[str, Agent]] = None, - tool_nodes: Optional[Dict[str, ToolNode]] = None, + function_nodes: Optional[Dict[str, FunctionNode]] = None, routers: Optional[Dict[str, Callable]] = None, base_llm: Optional[BaseLLM] = None, ) -> 'AriumBuilder': @@ -298,7 +298,7 @@ def from_yaml( yaml_file: Path to YAML file containing arium configuration memory: Memory instance to use for the workflow (defaults to MessageMemory) agents: Dictionary mapping agent names to pre-built Agent instances - tool_nodes: Dictionary mapping tool names to ToolNode instances + function_nodes: Dictionary mapping function names to FunctionNode instances routers: Dictionary mapping router names to router functions base_llm: Base LLM to use for all agents if not specified in individual agent configs @@ -341,9 +341,9 @@ def from_yaml( - name: reporter yaml_file: "path/to/reporter.yaml" - tool_nodes: - - name: tool1 - - name: tool2 + function_nodes: + - name: function1 + - name: function2 # LLM Router definitions (NEW) routers: @@ -511,39 +511,39 @@ def from_yaml( builder.add_agent(agent) # Process tool nodes - tool_nodes_config = arium_config.get('tool_nodes', []) - tool_nodes_dict = {} + function_nodes_config = arium_config.get('function_nodes', []) + function_nodes_dict = {} - for tool_node_config in tool_nodes_config: - tool_node_name = tool_node_config['name'] + for function_node_config in function_nodes_config: + function_node_name = function_node_config['name'] - # Add a tool node from a pre-built tool node - if len(tool_node_config) == 1 and 'name' in tool_node_config: - if tool_nodes and tool_node_name in tool_nodes: - tool_node = tool_nodes[tool_node_name] + # Add a function node from pre-built function nodes + if len(function_node_config) == 1 and 'name' in function_node_config: + if function_nodes and function_node_name in function_nodes: + function_node = function_nodes[function_node_name] else: raise ValueError( - f'ToolNode {tool_node_name} not found in provided tool_nodes dictionary. ' - f'Available tool_nodes: {list(tool_nodes.keys()) if tool_nodes else []}. ' - f'Either provide the ToolNode in the tool_nodes parameter or add configuration fields.' + f'FunctionNode {function_node_name} not found in provided function_nodes dictionary. ' + f'Available function_nodes: {list(function_nodes.keys()) if function_nodes else []}. ' + f'Either provide the FunctionNode in the function_nodes parameter or add configuration fields.' ) else: - # Add a tool node from a direct ToolNode definition (function must be provided in code, YAML cannot define callables) - function = tool_node_config.get('function') + # Add a function node from a direct FunctionNode definition (function must be provided in code, YAML cannot define callables) + function = function_node_config.get('function') if function is None: ValueError( - f'Function for ToolNode {tool_node_name} is not provided' + f'Function for FunctionNode {function_node_name} is not provided' ) - tool_node = ToolNode( - name=tool_node_name, - description=tool_node_config.get('description', ''), + function_node = FunctionNode( + name=function_node_name, + description=function_node_config.get('description', ''), function=function, - input_filter=tool_node_config.get('input_filter', None), + input_filter=function_node_config.get('input_filter', None), ) - tool_nodes_dict[tool_node_name] = tool_node - builder.add_tool_node(tool_node) + function_nodes_dict[function_node_name] = function_node + builder.add_function_node(function_node) # Process LLM routers (if defined in YAML) routers_config = arium_config.get('routers', []) @@ -664,7 +664,7 @@ def from_yaml( yaml_file=yaml_file_path, memory=None, agents=None, - tool_nodes=None, + function_nodes=None, routers=None, base_llm=base_llm, ) @@ -676,7 +676,7 @@ def from_yaml( sub_config = { 'arium': { 'agents': arium_node_config.get('agents', []), - 'tool_nodes': arium_node_config.get('tool_nodes', []), + 'function_nodes': arium_node_config.get('function_nodes', []), 'routers': arium_node_config.get('routers', []), 'ariums': arium_node_config.get( 'ariums', [] @@ -690,7 +690,7 @@ def from_yaml( yaml_str=yaml.dump(sub_config), memory=None, agents=None, - tool_nodes=None, + function_nodes=None, routers=None, base_llm=base_llm, ) @@ -725,7 +725,7 @@ def from_yaml( # Find execute_node from ALL node types execute_node = ( agents_dict.get(execute_node_name) - or tool_nodes_dict.get(execute_node_name) + or function_nodes_dict.get(execute_node_name) or arium_nodes_dict.get(execute_node_name) or foreach_nodes_dict.get(execute_node_name) ) @@ -733,7 +733,7 @@ def from_yaml( if not execute_node: all_nodes = ( list(agents_dict.keys()) - + list(tool_nodes_dict.keys()) + + list(function_nodes_dict.keys()) + list(arium_nodes_dict.keys()) + list(foreach_nodes_dict.keys()) ) @@ -755,7 +755,7 @@ def from_yaml( def _find_node(node_name: str): return ( agents_dict.get(node_name) - or tool_nodes_dict.get(node_name) + or function_nodes_dict.get(node_name) or arium_nodes_dict.get(node_name) or foreach_nodes_dict.get(node_name) ) @@ -769,7 +769,7 @@ def _find_node(node_name: str): if not start_node: all_available = ( list(agents_dict.keys()) - + list(tool_nodes_dict.keys()) + + list(function_nodes_dict.keys()) + list(arium_nodes_dict.keys()) + list(foreach_nodes_dict.keys()) ) diff --git a/flo_ai/flo_ai/arium/nodes.py b/flo_ai/flo_ai/arium/nodes.py index d29a59ea..79c89ea9 100644 --- a/flo_ai/flo_ai/arium/nodes.py +++ b/flo_ai/flo_ai/arium/nodes.py @@ -157,9 +157,9 @@ async def _execute_item_with_isolated_memory( return result -class ToolNode: +class FunctionNode: """ - Lightweight tool-as-node wrapper that conforms to ExecutableNode. + Lightweight function-as-node wrapper that conforms to ExecutableNode. Forwards inputs and variables to the provided function along with any kwargs. """ @@ -183,15 +183,16 @@ async def run( **kwargs, ) -> Any: logger.info( - f"Executing ToolNode '{self.name}' with inputs: {inputs} variables: {variables} kwargs: {kwargs}" + f"Executing FunctionNode '{self.name}' with inputs: {inputs} variables: {variables} kwargs: {kwargs}" ) if asyncio.iscoroutinefunction(self.function): result = await self.function(inputs=inputs, variables=variables, **kwargs) return UserMessage(content=result) - result = self.function(inputs=inputs, variables=variables, **kwargs) if asyncio.iscoroutine(result): content = await result return UserMessage(content=content) + + result = self.function(inputs=inputs, variables=variables, **kwargs) return UserMessage(content=result) From 5d2b1ee96f0a6ca7e5ac5143937f55078c8bb40f Mon Sep 17 00:00:00 2001 From: vishnu r kumar Date: Mon, 17 Nov 2025 19:33:47 +0530 Subject: [PATCH 4/7] fix tests after function node changes --- flo_ai/tests/unit-tests/test_arium_builder.py | 46 +++--- flo_ai/tests/unit-tests/test_arium_yaml.py | 153 ++++++++++-------- flo_ai/tests/unit-tests/test_llm_router.py | 17 +- 3 files changed, 120 insertions(+), 96 deletions(-) diff --git a/flo_ai/tests/unit-tests/test_arium_builder.py b/flo_ai/tests/unit-tests/test_arium_builder.py index 4e6b1308..d85cae39 100644 --- a/flo_ai/tests/unit-tests/test_arium_builder.py +++ b/flo_ai/tests/unit-tests/test_arium_builder.py @@ -7,7 +7,7 @@ from flo_ai.arium.builder import AriumBuilder, create_arium from flo_ai.arium.memory import MessageMemory from flo_ai.models.agent import Agent -from flo_ai.tool.base_tool import Tool +from flo_ai.arium.nodes import FunctionNode class TestAriumBuilder: @@ -16,7 +16,7 @@ def test_builder_initialization(self): builder = AriumBuilder() assert builder._memory is None assert builder._agents == [] - assert builder._tools == [] + assert builder._function_nodes == [] assert builder._start_node is None assert builder._end_nodes == [] assert builder._edges == [] @@ -48,26 +48,28 @@ def test_add_agents(self): assert result is builder assert all(agent in builder._agents for agent in agents) - def test_add_tool(self): - """Test adding a single tool""" + def test_add_function_node(self): + """Test adding a single function node""" builder = AriumBuilder() - tool = Mock(spec=Tool) - tool.name = 'test_tool' + function_node = Mock(spec=FunctionNode) + function_node.name = 'test_function_node' - result = builder.add_tool(tool) + result = builder.add_function_node(function_node) assert result is builder - assert tool in builder._tools + assert function_node in builder._function_nodes - def test_add_tools(self): - """Test adding multiple tools""" + def test_add_function_nodes(self): + """Test adding multiple function nodes""" builder = AriumBuilder() - tools = [Mock(spec=Tool) for _ in range(3)] - for i, tool in enumerate(tools): - tool.name = f'tool_{i}' + function_nodes = [Mock(spec=FunctionNode) for _ in range(3)] + for i, function_node in enumerate(function_nodes): + function_node.name = f'function_node_{i}' - result = builder.add_tools(tools) + result = builder.add_function_nodes(function_nodes) assert result is builder - assert all(tool in builder._tools for tool in tools) + assert all( + function_node in builder._function_nodes for function_node in function_nodes + ) def test_with_memory(self): """Test setting custom memory""" @@ -144,7 +146,7 @@ def test_reset(self): assert result is builder assert builder._memory is None assert builder._agents == [] - assert builder._tools == [] + assert builder._function_nodes == [] assert builder._start_node is None assert builder._end_nodes == [] assert builder._edges == [] @@ -154,7 +156,7 @@ def test_build_validation_no_nodes(self): """Test that build fails when no nodes are added""" builder = AriumBuilder() - with pytest.raises(ValueError, match='No agents or tools added'): + with pytest.raises(ValueError, match='No agents or function nodes added'): builder.build() def test_build_validation_no_start_node(self): @@ -183,18 +185,18 @@ def test_method_chaining(self): builder = AriumBuilder() agent = Mock(spec=Agent) agent.name = 'test_agent' - tool = Mock(spec=Tool) - tool.name = 'test_tool' + function_node = Mock(spec=FunctionNode) + function_node.name = 'test_function_node' memory = Mock(spec=MessageMemory) # This should not raise any errors and should work with chaining result = ( builder.with_memory(memory) .add_agent(agent) - .add_tool(tool) + .add_function_node(function_node) .start_with(agent) - .connect(agent, tool) - .end_with(tool) + .connect(agent, function_node) + .end_with(function_node) .reset() ) diff --git a/flo_ai/tests/unit-tests/test_arium_yaml.py b/flo_ai/tests/unit-tests/test_arium_yaml.py index 3b8c22a7..24caac76 100644 --- a/flo_ai/tests/unit-tests/test_arium_yaml.py +++ b/flo_ai/tests/unit-tests/test_arium_yaml.py @@ -8,7 +8,7 @@ from flo_ai.arium.builder import AriumBuilder from flo_ai.arium.memory import MessageMemory, BaseMemory from flo_ai.models.agent import Agent -from flo_ai.tool.base_tool import Tool +from flo_ai.arium.nodes import FunctionNode from flo_ai.llm import OpenAI @@ -156,8 +156,8 @@ def test_from_yaml_default_memory(self): assert builder._memory is not None assert isinstance(builder._memory, MessageMemory) - def test_from_yaml_with_tools(self): - """Test YAML configuration with tools.""" + def test_from_yaml_with_function_nodes(self): + """Test YAML configuration with function nodes.""" yaml_config = """ arium: agents: @@ -165,28 +165,28 @@ def test_from_yaml_with_tools(self): yaml_config: | agent: name: test_agent - job: "Test agent with tools" + job: "Test agent" model: provider: openai name: gpt-4o-mini - tools: - - name: test_tool + function_nodes: + - name: test_function_node workflow: start: test_agent edges: - from: test_agent - to: [test_tool] - - from: test_tool + to: [test_function_node] + - from: test_function_node to: [end] - end: [test_tool] + end: [test_function_node] """ - # Create mock tool - mock_tool = Mock(spec=Tool) - mock_tool.name = 'test_tool' - tools = {'test_tool': mock_tool} + # Create mock function node + mock_function_node = Mock(spec=FunctionNode) + mock_function_node.name = 'test_function_node' + function_nodes = {'test_function_node': mock_function_node} with patch('flo_ai.arium.builder.AgentBuilder') as mock_agent_builder: mock_agent = Mock(spec=Agent) @@ -196,11 +196,13 @@ def test_from_yaml_with_tools(self): mock_builder_instance.build.return_value = mock_agent mock_agent_builder.from_yaml.return_value = mock_builder_instance - builder = AriumBuilder.from_yaml(yaml_str=yaml_config, tools=tools) + builder = AriumBuilder.from_yaml( + yaml_str=yaml_config, function_nodes=function_nodes + ) - # Verify tools were added - assert len(builder._tools) == 1 - assert builder._tools[0] == mock_tool + # Verify function nodes were added + assert len(builder._function_nodes) == 1 + assert builder._function_nodes[0] == mock_function_node def test_from_yaml_with_routers(self): """Test YAML configuration with custom routers.""" @@ -260,8 +262,8 @@ def test_router(memory: BaseMemory) -> str: assert to_nodes == [mock_agent2] assert router == test_router - def test_from_yaml_missing_tool_error(self): - """Test error when referenced tool is not provided.""" + def test_from_yaml_missing_function_node_error(self): + """Test error when referenced function node is not provided.""" yaml_config = """ arium: agents: @@ -274,25 +276,25 @@ def test_from_yaml_missing_tool_error(self): provider: openai name: gpt-4o-mini - tools: - - name: missing_tool + function_nodes: + - name: missing_function_node workflow: start: test_agent edges: - from: test_agent - to: [missing_tool] - - from: missing_tool + to: [missing_function_node] + - from: missing_function_node to: [end] - end: [missing_tool] + end: [missing_function_node] """ with patch('flo_ai.arium.builder.AgentBuilder'): with pytest.raises( ValueError, - match='Tool missing_tool not found in provided tools dictionary', + match='FunctionNode missing_function_node not found in provided function_nodes dictionary', ): - AriumBuilder.from_yaml(yaml_str=yaml_config, tools={}) + AriumBuilder.from_yaml(yaml_str=yaml_config, function_nodes={}) def test_from_yaml_missing_router_error(self): """Test error when referenced router is not provided.""" @@ -479,7 +481,7 @@ def test_from_yaml_with_base_llm(self): assert 'job: "Test agent"' in call_kwargs['yaml_str'] def test_from_yaml_complex_workflow(self): - """Test complex workflow with multiple agents, tools, and routers.""" + """Test complex workflow with multiple agents, function nodes, and routers.""" yaml_config = """ metadata: name: complex-workflow @@ -515,19 +517,19 @@ def test_from_yaml_complex_workflow(self): provider: openai name: gpt-4o-mini - tools: - - name: data_tool - - name: analysis_tool + function_nodes: + - name: data_function_node + - name: analysis_function_node workflow: start: dispatcher edges: - from: dispatcher - to: [data_tool, analysis_tool, processor] + to: [data_function_node, analysis_function_node, processor] router: dispatch_router - - from: data_tool + - from: data_function_node to: [summarizer] - - from: analysis_tool + - from: analysis_function_node to: [summarizer] - from: processor to: [summarizer] @@ -540,12 +542,15 @@ def test_from_yaml_complex_workflow(self): def dispatch_router(memory: BaseMemory) -> str: return 'processor' - mock_data_tool = Mock(spec=Tool) - mock_data_tool.name = 'data_tool' - mock_analysis_tool = Mock(spec=Tool) - mock_analysis_tool.name = 'analysis_tool' + mock_data_function_node = Mock(spec=FunctionNode) + mock_data_function_node.name = 'data_function_node' + mock_analysis_function_node = Mock(spec=FunctionNode) + mock_analysis_function_node.name = 'analysis_function_node' - tools = {'data_tool': mock_data_tool, 'analysis_tool': mock_analysis_tool} + function_nodes = { + 'data_function_node': mock_data_function_node, + 'analysis_function_node': mock_analysis_function_node, + } routers = {'dispatch_router': dispatch_router} with patch('flo_ai.arium.builder.AgentBuilder') as mock_agent_builder: @@ -565,12 +570,12 @@ def dispatch_router(memory: BaseMemory) -> str: mock_agent_builder.from_yaml.return_value = mock_builder_instance builder = AriumBuilder.from_yaml( - yaml_str=yaml_config, tools=tools, routers=routers + yaml_str=yaml_config, function_nodes=function_nodes, routers=routers ) # Verify all components were configured assert len(builder._agents) == 3 - assert len(builder._tools) == 2 + assert len(builder._function_nodes) == 2 assert len(builder._edges) == 4 # 4 edge definitions assert builder._start_node == mock_dispatcher assert mock_summarizer in builder._end_nodes @@ -657,21 +662,20 @@ def test_from_yaml_direct_agent_configuration(self): ) assert mock_llm.temperature == 0.5 - def test_from_yaml_direct_config_with_tools(self): - """Test direct agent configuration with tools.""" + def test_from_yaml_direct_config_with_function_nodes(self): + """Test direct agent configuration with function nodes.""" yaml_config = """ arium: agents: - name: test_agent - job: "Test agent with tools" + job: "Test agent with function nodes" model: provider: openai name: gpt-4o-mini - tools: ["calculator", "web_search"] - tools: - - name: calculator - - name: web_search + function_nodes: + - name: calculator_function_node + - name: web_search_function_node workflow: start: test_agent @@ -681,26 +685,30 @@ def test_from_yaml_direct_config_with_tools(self): end: [test_agent] """ - # Create mock tools - mock_calculator = Mock(spec=Tool) - mock_calculator.name = 'calculator' - mock_web_search = Mock(spec=Tool) - mock_web_search.name = 'web_search' + # Create mock function nodes + mock_calculator_function_node = Mock(spec=FunctionNode) + mock_calculator_function_node.name = 'calculator_function_node' + mock_web_search_function_node = Mock(spec=FunctionNode) + mock_web_search_function_node.name = 'web_search_function_node' - tools = {'calculator': mock_calculator, 'web_search': mock_web_search} + function_nodes = { + 'calculator_function_node': mock_calculator_function_node, + 'web_search_function_node': mock_web_search_function_node, + } with patch('flo_ai.llm.OpenAI') as mock_openai: mock_llm = Mock() mock_openai.return_value = mock_llm - builder = AriumBuilder.from_yaml(yaml_str=yaml_config, tools=tools) + builder = AriumBuilder.from_yaml( + yaml_str=yaml_config, function_nodes=function_nodes + ) - # Verify agent was configured with tools + # Verify agent was configured with function nodes assert len(builder._agents) == 1 - agent = builder._agents[0] - assert len(agent.tools) == 2 - assert mock_calculator in agent.tools - assert mock_web_search in agent.tools + assert len(builder._function_nodes) == 2 + assert mock_calculator_function_node in builder._function_nodes + assert mock_web_search_function_node in builder._function_nodes def test_from_yaml_direct_config_with_parser(self): """Test direct agent configuration with structured output parser.""" @@ -1109,24 +1117,24 @@ def test_from_yaml_prebuilt_agents_parameter_validation(self): with pytest.raises(ValueError, match='Agent test_agent must have either'): AriumBuilder.from_yaml(yaml_str=yaml_config) - def test_from_yaml_prebuilt_agents_with_tools_and_routers(self): - """Test pre-built agents working together with tools and routers.""" + def test_from_yaml_prebuilt_agents_with_function_nodes_and_routers(self): + """Test pre-built agents working together with function nodes and routers.""" yaml_config = """ arium: agents: - name: dispatcher - name: processor - tools: - - name: calculator + function_nodes: + - name: calculator_function_node workflow: start: dispatcher edges: - from: dispatcher - to: [calculator, processor] + to: [calculator_function_node, processor] router: smart_router - - from: calculator + - from: calculator_function_node to: [processor] - from: processor to: [end] @@ -1139,27 +1147,30 @@ def test_from_yaml_prebuilt_agents_with_tools_and_routers(self): mock_processor = Mock(spec=Agent) mock_processor.name = 'processor' - mock_calculator = Mock(spec=Tool) - mock_calculator.name = 'calculator' + mock_calculator_function_node = Mock(spec=FunctionNode) + mock_calculator_function_node.name = 'calculator_function_node' def smart_router(memory): return 'processor' prebuilt_agents = {'dispatcher': mock_dispatcher, 'processor': mock_processor} - tools = {'calculator': mock_calculator} + function_nodes = {'calculator_function_node': mock_calculator_function_node} routers = {'smart_router': smart_router} builder = AriumBuilder.from_yaml( - yaml_str=yaml_config, agents=prebuilt_agents, tools=tools, routers=routers + yaml_str=yaml_config, + agents=prebuilt_agents, + function_nodes=function_nodes, + routers=routers, ) # Verify everything was configured correctly assert len(builder._agents) == 2 - assert len(builder._tools) == 1 + assert len(builder._function_nodes) == 1 assert len(builder._edges) == 2 assert mock_dispatcher in builder._agents assert mock_processor in builder._agents - assert mock_calculator in builder._tools + assert mock_calculator_function_node in builder._function_nodes if __name__ == '__main__': diff --git a/flo_ai/tests/unit-tests/test_llm_router.py b/flo_ai/tests/unit-tests/test_llm_router.py index 42fa78b8..198797a2 100644 --- a/flo_ai/tests/unit-tests/test_llm_router.py +++ b/flo_ai/tests/unit-tests/test_llm_router.py @@ -13,8 +13,9 @@ create_llm_router, llm_router, ) -from flo_ai.arium.memory import MessageMemory +from flo_ai.arium.memory import MessageMemory, MessageMemoryItem from flo_ai.llm.base_llm import BaseLLM +from flo_ai.models import UserMessage class MockLLM(BaseLLM): @@ -55,8 +56,18 @@ def format_image_in_message(self, image): def mock_memory(): """Create a mock memory with sample conversation""" memory = MessageMemory() - memory.add('I need to research market trends for renewable energy') - memory.add('Please analyze the data and provide insights') + memory.add( + MessageMemoryItem( + node='researcher', + result=UserMessage('I need to research market trends for renewable energy'), + ) + ) + memory.add( + MessageMemoryItem( + node='analyst', + result=UserMessage('Please analyze the data and provide insights'), + ) + ) return memory From 5094e042952c7653b2b9b770cc931ce0afa55e0c Mon Sep 17 00:00:00 2001 From: vishnu r kumar Date: Tue, 18 Nov 2025 16:37:18 +0530 Subject: [PATCH 5/7] update tests, examples -- function nodes from yaml --- flo_ai/examples/arium_yaml_example.py | 379 ++++++++++++++++-- flo_ai/examples/chat_history.py | 38 +- flo_ai/flo_ai/arium/builder.py | 66 ++- flo_ai/flo_ai/arium/nodes.py | 3 + flo_ai/flo_ai/llm/base_llm.py | 2 + flo_ai/flo_ai/models/agent.py | 29 +- flo_ai/tests/unit-tests/test_arium_builder.py | 23 -- flo_ai/tests/unit-tests/test_arium_yaml.py | 88 ++-- 8 files changed, 459 insertions(+), 169 deletions(-) diff --git a/flo_ai/examples/arium_yaml_example.py b/flo_ai/examples/arium_yaml_example.py index 46e3cc59..51c67aa1 100644 --- a/flo_ai/examples/arium_yaml_example.py +++ b/flo_ai/examples/arium_yaml_example.py @@ -6,12 +6,14 @@ """ import asyncio -from typing import Dict, Literal +from typing import Dict, Literal, Callable, Optional from flo_ai.arium.builder import AriumBuilder -from flo_ai.models import TextMessageContent, UserMessage +from flo_ai.arium.memory import MessageMemoryItem +from flo_ai.models import UserMessage from flo_ai.tool.base_tool import Tool from flo_ai.llm import OpenAI from flo_ai.arium.memory import BaseMemory +from typing import List # Example YAML configuration for a simple linear workflow (using direct agent definition) @@ -139,6 +141,141 @@ """ +# Example YAML configuration with function nodes +FUNCTION_NODES_WORKFLOW_YAML = """ +metadata: + name: function-nodes-workflow + version: 1.0.0 + description: "A workflow demonstrating function nodes for data processing" + +arium: + agents: + - name: data_analyzer + role: Data Analyzer + job: > + You are a data analyzer. Analyze the input data and extract key information. + Pass the analyzed data to the next processing step. + model: + provider: openai + name: gpt-4o-mini + settings: + temperature: 0.3 + + - name: report_formatter + role: Report Formatter + job: > + You are a report formatter. Take the processed data and format it into + a well-structured final report. + model: + provider: openai + name: gpt-4o-mini + settings: + temperature: 0.2 + + function_nodes: + - name: data_validator + function_name: validate_data + description: "Validates and cleans input data" + + - name: data_transformer + function_name: transform_data + description: "Transforms data into a structured format" + + workflow: + start: data_analyzer + edges: + - from: data_analyzer + to: [data_validator] + - from: data_validator + to: [data_transformer] + - from: data_transformer + to: [report_formatter] + - from: report_formatter + to: [end] + end: [report_formatter] +""" + + +# Example YAML configuration mixing agents, function nodes, and routing +MIXED_NODES_WORKFLOW_YAML = """ +metadata: + name: mixed-nodes-workflow + version: 1.0.0 + description: "A workflow mixing agents, function nodes, and conditional routing" + +arium: + agents: + - name: dispatcher + role: Workflow Dispatcher + job: > + You are a workflow dispatcher. Analyze the input and determine the appropriate + processing path: mathematical operations, text processing, or direct summarization. + model: + provider: openai + name: gpt-4o-mini + settings: + reasoning_pattern: REACT + + - name: math_agent + role: Mathematics Agent + job: > + You are a mathematics agent. Perform mathematical analysis and calculations + on the provided data. + model: + provider: openai + name: gpt-4o-mini + tools: ["calculator"] + settings: + reasoning_pattern: REACT + + - name: text_agent + role: Text Processing Agent + job: > + You are a text processing agent. Analyze and process text content. + model: + provider: openai + name: gpt-4o-mini + tools: ["text_processor"] + settings: + reasoning_pattern: REACT + + - name: final_summarizer + role: Final Summarizer + job: > + You are the final summarizer. Create a comprehensive summary of all processing results. + model: + provider: openai + name: gpt-4o-mini + + function_nodes: + - name: preprocessor + function_name: preprocess_input + description: "Preprocesses input data before agent processing" + + - name: result_aggregator + function_name: aggregate_results + description: "Aggregates results from multiple processing paths" + + workflow: + start: dispatcher + edges: + - from: dispatcher + to: [preprocessor] + - from: preprocessor + to: [math_agent, text_agent, final_summarizer] + router: dispatch_router + - from: math_agent + to: [result_aggregator] + - from: text_agent + to: [result_aggregator] + - from: result_aggregator + to: [final_summarizer] + - from: final_summarizer + to: [end] + end: [final_summarizer] +""" + + # Example showing mixed configuration approaches MIXED_CONFIG_YAML = """ metadata: @@ -244,6 +381,31 @@ def research_router( return 'final_summarizer' +# Custom router function for the mixed nodes workflow +def dispatch_router( + memory: BaseMemory, +) -> Literal['math_agent', 'text_agent', 'final_summarizer']: + """ + Custom router for the mixed nodes workflow that decides the next step. + """ + memory_content = memory.get() + latest_message = memory_content[-1] if memory_content else {} + content_text = str(latest_message).lower() + + if any( + keyword in content_text + for keyword in ['calculate', 'math', 'number', 'compute', 'add', 'multiply'] + ): + return 'math_agent' + elif any( + keyword in content_text + for keyword in ['text', 'analyze', 'process', 'parse', 'word'] + ): + return 'text_agent' + else: + return 'final_summarizer' + + async def create_example_tools() -> Dict[str, Tool]: """Create example tools for the workflow.""" @@ -309,6 +471,67 @@ async def process_text(text: str, operation: str = 'analyze') -> str: } +async def create_example_functions() -> Dict[str, Callable]: + """Create example functions for function nodes.""" + from typing import List, Any + + async def validate_data( + inputs: List[Any], variables: Optional[Dict[str, Any]] = None, **kwargs + ) -> str: + """Validates and cleans input data.""" + text = str(inputs[-1]) if inputs else '' + + cleaned = ' '.join(text.split()) + if not cleaned: + return 'Error: Empty input data' + + validation_result = ( + f'✓ Data validated: {len(cleaned)} characters, {len(cleaned.split())} words' + ) + return validation_result + + async def transform_data( + inputs: List[Any], variables: Optional[Dict[str, Any]] = None, **kwargs + ) -> str: + """Transforms data into a structured format.""" + text = str(inputs[-1]) if inputs else '' + + words = text.split() + transformed = f'STRUCTURED DATA:\n- Word count: {len(words)}\n- Character count: {len(text)}\n- Content: {text[:100]}...' + return transformed + + async def preprocess_input( + inputs: List[Any], variables: Optional[Dict[str, Any]] = None, **kwargs + ) -> str: + """Preprocesses input data before agent processing.""" + text = str(inputs[-1]) if inputs else '' + + normalized = text.strip().lower() + preprocessed = f'[PREPROCESSED] {normalized}' + return preprocessed + + async def aggregate_results( + inputs: List[Any], variables: Optional[Dict[str, Any]] = None, **kwargs + ) -> str: + """Aggregates results from multiple processing paths.""" + results = [] + for inp in inputs: + if hasattr(inp, 'content'): + results.append(str(inp.content)) + else: + results.append(str(inp)) + + aggregated = 'AGGREGATED RESULTS:\n' + '\n'.join(f'- {r}' for r in results) + return aggregated + + return { + 'validate_data': validate_data, + 'transform_data': transform_data, + 'preprocess_input': preprocess_input, + 'aggregate_results': aggregate_results, + } + + async def run_simple_example(): """Run the simple workflow example.""" print('=' * 60) @@ -319,22 +542,20 @@ async def run_simple_example(): builder = AriumBuilder.from_yaml(yaml_str=SIMPLE_WORKFLOW_YAML) # Run the workflow - result = await builder.build_and_run( + result: List[MessageMemoryItem] = await builder.build_and_run( [ UserMessage( - TextMessageContent( - text='Machine learning is transforming healthcare by enabling predictive analytics, ' - 'personalized treatment recommendations, and automated medical imaging analysis. ' - 'However, challenges include data privacy concerns, the need for regulatory approval, ' - 'and ensuring AI systems are transparent and unbiased in their decision-making.', - ) - ), + 'Machine learning is transforming healthcare by enabling predictive analytics, ' + 'personalized treatment recommendations, and automated medical imaging analysis. ' + 'However, challenges include data privacy concerns, the need for regulatory approval, ' + 'and ensuring AI systems are transparent and unbiased in their decision-making.', + ) ] ) print('Result:') - for i, message in enumerate(result): - print(f'{i+1}. {message}') + for i, message in enumerate[MessageMemoryItem](result): + print(f'{i+1}. {message.result.content}') return result @@ -353,7 +574,7 @@ async def run_complex_example(): # Create builder from YAML builder = AriumBuilder.from_yaml( - yaml_str=COMPLEX_WORKFLOW_YAML, tools=tools, routers=routers + yaml_str=COMPLEX_WORKFLOW_YAML, tool_registry=tools, routers=routers ) # Test with mathematical content @@ -363,34 +584,30 @@ async def run_complex_example(): ) print('Result:') - for i, message in enumerate(result1): - print(f'{i+1}. {message}') + for i, message in enumerate[MessageMemoryItem](result1): + print(f'{i+1}. {message.result.content}') # Reset and test with text content print('\nTesting with text content:') builder.reset() builder = AriumBuilder.from_yaml( - yaml_str=COMPLEX_WORKFLOW_YAML, tools=tools, routers=routers + yaml_str=COMPLEX_WORKFLOW_YAML, tool_registry=tools, routers=routers ) result2 = await builder.build_and_run( [ UserMessage( - TextMessageContent( - text="Please analyze this text and process it: 'The quick brown fox jumps over the lazy dog. ", - ) + "Please analyze this text and process it: 'The quick brown fox jumps over the lazy dog.'" ), UserMessage( - TextMessageContent( - text="This sentence contains every letter of the alphabet at least once.'", - ) + "This sentence contains every letter of the alphabet at least once.'" ), ] ) print('Result:') - for i, message in enumerate(result2): - print(f'{i+1}. {message}') + for i, message in enumerate[MessageMemoryItem](result2): + print(f'{i+1}. {message.result.content}') return result1, result2 @@ -408,23 +625,107 @@ async def run_mixed_config_example(): result = await builder.build_and_run( [ UserMessage( - TextMessageContent( - text='Please analyze this business report: Our Q3 revenue increased by 15% compared to Q2, ' - 'driven primarily by strong performance in the software division. However, hardware sales ' - 'declined by 8%. Customer satisfaction scores improved to 4.2/5.0. We recommend focusing ' - 'on digital transformation initiatives and reconsidering the hardware product line.', - ) + 'Please analyze this business report: Our Q3 revenue increased by 15% compared to Q2, ' + 'driven primarily by strong performance in the software division. However, hardware sales ' + 'declined by 8%. Customer satisfaction scores improved to 4.2/5.0. We recommend focusing ' + 'on digital transformation initiatives and reconsidering the hardware product line.' + ), + ] + ) + + print('Result:') + for i, message in enumerate[MessageMemoryItem](result): + print(f'{i+1}. {message.result.content}') + + return result + + +async def run_function_nodes_example(): + """Run the function nodes workflow example.""" + print('\n' + '=' * 60) + print('RUNNING FUNCTION NODES WORKFLOW EXAMPLE') + print('=' * 60) + + # Create function registry + functions = await create_example_functions() + + # Create builder from YAML with function registry + builder = AriumBuilder.from_yaml( + yaml_str=FUNCTION_NODES_WORKFLOW_YAML, function_registry=functions + ) + + # Run the workflow + result = await builder.build_and_run( + [ + UserMessage( + 'Sample data for processing: Customer satisfaction scores show 85% positive feedback, ' + 'with response times averaging 2.3 minutes. Revenue increased by 12% this quarter.' ), ] ) print('Result:') - for i, message in enumerate(result): - print(f'{i+1}. {message}') + for i, message in enumerate[MessageMemoryItem](result): + print(f'{i+1}. {message.result.content}') return result +async def run_mixed_nodes_example(): + """Run the mixed nodes workflow example with agents, function nodes, and routing.""" + print('\n' + '=' * 60) + print('RUNNING MIXED NODES WORKFLOW EXAMPLE') + print('=' * 60) + + # Create tools and functions + tools = await create_example_tools() + functions = await create_example_functions() + routers = {'dispatch_router': dispatch_router} + + # Create builder from YAML + builder = AriumBuilder.from_yaml( + yaml_str=MIXED_NODES_WORKFLOW_YAML, + tool_registry=tools, + function_registry=functions, + routers=routers, + ) + + # Test with mathematical content + print('\nTesting with mathematical content:') + result1 = await builder.build_and_run( + ['Please calculate the sum of 15 and 27, then multiply by 2.'] + ) + + print('Result:') + for i, message in enumerate[MessageMemoryItem](result1): + print(f'{i+1}. {message.result.content}') + + # Reset and test with text content + print('\nTesting with text content:') + builder.reset() + builder = AriumBuilder.from_yaml( + yaml_str=MIXED_NODES_WORKFLOW_YAML, + tool_registry=tools, + function_registry=functions, + routers=routers, + ) + + result2 = await builder.build_and_run( + [ + UserMessage( + 'Please analyze and process this text: Machine learning algorithms ' + 'are transforming how we process and understand data.' + ), + ] + ) + + print('Result:') + for i, message in enumerate[MessageMemoryItem](result2): + print(f'{i+1}. {message.result.content}') + + return result1, result2 + + async def run_prebuilt_agents_example(): """Run example using pre-built agents with YAML workflow.""" print('\n' + '=' * 60) @@ -530,8 +831,8 @@ async def run_prebuilt_agents_example(): ) print('Result:') - for i, message in enumerate(result): - print(f'{i+1}. {message}') + for i, message in enumerate[MessageMemoryItem](result): + print(f'{i+1}. {message.result.content}') return result @@ -548,6 +849,12 @@ async def main(): # Run mixed configuration example await run_mixed_config_example() + # Run function nodes example + await run_function_nodes_example() + + # Run mixed nodes example + await run_mixed_nodes_example() + # Run pre-built agents example await run_prebuilt_agents_example() @@ -558,7 +865,9 @@ async def main(): print(' • Simple linear workflow with direct agent configuration') print(' • Complex workflow with tools and conditional routing') print(' • Mixed configuration approaches') - print(' • Pre-built agents with YAML workflow (NEW!)') + print(' • Function nodes workflow (NEW!)') + print(' • Mixed nodes workflow with agents and function nodes (NEW!)') + print(' • Pre-built agents with YAML workflow') except Exception as e: print(f'Error running examples: {e}') diff --git a/flo_ai/examples/chat_history.py b/flo_ai/examples/chat_history.py index 4009f525..12c32fab 100644 --- a/flo_ai/examples/chat_history.py +++ b/flo_ai/examples/chat_history.py @@ -1,13 +1,9 @@ import asyncio -from typing import Any +from typing import List from flo_ai.builder.agent_builder import AgentBuilder from flo_ai.llm import Gemini from flo_ai.models.agent import Agent -from flo_ai.models import ( - AssistantMessage, - UserMessage, - TextMessageContent, -) +from flo_ai.models import AssistantMessage, UserMessage, BaseMessage from flo_ai.tool import flo_tool @@ -34,37 +30,21 @@ async def main() -> None: .build() ) - response: Any = await agent.run( + response: List[BaseMessage] = await agent.run( [ - UserMessage( - TextMessageContent( - text='What is the formula for the area of a circle?' - ), - ), - AssistantMessage( - TextMessageContent( - text='The formula for the area of a circle is πr^2.' - ), - ), - UserMessage( - TextMessageContent( - text='What is the formula for the area of a rectangle?' - ) - ), + UserMessage('What is the formula for the area of a circle?'), + AssistantMessage('The formula for the area of a circle is πr^2.'), + UserMessage('What is the formula for the area of a rectangle?'), AssistantMessage( - TextMessageContent( - text='The formula for the area of a rectangle is length * width.', - ), + 'The formula for the area of a rectangle is length * width.' ), UserMessage( - TextMessageContent( - text='What is the area of a rectable of length and breadth ', - ), + 'What is the area of a rectable of length and breadth ' ), ], variables={'length': 10, 'breadth': 70}, ) - print(f'Response: {response}') + print(f'Response: {response[-1].content}') asyncio.run(main()) diff --git a/flo_ai/flo_ai/arium/builder.py b/flo_ai/flo_ai/arium/builder.py index 950d5d61..1da67333 100644 --- a/flo_ai/flo_ai/arium/builder.py +++ b/flo_ai/flo_ai/arium/builder.py @@ -1,6 +1,6 @@ from typing import List, Optional, Callable, Union, Dict, Any from flo_ai.arium.arium import Arium -from flo_ai.arium.memory import MessageMemory, BaseMemory +from flo_ai.arium.memory import MessageMemory, BaseMemory, MessageMemoryItem from flo_ai.arium.protocols import ExecutableNode from flo_ai.arium.nodes import AriumNode, ForEachNode from flo_ai.models import BaseMessage, UserMessage @@ -60,12 +60,12 @@ def add_agents(self, agents: List[Agent]) -> 'AriumBuilder': return self def add_function_node(self, function_node: FunctionNode) -> 'AriumBuilder': - """Add a tool node to the Arium.""" + """Add a function node to the Arium.""" self._function_nodes.append(function_node) return self def add_function_nodes(self, function_nodes: List[FunctionNode]) -> 'AriumBuilder': - """Add multiple tool nodes to the Arium.""" + """Add multiple function nodes to the Arium.""" self._function_nodes.extend(function_nodes) return self @@ -244,7 +244,7 @@ async def build_and_run( self, inputs: List[BaseMessage] | str, variables: Optional[Dict[str, Any]] = None, - ) -> List[dict]: + ) -> List[MessageMemoryItem]: """Build the Arium and run it with the given inputs and optional runtime variables.""" arium = self.build() new_inputs = [] @@ -287,9 +287,10 @@ def from_yaml( yaml_file: Optional[str] = None, memory: Optional[BaseMemory] = None, agents: Optional[Dict[str, Agent]] = None, - function_nodes: Optional[Dict[str, FunctionNode]] = None, routers: Optional[Dict[str, Callable]] = None, base_llm: Optional[BaseLLM] = None, + function_registry: Optional[Dict[str, Callable]] = None, + tool_registry: Optional[Dict[str, Tool]] = None, ) -> 'AriumBuilder': """Create an AriumBuilder from a YAML configuration. @@ -298,10 +299,10 @@ def from_yaml( yaml_file: Path to YAML file containing arium configuration memory: Memory instance to use for the workflow (defaults to MessageMemory) agents: Dictionary mapping agent names to pre-built Agent instances - function_nodes: Dictionary mapping function names to FunctionNode instances routers: Dictionary mapping router names to router functions base_llm: Base LLM to use for all agents if not specified in individual agent configs - + function_registry: Dictionary mapping function names to function objects + tool_registry: Dictionary mapping tool names to Tool objects Returns: AriumBuilder: Configured builder instance @@ -482,19 +483,25 @@ def from_yaml( and 'yaml_config' not in agent_config and 'yaml_file' not in agent_config ): - agent = cls._create_agent_from_direct_config(agent_config, base_llm) + agent = cls._create_agent_from_direct_config( + agent_config, base_llm, tool_registry + ) # Method 3: Inline YAML config elif 'yaml_config' in agent_config: agent_builder = AgentBuilder.from_yaml( - yaml_str=agent_config['yaml_config'], base_llm=base_llm + yaml_str=agent_config['yaml_config'], + base_llm=base_llm, + tool_registry=tool_registry, ) agent = agent_builder.build() # Method 4: External file reference elif 'yaml_file' in agent_config: agent_builder: AgentBuilder = AgentBuilder.from_yaml( - yaml_file=agent_config['yaml_file'], base_llm=base_llm + yaml_file=agent_config['yaml_file'], + base_llm=base_llm, + tool_registry=tool_registry, ) agent = agent_builder.build() @@ -510,38 +517,29 @@ def from_yaml( agents_dict[agent_name] = agent builder.add_agent(agent) - # Process tool nodes + # Process function nodes function_nodes_config = arium_config.get('function_nodes', []) function_nodes_dict = {} for function_node_config in function_nodes_config: function_node_name = function_node_config['name'] + function_name = function_node_config['function_name'] + function = function_registry.get(function_name) - # Add a function node from pre-built function nodes - if len(function_node_config) == 1 and 'name' in function_node_config: - if function_nodes and function_node_name in function_nodes: - function_node = function_nodes[function_node_name] - else: - raise ValueError( - f'FunctionNode {function_node_name} not found in provided function_nodes dictionary. ' - f'Available function_nodes: {list(function_nodes.keys()) if function_nodes else []}. ' - f'Either provide the FunctionNode in the function_nodes parameter or add configuration fields.' - ) - else: - # Add a function node from a direct FunctionNode definition (function must be provided in code, YAML cannot define callables) - function = function_node_config.get('function') - if function is None: - ValueError( - f'Function for FunctionNode {function_node_name} is not provided' - ) - - function_node = FunctionNode( - name=function_node_name, - description=function_node_config.get('description', ''), - function=function, - input_filter=function_node_config.get('input_filter', None), + if function is None: + raise ValueError( + f'Function {function_name} not found in provided function_registry dictionary. ' + f'Available functions: {list[str](function_registry.keys()) if function_registry else []}. ' + f'Either provide the function in the function_registry parameter or add configuration fields.' ) + function_node = FunctionNode( + name=function_node_name, + description=function_node_config.get('description', ''), + function=function, + input_filter=function_node_config.get('input_filter', None), + ) + function_nodes_dict[function_node_name] = function_node builder.add_function_node(function_node) diff --git a/flo_ai/flo_ai/arium/nodes.py b/flo_ai/flo_ai/arium/nodes.py index 79c89ea9..e463cb27 100644 --- a/flo_ai/flo_ai/arium/nodes.py +++ b/flo_ai/flo_ai/arium/nodes.py @@ -187,12 +187,15 @@ async def run( ) if asyncio.iscoroutinefunction(self.function): + logger.info(f"Executing FunctionNode '{self.name}' as a coroutine function") result = await self.function(inputs=inputs, variables=variables, **kwargs) return UserMessage(content=result) if asyncio.iscoroutine(result): + logger.info(f"Executing FunctionNode '{self.name}' as a coroutine") content = await result return UserMessage(content=content) + logger.info(f"Executing FunctionNode '{self.name}' as a regular function") result = self.function(inputs=inputs, variables=variables, **kwargs) return UserMessage(content=result) diff --git a/flo_ai/flo_ai/llm/base_llm.py b/flo_ai/flo_ai/llm/base_llm.py index 15a68fd6..75864bc6 100644 --- a/flo_ai/flo_ai/llm/base_llm.py +++ b/flo_ai/flo_ai/llm/base_llm.py @@ -20,6 +20,7 @@ async def generate( self, messages: List[Dict[str, str]], functions: Optional[List[Dict[str, Any]]] = None, + output_schema: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Generate a response from the LLM""" pass @@ -29,6 +30,7 @@ async def stream( self, messages: List[Dict[str, str]], functions: Optional[List[Dict[str, Any]]] = None, + output_schema: Optional[Dict[str, Any]] = None, ) -> AsyncIterator[Dict[str, Any]]: """Stream partial responses from the LLM as they are generated""" pass diff --git a/flo_ai/flo_ai/models/agent.py b/flo_ai/flo_ai/models/agent.py index 660fc2be..092f3bd1 100644 --- a/flo_ai/flo_ai/models/agent.py +++ b/flo_ai/flo_ai/models/agent.py @@ -155,24 +155,31 @@ async def _run_conversational( assistant_message = self.llm.get_message_content(response) logger.debug(f'Extracted message: {assistant_message}') + # Ensure act_as is not None (default to 'assistant' if missing) + role = self.act_as if self.act_as is not None else MessageType.ASSISTANT + if assistant_message: - # Ensure act_as is not None (default to 'assistant' if missing) - role = ( - self.act_as - if self.act_as is not None - else MessageType.ASSISTANT - ) self.add_to_history( AssistantMessage(role=role, content=assistant_message) ) - - return self.conversation_history else: possible_tool_message = await self.llm.get_function_call(response) if possible_tool_message: - return possible_tool_message['arguments'] - logger.debug('Warning: No message content found in response') - return None + self.add_to_history( + AssistantMessage( + role=role, content=possible_tool_message['arguments'] + ) + ) + else: + logger.debug('Warning: No message content found in response') + self.add_to_history( + AssistantMessage( + role=role, + content='No message content found in response', + ) + ) + + return self.conversation_history except Exception as e: retry_count += 1 diff --git a/flo_ai/tests/unit-tests/test_arium_builder.py b/flo_ai/tests/unit-tests/test_arium_builder.py index d85cae39..edd85c59 100644 --- a/flo_ai/tests/unit-tests/test_arium_builder.py +++ b/flo_ai/tests/unit-tests/test_arium_builder.py @@ -180,33 +180,10 @@ def test_build_validation_no_end_nodes(self): with pytest.raises(ValueError, match='No end nodes specified'): builder.build() - def test_method_chaining(self): - """Test that all methods return self for chaining""" - builder = AriumBuilder() - agent = Mock(spec=Agent) - agent.name = 'test_agent' - function_node = Mock(spec=FunctionNode) - function_node.name = 'test_function_node' - memory = Mock(spec=MessageMemory) - - # This should not raise any errors and should work with chaining - result = ( - builder.with_memory(memory) - .add_agent(agent) - .add_function_node(function_node) - .start_with(agent) - .connect(agent, function_node) - .end_with(function_node) - .reset() - ) - - assert result is builder - if __name__ == '__main__': # Run a simple test test_builder = TestAriumBuilder() test_builder.test_builder_initialization() test_builder.test_add_agent() - test_builder.test_method_chaining() print('Basic tests passed!') diff --git a/flo_ai/tests/unit-tests/test_arium_yaml.py b/flo_ai/tests/unit-tests/test_arium_yaml.py index 24caac76..d47b011e 100644 --- a/flo_ai/tests/unit-tests/test_arium_yaml.py +++ b/flo_ai/tests/unit-tests/test_arium_yaml.py @@ -8,7 +8,6 @@ from flo_ai.arium.builder import AriumBuilder from flo_ai.arium.memory import MessageMemory, BaseMemory from flo_ai.models.agent import Agent -from flo_ai.arium.nodes import FunctionNode from flo_ai.llm import OpenAI @@ -157,7 +156,7 @@ def test_from_yaml_default_memory(self): assert isinstance(builder._memory, MessageMemory) def test_from_yaml_with_function_nodes(self): - """Test YAML configuration with function nodes.""" + """Test YAML configuration with function nodes using function_registry.""" yaml_config = """ arium: agents: @@ -172,6 +171,8 @@ def test_from_yaml_with_function_nodes(self): function_nodes: - name: test_function_node + function_name: test_function + description: "Test function node" workflow: start: test_agent @@ -183,10 +184,11 @@ def test_from_yaml_with_function_nodes(self): end: [test_function_node] """ - # Create mock function node - mock_function_node = Mock(spec=FunctionNode) - mock_function_node.name = 'test_function_node' - function_nodes = {'test_function_node': mock_function_node} + # Create mock function + async def test_function(inputs): + return 'processed' + + function_registry = {'test_function': test_function} with patch('flo_ai.arium.builder.AgentBuilder') as mock_agent_builder: mock_agent = Mock(spec=Agent) @@ -197,12 +199,13 @@ def test_from_yaml_with_function_nodes(self): mock_agent_builder.from_yaml.return_value = mock_builder_instance builder = AriumBuilder.from_yaml( - yaml_str=yaml_config, function_nodes=function_nodes + yaml_str=yaml_config, function_registry=function_registry ) # Verify function nodes were added assert len(builder._function_nodes) == 1 - assert builder._function_nodes[0] == mock_function_node + assert builder._function_nodes[0].name == 'test_function_node' + assert builder._function_nodes[0].function == test_function def test_from_yaml_with_routers(self): """Test YAML configuration with custom routers.""" @@ -262,8 +265,8 @@ def test_router(memory: BaseMemory) -> str: assert to_nodes == [mock_agent2] assert router == test_router - def test_from_yaml_missing_function_node_error(self): - """Test error when referenced function node is not provided.""" + def test_from_yaml_missing_function_in_registry_error(self): + """Test error when referenced function is not in function_registry.""" yaml_config = """ arium: agents: @@ -278,6 +281,7 @@ def test_from_yaml_missing_function_node_error(self): function_nodes: - name: missing_function_node + function_name: missing_function workflow: start: test_agent @@ -292,9 +296,9 @@ def test_from_yaml_missing_function_node_error(self): with patch('flo_ai.arium.builder.AgentBuilder'): with pytest.raises( ValueError, - match='FunctionNode missing_function_node not found in provided function_nodes dictionary', + match='Function missing_function not found in provided function_registry dictionary', ): - AriumBuilder.from_yaml(yaml_str=yaml_config, function_nodes={}) + AriumBuilder.from_yaml(yaml_str=yaml_config, function_registry={}) def test_from_yaml_missing_router_error(self): """Test error when referenced router is not provided.""" @@ -434,7 +438,7 @@ def test_from_yaml_external_file_reference(self): # Verify AgentBuilder.from_yaml was called with yaml_file mock_agent_builder.from_yaml.assert_called_with( - yaml_file='path/to/agent.yaml', base_llm=None + yaml_file='path/to/agent.yaml', base_llm=None, tool_registry=None ) def test_from_yaml_with_base_llm(self): @@ -519,7 +523,9 @@ def test_from_yaml_complex_workflow(self): function_nodes: - name: data_function_node + function_name: data_function - name: analysis_function_node + function_name: analysis_function workflow: start: dispatcher @@ -542,14 +548,15 @@ def test_from_yaml_complex_workflow(self): def dispatch_router(memory: BaseMemory) -> str: return 'processor' - mock_data_function_node = Mock(spec=FunctionNode) - mock_data_function_node.name = 'data_function_node' - mock_analysis_function_node = Mock(spec=FunctionNode) - mock_analysis_function_node.name = 'analysis_function_node' + async def data_function(inputs): + return 'data processed' + + async def analysis_function(inputs): + return 'analysis done' - function_nodes = { - 'data_function_node': mock_data_function_node, - 'analysis_function_node': mock_analysis_function_node, + function_registry = { + 'data_function': data_function, + 'analysis_function': analysis_function, } routers = {'dispatch_router': dispatch_router} @@ -570,7 +577,9 @@ def dispatch_router(memory: BaseMemory) -> str: mock_agent_builder.from_yaml.return_value = mock_builder_instance builder = AriumBuilder.from_yaml( - yaml_str=yaml_config, function_nodes=function_nodes, routers=routers + yaml_str=yaml_config, + function_registry=function_registry, + routers=routers, ) # Verify all components were configured @@ -675,7 +684,9 @@ def test_from_yaml_direct_config_with_function_nodes(self): function_nodes: - name: calculator_function_node + function_name: calculator - name: web_search_function_node + function_name: web_search workflow: start: test_agent @@ -685,15 +696,16 @@ def test_from_yaml_direct_config_with_function_nodes(self): end: [test_agent] """ - # Create mock function nodes - mock_calculator_function_node = Mock(spec=FunctionNode) - mock_calculator_function_node.name = 'calculator_function_node' - mock_web_search_function_node = Mock(spec=FunctionNode) - mock_web_search_function_node.name = 'web_search_function_node' + # Create mock functions + async def calculator(inputs): + return 'calculated' + + async def web_search(inputs): + return 'searched' - function_nodes = { - 'calculator_function_node': mock_calculator_function_node, - 'web_search_function_node': mock_web_search_function_node, + function_registry = { + 'calculator': calculator, + 'web_search': web_search, } with patch('flo_ai.llm.OpenAI') as mock_openai: @@ -701,14 +713,15 @@ def test_from_yaml_direct_config_with_function_nodes(self): mock_openai.return_value = mock_llm builder = AriumBuilder.from_yaml( - yaml_str=yaml_config, function_nodes=function_nodes + yaml_str=yaml_config, function_registry=function_registry ) # Verify agent was configured with function nodes assert len(builder._agents) == 1 assert len(builder._function_nodes) == 2 - assert mock_calculator_function_node in builder._function_nodes - assert mock_web_search_function_node in builder._function_nodes + function_node_names = [fn.name for fn in builder._function_nodes] + assert 'calculator_function_node' in function_node_names + assert 'web_search_function_node' in function_node_names def test_from_yaml_direct_config_with_parser(self): """Test direct agent configuration with structured output parser.""" @@ -1127,6 +1140,7 @@ def test_from_yaml_prebuilt_agents_with_function_nodes_and_routers(self): function_nodes: - name: calculator_function_node + function_name: calculator workflow: start: dispatcher @@ -1147,20 +1161,20 @@ def test_from_yaml_prebuilt_agents_with_function_nodes_and_routers(self): mock_processor = Mock(spec=Agent) mock_processor.name = 'processor' - mock_calculator_function_node = Mock(spec=FunctionNode) - mock_calculator_function_node.name = 'calculator_function_node' + async def calculator(inputs): + return 'calculated' def smart_router(memory): return 'processor' prebuilt_agents = {'dispatcher': mock_dispatcher, 'processor': mock_processor} - function_nodes = {'calculator_function_node': mock_calculator_function_node} + function_registry = {'calculator': calculator} routers = {'smart_router': smart_router} builder = AriumBuilder.from_yaml( yaml_str=yaml_config, agents=prebuilt_agents, - function_nodes=function_nodes, + function_registry=function_registry, routers=routers, ) @@ -1170,7 +1184,7 @@ def smart_router(memory): assert len(builder._edges) == 2 assert mock_dispatcher in builder._agents assert mock_processor in builder._agents - assert mock_calculator_function_node in builder._function_nodes + assert builder._function_nodes[0].name == 'calculator_function_node' if __name__ == '__main__': From 4b5edfcea894230d8f6fa22c701c91ddd5fa54ab Mon Sep 17 00:00:00 2001 From: vishnu r kumar Date: Tue, 18 Nov 2025 16:56:53 +0530 Subject: [PATCH 6/7] fix ruff formating --- flo_ai/flo_ai/arium/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flo_ai/flo_ai/arium/base.py b/flo_ai/flo_ai/arium/base.py index e67b71af..08f40691 100644 --- a/flo_ai/flo_ai/arium/base.py +++ b/flo_ai/flo_ai/arium/base.py @@ -14,7 +14,9 @@ class BaseArium: def __init__(self): self.start_node_name = '__start__' self.end_node_names: set = set() # Support multiple end nodes - self.nodes: Dict[str, ExecutableNode | StartNode | EndNode] = dict[str, ExecutableNode | StartNode | EndNode]() + self.nodes: Dict[str, ExecutableNode | StartNode | EndNode] = dict[ + str, ExecutableNode | StartNode | EndNode + ]() self.edges: Dict[str, Edge] = dict[str, Edge]() def add_nodes(self, agents: List[ExecutableNode | StartNode | EndNode]): From feec144e3db55ed2337a33d6709b6e560993e3d8 Mon Sep 17 00:00:00 2001 From: vishnu r kumar Date: Tue, 18 Nov 2025 17:05:57 +0530 Subject: [PATCH 7/7] add tool registry and function registry in nested builderr --- flo_ai/flo_ai/arium/builder.py | 10 ++++++++-- flo_ai/pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/flo_ai/flo_ai/arium/builder.py b/flo_ai/flo_ai/arium/builder.py index 6855d216..8ee349da 100644 --- a/flo_ai/flo_ai/arium/builder.py +++ b/flo_ai/flo_ai/arium/builder.py @@ -345,7 +345,11 @@ def from_yaml( function_nodes: - name: function1 + function_name: function1 - name: function2 + function_name: function2 + description: "Function 2" + input_filter: ["input1", "input2"] # LLM Router definitions (NEW) routers: @@ -663,9 +667,10 @@ def from_yaml( yaml_file=yaml_file_path, memory=None, agents=None, - function_nodes=None, routers=None, base_llm=base_llm, + function_registry=None, + tool_registry=None, ) nested_arium = nested_builder.build() @@ -689,9 +694,10 @@ def from_yaml( yaml_str=yaml.dump(sub_config), memory=None, agents=None, - function_nodes=None, routers=None, base_llm=base_llm, + function_registry=None, + tool_registry=None, ) nested_arium = nested_builder.build() diff --git a/flo_ai/pyproject.toml b/flo_ai/pyproject.toml index 0fbe710b..7334a634 100644 --- a/flo_ai/pyproject.toml +++ b/flo_ai/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "flo_ai" -version = "1.0.7-rc5" +version = "1.0.8-rc3" description = "A easy way to create structured AI agents" authors = [{ name = "rootflo", email = "*@rootflo.ai" }] requires-python = ">=3.10,<4.0"