Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions flo_ai/flo_ai/arium/arium.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,7 @@ async def run(
List of workflow execution results
"""
if isinstance(inputs, str):
inputs = [
UserMessage(
TextMessageContent(text=resolve_variables(inputs, variables))
)
]
inputs = [UserMessage(content=resolve_variables(inputs, variables))]

if not self.is_compiled:
raise ValueError('Arium is not compiled')
Expand Down Expand Up @@ -475,13 +471,19 @@ async def _execute_node(
elif isinstance(node, FunctionNode):
result = await node.run(inputs, variables=None)
elif isinstance(node, ForEachNode):
result = await node.run(
foreach_results: List[
MessageMemoryItem | BaseMessage
] = await node.run(
inputs,
variables=variables,
)
result = self._flatten_results(foreach_results)
elif isinstance(node, AriumNode):
# AriumNode execution
result = await node.run(inputs, variables=variables)
arium_result: List[MessageMemoryItem] = await node.run(
inputs, variables=variables
)
result = self._flatten_results(arium_result)
elif isinstance(node, StartNode):
result = None
elif isinstance(node, EndNode):
Expand Down Expand Up @@ -556,12 +558,18 @@ async def _execute_node(
elif isinstance(node, FunctionNode):
result = await node.run(inputs, variables=None)
elif isinstance(node, ForEachNode):
result = await node.run(
foreach_results: List[
MessageMemoryItem | BaseMessage
] = await node.run(
inputs,
variables=variables,
)
result = self._flatten_results(foreach_results)
elif isinstance(node, AriumNode):
result = await node.run(inputs, variables=variables)
arium_result: List[MessageMemoryItem] = await node.run(
inputs, variables=variables
)
result = self._flatten_results(arium_result)
elif isinstance(node, StartNode):
result = None
elif isinstance(node, EndNode):
Expand Down Expand Up @@ -602,6 +610,23 @@ async def _execute_node(
# Re-raise the exception
raise e

def _flatten_results(
self, sequence: List[MessageMemoryItem | BaseMessage | str]
) -> List[BaseMessage | str]:
"""
Flatten a sequence of results by extracting .result from MessageMemoryItem instances.

Args:
sequence: List of items that may be MessageMemoryItem, BaseMessage, or str

Returns:
List of BaseMessage or str with MessageMemoryItem layers removed
"""
return [
item.result if isinstance(item, MessageMemoryItem) else item
for item in sequence
]

def _add_to_memory(self, message: MessageMemoryItem):
"""
Store message in memory
Expand Down
15 changes: 14 additions & 1 deletion flo_ai/flo_ai/arium/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flo_ai.tool.base_tool import Tool
from flo_ai.utils.logger import logger
from typing import List, Optional, Callable, Literal, get_origin, get_args, Dict
from collections.abc import Awaitable as AwaitableABC
from flo_ai.arium.models import StartNode, EndNode, Edge, default_router
from pathlib import Path

Expand Down Expand Up @@ -59,9 +60,21 @@ def _check_router_return_type(self, router: Callable) -> Optional[List]:
if return_annotation == inspect.Signature.empty:
return None

# Check if the return type is a Literal
# Check if the return type is a Literal or Awaitable[Literal[...]]
origin = get_origin(return_annotation)

# Handle Awaitable[Literal[...]] for async router functions
if origin is AwaitableABC:
# Unwrap the Awaitable to get the inner type
args = get_args(return_annotation)
if args:
inner_type = args[0]
inner_origin = get_origin(inner_type)
if inner_origin is Literal:
# Extract the literal values from the inner Literal type
literal_values = list(get_args(inner_type))
return literal_values

# In Python 3.8+, Literal types have get_origin() return typing.Literal
if origin is Literal:
# Extract the literal values
Expand Down
79 changes: 44 additions & 35 deletions flo_ai/flo_ai/arium/llm_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
"""

from abc import ABC, abstractmethod
from typing import Dict, Optional, Callable, Any, Union, get_args, List
from typing import Dict, Optional, Callable, Any, Union, get_args, List, Awaitable
from functools import wraps
from flo_ai.arium.memory import BaseMemory, ExecutionPlan, StepStatus
from flo_ai.arium.memory import (
ExecutionPlan,
StepStatus,
MessageMemory,
MessageMemoryItem,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
from flo_ai.llm.base_llm import BaseLLM
from flo_ai.llm import OpenAI
from flo_ai.utils.logger import logger
Expand Down Expand Up @@ -57,7 +62,7 @@ def get_routing_options(self) -> Dict[str, str]:
@abstractmethod
def get_routing_prompt(
self,
memory: BaseMemory,
memory: MessageMemory,
options: Dict[str, str],
execution_context: dict = None,
) -> str:
Expand Down Expand Up @@ -96,7 +101,7 @@ def get_fallback_route(self, options: Dict[str, str]) -> str:
else:
return routes[0]

async def route(self, memory: BaseMemory, execution_context: dict = None) -> str:
async def route(self, memory: MessageMemory, execution_context: dict = None) -> str:
"""
Make a routing decision using the LLM.

Expand Down Expand Up @@ -180,19 +185,15 @@ def get_routing_options(self) -> Dict[str, str]:

def get_routing_prompt(
self,
memory: BaseMemory,
memory: MessageMemory,
options: Dict[str, str],
execution_context: dict = None,
) -> str:
conversation = memory.get()
conversation: List[MessageMemoryItem] = memory.get()

# Format conversation history with smart truncation
if isinstance(conversation, list):
# Start with last message and add more if we have space
messages = conversation[-5:] # Last 5 messages
conversation_text = self._truncate_conversation_for_tokens(messages)
else:
conversation_text = str(conversation)
conversation_text = self._truncate_conversation_for_tokens(
[f'{item.node}: {item.result.content}' for item in conversation[-5:]]
)

# Format options
options_text = '\n'.join(
Expand Down Expand Up @@ -245,7 +246,7 @@ def get_routing_prompt(
return prompt

def _truncate_conversation_for_tokens(
self, messages: List[Any], max_tokens: int = 128000
self, messages: List[str], max_tokens: int = 128000
) -> str:
"""
Intelligently truncate conversation to fit within token limits.
Expand Down Expand Up @@ -309,15 +310,15 @@ def get_routing_options(self) -> Dict[str, str]:

def get_routing_prompt(
self,
memory: BaseMemory,
memory: MessageMemory,
options: Dict[str, str],
execution_context: dict = None,
) -> str:
conversation = memory.get()
conversation: List[MessageMemoryItem] = memory.get()

# Get the latest user input or task
if isinstance(conversation, list) and conversation:
latest_task = str(conversation[-1])
latest_task = str(conversation[-1].result.content)
else:
latest_task = str(conversation)

Expand Down Expand Up @@ -445,16 +446,16 @@ def _get_next_step_in_pattern(self, execution_context: dict) -> Optional[str]:

def get_routing_prompt(
self,
memory: BaseMemory,
memory: MessageMemory,
options: Dict[str, str],
execution_context: dict = None,
) -> str:
conversation = memory.get()
conversation: List[MessageMemoryItem] = memory.get()

# Format conversation history
if isinstance(conversation, list):
conversation_text = '\n'.join(
[str(msg) for msg in conversation[-3:]]
[msg.result.content for msg in conversation[-3:]]
) # Last 3 messages for flow context
else:
conversation_text = str(conversation)
Expand Down Expand Up @@ -575,16 +576,16 @@ def get_routing_options(self) -> Dict[str, str]:

def get_routing_prompt(
self,
memory: BaseMemory,
memory: MessageMemory,
options: Dict[str, str],
execution_context: dict = None,
) -> str:
conversation = memory.get()
conversation: List[MessageMemoryItem] = memory.get()

# Format conversation history
if isinstance(conversation, list):
conversation_text = '\n'.join(
[str(msg) for msg in conversation[-3:]]
[msg.result.content for msg in conversation[-3:]]
) # Last 3 messages for context
else:
conversation_text = str(conversation)
Expand Down Expand Up @@ -760,17 +761,20 @@ def get_routing_options(self) -> Dict[str, str]:

def get_routing_prompt(
self,
memory: BaseMemory,
memory: MessageMemory,
options: Dict[str, str],
execution_context: dict = None,
) -> str:
conversation = memory.get()
conversation: List[MessageMemoryItem] = memory.get()

# Analyze recent conversation
if isinstance(conversation, list):
recent_messages = conversation[-self.analysis_depth :]
conversation_text = '\n'.join(
[f'Message {i+1}: {msg}' for i, msg in enumerate(recent_messages)]
[
f'Message {i+1}: {msg.result.content}'
for i, msg in enumerate(recent_messages)
]
)
else:
conversation_text = str(conversation)
Expand Down Expand Up @@ -825,7 +829,9 @@ def get_routing_prompt(
return prompt


def create_llm_router(router_type: str, **config) -> Callable[[BaseMemory], str]:
def create_llm_router(
router_type: str, **config
) -> Callable[[MessageMemory, Optional[dict]], Awaitable[str]]:
"""
Factory function to create LLM-powered routers with different configurations.

Expand Down Expand Up @@ -937,12 +943,15 @@ def create_llm_router(router_type: str, **config) -> Callable[[BaseMemory], str]
literal_type = Literal[option_names]

# Return a function that can be used as a router
async def router_function(memory: BaseMemory, execution_context: dict = None):
async def router_function(memory: MessageMemory, execution_context: dict = None):
"""Generated router function that uses LLM for routing decisions"""
return await router_instance.route(memory, execution_context)

# Add proper type annotations for validation
router_function.__annotations__ = {'memory': BaseMemory, 'return': literal_type}
router_function.__annotations__ = {
'memory': MessageMemory,
'return': Awaitable[literal_type],
}

# Transfer router instance attributes to the function for validation
router_function.supports_self_reference = getattr(
Expand Down Expand Up @@ -976,7 +985,7 @@ def llm_router(
"analyst": "Analyze data and perform calculations",
"writer": "Create reports and summaries"
})
def my_smart_router(memory: BaseMemory) -> Literal["researcher", "analyst", "writer"]:
def my_smart_router(memory: MessageMemory) -> Literal["researcher", "analyst", "writer"]:
pass # Implementation is provided by decorator
"""

Expand Down Expand Up @@ -1011,7 +1020,7 @@ def decorator(func):
)

@wraps(func)
async def wrapper(memory: BaseMemory, execution_context: dict = None):
async def wrapper(memory: MessageMemory, execution_context: dict = None):
return await router_instance.route(memory, execution_context)

# Preserve the original function's type annotations including return type
Expand All @@ -1034,7 +1043,7 @@ def create_research_analysis_router(
analysis_agent: str = 'analyst',
summary_agent: str = 'summarizer',
llm: Optional[BaseLLM] = None,
) -> Callable[[BaseMemory], str]:
) -> Callable[[MessageMemory, Optional[dict]], Awaitable[str]]:
"""
Create a router for common research -> analysis -> summary workflows.

Expand Down Expand Up @@ -1109,7 +1118,7 @@ def create_main_critic_reflection_router(
final_agent: str = 'final_agent',
allow_early_exit: bool = False,
llm: Optional[BaseLLM] = None,
) -> Callable[[BaseMemory], str]:
) -> Callable[[MessageMemory, Optional[dict]], Awaitable[str]]:
"""
Create a router for the A -> B -> A -> C reflection pattern (main -> critic -> main -> final).

Expand Down Expand Up @@ -1137,7 +1146,7 @@ def create_plan_execute_router(
reviewer_agent: Optional[str] = None,
additional_agents: Optional[Dict[str, str]] = None,
llm: Optional[BaseLLM] = None,
) -> Callable[[BaseMemory], str]:
) -> Callable[[MessageMemory, Optional[dict]], Awaitable[str]]:
"""
Create a router for plan-and-execute workflows like Cursor.

Expand Down Expand Up @@ -1179,7 +1188,7 @@ def create_main_critic_flow_router(
final_agent: str = 'final_agent',
allow_early_exit: bool = False,
llm: Optional[BaseLLM] = None,
) -> Callable[[BaseMemory], str]:
) -> Callable[[MessageMemory, Optional[dict]], Awaitable[str]]:
"""
DEPRECATED: Use create_main_critic_reflection_router instead.
Create a router for the A -> B -> A -> C reflection pattern (main -> critic -> main -> final).
Expand Down
6 changes: 2 additions & 4 deletions flo_ai/flo_ai/arium/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ class StepStatus(Enum):


class MessageMemoryItem:
def __init__(
self, node: str, occurrence: int = 0, result: BaseMessage | str = None
):
def __init__(self, node: str, occurrence: int = 0, result: BaseMessage = None):
self.node: str = node
self.occurrence: int = occurrence
self.result: BaseMessage | str = result
self.result: BaseMessage = result

def to_dict(self) -> Dict[str, Any]:
return {
Expand Down
Loading