Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion flo_ai/flo_ai/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Any, List, Optional
from typing import Dict, Any, List, Optional, AsyncIterator
from anthropic import AsyncAnthropic
import json
from .base_llm import BaseLLM, ImageMessage
Expand Down Expand Up @@ -77,6 +77,51 @@ async def generate(
except Exception as e:
raise Exception(f'Error in Claude API call: {str(e)}')

async def stream(
self,
messages: List[Dict[str, str]],
functions: Optional[List[Dict[str, Any]]] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""Stream partial responses from the LLM as they are generated"""
# Convert messages to Claude format
system_message = next(
(msg['content'] for msg in messages if msg['role'] == 'system'), None
)

conversation = []
for msg in messages:
if msg['role'] != 'system':
conversation.append(
{
'role': 'assistant' if msg['role'] == 'assistant' else 'user',
'content': msg['content'],
}
)

kwargs = {
'model': self.model,
'messages': conversation,
'temperature': self.temperature,
'max_tokens': self.kwargs.get('max_tokens', 1024),
**self.kwargs,
}

if system_message:
kwargs['system'] = system_message

if functions:
kwargs['tools'] = functions
# Use Anthropic SDK streaming API and yield text deltas
async with self.client.messages.stream(**kwargs) as stream:
async for event in stream:
if (
getattr(event, 'type', None) == 'content_block_delta'
and hasattr(event, 'delta')
and getattr(event.delta, 'type', None) == 'text_delta'
and hasattr(event.delta, 'text')
):
yield {'content': event.delta.text}

def get_message_content(self, response: Any) -> str:
"""Extract message content from response"""
if isinstance(response, dict):
Expand Down
11 changes: 10 additions & 1 deletion flo_ai/flo_ai/llm/base_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
from typing import Dict, Any, List, Optional, AsyncIterator
from flo_ai.tool.base_tool import Tool
from flo_ai.utils.document_processor import get_default_processor
from flo_ai.utils.logger import logger
Expand Down Expand Up @@ -34,6 +34,15 @@ async def generate(
"""Generate a response from the LLM"""
pass

@abstractmethod
async def stream(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an example in Readme on how to use this

self,
messages: List[Dict[str, str]],
functions: Optional[List[Dict[str, Any]]] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""Stream partial responses from the LLM as they are generated"""
pass

async def get_function_call(
self, response: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
Expand Down
44 changes: 43 additions & 1 deletion flo_ai/flo_ai/llm/gemini_llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Any, List, Optional
from typing import Dict, Any, List, Optional, AsyncIterator
from google import genai
from google.genai import types
from .base_llm import BaseLLM, ImageMessage
Expand Down Expand Up @@ -89,6 +89,48 @@ async def generate(
except Exception as e:
raise Exception(f'Error in Gemini API call: {str(e)}')

async def stream(
self,
messages: List[Dict[str, str]],
functions: Optional[List[Dict[str, Any]]] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""Stream partial responses from Gemini as they are generated"""
# Convert messages to Gemini format
contents = []
system_prompt = ''

for msg in messages:
role = msg['role']
message_content = msg['content']

if role == 'system':
system_prompt += f'{message_content}\n'
else:
contents.append(message_content)

# Prepare generation config
generation_config = types.GenerateContentConfig(
temperature=self.temperature,
system_instruction=system_prompt,
**self.kwargs,
)

# Add tools if functions are provided
if functions:
tools = types.Tool(function_declarations=functions)
generation_config.tools = [tools]

# Stream the API call
stream = self.client.models.generate_content_stream(
model=self.model,
contents=contents,
config=generation_config,
)

for chunk in stream:
if hasattr(chunk, 'text') and chunk.text:
yield {'content': chunk.text}

def get_message_content(self, response: Any) -> str:
"""Extract message content from response"""
if isinstance(response, dict):
Expand Down
58 changes: 57 additions & 1 deletion flo_ai/flo_ai/llm/ollama_llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Any, List, Optional
from typing import Dict, Any, List, Optional, AsyncIterator
import aiohttp
import json
from .base_llm import BaseLLM, ImageMessage
Expand Down Expand Up @@ -65,6 +65,62 @@ async def generate(
'function_call': result.get('function_call'),
}

async def stream(
self,
messages: List[Dict[str, str]],
functions: Optional[List[Dict[str, Any]]] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""Stream partial responses from the hosted Ollama service.

Note: For streaming, do not include the 'stream' flag in payload; the
service defaults to streamed output.
"""
# Convert messages to Ollama prompt format
prompt = ''
for msg in messages:
role = msg['role']
content = msg['content']
if role == 'system':
prompt += f'System: {content}\n'
elif role == 'user':
prompt += f'User: {content}\n'
elif role == 'assistant':
prompt += f'Assistant: {content}\n'

# Prepare request payload without 'stream' key for streaming
payload = {
'model': self.model,
'prompt': prompt,
'temperature': self.temperature,
**self.kwargs,
}

if functions:
payload['functions'] = functions

async with aiohttp.ClientSession() as session:
async with session.post(
f'{self.base_url}/api/generate', json=payload
) as response:
if response.status != 200:
raise Exception(f'Ollama API error: {await response.text()}')

async for raw_line in response.content:
line = raw_line.decode('utf-8').strip()
if not line:
continue
try:
data = json.loads(line)
except Exception:
# Skip non-JSON lines
continue

if 'response' in data and data['response']:
yield {'content': data['response']}

if data.get('done') is True:
break

def get_message_content(self, response: Any) -> str:
"""Extract message content from response"""
if isinstance(response, dict):
Expand Down
34 changes: 33 additions & 1 deletion flo_ai/flo_ai/llm/openai_llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Any, List
from typing import Dict, Any, List, AsyncIterator, Optional
from openai import AsyncOpenAI
from .base_llm import BaseLLM, ImageMessage
from flo_ai.tool.base_tool import Tool
Expand Down Expand Up @@ -65,6 +65,38 @@ async def generate(
# Return the full message object instead of just the content
return message

async def stream(
self,
messages: List[Dict[str, Any]],
functions: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
"""Stream partial responses from OpenAI Chat Completions API."""
# Prepare OpenAI API parameters
openai_kwargs = {
'model': self.model,
'messages': messages,
'temperature': self.temperature,
'stream': True,
**kwargs,
**self.kwargs,
}

if functions:
openai_kwargs['functions'] = functions

# Stream the API call and yield content deltas
response = await self.client.chat.completions.create(**openai_kwargs)
async for chunk in response:
choices = getattr(chunk, 'choices', []) or []
for choice in choices:
delta = getattr(choice, 'delta', None)
if delta is None:
continue
content = getattr(delta, 'content', None)
if content:
yield {'content': content}

def get_message_content(self, response: Dict[str, Any]) -> str:
# Handle both string responses and message objects
if isinstance(response, str):
Expand Down
32 changes: 31 additions & 1 deletion flo_ai/flo_ai/llm/openai_vllm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, AsyncIterator, Dict, List, Optional
from .openai_llm import OpenAI


Expand All @@ -18,6 +18,7 @@ def __init__(
base_url=base_url,
**kwargs,
)

# Store base_url attribute
self.base_url = base_url

Expand Down Expand Up @@ -65,3 +66,32 @@ async def generate(

# Return the full message object instead of just the content
return message

async def stream(
self,
messages: List[Dict[str, Any]],
functions: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
"""Stream partial responses from vLLM-hosted OpenAI-compatible endpoint."""
vllm_openai_kwargs = {
'model': self.model,
'messages': messages,
'temperature': self.temperature,
'stream': True,
**kwargs,
**self.kwargs,
}

if functions:
vllm_openai_kwargs['functions'] = functions
response = await self.client.chat.completions.create(**vllm_openai_kwargs)
async for chunk in response:
choices = getattr(chunk, 'choices', []) or []
for choice in choices:
delta = getattr(choice, 'delta', None)
if delta is None:
continue
content = getattr(delta, 'content', None)
if content:
yield {'content': content}
Loading