diff --git a/flo_ai/flo_ai/__init__.py b/flo_ai/flo_ai/__init__.py index d3078b4f..96739b87 100644 --- a/flo_ai/flo_ai/__init__.py +++ b/flo_ai/flo_ai/__init__.py @@ -11,6 +11,7 @@ ReasoningPattern, DocumentType, DocumentMessage, + MessageType, ) from .builder.agent_builder import AgentBuilder @@ -56,6 +57,7 @@ 'BaseAgent', 'AgentType', 'ReasoningPattern', + 'MessageType', # Utils 'FloUtils', # LLM diff --git a/flo_ai/flo_ai/builder/agent_builder.py b/flo_ai/flo_ai/builder/agent_builder.py index f6db72a9..d9a04a55 100644 --- a/flo_ai/flo_ai/builder/agent_builder.py +++ b/flo_ai/flo_ai/builder/agent_builder.py @@ -23,6 +23,7 @@ def __init__(self): self._reasoning_pattern = ReasoningPattern.DIRECT self._output_schema: Optional[Dict[str, Any]] = None self._role: Optional[str] = None + self._act_as: Optional[str] = None def with_name(self, name: str) -> 'AgentBuilder': """Set the agent's name""" @@ -153,6 +154,11 @@ def with_role(self, role: str) -> 'AgentBuilder': self._role = role return self + def with_actas(self, act_as: str) -> 'AgentBuilder': + """Set the agent's role""" + self._act_as = act_as + return self + def build(self) -> Agent: """Build and return the configured agent""" if not self._llm: @@ -167,6 +173,7 @@ def build(self) -> Agent: reasoning_pattern=self._reasoning_pattern, output_schema=self._output_schema, role=self._role, + act_as=self._act_as, ) @classmethod @@ -201,6 +208,7 @@ def from_yaml( builder.with_name(agent_config.get('name', 'AI Assistant')) builder.with_prompt(agent_config.get('job', 'You are a helpful AI assistant.')) builder.with_role(agent_config.get('role')) + builder.with_actas(agent_config.get('act_as')) # Configure LLM based on model settings if 'model' in agent_config and base_llm is None: diff --git a/flo_ai/flo_ai/models/__init__.py b/flo_ai/flo_ai/models/__init__.py index ff2b0df4..74abd897 100644 --- a/flo_ai/flo_ai/models/__init__.py +++ b/flo_ai/flo_ai/models/__init__.py @@ -2,7 +2,7 @@ Models package for flo_ai - Agent framework components """ -from .agent import Agent +from .agent import Agent, MessageType from .agent_error import AgentError from .base_agent import BaseAgent, AgentType, ReasoningPattern from .document import DocumentMessage, DocumentType @@ -16,5 +16,6 @@ 'ReasoningPattern', 'DocumentMessage', 'DocumentType', + 'MessageType', 'ChatMessage', ] diff --git a/flo_ai/flo_ai/models/agent.py b/flo_ai/flo_ai/models/agent.py index 3d133e2d..aff81a90 100644 --- a/flo_ai/flo_ai/models/agent.py +++ b/flo_ai/flo_ai/models/agent.py @@ -20,6 +20,13 @@ from flo_ai.telemetry import get_tracer +class MessageType: + USER = 'user' + ASSISTANT = 'assistant' + FUNCTION = 'function' + SYSTEM = 'system' + + class Agent(BaseAgent): def __init__( self, @@ -32,6 +39,7 @@ def __init__( reasoning_pattern: ReasoningPattern = ReasoningPattern.DIRECT, output_schema: Optional[Dict[str, Any]] = None, role: Optional[str] = None, + act_as: Optional[str] = MessageType.ASSISTANT, ): # Determine agent type based on tools agent_type = AgentType.TOOL_USING if tools else AgentType.CONVERSATIONAL @@ -54,6 +62,7 @@ def __init__( self.reasoning_pattern = reasoning_pattern self.output_schema = output_schema self.role = role + self.act_as = act_as @trace_agent_execution() async def run( @@ -84,17 +93,19 @@ async def run( # Process inputs and resolve variables in string inputs for input in inputs: if isinstance(input, ImageMessage): - self.add_to_history('user', self.llm.format_image_in_message(input)) + self.add_to_history( + MessageType.USER, self.llm.format_image_in_message(input) + ) elif isinstance(input, DocumentMessage): formatted_doc = await self.llm.format_document_in_message(input) - self.add_to_history('user', formatted_doc) + self.add_to_history(MessageType.USER, formatted_doc) elif isinstance(input, ChatMessage): resolved_content = resolve_variables(input.content, variables) self.add_to_history(input.role, resolved_content) else: # Resolve variables in text input resolved_input = resolve_variables(input, variables) - self.add_to_history('user', resolved_input) + self.add_to_history(MessageType.USER, resolved_input) # after resolving agent system prompts and inputs, mark variables as resolved self.resolved_variables = True @@ -103,14 +114,16 @@ async def run( # Variables already resolved, process inputs without variable resolution for input in inputs: if isinstance(input, ImageMessage): - self.add_to_history('user', self.llm.format_image_in_message(input)) + self.add_to_history( + MessageType.USER, self.llm.format_image_in_message(input) + ) elif isinstance(input, DocumentMessage): formatted_doc = await self.llm.format_document_in_message(input) - self.add_to_history('user', formatted_doc) + self.add_to_history(MessageType.USER, formatted_doc) elif isinstance(input, ChatMessage): self.add_to_history(input.role, input.content) else: - self.add_to_history('user', input) + self.add_to_history(MessageType.USER, input) retry_count = 0 @@ -137,7 +150,7 @@ async def _run_conversational( ) messages = [ { - 'role': 'system', + 'role': MessageType.SYSTEM, 'content': system_content, } ] + self.conversation_history @@ -152,7 +165,7 @@ async def _run_conversational( logger.debug(f'Extracted message: {assistant_message}') if assistant_message: - self.add_to_history('assistant', assistant_message) + self.add_to_history(self.act_as, assistant_message) return assistant_message else: possible_tool_message = await self.llm.get_function_call(response) @@ -172,7 +185,7 @@ async def _run_conversational( if should_retry and retry_count <= self.max_retries: self.add_to_history( - 'system', f'Error occurred. Analysis: {analysis}' + MessageType.SYSTEM, f'Error occurred. Analysis: {analysis}' ) continue else: @@ -199,7 +212,7 @@ async def _run_with_tools( messages = [ { - 'role': 'system', + 'role': MessageType.SYSTEM, 'content': system_content, } ] + self.conversation_history @@ -226,7 +239,7 @@ async def _run_with_tools( assistant_message, tool_call_count, messages ) if is_final: - self.add_to_history('assistant', assistant_message) + self.add_to_history(self.act_as, assistant_message) return assistant_message else: # This is intermediate reasoning, add to context and continue @@ -238,17 +251,17 @@ async def _run_with_tools( logger.debug( f'Detected intermediate reasoning (not final answer): {msg_preview}...' ) - self.add_to_history('assistant', assistant_message) + self.add_to_history(self.act_as, assistant_message) messages.append( { - 'role': 'assistant', + 'role': self.act_as, 'content': assistant_message, } ) # Prompt the agent to take action messages.append( { - 'role': 'user', + 'role': MessageType.USER, 'content': 'Based on your reasoning, please proceed with the necessary tool calls to complete the task.', } ) @@ -295,7 +308,7 @@ async def _run_with_tools( # Add function call to history self.add_to_history( - 'function', + MessageType.FUNCTION, f'Tool response: {str(function_response)}', name=function_name, ) @@ -303,7 +316,7 @@ async def _run_with_tools( # Add the function response to messages for context messages.append( { - 'role': 'function', + 'role': MessageType.FUNCTION, 'name': function_name, 'content': f'Here is the result of the tool call: \n {str(function_response)}', } @@ -312,7 +325,7 @@ async def _run_with_tools( # Add a prompt to continue the reasoning messages.append( { - 'role': 'user', + 'role': MessageType.USER, 'content': 'Continue with your reasoning based on this result. What should be done next?', } ) @@ -336,7 +349,7 @@ async def _run_with_tools( ) self.add_to_history( - 'system', f'Tool execution error: {analysis}' + MessageType.SYSTEM, f'Tool execution error: {analysis}' ) continue raise AgentError( @@ -348,7 +361,7 @@ async def _run_with_tools( messages + [ { - 'role': 'system', + 'role': MessageType.SYSTEM, 'content': 'Please provide a final answer based on all the tool results above.', } ], @@ -357,7 +370,7 @@ async def _run_with_tools( assistant_message = self.llm.get_message_content(final_response) if assistant_message: - self.add_to_history('assistant', assistant_message) + self.add_to_history(self.act_as, assistant_message) return assistant_message return f'The final result based on the tool executions is: {function_response}' @@ -375,7 +388,7 @@ async def _run_with_tools( agent_metrics.record_retry(self.name, 'execution_error') self.add_to_history( - 'system', f'Error occurred. Analysis: {analysis}' + MessageType.SYSTEM, f'Error occurred. Analysis: {analysis}' ) continue @@ -528,10 +541,10 @@ async def _is_final_answer( try: analysis_messages = [ { - 'role': 'system', + 'role': MessageType.SYSTEM, 'content': 'You are a precise classification system. Respond with only FINAL or INTERMEDIATE.', }, - {'role': 'user', 'content': analysis_prompt}, + {'role': MessageType.USER, 'content': analysis_prompt}, ] analysis_response = await self.llm.generate(analysis_messages) analysis = self.llm.get_message_content(analysis_response).strip().upper()