diff --git a/flo_ai/flo_ai/helpers/llm_factory.py b/flo_ai/flo_ai/helpers/llm_factory.py index 27891d3d..1dd21f28 100644 --- a/flo_ai/flo_ai/helpers/llm_factory.py +++ b/flo_ai/flo_ai/helpers/llm_factory.py @@ -22,6 +22,7 @@ class LLMFactory: 'ollama', 'vertexai', 'rootflo', + 'openai_vllm', } @staticmethod @@ -63,6 +64,14 @@ def create_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': ... {'provider': 'rootflo', 'model_id': 'model-123'}, ... app_key='key', app_secret='secret', issuer='iss', audience='aud' ... ) + + >>> # OpenAI vLLM with base_url + >>> llm = LLMFactory.create_llm({ + ... 'provider': 'openai_vllm', + ... 'name': 'microsoft/phi-4', + ... 'base_url': 'http://localhost:8000/v1', + ... 'api_key': 'vllm-key' + ... }) """ provider = model_config.get('provider', 'openai').lower() @@ -76,6 +85,8 @@ def create_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': return LLMFactory._create_rootflo_llm(model_config, **kwargs) elif provider == 'vertexai': return LLMFactory._create_vertexai_llm(model_config, **kwargs) + elif provider == 'openai_vllm': + return LLMFactory._create_openai_vllm_llm(model_config, **kwargs) else: return LLMFactory._create_standard_llm(provider, model_config, **kwargs) @@ -134,6 +145,39 @@ def _create_vertexai_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': base_url=base_url, ) + @staticmethod + def _create_openai_vllm_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': + """Create OpenAI vLLM instance with base_url handling.""" + from flo_ai.llm import OpenAIVLLM + + model_name = model_config.get('name') + if not model_name: + raise ValueError( + 'openai_vllm provider requires "name" parameter in model configuration' + ) + + # Priority: kwargs > model_config > None + base_url = kwargs.get('base_url') or model_config.get('base_url') + if not base_url: + raise ValueError( + 'openai_vllm provider requires "base_url" parameter. ' + 'Provide it in model_config or as a kwarg.' + ) + + # Optional parameters + api_key = kwargs.get('api_key') or model_config.get('api_key') + temperature = kwargs.get( + 'temperature', + model_config.get('temperature', 0.7), + ) + + return OpenAIVLLM( + model=model_name, + base_url=base_url, + api_key=api_key, + temperature=temperature, + ) + @staticmethod def _create_rootflo_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM': """Create RootFlo LLM instance with authentication.""" diff --git a/flo_ai/flo_ai/llm/rootflo_llm.py b/flo_ai/flo_ai/llm/rootflo_llm.py index 9d59c38a..6aa0ad5e 100644 --- a/flo_ai/flo_ai/llm/rootflo_llm.py +++ b/flo_ai/flo_ai/llm/rootflo_llm.py @@ -4,10 +4,12 @@ from flo_ai.models.chat_message import ImageMessageContent import jwt import httpx +import asyncio from .base_llm import BaseLLM from .openai_llm import OpenAI from .gemini_llm import Gemini from .anthropic_llm import Anthropic +from .openai_vllm import OpenAIVLLM from flo_ai.tool.base_tool import Tool @@ -17,12 +19,13 @@ class LLMProvider(Enum): OPENAI = 'openai' GEMINI = 'gemini' ANTHROPIC = 'anthropic' + VLLM = 'vllm' class RootFloLLM(BaseLLM): """ Proxy LLM class that routes to different SDK implementations based on type. - Acts as a unified interface to OpenAI, Gemini, and Anthropic SDKs via a proxy URL. + Acts as a unified interface to OpenAI, Gemini, Anthropic SDKs and VLLM via a proxy URL. """ def __init__( @@ -50,9 +53,12 @@ def __init__( access_token: Optional pre-generated access token (if provided, skips JWT generation) temperature: Temperature parameter for generation **kwargs: Additional parameters to pass to the underlying SDK + + Note: + The actual LLM configuration is fetched lazily on first use (generate/stream) + to avoid blocking HTTP calls during initialization. """ # Validate required parameters - if not model_id: raise ValueError('model_id is required') @@ -80,88 +86,41 @@ def __init__( if not app_key: raise ValueError('app_key is required even when using access_token') - # Use provided access_token or generate JWT token - if access_token: - api_token = access_token - else: - now = datetime.now() - payload = { - 'iss': issuer, - 'aud': audience, - 'iat': int(now.timestamp()), - 'exp': int((now + timedelta(seconds=3600)).timestamp()), - 'role_id': 'floconsole-service', - 'user_id': 'service', - 'service_auth': True, - } - service_token = jwt.encode(payload, app_secret, algorithm='HS256') - api_token = f'fc_{service_token}' - - # Fetch LLM configuration from API - config = self._fetch_llm_config(base_url, model_id, api_token, app_key) - llm_model = config['llm_model'] - llm_type = config['type'] - - # Map type string to LLMProvider enum - try: - llm_provider = LLMProvider(llm_type.lower()) - except ValueError: - raise ValueError( - f'Unsupported LLM provider type from API: {llm_type}. ' - f'Supported types: {[p.value for p in LLMProvider]}' - ) + # Store initialization parameters for lazy initialization + self._base_url = base_url + self._model_id = model_id + self._app_key = app_key + self._app_secret = app_secret + self._issuer = issuer + self._audience = audience + self._access_token = access_token + self._temperature = temperature + self._kwargs = kwargs + + # Lazy initialization state + self._llm = None + self._initialized = False + self._init_lock = asyncio.Lock() + + # Will be set during initialization + self.base_url = base_url + self.model_id = model_id + self.llm_provider = None + # Call parent __init__ with minimal parameters + # Actual model will be set during lazy initialization super().__init__( - model=llm_model, api_key=api_token, temperature=temperature, **kwargs + model='', + api_key='', + temperature=temperature, + **kwargs, ) - self.base_url = base_url - self.model_id = model_id - self.llm_provider = llm_provider - - # Construct full URL for LLM inference - full_url = f'{base_url}/v1/llm-inference/{model_id}' - - # Prepare custom headers for proxy authentication - custom_headers = {'X-Rootflo-Key': app_key} - - # Instantiate appropriate SDK wrapper based on llm_provider - if llm_provider == LLMProvider.OPENAI: - self._llm = OpenAI( - model=llm_model, - base_url=full_url, - api_key=api_token, - temperature=temperature, - custom_headers=custom_headers, - **kwargs, - ) - elif llm_provider == LLMProvider.ANTHROPIC: - self._llm = Anthropic( - model=llm_model, - base_url=full_url, - api_key=api_token, - temperature=temperature, - custom_headers=custom_headers, - **kwargs, - ) - elif llm_provider == LLMProvider.GEMINI: - # Gemini SDK - pass base_url which will be handled via http_options - self._llm = Gemini( - model=llm_model, - api_key=api_token, - temperature=temperature, - base_url=full_url, - custom_headers=custom_headers, - **kwargs, - ) - else: - raise ValueError(f'Unsupported LLM provider: {llm_provider}') - - def _fetch_llm_config( + async def _fetch_llm_config_async( self, base_url: str, model_id: str, api_token: str, app_key: str ) -> Dict[str, Any]: """ - Fetch LLM configuration from the API endpoint. + Fetch LLM configuration from the API endpoint asynchronously. Args: base_url: The base URL of the API server @@ -182,8 +141,8 @@ def _fetch_llm_config( } try: - with httpx.Client() as client: - response = client.get(config_url, headers=headers, timeout=30.0) + async with httpx.AsyncClient() as client: + response = await client.get(config_url, headers=headers, timeout=30.0) response.raise_for_status() data = response.json() @@ -211,6 +170,110 @@ def _fetch_llm_config( except Exception as e: raise Exception(f'Failed to fetch LLM config: {str(e)}') from e + async def _ensure_initialized(self): + """ + Ensure the LLM is initialized by fetching config on first use. + Uses double-checked locking to prevent race conditions in concurrent scenarios. + """ + # Fast path: already initialized + if self._initialized: + return + + # Acquire lock for initialization + async with self._init_lock: + # Double-check: another task might have initialized while we waited + if self._initialized: + return + + # Generate or use provided access token + if self._access_token: + api_token = self._access_token + else: + now = datetime.now() + payload = { + 'iss': self._issuer, + 'aud': self._audience, + 'iat': int(now.timestamp()), + 'exp': int((now + timedelta(seconds=3600)).timestamp()), + 'role_id': 'floconsole-service', + 'user_id': 'service', + 'service_auth': True, + } + service_token = jwt.encode(payload, self._app_secret, algorithm='HS256') + api_token = f'fc_{service_token}' + + # Fetch LLM configuration from API + config = await self._fetch_llm_config_async( + self._base_url, self._model_id, api_token, self._app_key + ) + llm_model = config['llm_model'] + llm_type = config['type'] + + # Map type string to LLMProvider enum + try: + llm_provider = LLMProvider(llm_type.lower()) + except ValueError: + raise ValueError( + f'Unsupported LLM provider type from API: {llm_type}. ' + f'Supported types: {[p.value for p in LLMProvider]}' + ) + + # Update instance attributes + self.llm_provider = llm_provider + self.model = llm_model + self.api_key = api_token + + # Construct full URL for LLM inference + full_url = f'{self._base_url}/v1/llm-inference/{self._model_id}' + + # Prepare custom headers for proxy authentication + custom_headers = {'X-Rootflo-Key': self._app_key} + + # Instantiate appropriate SDK wrapper based on llm_provider + if llm_provider == LLMProvider.OPENAI: + self._llm = OpenAI( + model=llm_model, + base_url=full_url, + api_key=api_token, + temperature=self._temperature, + custom_headers=custom_headers, + **self._kwargs, + ) + elif llm_provider == LLMProvider.ANTHROPIC: + self._llm = Anthropic( + model=llm_model, + base_url=full_url, + api_key=api_token, + temperature=self._temperature, + custom_headers=custom_headers, + **self._kwargs, + ) + elif llm_provider == LLMProvider.GEMINI: + # Gemini SDK - pass base_url which will be handled via http_options + self._llm = Gemini( + model=llm_model, + api_key=api_token, + temperature=self._temperature, + base_url=full_url, + custom_headers=custom_headers, + **self._kwargs, + ) + elif llm_provider == LLMProvider.VLLM: + # vLLM via OpenAI-compatible API + self._llm = OpenAIVLLM( + model=llm_model, + base_url=full_url, + api_key=api_token, + temperature=self._temperature, + custom_headers=custom_headers, + **self._kwargs, + ) + else: + raise ValueError(f'Unsupported LLM provider: {llm_provider}') + + # Mark as initialized + self._initialized = True + async def generate( self, messages: List[Dict[str, str]], @@ -219,6 +282,7 @@ async def generate( **kwargs, ) -> Dict[str, Any]: """Generate a response from the LLM""" + await self._ensure_initialized() return await self._llm.generate( messages, functions=functions, output_schema=output_schema, **kwargs ) @@ -230,11 +294,16 @@ async def stream( **kwargs: Any, ) -> AsyncIterator[Dict[str, Any]]: """Generate a streaming response from the LLM""" + await self._ensure_initialized() async for chunk in self._llm.stream(messages, functions=functions, **kwargs): yield chunk def get_message_content(self, response: Any) -> str: """Extract message content from response""" + if not getattr(self, '_initialized', False) or self._llm is None: + raise RuntimeError( + 'RootFloLLM is not initialized yet; call generate() or stream() first.' + ) return self._llm.get_message_content(response) def format_tool_for_llm(self, tool: 'Tool') -> Dict[str, Any]: