From bdd542d83806d18261f1b040593d3e60a7002bcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20M=C3=B6ller?= Date: Wed, 10 Jul 2024 09:18:34 +0200 Subject: [PATCH 1/3] fixed liskov problem in mypy --- .../services/mistral_ai_chat_completion.py | 13 +++-- .../test_mistralai_chat_completion.py | 54 +++++++++++++------ 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py index 39af4b01e8f8..73fd58b43472 100644 --- a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py @@ -20,6 +20,7 @@ MistralAIChatPromptExecutionSettings, ) from semantic_kernel.connectors.ai.mistral_ai.settings.mistral_ai_settings import MistralAISettings +from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.function_call_content import FunctionCallContent @@ -95,20 +96,22 @@ def __init__( async def get_chat_message_contents( self, chat_history: "ChatHistory", - settings: "MistralAIChatPromptExecutionSettings", # type: ignore[override] + settings: "PromptExecutionSettings", **kwargs: Any, ) -> list["ChatMessageContent"]: """Executes a chat completion request and returns the result. Args: chat_history (ChatHistory): The chat history to use for the chat completion. - settings (MistralAIChatPromptExecutionSettings): The settings to use + settings (PromptExecutionSettings): The settings to use for the chat completion request. kwargs (Dict[str, Any]): The optional arguments. Returns: List[ChatMessageContent]: The completion result(s). """ + settings = self.get_prompt_execution_settings_from_settings(settings) + if not settings.ai_model_id: settings.ai_model_id = self.ai_model_id @@ -128,14 +131,14 @@ async def get_chat_message_contents( async def get_streaming_chat_message_contents( self, chat_history: ChatHistory, - settings: MistralAIChatPromptExecutionSettings, # type: ignore[override] + settings: PromptExecutionSettings, **kwargs: Any, ) -> AsyncGenerator[list[StreamingChatMessageContent], Any]: """Executes a streaming chat completion request and returns the result. Args: chat_history (ChatHistory): The chat history to use for the chat completion. - settings (MistralAIChatPromptExecutionSettings): The settings to use + settings (PromptExecutionSettings): The settings to use for the chat completion request. kwargs (Dict[str, Any]): The optional arguments. @@ -143,6 +146,8 @@ async def get_streaming_chat_message_contents( List[StreamingChatMessageContent]: A stream of StreamingChatMessageContent when using Azure. """ + settings = self.get_prompt_execution_settings_from_settings(settings) + if not settings.ai_model_id: settings.ai_model_id = self.ai_model_id diff --git a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py index 8510fbae3ea5..10336f362513 100644 --- a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py +++ b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py @@ -9,17 +9,21 @@ MistralAIChatPromptExecutionSettings, ) from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_chat_completion import MistralAIChatCompletion +from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.exceptions import ServiceInitializationError, ServiceResponseException from semantic_kernel.functions.kernel_arguments import KernelArguments from semantic_kernel.kernel import Kernel +@pytest.fixture +def mock_settings() -> MistralAIChatPromptExecutionSettings: + return MistralAIChatPromptExecutionSettings() + + @pytest.mark.asyncio -async def test_complete_chat_contents(kernel: Kernel): +async def test_complete_chat_contents(kernel: Kernel, mock_settings: MistralAIChatPromptExecutionSettings): chat_history = MagicMock() - settings = MagicMock() - settings.number_of_responses = 1 arguments = KernelArguments() client = MagicMock(spec=MistralAsyncClient) chat_completion_response = AsyncMock() @@ -33,16 +37,14 @@ async def test_complete_chat_contents(kernel: Kernel): ) content: list[ChatMessageContent] = await chat_completion_base.get_chat_message_contents( - chat_history, settings, kernel=kernel, arguments=arguments + chat_history, mock_settings, kernel=kernel, arguments=arguments ) assert content is not None @pytest.mark.asyncio -async def test_complete_chat_stream_contents(kernel: Kernel): +async def test_complete_chat_stream_contents(kernel: Kernel, mock_settings: MistralAIChatPromptExecutionSettings): chat_history = MagicMock() - settings = MagicMock() - settings.ai_model_id = None arguments = KernelArguments() client = MagicMock(spec=MistralAsyncClient) chat_completion_response = MagicMock() @@ -62,16 +64,14 @@ async def test_complete_chat_stream_contents(kernel: Kernel): ) async for content in chat_completion_base.get_streaming_chat_message_contents( - chat_history, settings, kernel=kernel, arguments=arguments + chat_history, mock_settings, kernel=kernel, arguments=arguments ): assert content is not None @pytest.mark.asyncio -async def test_mistral_ai_sdk_exception(kernel: Kernel): +async def test_mistral_ai_sdk_exception(kernel: Kernel, mock_settings: MistralAIChatPromptExecutionSettings): chat_history = MagicMock() - settings = MagicMock() - settings.ai_model_id = None arguments = KernelArguments() client = MagicMock(spec=MistralAsyncClient) client.chat.side_effect = Exception("Test Exception") @@ -83,15 +83,13 @@ async def test_mistral_ai_sdk_exception(kernel: Kernel): with pytest.raises(ServiceResponseException): await chat_completion_base.get_chat_message_contents( - chat_history, settings, kernel=kernel, arguments=arguments + chat_history, mock_settings, kernel=kernel, arguments=arguments ) @pytest.mark.asyncio -async def test_mistral_ai_sdk_exception_streaming(kernel: Kernel): +async def test_mistral_ai_sdk_exception_streaming(kernel: Kernel, mock_settings: MistralAIChatPromptExecutionSettings): chat_history = MagicMock() - settings = MagicMock() - settings.number_of_responses = 1 arguments = KernelArguments() client = MagicMock(spec=MistralAsyncClient) client.chat_stream.side_effect = Exception("Test Exception") @@ -102,7 +100,7 @@ async def test_mistral_ai_sdk_exception_streaming(kernel: Kernel): with pytest.raises(ServiceResponseException): async for content in chat_completion_base.get_streaming_chat_message_contents( - chat_history, settings, kernel=kernel, arguments=arguments + chat_history, mock_settings, kernel=kernel, arguments=arguments ): assert content is not None @@ -138,3 +136,27 @@ def test_prompt_execution_settings_class(mistralai_unit_test_env): mistral_ai_chat_completion = MistralAIChatCompletion() prompt_execution_settings = mistral_ai_chat_completion.get_prompt_execution_settings_class() assert prompt_execution_settings == MistralAIChatPromptExecutionSettings + + +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.asyncio +async def test_with_prompt_execution_settings(kernel: Kernel, stream: bool): + chat_history = MagicMock() + arguments = KernelArguments() + client = MagicMock(spec=MistralAsyncClient) + mock_settings = PromptExecutionSettings() + + chat_completion_base = MistralAIChatCompletion( + ai_model_id="test_model_id", service_id="test", api_key="", async_client=client + ) + + if stream: + async for content in chat_completion_base.get_streaming_chat_message_contents( + chat_history, mock_settings, kernel=kernel, arguments=arguments + ): + assert content is not None + else: + content = await chat_completion_base.get_chat_message_contents( + chat_history, mock_settings, kernel=kernel, arguments=arguments + ) + assert content is not None From 5751a092a480c7f3d89ce453b96e16d76e5c7a03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20M=C3=B6ller?= Date: Thu, 11 Jul 2024 08:36:14 +0200 Subject: [PATCH 2/3] added assert statement --- .../ai/mistral_ai/services/mistral_ai_chat_completion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py index 73fd58b43472..890447c46b5e 100644 --- a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py @@ -111,6 +111,7 @@ async def get_chat_message_contents( List[ChatMessageContent]: The completion result(s). """ settings = self.get_prompt_execution_settings_from_settings(settings) + assert isinstance(settings, MistralAIChatPromptExecutionSettings) # nosec if not settings.ai_model_id: settings.ai_model_id = self.ai_model_id @@ -147,6 +148,7 @@ async def get_streaming_chat_message_contents( StreamingChatMessageContent when using Azure. """ settings = self.get_prompt_execution_settings_from_settings(settings) + assert isinstance(settings, MistralAIChatPromptExecutionSettings) # nosec if not settings.ai_model_id: settings.ai_model_id = self.ai_model_id From 91f1dd64920262b8137f97adf5d6b43f4d146167 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20M=C3=B6ller?= Date: Thu, 11 Jul 2024 10:22:31 +0200 Subject: [PATCH 3/3] adjusted testcases to test settings conversion --- .../services/mistral_ai_chat_completion.py | 6 +- .../test_mistralai_chat_completion.py | 112 ++++++++++++------ 2 files changed, 81 insertions(+), 37 deletions(-) diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py index 890447c46b5e..ffd6bc2594ad 100644 --- a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_chat_completion.py @@ -110,7 +110,8 @@ async def get_chat_message_contents( Returns: List[ChatMessageContent]: The completion result(s). """ - settings = self.get_prompt_execution_settings_from_settings(settings) + if not isinstance(settings, MistralAIChatPromptExecutionSettings): + settings = self.get_prompt_execution_settings_from_settings(settings) assert isinstance(settings, MistralAIChatPromptExecutionSettings) # nosec if not settings.ai_model_id: @@ -147,7 +148,8 @@ async def get_streaming_chat_message_contents( List[StreamingChatMessageContent]: A stream of StreamingChatMessageContent when using Azure. """ - settings = self.get_prompt_execution_settings_from_settings(settings) + if not isinstance(settings, MistralAIChatPromptExecutionSettings): + settings = self.get_prompt_execution_settings_from_settings(settings) assert isinstance(settings, MistralAIChatPromptExecutionSettings) # nosec if not settings.ai_model_id: diff --git a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py index 10336f362513..ba1b0b51aa7b 100644 --- a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py +++ b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py @@ -9,7 +9,9 @@ MistralAIChatPromptExecutionSettings, ) from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_chat_completion import MistralAIChatCompletion -from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings +from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_prompt_execution_settings import ( + OpenAIChatPromptExecutionSettings, +) from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.exceptions import ServiceInitializationError, ServiceResponseException from semantic_kernel.functions.kernel_arguments import KernelArguments @@ -21,10 +23,8 @@ def mock_settings() -> MistralAIChatPromptExecutionSettings: return MistralAIChatPromptExecutionSettings() -@pytest.mark.asyncio -async def test_complete_chat_contents(kernel: Kernel, mock_settings: MistralAIChatPromptExecutionSettings): - chat_history = MagicMock() - arguments = KernelArguments() +@pytest.fixture +def mock_mistral_ai_client_completion() -> MistralAsyncClient: client = MagicMock(spec=MistralAsyncClient) chat_completion_response = AsyncMock() choices = [ @@ -32,20 +32,11 @@ async def test_complete_chat_contents(kernel: Kernel, mock_settings: MistralAICh ] chat_completion_response.choices = choices client.chat.return_value = chat_completion_response - chat_completion_base = MistralAIChatCompletion( - ai_model_id="test_model_id", service_id="test", api_key="", async_client=client - ) - - content: list[ChatMessageContent] = await chat_completion_base.get_chat_message_contents( - chat_history, mock_settings, kernel=kernel, arguments=arguments - ) - assert content is not None + return client -@pytest.mark.asyncio -async def test_complete_chat_stream_contents(kernel: Kernel, mock_settings: MistralAIChatPromptExecutionSettings): - chat_history = MagicMock() - arguments = KernelArguments() +@pytest.fixture +def mock_mistral_ai_client_completion_stream() -> MistralAsyncClient: client = MagicMock(spec=MistralAsyncClient) chat_completion_response = MagicMock() choices = [ @@ -58,9 +49,40 @@ async def test_complete_chat_stream_contents(kernel: Kernel, mock_settings: Mist generator_mock = MagicMock() generator_mock.__aiter__.return_value = [chat_completion_response_empty, chat_completion_response] client.chat_stream.return_value = generator_mock + return client + +@pytest.mark.asyncio +async def test_complete_chat_contents( + kernel: Kernel, + mock_settings: MistralAIChatPromptExecutionSettings, + mock_mistral_ai_client_completion: MistralAsyncClient +): + chat_history = MagicMock() + arguments = KernelArguments() chat_completion_base = MistralAIChatCompletion( - ai_model_id="test_model_id", service_id="test", api_key="", async_client=client + ai_model_id="test_model_id", service_id="test", api_key="", async_client=mock_mistral_ai_client_completion + ) + + content: list[ChatMessageContent] = await chat_completion_base.get_chat_message_contents( + chat_history, mock_settings, kernel=kernel, arguments=arguments + ) + assert content is not None + + +@pytest.mark.asyncio +async def test_complete_chat_stream_contents( + kernel: Kernel, + mock_settings: MistralAIChatPromptExecutionSettings, + mock_mistral_ai_client_completion_stream: MistralAsyncClient +): + chat_history = MagicMock() + arguments = KernelArguments() + + chat_completion_base = MistralAIChatCompletion( + ai_model_id="test_model_id", + service_id="test", api_key="", + async_client=mock_mistral_ai_client_completion_stream ) async for content in chat_completion_base.get_streaming_chat_message_contents( @@ -75,10 +97,11 @@ async def test_mistral_ai_sdk_exception(kernel: Kernel, mock_settings: MistralAI arguments = KernelArguments() client = MagicMock(spec=MistralAsyncClient) client.chat.side_effect = Exception("Test Exception") - client.chat_stream.side_effect = Exception("Test Exception") chat_completion_base = MistralAIChatCompletion( - ai_model_id="test_model_id", service_id="test", api_key="", async_client=client + ai_model_id="test_model_id", + service_id="test", api_key="", + async_client=client ) with pytest.raises(ServiceResponseException): @@ -138,25 +161,44 @@ def test_prompt_execution_settings_class(mistralai_unit_test_env): assert prompt_execution_settings == MistralAIChatPromptExecutionSettings -@pytest.mark.parametrize("stream", [True, False]) @pytest.mark.asyncio -async def test_with_prompt_execution_settings(kernel: Kernel, stream: bool): +async def test_with_different_execution_settings( + kernel: Kernel, + mock_mistral_ai_client_completion: MagicMock +): chat_history = MagicMock() + settings = OpenAIChatPromptExecutionSettings(temperature=0.2, seed=2) arguments = KernelArguments() - client = MagicMock(spec=MistralAsyncClient) - mock_settings = PromptExecutionSettings() + chat_completion_base = MistralAIChatCompletion( + ai_model_id="test_model_id", + service_id="test", api_key="", + async_client=mock_mistral_ai_client_completion + ) + + await chat_completion_base.get_chat_message_contents( + chat_history, settings, kernel=kernel, arguments=arguments + ) + assert mock_mistral_ai_client_completion.chat.call_args.kwargs["temperature"] == 0.2 + assert mock_mistral_ai_client_completion.chat.call_args.kwargs["seed"] == 2 + +@pytest.mark.asyncio +async def test_with_different_execution_settings_stream( + kernel: Kernel, + mock_mistral_ai_client_completion_stream: MagicMock +): + chat_history = MagicMock() + settings = OpenAIChatPromptExecutionSettings(temperature=0.2, seed=2) + arguments = KernelArguments() chat_completion_base = MistralAIChatCompletion( - ai_model_id="test_model_id", service_id="test", api_key="", async_client=client + ai_model_id="test_model_id", + service_id="test", api_key="", + async_client=mock_mistral_ai_client_completion_stream ) - if stream: - async for content in chat_completion_base.get_streaming_chat_message_contents( - chat_history, mock_settings, kernel=kernel, arguments=arguments - ): - assert content is not None - else: - content = await chat_completion_base.get_chat_message_contents( - chat_history, mock_settings, kernel=kernel, arguments=arguments - ) - assert content is not None + async for chunk in chat_completion_base.get_streaming_chat_message_contents( + chat_history, settings, kernel=kernel, arguments=arguments + ): + continue + assert mock_mistral_ai_client_completion_stream.chat_stream.call_args.kwargs["temperature"] == 0.2 + assert mock_mistral_ai_client_completion_stream.chat_stream.call_args.kwargs["seed"] == 2