diff --git a/flo_ai/flo_ai/llm/gemini_llm.py b/flo_ai/flo_ai/llm/gemini_llm.py index 55065527..372c44ee 100644 --- a/flo_ai/flo_ai/llm/gemini_llm.py +++ b/flo_ai/flo_ai/llm/gemini_llm.py @@ -1,4 +1,5 @@ import base64 +import asyncio from typing import Dict, Any, List, Optional, AsyncIterator from google import genai from google.genai import types @@ -66,8 +67,9 @@ async def generate( generation_config.response_mime_type = 'application/json' generation_config.response_schema = output_schema - # Make the API call - response = self.client.models.generate_content( + # Make the API call (run in thread pool to avoid blocking event loop) + response = await asyncio.to_thread( + self.client.models.generate_content, model=self.model, contents=contents, config=generation_config, @@ -159,14 +161,26 @@ async def stream( tools = types.Tool(function_declarations=functions) generation_config.tools = [tools] - # Stream the API call - stream = self.client.models.generate_content_stream( + # Get stream in thread to avoid blocking the initial call + stream = await asyncio.to_thread( + self.client.models.generate_content_stream, model=self.model, contents=contents, config=generation_config, ) - for chunk in stream: + # Helper to get next chunk in thread pool + def get_next_chunk(): + try: + return next(stream) + except StopIteration: + return None + + # Iterate over synchronous stream without blocking event loop + while True: + chunk = await asyncio.to_thread(get_next_chunk) + if chunk is None: + break if hasattr(chunk, 'text') and chunk.text: yield {'content': chunk.text}