From 1808e246e7dc182b2e6aeaa6b7281d53dd85ffd2 Mon Sep 17 00:00:00 2001 From: vishnu r kumar Date: Fri, 12 Dec 2025 15:19:09 +0530 Subject: [PATCH 1/3] refactor(flo-ai): follow arium directory structure for agents --- TOOLS.md | 8 +- documentation/development.mdx | 2 +- documentation/essentials/agents.mdx | 14 +- documentation/essentials/arium.mdx | 30 +- documentation/essentials/code.mdx | 26 +- documentation/essentials/telemetry.mdx | 76 +-- documentation/essentials/yaml-agents.mdx | 106 +-- documentation/quickstart.mdx | 43 +- flo_ai/README.md | 10 +- flo_ai/docs/arium_yaml_guide.md | 126 ++-- flo_ai/examples/agent_builder_usage.py | 4 +- flo_ai/examples/arium_examples.py | 2 +- flo_ai/examples/arium_linear_usage.py | 2 +- flo_ai/examples/arium_yaml_example.py | 4 +- flo_ai/examples/chat_history.py | 4 +- flo_ai/examples/cot_agent_example.py | 4 +- flo_ai/examples/cot_conversational_example.py | 5 +- flo_ai/examples/custom_plan_execute_demo.py | 2 +- .../examples/document_processing_example.py | 2 +- .../examples/example_graph_visualization.py | 8 +- flo_ai/examples/flo_tool_example.py | 4 +- flo_ai/examples/llm_router_example.py | 4 +- flo_ai/examples/multi_tool_example.py | 4 +- flo_ai/examples/ollama_agent_example.py | 4 +- flo_ai/examples/output_formatter.py | 4 +- flo_ai/examples/partial_tool_example.py | 2 +- flo_ai/examples/simple_flow_router_demo.py | 2 +- flo_ai/examples/simple_plan_execute_demo.py | 2 +- .../examples/simple_reflection_router_demo.py | 2 +- flo_ai/examples/simple_working_demo.py | 4 +- flo_ai/examples/tool_usage.py | 4 +- flo_ai/examples/tool_using_agent.py | 6 +- flo_ai/examples/tools_quickstart.py | 2 +- flo_ai/examples/usage_claude.py | 4 +- flo_ai/examples/variables_workflow_example.py | 4 +- .../variables_workflow_yaml_example.py | 2 +- flo_ai/examples/vertexai_agent_example.py | 6 +- flo_ai/examples/vllm_agent_usage.py | 4 +- flo_ai/examples/yaml_agent_example.py | 2 +- flo_ai/examples/yaml_tool_config_example.py | 2 +- flo_ai/flo_ai/__init__.py | 16 +- flo_ai/flo_ai/agent/__init__.py | 14 + flo_ai/flo_ai/agent/agent.py | 643 ++++++++++++++++++ flo_ai/flo_ai/{models => agent}/base_agent.py | 0 .../agent_builder.py => agent/builder.py} | 4 +- .../flo_ai/{models => agent}/plan_agents.py | 2 +- flo_ai/flo_ai/arium/arium.py | 2 +- flo_ai/flo_ai/arium/base.py | 2 +- flo_ai/flo_ai/arium/builder.py | 6 +- flo_ai/flo_ai/models/__init__.py | 9 +- flo_ai/flo_ai/models/agent.py | 643 ------------------ .../unit-tests/test_agent_builder_tools.py | 2 +- flo_ai/tests/unit-tests/test_arium_builder.py | 2 +- flo_ai/tests/unit-tests/test_arium_yaml.py | 2 +- flo_ai/tests/unit-tests/test_base_llm.py | 1 + flo_ai/tests/unit-tests/test_llm_router.py | 1 + .../tests/unit-tests/test_yaml_tool_config.py | 2 +- 57 files changed, 964 insertions(+), 933 deletions(-) create mode 100644 flo_ai/flo_ai/agent/__init__.py create mode 100644 flo_ai/flo_ai/agent/agent.py rename flo_ai/flo_ai/{models => agent}/base_agent.py (100%) rename flo_ai/flo_ai/{builder/agent_builder.py => agent/builder.py} (99%) rename flo_ai/flo_ai/{models => agent}/plan_agents.py (99%) diff --git a/TOOLS.md b/TOOLS.md index 6e5df5bd..f12aa53e 100644 --- a/TOOLS.md +++ b/TOOLS.md @@ -65,7 +65,7 @@ Every tool has these properties: ### Using Tools in Agents ```python -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI agent = (AgentBuilder() @@ -95,7 +95,7 @@ Partial tools allow you to pre-fill some parameters during agent building, hidin #### Method 1: Using AgentBuilder.add_tool() ```python -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder # BigQuery tool with multiple parameters @flo_tool(description="Query BigQuery database") @@ -304,7 +304,7 @@ agent: ### Using YAML Configuration ```python -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder # Create tool registry tool_registry = { @@ -331,7 +331,7 @@ agent = AgentBuilder.from_yaml( ```python import asyncio from flo_ai.tool.flo_tool import flo_tool -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI # Define tools diff --git a/documentation/development.mdx b/documentation/development.mdx index 84937f32..caa58912 100644 --- a/documentation/development.mdx +++ b/documentation/development.mdx @@ -55,7 +55,7 @@ Test your installation with a simple agent: ```python import asyncio -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI async def test_installation(): diff --git a/documentation/essentials/agents.mdx b/documentation/essentials/agents.mdx index 1cdd79c7..74441499 100644 --- a/documentation/essentials/agents.mdx +++ b/documentation/essentials/agents.mdx @@ -1,7 +1,7 @@ --- -title: 'Agents' -description: 'Learn how to create and configure AI agents with Flo AI' -icon: 'robot' +title: "Agents" +description: "Learn how to create and configure AI agents with Flo AI" +icon: "robot" --- ## Creating Agents @@ -13,7 +13,7 @@ Agents are the core building blocks of Flo AI. They represent AI-powered entitie Create a simple conversational agent: ```python -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI agent = ( @@ -180,7 +180,7 @@ well_prompted_agent = ( 1. Review code for bugs, security issues, and best practices 2. Suggest improvements and optimizations 3. Provide constructive feedback - + Always be specific about issues and provide actionable suggestions. Focus on code quality, performance, and maintainability. ''') @@ -263,10 +263,10 @@ from flo_ai.arium.memory import BaseMemory class CustomMemory(BaseMemory): def __init__(self): self.messages = [] - + def add(self, message): self.messages.append(message) - + def get(self): return self.messages diff --git a/documentation/essentials/arium.mdx b/documentation/essentials/arium.mdx index 44388d06..0bd95b76 100644 --- a/documentation/essentials/arium.mdx +++ b/documentation/essentials/arium.mdx @@ -1,7 +1,7 @@ --- -title: 'Arium Workflows' -description: 'Create complex multi-agent workflows with Arium orchestration' -icon: 'sitemap' +title: "Arium Workflows" +description: "Create complex multi-agent workflows with Arium orchestration" +icon: "sitemap" --- ## What is Arium? @@ -16,25 +16,25 @@ Create a linear workflow with multiple agents: ```python from flo_ai.arium import AriumBuilder -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.llm import OpenAI async def simple_chain(): llm = OpenAI(model='gpt-4o-mini') - + # Create agents analyst = Agent( name='content_analyst', system_prompt='Analyze the input and extract key insights.', llm=llm ) - + summarizer = Agent( - name='summarizer', + name='summarizer', system_prompt='Create a concise summary based on the analysis.', llm=llm ) - + # Build and run workflow result = await ( AriumBuilder() @@ -44,7 +44,7 @@ async def simple_chain(): .end_with(summarizer) .build_and_run(["Analyze this complex business report..."]) ) - + return result ``` @@ -59,7 +59,7 @@ def route_by_type(memory: BaseMemory) -> str: """Route based on classification result""" messages = memory.get() last_message = str(messages[-1]) if messages else "" - + if "technical" in last_message.lower(): return "tech_specialist" else: @@ -96,7 +96,7 @@ arium: model: provider: "openai" name: "gpt-4o-mini" - + - name: "summarizer" role: "Content Summarizer" job: "Create a concise summary based on the analysis." @@ -130,7 +130,7 @@ Use LLMs for intelligent routing decisions: ```yaml routers: - name: "content_type_router" - type: "smart" # Uses LLM for intelligent routing + type: "smart" # Uses LLM for intelligent routing routing_options: technical_writer: "Technical content, documentation, tutorials" creative_writer: "Creative writing, storytelling, fiction" @@ -148,7 +148,7 @@ For A→B→A→C feedback patterns: routers: - name: "reflection_router" type: "reflection" - flow_pattern: [writer, critic, writer] # A → B → A pattern + flow_pattern: [writer, critic, writer] # A → B → A pattern model: provider: "openai" name: "gpt-4o-mini" @@ -248,10 +248,10 @@ from flo_ai.arium.memory import BaseMemory class CustomMemory(BaseMemory): def __init__(self): self.data = {} - + def add(self, key, value): self.data[key] = value - + def get(self, key): return self.data.get(key) diff --git a/documentation/essentials/code.mdx b/documentation/essentials/code.mdx index 503c253c..e83aeab6 100644 --- a/documentation/essentials/code.mdx +++ b/documentation/essentials/code.mdx @@ -1,7 +1,7 @@ --- -title: 'Code Examples' -description: 'Flo AI code examples and syntax highlighting' -icon: 'code' +title: "Code Examples" +description: "Flo AI code examples and syntax highlighting" +icon: "code" --- ## Basic Agent Creation @@ -10,7 +10,7 @@ Here's how to create a simple conversational agent with Flo AI: ```python Simple Agent import asyncio -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI async def main(): @@ -34,7 +34,7 @@ Create agents that can use custom tools: ```python Tool-Using Agent import asyncio -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.tool import flo_tool from flo_ai.llm import Anthropic @@ -72,7 +72,7 @@ Use Pydantic models for structured agent responses: ```python Structured Output import asyncio from pydantic import BaseModel, Field -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI class MathSolution(BaseModel): @@ -102,25 +102,25 @@ Create complex workflows with multiple agents: ```python Multi-Agent Workflow import asyncio from flo_ai.arium import AriumBuilder -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.llm import OpenAI async def content_analysis_workflow(): llm = OpenAI(model='gpt-4o-mini') - + # Create specialized agents analyst = Agent( name='content_analyst', system_prompt='Analyze the input and extract key insights.', llm=llm ) - + summarizer = Agent( - name='summarizer', + name='summarizer', system_prompt='Create a concise summary based on the analysis.', llm=llm ) - + # Build and run workflow result = await ( AriumBuilder() @@ -130,7 +130,7 @@ async def content_analysis_workflow(): .end_with(summarizer) .build_and_run(["Analyze this complex business report..."]) ) - + return result asyncio.run(content_analysis_workflow()) @@ -154,7 +154,7 @@ arium: model: provider: "openai" name: "gpt-4o-mini" - + - name: "summarizer" role: "Content Summarizer" job: "Create a concise summary based on the analysis." diff --git a/documentation/essentials/telemetry.mdx b/documentation/essentials/telemetry.mdx index c0e7f2a4..ed9ec242 100644 --- a/documentation/essentials/telemetry.mdx +++ b/documentation/essentials/telemetry.mdx @@ -1,7 +1,7 @@ --- -title: 'Telemetry & Observability' -description: 'Monitor and observe your AI agents with OpenTelemetry integration' -icon: 'chart-line' +title: "Telemetry & Observability" +description: "Monitor and observe your AI agents with OpenTelemetry integration" +icon: "chart-line" --- ## OpenTelemetry Integration @@ -48,7 +48,7 @@ export JAEGER_AGENT_PORT="14268" ### Automatic Instrumentation ```python -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI # Agents are automatically instrumented @@ -75,13 +75,13 @@ async def custom_agent_operation(): with tracer.start_as_current_span("custom_operation") as span: span.set_attribute("operation.type", "data_processing") span.set_attribute("operation.input_size", 1000) - + # Your custom logic here result = await process_data() - + span.set_attribute("operation.output_size", len(result)) span.set_attribute("operation.success", True) - + return result ``` @@ -127,15 +127,15 @@ workflow_duration = meter.create_histogram( # Record metrics async def execute_workflow(): start_time = time.time() - + workflow_counter.add(1, {"workflow_type": "analysis"}) - + # Execute workflow result = await workflow.build_and_run(["data"]) - + duration = time.time() - start_time workflow_duration.record(duration, {"workflow_type": "analysis"}) - + return result ``` @@ -168,26 +168,26 @@ agent_metrics = { # Record agent metrics async def execute_agent_with_metrics(agent, input_data): start_time = time.time() - + agent_metrics["agent_executions_total"].add(1, { "agent_name": agent.name, "model": agent.llm.model }) - + response = await agent.run(input_data) - + duration = time.time() - start_time agent_metrics["agent_duration_seconds"].record(duration, { "agent_name": agent.name }) - + # Record token usage if available if hasattr(response, 'usage'): agent_metrics["agent_tokens_used"].add( response.usage.total_tokens, {"agent_name": agent.name} ) - + return response ``` @@ -217,30 +217,30 @@ llm_metrics = { # Record LLM metrics async def llm_request_with_metrics(llm, prompt): start_time = time.time() - + llm_metrics["llm_requests_total"].add(1, { "provider": llm.provider, "model": llm.model }) - + try: response = await llm.generate(prompt) - + duration = time.time() - start_time llm_metrics["llm_request_duration"].record(duration, { "provider": llm.provider, "model": llm.model }) - + # Record token usage if hasattr(response, 'usage'): llm_metrics["llm_tokens_total"].add( response.usage.total_tokens, {"provider": llm.provider, "model": llm.model} ) - + return response - + except Exception as e: llm_metrics["llm_errors_total"].add(1, { "provider": llm.provider, @@ -260,15 +260,15 @@ from opentelemetry.trace import get_current_span async def distributed_workflow(): # Get current trace context current_span = get_current_span() - + if current_span: # Add custom attributes current_span.set_attribute("workflow.id", "analysis_workflow") current_span.set_attribute("workflow.version", "1.0.0") - + # Execute workflow - trace context is automatically propagated result = await workflow.build_and_run(["data"]) - + return result ``` @@ -302,21 +302,21 @@ async def custom_operation(): # Set span attributes span.set_attribute("operation.name", "data_processing") span.set_attribute("operation.input_size", 1000) - + # Add events span.add_event("Processing started") - + # Your custom logic result = await process_data() - + span.add_event("Processing completed", { "output_size": len(result), "processing_time": time.time() - start_time }) - + # Set status span.set_status(trace.Status(trace.StatusCode.OK)) - + return result ``` @@ -341,15 +341,15 @@ custom_histogram = meter.create_histogram( # Record custom metrics async def custom_operation_with_metrics(): start_time = time.time() - + custom_counter.add(1, {"operation_type": "data_processing"}) - + # Your operation result = await custom_operation() - + duration = time.time() - start_time custom_histogram.record(duration, {"operation_type": "data_processing"}) - + return result ``` @@ -425,7 +425,7 @@ async def monitored_operation(): result = await risky_operation() span.set_status(trace.Status(trace.StatusCode.OK)) return result - + except Exception as e: # Record exception span.record_exception(e) @@ -468,7 +468,7 @@ latency_histogram = meter.create_histogram( async def monitored_operation(): start_time = time.time() - + try: result = await operation() return result @@ -521,13 +521,13 @@ configure_telemetry( def sanitize_span_attributes(attributes): sensitive_keys = ['password', 'token', 'key', 'secret'] sanitized = {} - + for key, value in attributes.items(): if any(sensitive in key.lower() for sensitive in sensitive_keys): sanitized[key] = "[REDACTED]" else: sanitized[key] = value - + return sanitized ``` diff --git a/documentation/essentials/yaml-agents.mdx b/documentation/essentials/yaml-agents.mdx index 0452f713..a50b40c2 100644 --- a/documentation/essentials/yaml-agents.mdx +++ b/documentation/essentials/yaml-agents.mdx @@ -1,7 +1,7 @@ --- -title: 'YAML-Based Agents' -description: 'Create and configure agents using YAML configuration files' -icon: 'file-code' +title: "YAML-Based Agents" +description: "Create and configure agents using YAML configuration files" +icon: "file-code" --- ## YAML Agent Configuration @@ -30,7 +30,7 @@ agent: ``` ```python Load YAML Agent -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder # Load agent from YAML agent = AgentBuilder.from_yaml('agent.yaml') @@ -53,7 +53,7 @@ agent: provider: "anthropic" name: "claude-3-5-sonnet-20240620" temperature: 0.3 - + tools: - name: "calculate" description: "Perform mathematical calculations" @@ -84,7 +84,7 @@ agent: provider: "openai" name: "gpt-4o" temperature: 0.2 - + output_schema: type: "pydantic" model: "AnalysisResult" @@ -115,7 +115,7 @@ agent: model: provider: "openai" name: "gpt-4o-mini" - + variables: - name: "user_name" description: "User's name" @@ -135,58 +135,58 @@ agent: ```yaml metadata: - name: "agent-name" # Required: Unique agent identifier - version: "1.0.0" # Required: Semantic version - description: "Agent description" # Optional: Human-readable description - author: "Your Name" # Optional: Agent author - tags: ["tag1", "tag2"] # Optional: Categorization tags + name: "agent-name" # Required: Unique agent identifier + version: "1.0.0" # Required: Semantic version + description: "Agent description" # Optional: Human-readable description + author: "Your Name" # Optional: Agent author + tags: ["tag1", "tag2"] # Optional: Categorization tags ``` ### Agent Configuration ```yaml agent: - name: "Agent Display Name" # Required: Human-readable name - prompt: "System prompt" # Required: Agent's system prompt - model: # Required: LLM configuration - provider: "openai" # Required: openai, anthropic, google, ollama - name: "gpt-4o-mini" # Required: Model name - temperature: 0.7 # Optional: 0.0 to 2.0 - max_tokens: 1000 # Optional: Maximum response length - timeout: 30 # Optional: Request timeout in seconds - + name: "Agent Display Name" # Required: Human-readable name + prompt: "System prompt" # Required: Agent's system prompt + model: # Required: LLM configuration + provider: "openai" # Required: openai, anthropic, google, ollama + name: "gpt-4o-mini" # Required: Model name + temperature: 0.7 # Optional: 0.0 to 2.0 + max_tokens: 1000 # Optional: Maximum response length + timeout: 30 # Optional: Request timeout in seconds + # Optional configurations - retries: 3 # Number of retry attempts - cache_ttl: 3600 # Cache time-to-live in seconds - memory_type: "message" # Memory type: message, custom - tools: [] # List of tools (see tools section) - variables: [] # List of variables (see variables section) - output_schema: {} # Structured output configuration + retries: 3 # Number of retry attempts + cache_ttl: 3600 # Cache time-to-live in seconds + memory_type: "message" # Memory type: message, custom + tools: [] # List of tools (see tools section) + variables: [] # List of variables (see variables section) + output_schema: {} # Structured output configuration ``` ### Tools Configuration ```yaml tools: - - name: "tool_name" # Required: Tool identifier - description: "Tool description" # Required: Tool description - function: "function_name" # Required: Python function name - parameters: # Required: Function parameters + - name: "tool_name" # Required: Tool identifier + description: "Tool description" # Required: Tool description + function: "function_name" # Required: Python function name + parameters: # Required: Function parameters param1: - type: "string" # Parameter type + type: "string" # Parameter type description: "Parameter description" - required: true # Whether parameter is required - default: "value" # Default value + required: true # Whether parameter is required + default: "value" # Default value ``` ### Variables Configuration ```yaml variables: - - name: "variable_name" # Required: Variable name + - name: "variable_name" # Required: Variable name description: "Description" # Required: Variable description - type: "string" # Optional: string, number, boolean - required: true # Optional: Whether variable is required + type: "string" # Optional: string, number, boolean + required: true # Optional: Whether variable is required default: "default_value" # Optional: Default value ``` @@ -195,7 +195,7 @@ variables: ### Basic Loading ```python -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder # Load from file agent = AgentBuilder.from_yaml('agent.yaml') @@ -270,11 +270,11 @@ agent: prompt: "Concise and effective prompt" model: provider: "openai" - name: "gpt-4o-mini" # Use faster model for simple tasks - temperature: 0.3 # Lower temperature for consistency - max_tokens: 500 # Limit response length - cache_ttl: 3600 # Cache responses for 1 hour - retries: 2 # Limit retries to avoid costs + name: "gpt-4o-mini" # Use faster model for simple tasks + temperature: 0.3 # Lower temperature for consistency + max_tokens: 500 # Limit response length + cache_ttl: 3600 # Cache responses for 1 hour + retries: 2 # Limit retries to avoid costs ``` ### Security Considerations @@ -291,9 +291,9 @@ agent: model: provider: "openai" name: "gpt-4o" - temperature: 0.1 # Lower temperature for consistency - max_tokens: 200 # Limit response length - timeout: 10 # Short timeout for security + temperature: 0.1 # Lower temperature for consistency + max_tokens: 200 # Limit response length + timeout: 10 # Short timeout for security ``` ## Validation and Testing @@ -301,7 +301,7 @@ agent: ### Schema Validation ```python -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder import yaml # Validate YAML structure @@ -309,13 +309,13 @@ def validate_agent_yaml(file_path): try: with open(file_path, 'r') as f: config = yaml.safe_load(f) - + # Check required fields assert 'agent' in config assert 'name' in config['agent'] assert 'prompt' in config['agent'] assert 'model' in config['agent'] - + print("✅ YAML configuration is valid") return True except Exception as e: @@ -332,12 +332,12 @@ import asyncio async def test_yaml_agent(): agent = AgentBuilder.from_yaml('agent.yaml') - + # Test basic functionality response = await agent.run('Hello!') assert response is not None print(f"✅ Agent responds: {response}") - + # Test with variables if 'variables' in agent.config: variables = {'test_var': 'test_value'} @@ -365,7 +365,7 @@ agent: 2. Provide accurate and helpful information 3. Escalate complex issues when necessary 4. Maintain a friendly and professional tone - + Always be patient, understanding, and solution-oriented. model: provider: "openai" @@ -399,7 +399,7 @@ agent: type: "pydantic" model: "AnalysisResult" fields: - summary: + summary: type: "string" description: "Executive summary" insights: diff --git a/documentation/quickstart.mdx b/documentation/quickstart.mdx index 756eb9a1..b23cfee2 100644 --- a/documentation/quickstart.mdx +++ b/documentation/quickstart.mdx @@ -11,31 +11,15 @@ Build your first AI agent and understand the core concepts of Flo AI. - Install Flo AI using pip or poetry: - - ```bash - pip install flo-ai - # or using poetry - poetry add flo-ai - # or using uv - uv add flo-ai - ``` - + Install Flo AI using pip or poetry: ```bash pip install flo-ai # or using + poetry poetry add flo-ai # or using uv uv add flo-ai ``` Flo AI requires Python 3.10 or higher. - Configure your LLM provider API keys: - - ```bash - # For OpenAI - export OPENAI_API_KEY="your-openai-key" - - # For Anthropic - export ANTHROPIC_API_KEY="your-anthropic-key" - - # For Google Gemini - export GOOGLE_API_KEY="your-google-key" - ``` + Configure your LLM provider API keys: ```bash # For OpenAI export + OPENAI_API_KEY="your-openai-key" # For Anthropic export + ANTHROPIC_API_KEY="your-anthropic-key" # For Google Gemini export + GOOGLE_API_KEY="your-google-key" ``` @@ -46,7 +30,7 @@ Build your first AI agent and understand the core concepts of Flo AI. ```python import asyncio -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI async def main(): @@ -75,7 +59,7 @@ asyncio.run(main()) ```python import asyncio -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.tool import flo_tool from flo_ai.llm import Anthropic @@ -105,6 +89,7 @@ async def main(): asyncio.run(main()) ``` + ## Next Steps @@ -117,7 +102,11 @@ Now that you have your first agent running, explore these key features: Learn how to create complex multi-agent workflows with Arium. - + Define entire agent architectures in YAML for easy management. @@ -132,5 +121,7 @@ Now that you have your first agent running, explore these key features: - **Need help?** Check out our [examples](https://github.com/rootflo/flo-ai/tree/main/flo_ai/examples) or join our [community](https://github.com/rootflo/flo-ai/discussions). + **Need help?** Check out our + [examples](https://github.com/rootflo/flo-ai/tree/main/flo_ai/examples) or + join our [community](https://github.com/rootflo/flo-ai/discussions). diff --git a/flo_ai/README.md b/flo_ai/README.md index 486cd064..a89f82ce 100644 --- a/flo_ai/README.md +++ b/flo_ai/README.md @@ -85,7 +85,7 @@ uv add flo-ai ```python import asyncio -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI async def main(): @@ -108,7 +108,7 @@ asyncio.run(main()) ```python import asyncio -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.tool import flo_tool from flo_ai.llm import Anthropic @@ -144,7 +144,7 @@ asyncio.run(main()) ```python import asyncio from pydantic import BaseModel, Field -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI class MathSolution(BaseModel): @@ -359,7 +359,7 @@ Arium is Flo AI's powerful workflow orchestration engine for creating complex mu ```python from flo_ai.arium import AriumBuilder -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.llm import OpenAI async def simple_chain(): @@ -597,4 +597,4 @@ Flo AI is [MIT Licensed](LICENSE). Built with ❤️ by the rootflo team
CommunityDocumentation - \ No newline at end of file + diff --git a/flo_ai/docs/arium_yaml_guide.md b/flo_ai/docs/arium_yaml_guide.md index 419444a5..b138da3a 100644 --- a/flo_ai/docs/arium_yaml_guide.md +++ b/flo_ai/docs/arium_yaml_guide.md @@ -25,9 +25,9 @@ metadata: arium: agents: # Method 1: Reference pre-built agents (cleanest approach) - - name: content_analyst # Must exist in agents parameter - - name: summarizer # Must exist in agents parameter - + - name: content_analyst # Must exist in agents parameter + - name: summarizer # Must exist in agents parameter + # Method 2: Direct configuration - name: validator role: "Quality Validator" @@ -38,15 +38,15 @@ arium: settings: temperature: 0.1 reasoning_pattern: DIRECT - + tools: - name: tool1 - name: tool2 - + # LLM Router definitions (NEW!) routers: - name: content_router - type: smart # smart, task_classifier, conversation_analysis + type: smart # smart, task_classifier, conversation_analysis routing_options: technical_writer: "Handle technical documentation tasks" creative_writer: "Handle creative writing tasks" @@ -65,7 +65,7 @@ arium: to: [summarizer] - from: summarizer to: [validator, tool1] - router: content_router # References router defined above + router: content_router # References router defined above - from: validator to: [end] - from: tool1 @@ -77,10 +77,10 @@ arium: ```yaml metadata: - name: workflow-name # Required: Unique workflow identifier - version: 1.0.0 # Optional: Semantic version - description: "Description" # Optional: Human-readable description - tags: ["tag1", "tag2"] # Optional: Classification tags + name: workflow-name # Required: Unique workflow identifier + version: 1.0.0 # Optional: Semantic version + description: "Description" # Optional: Human-readable description + tags: ["tag1", "tag2"] # Optional: Classification tags ``` ### Arium Section @@ -101,6 +101,7 @@ builder = AriumBuilder.from_yaml(yaml_str=config, memory=custom_memory) ``` **Why memory is handled as a parameter:** + - ✅ **Cleaner YAML**: Focuses on workflow structure, not implementation details - ✅ **Runtime Flexibility**: Same workflow can use different memory implementations - ✅ **Better Separation**: Memory is an execution concern, not a workflow definition concern @@ -156,11 +157,13 @@ routers: ``` **Router Types:** + - **`smart`**: General-purpose routing based on content analysis - **`task_classifier`**: Routes based on task categorization with keywords and examples - **`conversation_analysis`**: Routes based on conversation context analysis **Key benefits of YAML LLM routers:** + - ✅ **Declarative Configuration**: No code needed to create routers - ✅ **Easy Modification**: Change routing logic without code changes - ✅ **Version Control**: Track router changes in YAML files @@ -172,28 +175,30 @@ routers: You can define agents in four ways: **1. Reference Pre-built Agents (New):** + ```yaml agents: - - name: content_analyst # Must exist in agents parameter - - name: summarizer # Must exist in agents parameter + - name: content_analyst # Must exist in agents parameter + - name: summarizer # Must exist in agents parameter ``` **2. Direct Configuration (Recommended):** + ```yaml agents: - name: my_agent - role: Assistant # optional + role: Assistant # optional job: "You are a helpful assistant" model: provider: openai name: gpt-4o-mini - base_url: "https://api.openai.com/v1" # optional + base_url: "https://api.openai.com/v1" # optional settings: temperature: 0.7 max_retries: 3 - reasoning_pattern: DIRECT # DIRECT, REACT, COT - tools: ["calculator", "web_search"] # optional - parser: # optional for structured output + reasoning_pattern: DIRECT # DIRECT, REACT, COT + tools: ["calculator", "web_search"] # optional + parser: # optional for structured output name: MyParser fields: - name: result @@ -202,6 +207,7 @@ agents: ``` **3. Inline YAML Configuration:** + ```yaml agents: - name: my_agent @@ -220,6 +226,7 @@ agents: ``` **4. External File Reference:** + ```yaml agents: - name: my_agent @@ -240,6 +247,7 @@ tools: #### Comparison of Agent Configuration Methods **Reference Pre-built Agents (New):** + - ✅ Maximum reusability across workflows - ✅ Programmatic agent building with YAML workflows - ✅ Clean separation of agent definition and workflow @@ -247,6 +255,7 @@ tools: - ⚠️ Requires separate agent building step **Direct Configuration (Recommended):** + - ✅ Clean, flat structure - ✅ No nested YAML-in-YAML - ✅ IDE syntax highlighting and validation @@ -254,12 +263,14 @@ tools: - ✅ Supports all agent features directly **Inline YAML Configuration:** + - ⚠️ Requires nested YAML string - ⚠️ Limited IDE support for nested content - ✅ Maintains existing workflow compatibility - ✅ Good for complex parser configurations **External File Reference:** + - ✅ Best for reusable agent definitions - ✅ Supports modular architecture - ✅ Version control friendly @@ -269,16 +280,16 @@ tools: ```yaml workflow: - start: agent1 # Starting node name + start: agent1 # Starting node name edges: - - from: agent1 # Source node - to: [tool1, agent2] # Target nodes - router: decision_router # Optional custom router + - from: agent1 # Source node + to: [tool1, agent2] # Target nodes + router: decision_router # Optional custom router - from: tool1 - to: [end] # 'end' is a special keyword + to: [end] # 'end' is a special keyword - from: agent2 to: [end] - end: [tool1, agent2] # End nodes + end: [tool1, agent2] # End nodes ``` ## Usage Examples @@ -292,7 +303,7 @@ from flo_ai.arium.builder import AriumBuilder yaml_config = """ metadata: name: simple-workflow - + arium: agents: - name: analyzer @@ -303,7 +314,7 @@ arium: name: gpt-4o-mini settings: temperature: 0.3 - + - name: summarizer role: Summary Generator job: "Create a concise summary based on the analysis" @@ -312,7 +323,7 @@ arium: name: gpt-4o-mini settings: temperature: 0.2 - + workflow: start: analyzer edges: @@ -332,6 +343,7 @@ asyncio.run(main()) ``` **Note**: To use custom memory, pass it as a parameter: + ```python from flo_ai.arium.memory import MessageMemory @@ -380,7 +392,7 @@ routers = {'smart_router': smart_router} yaml_config = """ metadata: name: complex-workflow - + arium: agents: - name: dispatcher @@ -392,7 +404,7 @@ arium: settings: temperature: 0.1 reasoning_pattern: REACT - + - name: summarizer role: Final Summarizer job: "Create final summary from all previous results" @@ -402,10 +414,10 @@ arium: settings: temperature: 0.3 tools: ["calculator"] # This agent can also use tools - + tools: - name: calculator - + workflow: start: dispatcher edges: @@ -425,7 +437,7 @@ async def main(): tools=tools, routers=routers ) - + result = await builder.build_and_run([ "Please calculate 15 + 25 and then summarize the result" ]) @@ -442,7 +454,7 @@ You can also mix different configuration approaches in the same workflow: yaml_config = """ metadata: name: mixed-workflow - + arium: agents: # Direct configuration @@ -454,8 +466,8 @@ arium: name: gpt-4o-mini settings: temperature: 0.1 - - # Inline YAML configuration + + # Inline YAML configuration - name: data_analyzer yaml_config: | agent: @@ -468,11 +480,11 @@ arium: settings: temperature: 0.3 reasoning_pattern: COT - + # External file reference - name: report_generator yaml_file: "agents/report_generator.yaml" - + workflow: start: input_processor edges: @@ -493,13 +505,13 @@ You can build agents separately using `AgentBuilder` and then reference them in ```python import asyncio from flo_ai.arium.builder import AriumBuilder -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI # Build agents separately from YAML files async def create_agents(): llm = OpenAI(model="gpt-4o-mini") - + # Agent 1: Built from YAML file content_analyst_yaml = """ agent: @@ -515,9 +527,9 @@ async def create_agents(): temperature: 0.3 reasoning_pattern: COT """ - + content_analyst = AgentBuilder.from_yaml(yaml_str=content_analyst_yaml).build() - + # Agent 2: Built programmatically summarizer = (AgentBuilder() .with_name("summarizer") @@ -526,10 +538,10 @@ async def create_agents(): .with_llm(llm) .with_reasoning(ReasoningPattern.DIRECT) .build()) - + # Agent 3: Built from external file # reporter = AgentBuilder.from_yaml(yaml_file="agents/reporter.yaml").build() - + return { 'content_analyst': content_analyst, 'summarizer': summarizer, @@ -548,7 +560,7 @@ arium: # Reference pre-built agents by name only - name: content_analyst - name: summarizer - + # You can also mix with other configuration methods - name: validator role: Content Validator @@ -558,7 +570,7 @@ arium: name: gpt-4o-mini settings: temperature: 0.1 - + workflow: start: content_analyst edges: @@ -574,20 +586,20 @@ arium: async def main(): # Build agents separately agents = await create_agents() - + # Create workflow with pre-built agents builder = AriumBuilder.from_yaml( yaml_str=workflow_yaml, agents=agents # Pass pre-built agents ) - + result = await builder.build_and_run([ "Artificial Intelligence is revolutionizing industries across the globe. " "From healthcare diagnostics to financial trading, AI systems are providing " "unprecedented capabilities. However, challenges include ethical considerations, " "data privacy, and the need for human oversight in critical decisions." ]) - + print(result) asyncio.run(main()) @@ -630,7 +642,7 @@ def router_function(memory: BaseMemory) -> str: """ Args: memory: Current workflow memory containing conversation history - + Returns: str: Name of the next node to execute """ @@ -732,34 +744,44 @@ except Exception as e: Common errors and their solutions: ### Missing Required Sections + ``` ValueError: YAML must contain an "arium" section ``` + **Solution**: Ensure your YAML has the required `arium` section. ### Agent Configuration Errors + ``` ValueError: Agent {name} must have either yaml_config or yaml_file ``` + **Solution**: Each agent must specify either inline `yaml_config` or `yaml_file`. ### Tool Not Found + ``` ValueError: Tool {name} not found in provided tools dictionary ``` + **Solution**: Ensure all referenced tools are provided in the `tools` parameter. ### Router Not Found + ``` ValueError: Router {name} not found in provided routers dictionary ``` + **Solution**: Ensure all referenced routers are provided in the `routers` parameter. ### Invalid Workflow Structure + ``` ValueError: Workflow must specify a start node ValueError: Workflow must specify end nodes ``` + **Solution**: Ensure `workflow` section has both `start` and `end` specifications. ## Migration from Programmatic Builder @@ -767,6 +789,7 @@ ValueError: Workflow must specify end nodes To convert existing programmatic builder code to YAML: **Before:** + ```python builder = (AriumBuilder() .add_agent(agent1) @@ -777,6 +800,7 @@ builder = (AriumBuilder() ``` **After:** + ```yaml arium: agents: @@ -786,7 +810,7 @@ arium: - name: agent2 yaml_config: | # agent2 configuration - + workflow: start: agent1 edges: @@ -812,4 +836,4 @@ Planned features for future versions: - YAML-based tool definitions using function references - Configuration validation schemas - Hot-reloading of configurations -- Workflow debugging and visualization tools \ No newline at end of file +- Workflow debugging and visualization tools diff --git a/flo_ai/examples/agent_builder_usage.py b/flo_ai/examples/agent_builder_usage.py index c0dbdd9e..e874e2dd 100644 --- a/flo_ai/examples/agent_builder_usage.py +++ b/flo_ai/examples/agent_builder_usage.py @@ -1,8 +1,8 @@ import asyncio from flo_ai import UserMessage -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.tool.base_tool import Tool -from flo_ai.models.base_agent import ReasoningPattern +from flo_ai.agent import ReasoningPattern from flo_ai.llm.openai_llm import OpenAI from flo_ai.llm.anthropic_llm import Anthropic diff --git a/flo_ai/examples/arium_examples.py b/flo_ai/examples/arium_examples.py index b1810a5f..c6a0d69f 100644 --- a/flo_ai/examples/arium_examples.py +++ b/flo_ai/examples/arium_examples.py @@ -9,7 +9,7 @@ from typing import Literal from flo_ai.arium import AriumBuilder, create_arium from flo_ai.llm import OpenAI -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.arium.nodes import FunctionNode from flo_ai.arium.memory import MessageMemory, MessageMemoryItem from flo_ai.models import BaseMessage diff --git a/flo_ai/examples/arium_linear_usage.py b/flo_ai/examples/arium_linear_usage.py index ead8d12e..dc7974a3 100644 --- a/flo_ai/examples/arium_linear_usage.py +++ b/flo_ai/examples/arium_linear_usage.py @@ -1,6 +1,6 @@ import asyncio from flo_ai.arium.builder import AriumBuilder -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.llm.openai_llm import OpenAI diff --git a/flo_ai/examples/arium_yaml_example.py b/flo_ai/examples/arium_yaml_example.py index 5d18b540..a60200c8 100644 --- a/flo_ai/examples/arium_yaml_example.py +++ b/flo_ai/examples/arium_yaml_example.py @@ -733,8 +733,8 @@ async def run_prebuilt_agents_example(): print('=' * 60) # Build agents separately using AgentBuilder - from flo_ai.builder.agent_builder import AgentBuilder - from flo_ai.models.base_agent import ReasoningPattern + from flo_ai.agent import AgentBuilder + from flo_ai.agent import ReasoningPattern llm = OpenAI(model='gpt-4o-mini') diff --git a/flo_ai/examples/chat_history.py b/flo_ai/examples/chat_history.py index 12c32fab..cf54db16 100644 --- a/flo_ai/examples/chat_history.py +++ b/flo_ai/examples/chat_history.py @@ -1,8 +1,8 @@ import asyncio from typing import List -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import Gemini -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.models import AssistantMessage, UserMessage, BaseMessage from flo_ai.tool import flo_tool diff --git a/flo_ai/examples/cot_agent_example.py b/flo_ai/examples/cot_agent_example.py index f29844d6..5c2a68dd 100644 --- a/flo_ai/examples/cot_agent_example.py +++ b/flo_ai/examples/cot_agent_example.py @@ -5,8 +5,8 @@ import asyncio from flo_ai import UserMessage -from flo_ai.models.agent import Agent -from flo_ai.models.base_agent import ReasoningPattern +from flo_ai.agent import Agent +from flo_ai.agent import ReasoningPattern from flo_ai.llm.openai_llm import OpenAI from flo_ai.tool.base_tool import Tool import os diff --git a/flo_ai/examples/cot_conversational_example.py b/flo_ai/examples/cot_conversational_example.py index 7ae22e42..80cca997 100644 --- a/flo_ai/examples/cot_conversational_example.py +++ b/flo_ai/examples/cot_conversational_example.py @@ -4,8 +4,9 @@ """ import asyncio -from flo_ai.models.agent import Agent, UserMessage -from flo_ai.models.base_agent import ReasoningPattern +from flo_ai.agent import Agent +from flo_ai.models import UserMessage +from flo_ai.agent import ReasoningPattern from flo_ai.llm.openai_llm import OpenAI import os diff --git a/flo_ai/examples/custom_plan_execute_demo.py b/flo_ai/examples/custom_plan_execute_demo.py index f4902c37..fa5a3998 100644 --- a/flo_ai/examples/custom_plan_execute_demo.py +++ b/flo_ai/examples/custom_plan_execute_demo.py @@ -11,7 +11,7 @@ from flo_ai.arium.memory import PlanAwareMemory from flo_ai.arium.llm_router import create_plan_execute_router from flo_ai.arium import AriumBuilder -from flo_ai.models.plan_agents import PlannerAgent, ExecutorAgent +from flo_ai.agent import PlannerAgent, ExecutorAgent async def main(): diff --git a/flo_ai/examples/document_processing_example.py b/flo_ai/examples/document_processing_example.py index 98c19bbc..1999f207 100644 --- a/flo_ai/examples/document_processing_example.py +++ b/flo_ai/examples/document_processing_example.py @@ -18,7 +18,7 @@ from reportlab.lib.styles import getSampleStyleSheet from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.arium import AriumBuilder from flo_ai.llm import OpenAI, Gemini from flo_ai.models.document import DocumentType diff --git a/flo_ai/examples/example_graph_visualization.py b/flo_ai/examples/example_graph_visualization.py index af3abe86..6bae270b 100644 --- a/flo_ai/examples/example_graph_visualization.py +++ b/flo_ai/examples/example_graph_visualization.py @@ -5,11 +5,11 @@ """ from flo_ai.arium.base import BaseArium -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.tool.flo_tool import flo_tool from flo_ai.llm.openai_llm import OpenAI from flo_ai.llm.base_llm import BaseLLM -from typing import Literal +from typing import Literal, Any # Create a simple validation tool using the @flo_tool decorator @@ -48,7 +48,9 @@ def __init__(self, model): async def generate(self, messages, functions=None, output_schema=None): return {'content': 'Mock response'} - async def stream(self, messages, functions=None, output_schema=None): + async def stream( + self, messages, functions=None, output_schema=None, **kwargs: Any + ): yield {'content': 'Mock response'} def get_message_content(self, response): diff --git a/flo_ai/examples/flo_tool_example.py b/flo_ai/examples/flo_tool_example.py index 2ff92a46..1e055d57 100644 --- a/flo_ai/examples/flo_tool_example.py +++ b/flo_ai/examples/flo_tool_example.py @@ -1,7 +1,7 @@ import asyncio -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.tool import flo_tool, create_tool_from_function -from flo_ai.models.base_agent import ReasoningPattern +from flo_ai.agent import ReasoningPattern from flo_ai.llm.openai_llm import OpenAI from flo_ai.llm.anthropic_llm import Anthropic from flo_ai.llm.base_llm import BaseLLM diff --git a/flo_ai/examples/llm_router_example.py b/flo_ai/examples/llm_router_example.py index 85a84b7e..9f4a4444 100644 --- a/flo_ai/examples/llm_router_example.py +++ b/flo_ai/examples/llm_router_example.py @@ -11,9 +11,9 @@ from typing import Literal from flo_ai.arium import AriumBuilder, create_llm_router, llm_router -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI -from flo_ai.models.base_agent import ReasoningPattern +from flo_ai.agent import ReasoningPattern from flo_ai.arium.memory import BaseMemory diff --git a/flo_ai/examples/multi_tool_example.py b/flo_ai/examples/multi_tool_example.py index 084217d3..eb9966ee 100644 --- a/flo_ai/examples/multi_tool_example.py +++ b/flo_ai/examples/multi_tool_example.py @@ -1,7 +1,7 @@ import asyncio -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.tool.base_tool import Tool -from flo_ai.models.base_agent import ReasoningPattern +from flo_ai.agent import ReasoningPattern from flo_ai.llm import Anthropic, Gemini, OpenAI, BaseLLM diff --git a/flo_ai/examples/ollama_agent_example.py b/flo_ai/examples/ollama_agent_example.py index 28ef3851..f9973bce 100644 --- a/flo_ai/examples/ollama_agent_example.py +++ b/flo_ai/examples/ollama_agent_example.py @@ -1,7 +1,7 @@ import asyncio -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm.ollama_llm import OllamaLLM -from flo_ai.models.base_agent import ReasoningPattern +from flo_ai.agent import ReasoningPattern from flo_ai.tool.base_tool import Tool diff --git a/flo_ai/examples/output_formatter.py b/flo_ai/examples/output_formatter.py index 25b60d54..3ae61213 100644 --- a/flo_ai/examples/output_formatter.py +++ b/flo_ai/examples/output_formatter.py @@ -3,8 +3,8 @@ from pydantic import BaseModel, Field from flo_ai.llm.openai_llm import OpenAI from flo_ai.llm.anthropic_llm import Anthropic -from flo_ai.models.agent import Agent as ToolAgent -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import Agent as ToolAgent +from flo_ai.agent import AgentBuilder # Define the output schema using Pydantic diff --git a/flo_ai/examples/partial_tool_example.py b/flo_ai/examples/partial_tool_example.py index cf559795..28ac8432 100644 --- a/flo_ai/examples/partial_tool_example.py +++ b/flo_ai/examples/partial_tool_example.py @@ -8,7 +8,7 @@ import asyncio from flo_ai.tool.flo_tool import flo_tool from flo_ai.tool.partial_tool import create_partial_tool -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI diff --git a/flo_ai/examples/simple_flow_router_demo.py b/flo_ai/examples/simple_flow_router_demo.py index 8b917b60..033f750f 100644 --- a/flo_ai/examples/simple_flow_router_demo.py +++ b/flo_ai/examples/simple_flow_router_demo.py @@ -8,7 +8,7 @@ import asyncio from flo_ai.arium.builder import AriumBuilder from flo_ai.arium.memory import MessageMemory -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.llm import OpenAI from flo_ai.arium.llm_router import create_main_critic_reflection_router diff --git a/flo_ai/examples/simple_plan_execute_demo.py b/flo_ai/examples/simple_plan_execute_demo.py index d57a5172..9ca0814c 100644 --- a/flo_ai/examples/simple_plan_execute_demo.py +++ b/flo_ai/examples/simple_plan_execute_demo.py @@ -10,7 +10,7 @@ from flo_ai.arium.builder import AriumBuilder from flo_ai.arium.memory import PlanAwareMemory, ExecutionPlan, PlanStep, StepStatus from flo_ai.llm import OpenAI -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.arium.llm_router import create_plan_execute_router diff --git a/flo_ai/examples/simple_reflection_router_demo.py b/flo_ai/examples/simple_reflection_router_demo.py index 7c4c2e40..adca8d37 100644 --- a/flo_ai/examples/simple_reflection_router_demo.py +++ b/flo_ai/examples/simple_reflection_router_demo.py @@ -8,7 +8,7 @@ import asyncio from flo_ai.arium.builder import AriumBuilder from flo_ai.arium.memory import MessageMemory -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.llm import OpenAI from flo_ai.arium.llm_router import create_main_critic_reflection_router diff --git a/flo_ai/examples/simple_working_demo.py b/flo_ai/examples/simple_working_demo.py index d88c4faf..c2672e0d 100644 --- a/flo_ai/examples/simple_working_demo.py +++ b/flo_ai/examples/simple_working_demo.py @@ -9,12 +9,12 @@ import asyncio import os -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.llm import OpenAI from flo_ai.arium.memory import MessageMemory, MessageMemoryItem from flo_ai.arium.llm_router import create_plan_execute_router from flo_ai.arium import AriumBuilder -from flo_ai.models.agent import UserMessage +from flo_ai.models import UserMessage async def simple_working_demo(): diff --git a/flo_ai/examples/tool_usage.py b/flo_ai/examples/tool_usage.py index cea389d7..722f9c89 100644 --- a/flo_ai/examples/tool_usage.py +++ b/flo_ai/examples/tool_usage.py @@ -1,10 +1,10 @@ import asyncio from flo_ai.models import UserMessage -from flo_ai.models.agent import Agent as ToolAgent +from flo_ai.agent import Agent as ToolAgent from flo_ai.llm.openai_llm import OpenAI from flo_ai.tool.base_tool import Tool from flo_ai.models.agent_error import AgentError -from flo_ai.models.base_agent import ReasoningPattern +from flo_ai.agent import ReasoningPattern # Example of using ToolAgent as a conversational agent diff --git a/flo_ai/examples/tool_using_agent.py b/flo_ai/examples/tool_using_agent.py index 13d99eef..3c759c86 100644 --- a/flo_ai/examples/tool_using_agent.py +++ b/flo_ai/examples/tool_using_agent.py @@ -1,9 +1,9 @@ import asyncio from typing import Any -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.tool.base_tool import Tool -from flo_ai.models.base_agent import ReasoningPattern -from flo_ai.models.agent import Agent +from flo_ai.agent import ReasoningPattern +from flo_ai.agent import Agent from flo_ai.llm import Anthropic diff --git a/flo_ai/examples/tools_quickstart.py b/flo_ai/examples/tools_quickstart.py index 02561bef..d1f964c0 100644 --- a/flo_ai/examples/tools_quickstart.py +++ b/flo_ai/examples/tools_quickstart.py @@ -10,7 +10,7 @@ import asyncio from flo_ai.tool.flo_tool import flo_tool -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI diff --git a/flo_ai/examples/usage_claude.py b/flo_ai/examples/usage_claude.py index 5183d01a..bedb73f2 100644 --- a/flo_ai/examples/usage_claude.py +++ b/flo_ai/examples/usage_claude.py @@ -1,7 +1,7 @@ import asyncio import os -from flo_ai.models.base_agent import ReasoningPattern -from flo_ai.models.agent import Agent as ToolAgent +from flo_ai.agent import ReasoningPattern +from flo_ai.agent import Agent as ToolAgent from flo_ai.llm.anthropic_llm import Anthropic from flo_ai.tool.base_tool import Tool from flo_ai.models.agent_error import AgentError diff --git a/flo_ai/examples/variables_workflow_example.py b/flo_ai/examples/variables_workflow_example.py index e7887a29..d75d0de2 100644 --- a/flo_ai/examples/variables_workflow_example.py +++ b/flo_ai/examples/variables_workflow_example.py @@ -8,9 +8,9 @@ import os import asyncio -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.arium.builder import AriumBuilder -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.llm.gemini_llm import Gemini from dotenv import load_dotenv diff --git a/flo_ai/examples/variables_workflow_yaml_example.py b/flo_ai/examples/variables_workflow_yaml_example.py index 8c69d20d..7fc53c2d 100644 --- a/flo_ai/examples/variables_workflow_yaml_example.py +++ b/flo_ai/examples/variables_workflow_yaml_example.py @@ -10,7 +10,7 @@ import os import asyncio -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.arium.builder import AriumBuilder from flo_ai.llm.gemini_llm import Gemini diff --git a/flo_ai/examples/vertexai_agent_example.py b/flo_ai/examples/vertexai_agent_example.py index 253aec69..2ad54776 100644 --- a/flo_ai/examples/vertexai_agent_example.py +++ b/flo_ai/examples/vertexai_agent_example.py @@ -7,10 +7,10 @@ import asyncio import os -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm.vertexai_llm import VertexAI -from flo_ai.models.agent import Agent as ToolAgent -from flo_ai.models.base_agent import ReasoningPattern +from flo_ai.agent import Agent as ToolAgent +from flo_ai.agent import ReasoningPattern from flo_ai.tool.base_tool import Tool from flo_ai.models.agent_error import AgentError diff --git a/flo_ai/examples/vllm_agent_usage.py b/flo_ai/examples/vllm_agent_usage.py index 3b3f219f..dc9d05bc 100644 --- a/flo_ai/examples/vllm_agent_usage.py +++ b/flo_ai/examples/vllm_agent_usage.py @@ -1,8 +1,8 @@ import asyncio import os -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.tool.base_tool import Tool -from flo_ai.models.base_agent import ReasoningPattern +from flo_ai.agent import ReasoningPattern from flo_ai.llm.openai_vllm import OpenAIVLLM from dotenv import load_dotenv diff --git a/flo_ai/examples/yaml_agent_example.py b/flo_ai/examples/yaml_agent_example.py index e306bfd5..2ab40e6f 100644 --- a/flo_ai/examples/yaml_agent_example.py +++ b/flo_ai/examples/yaml_agent_example.py @@ -1,4 +1,4 @@ -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder # Example YAML configuration yaml_config = """ diff --git a/flo_ai/examples/yaml_tool_config_example.py b/flo_ai/examples/yaml_tool_config_example.py index d2ff0168..e5d920e6 100644 --- a/flo_ai/examples/yaml_tool_config_example.py +++ b/flo_ai/examples/yaml_tool_config_example.py @@ -7,7 +7,7 @@ import asyncio from flo_ai.tool.flo_tool import flo_tool -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder # Define some example tools diff --git a/flo_ai/flo_ai/__init__.py b/flo_ai/flo_ai/__init__.py index 900a6d1c..5b5d8558 100644 --- a/flo_ai/flo_ai/__init__.py +++ b/flo_ai/flo_ai/__init__.py @@ -4,13 +4,8 @@ # Models package - Agent framework components from .models import ( - Agent, AgentError, - BaseAgent, - AgentType, - ReasoningPattern, DocumentType, - MessageType, SystemMessage, UserMessage, AssistantMessage, @@ -20,10 +15,9 @@ TextMessageContent, ImageMessageContent, DocumentMessageContent, + MessageType, ) -from .builder.agent_builder import AgentBuilder - # LLM package - Language model integrations from .llm import BaseLLM, Anthropic, OpenAI, OllamaLLM, Gemini, OpenAIVLLM @@ -47,6 +41,14 @@ create_arium, ) +from .agent import ( + Agent, + BaseAgent, + AgentType, + ReasoningPattern, + AgentBuilder, +) + # Utils package - Utility functions from .utils import FloUtils diff --git a/flo_ai/flo_ai/agent/__init__.py b/flo_ai/flo_ai/agent/__init__.py new file mode 100644 index 00000000..5381bbaa --- /dev/null +++ b/flo_ai/flo_ai/agent/__init__.py @@ -0,0 +1,14 @@ +from .agent import Agent +from .base_agent import BaseAgent, AgentType, ReasoningPattern +from .plan_agents import PlannerAgent, ExecutorAgent +from .builder import AgentBuilder + +__all__ = [ + 'Agent', + 'BaseAgent', + 'AgentType', + 'ReasoningPattern', + 'PlannerAgent', + 'ExecutorAgent', + 'AgentBuilder', +] diff --git a/flo_ai/flo_ai/agent/agent.py b/flo_ai/flo_ai/agent/agent.py new file mode 100644 index 00000000..e8a27a58 --- /dev/null +++ b/flo_ai/flo_ai/agent/agent.py @@ -0,0 +1,643 @@ +import json +from typing import Dict, Any, List, Optional +from flo_ai.agent.base_agent import BaseAgent, AgentType, ReasoningPattern +from flo_ai.llm.base_llm import BaseLLM +from flo_ai.models.chat_message import ( + AssistantMessage, + BaseMessage, + MessageType, + UserMessage, + TextMessageContent, + FunctionMessage, + SystemMessage, +) +from flo_ai.tool.base_tool import Tool, ToolExecutionError +from flo_ai.models.agent_error import AgentError +from flo_ai.utils.logger import logger +from flo_ai.utils.variable_extractor import ( + extract_variables_from_inputs, + extract_agent_variables, + validate_multi_agent_variables, + resolve_variables, +) +from flo_ai.telemetry.instrumentation import ( + trace_agent_execution, + agent_metrics, +) +from flo_ai.telemetry import get_tracer + + +class Agent(BaseAgent): + def __init__( + self, + name: str, + system_prompt: str | AssistantMessage, + llm: BaseLLM, + tools: Optional[List[Tool]] = None, + max_retries: int = 0, + max_tool_calls: int = 5, + reasoning_pattern: ReasoningPattern = ReasoningPattern.DIRECT, + output_schema: Optional[Dict[str, Any]] = None, + role: Optional[str] = None, + act_as: Optional[str] = MessageType.ASSISTANT, + input_filter: Optional[List[str]] = None, + ): + # Determine agent type based on tools + agent_type = AgentType.TOOL_USING if tools else AgentType.CONVERSATIONAL + + # Enhance system prompt with role if provided + enhanced_prompt = system_prompt + if role: + if isinstance(system_prompt, str): + enhanced_prompt = f'You are {role}. {system_prompt}' + elif isinstance(system_prompt, AssistantMessage): + enhanced_prompt = f'You are {role}. {system_prompt.content}' + + super().__init__( + name=name, + system_prompt=str(enhanced_prompt), + agent_type=agent_type, + llm=llm, + max_retries=max_retries, + max_tool_calls=max_tool_calls, + ) + self.tools = tools or [] + self.tools_dict = {tool.name: tool for tool in self.tools} + self.reasoning_pattern = reasoning_pattern + self.output_schema = output_schema + self.role = role + self.act_as = act_as + self.input_filter: Optional[List[str]] = input_filter + + @trace_agent_execution() + async def run( + self, + inputs: List[BaseMessage] | str, + variables: Optional[Dict[str, Any]] = None, + ) -> List[BaseMessage]: + variables = variables or {} + if isinstance(inputs, str): + inputs = [UserMessage(TextMessageContent(text=inputs))] + + # Perform runtime variable validation if not already resolved (single agent usage) + if not self.resolved_variables: + # Extract variables from inputs and system prompt + input_variables = extract_variables_from_inputs(inputs) + agent_variables = extract_agent_variables(self) + all_required_variables = input_variables.union(agent_variables) + + # Validate that all required variables are provided + if all_required_variables: + agents_variables = {self.name: all_required_variables} + validate_multi_agent_variables(agents_variables, variables) + + # Resolve variables and mark as resolved + self.system_prompt = resolve_variables(self.system_prompt, variables) + + # Process inputs and resolve variables in string inputs + for input in inputs: + if isinstance(input, BaseMessage): + # checking whether the TextMessageContent is resolved + if isinstance(input.content, TextMessageContent) and variables: + input.content.text = resolve_variables( + input.content.text, variables + ) + self.add_to_history(input) + else: + raise ValueError(f'Invalid input type: {type(input)}') + # after resolving agent system prompts and inputs, mark variables as resolved + self.resolved_variables = True + + else: + # Variables already resolved, process inputs without variable resolution + for input in inputs: + if isinstance(input, BaseMessage): + # Handle InputMessage - check content type + self.add_to_history(input) + else: + raise ValueError(f'Invalid input type: {type(input)}') + + retry_count = 0 + + # If no tools, act as conversational agent + if not self.tools: + return await self._run_conversational(retry_count, variables) + + # Otherwise, run as tool agent + return await self._run_with_tools(retry_count, variables) + + async def _run_conversational( + self, retry_count: int, variables: Optional[Dict[str, Any]] = None + ) -> List[BaseMessage]: + """Run as a conversational agent when no tools are provided""" + variables = variables or {} + + while retry_count <= self.max_retries: + try: + # Resolve variables in system prompt + system_content = ( + self._get_cot_prompt(variables) + if self.reasoning_pattern == ReasoningPattern.COT + else resolve_variables(self.system_prompt, variables) + ) + system_message = SystemMessage(content=system_content) + self.add_to_history(system_message) + messages = await self._get_message_history(variables) + + logger.debug(f'Sending messages to LLM: {messages}') + response = await self.llm.generate( + messages, output_schema=self.output_schema + ) + logger.debug(f'Raw LLM Response: {response}') + + assistant_message = self.llm.get_message_content(response) + logger.debug(f'Extracted message: {assistant_message}') + + # Ensure act_as is not None (default to 'assistant' if missing) + role = self.act_as if self.act_as is not None else MessageType.ASSISTANT + + if assistant_message: + self.add_to_history( + AssistantMessage(role=role, content=assistant_message) + ) + else: + possible_tool_message = await self.llm.get_function_call(response) + if possible_tool_message: + self.add_to_history( + AssistantMessage( + role=role, content=possible_tool_message['arguments'] + ) + ) + else: + logger.debug('Warning: No message content found in response') + self.add_to_history( + AssistantMessage( + role=role, + content='No message content found in response', + ) + ) + + return self.conversation_history + + except Exception as e: + retry_count += 1 + context = { + 'conversation_history': self.conversation_history, + 'attempt': retry_count, + } + + should_retry, analysis = await self.handle_error(e, context) + + if should_retry and retry_count <= self.max_retries: + self.add_to_history( + AssistantMessage( + content=f'Error occurred. Analysis: {analysis}' + ) + ) + continue + else: + raise AgentError( + f'Failed after {retry_count} attempts. Last error: {analysis}', + original_error=e, + ) + + # return conversation history if we exit the loop without returning + return self.conversation_history + + async def _run_with_tools( + self, retry_count: int = 0, variables: Optional[Dict[str, Any]] = None + ) -> List[BaseMessage]: + """Run as a tool-using agent when tools are provided""" + variables = variables or {} + print('running with tools') + + while retry_count <= self.max_retries: + try: + # Resolve variables in system prompt based on reasoning pattern + if self.reasoning_pattern == ReasoningPattern.REACT: + system_content = self._get_react_prompt(variables) + elif self.reasoning_pattern == ReasoningPattern.COT: + system_content = self._get_cot_prompt(variables) + else: + system_content = resolve_variables(self.system_prompt, variables) + + system_message = SystemMessage(content=system_content) + self.add_to_history(system_message) + messages = await self._get_message_history(variables) + + # Keep executing tools until we get a final answer + tool_call_count = 0 + function_response = None + function_name = None + while tool_call_count < self.max_tool_calls: + formatted_tools = self.llm.format_tools_for_llm(self.tools) + response = await self.llm.generate( + messages, + functions=formatted_tools, + output_schema=self.output_schema, + ) + + # Handle ReACT and CoT patterns + function_call = await self.llm.get_function_call(response) + + # If no function call, check if this is truly a final answer + if not function_call: + assistant_message = self.llm.get_message_content(response) + if assistant_message: + # Check if this is a final answer or just intermediate reasoning + is_final = await self._is_final_answer( + assistant_message, tool_call_count, messages + ) + if is_final: + # Ensure act_as is not None (default to 'assistant' if missing) + role = ( + self.act_as + if self.act_as is not None + else MessageType.ASSISTANT + ) + self.add_to_history( + AssistantMessage( + role=role, content=assistant_message + ) + ) + return self.conversation_history + else: + # This is intermediate reasoning, add to context and continue + msg_preview = ( + assistant_message[:100] + if len(assistant_message) > 100 + else assistant_message + ) + logger.debug( + f'Detected intermediate reasoning (not final answer): {msg_preview}...' + ) + # Ensure act_as is not None (default to 'assistant' if missing) + role = ( + self.act_as + if self.act_as is not None + else MessageType.ASSISTANT + ) + self.add_to_history( + AssistantMessage( + role=role, content=assistant_message + ) + ) + self.add_to_history( + UserMessage( + content='Based on your reasoning, please proceed with the necessary tool calls to complete the task.', + ) + ) + continue + break + + # If there's a function call, add the assistant's response + # LLM-specific implementations handle special formatting (e.g., Claude's raw_content) + assistant_message_content = ( + self.llm.get_assistant_message_for_tool_call(response) + ) + if assistant_message_content: + # LLM returned special formatting (e.g., Claude's raw_content) + messages.append( + { + 'role': self.act_as, + 'content': assistant_message_content, + } + ) + else: + # Use default text content extraction + assistant_text = self.llm.get_message_content(response) + if assistant_text: + messages.append( + { + 'role': self.act_as, + 'content': assistant_text, + } + ) + + # Execute the tool + try: + function_name = function_call['name'] + # Get tool_use_id if available (LLM-specific, e.g., Claude) + tool_use_id = self.llm.get_tool_use_id(function_call) + if isinstance(function_call['arguments'], str): + function_args = json.loads(function_call['arguments']) + else: + function_args = function_call['arguments'] + + tool = self.tools_dict[function_name] + + # Track tool execution with telemetry + tracer = get_tracer() + + if tracer: + with tracer.start_as_current_span( + f'agent.tool.{function_name}', + attributes={ + 'tool.name': function_name, + 'agent.name': self.name, + }, + ) as tool_span: + function_response = await tool.run( + inputs=[], variables=None, **function_args + ) + tool_span.set_attribute( + 'tool.result.length', len(str(function_response)) + ) + else: + function_response = await tool.run( + inputs=[], variables=None, **function_args + ) + + agent_metrics.record_tool_call( + self.name, function_name, 'success' + ) + + tool_call_count += 1 + + # Add function call result to history using OpenAI's "function" role format + # According to OpenAI API: {"role": "function", "name": "", "content": ""} + self.add_to_history( + FunctionMessage( + content=str( + 'Here is the result of the tool call: \n' + + str(function_response) + ), + name=function_name, + ) + ) + + # Add the function response to messages for context + # LLM-specific implementations format the message appropriately + function_result_msg = self.llm.format_function_result_message( + function_name, str(function_response), tool_use_id + ) + messages.append(function_result_msg) + + except (json.JSONDecodeError, KeyError, ToolExecutionError) as e: + # Record tool call failure + # Safely extract function_name from function_call if available + error_function_name: str = function_name or 'unknown' + if error_function_name == 'unknown' and function_call: + error_function_name = function_call.get('name') or 'unknown' + agent_metrics.record_tool_call( + self.name, error_function_name, 'error' + ) + + retry_count += 1 + context = { + 'function_call': function_call, + 'attempt': retry_count, + } + should_retry, analysis = await self.handle_error(e, context) + if should_retry and retry_count <= self.max_retries: + # Record retry + agent_metrics.record_retry( + self.name, 'tool_execution_error' + ) + + self.add_to_history( + AssistantMessage( + content=f'Tool execution error: {analysis}' + ) + ) + continue + raise AgentError( + f'Tool execution failed: {analysis}', original_error=e + ) + + # If no tools were called, return conversation history + if tool_call_count == 0: + return self.conversation_history + + # Generate final response if we've hit the tool call limit or exited the loop + system_message = SystemMessage( + content='Please provide a final answer based on all the tool results above.' + ) + self.add_to_history(system_message) + messages = await self._get_message_history(variables) + + final_response = await self.llm.generate( + messages, + output_schema=self.output_schema, + ) + + assistant_message = self.llm.get_message_content(final_response) + if assistant_message: + # Ensure act_as is not None (default to 'assistant' if missing) + role = ( + self.act_as + if self.act_as is not None + else MessageType.ASSISTANT + ) + self.add_to_history(AssistantMessage(content=assistant_message)) + return self.conversation_history + + # Fallback: return function message only if we have valid tool execution data + if function_response is not None and function_name is not None: + return [ + FunctionMessage( + content=str( + 'The final result based on the tool executions is: \n' + + str(function_response) + ), + name=function_name, + ) + ] + else: + # No tools were executed and no assistant message, return safe fallback + return self.conversation_history + + except Exception as e: + retry_count += 1 + context = { + 'conversation_history': self.conversation_history, + 'attempt': retry_count, + } + + should_retry, analysis = await self.handle_error(e, context) + if should_retry and retry_count <= self.max_retries: + # Record retry + agent_metrics.record_retry(self.name, 'execution_error') + + self.add_to_history( + AssistantMessage( + content=f'Error occurred. Analysis: {analysis}' + ) + ) + continue + + raise AgentError( + f'Failed after {retry_count} attempts. Last error: {analysis}', + original_error=e, + ) + + raise AgentError(f'Failed after maximum {self.max_retries} attempts.') + + def _get_react_prompt(self, variables: Optional[Dict[str, Any]] = None) -> str: + """Get system prompt modified for ReACT pattern""" + variables = variables or {} + + tools_desc = '\n'.join( + [f'- {tool.name}: {tool.description}' for tool in self.tools] + ) + + # Resolve variables in the base system prompt + resolved_system_prompt = resolve_variables(self.system_prompt, variables) + + react_prompt = f"""{resolved_system_prompt} + When solving tasks, follow this format: + + Thought: Analyze the situation and think about what to do + Action: Use available tools in the format: tool_name(param1: "value1", param2: "value2") + Observation: The result of the action + ... (repeat Thought/Action/Observation if needed) + Final Answer: [Your complete answer to the user's question] + + Available tools: + {tools_desc} + + Remember to: + 1. Think carefully about what needs to be done + 2. Use tools when needed + 3. Make observations about tool results + 4. Conclude with a final answer when the task is complete + + IMPORTANT: When you have enough information to answer the user's question, you MUST prefix your response with "Final Answer:" to indicate completion.""" + + return react_prompt + + def _get_cot_prompt(self, variables: Optional[Dict[str, Any]] = None) -> str: + """Get system prompt modified for Chain of Thought pattern""" + variables = variables or {} + + tools_desc = '\n'.join( + [f'- {tool.name}: {tool.description}' for tool in self.tools] + ) + + # Resolve variables in the base system prompt + resolved_system_prompt = resolve_variables(self.system_prompt, variables) + + cot_prompt = f"""{resolved_system_prompt} + When solving tasks, follow this Chain of Thought reasoning format: + + Let me think through this step by step: + 1. First, I need to understand what is being asked... + 2. Then, I should consider what information or tools I need.... Use available tools in the format: tool_name(param1: "value1", param2: "value2") + 3. Next, I'll analyze the available options... + 4. Finally, I'll provide a well-reasoned answer... + + Available tools: + {tools_desc} + + Remember to: + 1. Break down complex problems into smaller steps + 2. Think through each step logically + 3. Use tools when needed to gather information + 4. Provide clear reasoning for your conclusions + 5. End with a final, well-justified answer + + IMPORTANT: When you have gathered all necessary information and are ready to provide your complete answer, you MUST prefix your response with "Final Answer:" to indicate completion.""" + + return cot_prompt + + async def _is_final_answer( + self, message: str, tool_call_count: int, messages: List[Dict[str, Any]] + ) -> bool: + """ + Determine if a message is a final answer or intermediate reasoning. + Uses structured token detection (like LangChain's ReAct) with LLM fallback. + + Approach inspired by LangChain/CrewAI: + 1. Primary: Check for explicit "Final Answer:" token + 2. Fallback: Use LLM-based classification for robustness + """ + message_stripped = message.strip() + message_lower = message_stripped.lower() + + # Primary Detection: Explicit "Final Answer:" token (ReAct pattern) + # This is the most reliable method used by LangChain and similar frameworks + if message_stripped.startswith('Final Answer:') or message_lower.startswith( + 'final answer:' + ): + logger.debug('Explicit "Final Answer:" token detected - this is final') + return True + + # Check if "Final Answer:" appears anywhere in the response + # (agent might add context before the token) + if 'final answer:' in message_lower: + logger.debug('"Final Answer:" token found in response - treating as final') + return True + + # Secondary Detection: Use LLM-based analysis for cases without explicit tokens + # This handles: + # - Agents not following the format perfectly + # - Direct mode (without ReAct/CoT patterns) + # - Edge cases where the agent provides answer without token + + analysis_prompt = f"""You are a classifier that determines if an AI agent's response is a FINAL ANSWER or INTERMEDIATE REASONING. + +Agent's Response: +"{message_stripped}" + +Context: +- Tool calls executed so far: {tool_call_count} +- Total conversation turns: {len(messages)} + +Classification Criteria: + +FINAL ANSWER - The response is final if it: +✓ Directly answers the user's original question with concrete information +✓ Provides specific data, results, or conclusions +✓ Does not suggest or request additional actions +✓ Reads like a complete, standalone answer +✓ Contains synthesis of information already gathered + +INTERMEDIATE REASONING - The response is intermediate if it: +✗ Describes plans or intentions for what to do next +✗ Expresses need to gather more information +✗ Contains thinking/reasoning WITHOUT providing the actual answer +✗ Poses questions or expresses uncertainty about next steps +✗ Mentions specific tools it wants to use + +Examples of INTERMEDIATE: +- "I need to query the database schema first" +- "Let me check the table structure" +- "First, I should examine..." + +Examples of FINAL: +- "Based on the query results, the table contains 1,245 records..." +- "The analysis shows that revenue increased by 23%..." +- "After examining the data, the answer is..." + +Respond with EXACTLY one word: "FINAL" or "INTERMEDIATE" +""" + + try: + analysis_messages = [ + { + 'role': MessageType.SYSTEM, + 'content': 'You are a precise classification system. Respond with only FINAL or INTERMEDIATE.', + }, + {'role': MessageType.USER, 'content': analysis_prompt}, + ] + analysis_response = await self.llm.generate(analysis_messages) + analysis = self.llm.get_message_content(analysis_response).strip().upper() + + is_final = 'FINAL' in analysis + msg_preview = ( + message_stripped[:80] + if len(message_stripped) > 80 + else message_stripped + ) + logger.debug( + f'LLM classifier: "{analysis}" -> is_final={is_final} (message preview: "{msg_preview}...")' + ) + return is_final + + except Exception as e: + logger.warning( + f'LLM classification failed: {e}. Defaulting to final=False to allow continuation.' + ) + # Conservative default: treat as intermediate to avoid premature exit + # This is safer as it allows the agent to continue rather than stopping too early + return False diff --git a/flo_ai/flo_ai/models/base_agent.py b/flo_ai/flo_ai/agent/base_agent.py similarity index 100% rename from flo_ai/flo_ai/models/base_agent.py rename to flo_ai/flo_ai/agent/base_agent.py diff --git a/flo_ai/flo_ai/builder/agent_builder.py b/flo_ai/flo_ai/agent/builder.py similarity index 99% rename from flo_ai/flo_ai/builder/agent_builder.py rename to flo_ai/flo_ai/agent/builder.py index 482bd21b..a94ccf83 100644 --- a/flo_ai/flo_ai/builder/agent_builder.py +++ b/flo_ai/flo_ai/agent/builder.py @@ -1,8 +1,8 @@ from typing import List, Optional, Dict, Any, Union, Type from flo_ai.models import AssistantMessage import yaml -from flo_ai.models.agent import Agent -from flo_ai.models.base_agent import ReasoningPattern +from flo_ai.agent import Agent +from flo_ai.agent.base_agent import ReasoningPattern from flo_ai.llm import BaseLLM from flo_ai.tool.base_tool import Tool from flo_ai.tool.tool_config import ToolConfig, create_tool_config diff --git a/flo_ai/flo_ai/models/plan_agents.py b/flo_ai/flo_ai/agent/plan_agents.py similarity index 99% rename from flo_ai/flo_ai/models/plan_agents.py rename to flo_ai/flo_ai/agent/plan_agents.py index d49df782..e2b947fb 100644 --- a/flo_ai/flo_ai/models/plan_agents.py +++ b/flo_ai/flo_ai/agent/plan_agents.py @@ -6,7 +6,7 @@ """ from typing import List, Optional -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.llm.base_llm import BaseLLM from flo_ai.arium.memory import PlanAwareMemory from flo_ai.tool.plan_tool import PlanTool, StepTool, PlanStatusTool diff --git a/flo_ai/flo_ai/arium/arium.py b/flo_ai/flo_ai/arium/arium.py index 2d6efbe6..715fcece 100644 --- a/flo_ai/flo_ai/arium/arium.py +++ b/flo_ai/flo_ai/arium/arium.py @@ -2,7 +2,7 @@ from flo_ai.arium.memory import MessageMemory, MessageMemoryItem from flo_ai.models import BaseMessage, UserMessage, TextMessageContent from typing import List, Dict, Any, Optional, Callable -from flo_ai.models.agent import Agent +from flo_ai.agent.agent import Agent from flo_ai.arium.base import AriumNodeType from flo_ai.arium.models import StartNode, EndNode from flo_ai.arium.events import AriumEventType, AriumEvent diff --git a/flo_ai/flo_ai/arium/base.py b/flo_ai/flo_ai/arium/base.py index 33bacca4..d49b92e0 100644 --- a/flo_ai/flo_ai/arium/base.py +++ b/flo_ai/flo_ai/arium/base.py @@ -2,7 +2,7 @@ from functools import partial from flo_ai.arium.nodes import AriumNode, ForEachNode, FunctionNode from flo_ai.arium.protocols import ExecutableNode -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.tool.base_tool import Tool from flo_ai.utils.logger import logger from typing import List, Optional, Callable, Literal, get_origin, get_args, Dict, Union diff --git a/flo_ai/flo_ai/arium/builder.py b/flo_ai/flo_ai/arium/builder.py index f76a3c5d..3f0e98b0 100644 --- a/flo_ai/flo_ai/arium/builder.py +++ b/flo_ai/flo_ai/arium/builder.py @@ -3,10 +3,10 @@ from flo_ai.arium.memory import MessageMemory, MessageMemoryItem from flo_ai.arium.nodes import AriumNode, ForEachNode from flo_ai.models import BaseMessage, UserMessage -from flo_ai.models.agent import Agent, resolve_variables +from flo_ai.agent.agent import Agent, resolve_variables from flo_ai.tool.base_tool import Tool import yaml -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import BaseLLM from flo_ai.arium.llm_router import create_llm_router from flo_ai.arium.nodes import FunctionNode @@ -906,7 +906,7 @@ def _create_agent_from_direct_config( Returns: Agent: Configured agent instance """ - from flo_ai.models.base_agent import ReasoningPattern + from flo_ai.agent.base_agent import ReasoningPattern # from flo_ai.llm import OpenAI, Anthropic, Gemini, OllamaLLM # Extract basic configuration diff --git a/flo_ai/flo_ai/models/__init__.py b/flo_ai/flo_ai/models/__init__.py index 599cab6f..1467adcc 100644 --- a/flo_ai/flo_ai/models/__init__.py +++ b/flo_ai/flo_ai/models/__init__.py @@ -2,9 +2,7 @@ Models package for flo_ai - Agent framework components """ -from .agent import Agent, MessageType from .agent_error import AgentError -from .base_agent import BaseAgent, AgentType, ReasoningPattern from .document import DocumentType from .chat_message import ( SystemMessage, @@ -16,16 +14,12 @@ TextMessageContent, ImageMessageContent, DocumentMessageContent, + MessageType, ) __all__ = [ - 'Agent', 'AgentError', - 'BaseAgent', - 'AgentType', - 'ReasoningPattern', 'DocumentType', - 'MessageType', 'SystemMessage', 'UserMessage', 'AssistantMessage', @@ -35,4 +29,5 @@ 'TextMessageContent', 'ImageMessageContent', 'DocumentMessageContent', + 'MessageType', ] diff --git a/flo_ai/flo_ai/models/agent.py b/flo_ai/flo_ai/models/agent.py index b773865f..e69de29b 100644 --- a/flo_ai/flo_ai/models/agent.py +++ b/flo_ai/flo_ai/models/agent.py @@ -1,643 +0,0 @@ -import json -from typing import Dict, Any, List, Optional -from flo_ai.models.base_agent import BaseAgent, AgentType, ReasoningPattern -from flo_ai.llm.base_llm import BaseLLM -from flo_ai.models.chat_message import ( - AssistantMessage, - BaseMessage, - MessageType, - UserMessage, - TextMessageContent, - FunctionMessage, - SystemMessage, -) -from flo_ai.tool.base_tool import Tool, ToolExecutionError -from flo_ai.models.agent_error import AgentError -from flo_ai.utils.logger import logger -from flo_ai.utils.variable_extractor import ( - extract_variables_from_inputs, - extract_agent_variables, - validate_multi_agent_variables, - resolve_variables, -) -from flo_ai.telemetry.instrumentation import ( - trace_agent_execution, - agent_metrics, -) -from flo_ai.telemetry import get_tracer - - -class Agent(BaseAgent): - def __init__( - self, - name: str, - system_prompt: str | AssistantMessage, - llm: BaseLLM, - tools: Optional[List[Tool]] = None, - max_retries: int = 0, - max_tool_calls: int = 5, - reasoning_pattern: ReasoningPattern = ReasoningPattern.DIRECT, - output_schema: Optional[Dict[str, Any]] = None, - role: Optional[str] = None, - act_as: Optional[str] = MessageType.ASSISTANT, - input_filter: Optional[List[str]] = None, - ): - # Determine agent type based on tools - agent_type = AgentType.TOOL_USING if tools else AgentType.CONVERSATIONAL - - # Enhance system prompt with role if provided - enhanced_prompt = system_prompt - if role: - if isinstance(system_prompt, str): - enhanced_prompt = f'You are {role}. {system_prompt}' - elif isinstance(system_prompt, AssistantMessage): - enhanced_prompt = f'You are {role}. {system_prompt.content}' - - super().__init__( - name=name, - system_prompt=str(enhanced_prompt), - agent_type=agent_type, - llm=llm, - max_retries=max_retries, - max_tool_calls=max_tool_calls, - ) - self.tools = tools or [] - self.tools_dict = {tool.name: tool for tool in self.tools} - self.reasoning_pattern = reasoning_pattern - self.output_schema = output_schema - self.role = role - self.act_as = act_as - self.input_filter: Optional[List[str]] = input_filter - - @trace_agent_execution() - async def run( - self, - inputs: List[BaseMessage] | str, - variables: Optional[Dict[str, Any]] = None, - ) -> List[BaseMessage]: - variables = variables or {} - if isinstance(inputs, str): - inputs = [UserMessage(TextMessageContent(text=inputs))] - - # Perform runtime variable validation if not already resolved (single agent usage) - if not self.resolved_variables: - # Extract variables from inputs and system prompt - input_variables = extract_variables_from_inputs(inputs) - agent_variables = extract_agent_variables(self) - all_required_variables = input_variables.union(agent_variables) - - # Validate that all required variables are provided - if all_required_variables: - agents_variables = {self.name: all_required_variables} - validate_multi_agent_variables(agents_variables, variables) - - # Resolve variables and mark as resolved - self.system_prompt = resolve_variables(self.system_prompt, variables) - - # Process inputs and resolve variables in string inputs - for input in inputs: - if isinstance(input, BaseMessage): - # checking whether the TextMessageContent is resolved - if isinstance(input.content, TextMessageContent) and variables: - input.content.text = resolve_variables( - input.content.text, variables - ) - self.add_to_history(input) - else: - raise ValueError(f'Invalid input type: {type(input)}') - # after resolving agent system prompts and inputs, mark variables as resolved - self.resolved_variables = True - - else: - # Variables already resolved, process inputs without variable resolution - for input in inputs: - if isinstance(input, BaseMessage): - # Handle InputMessage - check content type - self.add_to_history(input) - else: - raise ValueError(f'Invalid input type: {type(input)}') - - retry_count = 0 - - # If no tools, act as conversational agent - if not self.tools: - return await self._run_conversational(retry_count, variables) - - # Otherwise, run as tool agent - return await self._run_with_tools(retry_count, variables) - - async def _run_conversational( - self, retry_count: int, variables: Optional[Dict[str, Any]] = None - ) -> List[BaseMessage]: - """Run as a conversational agent when no tools are provided""" - variables = variables or {} - - while retry_count <= self.max_retries: - try: - # Resolve variables in system prompt - system_content = ( - self._get_cot_prompt(variables) - if self.reasoning_pattern == ReasoningPattern.COT - else resolve_variables(self.system_prompt, variables) - ) - system_message = SystemMessage(content=system_content) - self.add_to_history(system_message) - messages = await self._get_message_history(variables) - - logger.debug(f'Sending messages to LLM: {messages}') - response = await self.llm.generate( - messages, output_schema=self.output_schema - ) - logger.debug(f'Raw LLM Response: {response}') - - assistant_message = self.llm.get_message_content(response) - logger.debug(f'Extracted message: {assistant_message}') - - # Ensure act_as is not None (default to 'assistant' if missing) - role = self.act_as if self.act_as is not None else MessageType.ASSISTANT - - if assistant_message: - self.add_to_history( - AssistantMessage(role=role, content=assistant_message) - ) - else: - possible_tool_message = await self.llm.get_function_call(response) - if possible_tool_message: - self.add_to_history( - AssistantMessage( - role=role, content=possible_tool_message['arguments'] - ) - ) - else: - logger.debug('Warning: No message content found in response') - self.add_to_history( - AssistantMessage( - role=role, - content='No message content found in response', - ) - ) - - return self.conversation_history - - except Exception as e: - retry_count += 1 - context = { - 'conversation_history': self.conversation_history, - 'attempt': retry_count, - } - - should_retry, analysis = await self.handle_error(e, context) - - if should_retry and retry_count <= self.max_retries: - self.add_to_history( - AssistantMessage( - content=f'Error occurred. Analysis: {analysis}' - ) - ) - continue - else: - raise AgentError( - f'Failed after {retry_count} attempts. Last error: {analysis}', - original_error=e, - ) - - # return conversation history if we exit the loop without returning - return self.conversation_history - - async def _run_with_tools( - self, retry_count: int = 0, variables: Optional[Dict[str, Any]] = None - ) -> List[BaseMessage]: - """Run as a tool-using agent when tools are provided""" - variables = variables or {} - print('running with tools') - - while retry_count <= self.max_retries: - try: - # Resolve variables in system prompt based on reasoning pattern - if self.reasoning_pattern == ReasoningPattern.REACT: - system_content = self._get_react_prompt(variables) - elif self.reasoning_pattern == ReasoningPattern.COT: - system_content = self._get_cot_prompt(variables) - else: - system_content = resolve_variables(self.system_prompt, variables) - - system_message = SystemMessage(content=system_content) - self.add_to_history(system_message) - messages = await self._get_message_history(variables) - - # Keep executing tools until we get a final answer - tool_call_count = 0 - function_response = None - function_name = None - while tool_call_count < self.max_tool_calls: - formatted_tools = self.llm.format_tools_for_llm(self.tools) - response = await self.llm.generate( - messages, - functions=formatted_tools, - output_schema=self.output_schema, - ) - - # Handle ReACT and CoT patterns - function_call = await self.llm.get_function_call(response) - - # If no function call, check if this is truly a final answer - if not function_call: - assistant_message = self.llm.get_message_content(response) - if assistant_message: - # Check if this is a final answer or just intermediate reasoning - is_final = await self._is_final_answer( - assistant_message, tool_call_count, messages - ) - if is_final: - # Ensure act_as is not None (default to 'assistant' if missing) - role = ( - self.act_as - if self.act_as is not None - else MessageType.ASSISTANT - ) - self.add_to_history( - AssistantMessage( - role=role, content=assistant_message - ) - ) - return self.conversation_history - else: - # This is intermediate reasoning, add to context and continue - msg_preview = ( - assistant_message[:100] - if len(assistant_message) > 100 - else assistant_message - ) - logger.debug( - f'Detected intermediate reasoning (not final answer): {msg_preview}...' - ) - # Ensure act_as is not None (default to 'assistant' if missing) - role = ( - self.act_as - if self.act_as is not None - else MessageType.ASSISTANT - ) - self.add_to_history( - AssistantMessage( - role=role, content=assistant_message - ) - ) - self.add_to_history( - UserMessage( - content='Based on your reasoning, please proceed with the necessary tool calls to complete the task.', - ) - ) - continue - break - - # If there's a function call, add the assistant's response - # LLM-specific implementations handle special formatting (e.g., Claude's raw_content) - assistant_message_content = ( - self.llm.get_assistant_message_for_tool_call(response) - ) - if assistant_message_content: - # LLM returned special formatting (e.g., Claude's raw_content) - messages.append( - { - 'role': self.act_as, - 'content': assistant_message_content, - } - ) - else: - # Use default text content extraction - assistant_text = self.llm.get_message_content(response) - if assistant_text: - messages.append( - { - 'role': self.act_as, - 'content': assistant_text, - } - ) - - # Execute the tool - try: - function_name = function_call['name'] - # Get tool_use_id if available (LLM-specific, e.g., Claude) - tool_use_id = self.llm.get_tool_use_id(function_call) - if isinstance(function_call['arguments'], str): - function_args = json.loads(function_call['arguments']) - else: - function_args = function_call['arguments'] - - tool = self.tools_dict[function_name] - - # Track tool execution with telemetry - tracer = get_tracer() - - if tracer: - with tracer.start_as_current_span( - f'agent.tool.{function_name}', - attributes={ - 'tool.name': function_name, - 'agent.name': self.name, - }, - ) as tool_span: - function_response = await tool.run( - inputs=[], variables=None, **function_args - ) - tool_span.set_attribute( - 'tool.result.length', len(str(function_response)) - ) - else: - function_response = await tool.run( - inputs=[], variables=None, **function_args - ) - - agent_metrics.record_tool_call( - self.name, function_name, 'success' - ) - - tool_call_count += 1 - - # Add function call result to history using OpenAI's "function" role format - # According to OpenAI API: {"role": "function", "name": "", "content": ""} - self.add_to_history( - FunctionMessage( - content=str( - 'Here is the result of the tool call: \n' - + str(function_response) - ), - name=function_name, - ) - ) - - # Add the function response to messages for context - # LLM-specific implementations format the message appropriately - function_result_msg = self.llm.format_function_result_message( - function_name, str(function_response), tool_use_id - ) - messages.append(function_result_msg) - - except (json.JSONDecodeError, KeyError, ToolExecutionError) as e: - # Record tool call failure - # Safely extract function_name from function_call if available - error_function_name: str = function_name or 'unknown' - if error_function_name == 'unknown' and function_call: - error_function_name = function_call.get('name') or 'unknown' - agent_metrics.record_tool_call( - self.name, error_function_name, 'error' - ) - - retry_count += 1 - context = { - 'function_call': function_call, - 'attempt': retry_count, - } - should_retry, analysis = await self.handle_error(e, context) - if should_retry and retry_count <= self.max_retries: - # Record retry - agent_metrics.record_retry( - self.name, 'tool_execution_error' - ) - - self.add_to_history( - AssistantMessage( - content=f'Tool execution error: {analysis}' - ) - ) - continue - raise AgentError( - f'Tool execution failed: {analysis}', original_error=e - ) - - # If no tools were called, return conversation history - if tool_call_count == 0: - return self.conversation_history - - # Generate final response if we've hit the tool call limit or exited the loop - system_message = SystemMessage( - content='Please provide a final answer based on all the tool results above.' - ) - self.add_to_history(system_message) - messages = await self._get_message_history(variables) - - final_response = await self.llm.generate( - messages, - output_schema=self.output_schema, - ) - - assistant_message = self.llm.get_message_content(final_response) - if assistant_message: - # Ensure act_as is not None (default to 'assistant' if missing) - role = ( - self.act_as - if self.act_as is not None - else MessageType.ASSISTANT - ) - self.add_to_history(AssistantMessage(content=assistant_message)) - return self.conversation_history - - # Fallback: return function message only if we have valid tool execution data - if function_response is not None and function_name is not None: - return [ - FunctionMessage( - content=str( - 'The final result based on the tool executions is: \n' - + str(function_response) - ), - name=function_name, - ) - ] - else: - # No tools were executed and no assistant message, return safe fallback - return self.conversation_history - - except Exception as e: - retry_count += 1 - context = { - 'conversation_history': self.conversation_history, - 'attempt': retry_count, - } - - should_retry, analysis = await self.handle_error(e, context) - if should_retry and retry_count <= self.max_retries: - # Record retry - agent_metrics.record_retry(self.name, 'execution_error') - - self.add_to_history( - AssistantMessage( - content=f'Error occurred. Analysis: {analysis}' - ) - ) - continue - - raise AgentError( - f'Failed after {retry_count} attempts. Last error: {analysis}', - original_error=e, - ) - - raise AgentError(f'Failed after maximum {self.max_retries} attempts.') - - def _get_react_prompt(self, variables: Optional[Dict[str, Any]] = None) -> str: - """Get system prompt modified for ReACT pattern""" - variables = variables or {} - - tools_desc = '\n'.join( - [f'- {tool.name}: {tool.description}' for tool in self.tools] - ) - - # Resolve variables in the base system prompt - resolved_system_prompt = resolve_variables(self.system_prompt, variables) - - react_prompt = f"""{resolved_system_prompt} - When solving tasks, follow this format: - - Thought: Analyze the situation and think about what to do - Action: Use available tools in the format: tool_name(param1: "value1", param2: "value2") - Observation: The result of the action - ... (repeat Thought/Action/Observation if needed) - Final Answer: [Your complete answer to the user's question] - - Available tools: - {tools_desc} - - Remember to: - 1. Think carefully about what needs to be done - 2. Use tools when needed - 3. Make observations about tool results - 4. Conclude with a final answer when the task is complete - - IMPORTANT: When you have enough information to answer the user's question, you MUST prefix your response with "Final Answer:" to indicate completion.""" - - return react_prompt - - def _get_cot_prompt(self, variables: Optional[Dict[str, Any]] = None) -> str: - """Get system prompt modified for Chain of Thought pattern""" - variables = variables or {} - - tools_desc = '\n'.join( - [f'- {tool.name}: {tool.description}' for tool in self.tools] - ) - - # Resolve variables in the base system prompt - resolved_system_prompt = resolve_variables(self.system_prompt, variables) - - cot_prompt = f"""{resolved_system_prompt} - When solving tasks, follow this Chain of Thought reasoning format: - - Let me think through this step by step: - 1. First, I need to understand what is being asked... - 2. Then, I should consider what information or tools I need.... Use available tools in the format: tool_name(param1: "value1", param2: "value2") - 3. Next, I'll analyze the available options... - 4. Finally, I'll provide a well-reasoned answer... - - Available tools: - {tools_desc} - - Remember to: - 1. Break down complex problems into smaller steps - 2. Think through each step logically - 3. Use tools when needed to gather information - 4. Provide clear reasoning for your conclusions - 5. End with a final, well-justified answer - - IMPORTANT: When you have gathered all necessary information and are ready to provide your complete answer, you MUST prefix your response with "Final Answer:" to indicate completion.""" - - return cot_prompt - - async def _is_final_answer( - self, message: str, tool_call_count: int, messages: List[Dict[str, Any]] - ) -> bool: - """ - Determine if a message is a final answer or intermediate reasoning. - Uses structured token detection (like LangChain's ReAct) with LLM fallback. - - Approach inspired by LangChain/CrewAI: - 1. Primary: Check for explicit "Final Answer:" token - 2. Fallback: Use LLM-based classification for robustness - """ - message_stripped = message.strip() - message_lower = message_stripped.lower() - - # Primary Detection: Explicit "Final Answer:" token (ReAct pattern) - # This is the most reliable method used by LangChain and similar frameworks - if message_stripped.startswith('Final Answer:') or message_lower.startswith( - 'final answer:' - ): - logger.debug('Explicit "Final Answer:" token detected - this is final') - return True - - # Check if "Final Answer:" appears anywhere in the response - # (agent might add context before the token) - if 'final answer:' in message_lower: - logger.debug('"Final Answer:" token found in response - treating as final') - return True - - # Secondary Detection: Use LLM-based analysis for cases without explicit tokens - # This handles: - # - Agents not following the format perfectly - # - Direct mode (without ReAct/CoT patterns) - # - Edge cases where the agent provides answer without token - - analysis_prompt = f"""You are a classifier that determines if an AI agent's response is a FINAL ANSWER or INTERMEDIATE REASONING. - -Agent's Response: -"{message_stripped}" - -Context: -- Tool calls executed so far: {tool_call_count} -- Total conversation turns: {len(messages)} - -Classification Criteria: - -FINAL ANSWER - The response is final if it: -✓ Directly answers the user's original question with concrete information -✓ Provides specific data, results, or conclusions -✓ Does not suggest or request additional actions -✓ Reads like a complete, standalone answer -✓ Contains synthesis of information already gathered - -INTERMEDIATE REASONING - The response is intermediate if it: -✗ Describes plans or intentions for what to do next -✗ Expresses need to gather more information -✗ Contains thinking/reasoning WITHOUT providing the actual answer -✗ Poses questions or expresses uncertainty about next steps -✗ Mentions specific tools it wants to use - -Examples of INTERMEDIATE: -- "I need to query the database schema first" -- "Let me check the table structure" -- "First, I should examine..." - -Examples of FINAL: -- "Based on the query results, the table contains 1,245 records..." -- "The analysis shows that revenue increased by 23%..." -- "After examining the data, the answer is..." - -Respond with EXACTLY one word: "FINAL" or "INTERMEDIATE" -""" - - try: - analysis_messages = [ - { - 'role': MessageType.SYSTEM, - 'content': 'You are a precise classification system. Respond with only FINAL or INTERMEDIATE.', - }, - {'role': MessageType.USER, 'content': analysis_prompt}, - ] - analysis_response = await self.llm.generate(analysis_messages) - analysis = self.llm.get_message_content(analysis_response).strip().upper() - - is_final = 'FINAL' in analysis - msg_preview = ( - message_stripped[:80] - if len(message_stripped) > 80 - else message_stripped - ) - logger.debug( - f'LLM classifier: "{analysis}" -> is_final={is_final} (message preview: "{msg_preview}...")' - ) - return is_final - - except Exception as e: - logger.warning( - f'LLM classification failed: {e}. Defaulting to final=False to allow continuation.' - ) - # Conservative default: treat as intermediate to avoid premature exit - # This is safer as it allows the agent to continue rather than stopping too early - return False diff --git a/flo_ai/tests/unit-tests/test_agent_builder_tools.py b/flo_ai/tests/unit-tests/test_agent_builder_tools.py index 8a86dcc6..c295aae7 100644 --- a/flo_ai/tests/unit-tests/test_agent_builder_tools.py +++ b/flo_ai/tests/unit-tests/test_agent_builder_tools.py @@ -2,7 +2,7 @@ from unittest.mock import Mock, AsyncMock from flo_ai.tool.base_tool import Tool from flo_ai.tool.tool_config import ToolConfig -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI diff --git a/flo_ai/tests/unit-tests/test_arium_builder.py b/flo_ai/tests/unit-tests/test_arium_builder.py index bac36c7d..4cf78707 100644 --- a/flo_ai/tests/unit-tests/test_arium_builder.py +++ b/flo_ai/tests/unit-tests/test_arium_builder.py @@ -6,7 +6,7 @@ from unittest.mock import Mock from flo_ai.arium.builder import AriumBuilder, create_arium from flo_ai.arium.memory import MessageMemory -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.arium.nodes import FunctionNode diff --git a/flo_ai/tests/unit-tests/test_arium_yaml.py b/flo_ai/tests/unit-tests/test_arium_yaml.py index 442e338e..e3492f4e 100644 --- a/flo_ai/tests/unit-tests/test_arium_yaml.py +++ b/flo_ai/tests/unit-tests/test_arium_yaml.py @@ -7,7 +7,7 @@ from flo_ai.arium.builder import AriumBuilder from flo_ai.arium.memory import MessageMemory, BaseMemory -from flo_ai.models.agent import Agent +from flo_ai.agent import Agent from flo_ai.llm import OpenAI diff --git a/flo_ai/tests/unit-tests/test_base_llm.py b/flo_ai/tests/unit-tests/test_base_llm.py index 34c275da..79b980a0 100644 --- a/flo_ai/tests/unit-tests/test_base_llm.py +++ b/flo_ai/tests/unit-tests/test_base_llm.py @@ -51,6 +51,7 @@ async def stream( messages: List[Dict[str, str]], functions: Optional[List[Dict[str, Any]]] = None, output_schema: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> AsyncIterator[Dict[str, Any]]: async def generator(): yield {'response': self.response_text} diff --git a/flo_ai/tests/unit-tests/test_llm_router.py b/flo_ai/tests/unit-tests/test_llm_router.py index e820ee4c..7532ab4a 100644 --- a/flo_ai/tests/unit-tests/test_llm_router.py +++ b/flo_ai/tests/unit-tests/test_llm_router.py @@ -40,6 +40,7 @@ async def stream( messages: List[Dict[str, str]], functions: Optional[List[Dict[str, Any]]] = None, output_schema: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> AsyncIterator[Dict[str, Any]]: async def generator(): yield {'response': self.response_text} diff --git a/flo_ai/tests/unit-tests/test_yaml_tool_config.py b/flo_ai/tests/unit-tests/test_yaml_tool_config.py index c1fe3b45..8b55c621 100644 --- a/flo_ai/tests/unit-tests/test_yaml_tool_config.py +++ b/flo_ai/tests/unit-tests/test_yaml_tool_config.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import Mock, AsyncMock from flo_ai.tool.base_tool import Tool -from flo_ai.builder.agent_builder import AgentBuilder +from flo_ai.agent import AgentBuilder from flo_ai.llm import OpenAI From 0cedc1df5d1b5de403d6ebc8cff7f17e5706dad7 Mon Sep 17 00:00:00 2001 From: vishnu r kumar Date: Fri, 12 Dec 2025 22:04:44 +0530 Subject: [PATCH 2/3] feat(flo-ai): add field level validation for agent/arium yaml --- flo_ai/examples/arium_yaml_example.py | 8 +- flo_ai/examples/simple_yaml_workflow.py | 2 +- flo_ai/examples/yaml_agent_example.py | 2 +- flo_ai/flo_ai/agent/builder.py | 99 +- flo_ai/flo_ai/arium/builder.py | 429 ++++---- flo_ai/flo_ai/formatter/yaml_format_parser.py | 4 +- flo_ai/flo_ai/helpers/llm_factory.py | 80 +- flo_ai/flo_ai/helpers/yaml_validation.py | 53 + flo_ai/flo_ai/models/agent.py | 256 +++++ flo_ai/flo_ai/models/arium.py | 365 +++++++ .../unit-tests/test_agent_yaml_validation.py | 970 ++++++++++++++++++ flo_ai/tests/unit-tests/test_arium_yaml.py | 158 +-- .../unit-tests/test_arium_yaml_validation.py | 837 +++++++++++++++ flo_ai/uv.lock | 2 +- .../models/telephony_schemas.py | 6 +- 15 files changed, 2791 insertions(+), 480 deletions(-) create mode 100644 flo_ai/flo_ai/helpers/yaml_validation.py create mode 100644 flo_ai/flo_ai/models/arium.py create mode 100644 flo_ai/tests/unit-tests/test_agent_yaml_validation.py create mode 100644 flo_ai/tests/unit-tests/test_arium_yaml_validation.py diff --git a/flo_ai/examples/arium_yaml_example.py b/flo_ai/examples/arium_yaml_example.py index a60200c8..a5ea181d 100644 --- a/flo_ai/examples/arium_yaml_example.py +++ b/flo_ai/examples/arium_yaml_example.py @@ -329,14 +329,18 @@ description: "Executive summary" - name: findings type: array + description: "Key findings list" items: + name: finding type: str - description: "Key findings list" + description: "A single finding" - name: recommendations type: array + description: "Recommended actions" items: + name: recommendation type: str - description: "Recommended actions" + description: "A single recommendation" workflow: start: input_validator diff --git a/flo_ai/examples/simple_yaml_workflow.py b/flo_ai/examples/simple_yaml_workflow.py index 4a9fc58d..8df498e2 100644 --- a/flo_ai/examples/simple_yaml_workflow.py +++ b/flo_ai/examples/simple_yaml_workflow.py @@ -125,7 +125,7 @@ async def main(): print('-' * 30) if isinstance(result, list): for i, message in enumerate(result): - print(f'{i+1}. {message}') + print(f'{i+1}. {message.result.content}') else: print(result) diff --git a/flo_ai/examples/yaml_agent_example.py b/flo_ai/examples/yaml_agent_example.py index 2ab40e6f..7810853a 100644 --- a/flo_ai/examples/yaml_agent_example.py +++ b/flo_ai/examples/yaml_agent_example.py @@ -137,7 +137,7 @@ async def main(): # Process the email thread result = await agent.run(email_thread) - print('Analysis Result:', result) + print('Analysis Result:', result[-1].content) if __name__ == '__main__': diff --git a/flo_ai/flo_ai/agent/builder.py b/flo_ai/flo_ai/agent/builder.py index a94ccf83..cde16d1d 100644 --- a/flo_ai/flo_ai/agent/builder.py +++ b/flo_ai/flo_ai/agent/builder.py @@ -7,7 +7,9 @@ from flo_ai.tool.base_tool import Tool from flo_ai.tool.tool_config import ToolConfig, create_tool_config from flo_ai.formatter.yaml_format_parser import FloYamlParser -from pydantic import BaseModel +from flo_ai.models.agent import AgentYamlModel, LLMConfigModel +from flo_ai.helpers.yaml_validation import format_validation_error_path +from pydantic import BaseModel, ValidationError class AgentBuilder: @@ -183,6 +185,36 @@ def build(self) -> Agent: act_as=self._act_as, ) + @staticmethod + def _validate_yaml_config(config: Dict[str, Any]) -> AgentYamlModel: + """Validate YAML configuration using Pydantic models. + + Args: + config: Dictionary containing YAML configuration + + Returns: + AgentYamlModel: Validated configuration model + + Raises: + ValueError: If validation fails with formatted error messages + """ + try: + validated_config = AgentYamlModel(**config) + except ValidationError as e: + # Format validation errors for better readability + error_messages = [] + for error in e.errors(): + field_path = format_validation_error_path(error['loc'], config) + error_msg = f"{field_path}: {error['msg']}" + if 'ctx' in error: + error_msg += f" (context: {error['ctx']})" + error_messages.append(error_msg) + raise ValueError( + 'YAML validation failed:\n' + + '\n'.join(f' - {msg}' for msg in error_messages) + ) from e + return validated_config + @classmethod def from_yaml( cls, @@ -220,26 +252,31 @@ def from_yaml( with open(yaml_file, 'r') as f: config = yaml.safe_load(f) - if 'agent' not in config: - raise ValueError('YAML must contain an "agent" section') + validated_config = cls._validate_yaml_config(config) - agent_config = config['agent'] + # Use validated config directly - convert to dict only when needed for compatibility + config = validated_config.model_dump(exclude_none=True) + agent = validated_config.agent builder = cls() - # Set basic properties - builder.with_name(agent_config.get('name', 'AI Assistant')) - builder.with_prompt(agent_config.get('job', 'You are a helpful AI assistant.')) - builder.with_role(agent_config.get('role')) - builder.with_actas(agent_config.get('act_as')) + builder.with_name(agent.name or 'AI Assistant') + # Handle both 'job' and 'prompt' fields (job takes precedence) + prompt = agent.job or agent.prompt or 'You are a helpful AI assistant.' + builder.with_prompt(prompt) + builder.with_role(agent.role) + builder.with_actas(agent.act_as) # Configure LLM based on model settings - if 'model' in agent_config and base_llm is None: + if agent.model is not None and base_llm is None: from flo_ai.helpers.llm_factory import create_llm_from_config - model_config: dict = agent_config['model'] - # Merge base_url from agent_config if present and not in model_config - if 'base_url' in agent_config and 'base_url' not in model_config: - model_config = {**model_config, 'base_url': agent_config['base_url']} + # Merge base_url from agent if present and not in model_config + model_config: LLMConfigModel = agent.model + if agent.base_url is not None and model_config.base_url is None: + # Create a new model instance with merged base_url using model_copy + model_config = model_config.model_copy( + update={'base_url': agent.base_url} + ) llm = create_llm_from_config(model_config, **kwargs) builder.with_llm(llm) @@ -250,29 +287,33 @@ def from_yaml( ) builder.with_llm(base_llm) - # Handle tools configuration - if 'tools' in agent_config: - # Process tools from YAML configuration - yaml_tools = cls._process_yaml_tools(agent_config['tools'], tool_registry) + if agent.tools is not None: + tools_list = [] + for tool in agent.tools: + if isinstance(tool, str): + tools_list.append(tool) + else: + # ToolConfigModel - convert to dict + tools_list.append(tool.model_dump(exclude_none=True)) + + yaml_tools = cls._process_yaml_tools(tools_list, tool_registry) builder.with_tools(yaml_tools) elif tools: # Use provided tools builder.with_tools(tools) - # Set parser if present - if 'parser' in agent_config: + if agent.parser is not None: parser = FloYamlParser.create(yaml_dict=config) builder.with_output_schema(parser.get_format()) - # Apply settings if present - if 'settings' in agent_config: - settings = agent_config['settings'] - if 'temperature' in settings: - builder._llm.temperature = settings['temperature'] - if 'max_retries' in settings: - builder.with_retries(settings['max_retries']) - if 'reasoning_pattern' in settings: - builder.with_reasoning(ReasoningPattern[settings['reasoning_pattern']]) + if agent.settings is not None: + settings = agent.settings + if settings.temperature is not None: + builder._llm.temperature = settings.temperature + if settings.max_retries is not None: + builder.with_retries(settings.max_retries) + if settings.reasoning_pattern is not None: + builder.with_reasoning(ReasoningPattern[settings.reasoning_pattern]) return builder diff --git a/flo_ai/flo_ai/arium/builder.py b/flo_ai/flo_ai/arium/builder.py index 3f0e98b0..05fcd86a 100644 --- a/flo_ai/flo_ai/arium/builder.py +++ b/flo_ai/flo_ai/arium/builder.py @@ -11,6 +11,10 @@ from flo_ai.arium.llm_router import create_llm_router from flo_ai.arium.nodes import FunctionNode from flo_ai.arium.base import AriumNodeType +from flo_ai.models.arium import AriumYamlModel, AriumAgentConfigModel +from flo_ai.models.agent import LLMConfigModel +from flo_ai.helpers.yaml_validation import format_validation_error_path +from pydantic import ValidationError class AriumBuilder: @@ -249,6 +253,7 @@ async def build_and_run( variables = variables if variables is not None else {} arium = self.build() new_inputs = [] + if isinstance(inputs, list): for input in inputs: if isinstance(input, str): @@ -283,6 +288,36 @@ def reset(self) -> 'AriumBuilder': self._arium = None return self + @staticmethod + def _validate_yaml_config(config: Dict[str, Any]) -> AriumYamlModel: + """Validate YAML configuration using Pydantic models. + + Args: + config: Dictionary containing YAML configuration + + Returns: + AriumYamlModel: Validated configuration model + + Raises: + ValueError: If validation fails with formatted error messages + """ + try: + validated_config = AriumYamlModel(**config) + except ValidationError as e: + # Format validation errors for better readability + error_messages = [] + for error in e.errors(): + field_path = format_validation_error_path(error['loc'], config) + error_msg = f"{field_path}: {error['msg']}" + if 'ctx' in error: + error_msg += f" (context: {error['ctx']})" + error_messages.append(error_msg) + raise ValueError( + 'YAML validation failed:\n' + + '\n'.join(f' - {msg}' for msg in error_messages) + ) from e + return validated_config + @classmethod def from_yaml( cls, @@ -310,136 +345,8 @@ def from_yaml( Returns: AriumBuilder: Configured builder instance - Example YAML structure: - metadata: - name: my-workflow - version: 1.0.0 - description: "Example workflow" - - arium: - agents: - # Method 1: Reference pre-built agents - - name: content_analyst # Must exist in agents parameter - - name: summarizer # Must exist in agents parameter - - # Method 2: Direct agent definition - - name: validator - role: "Data Validator" - job: "You are a data validator" - model: - provider: openai - name: gpt-4o-mini - settings: - temperature: 0.1 - - # Method 3: Inline YAML configuration - - name: processor - yaml_config: | - agent: - name: processor - job: "You are a data processor" - model: - provider: openai - name: gpt-4o-mini - - # Method 4: External file reference - - name: reporter - yaml_file: "path/to/reporter.yaml" - - function_nodes: - - name: function1 - function_name: function1 - - name: function2 - function_name: function2 - description: "Function 2" - input_filter: ["input1", "input2"] - prefilled_params: - param1: "value1" - param2: "value2" - # LLM Router definitions (NEW) - routers: - - name: content_router - type: smart # smart, task_classifier, conversation_analysis, reflection, plan_execute - routing_options: - technical_writer: "Handle technical documentation tasks" - creative_writer: "Handle creative writing tasks" - editor: "Handle editing and review tasks" - model: - provider: openai - name: gpt-4o-mini - settings: - temperature: 0.3 - fallback_strategy: first - - # Reflection router for A -> B -> A -> C patterns - - name: main_critic_reflection - type: reflection - flow_pattern: [main_agent, critic, main_agent, final_agent] - settings: - allow_early_exit: false - - # Plan-Execute router for Cursor-style workflows - - name: plan_execute_router - type: plan_execute - agents: - planner: "Creates detailed execution plans" - developer: "Implements code and features" - tester: "Tests implementations" - reviewer: "Reviews final results" - settings: - planner_agent: planner - executor_agent: developer - reviewer_agent: reviewer - - # AriumNode definitions (nested Arium workflows) - arium_nodes: - # Method 1: Inline nested Arium definition - - name: document_processor - inherit_variables: true # optional, default: true - agents: - - name: classifier - job: "Classify documents" - model: - provider: openai - name: gpt-4o-mini - - name: specialist - job: "Process classified documents" - model: - provider: openai - name: gpt-4o-mini - workflow: - start: classifier - edges: - - from: classifier - to: [specialist] - end: [specialist] - - # Method 2: External YAML file reference - - name: complex_processor - yaml_file: "workflows/document_classifier.yaml" - inherit_variables: false - - # ForEachNode definitions - foreach_nodes: - - name: batch_processor - execute_node: document_processor # Can reference any node type - - workflow: - start: batch_processor # Can reference any node type including foreach/arium nodes - edges: - - from: content_analyst - to: [validator, summarizer] - router: content_router # References router defined above - - from: validator - to: [processor] - - from: summarizer - to: [reporter] - - from: processor - to: [end] - - from: reporter - to: [end] - end: [processor, reporter] """ + if yaml_str is None and yaml_file is None: raise ValueError('Either yaml_str or yaml_file must be provided') @@ -455,10 +362,11 @@ def from_yaml( with open(yaml_file, 'r') as f: config = yaml.safe_load(f) - if 'arium' not in config: - raise ValueError('YAML must contain an "arium" section') + validated_config = cls._validate_yaml_config(config) - arium_config = config['arium'] + # Use validated config directly - convert to dict only when needed for compatibility + config = validated_config.model_dump(exclude_none=True, by_alias=True) + arium = validated_config.arium builder = cls() # Configure memory - use provided memory or default to MessageMemory @@ -468,14 +376,21 @@ def from_yaml( builder.with_memory(MessageMemory()) # Process agents - agents_config = arium_config.get('agents', []) + agents_list = arium.agents or [] agents_dict = {} - for agent_config in agents_config: - agent_name = agent_config['name'] + for agent_config in agents_list: + agent_name = agent_config.name # Method 1: Reference pre-built agent - if len(agent_config) == 1 and 'name' in agent_config: + # Check if only name is provided (no other config fields) + if ( + agent_config.job is None + and agent_config.prompt is None + and agent_config.yaml_config is None + and agent_config.yaml_file is None + and agent_config.model is None + ): # Only has name field, so it's a reference to a pre-built agent if agents and agent_name in agents: agent = agents[agent_name] @@ -491,27 +406,27 @@ def from_yaml( # Method 2: Direct agent definition elif ( - 'job' in agent_config - and 'yaml_config' not in agent_config - and 'yaml_file' not in agent_config + (agent_config.job is not None or agent_config.prompt is not None) + and agent_config.yaml_config is None + and agent_config.yaml_file is None ): agent = cls._create_agent_from_direct_config( agent_config, base_llm, tool_registry, **kwargs ) # Method 3: Inline YAML config - elif 'yaml_config' in agent_config: + elif agent_config.yaml_config is not None: agent_builder = AgentBuilder.from_yaml( - yaml_str=agent_config['yaml_config'], + yaml_str=agent_config.yaml_config, base_llm=base_llm, tool_registry=tool_registry, ) agent = agent_builder.build() # Method 4: External file reference - elif 'yaml_file' in agent_config: + elif agent_config.yaml_file is not None: agent_builder: AgentBuilder = AgentBuilder.from_yaml( - yaml_file=agent_config['yaml_file'], + yaml_file=agent_config.yaml_file, base_llm=base_llm, tool_registry=tool_registry, ) @@ -530,126 +445,124 @@ def from_yaml( builder.add_agent(agent) # Process function nodes - function_nodes_config = arium_config.get('function_nodes', []) + function_nodes_list = arium.function_nodes or [] function_nodes_dict = {} - for function_node_config in function_nodes_config: - function_node_name = function_node_config['name'] - function_name = function_node_config['function_name'] - prefilled_params = function_node_config.get('prefilled_params', None) - description = function_node_config.get('description', None) - input_filter = function_node_config.get('input_filter', None) + for function_node in function_nodes_list: function = ( - function_registry.get(function_name) + function_registry.get(function_node.function_name) if function_registry is not None else None ) if function is None: raise ValueError( - f'Function {function_name} not found in provided function_registry dictionary. ' + f'Function {function_node.function_name} not found in provided function_registry dictionary. ' f'Available functions: {list[str](function_registry.keys()) if function_registry else []}. ' f'Either provide the function in the function_registry parameter or add configuration fields.' ) function_node = FunctionNode( - name=function_node_name, - description=description, + name=function_node.name, + description=function_node.description, function=function, - input_filter=input_filter, - prefilled_params=prefilled_params, + input_filter=function_node.input_filter, + prefilled_params=function_node.prefilled_params, ) - function_nodes_dict[function_node_name] = function_node + function_nodes_dict[function_node.name] = function_node builder.add_function_node(function_node) # Process LLM routers (if defined in YAML) - routers_config = arium_config.get('routers', []) + routers_list = arium.routers or [] yaml_routers = {} # Store routers created from YAML config - for router_config in routers_config: - router_name = router_config['name'] - router_type = router_config.get('type', 'smart') + for router in routers_list: + router_type = router.type or 'smart' # Create LLM instance for router router_llm = None - if 'model' in router_config: + if router.model is not None: router_llm = cls._create_llm_from_config( - router_config['model'], base_llm, **kwargs + router.model, base_llm, **kwargs ) else: router_llm = base_llm # Use base LLM if no specific model configured # Extract router-specific settings - settings = router_config.get('settings', {}) + settings = ( + router.settings.model_dump(exclude_none=True) if router.settings else {} + ) # Create router based on type if router_type == 'smart': - routing_options = router_config.get('routing_options', {}) - if not routing_options: + if not router.routing_options: raise ValueError( - f'Smart router {router_name} must specify routing_options' + f'Smart router {router.name} must specify routing_options' ) router_fn = create_llm_router( router_type='smart', - routing_options=routing_options, + routing_options=router.routing_options, llm=router_llm, **settings, ) elif router_type == 'task_classifier': - task_categories = router_config.get('task_categories', {}) - if not task_categories: + if not router.task_categories: raise ValueError( - f'Task classifier router {router_name} must specify task_categories' + f'Task classifier router {router.name} must specify task_categories' ) + # Convert TaskCategoryModel dict to regular dict + task_categories_dict = { + k: v.model_dump(exclude_none=True) + if hasattr(v, 'model_dump') + else v + for k, v in router.task_categories.items() + } router_fn = create_llm_router( router_type='task_classifier', - task_categories=task_categories, + task_categories=task_categories_dict, llm=router_llm, **settings, ) elif router_type == 'conversation_analysis': - routing_logic = router_config.get('routing_logic', {}) - if not routing_logic: + if not router.routing_logic: raise ValueError( - f'Conversation analysis router {router_name} must specify routing_logic' + f'Conversation analysis router {router.name} must specify routing_logic' ) router_fn = create_llm_router( router_type='conversation_analysis', - routing_logic=routing_logic, + routing_logic=router.routing_logic, llm=router_llm, **settings, ) elif router_type == 'reflection': - flow_pattern = router_config.get('flow_pattern', []) - if not flow_pattern: + if not router.flow_pattern: raise ValueError( - f'Reflection router {router_name} must specify flow_pattern' + f'Reflection router {router.name} must specify flow_pattern' ) router_fn = create_llm_router( router_type='reflection', - flow_pattern=flow_pattern, + flow_pattern=router.flow_pattern, llm=router_llm, **settings, ) elif router_type == 'plan_execute': - agents = router_config.get('agents', {}) - if not agents: + if not router.agents: raise ValueError( - f'Plan-Execute router {router_name} must specify agents' + f'Plan-Execute router {router.name} must specify agents' ) router_fn = create_llm_router( router_type='plan_execute', - agents=agents, + agents=router.agents, llm=router_llm, **settings, ) @@ -658,7 +571,7 @@ def from_yaml( f'Unknown router type: {router_type}. Supported types: smart, task_classifier, conversation_analysis, reflection, plan_execute' ) - yaml_routers[router_name] = router_fn + yaml_routers[router.name] = router_fn # Merge YAML routers with provided routers all_routers = {} @@ -667,19 +580,20 @@ def from_yaml( all_routers.update(yaml_routers) # Process AriumNodes (nested Arium workflows) - arium_nodes_config = arium_config.get('ariums', []) + arium_nodes_list = arium.ariums or [] arium_nodes_dict = {} - for arium_node_config in arium_nodes_config: - node_name = arium_node_config['name'] - inherit_vars = arium_node_config.get('inherit_variables', True) + for arium_node in arium_nodes_list: + inherit_vars = ( + arium_node.inherit_variables + if arium_node.inherit_variables is not None + else True + ) # Method 1: External YAML file reference - if 'yaml_file' in arium_node_config: - yaml_file_path = arium_node_config['yaml_file'] - + if arium_node.yaml_file is not None: nested_builder = cls.from_yaml( - yaml_file=yaml_file_path, + yaml_file=arium_node.yaml_file, memory=None, agents=None, routers=None, @@ -694,14 +608,29 @@ def from_yaml( # Build sub-config from inline definition sub_config = { 'arium': { - 'agents': arium_node_config.get('agents', []), - 'function_nodes': arium_node_config.get('function_nodes', []), - 'routers': arium_node_config.get('routers', []), - 'ariums': arium_node_config.get( - 'ariums', [] - ), # Support nesting! - 'iterators': arium_node_config.get('iterators', []), - 'workflow': arium_node_config['workflow'], + 'agents': [ + agent.model_dump(exclude_none=True, by_alias=True) + for agent in (arium_node.agents or []) + ], + 'function_nodes': [ + fn.model_dump(exclude_none=True, by_alias=True) + for fn in (arium_node.function_nodes or []) + ], + 'routers': [ + router.model_dump(exclude_none=True, by_alias=True) + for router in (arium_node.routers or []) + ], + 'ariums': [ + a.model_dump(exclude_none=True, by_alias=True) + for a in (arium_node.ariums or []) + ], # Support nesting! + 'iterators': [ + it.model_dump(exclude_none=True, by_alias=True) + for it in (arium_node.iterators or []) + ], + 'workflow': arium_node.workflow.model_dump( + exclude_none=True, by_alias=True + ), } } @@ -718,20 +647,20 @@ def from_yaml( # Wrap in AriumNode arium_node = AriumNode( - name=node_name, arium=nested_arium, inherit_variables=inherit_vars + name=arium_node.name, arium=nested_arium, inherit_variables=inherit_vars ) - arium_nodes_dict[node_name] = arium_node + arium_nodes_dict[arium_node.name] = arium_node builder._all_ariums.append(arium_node) # Don't add to builder yet - will add during workflow processing if actually used # Process ForEachNodes (store configs, resolve later) - foreach_nodes_config = arium_config.get('iterators', []) + foreach_nodes_list = arium.iterators or [] foreach_nodes_dict = {} - for foreach_config in foreach_nodes_config: - foreach_name = foreach_config['name'] - execute_node_name = foreach_config['execute_node'] + for foreach_config in foreach_nodes_list: + foreach_name = foreach_config.name + execute_node_name = foreach_config.execute_node foreach_nodes_dict[foreach_name] = { 'name': foreach_name, @@ -769,7 +698,7 @@ def from_yaml( builder._foreach_nodes.append(foreach_node) # Process workflow - workflow_config = arium_config.get('workflow', {}) + workflow = arium.workflow # Helper function to find node from all sources def _find_node(node_name: str): @@ -781,7 +710,7 @@ def _find_node(node_name: str): ) # Set start node - start_node_name = workflow_config.get('start') + start_node_name = workflow.start if not start_node_name: raise ValueError('Workflow must specify a start node') @@ -804,12 +733,12 @@ def _find_node(node_name: str): builder.start_with(start_node) # Process edges - edges_config = workflow_config.get('edges', []) + edges_list = workflow.edges - for edge_config in edges_config: - from_node_name = edge_config['from'] - to_nodes_names = edge_config['to'] - router_name = edge_config.get('router') + for edge_config in edges_list: + from_node_name = edge_config.from_ + to_nodes_names = edge_config.to + router_name = edge_config.router # Find from node from_node = _find_node(from_node_name) @@ -853,7 +782,7 @@ def _find_node(node_name: str): builder.add_edge(from_node, to_nodes, router_fn) # Set end nodes - end_nodes_names = workflow_config.get('end', []) + end_nodes_names = workflow.end if not end_nodes_names: raise ValueError('Workflow must specify end nodes') @@ -872,14 +801,14 @@ def _find_node(node_name: str): @staticmethod def _create_llm_from_config( - model_config: Dict[str, Any], + model_config: LLMConfigModel, base_llm: Optional[BaseLLM] = None, **kwargs, ) -> BaseLLM: """Create an LLM instance from model configuration. Args: - model_config: Dictionary containing model configuration + model_config: LLMConfigModel instance containing model configuration base_llm: Base LLM to use as fallback Returns: @@ -891,7 +820,7 @@ def _create_llm_from_config( @staticmethod def _create_agent_from_direct_config( - agent_config: Dict[str, Any], + agent_config: AriumAgentConfigModel, base_llm: Optional[BaseLLM] = None, available_tools: Optional[Dict[str, Tool]] = None, **kwargs, @@ -899,7 +828,7 @@ def _create_agent_from_direct_config( """Create an Agent from direct YAML configuration. Args: - agent_config: Dictionary containing agent configuration + agent_config: AriumAgentConfigModel instance containing agent configuration base_llm: Base LLM to use if not specified in config available_tools: Available tools dictionary for tool lookup @@ -910,17 +839,23 @@ def _create_agent_from_direct_config( # from flo_ai.llm import OpenAI, Anthropic, Gemini, OllamaLLM # Extract basic configuration - name = agent_config['name'] - job = agent_config['job'] - role: str = ( - str(agent_config.get('role')) - if agent_config.get('role') is not None - else '' + name = agent_config.name + job = ( + agent_config.job or agent_config.prompt or 'You are a helpful AI assistant.' ) + role: str = str(agent_config.role) if agent_config.role is not None else '' + act_as = agent_config.act_as # Configure LLM - if 'model' in agent_config and base_llm is None: - llm = AriumBuilder._create_llm_from_config(agent_config['model'], **kwargs) + if agent_config.model is not None and base_llm is None: + # Merge base_url from agent_config if present and not in model_config + model_config = agent_config.model + if agent_config.base_url is not None and model_config.base_url is None: + # Create a new model instance with merged base_url using model_copy + model_config = model_config.model_copy( + update={'base_url': agent_config.base_url} + ) + llm = AriumBuilder._create_llm_from_config(model_config, **kwargs) elif base_llm: llm = base_llm else: @@ -929,10 +864,16 @@ def _create_agent_from_direct_config( ) # Extract settings - settings = agent_config.get('settings', {}) - temperature = settings.get('temperature') - max_retries = settings.get('max_retries', 3) - reasoning_pattern_str = settings.get('reasoning_pattern', 'DIRECT') + settings = agent_config.settings + temperature = settings.temperature if settings else None + max_retries = ( + settings.max_retries if settings and settings.max_retries is not None else 3 + ) + reasoning_pattern_str = ( + settings.reasoning_pattern + if settings and settings.reasoning_pattern + else 'DIRECT' + ) # Convert reasoning pattern string to enum try: @@ -946,9 +887,16 @@ def _create_agent_from_direct_config( # Extract and resolve tools agent_tools = [] - tool_names = agent_config.get('tools', []) - if tool_names and available_tools: - for tool_name in tool_names: + tool_configs = agent_config.tools or [] + if tool_configs and available_tools: + for tool_item in tool_configs: + # Handle both string tool names and ToolConfigModel instances + if isinstance(tool_item, str): + tool_name = tool_item + else: + # ToolConfigModel - extract name + tool_name = tool_item.name + if tool_name in available_tools: agent_tools.append(available_tools[tool_name]) else: @@ -959,15 +907,16 @@ def _create_agent_from_direct_config( # Handle parser configuration if present output_schema: Optional[Dict[str, Any]] = None - if 'parser' in agent_config: + if agent_config.parser is not None: from flo_ai.formatter.yaml_format_parser import FloYamlParser # Convert agent_config to the format expected by FloYamlParser - parser_config = {'agent': {'parser': agent_config['parser']}} + parser_dict = agent_config.parser.model_dump(exclude_none=True) + parser_config = {'agent': {'parser': parser_dict}} parser = FloYamlParser.create(yaml_dict=parser_config) output_schema = parser.get_format() - agent = ( + builder = ( AgentBuilder() .with_name(name) .with_prompt(job) @@ -977,9 +926,13 @@ def _create_agent_from_direct_config( .with_reasoning(reasoning_pattern) .with_output_schema(output_schema if output_schema is not None else {}) .with_role(role) - .build() ) + if act_as is not None: + builder.with_actas(act_as) + + agent = builder.build() + return agent diff --git a/flo_ai/flo_ai/formatter/yaml_format_parser.py b/flo_ai/flo_ai/formatter/yaml_format_parser.py index 62a3c607..90029eb0 100644 --- a/flo_ai/flo_ai/formatter/yaml_format_parser.py +++ b/flo_ai/flo_ai/formatter/yaml_format_parser.py @@ -91,7 +91,7 @@ def __create_literal_type(self, field: Dict[str, Any]) -> Any: ) literals = tuple(literal_value['value'] for literal_value in literal_values) # Construct Literal type dynamically at runtime - return Literal.__class_getitem__(literals) # type: ignore + return Literal.__getitem__(literals) def __create_array_type(self, field: Dict[str, Any], model_name: str) -> Any: """Creates a List type from field definition""" @@ -190,7 +190,7 @@ def __init__( raise ValueError( 'Either yaml_dict or yaml_path is required to build a FloYamlParser' ) - self.yaml_dict = yaml_dict + self.yaml_dict: dict[str, Any] | None = yaml_dict self.yaml_path = yaml_path def build(self): diff --git a/flo_ai/flo_ai/helpers/llm_factory.py b/flo_ai/flo_ai/helpers/llm_factory.py index 3673b97e..9b2e4a64 100644 --- a/flo_ai/flo_ai/helpers/llm_factory.py +++ b/flo_ai/flo_ai/helpers/llm_factory.py @@ -2,15 +2,17 @@ LLM Factory - Centralized LLM creation from configuration. This module provides a unified factory function for creating LLM instances -from configuration dictionaries, supporting all providers in the flo_ai ecosystem. +from configuration models, supporting all providers in the flo_ai ecosystem. """ import os -from typing import Dict, Any, TYPE_CHECKING +from typing import TYPE_CHECKING if TYPE_CHECKING: from flo_ai.llm import BaseLLM +from flo_ai.models.agent import LLMConfigModel + class LLMFactory: """Factory class for creating LLM instances from configuration.""" @@ -26,17 +28,11 @@ class LLMFactory: } @staticmethod - def create_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': + def create_llm(model_config: LLMConfigModel, **kwargs) -> 'BaseLLM': """Create an LLM instance from model configuration. Args: - model_config: Dictionary containing model configuration with keys: - - provider (str): LLM provider name (default: 'openai') - - name (str): Model name (required for most providers) - - base_url (str, optional): Custom base URL - - model_id (str): For RootFlo provider - - project (str): For VertexAI provider - - location (str): For VertexAI provider (default: 'asia-south1') + model_config: LLMConfigModel instance containing model configuration **kwargs: Additional parameters that override config and env vars: - base_url: Override base URL - For RootFlo: app_key, app_secret, issuer, audience, access_token @@ -46,34 +42,8 @@ def create_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': Raises: ValueError: If provider is unsupported or required parameters are missing - - Examples: - >>> # OpenAI - >>> llm = LLMFactory.create_llm({'provider': 'openai', 'name': 'gpt-4'}) - - >>> # VertexAI with project - >>> llm = LLMFactory.create_llm({ - ... 'provider': 'vertexai', - ... 'name': 'gemini-pro', - ... 'project': 'my-project', - ... 'location': 'us-central1' - ... }) - - >>> # RootFlo with auth - >>> llm = LLMFactory.create_llm( - ... {'provider': 'rootflo', 'model_id': 'model-123'}, - ... app_key='key', app_secret='secret', issuer='iss', audience='aud' - ... ) - - >>> # OpenAI vLLM with base_url - >>> llm = LLMFactory.create_llm({ - ... 'provider': 'openai_vllm', - ... 'name': 'microsoft/phi-4', - ... 'base_url': 'http://localhost:8000/v1', - ... 'api_key': 'vllm-key' - ... }) """ - provider = model_config.get('provider', 'openai').lower() + provider = model_config.provider.lower() if provider not in LLMFactory.SUPPORTED_PROVIDERS: raise ValueError( @@ -92,19 +62,19 @@ def create_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': @staticmethod def _create_standard_llm( - provider: str, model_config: Dict[str, Any], **kwargs + provider: str, model_config: LLMConfigModel, **kwargs ) -> 'BaseLLM': """Create standard LLM instances (OpenAI, Anthropic, Gemini, Ollama).""" from flo_ai.llm import OpenAI, Anthropic, Gemini, OllamaLLM - model_name = model_config.get('name') + model_name = model_config.name if not model_name: raise ValueError( f'{provider.title()} provider requires "name" parameter in model configuration' ) # Priority: kwargs > model_config > None - base_url = kwargs.get('base_url') or model_config.get('base_url') + base_url = kwargs.get('base_url') or model_config.base_url provider_map = { 'openai': OpenAI, @@ -117,20 +87,20 @@ def _create_standard_llm( return llm_class(model=model_name, base_url=base_url) @staticmethod - def _create_vertexai_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': + def _create_vertexai_llm(model_config: LLMConfigModel, **kwargs) -> 'BaseLLM': """Create VertexAI LLM instance with project and location.""" from flo_ai.llm import VertexAI - model_name = model_config.get('name') + model_name = model_config.name if not model_name: raise ValueError( 'VertexAI provider requires "name" parameter in model configuration' ) # Get VertexAI-specific parameters - project = kwargs.get('project') or model_config.get('project') - location = kwargs.get('location') or model_config.get('location', 'asia-south1') - base_url = kwargs.get('base_url') or model_config.get('base_url') + project = kwargs.get('project') or model_config.project + location = kwargs.get('location') or model_config.location or 'asia-south1' + base_url = kwargs.get('base_url') or model_config.base_url if not project: raise ValueError( @@ -152,18 +122,18 @@ def _create_vertexai_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': ) @staticmethod - def _create_openai_vllm_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': + def _create_openai_vllm_llm(model_config: LLMConfigModel, **kwargs) -> 'BaseLLM': """Create OpenAI vLLM instance with base_url handling.""" from flo_ai.llm import OpenAIVLLM - model_name = model_config.get('name') + model_name = model_config.name if not model_name: raise ValueError( 'openai_vllm provider requires "name" parameter in model configuration' ) # Priority: kwargs > model_config > None - base_url = kwargs.get('base_url') or model_config.get('base_url') + base_url = kwargs.get('base_url') or model_config.base_url if not base_url: raise ValueError( 'openai_vllm provider requires "base_url" parameter. ' @@ -171,7 +141,7 @@ def _create_openai_vllm_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM' ) # Optional parameters - api_key = kwargs.get('api_key') or model_config.get('api_key') + api_key = kwargs.get('api_key') or model_config.api_key if not api_key: raise ValueError( 'openai_vllm provider requires "api_key" parameter. ' @@ -179,7 +149,7 @@ def _create_openai_vllm_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM' ) temperature = kwargs.get( 'temperature', - model_config.get('temperature', 0.7), + model_config.temperature if model_config.temperature is not None else 0.7, ) return OpenAIVLLM( @@ -190,11 +160,11 @@ def _create_openai_vllm_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM' ) @staticmethod - def _create_rootflo_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': + def _create_rootflo_llm(model_config: LLMConfigModel, **kwargs) -> 'BaseLLM': """Create RootFlo LLM instance with authentication.""" from flo_ai.llm import RootFloLLM - model_id = model_config.get('model_id') + model_id = model_config.model_id if not model_id: raise ValueError( 'RootFlo provider requires "model_id" in model configuration' @@ -203,7 +173,7 @@ def _create_rootflo_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': # Gather RootFlo parameters from kwargs or environment base_url = ( kwargs.get('base_url') - or model_config.get('base_url') + or model_config.base_url or os.getenv('ROOTFLO_BASE_URL') ) app_key = kwargs.get('app_key') or os.getenv('ROOTFLO_APP_KEY') @@ -231,14 +201,14 @@ def _create_rootflo_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': # Convenience function for direct import -def create_llm_from_config(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': +def create_llm_from_config(model_config: LLMConfigModel, **kwargs) -> 'BaseLLM': """ Convenience function to create an LLM instance from configuration. This is a wrapper around LLMFactory.create_llm() for easier imports. Args: - model_config: Dictionary containing model configuration + model_config: LLMConfigModel instance containing model configuration **kwargs: Additional parameters that override config and env vars Returns: diff --git a/flo_ai/flo_ai/helpers/yaml_validation.py b/flo_ai/flo_ai/helpers/yaml_validation.py new file mode 100644 index 00000000..b3ef9608 --- /dev/null +++ b/flo_ai/flo_ai/helpers/yaml_validation.py @@ -0,0 +1,53 @@ +"""Helper functions for YAML validation error formatting.""" + +from typing import Dict, Any, Tuple + + +def format_validation_error_path(loc: Tuple, config: Dict[str, Any]) -> str: + """Format error location path, using section names instead of indices when available. + + This function improves error message readability by replacing numeric indices + with meaningful section names (e.g., agent names, router names) when available + in the YAML configuration. + + Args: + loc: Location tuple from Pydantic validation error (e.g., ('arium', 'agents', 0, 'name')) + config: Original YAML configuration dictionary + + Returns: + Formatted path string with names instead of indices where possible + + Example: + Instead of: "arium -> agents -> 0 -> job: Field required" + Returns: "arium -> agents -> my_agent -> job: Field required" + """ + path_parts = [] + current = config + + for part in loc: + # If part is an integer, try to find the name of the item at that index + if isinstance(part, int): + # Check if current is a list and we can access the item + if isinstance(current, list) and 0 <= part < len(current): + item = current[part] + # If the item has a 'name' field, use that instead of the index + if isinstance(item, dict) and 'name' in item: + path_parts.append(f"{item['name']}") + else: + path_parts.append(str(part)) + current = item + else: + path_parts.append(str(part)) + else: + # String key - use it directly + path_parts.append(str(part)) + # Navigate deeper into the config + if isinstance(current, dict) and part in current: + current = current[part] + elif isinstance(current, list): + # If we're at a list level, we can't navigate further by key + current = None + else: + current = None + + return ' -> '.join(path_parts) diff --git a/flo_ai/flo_ai/models/agent.py b/flo_ai/flo_ai/models/agent.py index e69de29b..95225b2e 100644 --- a/flo_ai/flo_ai/models/agent.py +++ b/flo_ai/flo_ai/models/agent.py @@ -0,0 +1,256 @@ +""" +Pydantic models for validating agent YAML configurations. + +These models ensure that YAML configurations are properly structured +and validated before being used to create agents. +""" + +from typing import List, Optional, Dict, Any, Union, Literal +from pydantic import BaseModel, Field, field_validator + + +class MetadataModel(BaseModel): + """Metadata section for agent YAML configuration.""" + + name: Optional[str] = None + version: Optional[str] = None + description: Optional[str] = None + author: Optional[str] = None + tags: Optional[List[str]] = None + + +class LiteralValueModel(BaseModel): + """A single value in a literal type field.""" + + value: str = Field(..., description='The literal value') + description: str = Field(..., description='Description of this value') + examples: Optional[List[str]] = Field( + None, description='Example strings for this value' + ) + + +class ParserFieldModel(BaseModel): + """A field definition in a parser configuration.""" + + name: str = Field(..., description='Field name') + type: Literal['str', 'int', 'bool', 'float', 'literal', 'object', 'array'] = Field( + ..., description='Field type' + ) + description: str = Field(..., description='Field description') + required: Optional[bool] = Field(None, description='Whether field is required') + values: Optional[List[LiteralValueModel]] = Field( + None, description='Values for literal type fields' + ) + items: Optional['ParserFieldModel'] = Field( + None, description='Item type for array fields' + ) + fields: Optional[List['ParserFieldModel']] = Field( + None, description='Nested fields for object type fields' + ) + default_value_prompt: Optional[str] = Field( + None, description='Default value prompt for literal fields' + ) + + def model_post_init(self, __context): + """Validate that literal type fields have values.""" + if self.type == 'literal' and not self.values: + raise ValueError( + f"Field '{self.name}' of type 'literal' must specify 'values'." + ) + if self.type == 'array' and not self.items: + raise ValueError( + f"Field '{self.name}' of type 'array' must specify 'items'." + ) + if self.type == 'object' and not self.fields: + raise ValueError( + f"Field '{self.name}' of type 'object' must specify 'fields'." + ) + + +class ParserModel(BaseModel): + """Parser configuration for structured output.""" + + name: str = Field(..., description='Parser name') + version: Optional[str] = Field(None, description='Parser version') + description: Optional[str] = Field(None, description='Parser description') + fields: List[ParserFieldModel] = Field(..., description='Parser field definitions') + + +class ExampleModel(BaseModel): + """Example input/output pair for the agent.""" + + input: str = Field(..., description='Example input') + output: Union[Dict[str, Any], str] = Field(..., description='Example output') + + +class LLMConfigModel(BaseModel): + """LLM model configuration.""" + + provider: Literal[ + 'openai', + 'anthropic', + 'claude', # Alias for anthropic + 'gemini', + 'google', # Alias for gemini + 'ollama', + 'vertexai', + 'rootflo', + 'openai_vllm', + ] = Field(..., description='LLM provider') + name: Optional[str] = Field( + None, description='Model name (required for most providers)' + ) + base_url: Optional[str] = Field(None, description='Custom base URL') + temperature: Optional[float] = Field( + None, ge=0.0, le=2.0, description='Temperature setting' + ) + max_tokens: Optional[int] = Field(None, gt=0, description='Maximum tokens') + timeout: Optional[int] = Field(None, gt=0, description='Request timeout in seconds') + # VertexAI specific + project: Optional[str] = Field(None, description='GCP project ID (for VertexAI)') + location: Optional[str] = Field(None, description='GCP location (for VertexAI)') + # RootFlo specific + model_id: Optional[str] = Field(None, description='Model ID (for RootFlo)') + # OpenAI vLLM specific + api_key: Optional[str] = Field(None, description='API key (for openai_vllm)') + + def model_post_init(self, __context): + """Validate provider-specific requirements.""" + provider = self.provider.lower() + + # Most providers require 'name' + if provider in ['openai', 'anthropic', 'claude', 'gemini', 'google', 'ollama']: + if not self.name: + raise ValueError( + f'{provider.title()} provider requires "name" parameter in model configuration' + ) + + # VertexAI requires name, project, and base_url + if provider == 'vertexai': + if not self.name: + raise ValueError('VertexAI provider requires "name" parameter') + if not self.project: + raise ValueError('VertexAI provider requires "project" parameter') + if not self.base_url: + raise ValueError('VertexAI provider requires "base_url" parameter') + + # RootFlo requires model_id + if provider == 'rootflo': + if not self.model_id: + raise ValueError( + 'RootFlo provider requires "model_id" in model configuration' + ) + + # OpenAI vLLM requires name, base_url, and api_key + if provider == 'openai_vllm': + if not self.name: + raise ValueError('openai_vllm provider requires "name" parameter') + if not self.base_url: + raise ValueError('openai_vllm provider requires "base_url" parameter') + if not self.api_key: + raise ValueError('openai_vllm provider requires "api_key" parameter') + + +class SettingsModel(BaseModel): + """Agent settings configuration.""" + + temperature: Optional[float] = Field( + None, ge=0.0, le=2.0, description='Temperature setting' + ) + max_retries: Optional[int] = Field( + None, ge=0, description='Maximum number of retries' + ) + reasoning_pattern: Optional[Literal['DIRECT', 'REACT', 'COT']] = Field( + None, description='Reasoning pattern' + ) + + +class ToolConfigModel(BaseModel): + """Tool configuration in YAML.""" + + name: str = Field(..., description='Tool name (must exist in tool registry)') + prefilled_params: Optional[Dict[str, Any]] = Field( + None, description='Pre-filled parameters for the tool' + ) + name_override: Optional[str] = Field( + None, description='Custom name override for the tool' + ) + description_override: Optional[str] = Field( + None, description='Custom description override for the tool' + ) + + +class AgentConfigModel(BaseModel): + """Main agent configuration model.""" + + name: str = Field(..., description='Agent name') + job: Optional[str] = Field(None, description='System prompt/job description') + prompt: Optional[str] = Field(None, description='System prompt (alias for job)') + role: Optional[str] = Field(None, description='Agent role') + act_as: Optional[str] = Field(None, description='Agent act_as setting') + model: Optional[LLMConfigModel] = Field(None, description='LLM model configuration') + base_url: Optional[str] = Field( + None, description='Base URL (can be at agent or model level)' + ) + settings: Optional[SettingsModel] = Field(None, description='Agent settings') + tools: Optional[List[Union[str, ToolConfigModel]]] = Field( + None, description='List of tools (strings or tool configs)' + ) + parser: Optional[ParserModel] = Field( + None, description='Parser configuration for structured output' + ) + examples: Optional[List[ExampleModel]] = Field( + None, description='Example input/output pairs' + ) + + def model_post_init(self, __context): + """Ensure either job or prompt is provided.""" + if not self.job and not self.prompt: + raise ValueError( + "Agent configuration must have either 'job' or 'prompt' field" + ) + # If both are provided, prefer 'job' and ignore 'prompt' + if self.job and self.prompt: + # Keep job, prompt will be ignored in favor of job + pass + + @field_validator('tools', mode='before') + @classmethod + def validate_tools(cls, v): + """Validate tools configuration.""" + if v is None: + return v + for tool in v: + if isinstance(tool, str): + # String reference - valid + continue + elif isinstance(tool, dict): + # Should be validated as ToolConfigModel + if 'name' not in tool: + raise ValueError("Tool configuration must have a 'name' field") + elif isinstance(tool, ToolConfigModel): + # Already validated as ToolConfigModel - valid + continue + else: + raise ValueError( + f'Invalid tool configuration type: {type(tool)}. ' + "Must be string or dict with 'name' field." + ) + return v + + +class AgentYamlModel(BaseModel): + """Root model for agent YAML configuration.""" + + metadata: Optional[MetadataModel] = Field(None, description='Metadata section') + agent: AgentConfigModel = Field(..., description='Agent configuration') + + def model_post_init(self, __context): + """Validate that agent has model config or will receive base_llm.""" + # This validation is handled in the builder, but we can note it here + # The actual check happens in from_yaml when base_llm is None + pass + + +# Update forward references for recursive types +ParserFieldModel.model_rebuild() diff --git a/flo_ai/flo_ai/models/arium.py b/flo_ai/flo_ai/models/arium.py new file mode 100644 index 00000000..53c25d87 --- /dev/null +++ b/flo_ai/flo_ai/models/arium.py @@ -0,0 +1,365 @@ +""" +Pydantic models for validating arium YAML configurations. + +These models ensure that YAML configurations are properly structured +and validated before being used to create arium workflows. +""" + +from typing import List, Optional, Dict, Any, Literal +from pydantic import BaseModel, Field, field_validator, ConfigDict + +# Import shared models from agent.py +from flo_ai.models.agent import ( + MetadataModel, + LLMConfigModel, + AgentConfigModel, +) + + +class AriumAgentConfigModel(AgentConfigModel): + """Agent configuration within an arium workflow. + + Extends AgentConfigModel to support arium-specific configuration methods: + - Name-only reference to pre-built agent (allowed in arium, not in standalone agent) + - Inline yaml_config string + - External yaml_file reference + """ + + yaml_config: Optional[str] = Field( + None, description='Inline YAML configuration string for agent' + ) + yaml_file: Optional[str] = Field( + None, description='Path to YAML file containing agent configuration' + ) + + def model_post_init(self, __context): + """Validate agent configuration methods for arium context. + + Overrides parent validation to allow name-only references (for pre-built agents) + and to validate arium-specific configuration methods (yaml_config, yaml_file). + """ + # Count how many configuration methods are provided + config_methods = [ + self.job or self.prompt, + self.yaml_config, + self.yaml_file, + ] + provided_methods = sum(1 for method in config_methods if method is not None) + + # If model is provided, it indicates direct configuration attempt + # In this case, we need job/prompt (unless using yaml_config/yaml_file) + if self.model is not None: + if not self.yaml_config and not self.yaml_file: + # Model provided but no yaml_config/yaml_file means direct config + # Must have job or prompt + if not self.job and not self.prompt: + raise ValueError( + f"Agent '{self.name}' has 'model' specified but is missing 'job' or 'prompt' field. " + "When using direct configuration with a model, 'job' or 'prompt' is required." + ) + + # If only name is provided (no model, no config methods), it's a reference to a pre-built agent (valid in arium) + if provided_methods == 0 and self.model is None: + # This is a reference to a pre-built agent - valid in arium context + # Skip parent validation which requires job/prompt + return + + # If multiple methods are provided, that's ambiguous + if provided_methods > 1: + methods = [] + if self.job or self.prompt: + methods.append('job/prompt') + if self.yaml_config: + methods.append('yaml_config') + if self.yaml_file: + methods.append('yaml_file') + raise ValueError( + f"Agent '{self.name}' has multiple configuration methods: {', '.join(methods)}. " + "Only one method should be provided." + ) + + # If using direct config (job/prompt), ensure at least one is provided + # This mirrors the parent validation but only when using direct config + if not self.yaml_config and not self.yaml_file: + if not self.job and not self.prompt: + raise ValueError( + "Agent configuration must have either 'job' or 'prompt' field when using direct configuration" + ) + + +class FunctionNodeConfigModel(BaseModel): + """Function node configuration in arium workflow.""" + + name: str = Field(..., description='Function node name') + function_name: str = Field(..., description='Name of function in function registry') + description: Optional[str] = Field(None, description='Function description') + input_filter: Optional[List[str]] = Field( + None, description='List of input keys to filter' + ) + prefilled_params: Optional[Dict[str, Any]] = Field( + None, description='Pre-filled parameters for the function' + ) + + +class RouterSettingsModel(BaseModel): + """Settings for router configuration.""" + + temperature: Optional[float] = Field( + None, ge=0.0, le=2.0, description='Temperature setting' + ) + fallback_strategy: Optional[Literal['first', 'random', 'all']] = Field( + None, description='Fallback strategy for routing' + ) + allow_early_exit: Optional[bool] = Field( + None, description='Allow early exit (for reflection router)' + ) + planner_agent: Optional[str] = Field( + None, description='Planner agent name (for plan_execute router)' + ) + executor_agent: Optional[str] = Field( + None, description='Executor agent name (for plan_execute router)' + ) + reviewer_agent: Optional[str] = Field( + None, description='Reviewer agent name (for plan_execute router)' + ) + + +class TaskCategoryModel(BaseModel): + """Task category configuration for task_classifier router.""" + + description: str = Field( + ..., description='Description of what this category handles' + ) + keywords: Optional[List[str]] = Field( + None, description='Optional keywords for this category' + ) + examples: Optional[List[str]] = Field( + None, description='Optional example tasks for this category' + ) + + +class RouterConfigModel(BaseModel): + """Router configuration in arium workflow.""" + + name: str = Field(..., description='Router name') + type: Literal[ + 'smart', + 'task_classifier', + 'conversation_analysis', + 'reflection', + 'plan_execute', + ] = Field(..., description='Router type') + model: Optional[LLMConfigModel] = Field( + None, description='LLM model configuration for router' + ) + settings: Optional[RouterSettingsModel] = Field(None, description='Router settings') + # Smart router fields + routing_options: Optional[Dict[str, str]] = Field( + None, description='Routing options for smart router (agent_name: description)' + ) + # Task classifier router fields + task_categories: Optional[Dict[str, TaskCategoryModel]] = Field( + None, description='Task categories for task_classifier router' + ) + # Conversation analysis router fields + routing_logic: Optional[Dict[str, str]] = Field( + None, description='Routing logic for conversation_analysis router' + ) + # Reflection router fields + flow_pattern: Optional[List[str]] = Field( + None, description='Flow pattern for reflection router (list of agent names)' + ) + # Plan-execute router fields + agents: Optional[Dict[str, str]] = Field( + None, description='Agent descriptions for plan_execute router' + ) + + def model_post_init(self, __context): + """Validate router type-specific requirements.""" + if self.type == 'smart': + if not self.routing_options: + raise ValueError( + f"Smart router '{self.name}' must specify 'routing_options'" + ) + elif self.type == 'task_classifier': + if not self.task_categories: + raise ValueError( + f"Task classifier router '{self.name}' must specify 'task_categories'" + ) + elif self.type == 'conversation_analysis': + if not self.routing_logic: + raise ValueError( + f"Conversation analysis router '{self.name}' must specify 'routing_logic'" + ) + elif self.type == 'reflection': + if not self.flow_pattern: + raise ValueError( + f"Reflection router '{self.name}' must specify 'flow_pattern'" + ) + elif self.type == 'plan_execute': + if not self.agents: + raise ValueError( + f"Plan-Execute router '{self.name}' must specify 'agents'" + ) + + +class EdgeConfigModel(BaseModel): + """Edge configuration in arium workflow.""" + + model_config = ConfigDict(populate_by_name=True) # Allow 'from' alias + + from_: str = Field(..., alias='from', description='Source node name') + to: List[str] = Field(..., description='Target node names') + router: Optional[str] = Field(None, description='Router name to use for this edge') + + +class WorkflowConfigModel(BaseModel): + """Workflow configuration in arium.""" + + start: str = Field(..., description='Start node name') + edges: List[EdgeConfigModel] = Field(..., description='List of edges') + end: List[str] = Field(..., description='List of end node names') + + @field_validator('edges', mode='before') + @classmethod + def validate_edges(cls, v): + """Validate edges configuration.""" + if not isinstance(v, list): + raise ValueError('Edges must be a list') + return v + + +class AriumNodeConfigModel(BaseModel): + """Nested arium node configuration.""" + + name: str = Field(..., description='Arium node name') + inherit_variables: Optional[bool] = Field( + True, description='Whether to inherit parent variables' + ) + yaml_file: Optional[str] = Field( + None, description='Path to YAML file containing nested arium configuration' + ) + # Inline nested arium configuration + agents: Optional[List[AriumAgentConfigModel]] = Field( + None, description='List of agents for nested arium' + ) + function_nodes: Optional[List[FunctionNodeConfigModel]] = Field( + None, description='List of function nodes for nested arium' + ) + routers: Optional[List['RouterConfigModel']] = Field( + None, description='List of routers for nested arium' + ) + ariums: Optional[List['AriumNodeConfigModel']] = Field( + None, description='Nested arium nodes (supports nesting)' + ) + iterators: Optional[List['ForEachNodeConfigModel']] = Field( + None, description='List of foreach nodes for nested arium' + ) + workflow: Optional[WorkflowConfigModel] = Field( + None, description='Workflow configuration for nested arium' + ) + + def model_post_init(self, __context): + """Validate that either yaml_file or inline config is provided.""" + has_yaml_file = self.yaml_file is not None + has_inline_config = ( + self.agents is not None + or self.function_nodes is not None + or self.routers is not None + or self.ariums is not None + or self.iterators is not None + or self.workflow is not None + ) + + if not has_yaml_file and not has_inline_config: + raise ValueError( + f"Arium node '{self.name}' must have either 'yaml_file' or inline configuration" + ) + + if has_yaml_file and has_inline_config: + raise ValueError( + f"Arium node '{self.name}' cannot have both 'yaml_file' and inline configuration" + ) + + if has_inline_config and not self.workflow: + raise ValueError( + f"Arium node '{self.name}' with inline configuration must specify 'workflow'" + ) + + +class ForEachNodeConfigModel(BaseModel): + """ForEach node configuration in arium workflow.""" + + name: str = Field(..., description='ForEach node name') + execute_node: str = Field(..., description='Name of node to execute on each item') + + +class AriumConfigModel(BaseModel): + """Main arium configuration model.""" + + agents: Optional[List[AriumAgentConfigModel]] = Field( + None, description='List of agents in the workflow' + ) + function_nodes: Optional[List[FunctionNodeConfigModel]] = Field( + None, description='List of function nodes in the workflow' + ) + routers: Optional[List[RouterConfigModel]] = Field( + None, description='List of routers in the workflow' + ) + ariums: Optional[List[AriumNodeConfigModel]] = Field( + None, description='List of nested arium nodes' + ) + iterators: Optional[List[ForEachNodeConfigModel]] = Field( + None, description="List of foreach nodes (aliased as 'iterators' in YAML)" + ) + foreach_nodes: Optional[List[ForEachNodeConfigModel]] = Field( + None, description='List of foreach nodes (alternative name)' + ) + workflow: WorkflowConfigModel = Field(..., description='Workflow configuration') + + @field_validator('iterators', 'foreach_nodes', mode='before') + @classmethod + def validate_foreach_nodes(cls, v): + """Handle both 'iterators' and 'foreach_nodes' aliases.""" + return v + + def model_post_init(self, __context): + """Merge iterators and foreach_nodes if both are provided.""" + # Merge iterators and foreach_nodes (they're aliases) + if self.iterators and self.foreach_nodes: + # Prefer foreach_nodes if both are provided + self.iterators = self.foreach_nodes + self.foreach_nodes = None + elif self.foreach_nodes: + self.iterators = self.foreach_nodes + self.foreach_nodes = None + + # Validate that at least one node type is defined + has_nodes = ( + (self.agents and len(self.agents) > 0) + or (self.function_nodes and len(self.function_nodes) > 0) + or (self.ariums and len(self.ariums) > 0) + or (self.iterators and len(self.iterators) > 0) + ) + + if not has_nodes: + raise ValueError( + 'Arium configuration must have at least one of: agents, function_nodes, ariums, or iterators' + ) + + +class AriumYamlModel(BaseModel): + """Root model for arium YAML configuration.""" + + metadata: Optional[MetadataModel] = Field(None, description='Metadata section') + arium: AriumConfigModel = Field(..., description='Arium configuration') + + def model_post_init(self, __context): + """Additional validation if needed.""" + pass + + +# Update forward references for recursive types +AriumNodeConfigModel.model_rebuild() +RouterConfigModel.model_rebuild() +ForEachNodeConfigModel.model_rebuild() diff --git a/flo_ai/tests/unit-tests/test_agent_yaml_validation.py b/flo_ai/tests/unit-tests/test_agent_yaml_validation.py new file mode 100644 index 00000000..4a1a7837 --- /dev/null +++ b/flo_ai/tests/unit-tests/test_agent_yaml_validation.py @@ -0,0 +1,970 @@ +""" +Tests for agent YAML validation using Pydantic models. + +This module tests all validation rules defined in flo_ai.models.agent, +ensuring that YAML configurations are properly validated before being used. +""" + +import pytest +import yaml +from pydantic import ValidationError + +from typing import cast + +from flo_ai.models.agent import ( + AgentYamlModel, + AgentConfigModel, + LiteralValueModel, + MetadataModel, + LLMConfigModel, + SettingsModel, + ParserModel, + ParserFieldModel, + ExampleModel, + ToolConfigModel, +) + + +class TestMetadataModel: + """Test cases for MetadataModel validation.""" + + def test_valid_metadata(self): + """Test valid metadata configuration.""" + metadata = MetadataModel( + name='test-agent', + version='1.0.0', + description='A test agent', + author='Test Author', + tags=['test', 'agent'], + ) + assert metadata.name == 'test-agent' + assert metadata.version == '1.0.0' + assert metadata.description == 'A test agent' + assert metadata.author == 'Test Author' + assert metadata.tags == ['test', 'agent'] + + def test_empty_metadata(self): + """Test that all metadata fields are optional.""" + metadata = MetadataModel() + assert metadata.name is None + assert metadata.version is None + assert metadata.description is None + assert metadata.author is None + assert metadata.tags is None + + def test_partial_metadata(self): + """Test metadata with only some fields.""" + metadata = MetadataModel(name='test-agent', version='1.0.0') + assert metadata.name == 'test-agent' + assert metadata.version == '1.0.0' + assert metadata.description is None + + +class TestParserFieldModel: + """Test cases for ParserFieldModel validation.""" + + def test_valid_string_field(self): + """Test valid string field.""" + field = ParserFieldModel( + name='query', + type='str', + description='A query string', + required=True, + ) + assert field.name == 'query' + assert field.type == 'str' + assert field.description == 'A query string' + assert field.required is True + + def test_valid_literal_field(self): + """Test valid literal field with values.""" + field = ParserFieldModel( + name='status', + type='literal', + description='Status value', + values=[ + LiteralValueModel(value='active', description='Active status'), + LiteralValueModel(value='inactive', description='Inactive status'), + ], + ) + assert field.type == 'literal' + assert field.values is not None + assert len(field.values) == 2 + assert field.values[0].value == 'active' + + def test_literal_field_missing_values(self): + """Test that literal fields must have values.""" + with pytest.raises(ValueError, match="must specify 'values'"): + ParserFieldModel( + name='status', + type='literal', + description='Status value', + ) + + def test_valid_array_field(self): + """Test valid array field with items.""" + field = ParserFieldModel( + name='items', + type='array', + description='List of items', + items=ParserFieldModel( + name='item', + type='str', + description='An item', + ), + ) + assert field.type == 'array' + assert field.items is not None + assert field.items.type == 'str' + + def test_array_field_missing_items(self): + """Test that array fields must have items.""" + with pytest.raises(ValueError, match="must specify 'items'"): + ParserFieldModel( + name='items', + type='array', + description='List of items', + ) + + def test_valid_object_field(self): + """Test valid object field with nested fields.""" + field = ParserFieldModel( + name='user', + type='object', + description='User object', + fields=[ + ParserFieldModel( + name='name', + type='str', + description='User name', + ), + ParserFieldModel( + name='age', + type='int', + description='User age', + ), + ], + ) + assert field.type == 'object' + assert field.fields is not None + assert len(field.fields) == 2 + + def test_object_field_missing_fields(self): + """Test that object fields must have fields.""" + with pytest.raises(ValueError, match="must specify 'fields'"): + ParserFieldModel( + name='user', + type='object', + description='User object', + ) + + def test_nested_object_field(self): + """Test nested object fields.""" + field = ParserFieldModel( + name='address', + type='object', + description='Address object', + fields=[ + ParserFieldModel( + name='street', + type='str', + description='Street name', + ), + ParserFieldModel( + name='city', + type='object', + description='City object', + fields=[ + ParserFieldModel( + name='name', + type='str', + description='City name', + ), + ], + ), + ], + ) + assert field.type == 'object' + assert field.fields is not None + assert field.fields[1].type == 'object' + assert field.fields[1].fields is not None + assert field.fields[1].fields[0].type == 'str' + + +class TestParserModel: + """Test cases for ParserModel validation.""" + + def test_valid_parser(self): + """Test valid parser configuration.""" + parser = ParserModel( + name='test_parser', + version='1.0.0', + description='A test parser', + fields=[ + ParserFieldModel( + name='query', + type='str', + description='Query string', + ), + ], + ) + assert parser.name == 'test_parser' + assert parser.version == '1.0.0' + assert len(parser.fields) == 1 + + def test_parser_missing_required_fields(self): + """Test that parser must have name and fields.""" + with pytest.raises(ValidationError): + ParserModel( + name='test_parser', + # Missing fields + ) + + def test_parser_optional_fields(self): + """Test parser with only required fields.""" + parser = ParserModel( + name='test_parser', + fields=[ + ParserFieldModel( + name='query', + type='str', + description='Query string', + ), + ], + ) + assert parser.name == 'test_parser' + assert parser.version is None + assert parser.description is None + + +class TestExampleModel: + """Test cases for ExampleModel validation.""" + + def test_valid_example_with_string_output(self): + """Test valid example with string output.""" + example = ExampleModel( + input='Hello', + output='Hi there!', + ) + assert example.input == 'Hello' + assert example.output == 'Hi there!' + + def test_valid_example_with_dict_output(self): + """Test valid example with dictionary output.""" + example = ExampleModel( + input='Get user info', + output={'name': 'John', 'age': 30}, + ) + assert example.input == 'Get user info' + assert example.output == {'name': 'John', 'age': 30} + + def test_example_missing_input(self): + """Test that example must have input.""" + with pytest.raises(ValidationError): + ExampleModel(output='Hi there!') + + def test_example_missing_output(self): + """Test that example must have output.""" + with pytest.raises(ValidationError): + ExampleModel(input='Hello') + + +class TestLLMConfigModel: + """Test cases for LLMConfigModel validation.""" + + def test_valid_openai_config(self): + """Test valid OpenAI model configuration.""" + config = LLMConfigModel( + provider='openai', + name='gpt-4', + temperature=0.7, + max_tokens=1000, + ) + assert config.provider == 'openai' + assert config.name == 'gpt-4' + assert config.temperature == 0.7 + assert config.max_tokens == 1000 + + def test_openai_missing_name(self): + """Test that OpenAI requires name.""" + with pytest.raises(ValueError, match='requires "name" parameter'): + LLMConfigModel(provider='openai') + + def test_valid_anthropic_config(self): + """Test valid Anthropic model configuration.""" + config = LLMConfigModel( + provider='anthropic', + name='claude-3-opus-20240229', + temperature=0.5, + ) + assert config.provider == 'anthropic' + assert config.name == 'claude-3-opus-20240229' + + def test_anthropic_missing_name(self): + """Test that Anthropic requires name.""" + with pytest.raises(ValueError, match='requires "name" parameter'): + LLMConfigModel(provider='anthropic') + + def test_valid_claude_alias(self): + """Test that 'claude' alias works for Anthropic.""" + config = LLMConfigModel( + provider='claude', + name='claude-3-opus-20240229', + ) + assert config.provider == 'claude' + + def test_valid_gemini_config(self): + """Test valid Gemini model configuration.""" + config = LLMConfigModel( + provider='gemini', + name='gemini-pro', + temperature=0.8, + ) + assert config.provider == 'gemini' + assert config.name == 'gemini-pro' + + def test_gemini_missing_name(self): + """Test that Gemini requires name.""" + with pytest.raises(ValueError, match='requires "name" parameter'): + LLMConfigModel(provider='gemini') + + def test_valid_google_alias(self): + """Test that 'google' alias works for Gemini.""" + config = LLMConfigModel( + provider='google', + name='gemini-pro', + ) + assert config.provider == 'google' + + def test_valid_ollama_config(self): + """Test valid Ollama model configuration.""" + config = LLMConfigModel( + provider='ollama', + name='llama2', + base_url='http://localhost:11434', + ) + assert config.provider == 'ollama' + assert config.name == 'llama2' + + def test_ollama_missing_name(self): + """Test that Ollama requires name.""" + with pytest.raises(ValueError, match='requires "name" parameter'): + LLMConfigModel(provider='ollama') + + def test_valid_vertexai_config(self): + """Test valid VertexAI model configuration.""" + config = LLMConfigModel( + provider='vertexai', + name='gemini-pro', + project='my-project', + base_url='https://us-central1-aiplatform.googleapis.com', + location='us-central1', + ) + assert config.provider == 'vertexai' + assert config.project == 'my-project' + assert config.base_url == 'https://us-central1-aiplatform.googleapis.com' + + def test_vertexai_missing_name(self): + """Test that VertexAI requires name.""" + with pytest.raises(ValueError, match='requires "name" parameter'): + LLMConfigModel( + provider='vertexai', + project='my-project', + base_url='https://us-central1-aiplatform.googleapis.com', + ) + + def test_vertexai_missing_project(self): + """Test that VertexAI requires project.""" + with pytest.raises(ValueError, match='requires "project" parameter'): + LLMConfigModel( + provider='vertexai', + name='gemini-pro', + base_url='https://us-central1-aiplatform.googleapis.com', + ) + + def test_vertexai_missing_base_url(self): + """Test that VertexAI requires base_url.""" + with pytest.raises(ValueError, match='requires "base_url" parameter'): + LLMConfigModel( + provider='vertexai', + name='gemini-pro', + project='my-project', + ) + + def test_valid_rootflo_config(self): + """Test valid RootFlo model configuration.""" + config = LLMConfigModel( + provider='rootflo', + model_id='model-123', + ) + assert config.provider == 'rootflo' + assert config.model_id == 'model-123' + + def test_rootflo_missing_model_id(self): + """Test that RootFlo requires model_id.""" + with pytest.raises(ValueError, match='requires "model_id"'): + LLMConfigModel(provider='rootflo') + + def test_valid_openai_vllm_config(self): + """Test valid OpenAI vLLM model configuration.""" + config = LLMConfigModel( + provider='openai_vllm', + name='gpt-4', + base_url='http://localhost:8000/v1', + api_key='sk-test', + ) + assert config.provider == 'openai_vllm' + assert config.name == 'gpt-4' + assert config.base_url == 'http://localhost:8000/v1' + assert config.api_key == 'sk-test' + + def test_openai_vllm_missing_name(self): + """Test that OpenAI vLLM requires name.""" + with pytest.raises(ValueError, match='requires "name" parameter'): + LLMConfigModel( + provider='openai_vllm', + base_url='http://localhost:8000/v1', + api_key='sk-test', + ) + + def test_openai_vllm_missing_base_url(self): + """Test that OpenAI vLLM requires base_url.""" + with pytest.raises(ValueError, match='requires "base_url" parameter'): + LLMConfigModel( + provider='openai_vllm', + name='gpt-4', + api_key='sk-test', + ) + + def test_openai_vllm_missing_api_key(self): + """Test that OpenAI vLLM requires api_key.""" + with pytest.raises(ValueError, match='requires "api_key" parameter'): + LLMConfigModel( + provider='openai_vllm', + name='gpt-4', + base_url='http://localhost:8000/v1', + ) + + def test_temperature_range_validation(self): + """Test temperature range validation (0.0 to 2.0).""" + # Valid temperatures + config1 = LLMConfigModel(provider='openai', name='gpt-4', temperature=0.0) + assert config1.temperature == 0.0 + + config2 = LLMConfigModel(provider='openai', name='gpt-4', temperature=2.0) + assert config2.temperature == 2.0 + + # Invalid temperatures + with pytest.raises(ValidationError): + LLMConfigModel(provider='openai', name='gpt-4', temperature=-0.1) + + with pytest.raises(ValidationError): + LLMConfigModel(provider='openai', name='gpt-4', temperature=2.1) + + def test_max_tokens_validation(self): + """Test max_tokens must be greater than 0.""" + config = LLMConfigModel(provider='openai', name='gpt-4', max_tokens=100) + assert config.max_tokens == 100 + + with pytest.raises(ValidationError): + LLMConfigModel(provider='openai', name='gpt-4', max_tokens=0) + + with pytest.raises(ValidationError): + LLMConfigModel(provider='openai', name='gpt-4', max_tokens=-1) + + def test_timeout_validation(self): + """Test timeout must be greater than 0.""" + config = LLMConfigModel(provider='openai', name='gpt-4', timeout=30) + assert config.timeout == 30 + + with pytest.raises(ValidationError): + LLMConfigModel(provider='openai', name='gpt-4', timeout=0) + + with pytest.raises(ValidationError): + LLMConfigModel(provider='openai', name='gpt-4', timeout=-1) + + +class TestSettingsModel: + """Test cases for SettingsModel validation.""" + + def test_valid_settings(self): + """Test valid settings configuration.""" + settings = SettingsModel( + temperature=0.7, + max_retries=3, + reasoning_pattern='REACT', + ) + assert settings.temperature == 0.7 + assert settings.max_retries == 3 + assert settings.reasoning_pattern == 'REACT' + + def test_empty_settings(self): + """Test that all settings fields are optional.""" + settings = SettingsModel() + assert settings.temperature is None + assert settings.max_retries is None + assert settings.reasoning_pattern is None + + def test_temperature_range_validation(self): + """Test temperature range validation (0.0 to 2.0).""" + settings1 = SettingsModel(temperature=0.0) + assert settings1.temperature == 0.0 + + settings2 = SettingsModel(temperature=2.0) + assert settings2.temperature == 2.0 + + with pytest.raises(ValidationError): + SettingsModel(temperature=-0.1) + + with pytest.raises(ValidationError): + SettingsModel(temperature=2.1) + + def test_max_retries_validation(self): + """Test max_retries must be >= 0.""" + settings = SettingsModel(max_retries=0) + assert settings.max_retries == 0 + + settings = SettingsModel(max_retries=5) + assert settings.max_retries == 5 + + with pytest.raises(ValidationError): + SettingsModel(max_retries=-1) + + def test_reasoning_pattern_validation(self): + """Test reasoning_pattern must be one of the allowed values.""" + from typing import Literal + + for pattern in ['DIRECT', 'REACT', 'COT']: + settings = SettingsModel( + reasoning_pattern=cast(Literal['DIRECT', 'REACT', 'COT'], pattern) + ) + assert settings.reasoning_pattern == pattern + + with pytest.raises(ValidationError): + SettingsModel(reasoning_pattern='INVALID') # type: ignore[arg-type] + + +class TestToolConfigModel: + """Test cases for ToolConfigModel validation.""" + + def test_valid_tool_config(self): + """Test valid tool configuration.""" + tool_config = ToolConfigModel( + name='test_tool', + prefilled_params={'param1': 'value1'}, + name_override='custom_tool', + description_override='Custom description', + ) + assert tool_config.name == 'test_tool' + assert tool_config.prefilled_params == {'param1': 'value1'} + assert tool_config.name_override == 'custom_tool' + assert tool_config.description_override == 'Custom description' + + def test_tool_config_missing_name(self): + """Test that tool config must have name.""" + with pytest.raises(ValidationError): + ToolConfigModel( + prefilled_params={'param1': 'value1'}, + ) + + def test_tool_config_minimal(self): + """Test tool config with only required field.""" + tool_config = ToolConfigModel(name='test_tool') + assert tool_config.name == 'test_tool' + assert tool_config.prefilled_params is None + assert tool_config.name_override is None + assert tool_config.description_override is None + + +class TestAgentConfigModel: + """Test cases for AgentConfigModel validation.""" + + def test_valid_agent_with_job(self): + """Test valid agent configuration with job field.""" + agent = AgentConfigModel( + name='Test Agent', + job='You are a helpful assistant', + ) + assert agent.name == 'Test Agent' + assert agent.job == 'You are a helpful assistant' + + def test_valid_agent_with_prompt(self): + """Test valid agent configuration with prompt field.""" + agent = AgentConfigModel( + name='Test Agent', + prompt='You are a helpful assistant', + ) + assert agent.name == 'Test Agent' + assert agent.prompt == 'You are a helpful assistant' + + def test_agent_missing_job_and_prompt(self): + """Test that agent must have either job or prompt.""" + with pytest.raises( + ValueError, match="must have either 'job' or 'prompt' field" + ): + AgentConfigModel(name='Test Agent') + + def test_agent_with_both_job_and_prompt(self): + """Test agent with both job and prompt (job takes precedence).""" + agent = AgentConfigModel( + name='Test Agent', + job='Job description', + prompt='Prompt description', + ) + assert agent.job == 'Job description' + assert agent.prompt == 'Prompt description' + + def test_agent_missing_name(self): + """Test that agent must have name.""" + with pytest.raises(ValidationError): + AgentConfigModel(job='You are a helpful assistant') + + def test_agent_with_model(self): + """Test agent with model configuration.""" + agent = AgentConfigModel( + name='Test Agent', + job='You are a helpful assistant', + model=LLMConfigModel(provider='openai', name='gpt-4'), + ) + assert agent.model is not None + assert agent.model.provider == 'openai' + assert agent.model.name == 'gpt-4' + + def test_agent_with_settings(self): + """Test agent with settings.""" + agent = AgentConfigModel( + name='Test Agent', + job='You are a helpful assistant', + settings=SettingsModel(temperature=0.7, max_retries=3), + ) + assert agent.settings is not None + assert agent.settings.temperature == 0.7 + assert agent.settings.max_retries == 3 + + def test_agent_with_tools_string_list(self): + """Test agent with tools as string list.""" + agent = AgentConfigModel( + name='Test Agent', + job='You are a helpful assistant', + tools=['tool1', 'tool2'], + ) + assert agent.tools == ['tool1', 'tool2'] + + def test_agent_with_tools_config_list(self): + """Test agent with tools as ToolConfigModel list.""" + agent = AgentConfigModel( + name='Test Agent', + job='You are a helpful assistant', + tools=[ + ToolConfigModel(name='tool1'), + ToolConfigModel(name='tool2', prefilled_params={'param': 'value'}), + ], + ) + assert agent.tools is not None + assert len(agent.tools) == 2 + assert isinstance(agent.tools[0], ToolConfigModel) + assert isinstance(agent.tools[1], ToolConfigModel) + assert agent.tools[0].name == 'tool1' + assert agent.tools[1].name == 'tool2' + + def test_agent_with_mixed_tools(self): + """Test agent with mixed tool types (strings and configs).""" + agent = AgentConfigModel( + name='Test Agent', + job='You are a helpful assistant', + tools=[ + 'tool1', + ToolConfigModel(name='tool2'), + ], + ) + assert agent.tools is not None + assert len(agent.tools) == 2 + assert agent.tools[0] == 'tool1' + assert isinstance(agent.tools[1], ToolConfigModel) + assert agent.tools[1].name == 'tool2' + + def test_agent_tools_missing_name(self): + """Test that tool configs in tools list must have name.""" + with pytest.raises(ValueError, match="must have a 'name' field"): + AgentConfigModel( + name='Test Agent', + job='You are a helpful assistant', + tools=[{'prefilled_params': {'param': 'value'}}], # type: ignore[arg-type] + ) + + def test_agent_tools_invalid_type(self): + """Test that tools must be strings or dicts with name.""" + with pytest.raises(ValueError, match='Invalid tool configuration type'): + AgentConfigModel( + name='Test Agent', + job='You are a helpful assistant', + tools=[123], # type: ignore[arg-type] + ) + + def test_agent_with_parser(self): + """Test agent with parser configuration.""" + agent = AgentConfigModel( + name='Test Agent', + job='You are a helpful assistant', + parser=ParserModel( + name='test_parser', + fields=[ + ParserFieldModel( + name='query', + type='str', + description='Query string', + ), + ], + ), + ) + assert agent.parser is not None + assert agent.parser.name == 'test_parser' + + def test_agent_with_examples(self): + """Test agent with examples.""" + agent = AgentConfigModel( + name='Test Agent', + job='You are a helpful assistant', + examples=[ + ExampleModel(input='Hello', output='Hi there!'), + ExampleModel(input='How are you?', output={'status': 'good'}), + ], + ) + assert agent.examples is not None + assert len(agent.examples) == 2 + assert agent.examples[0].input == 'Hello' + assert agent.examples[1].output == {'status': 'good'} + + +class TestAgentYamlModel: + """Test cases for AgentYamlModel validation.""" + + def test_valid_yaml_minimal(self): + """Test valid minimal YAML configuration.""" + yaml_data = { + 'agent': { + 'name': 'Test Agent', + 'job': 'You are a helpful assistant', + } + } + config = AgentYamlModel(**yaml_data) # type: ignore[arg-type] + assert config.agent.name == 'Test Agent' + assert config.agent.job == 'You are a helpful assistant' + assert config.metadata is None + + def test_valid_yaml_with_metadata(self): + """Test valid YAML with metadata.""" + yaml_data = { + 'metadata': { + 'name': 'test-agent', + 'version': '1.0.0', + 'description': 'A test agent', + }, + 'agent': { + 'name': 'Test Agent', + 'job': 'You are a helpful assistant', + }, + } + config = AgentYamlModel(**yaml_data) # type: ignore[arg-type] + assert config.metadata is not None + assert config.metadata.name == 'test-agent' + assert config.metadata.version == '1.0.0' + assert config.agent.name == 'Test Agent' + + def test_valid_yaml_full_config(self): + """Test valid YAML with all fields.""" + yaml_data = { + 'metadata': { + 'name': 'test-agent', + 'version': '1.0.0', + 'description': 'A test agent', + 'author': 'Test Author', + 'tags': ['test', 'agent'], + }, + 'agent': { + 'name': 'Test Agent', + 'job': 'You are a helpful assistant', + 'role': 'assistant', + 'act_as': 'helpful AI', + 'model': { + 'provider': 'openai', + 'name': 'gpt-4', + 'temperature': 0.7, + 'max_tokens': 1000, + }, + 'settings': { + 'temperature': 0.8, + 'max_retries': 3, + 'reasoning_pattern': 'REACT', + }, + 'tools': [ + 'tool1', + {'name': 'tool2', 'prefilled_params': {'param': 'value'}}, + ], + 'parser': { + 'name': 'test_parser', + 'fields': [ + { + 'name': 'query', + 'type': 'str', + 'description': 'Query string', + }, + ], + }, + 'examples': [ + {'input': 'Hello', 'output': 'Hi there!'}, + ], + }, + } + config = AgentYamlModel(**yaml_data) # type: ignore[arg-type] + assert config.metadata is not None + assert config.metadata.name == 'test-agent' + assert config.agent.name == 'Test Agent' + assert config.agent.model is not None + assert config.agent.model.provider == 'openai' + assert config.agent.settings is not None + assert config.agent.settings.max_retries == 3 + assert config.agent.tools is not None + assert len(config.agent.tools) == 2 + assert config.agent.parser is not None + assert config.agent.parser.name == 'test_parser' + assert config.agent.examples is not None + assert len(config.agent.examples) == 1 + + def test_yaml_missing_agent(self): + """Test that YAML must have agent section.""" + yaml_data = { + 'metadata': { + 'name': 'test-agent', + }, + } + with pytest.raises(ValidationError): + AgentYamlModel(**yaml_data) # type: ignore[arg-type] + + def test_yaml_from_string(self): + """Test parsing YAML from string.""" + yaml_str = """ +metadata: + name: test-agent + version: 1.0.0 + +agent: + name: Test Agent + job: You are a helpful assistant + model: + provider: openai + name: gpt-4 +""" + yaml_data = yaml.safe_load(yaml_str) + config = AgentYamlModel(**yaml_data) # type: ignore[arg-type] + assert config.metadata is not None + assert config.metadata.name == 'test-agent' + assert config.agent.name == 'Test Agent' + assert config.agent.model is not None + assert config.agent.model.provider == 'openai' + + def test_yaml_complex_parser(self): + """Test YAML with complex nested parser.""" + yaml_str = """ +agent: + name: Test Agent + job: You are a helpful assistant + parser: + name: complex_parser + fields: + - name: user + type: object + description: User object + fields: + - name: name + type: str + description: User name + - name: addresses + type: array + description: User addresses + items: + name: address + type: object + description: Address object + fields: + - name: street + type: str + description: Street name + - name: city + type: str + description: City name +""" + yaml_data = yaml.safe_load(yaml_str) + config = AgentYamlModel(**yaml_data) # type: ignore[arg-type] + assert config.agent.parser is not None + assert config.agent.parser.fields[0].type == 'object' + assert config.agent.parser.fields[0].fields is not None + assert config.agent.parser.fields[0].fields[1].type == 'array' + assert config.agent.parser.fields[0].fields[1].items is not None + assert config.agent.parser.fields[0].fields[1].items.type == 'object' + + def test_yaml_literal_parser_field(self): + """Test YAML with literal parser field.""" + yaml_str = """ +agent: + name: Test Agent + job: You are a helpful assistant + parser: + name: status_parser + fields: + - name: status + type: literal + description: Status value + values: + - value: active + description: Active status + - value: inactive + description: Inactive status +""" + yaml_data = yaml.safe_load(yaml_str) + config = AgentYamlModel(**yaml_data) # type: ignore[arg-type] + assert config.agent.parser is not None + assert config.agent.parser.fields[0].type == 'literal' + assert config.agent.parser.fields[0].values is not None + assert len(config.agent.parser.fields[0].values) == 2 + assert config.agent.parser.fields[0].values[0].value == 'active' + + def test_yaml_all_providers(self): + """Test YAML with different model providers.""" + providers = [ + ('openai', {'name': 'gpt-4'}), + ('anthropic', {'name': 'claude-3-opus-20240229'}), + ('claude', {'name': 'claude-3-opus-20240229'}), + ('gemini', {'name': 'gemini-pro'}), + ('google', {'name': 'gemini-pro'}), + ('ollama', {'name': 'llama2', 'base_url': 'http://localhost:11434'}), + ( + 'vertexai', + { + 'name': 'gemini-pro', + 'project': 'my-project', + 'base_url': 'https://us-central1-aiplatform.googleapis.com', + }, + ), + ('rootflo', {'model_id': 'model-123'}), + ( + 'openai_vllm', + { + 'name': 'gpt-4', + 'base_url': 'http://localhost:8000/v1', + 'api_key': 'sk-test', + }, + ), + ] + + for provider, provider_config in providers: + yaml_data = { + 'agent': { + 'name': 'Test Agent', + 'job': 'You are a helpful assistant', + 'model': {'provider': provider, **provider_config}, + } + } + config = AgentYamlModel(**yaml_data) # type: ignore[arg-type] + assert config.agent.model is not None + assert config.agent.model.provider == provider diff --git a/flo_ai/tests/unit-tests/test_arium_yaml.py b/flo_ai/tests/unit-tests/test_arium_yaml.py index e3492f4e..f62518e2 100644 --- a/flo_ai/tests/unit-tests/test_arium_yaml.py +++ b/flo_ai/tests/unit-tests/test_arium_yaml.py @@ -1,5 +1,11 @@ """ -Tests for YAML-based Arium workflow construction. +Integration tests for YAML-based Arium workflow construction. + +This module tests the actual building and integration of arium workflows from YAML, +including function nodes, routers, pre-built agents, and complex workflows. + +Note: YAML structure validation is tested in test_arium_yaml_validation.py. +This file focuses on integration and runtime behavior. """ import pytest @@ -28,15 +34,6 @@ def test_from_yaml_validation_both_params(self): ): AriumBuilder.from_yaml(yaml_str='test', yaml_file='test.yaml') - def test_from_yaml_validation_missing_arium_section(self): - """Test that from_yaml fails when YAML doesn't contain arium section.""" - yaml_config = """ - metadata: - name: test - """ - with pytest.raises(ValueError, match='YAML must contain an "arium" section'): - AriumBuilder.from_yaml(yaml_str=yaml_config) - def test_from_yaml_simple_configuration(self): """Test basic YAML configuration parsing.""" yaml_config = """ @@ -406,76 +403,6 @@ def test_from_yaml_missing_router_error(self): ): AriumBuilder.from_yaml(yaml_str=yaml_config, routers={}) - def test_from_yaml_missing_start_node_error(self): - """Test error when workflow doesn't specify start node.""" - yaml_config = """ - arium: - agents: - - name: test_agent - yaml_config: | - agent: - name: test_agent - job: "Test agent" - model: - provider: openai - name: gpt-4o-mini - - workflow: - edges: [] - end: [test_agent] - """ - - with patch('flo_ai.arium.builder.AgentBuilder'): - with pytest.raises(ValueError, match='Workflow must specify a start node'): - AriumBuilder.from_yaml(yaml_str=yaml_config) - - def test_from_yaml_missing_end_nodes_error(self): - """Test error when workflow doesn't specify end nodes.""" - yaml_config = """ - arium: - agents: - - name: test_agent - yaml_config: | - agent: - name: test_agent - job: "Test agent" - model: - provider: openai - name: gpt-4o-mini - - workflow: - start: test_agent - edges: - - from: test_agent - to: [end] - """ - - with patch('flo_ai.arium.builder.AgentBuilder'): - with pytest.raises(ValueError, match='Workflow must specify end nodes'): - AriumBuilder.from_yaml(yaml_str=yaml_config) - - def test_from_yaml_invalid_agent_config_error(self): - """Test error when agent doesn't have yaml_config or yaml_file.""" - yaml_config = """ - arium: - agents: - - name: invalid_agent - # Missing yaml_config or yaml_file - - workflow: - start: invalid_agent - edges: - - from: invalid_agent - to: [end] - end: [invalid_agent] - """ - - with pytest.raises( - ValueError, - match='Agent invalid_agent not found in provided agents dictionary', - ): - AriumBuilder.from_yaml(yaml_str=yaml_config) - def test_from_yaml_external_file_reference(self): """Test YAML configuration with external agent file reference.""" yaml_config = """ @@ -933,51 +860,9 @@ def test_from_yaml_mixed_configuration_methods(self): assert any(a.name == 'yaml_agent' for a in builder._agents) assert any(a.name == 'file_agent' for a in builder._agents) - def test_from_yaml_direct_config_validation_errors(self): - """Test validation errors for direct agent configuration.""" - - # Test missing required field - yaml_config_missing_job = """ - arium: - agents: - - name: test_agent - role: Test Agent - # missing job field - model: - provider: openai - name: gpt-4o-mini - workflow: - start: test_agent - edges: [] - end: [test_agent] - """ - - with pytest.raises(ValueError, match='Agent test_agent must have either'): - AriumBuilder.from_yaml(yaml_str=yaml_config_missing_job) - - # Test invalid reasoning pattern - yaml_config_invalid_pattern = """ - arium: - agents: - - name: test_agent - job: "Test agent" - model: - provider: openai - name: gpt-4o-mini - settings: - reasoning_pattern: INVALID_PATTERN - workflow: - start: test_agent - edges: [] - end: [test_agent] - """ - - with patch('flo_ai.llm.OpenAI'): - with pytest.raises(ValueError, match='Invalid reasoning pattern'): - AriumBuilder.from_yaml(yaml_str=yaml_config_invalid_pattern) - - # Test missing model when no base_llm provided - yaml_config_missing_model = """ + def test_from_yaml_direct_config_missing_model_error(self): + """Test error when agent is missing model and no base_llm is provided.""" + yaml_config = """ arium: agents: - name: test_agent @@ -990,7 +875,7 @@ def test_from_yaml_direct_config_validation_errors(self): """ with pytest.raises(ValueError, match='Model must be specified'): - AriumBuilder.from_yaml(yaml_str=yaml_config_missing_model) + AriumBuilder.from_yaml(yaml_str=yaml_config) def test_from_yaml_direct_config_with_base_llm(self): """Test direct agent configuration with base LLM override.""" @@ -1175,27 +1060,6 @@ def test_from_yaml_mixed_prebuilt_and_configured_agents(self): # Check YAML agent was added assert mock_yaml_agent in builder._agents - def test_from_yaml_prebuilt_agents_parameter_validation(self): - """Test parameter validation for pre-built agents.""" - yaml_config = """ - arium: - agents: - - name: test_agent - # Has additional fields, so not a pure reference - role: "Some Role" - - workflow: - start: test_agent - edges: [] - end: [test_agent] - """ - - # This should not be treated as a pre-built agent reference - # because it has additional fields beyond just 'name' - with patch('flo_ai.llm.OpenAI'): - with pytest.raises(ValueError, match='Agent test_agent must have either'): - AriumBuilder.from_yaml(yaml_str=yaml_config) - def test_from_yaml_prebuilt_agents_with_function_nodes_and_routers(self): """Test pre-built agents working together with function nodes and routers.""" yaml_config = """ diff --git a/flo_ai/tests/unit-tests/test_arium_yaml_validation.py b/flo_ai/tests/unit-tests/test_arium_yaml_validation.py new file mode 100644 index 00000000..dbd7012e --- /dev/null +++ b/flo_ai/tests/unit-tests/test_arium_yaml_validation.py @@ -0,0 +1,837 @@ +""" +Tests for arium YAML validation using Pydantic models. + +This module tests all validation rules defined in flo_ai.models.arium, +ensuring that YAML configurations are properly validated before being used. +""" + +import pytest +import yaml +from pydantic import ValidationError + +from flo_ai.models.arium import ( + AriumYamlModel, + AriumConfigModel, + AriumAgentConfigModel, + FunctionNodeConfigModel, + RouterConfigModel, + RouterSettingsModel, + EdgeConfigModel, + WorkflowConfigModel, + AriumNodeConfigModel, + ForEachNodeConfigModel, +) +from flo_ai.models.agent import LLMConfigModel + + +class TestAriumAgentConfigModel: + """Test cases for AriumAgentConfigModel validation.""" + + def test_valid_agent_with_direct_config(self): + """Test valid agent with direct configuration.""" + agent = AriumAgentConfigModel( + name='test_agent', + role='Test Role', + job='You are a test agent', + model=LLMConfigModel(provider='openai', name='gpt-4o-mini'), + ) + assert agent.name == 'test_agent' + assert agent.role == 'Test Role' + assert agent.job == 'You are a test agent' + assert agent.model is not None + + def test_valid_agent_reference_only(self): + """Test valid agent with only name (reference to pre-built agent).""" + agent = AriumAgentConfigModel(name='prebuilt_agent') + assert agent.name == 'prebuilt_agent' + assert agent.job is None + assert agent.yaml_config is None + assert agent.yaml_file is None + + def test_valid_agent_with_yaml_config(self): + """Test valid agent with inline yaml_config.""" + agent = AriumAgentConfigModel( + name='yaml_agent', + yaml_config='agent:\n name: yaml_agent\n job: Test job', + ) + assert agent.name == 'yaml_agent' + assert agent.yaml_config is not None + assert agent.job is None + + def test_valid_agent_with_yaml_file(self): + """Test valid agent with yaml_file reference.""" + agent = AriumAgentConfigModel( + name='file_agent', + yaml_file='path/to/agent.yaml', + ) + assert agent.name == 'file_agent' + assert agent.yaml_file == 'path/to/agent.yaml' + assert agent.job is None + + def test_agent_multiple_config_methods(self): + """Test that agent cannot have multiple configuration methods.""" + with pytest.raises(ValueError, match='multiple configuration methods'): + AriumAgentConfigModel( + name='invalid_agent', + job='Direct job', + yaml_config='agent:\n name: test', + ) + + def test_agent_job_and_prompt_alias(self): + """Test that job and prompt are aliases.""" + agent1 = AriumAgentConfigModel( + name='agent1', + job='Test job', + model=LLMConfigModel(provider='openai', name='gpt-4o-mini'), + ) + agent2 = AriumAgentConfigModel( + name='agent2', + prompt='Test job', + model=LLMConfigModel(provider='openai', name='gpt-4o-mini'), + ) + assert agent1.job == 'Test job' + assert agent2.prompt == 'Test job' + + +class TestFunctionNodeConfigModel: + """Test cases for FunctionNodeConfigModel validation.""" + + def test_valid_function_node(self): + """Test valid function node configuration.""" + node = FunctionNodeConfigModel( + name='preprocessor', + function_name='preprocess_data', + description='Preprocesses input data', + input_filter=['input1', 'input2'], + prefilled_params={'param1': 'value1'}, + ) + assert node.name == 'preprocessor' + assert node.function_name == 'preprocess_data' + assert node.description == 'Preprocesses input data' + assert node.input_filter == ['input1', 'input2'] + assert node.prefilled_params == {'param1': 'value1'} + + def test_minimal_function_node(self): + """Test function node with only required fields.""" + node = FunctionNodeConfigModel( + name='simple_node', + function_name='simple_function', + ) + assert node.name == 'simple_node' + assert node.function_name == 'simple_function' + assert node.description is None + assert node.input_filter is None + assert node.prefilled_params is None + + +class TestRouterConfigModel: + """Test cases for RouterConfigModel validation.""" + + def test_valid_smart_router(self): + """Test valid smart router configuration.""" + router = RouterConfigModel( + name='content_router', + type='smart', + routing_options={ + 'technical_writer': 'Handle technical documentation', + 'creative_writer': 'Handle creative writing', + }, + model=LLMConfigModel(provider='openai', name='gpt-4o-mini'), + ) + assert router.name == 'content_router' + assert router.type == 'smart' + assert router.routing_options is not None + assert len(router.routing_options) == 2 + + def test_valid_reflection_router(self): + """Test valid reflection router configuration.""" + router = RouterConfigModel( + name='reflection_router', + type='reflection', + flow_pattern=['agent1', 'critic', 'agent1', 'final'], + settings=RouterSettingsModel(allow_early_exit=False), + ) + assert router.type == 'reflection' + assert router.flow_pattern == ['agent1', 'critic', 'agent1', 'final'] + + def test_valid_plan_execute_router(self): + """Test valid plan-execute router configuration.""" + router = RouterConfigModel( + name='plan_execute_router', + type='plan_execute', + agents={ + 'planner': 'Creates plans', + 'developer': 'Implements code', + 'tester': 'Tests code', + }, + settings=RouterSettingsModel( + planner_agent='planner', + executor_agent='developer', + reviewer_agent='tester', + ), + ) + assert router.type == 'plan_execute' + assert router.agents is not None + assert len(router.agents) == 3 + + def test_smart_router_missing_routing_options(self): + """Test that smart router must have routing_options.""" + with pytest.raises(ValueError, match="must specify 'routing_options'"): + RouterConfigModel( + name='invalid_router', + type='smart', + ) + + def test_reflection_router_missing_flow_pattern(self): + """Test that reflection router must have flow_pattern.""" + with pytest.raises(ValueError, match="must specify 'flow_pattern'"): + RouterConfigModel( + name='invalid_router', + type='reflection', + ) + + def test_plan_execute_router_missing_agents(self): + """Test that plan_execute router must have agents.""" + with pytest.raises(ValueError, match="must specify 'agents'"): + RouterConfigModel( + name='invalid_router', + type='plan_execute', + ) + + +class TestEdgeConfigModel: + """Test cases for EdgeConfigModel validation.""" + + def test_valid_edge(self): + """Test valid edge configuration.""" + edge = EdgeConfigModel.model_validate( + {'from': 'agent1', 'to': ['agent2', 'agent3'], 'router': 'content_router'} + ) + assert edge.from_ == 'agent1' + assert edge.to == ['agent2', 'agent3'] + assert edge.router == 'content_router' + + def test_edge_with_from_alias(self): + """Test edge using 'from' and 'to' aliases.""" + edge_dict = {'from': 'agent1', 'to': ['agent2']} + # When using dict, only 'from' and 'to' aliases are allowed (populate_by_name=False) + edge = EdgeConfigModel.model_validate(edge_dict) + assert edge.from_ == 'agent1' + assert edge.to == ['agent2'] + + def test_edge_without_router(self): + """Test edge without router (direct connection).""" + edge = EdgeConfigModel.model_validate({'from': 'agent1', 'to': ['agent2']}) + assert edge.router is None + + +class TestWorkflowConfigModel: + """Test cases for WorkflowConfigModel validation.""" + + def test_valid_workflow(self): + """Test valid workflow configuration.""" + workflow = WorkflowConfigModel( + start='agent1', + edges=[ + EdgeConfigModel.model_validate({'from': 'agent1', 'to': ['agent2']}), + EdgeConfigModel.model_validate({'from': 'agent2', 'to': ['agent3']}), + ], + end=['agent3'], + ) + assert workflow.start == 'agent1' + assert len(workflow.edges) == 2 + assert workflow.end == ['agent3'] + + def test_workflow_with_from_alias_in_edges(self): + """Test workflow with edges using 'from' and 'to' aliases.""" + workflow_dict = { + 'start': 'agent1', + 'edges': [ + {'from': 'agent1', 'to': ['agent2']}, + {'from': 'agent2', 'to': ['agent3']}, + ], + 'end': ['agent3'], + } + workflow = WorkflowConfigModel.model_validate(workflow_dict) + assert workflow.start == 'agent1' + assert len(workflow.edges) == 2 + assert workflow.edges[0].from_ == 'agent1' + + +class TestAriumNodeConfigModel: + """Test cases for AriumNodeConfigModel validation.""" + + def test_valid_arium_node_with_yaml_file(self): + """Test valid arium node with yaml_file reference.""" + node = AriumNodeConfigModel( + name='nested_workflow', + yaml_file='path/to/workflow.yaml', + inherit_variables=True, + ) + assert node.name == 'nested_workflow' + assert node.yaml_file == 'path/to/workflow.yaml' + assert node.inherit_variables is True + + def test_valid_arium_node_with_inline_config(self): + """Test valid arium node with inline configuration.""" + node = AriumNodeConfigModel( + name='inline_workflow', + agents=[ + AriumAgentConfigModel( + name='agent1', + job='Test job', + model=LLMConfigModel(provider='openai', name='gpt-4o-mini'), + ) + ], + workflow=WorkflowConfigModel( + start='agent1', + edges=[], + end=['agent1'], + ), + ) + assert node.name == 'inline_workflow' + assert node.agents is not None + assert len(node.agents) == 1 + assert node.workflow is not None + + def test_arium_node_missing_config(self): + """Test that arium node must have either yaml_file or inline config.""" + with pytest.raises( + ValueError, match="must have either 'yaml_file' or inline configuration" + ): + AriumNodeConfigModel( + name='invalid_node', + ) + + def test_arium_node_both_configs(self): + """Test that arium node cannot have both yaml_file and inline config.""" + with pytest.raises( + ValueError, match="cannot have both 'yaml_file' and inline configuration" + ): + AriumNodeConfigModel( + name='invalid_node', + yaml_file='path/to/file.yaml', + agents=[ + AriumAgentConfigModel( + name='agent1', + job='Test', + model=LLMConfigModel(provider='openai', name='gpt-4o-mini'), + ) + ], + workflow=WorkflowConfigModel( + start='agent1', + edges=[], + end=['agent1'], + ), + ) + + def test_arium_node_inline_config_missing_workflow(self): + """Test that inline arium node must have workflow.""" + with pytest.raises(ValueError, match="must specify 'workflow'"): + AriumNodeConfigModel( + name='invalid_node', + agents=[ + AriumAgentConfigModel( + name='agent1', + job='Test', + model=LLMConfigModel(provider='openai', name='gpt-4o-mini'), + ) + ], + ) + + +class TestForEachNodeConfigModel: + """Test cases for ForEachNodeConfigModel validation.""" + + def test_valid_foreach_node(self): + """Test valid foreach node configuration.""" + node = ForEachNodeConfigModel( + name='batch_processor', + execute_node='processor_agent', + ) + assert node.name == 'batch_processor' + assert node.execute_node == 'processor_agent' + + +class TestAriumConfigModel: + """Test cases for AriumConfigModel validation.""" + + def test_valid_arium_config_with_agents(self): + """Test valid arium configuration with agents.""" + config = AriumConfigModel( + agents=[ + AriumAgentConfigModel( + name='agent1', + job='Test job', + model=LLMConfigModel(provider='openai', name='gpt-4o-mini'), + ) + ], + workflow=WorkflowConfigModel( + start='agent1', + edges=[], + end=['agent1'], + ), + ) + assert config.agents is not None + assert len(config.agents) == 1 + assert config.workflow is not None + + def test_valid_arium_config_with_function_nodes(self): + """Test valid arium configuration with function nodes.""" + config = AriumConfigModel( + function_nodes=[ + FunctionNodeConfigModel( + name='func1', + function_name='function1', + ) + ], + workflow=WorkflowConfigModel( + start='func1', + edges=[], + end=['func1'], + ), + ) + assert config.function_nodes is not None + assert len(config.function_nodes) == 1 + + def test_valid_arium_config_with_iterators(self): + """Test valid arium configuration with iterators.""" + config = AriumConfigModel( + agents=[ + AriumAgentConfigModel( + name='processor', + job='Process', + model=LLMConfigModel(provider='openai', name='gpt-4o-mini'), + ) + ], + iterators=[ + ForEachNodeConfigModel( + name='batch_processor', + execute_node='processor', + ) + ], + workflow=WorkflowConfigModel( + start='batch_processor', + edges=[], + end=['batch_processor'], + ), + ) + assert config.iterators is not None + assert len(config.iterators) == 1 + + def test_arium_config_foreach_nodes_alias(self): + """Test that foreach_nodes and iterators are aliases.""" + config_dict = { + 'agents': [ + { + 'name': 'processor', + 'job': 'Process', + 'model': {'provider': 'openai', 'name': 'gpt-4o-mini'}, + } + ], + 'foreach_nodes': [{'name': 'batch', 'execute_node': 'processor'}], + 'workflow': { + 'start': 'batch', + 'edges': [], + 'end': ['batch'], + }, + } + config = AriumConfigModel.model_validate(config_dict) + # foreach_nodes should be merged into iterators + assert config.iterators is not None + assert config.iterators is not None and len(config.iterators) == 1 + + def test_arium_config_no_nodes(self): + """Test that arium config must have at least one node type.""" + with pytest.raises(ValueError, match='must have at least one of'): + AriumConfigModel( + workflow=WorkflowConfigModel( + start='agent1', + edges=[], + end=['agent1'], + ), + ) + + +class TestAriumYamlModel: + """Test cases for AriumYamlModel validation.""" + + def test_valid_arium_yaml_minimal(self): + """Test valid minimal arium YAML.""" + yaml_data = { + 'arium': { + 'agents': [ + { + 'name': 'agent1', + 'job': 'Test job', + 'model': {'provider': 'openai', 'name': 'gpt-4o-mini'}, + } + ], + 'workflow': { + 'start': 'agent1', + 'edges': [], + 'end': ['agent1'], + }, + } + } + config = AriumYamlModel.model_validate(yaml_data) + assert config.arium is not None + assert config.arium.agents is not None + assert config.arium.agents is not None and len(config.arium.agents) == 1 + + def test_valid_arium_yaml_with_metadata(self): + """Test valid arium YAML with metadata.""" + yaml_data = { + 'metadata': { + 'name': 'test-workflow', + 'version': '1.0.0', + 'description': 'Test workflow', + }, + 'arium': { + 'agents': [ + { + 'name': 'agent1', + 'job': 'Test job', + 'model': {'provider': 'openai', 'name': 'gpt-4o-mini'}, + } + ], + 'workflow': { + 'start': 'agent1', + 'edges': [], + 'end': ['agent1'], + }, + }, + } + config = AriumYamlModel.model_validate(yaml_data) + assert config.metadata is not None + assert config.metadata.name == 'test-workflow' + assert config.arium is not None + + def test_valid_arium_yaml_complex(self): + """Test valid complex arium YAML with all components.""" + yaml_data = { + 'metadata': { + 'name': 'complex-workflow', + 'version': '1.0.0', + }, + 'arium': { + 'agents': [ + { + 'name': 'agent1', + 'job': 'Agent 1', + 'model': {'provider': 'openai', 'name': 'gpt-4o-mini'}, + }, + {'name': 'agent2'}, # Reference to pre-built agent + ], + 'function_nodes': [ + { + 'name': 'func1', + 'function_name': 'function1', + } + ], + 'routers': [ + { + 'name': 'router1', + 'type': 'smart', + 'routing_options': { + 'agent1': 'Handle type 1', + 'agent2': 'Handle type 2', + }, + } + ], + 'workflow': { + 'start': 'agent1', + 'edges': [ + { + 'from': 'agent1', + 'to': ['agent2', 'func1'], + 'router': 'router1', + }, + {'from': 'agent2', 'to': ['end']}, + {'from': 'func1', 'to': ['end']}, + ], + 'end': ['agent2', 'func1'], + }, + }, + } + config = AriumYamlModel.model_validate(yaml_data) + assert config.arium.agents is not None + assert config.arium.agents is not None and len(config.arium.agents) == 2 + assert config.arium.function_nodes is not None + assert config.arium.routers is not None + assert len(config.arium.routers) == 1 + assert len(config.arium.workflow.edges) == 3 + + def test_arium_yaml_missing_arium_section(self): + """Test that YAML must have arium section.""" + with pytest.raises(ValidationError): + AriumYamlModel.model_validate({}) + + def test_arium_yaml_from_string(self): + """Test parsing arium YAML from string.""" + yaml_str = """ +metadata: + name: test-workflow + version: 1.0.0 + +arium: + agents: + - name: agent1 + job: Test job + model: + provider: openai + name: gpt-4o-mini + workflow: + start: agent1 + edges: + - from: agent1 + to: [end] + end: [agent1] +""" + yaml_data = yaml.safe_load(yaml_str) + config = AriumYamlModel.model_validate(yaml_data) + assert config.metadata is not None + assert config.metadata.name == 'test-workflow' + assert config.arium.agents is not None + assert config.arium.agents is not None and len(config.arium.agents) == 1 + + +class TestAriumBuilderValidation: + """Test cases for AriumBuilder YAML validation integration.""" + + def test_builder_validation_method(self): + """Test that AriumBuilder._validate_yaml_config works correctly.""" + from flo_ai.arium.builder import AriumBuilder + + valid_config = { + 'arium': { + 'agents': [ + { + 'name': 'agent1', + 'job': 'Test job', + 'model': {'provider': 'openai', 'name': 'gpt-4o-mini'}, + } + ], + 'workflow': { + 'start': 'agent1', + 'edges': [], + 'end': ['agent1'], + }, + } + } + # Should not raise validation error + validated = AriumBuilder._validate_yaml_config(valid_config) + assert validated is not None + assert validated.arium is not None + assert validated.arium.agents is not None + assert len(validated.arium.agents) == 1 + + def test_builder_validation_error_formatting(self): + """Test that validation errors are properly formatted.""" + from flo_ai.arium.builder import AriumBuilder + + # Invalid router missing routing_options + invalid_config = { + 'arium': { + 'routers': [ + { + 'name': 'invalid_router', + 'type': 'smart', + # Missing routing_options + } + ], + 'workflow': { + 'start': 'agent1', + 'edges': [], + 'end': ['agent1'], + }, + } + } + with pytest.raises(ValueError, match='YAML validation failed'): + AriumBuilder._validate_yaml_config(invalid_config) + + def test_builder_validates_workflow_structure(self): + """Test that builder validates workflow structure.""" + from flo_ai.arium.builder import AriumBuilder + + # Missing workflow section + invalid_config = { + 'arium': { + 'agents': [ + { + 'name': 'agent1', + 'job': 'Test', + 'model': {'provider': 'openai', 'name': 'gpt-4o-mini'}, + } + ], + # Missing workflow + } + } + with pytest.raises(ValueError, match='YAML validation failed'): + AriumBuilder._validate_yaml_config(invalid_config) + + def test_builder_validates_edge_structure(self): + """Test that builder validates edge structure.""" + from flo_ai.arium.builder import AriumBuilder + + # Edge missing 'to' field + invalid_config = { + 'arium': { + 'agents': [ + { + 'name': 'agent1', + 'job': 'Test', + 'model': {'provider': 'openai', 'name': 'gpt-4o-mini'}, + } + ], + 'workflow': { + 'start': 'agent1', + 'edges': [ + {'from': 'agent1'} # Missing 'to' field + ], + 'end': ['agent1'], + }, + } + } + with pytest.raises(ValueError, match='YAML validation failed'): + AriumBuilder._validate_yaml_config(invalid_config) + + def test_builder_validates_missing_start_node(self): + """Test that builder validates missing start node.""" + from flo_ai.arium.builder import AriumBuilder + + invalid_config = { + 'arium': { + 'agents': [ + { + 'name': 'agent1', + 'job': 'Test', + 'model': {'provider': 'openai', 'name': 'gpt-4o-mini'}, + } + ], + 'workflow': { + 'edges': [], + 'end': ['agent1'], + # Missing start + }, + } + } + with pytest.raises(ValueError, match='YAML validation failed'): + AriumBuilder._validate_yaml_config(invalid_config) + + def test_builder_validates_missing_end_nodes(self): + """Test that builder validates missing end nodes.""" + from flo_ai.arium.builder import AriumBuilder + + invalid_config = { + 'arium': { + 'agents': [ + { + 'name': 'agent1', + 'job': 'Test', + 'model': {'provider': 'openai', 'name': 'gpt-4o-mini'}, + } + ], + 'workflow': { + 'start': 'agent1', + 'edges': [], + # Missing end + }, + } + } + with pytest.raises(ValueError, match='YAML validation failed'): + AriumBuilder._validate_yaml_config(invalid_config) + + def test_builder_validates_agent_missing_job_when_direct_config(self): + """Test that builder validates agent missing job when using direct config.""" + from flo_ai.arium.builder import AriumBuilder + + invalid_config = { + 'arium': { + 'agents': [ + { + 'name': 'agent1', + 'role': 'Test Agent', + 'model': {'provider': 'openai', 'name': 'gpt-4o-mini'}, + # Missing job/prompt + } + ], + 'workflow': { + 'start': 'agent1', + 'edges': [], + 'end': ['agent1'], + }, + } + } + with pytest.raises(ValueError, match='YAML validation failed'): + AriumBuilder._validate_yaml_config(invalid_config) + + def test_builder_validates_agent_multiple_config_methods(self): + """Test that builder validates agent cannot have multiple config methods.""" + from flo_ai.arium.builder import AriumBuilder + + invalid_config = { + 'arium': { + 'agents': [ + { + 'name': 'agent1', + 'job': 'Test job', + 'yaml_config': 'agent:\n name: test', + # Multiple config methods + } + ], + 'workflow': { + 'start': 'agent1', + 'edges': [], + 'end': ['agent1'], + }, + } + } + with pytest.raises(ValueError, match='YAML validation failed'): + AriumBuilder._validate_yaml_config(invalid_config) + + def test_builder_validates_function_node_missing_function_name(self): + """Test that builder validates function node has function_name.""" + from flo_ai.arium.builder import AriumBuilder + + invalid_config = { + 'arium': { + 'function_nodes': [ + { + 'name': 'func1', + # Missing function_name + } + ], + 'workflow': { + 'start': 'func1', + 'edges': [], + 'end': ['func1'], + }, + } + } + with pytest.raises(ValueError, match='YAML validation failed'): + AriumBuilder._validate_yaml_config(invalid_config) + + def test_builder_validates_router_type_specific_requirements(self): + """Test that builder validates router type-specific requirements.""" + from flo_ai.arium.builder import AriumBuilder + + # Test task_classifier missing task_categories + invalid_config = { + 'arium': { + 'routers': [ + { + 'name': 'router1', + 'type': 'task_classifier', + # Missing task_categories + } + ], + 'workflow': { + 'start': 'agent1', + 'edges': [], + 'end': ['agent1'], + }, + } + } + with pytest.raises(ValueError, match='YAML validation failed'): + AriumBuilder._validate_yaml_config(invalid_config) diff --git a/flo_ai/uv.lock b/flo_ai/uv.lock index 7b8155ef..f3dc7384 100644 --- a/flo_ai/uv.lock +++ b/flo_ai/uv.lock @@ -897,7 +897,7 @@ wheels = [ [[package]] name = "flo-ai" -version = "1.1.0rc3" +version = "1.1.0rc5" source = { editable = "." } dependencies = [ { name = "aiohttp" }, diff --git a/wavefront/server/modules/voice_agents_module/voice_agents_module/models/telephony_schemas.py b/wavefront/server/modules/voice_agents_module/voice_agents_module/models/telephony_schemas.py index 95a45cff..46daf998 100644 --- a/wavefront/server/modules/voice_agents_module/voice_agents_module/models/telephony_schemas.py +++ b/wavefront/server/modules/voice_agents_module/voice_agents_module/models/telephony_schemas.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field from typing import Optional, Union, Any, Dict, List, Literal from enum import Enum from datetime import datetime @@ -91,12 +91,10 @@ class CreateTelephonyConfigPayload(BaseModel): None, description='SIP configuration (required for SIP connection type)' ) - @model_validator(mode='after') - def validate_connection_type_requirements(self): + def model_post_init(self, __context): """Validate connection type specific requirements""" if self.connection_type == ConnectionType.SIP and not self.sip_config: raise ValueError('sip_config is required for SIP connection type') - return self class UpdateTelephonyConfigPayload(BaseModel): From 453eb2070852b78acf1be5246d675e622f1a5aff Mon Sep 17 00:00:00 2001 From: vishnu r kumar Date: Fri, 12 Dec 2025 22:59:52 +0530 Subject: [PATCH 3/3] fix: review comments --- documentation/quickstart.mdx | 23 +++++-- flo_ai/docs/arium_yaml_guide.md | 2 +- .../examples/example_graph_visualization.py | 4 +- flo_ai/flo_ai/agent/agent.py | 61 +++++++++++++------ flo_ai/flo_ai/agent/builder.py | 4 +- flo_ai/flo_ai/arium/builder.py | 11 ++-- flo_ai/flo_ai/models/agent.py | 9 +++ 7 files changed, 78 insertions(+), 36 deletions(-) diff --git a/documentation/quickstart.mdx b/documentation/quickstart.mdx index b23cfee2..5d432e1e 100644 --- a/documentation/quickstart.mdx +++ b/documentation/quickstart.mdx @@ -11,15 +11,26 @@ Build your first AI agent and understand the core concepts of Flo AI. - Install Flo AI using pip or poetry: ```bash pip install flo-ai # or using - poetry poetry add flo-ai # or using uv uv add flo-ai ``` + Install Flo AI using pip or poetry: + +```bash +pip install flo-ai +poetry add flo-ai +uv add flo-ai +``` + Flo AI requires Python 3.10 or higher. + - Configure your LLM provider API keys: ```bash # For OpenAI export - OPENAI_API_KEY="your-openai-key" # For Anthropic export - ANTHROPIC_API_KEY="your-anthropic-key" # For Google Gemini export - GOOGLE_API_KEY="your-google-key" ``` + Configure your LLM provider API keys: + +```bash +export OPENAI_API_KEY="your-openai-key" +export ANTHROPIC_API_KEY="your-anthropic-key" +export GOOGLE_API_KEY="your-google-key" +``` + diff --git a/flo_ai/docs/arium_yaml_guide.md b/flo_ai/docs/arium_yaml_guide.md index b138da3a..0db39315 100644 --- a/flo_ai/docs/arium_yaml_guide.md +++ b/flo_ai/docs/arium_yaml_guide.md @@ -264,7 +264,7 @@ tools: **Inline YAML Configuration:** -- ⚠️ Requires nested YAML string +- ⚠️ Requires YAML to be nested as a string - ⚠️ Limited IDE support for nested content - ✅ Maintains existing workflow compatibility - ✅ Good for complex parser configurations diff --git a/flo_ai/examples/example_graph_visualization.py b/flo_ai/examples/example_graph_visualization.py index 6bae270b..e70791a2 100644 --- a/flo_ai/examples/example_graph_visualization.py +++ b/flo_ai/examples/example_graph_visualization.py @@ -45,7 +45,9 @@ class MockLLM(BaseLLM): def __init__(self, model): super().__init__(model=model) - async def generate(self, messages, functions=None, output_schema=None): + async def generate( + self, messages, functions=None, output_schema=None, **kwargs: Any + ): return {'content': 'Mock response'} async def stream( diff --git a/flo_ai/flo_ai/agent/agent.py b/flo_ai/flo_ai/agent/agent.py index e8a27a58..89977035 100644 --- a/flo_ai/flo_ai/agent/agent.py +++ b/flo_ai/flo_ai/agent/agent.py @@ -132,16 +132,11 @@ async def _run_conversational( """Run as a conversational agent when no tools are provided""" variables = variables or {} + # Create and add system message once before the retry loop + self._setup_system_message(variables, include_react=False) + while retry_count <= self.max_retries: try: - # Resolve variables in system prompt - system_content = ( - self._get_cot_prompt(variables) - if self.reasoning_pattern == ReasoningPattern.COT - else resolve_variables(self.system_prompt, variables) - ) - system_message = SystemMessage(content=system_content) - self.add_to_history(system_message) messages = await self._get_message_history(variables) logger.debug(f'Sending messages to LLM: {messages}') @@ -209,20 +204,13 @@ async def _run_with_tools( ) -> List[BaseMessage]: """Run as a tool-using agent when tools are provided""" variables = variables or {} - print('running with tools') + logger.debug('Running agent with tools') + + # Create and add system message once before the retry loop + self._setup_system_message(variables, include_react=True) while retry_count <= self.max_retries: try: - # Resolve variables in system prompt based on reasoning pattern - if self.reasoning_pattern == ReasoningPattern.REACT: - system_content = self._get_react_prompt(variables) - elif self.reasoning_pattern == ReasoningPattern.COT: - system_content = self._get_cot_prompt(variables) - else: - system_content = resolve_variables(self.system_prompt, variables) - - system_message = SystemMessage(content=system_content) - self.add_to_history(system_message) messages = await self._get_message_history(variables) # Keep executing tools until we get a final answer @@ -429,7 +417,9 @@ async def _run_with_tools( if self.act_as is not None else MessageType.ASSISTANT ) - self.add_to_history(AssistantMessage(content=assistant_message)) + self.add_to_history( + AssistantMessage(role=role, content=assistant_message) + ) return self.conversation_history # Fallback: return function message only if we have valid tool execution data @@ -473,6 +463,37 @@ async def _run_with_tools( raise AgentError(f'Failed after maximum {self.max_retries} attempts.') + def _setup_system_message( + self, variables: Optional[Dict[str, Any]] = None, include_react: bool = False + ) -> None: + """ + Create and add system message once before retry loop. + Removes any existing system messages to avoid duplicates on retries. + + Args: + variables: Optional variables for resolving system prompt + include_react: Whether to check for REACT reasoning pattern (only for tool-using agents) + """ + variables = variables or {} + + # Remove any existing system messages to avoid duplicates on retries + self.conversation_history = [ + msg + for msg in self.conversation_history + if not isinstance(msg, SystemMessage) + ] + + # Resolve variables in system prompt based on reasoning pattern + if include_react and self.reasoning_pattern == ReasoningPattern.REACT: + system_content = self._get_react_prompt(variables) + elif self.reasoning_pattern == ReasoningPattern.COT: + system_content = self._get_cot_prompt(variables) + else: + system_content = resolve_variables(self.system_prompt, variables) + + system_message = SystemMessage(content=system_content) + self.add_to_history(system_message) + def _get_react_prompt(self, variables: Optional[Dict[str, Any]] = None) -> str: """Get system prompt modified for ReACT pattern""" variables = variables or {} diff --git a/flo_ai/flo_ai/agent/builder.py b/flo_ai/flo_ai/agent/builder.py index cde16d1d..4dac3ec1 100644 --- a/flo_ai/flo_ai/agent/builder.py +++ b/flo_ai/flo_ai/agent/builder.py @@ -253,9 +253,6 @@ def from_yaml( config = yaml.safe_load(f) validated_config = cls._validate_yaml_config(config) - - # Use validated config directly - convert to dict only when needed for compatibility - config = validated_config.model_dump(exclude_none=True) agent = validated_config.agent builder = cls() @@ -303,6 +300,7 @@ def from_yaml( builder.with_tools(tools) if agent.parser is not None: + config = agent.parser.model_dump(exclude_none=True) parser = FloYamlParser.create(yaml_dict=config) builder.with_output_schema(parser.get_format()) diff --git a/flo_ai/flo_ai/arium/builder.py b/flo_ai/flo_ai/arium/builder.py index 05fcd86a..e1a8ed94 100644 --- a/flo_ai/flo_ai/arium/builder.py +++ b/flo_ai/flo_ai/arium/builder.py @@ -363,9 +363,6 @@ def from_yaml( config = yaml.safe_load(f) validated_config = cls._validate_yaml_config(config) - - # Use validated config directly - convert to dict only when needed for compatibility - config = validated_config.model_dump(exclude_none=True, by_alias=True) arium = validated_config.arium builder = cls() @@ -458,7 +455,7 @@ def from_yaml( if function is None: raise ValueError( f'Function {function_node.function_name} not found in provided function_registry dictionary. ' - f'Available functions: {list[str](function_registry.keys()) if function_registry else []}. ' + f'Available functions: {list(function_registry.keys()) if function_registry else []}. ' f'Either provide the function in the function_registry parameter or add configuration fields.' ) @@ -843,7 +840,11 @@ def _create_agent_from_direct_config( job = ( agent_config.job or agent_config.prompt or 'You are a helpful AI assistant.' ) - role: str = str(agent_config.role) if agent_config.role is not None else '' + role: str = ( + '' + if agent_config.role is None or agent_config.role == 'None' + else agent_config.role + ) act_as = agent_config.act_as # Configure LLM diff --git a/flo_ai/flo_ai/models/agent.py b/flo_ai/flo_ai/models/agent.py index 95225b2e..184f20ff 100644 --- a/flo_ai/flo_ai/models/agent.py +++ b/flo_ai/flo_ai/models/agent.py @@ -220,6 +220,15 @@ def validate_tools(cls, v): """Validate tools configuration.""" if v is None: return v + + # Normalize singletons (common YAML mistake) or fail fast with a clear error + if isinstance(v, (str, dict, ToolConfigModel)): + v = [v] + if not isinstance(v, list): + raise ValueError( + 'Tools must be a list of tool names or tool config objects' + ) + for tool in v: if isinstance(tool, str): # String reference - valid