Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 10 additions & 29 deletions flo_ai/flo_ai/helpers/llm_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _create_vertexai_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM':
model=model_name,
project=project,
location=location,
base_url=base_url,
base_url=str(base_url),
)

@staticmethod
Expand Down Expand Up @@ -173,8 +173,8 @@ def _create_openai_vllm_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM'

return OpenAIVLLM(
model=model_name,
base_url=base_url,
api_key=api_key,
base_url=str(base_url),
api_key=str(api_key),
temperature=temperature,
)

Expand All @@ -201,34 +201,15 @@ def _create_rootflo_llm(model_config: Dict[str, Any], **kwargs) -> 'BaseLLM':
audience = kwargs.get('audience') or os.getenv('ROOTFLO_AUDIENCE')
access_token = kwargs.get('access_token') # Optional, from kwargs only

# Validate required parameters based on auth method
if not access_token:
# JWT auth flow - requires all parameters
required_params = {
'base_url': base_url,
'app_key': app_key,
'app_secret': app_secret,
'issuer': issuer,
'audience': audience,
}
missing = [k for k, v in required_params.items() if not v]

if missing:
raise ValueError(
f'RootFlo configuration incomplete. Missing required parameters: {", ".join(missing)}. '
f'These can be provided via kwargs or environment variables '
f'(ROOTFLO_BASE_URL, ROOTFLO_APP_KEY, ROOTFLO_APP_SECRET, ROOTFLO_ISSUER, ROOTFLO_AUDIENCE).'
)
else:
# Access token flow - only needs base_url
if not base_url:
raise ValueError(
'RootFlo configuration incomplete. Missing required parameter: base_url. '
'Provide it in model_config, as a kwarg, or via ROOTFLO_BASE_URL environment variable.'
)
# Access token flow - only needs base_url
if not base_url:
raise ValueError(
'RootFlo configuration incomplete. Missing required parameter: base_url. '
'Provide it in model_config, as a kwarg, or via ROOTFLO_BASE_URL environment variable.'
)

return RootFloLLM(
base_url=base_url,
base_url=str(base_url),
model_id=model_id,
app_key=app_key,
app_secret=app_secret,
Expand Down
1 change: 1 addition & 0 deletions flo_ai/flo_ai/llm/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ async def stream(
messages: List[Dict[str, str]],
functions: Optional[List[Dict[str, Any]]] = None,
output_schema: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
"""Stream partial responses from the LLM as they are generated"""
pass
Expand Down
38 changes: 11 additions & 27 deletions flo_ai/flo_ai/llm/rootflo_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,24 +65,6 @@ def __init__(
if not base_url:
raise ValueError('base_url is required')

# Validate JWT credentials if access_token is not provided
if not access_token:
missing = []
if not app_key:
missing.append('app_key')
if not app_secret:
missing.append('app_secret')
if not issuer:
missing.append('issuer')
if not audience:
missing.append('audience')

if missing:
raise ValueError(
f'Missing required parameters for JWT generation: {", ".join(missing)}. '
f'Either provide these parameters or pass an access_token directly.'
)

# Store initialization parameters for lazy initialization
self._base_url = base_url
self._model_id = model_id
Expand Down Expand Up @@ -117,7 +99,7 @@ async def _fetch_llm_config_async(
self,
base_url: str,
model_id: str,
api_token: str,
api_token: str | None = None,
app_key: Optional[str] = None,
) -> Dict[str, Any]:
"""
Expand All @@ -136,9 +118,10 @@ async def _fetch_llm_config_async(
Exception: If API call fails or response is invalid
"""
config_url = f'{base_url}/v1/llm-inference-configs/{model_id}'
headers = {
'Authorization': f'Bearer {api_token}',
}

headers = {}
if api_token:
headers['Authorization'] = f'Bearer {api_token}'

# Only add X-Rootflo-Key header if app_key is provided
if app_key:
Expand Down Expand Up @@ -189,10 +172,11 @@ async def _ensure_initialized(self):
if self._initialized:
return

api_token = None
# Generate or use provided access token
if self._access_token:
api_token = self._access_token
else:
elif self._app_key and self._app_secret:
now = datetime.now()
payload = {
'iss': self._issuer,
Expand Down Expand Up @@ -238,7 +222,7 @@ async def _ensure_initialized(self):
self._llm = OpenAI(
model=llm_model,
base_url=full_url,
api_key=api_token,
api_key=api_token or 'no_token',
temperature=self._temperature,
custom_headers=custom_headers,
**self._kwargs,
Expand All @@ -247,7 +231,7 @@ async def _ensure_initialized(self):
self._llm = Anthropic(
model=llm_model,
base_url=full_url,
api_key=api_token,
api_key=api_token or 'no_token',
temperature=self._temperature,
custom_headers=custom_headers,
**self._kwargs,
Expand All @@ -256,7 +240,7 @@ async def _ensure_initialized(self):
# Gemini SDK - pass base_url which will be handled via http_options
self._llm = Gemini(
model=llm_model,
api_key=api_token,
api_key=api_token or 'no_token',
temperature=self._temperature,
base_url=full_url,
custom_headers=custom_headers,
Expand All @@ -267,7 +251,7 @@ async def _ensure_initialized(self):
self._llm = OpenAIVLLM(
model=llm_model,
base_url=full_url,
api_key=api_token,
api_key=api_token or 'no_token',
temperature=self._temperature,
custom_headers=custom_headers,
**self._kwargs,
Expand Down