Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions flo_ai/flo_ai/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,24 @@ def __init__(
temperature: float = 0.7,
api_key: Optional[str] = None,
base_url: str = None,
custom_headers: Optional[Dict[str, str]] = None,
**kwargs,
):
super().__init__(model, api_key, temperature, **kwargs)
self.client = AsyncAnthropic(api_key=self.api_key, base_url=base_url)
# Add custom headers if base_url is provided (proxy scenario)
client_kwargs = {'api_key': self.api_key, 'base_url': base_url}
if base_url and custom_headers:
client_kwargs['default_headers'] = custom_headers

self.client = AsyncAnthropic(**client_kwargs)

@trace_llm_call(provider='anthropic')
async def generate(
self,
messages: List[Dict[str, str]],
functions: Optional[List[Dict[str, Any]]] = None,
output_schema: Optional[Dict[str, Any]] = None,
**kwargs,
Comment thread
vizsatiz marked this conversation as resolved.
) -> Dict[str, Any]:
# Convert messages to Claude format
system_message = next(
Expand All @@ -54,21 +61,22 @@ async def generate(
)

try:
kwargs = {
anthropic_kwargs = {
'model': self.model,
'messages': conversation,
'temperature': self.temperature,
'max_tokens': self.kwargs.get('max_tokens', 1024),
**self.kwargs,
**kwargs,
}

if system_message:
kwargs['system'] = system_message
anthropic_kwargs['system'] = system_message

if functions:
kwargs['tools'] = functions
anthropic_kwargs['tools'] = functions

response = await self.client.messages.create(**kwargs)
response = await self.client.messages.create(**anthropic_kwargs)

# Record token usage if available
if hasattr(response, 'usage') and response.usage:
Expand Down Expand Up @@ -117,6 +125,7 @@ async def stream(
self,
messages: List[Dict[str, str]],
functions: Optional[List[Dict[str, Any]]] = None,
**kwargs,
) -> AsyncIterator[Dict[str, Any]]:
"""Stream partial responses from the LLM as they are generated"""
# Convert messages to Claude format
Expand All @@ -134,21 +143,22 @@ async def stream(
}
)

kwargs = {
anthropic_kwargs = {
'model': self.model,
'messages': conversation,
'temperature': self.temperature,
'max_tokens': self.kwargs.get('max_tokens', 1024),
**self.kwargs,
**kwargs,
}

if system_message:
kwargs['system'] = system_message
anthropic_kwargs['system'] = system_message

if functions:
kwargs['tools'] = functions
anthropic_kwargs['tools'] = functions
# Use Anthropic SDK streaming API and yield text deltas
async with self.client.messages.stream(**kwargs) as stream:
async with self.client.messages.stream(**anthropic_kwargs) as stream:
async for event in stream:
if (
getattr(event, 'type', None) == 'content_block_delta'
Expand Down
30 changes: 25 additions & 5 deletions flo_ai/flo_ai/llm/gemini_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,34 @@ def __init__(
temperature: float = 0.7,
api_key: Optional[str] = None,
base_url: str = None,
custom_headers: Optional[Dict[str, str]] = None,
**kwargs,
):
super().__init__(model, api_key, temperature, **kwargs)
self.client = (
genai.Client(api_key=self.api_key) if self.api_key else genai.Client()
)
# Configure http_options for proxy or custom base_url
http_options = {'base_url': base_url} if base_url else {}
if base_url and self.api_key:
# For custom base_url (proxy), set Authorization header explicitly
http_options['headers'] = {'Authorization': f'Bearer {self.api_key}'}
# Merge custom headers if provided (proxy scenario)
if custom_headers:
http_options['headers'].update(custom_headers)

# Initialize client based on configuration
if http_options:
self.client = genai.Client(http_options=http_options)
elif self.api_key:
self.client = genai.Client(api_key=self.api_key)
else:
self.client = genai.Client()

@trace_llm_call(provider='gemini')
async def generate(
self,
messages: List[Dict[str, str]],
functions: Optional[List[Dict[str, Any]]] = None,
output_schema: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Dict[str, Any]:
# Convert messages to Gemini format
contents = []
Expand All @@ -51,10 +66,12 @@ async def generate(

try:
# Prepare generation config
# Merge instance kwargs with method kwargs
config_kwargs = {**self.kwargs, **kwargs}
generation_config = types.GenerateContentConfig(
temperature=self.temperature,
system_instruction=system_prompt,
**self.kwargs,
**config_kwargs,
)

# Add tools if functions are provided
Expand Down Expand Up @@ -134,6 +151,7 @@ async def stream(
self,
messages: List[Dict[str, str]],
functions: Optional[List[Dict[str, Any]]] = None,
**kwargs,
) -> AsyncIterator[Dict[str, Any]]:
"""Stream partial responses from Gemini as they are generated"""
# Convert messages to Gemini format
Expand All @@ -150,10 +168,12 @@ async def stream(
contents.append(message_content)

# Prepare generation config
# Merge instance kwargs with method kwargs
config_kwargs = {**self.kwargs, **kwargs}
generation_config = types.GenerateContentConfig(
temperature=self.temperature,
system_instruction=system_prompt,
**self.kwargs,
**config_kwargs,
)

# Add tools if functions are provided
Expand Down
25 changes: 20 additions & 5 deletions flo_ai/flo_ai/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,33 @@ def __init__(
api_key: str = None,
temperature: float = 0.7,
base_url: str = None,
custom_headers: Optional[Dict[str, str]] = None,
**kwargs,
):
super().__init__(
model=model, api_key=api_key, temperature=temperature, **kwargs
)
self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
# Add custom headers if base_url is provided (proxy scenario)
client_kwargs = {'api_key': api_key, 'base_url': base_url}
if base_url and custom_headers:
client_kwargs['default_headers'] = custom_headers

self.client = AsyncOpenAI(**client_kwargs)
self.model = model
self.kwargs = kwargs

@trace_llm_call(provider='openai')
async def generate(
self, messages: list[dict], output_schema: dict = None, **kwargs
self,
messages: list[dict],
functions: Optional[List[Dict[str, Any]]] = None,
output_schema: dict = None,
**kwargs,
) -> Any:
# Convert output_schema to OpenAI format if provided
# Handle structured output vs tool calling
# Priority: output_schema takes precedence over functions for structured output
if output_schema:
# Convert output_schema to OpenAI format for structured output
kwargs['response_format'] = {'type': 'json_object'}
kwargs['functions'] = [
{
Expand All @@ -57,14 +69,17 @@ async def generate(
'content': 'Please provide your response in JSON format according to the specified schema.',
},
)
elif functions:
# Use functions for tool calling when output_schema is not provided
kwargs['functions'] = functions

# Prepare OpenAI API parameters
openai_kwargs = {
'model': self.model,
'messages': messages,
'temperature': self.temperature,
**kwargs,
**self.kwargs,
**kwargs,
}

# Make the API call
Expand Down Expand Up @@ -112,8 +127,8 @@ async def stream(
'messages': messages,
'temperature': self.temperature,
'stream': True,
**kwargs,
**self.kwargs,
**kwargs,
}

if functions:
Expand Down
149 changes: 149 additions & 0 deletions flo_ai/flo_ai/llm/rootflo_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from enum import Enum
from typing import AsyncIterator, Dict, Any, List, Optional
from datetime import datetime, timedelta
import jwt
from .base_llm import BaseLLM, ImageMessage
from .openai_llm import OpenAI
from .gemini_llm import Gemini
from .anthropic_llm import Anthropic
from flo_ai.tool.base_tool import Tool


class LLMProvider(Enum):
"""Enum for supported LLM providers"""

OPENAI = 'openai'
GEMINI = 'gemini'
ANTHROPIC = 'anthropic'


class RootFloLLM(BaseLLM):
"""
Proxy LLM class that routes to different SDK implementations based on type.
Acts as a unified interface to OpenAI, Gemini, and Anthropic SDKs via a proxy URL.
"""

def __init__(
self,
base_url: str,
model_id: str,
llm_model: str,
llm_provider: LLMProvider,
app_key: str,
app_secret: str,
issuer: str,
audience: str,
temperature: float = 0.7,
**kwargs,
):
"""
Initialize RootFloLLM proxy.

Args:
base_url: The base URL of the proxy server
model_id: The model identifier
llm_provider: Type of LLM SDK to use (LLMProvider enum)
app_key: Application key for X-Rootflo-Key header
app_secret: Application secret for JWT signing
issuer: JWT issuer claim
audience: JWT audience claim
temperature: Temperature parameter for generation
**kwargs: Additional parameters to pass to the underlying SDK
"""
# Generate JWT token
now = datetime.now()
payload = {
'iss': issuer,
'aud': audience,
'iat': int(now.timestamp()),
'exp': int((now + timedelta(seconds=3600)).timestamp()),
'role_id': 'floconsole-service',
'user_id': 'service',
'service_auth': True,
}
service_token = jwt.encode(payload, app_secret, algorithm='HS256')
api_token = f'fc_{service_token}'

super().__init__(
model=llm_model, api_key=api_token, temperature=temperature, **kwargs
)

self.base_url = base_url
self.model_id = model_id
self.llm_provider = llm_provider

# Construct full URL
full_url = f'{base_url}/{model_id}'

# Prepare custom headers for proxy authentication
custom_headers = {'X-Rootflo-Key': app_key}

# Instantiate appropriate SDK wrapper based on llm_provider
if llm_provider == LLMProvider.OPENAI:
self._llm = OpenAI(
model=llm_model,
base_url=full_url,
api_key=api_token,
temperature=temperature,
custom_headers=custom_headers,
**kwargs,
)
elif llm_provider == LLMProvider.ANTHROPIC:
self._llm = Anthropic(
model=llm_model,
base_url=full_url,
api_key=api_token,
temperature=temperature,
custom_headers=custom_headers,
**kwargs,
)
elif llm_provider == LLMProvider.GEMINI:
# Gemini SDK - pass base_url which will be handled via http_options
self._llm = Gemini(
model=llm_model,
api_key=api_token,
temperature=temperature,
base_url=full_url,
custom_headers=custom_headers,
**kwargs,
)
else:
raise ValueError(f'Unsupported LLM provider: {llm_provider}')

async def generate(
self,
messages: List[Dict[str, str]],
functions: Optional[List[Dict[str, Any]]] = None,
output_schema: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Dict[str, Any]:
"""Generate a response from the LLM"""
return await self._llm.generate(
messages, functions=functions, output_schema=output_schema, **kwargs
)

async def stream(
self,
messages: List[Dict[str, Any]],
functions: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
"""Generate a streaming response from the LLM"""
async for chunk in self._llm.stream(messages, functions=functions, **kwargs):
yield chunk

def get_message_content(self, response: Any) -> str:
"""Extract message content from response"""
return self._llm.get_message_content(response)

def format_tool_for_llm(self, tool: 'Tool') -> Dict[str, Any]:
"""Format a tool for the specific LLM's API"""
return self._llm.format_tool_for_llm(tool)

def format_tools_for_llm(self, tools: List['Tool']) -> List[Dict[str, Any]]:
"""Format a list of tools for the specific LLM's API"""
return self._llm.format_tools_for_llm(tools)

def format_image_in_message(self, image: ImageMessage) -> str:
"""Format a image in the message"""
return self._llm.format_image_in_message(image)
Loading