diff --git a/flo_ai/flo_ai/arium/arium.py b/flo_ai/flo_ai/arium/arium.py index 4f2cabff..2605a705 100644 --- a/flo_ai/flo_ai/arium/arium.py +++ b/flo_ai/flo_ai/arium/arium.py @@ -381,7 +381,7 @@ def _resolve_inputs( ) ) else: - # ImageMessage and DocumentMessage objects don't need variable resolution + # ImageMessageContent and DocumentMessage objects don't need variable resolution resolved_inputs.append(input_item) return resolved_inputs diff --git a/flo_ai/flo_ai/arium/builder.py b/flo_ai/flo_ai/arium/builder.py index 0baabd94..b540ba2e 100644 --- a/flo_ai/flo_ai/arium/builder.py +++ b/flo_ai/flo_ai/arium/builder.py @@ -285,6 +285,7 @@ def from_yaml( tools: Optional[Dict[str, Tool]] = None, routers: Optional[Dict[str, Callable]] = None, base_llm: Optional[BaseLLM] = None, + **kwargs, ) -> 'AriumBuilder': """Create an AriumBuilder from a YAML configuration. @@ -478,7 +479,7 @@ def from_yaml( and 'yaml_file' not in agent_config ): agent = cls._create_agent_from_direct_config( - agent_config, base_llm, tools + agent_config, base_llm, tools, **kwargs ) # Method 3: Inline YAML config @@ -537,7 +538,7 @@ def from_yaml( router_llm = None if 'model' in router_config: router_llm = cls._create_llm_from_config( - router_config['model'], base_llm + router_config['model'], base_llm, **kwargs ) else: router_llm = base_llm # Use base LLM if no specific model configured @@ -832,7 +833,9 @@ def _find_node(node_name: str): @staticmethod def _create_llm_from_config( - model_config: Dict[str, Any], base_llm: Optional[BaseLLM] = None + model_config: Dict[str, Any], + base_llm: Optional[BaseLLM] = None, + **kwargs, ) -> BaseLLM: """Create an LLM instance from model configuration. @@ -843,33 +846,16 @@ def _create_llm_from_config( Returns: BaseLLM: Configured LLM instance """ - from flo_ai.llm import OpenAI, Anthropic, Gemini, OllamaLLM - - provider = model_config.get('provider', 'openai').lower() - model_name = model_config.get('name') - base_url = model_config.get('base_url') - - if not model_name: - raise ValueError('Model name must be specified in model configuration') - - if provider == 'openai': - llm = OpenAI(model=model_name, base_url=base_url) - elif provider == 'anthropic': - llm = Anthropic(model=model_name, base_url=base_url) - elif provider == 'gemini': - llm = Gemini(model=model_name, base_url=base_url) - elif provider == 'ollama': - llm = OllamaLLM(model=model_name, base_url=base_url) - else: - raise ValueError(f'Unsupported model provider: {provider}') + from flo_ai.helpers.llm_factory import create_llm_from_config - return llm + return create_llm_from_config(model_config, **kwargs) @staticmethod def _create_agent_from_direct_config( agent_config: Dict[str, Any], base_llm: Optional[BaseLLM] = None, available_tools: Optional[Dict[str, Tool]] = None, + **kwargs, ) -> Agent: """Create an Agent from direct YAML configuration. @@ -891,7 +877,7 @@ def _create_agent_from_direct_config( # Configure LLM if 'model' in agent_config and base_llm is None: - llm = AriumBuilder._create_llm_from_config(agent_config['model']) + llm = AriumBuilder._create_llm_from_config(agent_config['model'], **kwargs) elif base_llm: llm = base_llm else: diff --git a/flo_ai/flo_ai/builder/agent_builder.py b/flo_ai/flo_ai/builder/agent_builder.py index f011fb01..31b8fa37 100644 --- a/flo_ai/flo_ai/builder/agent_builder.py +++ b/flo_ai/flo_ai/builder/agent_builder.py @@ -3,7 +3,7 @@ import yaml from flo_ai.models.agent import Agent from flo_ai.models.base_agent import ReasoningPattern -from flo_ai.llm import BaseLLM, OpenAI, Anthropic, Gemini, OllamaLLM, VertexAI +from flo_ai.llm import BaseLLM from flo_ai.tool.base_tool import Tool from flo_ai.tool.tool_config import ToolConfig, create_tool_config from flo_ai.formatter.yaml_format_parser import FloYamlParser @@ -190,6 +190,7 @@ def from_yaml( tools: Optional[List[Tool]] = None, base_llm: Optional[BaseLLM] = None, tool_registry: Optional[Dict[str, Tool]] = None, + **kwargs, ) -> 'AgentBuilder': """Create an agent builder from a YAML configuration string @@ -219,35 +220,15 @@ def from_yaml( # Configure LLM based on model settings if 'model' in agent_config and base_llm is None: - base_url = agent_config.get('base_url', None) + from flo_ai.helpers.llm_factory import create_llm_from_config + model_config: dict = agent_config['model'] - provider = model_config.get('provider', 'openai').lower() - model_name = model_config.get('name') - - if not model_name: - raise ValueError('Model name must be specified in YAML configuration') - - if provider == 'openai': - builder.with_llm(OpenAI(model=model_name, base_url=base_url)) - elif provider == 'anthropic': - builder.with_llm(Anthropic(model=model_name, base_url=base_url)) - elif provider == 'gemini': - builder.with_llm(Gemini(model=model_name, base_url=base_url)) - elif provider == 'ollama': - builder.with_llm(OllamaLLM(model=model_name, base_url=base_url)) - elif provider == 'vertexai': - project = model_config.get('project') - location = model_config.get('location', 'asia-south1') - builder.with_llm( - VertexAI( - model=model_name, - project=project, - location=location, - base_url=base_url, - ) - ) - else: - raise ValueError(f'Unsupported model provider: {provider}') + # Merge base_url from agent_config if present and not in model_config + if 'base_url' in agent_config and 'base_url' not in model_config: + model_config = {**model_config, 'base_url': agent_config['base_url']} + + llm = create_llm_from_config(model_config, **kwargs) + builder.with_llm(llm) else: if base_llm is None: raise ValueError( diff --git a/flo_ai/flo_ai/helpers/__init__.py b/flo_ai/flo_ai/helpers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/flo_ai/flo_ai/helpers/llm_factory.py b/flo_ai/flo_ai/helpers/llm_factory.py new file mode 100644 index 00000000..27891d3d --- /dev/null +++ b/flo_ai/flo_ai/helpers/llm_factory.py @@ -0,0 +1,220 @@ +""" +LLM Factory - Centralized LLM creation from configuration. + +This module provides a unified factory function for creating LLM instances +from configuration dictionaries, supporting all providers in the flo_ai ecosystem. +""" + +import os +from typing import Dict, Any, TYPE_CHECKING + +if TYPE_CHECKING: + from flo_ai.llm import BaseLLM + + +class LLMFactory: + """Factory class for creating LLM instances from configuration.""" + + SUPPORTED_PROVIDERS = { + 'openai', + 'anthropic', + 'gemini', + 'ollama', + 'vertexai', + 'rootflo', + } + + @staticmethod + def create_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': + """Create an LLM instance from model configuration. + + Args: + model_config: Dictionary containing model configuration with keys: + - provider (str): LLM provider name (default: 'openai') + - name (str): Model name (required for most providers) + - base_url (str, optional): Custom base URL + - model_id (str): For RootFlo provider + - project (str): For VertexAI provider + - location (str): For VertexAI provider (default: 'asia-south1') + **kwargs: Additional parameters that override config and env vars: + - base_url: Override base URL + - For RootFlo: app_key, app_secret, issuer, audience, access_token + + Returns: + BaseLLM: Configured LLM instance + + Raises: + ValueError: If provider is unsupported or required parameters are missing + + Examples: + >>> # OpenAI + >>> llm = LLMFactory.create_llm({'provider': 'openai', 'name': 'gpt-4'}) + + >>> # VertexAI with project + >>> llm = LLMFactory.create_llm({ + ... 'provider': 'vertexai', + ... 'name': 'gemini-pro', + ... 'project': 'my-project', + ... 'location': 'us-central1' + ... }) + + >>> # RootFlo with auth + >>> llm = LLMFactory.create_llm( + ... {'provider': 'rootflo', 'model_id': 'model-123'}, + ... app_key='key', app_secret='secret', issuer='iss', audience='aud' + ... ) + """ + provider = model_config.get('provider', 'openai').lower() + + if provider not in LLMFactory.SUPPORTED_PROVIDERS: + raise ValueError( + f'Unsupported model provider: {provider}. ' + f'Supported providers: {", ".join(sorted(LLMFactory.SUPPORTED_PROVIDERS))}' + ) + + if provider == 'rootflo': + return LLMFactory._create_rootflo_llm(model_config, **kwargs) + elif provider == 'vertexai': + return LLMFactory._create_vertexai_llm(model_config, **kwargs) + else: + return LLMFactory._create_standard_llm(provider, model_config, **kwargs) + + @staticmethod + def _create_standard_llm( + provider: str, model_config: Dict[str, Any], **kwargs + ) -> 'BaseLLM': + """Create standard LLM instances (OpenAI, Anthropic, Gemini, Ollama).""" + from flo_ai.llm import OpenAI, Anthropic, Gemini, OllamaLLM + + model_name = model_config.get('name') + if not model_name: + raise ValueError( + f'{provider.title()} provider requires "name" parameter in model configuration' + ) + + # Priority: kwargs > model_config > None + base_url = kwargs.get('base_url') or model_config.get('base_url') + + provider_map = { + 'openai': OpenAI, + 'anthropic': Anthropic, + 'gemini': Gemini, + 'ollama': OllamaLLM, + } + + llm_class = provider_map[provider] + return llm_class(model=model_name, base_url=base_url) + + @staticmethod + def _create_vertexai_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': + """Create VertexAI LLM instance with project and location.""" + from flo_ai.llm import VertexAI + + model_name = model_config.get('name') + if not model_name: + raise ValueError( + 'VertexAI provider requires "name" parameter in model configuration' + ) + + # Get VertexAI-specific parameters + project = kwargs.get('project') or model_config.get('project') + location = kwargs.get('location') or model_config.get('location', 'asia-south1') + base_url = kwargs.get('base_url') or model_config.get('base_url') + + if not project: + raise ValueError( + 'VertexAI provider requires "project" parameter. ' + 'Provide it in model_config or as a kwarg.' + ) + + return VertexAI( + model=model_name, + project=project, + location=location, + base_url=base_url, + ) + + @staticmethod + def _create_rootflo_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': + """Create RootFlo LLM instance with authentication.""" + from flo_ai.llm import RootFloLLM + + model_id = model_config.get('model_id') + if not model_id: + raise ValueError( + 'RootFlo provider requires "model_id" in model configuration' + ) + + # Gather RootFlo parameters from kwargs or environment + base_url = ( + kwargs.get('base_url') + or model_config.get('base_url') + or os.getenv('ROOTFLO_BASE_URL') + ) + app_key = kwargs.get('app_key') or os.getenv('ROOTFLO_APP_KEY') + app_secret = kwargs.get('app_secret') or os.getenv('ROOTFLO_APP_SECRET') + issuer = kwargs.get('issuer') or os.getenv('ROOTFLO_ISSUER') + audience = kwargs.get('audience') or os.getenv('ROOTFLO_AUDIENCE') + access_token = kwargs.get('access_token') # Optional, from kwargs only + + # Validate required parameters based on auth method + if not access_token: + # JWT auth flow - requires all parameters + required_params = { + 'base_url': base_url, + 'app_key': app_key, + 'app_secret': app_secret, + 'issuer': issuer, + 'audience': audience, + } + missing = [k for k, v in required_params.items() if not v] + + if missing: + raise ValueError( + f'RootFlo configuration incomplete. Missing required parameters: {", ".join(missing)}. ' + f'These can be provided via kwargs or environment variables ' + f'(ROOTFLO_BASE_URL, ROOTFLO_APP_KEY, ROOTFLO_APP_SECRET, ROOTFLO_ISSUER, ROOTFLO_AUDIENCE).' + ) + else: + # Access token flow - only needs base_url and app_key + required_params = { + 'base_url': base_url, + 'app_key': app_key, + } + missing = [k for k, v in required_params.items() if not v] + + if missing: + raise ValueError( + f'RootFlo configuration incomplete. Missing required parameters: {", ".join(missing)}. ' + f'These can be provided via kwargs or environment variables ' + f'(ROOTFLO_BASE_URL, ROOTFLO_APP_KEY).' + ) + + return RootFloLLM( + base_url=base_url, + model_id=model_id, + app_key=app_key, + app_secret=app_secret, + issuer=issuer, + audience=audience, + access_token=access_token, + ) + + +# Convenience function for direct import +def create_llm_from_config(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': + """ + Convenience function to create an LLM instance from configuration. + + This is a wrapper around LLMFactory.create_llm() for easier imports. + + Args: + model_config: Dictionary containing model configuration + **kwargs: Additional parameters that override config and env vars + + Returns: + BaseLLM: Configured LLM instance + + See LLMFactory.create_llm() for detailed documentation. + """ + return LLMFactory.create_llm(model_config, **kwargs) diff --git a/flo_ai/flo_ai/llm/__init__.py b/flo_ai/flo_ai/llm/__init__.py index 4578d7f2..934e4da4 100644 --- a/flo_ai/flo_ai/llm/__init__.py +++ b/flo_ai/flo_ai/llm/__init__.py @@ -5,6 +5,7 @@ from .gemini_llm import Gemini from .openai_vllm import OpenAIVLLM from .vertexai_llm import VertexAI +from .rootflo_llm import RootFloLLM __all__ = [ 'BaseLLM', @@ -14,4 +15,5 @@ 'Gemini', 'OpenAIVLLM', 'VertexAI', + 'RootFloLLM', ] diff --git a/flo_ai/flo_ai/llm/rootflo_llm.py b/flo_ai/flo_ai/llm/rootflo_llm.py index ccb792aa..9d59c38a 100644 --- a/flo_ai/flo_ai/llm/rootflo_llm.py +++ b/flo_ai/flo_ai/llm/rootflo_llm.py @@ -1,8 +1,10 @@ from enum import Enum from typing import AsyncIterator, Dict, Any, List, Optional from datetime import datetime, timedelta +from flo_ai.models.chat_message import ImageMessageContent import jwt -from .base_llm import BaseLLM, ImageMessage +import httpx +from .base_llm import BaseLLM from .openai_llm import OpenAI from .gemini_llm import Gemini from .anthropic_llm import Anthropic @@ -27,12 +29,11 @@ def __init__( self, base_url: str, model_id: str, - llm_model: str, - llm_provider: LLMProvider, app_key: str, app_secret: str, issuer: str, audience: str, + access_token: Optional[str] = None, temperature: float = 0.7, **kwargs, ): @@ -41,28 +42,74 @@ def __init__( Args: base_url: The base URL of the proxy server - model_id: The model identifier - llm_provider: Type of LLM SDK to use (LLMProvider enum) + model_id: The model identifier (config_id) app_key: Application key for X-Rootflo-Key header app_secret: Application secret for JWT signing issuer: JWT issuer claim audience: JWT audience claim + access_token: Optional pre-generated access token (if provided, skips JWT generation) temperature: Temperature parameter for generation **kwargs: Additional parameters to pass to the underlying SDK """ - # Generate JWT token - now = datetime.now() - payload = { - 'iss': issuer, - 'aud': audience, - 'iat': int(now.timestamp()), - 'exp': int((now + timedelta(seconds=3600)).timestamp()), - 'role_id': 'floconsole-service', - 'user_id': 'service', - 'service_auth': True, - } - service_token = jwt.encode(payload, app_secret, algorithm='HS256') - api_token = f'fc_{service_token}' + # Validate required parameters + + if not model_id: + raise ValueError('model_id is required') + + if not base_url: + raise ValueError('base_url is required') + + # Validate JWT credentials if access_token is not provided + if not access_token: + missing = [] + if not app_key: + missing.append('app_key') + if not app_secret: + missing.append('app_secret') + if not issuer: + missing.append('issuer') + if not audience: + missing.append('audience') + + if missing: + raise ValueError( + f'Missing required parameters for JWT generation: {", ".join(missing)}. ' + f'Either provide these parameters or pass an access_token directly.' + ) + else: # app key is still required + if not app_key: + raise ValueError('app_key is required even when using access_token') + + # Use provided access_token or generate JWT token + if access_token: + api_token = access_token + else: + now = datetime.now() + payload = { + 'iss': issuer, + 'aud': audience, + 'iat': int(now.timestamp()), + 'exp': int((now + timedelta(seconds=3600)).timestamp()), + 'role_id': 'floconsole-service', + 'user_id': 'service', + 'service_auth': True, + } + service_token = jwt.encode(payload, app_secret, algorithm='HS256') + api_token = f'fc_{service_token}' + + # Fetch LLM configuration from API + config = self._fetch_llm_config(base_url, model_id, api_token, app_key) + llm_model = config['llm_model'] + llm_type = config['type'] + + # Map type string to LLMProvider enum + try: + llm_provider = LLMProvider(llm_type.lower()) + except ValueError: + raise ValueError( + f'Unsupported LLM provider type from API: {llm_type}. ' + f'Supported types: {[p.value for p in LLMProvider]}' + ) super().__init__( model=llm_model, api_key=api_token, temperature=temperature, **kwargs @@ -72,8 +119,8 @@ def __init__( self.model_id = model_id self.llm_provider = llm_provider - # Construct full URL - full_url = f'{base_url}/{model_id}' + # Construct full URL for LLM inference + full_url = f'{base_url}/v1/llm-inference/{model_id}' # Prepare custom headers for proxy authentication custom_headers = {'X-Rootflo-Key': app_key} @@ -110,6 +157,60 @@ def __init__( else: raise ValueError(f'Unsupported LLM provider: {llm_provider}') + def _fetch_llm_config( + self, base_url: str, model_id: str, api_token: str, app_key: str + ) -> Dict[str, Any]: + """ + Fetch LLM configuration from the API endpoint. + + Args: + base_url: The base URL of the API server + model_id: The model identifier (config_id) + api_token: The JWT token for authorization + app_key: Application key for X-Rootflo-Key header + + Returns: + Dict containing llm_model and type + + Raises: + Exception: If API call fails or response is invalid + """ + config_url = f'{base_url}/v1/llm-inference-configs/{model_id}' + headers = { + 'Authorization': f'Bearer {api_token}', + 'X-Rootflo-Key': app_key, + } + + try: + with httpx.Client() as client: + response = client.get(config_url, headers=headers, timeout=30.0) + response.raise_for_status() + + data = response.json() + + config_data = data.get('data') + if not config_data: + raise Exception('API response missing data field') + + llm_model = config_data.get('llm_model') + llm_type = config_data.get('type') + + if not llm_model or not llm_type: + raise Exception( + f'API response missing required fields: llm_model={llm_model}, type={llm_type}' + ) + + return {'llm_model': llm_model, 'type': llm_type} + + except httpx.HTTPStatusError as e: + raise Exception( + f'API request failed with status {e.response.status_code}: {e.response.text}' + ) from e + except httpx.RequestError as e: + raise Exception(f'API request failed: {str(e)}') from e + except Exception as e: + raise Exception(f'Failed to fetch LLM config: {str(e)}') from e + async def generate( self, messages: List[Dict[str, str]], @@ -144,6 +245,6 @@ def format_tools_for_llm(self, tools: List['Tool']) -> List[Dict[str, Any]]: """Format a list of tools for the specific LLM's API""" return self._llm.format_tools_for_llm(tools) - def format_image_in_message(self, image: ImageMessage) -> str: + def format_image_in_message(self, image: ImageMessageContent) -> str: """Format a image in the message""" return self._llm.format_image_in_message(image) diff --git a/flo_ai/flo_ai/utils/variable_extractor.py b/flo_ai/flo_ai/utils/variable_extractor.py index 3c8dae0d..9ea0266b 100644 --- a/flo_ai/flo_ai/utils/variable_extractor.py +++ b/flo_ai/flo_ai/utils/variable_extractor.py @@ -50,7 +50,7 @@ def extract_variables_from_inputs(inputs: List[BaseMessage]) -> Set[str]: Set of variable names found across all string inputs Note: - ImageMessage objects are skipped as they don't contain variable placeholders + ImageMessageContent objects are skipped as they don't contain variable placeholders """ all_variables = set() for input_item in inputs: @@ -61,7 +61,7 @@ def extract_variables_from_inputs(inputs: List[BaseMessage]) -> Set[str]: if isinstance(input_item, str): variables = extract_variables_from_text(input_item) all_variables.update(variables) - # Skip ImageMessage objects as they don't contain text variables + # Skip ImageMessageContent objects as they don't contain text variables return all_variables