Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 52 additions & 14 deletions flo_ai/flo_ai/agent/base_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Dict, Any, List, Tuple, cast, Optional
from abc import ABC, abstractmethod
from enum import Enum
Expand All @@ -9,6 +10,7 @@
FunctionMessage,
)
from flo_ai.utils.variable_extractor import resolve_variables
from flo_ai.utils.profiler import aprofile


class AgentType(Enum):
Expand Down Expand Up @@ -85,43 +87,79 @@ def clear_history(self):
self.conversation_history = []

async def _get_message_history(self, variables: Optional[Dict[str, Any]] = None):
async with aprofile(f'agent.{self.name}.get_message_history'):
return await self._get_message_history_impl(variables)

async def _get_message_history_impl(
self, variables: Optional[Dict[str, Any]] = None
):
"""Build the message list passed to the LLM from the conversation history.

Document formatting (the expensive step — PDF rasterization or
extraction) is dispatched concurrently via ``asyncio.gather`` and
cached on the ``DocumentMessageContent`` instance by the underlying
LLM, so the same document is formatted at most once per LLM across
all nodes and retries in a workflow.
"""
variables = variables if variables is not None else {}
message_history = []
for input in self.conversation_history:
# Handle FunctionMessage (OpenAI function role format)

# First pass: kick off one formatting coroutine per *unique* document
# instance. If the same DocumentMessageContent is referenced at
# multiple indices, we share the single in-flight task so we never
# rasterize it twice concurrently.
doc_tasks_by_id: Dict[int, 'asyncio.Future[Any]'] = {}
doc_id_by_idx: Dict[int, int] = {}
for idx, input in enumerate(self.conversation_history):
if (
not isinstance(input, FunctionMessage)
and isinstance(input.content, MediaMessageContent)
and input.content.type == 'document'
):
doc_id = id(input.content)
doc_id_by_idx[idx] = doc_id
if doc_id not in doc_tasks_by_id:
doc_tasks_by_id[doc_id] = asyncio.ensure_future(
self.llm.format_document_in_message(input.content) # type: ignore[arg-type]
)

if doc_tasks_by_id:
formatted_docs = await asyncio.gather(*doc_tasks_by_id.values())
formatted_by_doc_id: Dict[int, Any] = dict(
zip(doc_tasks_by_id.keys(), formatted_docs)
)
else:
formatted_by_doc_id = {}

# Second pass: assemble the provider-ready message list.
message_history: List[Dict[str, Any]] = []
for idx, input in enumerate(self.conversation_history):
if isinstance(input, FunctionMessage):
message_history.append(
{'role': input.role, 'name': input.name, 'content': input.content}
)
# CRITICAL: Check content type FIRST, before message type
# This ensures TextMessageContent objects are converted to strings
elif isinstance(input.content, TextMessageContent):
resolved_content = resolve_variables(input.content.text, variables)
message_history.append(
{'role': input.role, 'content': resolved_content}
)
elif isinstance(input.content, MediaMessageContent):
if input.content.type == 'image':
# Format image message and add to history
formatted_content = self.llm.format_image_in_message(input.content) # type: ignore
formatted_content = self.llm.format_image_in_message(input.content) # type: ignore[arg-type]
message_history.append(
{'role': input.role, 'content': formatted_content}
)

