diff --git a/flo_ai/flo_ai/arium/arium.py b/flo_ai/flo_ai/arium/arium.py index feb7a686..4ebab1ab 100644 --- a/flo_ai/flo_ai/arium/arium.py +++ b/flo_ai/flo_ai/arium/arium.py @@ -50,11 +50,7 @@ async def run( List of workflow execution results """ if isinstance(inputs, str): - inputs = [ - UserMessage( - TextMessageContent(text=resolve_variables(inputs, variables)) - ) - ] + inputs = [UserMessage(content=resolve_variables(inputs, variables))] if not self.is_compiled: raise ValueError('Arium is not compiled') @@ -475,13 +471,19 @@ async def _execute_node( elif isinstance(node, FunctionNode): result = await node.run(inputs, variables=None) elif isinstance(node, ForEachNode): - result = await node.run( + foreach_results: List[ + MessageMemoryItem | BaseMessage + ] = await node.run( inputs, variables=variables, ) + result = self._flatten_results(foreach_results) elif isinstance(node, AriumNode): # AriumNode execution - result = await node.run(inputs, variables=variables) + arium_result: List[MessageMemoryItem] = await node.run( + inputs, variables=variables + ) + result = self._flatten_results(arium_result) elif isinstance(node, StartNode): result = None elif isinstance(node, EndNode): @@ -556,12 +558,18 @@ async def _execute_node( elif isinstance(node, FunctionNode): result = await node.run(inputs, variables=None) elif isinstance(node, ForEachNode): - result = await node.run( + foreach_results: List[ + MessageMemoryItem | BaseMessage + ] = await node.run( inputs, variables=variables, ) + result = self._flatten_results(foreach_results) elif isinstance(node, AriumNode): - result = await node.run(inputs, variables=variables) + arium_result: List[MessageMemoryItem] = await node.run( + inputs, variables=variables + ) + result = self._flatten_results(arium_result) elif isinstance(node, StartNode): result = None elif isinstance(node, EndNode): @@ -602,6 +610,23 @@ async def _execute_node( # Re-raise the exception raise e + def _flatten_results( + self, sequence: List[MessageMemoryItem | BaseMessage | str] + ) -> List[BaseMessage | str]: + """ + Flatten a sequence of results by extracting .result from MessageMemoryItem instances. + + Args: + sequence: List of items that may be MessageMemoryItem, BaseMessage, or str + + Returns: + List of BaseMessage or str with MessageMemoryItem layers removed + """ + return [ + item.result if isinstance(item, MessageMemoryItem) else item + for item in sequence + ] + def _add_to_memory(self, message: MessageMemoryItem): """ Store message in memory diff --git a/flo_ai/flo_ai/arium/base.py b/flo_ai/flo_ai/arium/base.py index 08f40691..19dbc900 100644 --- a/flo_ai/flo_ai/arium/base.py +++ b/flo_ai/flo_ai/arium/base.py @@ -6,6 +6,7 @@ from flo_ai.tool.base_tool import Tool from flo_ai.utils.logger import logger from typing import List, Optional, Callable, Literal, get_origin, get_args, Dict +from collections.abc import Awaitable as AwaitableABC from flo_ai.arium.models import StartNode, EndNode, Edge, default_router from pathlib import Path @@ -59,9 +60,21 @@ def _check_router_return_type(self, router: Callable) -> Optional[List]: if return_annotation == inspect.Signature.empty: return None - # Check if the return type is a Literal + # Check if the return type is a Literal or Awaitable[Literal[...]] origin = get_origin(return_annotation) + # Handle Awaitable[Literal[...]] for async router functions + if origin is AwaitableABC: + # Unwrap the Awaitable to get the inner type + args = get_args(return_annotation) + if args: + inner_type = args[0] + inner_origin = get_origin(inner_type) + if inner_origin is Literal: + # Extract the literal values from the inner Literal type + literal_values = list(get_args(inner_type)) + return literal_values + # In Python 3.8+, Literal types have get_origin() return typing.Literal if origin is Literal: # Extract the literal values diff --git a/flo_ai/flo_ai/arium/llm_router.py b/flo_ai/flo_ai/arium/llm_router.py index 24bc1b7f..ea721861 100644 --- a/flo_ai/flo_ai/arium/llm_router.py +++ b/flo_ai/flo_ai/arium/llm_router.py @@ -6,9 +6,14 @@ """ from abc import ABC, abstractmethod -from typing import Dict, Optional, Callable, Any, Union, get_args, List +from typing import Dict, Optional, Callable, Any, Union, get_args, List, Awaitable from functools import wraps -from flo_ai.arium.memory import BaseMemory, ExecutionPlan, StepStatus +from flo_ai.arium.memory import ( + ExecutionPlan, + StepStatus, + MessageMemory, + MessageMemoryItem, +) from flo_ai.llm.base_llm import BaseLLM from flo_ai.llm import OpenAI from flo_ai.utils.logger import logger @@ -57,7 +62,7 @@ def get_routing_options(self) -> Dict[str, str]: @abstractmethod def get_routing_prompt( self, - memory: BaseMemory, + memory: MessageMemory, options: Dict[str, str], execution_context: dict = None, ) -> str: @@ -96,7 +101,7 @@ def get_fallback_route(self, options: Dict[str, str]) -> str: else: return routes[0] - async def route(self, memory: BaseMemory, execution_context: dict = None) -> str: + async def route(self, memory: MessageMemory, execution_context: dict = None) -> str: """ Make a routing decision using the LLM. @@ -180,19 +185,15 @@ def get_routing_options(self) -> Dict[str, str]: def get_routing_prompt( self, - memory: BaseMemory, + memory: MessageMemory, options: Dict[str, str], execution_context: dict = None, ) -> str: - conversation = memory.get() + conversation: List[MessageMemoryItem] = memory.get() - # Format conversation history with smart truncation - if isinstance(conversation, list): - # Start with last message and add more if we have space - messages = conversation[-5:] # Last 5 messages - conversation_text = self._truncate_conversation_for_tokens(messages) - else: - conversation_text = str(conversation) + conversation_text = self._truncate_conversation_for_tokens( + [f'{item.node}: {item.result.content}' for item in conversation[-5:]] + ) # Format options options_text = '\n'.join( @@ -245,7 +246,7 @@ def get_routing_prompt( return prompt def _truncate_conversation_for_tokens( - self, messages: List[Any], max_tokens: int = 128000 + self, messages: List[str], max_tokens: int = 128000 ) -> str: """ Intelligently truncate conversation to fit within token limits. @@ -309,15 +310,15 @@ def get_routing_options(self) -> Dict[str, str]: def get_routing_prompt( self, - memory: BaseMemory, + memory: MessageMemory, options: Dict[str, str], execution_context: dict = None, ) -> str: - conversation = memory.get() + conversation: List[MessageMemoryItem] = memory.get() # Get the latest user input or task if isinstance(conversation, list) and conversation: - latest_task = str(conversation[-1]) + latest_task = str(conversation[-1].result.content) else: latest_task = str(conversation) @@ -445,16 +446,16 @@ def _get_next_step_in_pattern(self, execution_context: dict) -> Optional[str]: def get_routing_prompt( self, - memory: BaseMemory, + memory: MessageMemory, options: Dict[str, str], execution_context: dict = None, ) -> str: - conversation = memory.get() + conversation: List[MessageMemoryItem] = memory.get() # Format conversation history if isinstance(conversation, list): conversation_text = '\n'.join( - [str(msg) for msg in conversation[-3:]] + [msg.result.content for msg in conversation[-3:]] ) # Last 3 messages for flow context else: conversation_text = str(conversation) @@ -575,16 +576,16 @@ def get_routing_options(self) -> Dict[str, str]: def get_routing_prompt( self, - memory: BaseMemory, + memory: MessageMemory, options: Dict[str, str], execution_context: dict = None, ) -> str: - conversation = memory.get() + conversation: List[MessageMemoryItem] = memory.get() # Format conversation history if isinstance(conversation, list): conversation_text = '\n'.join( - [str(msg) for msg in conversation[-3:]] + [msg.result.content for msg in conversation[-3:]] ) # Last 3 messages for context else: conversation_text = str(conversation) @@ -760,17 +761,20 @@ def get_routing_options(self) -> Dict[str, str]: def get_routing_prompt( self, - memory: BaseMemory, + memory: MessageMemory, options: Dict[str, str], execution_context: dict = None, ) -> str: - conversation = memory.get() + conversation: List[MessageMemoryItem] = memory.get() # Analyze recent conversation if isinstance(conversation, list): recent_messages = conversation[-self.analysis_depth :] conversation_text = '\n'.join( - [f'Message {i+1}: {msg}' for i, msg in enumerate(recent_messages)] + [ + f'Message {i+1}: {msg.result.content}' + for i, msg in enumerate(recent_messages) + ] ) else: conversation_text = str(conversation) @@ -825,7 +829,9 @@ def get_routing_prompt( return prompt -def create_llm_router(router_type: str, **config) -> Callable[[BaseMemory], str]: +def create_llm_router( + router_type: str, **config +) -> Callable[[MessageMemory, Optional[dict]], Awaitable[str]]: """ Factory function to create LLM-powered routers with different configurations. @@ -937,12 +943,15 @@ def create_llm_router(router_type: str, **config) -> Callable[[BaseMemory], str] literal_type = Literal[option_names] # Return a function that can be used as a router - async def router_function(memory: BaseMemory, execution_context: dict = None): + async def router_function(memory: MessageMemory, execution_context: dict = None): """Generated router function that uses LLM for routing decisions""" return await router_instance.route(memory, execution_context) # Add proper type annotations for validation - router_function.__annotations__ = {'memory': BaseMemory, 'return': literal_type} + router_function.__annotations__ = { + 'memory': MessageMemory, + 'return': Awaitable[literal_type], + } # Transfer router instance attributes to the function for validation router_function.supports_self_reference = getattr( @@ -976,7 +985,7 @@ def llm_router( "analyst": "Analyze data and perform calculations", "writer": "Create reports and summaries" }) - def my_smart_router(memory: BaseMemory) -> Literal["researcher", "analyst", "writer"]: + def my_smart_router(memory: MessageMemory) -> Literal["researcher", "analyst", "writer"]: pass # Implementation is provided by decorator """ @@ -1011,7 +1020,7 @@ def decorator(func): ) @wraps(func) - async def wrapper(memory: BaseMemory, execution_context: dict = None): + async def wrapper(memory: MessageMemory, execution_context: dict = None): return await router_instance.route(memory, execution_context) # Preserve the original function's type annotations including return type @@ -1034,7 +1043,7 @@ def create_research_analysis_router( analysis_agent: str = 'analyst', summary_agent: str = 'summarizer', llm: Optional[BaseLLM] = None, -) -> Callable[[BaseMemory], str]: +) -> Callable[[MessageMemory, Optional[dict]], Awaitable[str]]: """ Create a router for common research -> analysis -> summary workflows. @@ -1109,7 +1118,7 @@ def create_main_critic_reflection_router( final_agent: str = 'final_agent', allow_early_exit: bool = False, llm: Optional[BaseLLM] = None, -) -> Callable[[BaseMemory], str]: +) -> Callable[[MessageMemory, Optional[dict]], Awaitable[str]]: """ Create a router for the A -> B -> A -> C reflection pattern (main -> critic -> main -> final). @@ -1137,7 +1146,7 @@ def create_plan_execute_router( reviewer_agent: Optional[str] = None, additional_agents: Optional[Dict[str, str]] = None, llm: Optional[BaseLLM] = None, -) -> Callable[[BaseMemory], str]: +) -> Callable[[MessageMemory, Optional[dict]], Awaitable[str]]: """ Create a router for plan-and-execute workflows like Cursor. @@ -1179,7 +1188,7 @@ def create_main_critic_flow_router( final_agent: str = 'final_agent', allow_early_exit: bool = False, llm: Optional[BaseLLM] = None, -) -> Callable[[BaseMemory], str]: +) -> Callable[[MessageMemory, Optional[dict]], Awaitable[str]]: """ DEPRECATED: Use create_main_critic_reflection_router instead. Create a router for the A -> B -> A -> C reflection pattern (main -> critic -> main -> final). diff --git a/flo_ai/flo_ai/arium/memory.py b/flo_ai/flo_ai/arium/memory.py index af0db6d2..c6420986 100644 --- a/flo_ai/flo_ai/arium/memory.py +++ b/flo_ai/flo_ai/arium/memory.py @@ -20,12 +20,10 @@ class StepStatus(Enum): class MessageMemoryItem: - def __init__( - self, node: str, occurrence: int = 0, result: BaseMessage | str = None - ): + def __init__(self, node: str, occurrence: int = 0, result: BaseMessage = None): self.node: str = node self.occurrence: int = occurrence - self.result: BaseMessage | str = result + self.result: BaseMessage = result def to_dict(self) -> Dict[str, Any]: return { diff --git a/flo_ai/flo_ai/models/agent.py b/flo_ai/flo_ai/models/agent.py index 092f3bd1..3af3fa35 100644 --- a/flo_ai/flo_ai/models/agent.py +++ b/flo_ai/flo_ai/models/agent.py @@ -9,6 +9,7 @@ UserMessage, TextMessageContent, FunctionMessage, + SystemMessage, ) from flo_ai.tool.base_tool import Tool, ToolExecutionError from flo_ai.models.agent_error import AgentError @@ -139,10 +140,7 @@ async def _run_conversational( if self.reasoning_pattern == ReasoningPattern.COT else resolve_variables(self.system_prompt, variables) ) - system_message = AssistantMessage( - role=MessageType.SYSTEM, - content=TextMessageContent(text=system_content), - ) + system_message = SystemMessage(content=system_content) self.add_to_history(system_message) messages = await self._get_message_history(variables) @@ -220,9 +218,7 @@ async def _run_with_tools( else: system_content = resolve_variables(self.system_prompt, variables) - system_message = AssistantMessage( - role=MessageType.SYSTEM, content=system_content - ) + system_message = SystemMessage(content=system_content) self.add_to_history(system_message) messages = await self._get_message_history(variables) @@ -401,11 +397,8 @@ async def _run_with_tools( ) # Generate final response if we've hit the tool call limit or exited the loop - system_message = AssistantMessage( - role=MessageType.SYSTEM, - content=TextMessageContent( - text='Please provide a final answer based on all the tool results above.' - ), + system_message = SystemMessage( + content='Please provide a final answer based on all the tool results above.' ) self.add_to_history(system_message) messages = await self._get_message_history(variables) diff --git a/flo_ai/pyproject.toml b/flo_ai/pyproject.toml index 453cc861..c80ad674 100644 --- a/flo_ai/pyproject.toml +++ b/flo_ai/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "flo_ai" -version = "1.1.0-rc2" +version = "1.1.0-rc3" description = "A easy way to create structured AI agents" authors = [{ name = "rootflo", email = "*@rootflo.ai" }] requires-python = ">=3.10,<4.0" diff --git a/flo_ai/tests/unit-tests/test_router_fix.py b/flo_ai/tests/unit-tests/test_router_fix.py index 291fa5f7..187473be 100644 --- a/flo_ai/tests/unit-tests/test_router_fix.py +++ b/flo_ai/tests/unit-tests/test_router_fix.py @@ -5,6 +5,7 @@ import inspect from typing import get_origin, get_args, Literal +from collections.abc import Awaitable as AwaitableABC from flo_ai.arium import create_llm_router from flo_ai.llm import OpenAI @@ -31,18 +32,43 @@ def test_router_type_annotation(): print(f'Return annotation: {return_annotation}') print(f'Return annotation type: {type(return_annotation)}') - # Check if it's a Literal type + # Check if it's Awaitable[Literal[...]] or Literal type origin = get_origin(return_annotation) print(f'Origin: {origin}') + print(f'Is Awaitable: {origin is AwaitableABC}') print(f'Is Literal: {origin is Literal}') - if origin is Literal: + # Handle Awaitable[Literal[...]] for async router functions + if origin is AwaitableABC: + args = get_args(return_annotation) + if args: + inner_type = args[0] + inner_origin = get_origin(inner_type) + if inner_origin is Literal: + literal_values = list(get_args(inner_type)) + print(f'Literal values (from Awaitable): {literal_values}') + assert ( + True + ), 'Router function has correct Awaitable[Literal] type annotation' + else: + print('❌ Awaitable contains non-Literal type!') + assert ( + False + ), 'Router function should have Awaitable[Literal] type annotation' + else: + print('❌ Awaitable has no args!') + assert ( + False + ), 'Router function should have Awaitable[Literal] type annotation' + elif origin is Literal: literal_values = list(get_args(return_annotation)) print(f'Literal values: {literal_values}') assert True, 'Router function has correct Literal type annotation' else: - print('❌ Not a Literal type!') - assert False, 'Router function should have Literal type annotation' + print('❌ Not an Awaitable[Literal] or Literal type!') + assert ( + False + ), 'Router function should have Awaitable[Literal] or Literal type annotation' def test_validation_logic(): @@ -65,18 +91,42 @@ def test_validation_logic(): print('❌ No return annotation') return False - # Check if the return type is a Literal + # Check if the return type is a Literal or Awaitable[Literal[...]] origin = get_origin(return_annotation) + # Handle Awaitable[Literal[...]] for async router functions + if origin is AwaitableABC: + # Unwrap the Awaitable to get the inner type + args = get_args(return_annotation) + if args: + inner_type = args[0] + inner_origin = get_origin(inner_type) + if inner_origin is Literal: + # Extract the literal values from the inner Literal type + literal_values = list(get_args(inner_type)) + print( + f'✅ Validation passed! Literal values (from Awaitable): {literal_values}' + ) + assert True, 'Validation logic works correctly' + else: + print( + f'❌ Validation failed! Awaitable contains {inner_origin}, not Literal' + ) + assert False, f'Validation failed! Awaitable contains {inner_origin}, not Literal' + else: + print('❌ Validation failed! Awaitable has no args') + assert False, 'Validation failed! Awaitable has no args' # In Python 3.8+, Literal types have get_origin() return typing.Literal - if origin is Literal: + elif origin is Literal: # Extract the literal values literal_values = list(get_args(return_annotation)) print(f'✅ Validation passed! Literal values: {literal_values}') assert True, 'Validation logic works correctly' else: - print(f'❌ Validation failed! Origin is {origin}, not Literal') - assert False, f'Validation failed! Origin is {origin}, not Literal' + print(f'❌ Validation failed! Origin is {origin}, not Awaitable or Literal') + assert ( + False + ), f'Validation failed! Origin is {origin}, not Awaitable or Literal' except Exception as e: print(f'❌ Exception during validation: {e}')