diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py index 39af4b01e8f8..ffd6bc2594ad 100644 --- a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py @@ -20,6 +20,7 @@ MistralAIChatPromptExecutionSettings, ) from semantic_kernel.connectors.ai.mistral_ai.settings.mistral_ai_settings import MistralAISettings +from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.function_call_content import FunctionCallContent @@ -95,20 +96,24 @@ def __init__( async def get_chat_message_contents( self, chat_history: "ChatHistory", - settings: "MistralAIChatPromptExecutionSettings", # type: ignore[override] + settings: "PromptExecutionSettings", **kwargs: Any, ) -> list["ChatMessageContent"]: """Executes a chat completion request and returns the result. Args: chat_history (ChatHistory): The chat history to use for the chat completion. - settings (MistralAIChatPromptExecutionSettings): The settings to use + settings (PromptExecutionSettings): The settings to use for the chat completion request. kwargs (Dict[str, Any]): The optional arguments. Returns: List[ChatMessageContent]: The completion result(s). """ + if not isinstance(settings, MistralAIChatPromptExecutionSettings): + settings = self.get_prompt_execution_settings_from_settings(settings) + assert isinstance(settings, MistralAIChatPromptExecutionSettings) # nosec + if not settings.ai_model_id: settings.ai_model_id = self.ai_model_id @@ -128,14 +133,14 @@ async def get_chat_message_contents( async def get_streaming_chat_message_contents( self, chat_history: ChatHistory, - settings: MistralAIChatPromptExecutionSettings, # type: ignore[override] + settings: PromptExecutionSettings, **kwargs: Any, ) -> AsyncGenerator[list[StreamingChatMessageContent], Any]: """Executes a streaming chat completion request and returns the result. Args: chat_history (ChatHistory): The chat history to use for the chat completion. - settings (MistralAIChatPromptExecutionSettings): The settings to use + settings (PromptExecutionSettings): The settings to use for the chat completion request. kwargs (Dict[str, Any]): The optional arguments. @@ -143,6 +148,10 @@ async def get_streaming_chat_message_contents( List[StreamingChatMessageContent]: A stream of StreamingChatMessageContent when using Azure. """ + if not isinstance(settings, MistralAIChatPromptExecutionSettings): + settings = self.get_prompt_execution_settings_from_settings(settings) + assert isinstance(settings, MistralAIChatPromptExecutionSettings) # nosec + if not settings.ai_model_id: settings.ai_model_id = self.ai_model_id diff --git a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py index 8510fbae3ea5..ba1b0b51aa7b 100644 --- a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py +++ b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py @@ -9,18 +9,22 @@ MistralAIChatPromptExecutionSettings, ) from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_chat_completion import MistralAIChatCompletion +from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_prompt_execution_settings import ( + OpenAIChatPromptExecutionSettings, +) from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.exceptions import ServiceInitializationError, ServiceResponseException from semantic_kernel.functions.kernel_arguments import KernelArguments from semantic_kernel.kernel import Kernel -@pytest.mark.asyncio -async def test_complete_chat_contents(kernel: Kernel): - chat_history = MagicMock() - settings = MagicMock() - settings.number_of_responses = 1 - arguments = KernelArguments() +@pytest.fixture +def mock_settings() -> MistralAIChatPromptExecutionSettings: + return MistralAIChatPromptExecutionSettings() + + +@pytest.fixture +def mock_mistral_ai_client_completion() -> MistralAsyncClient: client = MagicMock(spec=MistralAsyncClient) chat_completion_response = AsyncMock() choices = [ @@ -28,22 +32,11 @@ async def test_complete_chat_contents(kernel: Kernel): ] chat_completion_response.choices = choices client.chat.return_value = chat_completion_response - chat_completion_base = MistralAIChatCompletion( - ai_model_id="test_model_id", service_id="test", api_key="", async_client=client - ) - - content: list[ChatMessageContent] = await chat_completion_base.get_chat_message_contents( - chat_history, settings, kernel=kernel, arguments=arguments - ) - assert content is not None + return client -@pytest.mark.asyncio -async def test_complete_chat_stream_contents(kernel: Kernel): - chat_history = MagicMock() - settings = MagicMock() - settings.ai_model_id = None - arguments = KernelArguments() +@pytest.fixture +def mock_mistral_ai_client_completion_stream() -> MistralAsyncClient: client = MagicMock(spec=MistralAsyncClient) chat_completion_response = MagicMock() choices = [ @@ -56,42 +49,70 @@ async def test_complete_chat_stream_contents(kernel: Kernel): generator_mock = MagicMock() generator_mock.__aiter__.return_value = [chat_completion_response_empty, chat_completion_response] client.chat_stream.return_value = generator_mock + return client + +@pytest.mark.asyncio +async def test_complete_chat_contents( + kernel: Kernel, + mock_settings: MistralAIChatPromptExecutionSettings, + mock_mistral_ai_client_completion: MistralAsyncClient +): + chat_history = MagicMock() + arguments = KernelArguments() chat_completion_base = MistralAIChatCompletion( - ai_model_id="test_model_id", service_id="test", api_key="", async_client=client + ai_model_id="test_model_id", service_id="test", api_key="", async_client=mock_mistral_ai_client_completion + ) + + content: list[ChatMessageContent] = await chat_completion_base.get_chat_message_contents( + chat_history, mock_settings, kernel=kernel, arguments=arguments + ) + assert content is not None + + +@pytest.mark.asyncio +async def test_complete_chat_stream_contents( + kernel: Kernel, + mock_settings: MistralAIChatPromptExecutionSettings, + mock_mistral_ai_client_completion_stream: MistralAsyncClient +): + chat_history = MagicMock() + arguments = KernelArguments() + + chat_completion_base = MistralAIChatCompletion( + ai_model_id="test_model_id", + service_id="test", api_key="", + async_client=mock_mistral_ai_client_completion_stream ) async for content in chat_completion_base.get_streaming_chat_message_contents( - chat_history, settings, kernel=kernel, arguments=arguments + chat_history, mock_settings, kernel=kernel, arguments=arguments ): assert content is not None @pytest.mark.asyncio -async def test_mistral_ai_sdk_exception(kernel: Kernel): +async def test_mistral_ai_sdk_exception(kernel: Kernel, mock_settings: MistralAIChatPromptExecutionSettings): chat_history = MagicMock() - settings = MagicMock() - settings.ai_model_id = None arguments = KernelArguments() client = MagicMock(spec=MistralAsyncClient) client.chat.side_effect = Exception("Test Exception") - client.chat_stream.side_effect = Exception("Test Exception") chat_completion_base = MistralAIChatCompletion( - ai_model_id="test_model_id", service_id="test", api_key="", async_client=client + ai_model_id="test_model_id", + service_id="test", api_key="", + async_client=client ) with pytest.raises(ServiceResponseException): await chat_completion_base.get_chat_message_contents( - chat_history, settings, kernel=kernel, arguments=arguments + chat_history, mock_settings, kernel=kernel, arguments=arguments ) @pytest.mark.asyncio -async def test_mistral_ai_sdk_exception_streaming(kernel: Kernel): +async def test_mistral_ai_sdk_exception_streaming(kernel: Kernel, mock_settings: MistralAIChatPromptExecutionSettings): chat_history = MagicMock() - settings = MagicMock() - settings.number_of_responses = 1 arguments = KernelArguments() client = MagicMock(spec=MistralAsyncClient) client.chat_stream.side_effect = Exception("Test Exception") @@ -102,7 +123,7 @@ async def test_mistral_ai_sdk_exception_streaming(kernel: Kernel): with pytest.raises(ServiceResponseException): async for content in chat_completion_base.get_streaming_chat_message_contents( - chat_history, settings, kernel=kernel, arguments=arguments + chat_history, mock_settings, kernel=kernel, arguments=arguments ): assert content is not None @@ -138,3 +159,46 @@ def test_prompt_execution_settings_class(mistralai_unit_test_env): mistral_ai_chat_completion = MistralAIChatCompletion() prompt_execution_settings = mistral_ai_chat_completion.get_prompt_execution_settings_class() assert prompt_execution_settings == MistralAIChatPromptExecutionSettings + + +@pytest.mark.asyncio +async def test_with_different_execution_settings( + kernel: Kernel, + mock_mistral_ai_client_completion: MagicMock +): + chat_history = MagicMock() + settings = OpenAIChatPromptExecutionSettings(temperature=0.2, seed=2) + arguments = KernelArguments() + chat_completion_base = MistralAIChatCompletion( + ai_model_id="test_model_id", + service_id="test", api_key="", + async_client=mock_mistral_ai_client_completion + ) + + await chat_completion_base.get_chat_message_contents( + chat_history, settings, kernel=kernel, arguments=arguments + ) + assert mock_mistral_ai_client_completion.chat.call_args.kwargs["temperature"] == 0.2 + assert mock_mistral_ai_client_completion.chat.call_args.kwargs["seed"] == 2 + + +@pytest.mark.asyncio +async def test_with_different_execution_settings_stream( + kernel: Kernel, + mock_mistral_ai_client_completion_stream: MagicMock +): + chat_history = MagicMock() + settings = OpenAIChatPromptExecutionSettings(temperature=0.2, seed=2) + arguments = KernelArguments() + chat_completion_base = MistralAIChatCompletion( + ai_model_id="test_model_id", + service_id="test", api_key="", + async_client=mock_mistral_ai_client_completion_stream + ) + + async for chunk in chat_completion_base.get_streaming_chat_message_contents( + chat_history, settings, kernel=kernel, arguments=arguments + ): + continue + assert mock_mistral_ai_client_completion_stream.chat_stream.call_args.kwargs["temperature"] == 0.2 + assert mock_mistral_ai_client_completion_stream.chat_stream.call_args.kwargs["seed"] == 2