diff --git a/flo_ai/flo_ai/llm/anthropic_llm.py b/flo_ai/flo_ai/llm/anthropic_llm.py index aa3e9c2f..308d177d 100644 --- a/flo_ai/flo_ai/llm/anthropic_llm.py +++ b/flo_ai/flo_ai/llm/anthropic_llm.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, AsyncIterator from anthropic import AsyncAnthropic import json from .base_llm import BaseLLM, ImageMessage @@ -77,6 +77,51 @@ async def generate( except Exception as e: raise Exception(f'Error in Claude API call: {str(e)}') + async def stream( + self, + messages: List[Dict[str, str]], + functions: Optional[List[Dict[str, Any]]] = None, + ) -> AsyncIterator[Dict[str, Any]]: + """Stream partial responses from the LLM as they are generated""" + # Convert messages to Claude format + system_message = next( + (msg['content'] for msg in messages if msg['role'] == 'system'), None + ) + + conversation = [] + for msg in messages: + if msg['role'] != 'system': + conversation.append( + { + 'role': 'assistant' if msg['role'] == 'assistant' else 'user', + 'content': msg['content'], + } + ) + + kwargs = { + 'model': self.model, + 'messages': conversation, + 'temperature': self.temperature, + 'max_tokens': self.kwargs.get('max_tokens', 1024), + **self.kwargs, + } + + if system_message: + kwargs['system'] = system_message + + if functions: + kwargs['tools'] = functions + # Use Anthropic SDK streaming API and yield text deltas + async with self.client.messages.stream(**kwargs) as stream: + async for event in stream: + if ( + getattr(event, 'type', None) == 'content_block_delta' + and hasattr(event, 'delta') + and getattr(event.delta, 'type', None) == 'text_delta' + and hasattr(event.delta, 'text') + ): + yield {'content': event.delta.text} + def get_message_content(self, response: Any) -> str: """Extract message content from response""" if isinstance(response, dict): diff --git a/flo_ai/flo_ai/llm/base_llm.py b/flo_ai/flo_ai/llm/base_llm.py index d7910e54..6719088c 100644 --- a/flo_ai/flo_ai/llm/base_llm.py +++ b/flo_ai/flo_ai/llm/base_llm.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, AsyncIterator from flo_ai.tool.base_tool import Tool from flo_ai.utils.document_processor import get_default_processor from flo_ai.utils.logger import logger @@ -34,6 +34,15 @@ async def generate( """Generate a response from the LLM""" pass + @abstractmethod + async def stream( + self, + messages: List[Dict[str, str]], + functions: Optional[List[Dict[str, Any]]] = None, + ) -> AsyncIterator[Dict[str, Any]]: + """Stream partial responses from the LLM as they are generated""" + pass + async def get_function_call( self, response: Dict[str, Any] ) -> Optional[Dict[str, Any]]: diff --git a/flo_ai/flo_ai/llm/gemini_llm.py b/flo_ai/flo_ai/llm/gemini_llm.py index 55885659..5f891903 100644 --- a/flo_ai/flo_ai/llm/gemini_llm.py +++ b/flo_ai/flo_ai/llm/gemini_llm.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, AsyncIterator from google import genai from google.genai import types from .base_llm import BaseLLM, ImageMessage @@ -89,6 +89,48 @@ async def generate( except Exception as e: raise Exception(f'Error in Gemini API call: {str(e)}') + async def stream( + self, + messages: List[Dict[str, str]], + functions: Optional[List[Dict[str, Any]]] = None, + ) -> AsyncIterator[Dict[str, Any]]: + """Stream partial responses from Gemini as they are generated""" + # Convert messages to Gemini format + contents = [] + system_prompt = '' + + for msg in messages: + role = msg['role'] + message_content = msg['content'] + + if role == 'system': + system_prompt += f'{message_content}\n' + else: + contents.append(message_content) + + # Prepare generation config + generation_config = types.GenerateContentConfig( + temperature=self.temperature, + system_instruction=system_prompt, + **self.kwargs, + ) + + # Add tools if functions are provided + if functions: + tools = types.Tool(function_declarations=functions) + generation_config.tools = [tools] + + # Stream the API call + stream = self.client.models.generate_content_stream( + model=self.model, + contents=contents, + config=generation_config, + ) + + for chunk in stream: + if hasattr(chunk, 'text') and chunk.text: + yield {'content': chunk.text} + def get_message_content(self, response: Any) -> str: """Extract message content from response""" if isinstance(response, dict): diff --git a/flo_ai/flo_ai/llm/ollama_llm.py b/flo_ai/flo_ai/llm/ollama_llm.py index d4e12602..8e7015be 100644 --- a/flo_ai/flo_ai/llm/ollama_llm.py +++ b/flo_ai/flo_ai/llm/ollama_llm.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, AsyncIterator import aiohttp import json from .base_llm import BaseLLM, ImageMessage @@ -65,6 +65,62 @@ async def generate( 'function_call': result.get('function_call'), } + async def stream( + self, + messages: List[Dict[str, str]], + functions: Optional[List[Dict[str, Any]]] = None, + ) -> AsyncIterator[Dict[str, Any]]: + """Stream partial responses from the hosted Ollama service. + + Note: For streaming, do not include the 'stream' flag in payload; the + service defaults to streamed output. + """ + # Convert messages to Ollama prompt format + prompt = '' + for msg in messages: + role = msg['role'] + content = msg['content'] + if role == 'system': + prompt += f'System: {content}\n' + elif role == 'user': + prompt += f'User: {content}\n' + elif role == 'assistant': + prompt += f'Assistant: {content}\n' + + # Prepare request payload without 'stream' key for streaming + payload = { + 'model': self.model, + 'prompt': prompt, + 'temperature': self.temperature, + **self.kwargs, + } + + if functions: + payload['functions'] = functions + + async with aiohttp.ClientSession() as session: + async with session.post( + f'{self.base_url}/api/generate', json=payload + ) as response: + if response.status != 200: + raise Exception(f'Ollama API error: {await response.text()}') + + async for raw_line in response.content: + line = raw_line.decode('utf-8').strip() + if not line: + continue + try: + data = json.loads(line) + except Exception: + # Skip non-JSON lines + continue + + if 'response' in data and data['response']: + yield {'content': data['response']} + + if data.get('done') is True: + break + def get_message_content(self, response: Any) -> str: """Extract message content from response""" if isinstance(response, dict): diff --git a/flo_ai/flo_ai/llm/openai_llm.py b/flo_ai/flo_ai/llm/openai_llm.py index 383ed916..3531f500 100644 --- a/flo_ai/flo_ai/llm/openai_llm.py +++ b/flo_ai/flo_ai/llm/openai_llm.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, List +from typing import Dict, Any, List, AsyncIterator, Optional from openai import AsyncOpenAI from .base_llm import BaseLLM, ImageMessage from flo_ai.tool.base_tool import Tool @@ -65,6 +65,38 @@ async def generate( # Return the full message object instead of just the content return message + async def stream( + self, + messages: List[Dict[str, Any]], + functions: Optional[List[Dict[str, Any]]] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + """Stream partial responses from OpenAI Chat Completions API.""" + # Prepare OpenAI API parameters + openai_kwargs = { + 'model': self.model, + 'messages': messages, + 'temperature': self.temperature, + 'stream': True, + **kwargs, + **self.kwargs, + } + + if functions: + openai_kwargs['functions'] = functions + + # Stream the API call and yield content deltas + response = await self.client.chat.completions.create(**openai_kwargs) + async for chunk in response: + choices = getattr(chunk, 'choices', []) or [] + for choice in choices: + delta = getattr(choice, 'delta', None) + if delta is None: + continue + content = getattr(delta, 'content', None) + if content: + yield {'content': content} + def get_message_content(self, response: Dict[str, Any]) -> str: # Handle both string responses and message objects if isinstance(response, str): diff --git a/flo_ai/flo_ai/llm/openai_vllm.py b/flo_ai/flo_ai/llm/openai_vllm.py index eb4a7c0b..37ffa068 100644 --- a/flo_ai/flo_ai/llm/openai_vllm.py +++ b/flo_ai/flo_ai/llm/openai_vllm.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, AsyncIterator, Dict, List, Optional from .openai_llm import OpenAI @@ -18,6 +18,7 @@ def __init__( base_url=base_url, **kwargs, ) + # Store base_url attribute self.base_url = base_url @@ -65,3 +66,32 @@ async def generate( # Return the full message object instead of just the content return message + + async def stream( + self, + messages: List[Dict[str, Any]], + functions: Optional[List[Dict[str, Any]]] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + """Stream partial responses from vLLM-hosted OpenAI-compatible endpoint.""" + vllm_openai_kwargs = { + 'model': self.model, + 'messages': messages, + 'temperature': self.temperature, + 'stream': True, + **kwargs, + **self.kwargs, + } + + if functions: + vllm_openai_kwargs['functions'] = functions + response = await self.client.chat.completions.create(**vllm_openai_kwargs) + async for chunk in response: + choices = getattr(chunk, 'choices', []) or [] + for choice in choices: + delta = getattr(choice, 'delta', None) + if delta is None: + continue + content = getattr(delta, 'content', None) + if content: + yield {'content': content} diff --git a/flo_ai/tests/test_anthropic_llm.py b/flo_ai/tests/test_anthropic_llm.py index 975b29f9..657c08b2 100644 --- a/flo_ai/tests/test_anthropic_llm.py +++ b/flo_ai/tests/test_anthropic_llm.py @@ -405,3 +405,111 @@ def test_anthropic_base_url_handling(self): # Test without base URL llm = Anthropic() assert not hasattr(llm, 'base_url') + + @pytest.mark.asyncio + async def test_anthropic_stream_basic(self): + """Test basic stream method without functions.""" + llm = Anthropic(model='claude-3-5-sonnet-20240620') + + # Mock streaming events + mock_delta1 = Mock() + mock_delta1.type = 'text_delta' + mock_delta1.text = 'Hello' + + mock_delta2 = Mock() + mock_delta2.type = 'text_delta' + mock_delta2.text = ', world!' + + mock_event1 = Mock() + mock_event1.type = 'content_block_delta' + mock_event1.delta = mock_delta1 + + mock_event2 = Mock() + mock_event2.type = 'content_block_delta' + mock_event2.delta = mock_delta2 + + # Create a proper async iterator + async def async_iter(): + yield mock_event1 + yield mock_event2 + + # Mock the streaming context manager + mock_stream = AsyncMock() + mock_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_stream.__aexit__ = AsyncMock(return_value=None) + mock_stream.__aiter__ = Mock(return_value=async_iter()) + + llm.client = Mock() + llm.client.messages.stream = Mock(return_value=mock_stream) + + messages = [{'role': 'user', 'content': 'Hello'}] + + # Collect streaming results + results = [] + async for chunk in llm.stream(messages): + results.append(chunk) + + # Verify the API call + llm.client.messages.stream.assert_called_once() + call_args = llm.client.messages.stream.call_args[1] + + assert call_args['model'] == 'claude-3-5-sonnet-20240620' + assert call_args['messages'] == messages + assert call_args['temperature'] == 0.7 + assert call_args['max_tokens'] == 1024 + + # Verify the streaming results + assert len(results) == 2 + assert results[0] == {'content': 'Hello'} + assert results[1] == {'content': ', world!'} + + @pytest.mark.asyncio + async def test_anthropic_stream_with_functions(self): + """Test stream method with functions (tools).""" + llm = Anthropic(model='claude-3-5-sonnet-20240620') + + functions = [ + { + 'type': 'custom', + 'name': 'test_function', + 'description': 'A test function', + 'input_schema': {'type': 'object'}, + } + ] + + # Mock streaming events + mock_delta = Mock() + mock_delta.type = 'text_delta' + mock_delta.text = 'I will use the function' + + mock_event = Mock() + mock_event.type = 'content_block_delta' + mock_event.delta = mock_delta + + # Create a proper async iterator + async def async_iter(): + yield mock_event + + # Mock the streaming context manager + mock_stream = AsyncMock() + mock_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_stream.__aexit__ = AsyncMock(return_value=None) + mock_stream.__aiter__ = Mock(return_value=async_iter()) + + llm.client = Mock() + llm.client.messages.stream = Mock(return_value=mock_stream) + + messages = [{'role': 'user', 'content': 'Use the function'}] + + # Collect streaming results + results = [] + async for chunk in llm.stream(messages, functions=functions): + results.append(chunk) + + # Verify functions were passed correctly + call_args = llm.client.messages.stream.call_args[1] + assert call_args['tools'] == functions + + # Verify the streaming results + assert len(results) == 1 + assert results[0] == {'content': 'I will use the function'} diff --git a/flo_ai/tests/test_base_llm.py b/flo_ai/tests/test_base_llm.py index 147e3457..c291cbe0 100644 --- a/flo_ai/tests/test_base_llm.py +++ b/flo_ai/tests/test_base_llm.py @@ -33,6 +33,12 @@ async def get_function_call(self, response): } return None + async def stream(self, messages, functions=None): + async def generator(): + yield {'response': self.response_text} + + return generator() + def get_message_content(self, response): if isinstance(response, dict): return response.get('content', '') diff --git a/flo_ai/tests/test_gemini_llm.py b/flo_ai/tests/test_gemini_llm.py index 50922502..6c53976f 100644 --- a/flo_ai/tests/test_gemini_llm.py +++ b/flo_ai/tests/test_gemini_llm.py @@ -488,3 +488,102 @@ def test_gemini_generation_config_creation(self): # This would normally be called in generate method # For testing, we'll just verify the config class exists assert mock_config is not None + + @pytest.mark.asyncio + @patch('flo_ai.llm.gemini_llm.types.GenerateContentConfig') + async def test_gemini_stream_basic(self, mock_config_class): + """Test basic stream method without functions.""" + llm = Gemini(model='gemini-2.5-flash') + + # Mock the config + mock_config = Mock() + mock_config_class.return_value = mock_config + + # Mock streaming chunks + mock_chunk1 = Mock() + mock_chunk1.text = 'Hello' + + mock_chunk2 = Mock() + mock_chunk2.text = ', world!' + + # Create a regular iterator (Gemini API returns regular iterator, not async) + def regular_iter(): + yield mock_chunk1 + yield mock_chunk2 + + # Mock the client response + llm.client = Mock() + llm.client.models.generate_content_stream = Mock(return_value=regular_iter()) + + messages = [{'role': 'user', 'content': 'Hello'}] + + # Collect streaming results + results = [] + async for chunk in llm.stream(messages): + results.append(chunk) + + # Verify the API call + llm.client.models.generate_content_stream.assert_called_once() + call_args = llm.client.models.generate_content_stream.call_args + + assert call_args[1]['model'] == 'gemini-2.5-flash' + assert call_args[1]['contents'] == ['Hello'] + assert call_args[1]['config'] == mock_config + + # Verify the streaming results + assert len(results) == 2 + assert results[0] == {'content': 'Hello'} + assert results[1] == {'content': ', world!'} + + @pytest.mark.asyncio + @patch('flo_ai.llm.gemini_llm.types.Tool') + @patch('flo_ai.llm.gemini_llm.types.GenerateContentConfig') + async def test_gemini_stream_with_functions( + self, mock_config_class, mock_tool_class + ): + """Test stream method with functions (tools).""" + llm = Gemini(model='gemini-2.5-flash') + + functions = [ + { + 'name': 'test_function', + 'description': 'A test function', + 'parameters': {'type': 'object'}, + } + ] + + # Mock the tool and config + mock_tool = Mock() + mock_tool_class.return_value = mock_tool + + mock_config = Mock() + mock_config_class.return_value = mock_config + + # Mock streaming chunks + mock_chunk = Mock() + mock_chunk.text = 'I will use the function' + + # Create a regular iterator (Gemini API returns regular iterator, not async) + def regular_iter(): + yield mock_chunk + + # Mock the client response + llm.client = Mock() + llm.client.models.generate_content_stream = Mock(return_value=regular_iter()) + + messages = [{'role': 'user', 'content': 'Use the function'}] + + # Collect streaming results + results = [] + async for chunk in llm.stream(messages, functions=functions): + results.append(chunk) + + # Verify tool was created with function declarations + mock_tool_class.assert_called_once_with(function_declarations=functions) + + # Verify tools were added to config + mock_config.tools = [mock_tool] + + # Verify the streaming results + assert len(results) == 1 + assert results[0] == {'content': 'I will use the function'} diff --git a/flo_ai/tests/test_llm_router.py b/flo_ai/tests/test_llm_router.py index 728ebdc9..42fa78b8 100644 --- a/flo_ai/tests/test_llm_router.py +++ b/flo_ai/tests/test_llm_router.py @@ -29,6 +29,12 @@ async def generate(self, messages, **kwargs): self.call_count += 1 return {'response': self.response_text} + async def stream(self, messages, functions=None): + async def generator(): + yield {'response': self.response_text} + + return generator() + def get_message_content(self, response): return response.get('response', 'researcher') diff --git a/flo_ai/tests/test_openai_llm.py b/flo_ai/tests/test_openai_llm.py index 3a069f5a..4a66df21 100644 --- a/flo_ai/tests/test_openai_llm.py +++ b/flo_ai/tests/test_openai_llm.py @@ -323,3 +323,120 @@ def test_openai_base_url_handling(self): # Test without base URL llm = OpenAI() assert not hasattr(llm, 'base_url') + + @pytest.mark.asyncio + async def test_openai_stream_basic(self): + """Test basic stream method without functions.""" + llm = OpenAI(model='gpt-4o-mini') + + # Mock streaming chunks + mock_delta1 = Mock() + mock_delta1.content = 'Hello' + + mock_delta2 = Mock() + mock_delta2.content = ', world!' + + mock_choice1 = Mock() + mock_choice1.delta = mock_delta1 + + mock_choice2 = Mock() + mock_choice2.delta = mock_delta2 + + mock_chunk1 = Mock() + mock_chunk1.choices = [mock_choice1] + + mock_chunk2 = Mock() + mock_chunk2.choices = [mock_choice2] + + # Create a proper async iterator + async def async_iter(): + yield mock_chunk1 + yield mock_chunk2 + + # Mock the client response + llm.client = Mock() + llm.client.chat.completions.create = AsyncMock(return_value=async_iter()) + + messages = [{'role': 'user', 'content': 'Hello'}] + + # Collect streaming results + results = [] + async for chunk in llm.stream(messages): + results.append(chunk) + + # Verify the API call + llm.client.chat.completions.create.assert_called_once() + call_args = llm.client.chat.completions.create.call_args + + assert call_args[1]['model'] == 'gpt-4o-mini' + assert call_args[1]['messages'] == messages + assert call_args[1]['temperature'] == 0.7 + assert call_args[1]['stream'] is True + + # Verify the streaming results + assert len(results) == 2 + assert results[0] == {'content': 'Hello'} + assert results[1] == {'content': ', world!'} + + @pytest.mark.asyncio + async def test_openai_stream_with_functions(self): + """Test stream method with functions.""" + llm = OpenAI(model='gpt-4o-mini') + + functions = [ + { + 'name': 'test_function', + 'description': 'A test function', + 'parameters': {'type': 'object'}, + } + ] + + # Mock streaming chunks + mock_delta = Mock() + mock_delta.content = 'I will use the function' + + mock_choice = Mock() + mock_choice.delta = mock_delta + + mock_chunk = Mock() + mock_chunk.choices = [mock_choice] + + # Create a proper async iterator + async def async_iter(): + yield mock_chunk + + # Mock the client response + llm.client = Mock() + llm.client.chat.completions.create = AsyncMock(return_value=async_iter()) + + messages = [{'role': 'user', 'content': 'Use the function'}] + + # Collect streaming results + results = [] + async for chunk in llm.stream(messages, functions=functions): + results.append(chunk) + + # Verify functions were passed correctly + call_args = llm.client.chat.completions.create.call_args + assert call_args[1]['functions'] == functions + + # Verify the streaming results + assert len(results) == 1 + assert results[0] == {'content': 'I will use the function'} + + @pytest.mark.asyncio + async def test_openai_stream_error_handling(self): + """Test error handling in stream method.""" + llm = OpenAI(model='gpt-4o-mini') + + # Mock client to raise an exception + llm.client = Mock() + llm.client.chat.completions.create = AsyncMock( + side_effect=Exception('Streaming API Error') + ) + + messages = [{'role': 'user', 'content': 'Hello'}] + + with pytest.raises(Exception, match='Streaming API Error'): + async for chunk in llm.stream(messages): + pass diff --git a/flo_ai/tests/test_openai_vllm.py b/flo_ai/tests/test_openai_vllm.py index a7938c08..409ee412 100644 --- a/flo_ai/tests/test_openai_vllm.py +++ b/flo_ai/tests/test_openai_vllm.py @@ -573,3 +573,129 @@ def test_openai_vllm_initialization_order(self, mock_async_openai): assert llm.model == 'test-model' assert llm.base_url == 'https://test.vllm.com' assert llm.client == mock_client + + @pytest.mark.asyncio + @patch('flo_ai.llm.openai_llm.AsyncOpenAI') + async def test_openai_vllm_stream_basic(self, mock_async_openai): + """Test basic stream method without functions.""" + mock_client = Mock() + mock_async_openai.return_value = mock_client + + llm = OpenAIVLLM(base_url='https://api.vllm.com', model='gpt-4o-mini') + + # Mock streaming chunks + mock_delta1 = Mock() + mock_delta1.content = 'Hello' + + mock_delta2 = Mock() + mock_delta2.content = ', world!' + + mock_choice1 = Mock() + mock_choice1.delta = mock_delta1 + + mock_choice2 = Mock() + mock_choice2.delta = mock_delta2 + + mock_chunk1 = Mock() + mock_chunk1.choices = [mock_choice1] + + mock_chunk2 = Mock() + mock_chunk2.choices = [mock_choice2] + + # Create a proper async iterator + async def async_iter(): + yield mock_chunk1 + yield mock_chunk2 + + # Mock the client response + llm.client.chat.completions.create = AsyncMock(return_value=async_iter()) + + messages = [{'role': 'user', 'content': 'Hello'}] + + # Collect streaming results + results = [] + async for chunk in llm.stream(messages): + results.append(chunk) + + # Verify the API call + llm.client.chat.completions.create.assert_called_once() + call_args = llm.client.chat.completions.create.call_args + + assert call_args[1]['model'] == 'gpt-4o-mini' + assert call_args[1]['messages'] == messages + assert call_args[1]['temperature'] == 0.7 + assert call_args[1]['stream'] is True + + # Verify the streaming results + assert len(results) == 2 + assert results[0] == {'content': 'Hello'} + assert results[1] == {'content': ', world!'} + + @pytest.mark.asyncio + @patch('flo_ai.llm.openai_llm.AsyncOpenAI') + async def test_openai_vllm_stream_with_functions(self, mock_async_openai): + """Test stream method with functions.""" + mock_client = Mock() + mock_async_openai.return_value = mock_client + + llm = OpenAIVLLM(base_url='https://api.vllm.com', model='gpt-4o-mini') + + functions = [ + { + 'name': 'test_function', + 'description': 'A test function', + 'parameters': {'type': 'object'}, + } + ] + + # Mock streaming chunks + mock_delta = Mock() + mock_delta.content = 'I will use the function' + + mock_choice = Mock() + mock_choice.delta = mock_delta + + mock_chunk = Mock() + mock_chunk.choices = [mock_choice] + + # Create a proper async iterator + async def async_iter(): + yield mock_chunk + + # Mock the client response + llm.client.chat.completions.create = AsyncMock(return_value=async_iter()) + + messages = [{'role': 'user', 'content': 'Use the function'}] + + # Collect streaming results + results = [] + async for chunk in llm.stream(messages, functions=functions): + results.append(chunk) + + # Verify functions were passed correctly + call_args = llm.client.chat.completions.create.call_args + assert call_args[1]['functions'] == functions + + # Verify the streaming results + assert len(results) == 1 + assert results[0] == {'content': 'I will use the function'} + + @pytest.mark.asyncio + @patch('flo_ai.llm.openai_llm.AsyncOpenAI') + async def test_openai_vllm_stream_error_handling(self, mock_async_openai): + """Test error handling in stream method.""" + mock_client = Mock() + mock_async_openai.return_value = mock_client + + llm = OpenAIVLLM(base_url='https://api.vllm.com', model='gpt-4o-mini') + + # Mock client to raise an exception + llm.client.chat.completions.create = AsyncMock( + side_effect=Exception('Streaming API Error') + ) + + messages = [{'role': 'user', 'content': 'Hello'}] + + with pytest.raises(Exception, match='Streaming API Error'): + async for chunk in llm.stream(messages): + pass