elif input.content.type == 'document':
# Format document message and add to history
formatted_content = await self.llm.format_document_in_message(
input.content # type: ignore
)
message_history.append(
{'role': input.role, 'content': formatted_content}
{
'role': input.role,
'content': formatted_by_doc_id[doc_id_by_idx[idx]],
}
)
else:
raise ValueError(
f'Invalid media message content type: {input.content.type}'
)
elif isinstance(input.content, str):
# Handle other messages with string content (UserMessage, SystemMessage, etc.)
resolved_content = resolve_variables(input.content, variables)
message_history.append(
{'role': input.role, 'content': resolved_content}
Expand Down
107 changes: 56 additions & 51 deletions flo_ai/flo_ai/arium/arium.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from flo_ai.telemetry.instrumentation import workflow_metrics
from flo_ai.telemetry import get_tracer
from flo_ai.utils.profiler import aprofile, record as _profile_record
from opentelemetry.trace import Status, StatusCode
import asyncio
import time
Expand Down Expand Up @@ -203,6 +204,20 @@ async def _execute_graph(
event_callback: Optional[Callable[[AriumEvent], None]] = None,
events_filter: Optional[List[AriumEventType]] = None,
variables: Optional[Dict[str, Any]] = None,
):
async with aprofile(
f'arium.execute_graph[{getattr(self, "name", "unnamed_workflow")}]'
):
return await self._execute_graph_impl(
inputs, event_callback, events_filter, variables
)

async def _execute_graph_impl(
self,
inputs: List[BaseMessage],
event_callback: Optional[Callable[[AriumEvent], None]] = None,
events_filter: Optional[List[AriumEventType]] = None,
variables: Optional[Dict[str, Any]] = None,
):
variables = variables if variables is not None else {}
[
Expand Down Expand Up @@ -464,33 +479,9 @@ async def _execute_node(
},
) as node_span:
try:
# Execute the node based on its type

if isinstance(node, Agent):
# Variables are already resolved, pass empty dict to avoid re-processing
result = await node.run(inputs, variables={})
elif isinstance(node, FunctionNode):
result = await node.run(inputs, variables=None)
elif isinstance(node, ForEachNode):
foreach_results: List[
MessageMemoryItem | BaseMessage
] = await node.run(
inputs,
variables=variables,
)
result = self._flatten_results(foreach_results)
elif isinstance(node, AriumNode):
# AriumNode execution
arium_result: List[MessageMemoryItem] = await node.run(
inputs, variables=variables
)
result = self._flatten_results(arium_result)
elif isinstance(node, StartNode):
result = None
elif isinstance(node, EndNode):
result = None
else:
result = None
result = await self._dispatch_node_run(
node, node_type, inputs, variables
)

# Calculate execution time
execution_time = time.time() - start_time
Expand All @@ -503,6 +494,7 @@ async def _execute_node(
workflow_metrics.record_node_latency(
execution_time_ms, workflow_name, node.name, node_type
)
_profile_record(f'node.{node.name}[{node_type}]', execution_time)

node_span.set_status(Status(StatusCode.OK))
node_span.set_attribute('node.execution_time_ms', execution_time_ms)
Expand Down Expand Up @@ -553,33 +545,13 @@ async def _execute_node(
else:
# No telemetry or start/end node, execute without tracing
try:
# Execute the node based on its type
if isinstance(node, Agent):
result = await node.run(inputs, variables={})
elif isinstance(node, FunctionNode):
result = await node.run(inputs, variables=None)
elif isinstance(node, ForEachNode):
foreach_results: List[
MessageMemoryItem | BaseMessage
] = await node.run(
inputs,
variables=variables,
)
result = self._flatten_results(foreach_results)
elif isinstance(node, AriumNode):
arium_result: List[MessageMemoryItem] = await node.run(
inputs, variables=variables
)
result = self._flatten_results(arium_result)
elif isinstance(node, StartNode):
result = None
elif isinstance(node, EndNode):
result = None
else:
result = None
result = await self._dispatch_node_run(
node, node_type, inputs, variables
)

# Calculate execution time
execution_time = time.time() - start_time
_profile_record(f'node.{node.name}[{node_type}]', execution_time)

# Emit node completed event
self._emit_event(
Expand Down Expand Up @@ -611,6 +583,39 @@ async def _execute_node(
# Re-raise the exception
raise e

async def _dispatch_node_run(
self,
node: AriumNodeType,
node_type: str,
inputs: List[BaseMessage],
variables: Dict[str, Any],
):
"""Dispatch a node's ``run`` invocation under a profiler scope.

Keeps the dispatch logic in one place so both the telemetry and
non-telemetry branches of ``_execute_node`` get consistent profiling.
"""
if node_type in ('start', 'end'):
return None

async with aprofile(f'node.{node.name}[{node_type}]'):
if isinstance(node, Agent):
return await node.run(inputs, variables={})
if isinstance(node, FunctionNode):
return await node.run(inputs, variables=None)
if isinstance(node, ForEachNode):
foreach_results: List[MessageMemoryItem | BaseMessage] = await node.run(
inputs,
variables=variables,
)
return self._flatten_results(foreach_results)
if isinstance(node, AriumNode):
arium_result: List[MessageMemoryItem] = await node.run(
inputs, variables=variables
)
return self._flatten_results(arium_result)
return None

def _flatten_results(
self, sequence: List[MessageMemoryItem | BaseMessage | str]
) -> List[BaseMessage | str]:
Expand Down
56 changes: 54 additions & 2 deletions flo_ai/flo_ai/formatter/yaml_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __get_field_type_annotation(
'bool': bool,
'float': float,
'literal': self.__create_literal_type,
'enum': self.__create_enum_type,
'object': lambda f: self.__create_nested_model(f, model_name),
'array': lambda f: self.__create_array_type(f, model_name),
}
Expand All @@ -78,7 +79,7 @@ def __get_field_type_annotation(

return (
type_handler(field)
if field_type in ['literal', 'object', 'array']
if field_type in ['literal', 'enum', 'object', 'array']
else type_handler
)

Expand All @@ -89,10 +90,50 @@ def __create_literal_type(self, field: Dict[str, Any]) -> Any:
raise ValueError(
f"Field '{field['name']}' of type 'literal' must specify 'values'."
)
literals = tuple(literal_value['value'] for literal_value in literal_values)
literals = tuple(
self.__extract_enum_value(v, field['name']) for v in literal_values
)
# Construct Literal type dynamically at runtime
return Literal.__getitem__(literals)

def __create_enum_type(self, field: Dict[str, Any]) -> Any:
"""Creates a Literal type for an enum field.

Enum values may be plain primitives (str/int/float) or dicts with a
``value`` key (same shape as literal values). At the JSON-schema level
this emits ``{"enum": [...]}`` which every supported LLM backend handles
(OpenAI/Azure via response_format + function schema, Gemini via
response_schema, Anthropic/Ollama/Bedrock/vLLM via inlined schema in the
system prompt).
"""
raw_values = field.get('values', [])
if not raw_values:
raise ValueError(
f"Field '{field['name']}' of type 'enum' must specify 'values'."
)
literals = tuple(
self.__extract_enum_value(v, field['name']) for v in raw_values
)
return Literal.__getitem__(literals)

@staticmethod
def __extract_enum_value(value: Any, field_name: str) -> Any:
"""Normalize an enum/literal value entry to its primitive value."""
if isinstance(value, dict):
if 'value' not in value:
raise ValueError(
f"Field '{field_name}' has an object-style value entry "
"missing the required 'value' key."
)
return value['value']
if isinstance(value, (str, int, float, bool)):
return value
raise ValueError(
f"Field '{field_name}' has an unsupported enum value of type "
f'{type(value).__name__}. Expected str, int, float or object with '
"'value' key."
)

def __create_array_type(self, field: Dict[str, Any], model_name: str) -> Any:
"""Creates a List type from field definition"""
inner_type = self.__get_field_type_annotation(
Expand Down Expand Up @@ -121,6 +162,17 @@ def __create_contract_from_json(self) -> BaseModel:
This should be one of the values in the `value` column in the above csv.
{default_prompt}
"""
elif field['type'] == 'enum':
enum_values = [
self.__extract_enum_value(v, field['name'])
for v in field.get('values', [])
]
default_prompt = field.get('default_value_prompt', '')
field_description = (
f"{field['description']}\n"
f'Must be exactly one of: {enum_values}.'
+ (f'\n{default_prompt}' if default_prompt else '')
)
else:
field_description = field['description']

Expand Down
Loading
Loading