Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
MistralAIChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.mistral_ai.settings.mistral_ai_settings import MistralAISettings
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
Expand Down Expand Up @@ -95,20 +96,24 @@ def __init__(
async def get_chat_message_contents(
self,
chat_history: "ChatHistory",
settings: "MistralAIChatPromptExecutionSettings", # type: ignore[override]
settings: "PromptExecutionSettings",
**kwargs: Any,
) -> list["ChatMessageContent"]:
"""Executes a chat completion request and returns the result.

Args:
chat_history (ChatHistory): The chat history to use for the chat completion.
settings (MistralAIChatPromptExecutionSettings): The settings to use
settings (PromptExecutionSettings): The settings to use
for the chat completion request.
kwargs (Dict[str, Any]): The optional arguments.

Returns:
List[ChatMessageContent]: The completion result(s).
"""
if not isinstance(settings, MistralAIChatPromptExecutionSettings):
settings = self.get_prompt_execution_settings_from_settings(settings)
assert isinstance(settings, MistralAIChatPromptExecutionSettings) # nosec

if not settings.ai_model_id:
settings.ai_model_id = self.ai_model_id

Expand All @@ -128,21 +133,25 @@ async def get_chat_message_contents(
async def get_streaming_chat_message_contents(
self,
chat_history: ChatHistory,
settings: MistralAIChatPromptExecutionSettings, # type: ignore[override]
settings: PromptExecutionSettings,
**kwargs: Any,
) -> AsyncGenerator[list[StreamingChatMessageContent], Any]:
"""Executes a streaming chat completion request and returns the result.

Args:
chat_history (ChatHistory): The chat history to use for the chat completion.
settings (MistralAIChatPromptExecutionSettings): The settings to use
settings (PromptExecutionSettings): The settings to use
for the chat completion request.
kwargs (Dict[str, Any]): The optional arguments.

Yields:
List[StreamingChatMessageContent]: A stream of
StreamingChatMessageContent when using Azure.
"""
if not isinstance(settings, MistralAIChatPromptExecutionSettings):
settings = self.get_prompt_execution_settings_from_settings(settings)
assert isinstance(settings, MistralAIChatPromptExecutionSettings) # nosec

if not settings.ai_model_id:
settings.ai_model_id = self.ai_model_id

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,34 @@
MistralAIChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_chat_completion import MistralAIChatCompletion
from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_prompt_execution_settings import (
OpenAIChatPromptExecutionSettings,
)
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.exceptions import ServiceInitializationError, ServiceResponseException
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.kernel import Kernel


@pytest.mark.asyncio
async def test_complete_chat_contents(kernel: Kernel):
chat_history = MagicMock()
settings = MagicMock()
settings.number_of_responses = 1
arguments = KernelArguments()
@pytest.fixture
def mock_settings() -> MistralAIChatPromptExecutionSettings:
return MistralAIChatPromptExecutionSettings()


@pytest.fixture
def mock_mistral_ai_client_completion() -> MistralAsyncClient:
client = MagicMock(spec=MistralAsyncClient)
chat_completion_response = AsyncMock()
choices = [
MagicMock(finish_reason="stop", message=MagicMock(role="assistant", content="Test"))
]
chat_completion_response.choices = choices
client.chat.return_value = chat_completion_response
chat_completion_base = MistralAIChatCompletion(
ai_model_id="test_model_id", service_id="test", api_key="", async_client=client
)

content: list[ChatMessageContent] = await chat_completion_base.get_chat_message_contents(
chat_history, settings, kernel=kernel, arguments=arguments
)
assert content is not None
return client


@pytest.mark.asyncio
async def test_complete_chat_stream_contents(kernel: Kernel):
chat_history = MagicMock()
settings = MagicMock()
settings.ai_model_id = None
arguments = KernelArguments()
@pytest.fixture
def mock_mistral_ai_client_completion_stream() -> MistralAsyncClient:
client = MagicMock(spec=MistralAsyncClient)
chat_completion_response = MagicMock()
choices = [
Expand All @@ -56,42 +49,70 @@ async def test_complete_chat_stream_contents(kernel: Kernel):
generator_mock = MagicMock()
generator_mock.__aiter__.return_value = [chat_completion_response_empty, chat_completion_response]
client.chat_stream.return_value = generator_mock
return client


@pytest.mark.asyncio
async def test_complete_chat_contents(
kernel: Kernel,
mock_settings: MistralAIChatPromptExecutionSettings,
mock_mistral_ai_client_completion: MistralAsyncClient
):
chat_history = MagicMock()
arguments = KernelArguments()
chat_completion_base = MistralAIChatCompletion(
ai_model_id="test_model_id", service_id="test", api_key="", async_client=client
ai_model_id="test_model_id", service_id="test", api_key="", async_client=mock_mistral_ai_client_completion
)

content: list[ChatMessageContent] = await chat_completion_base.get_chat_message_contents(
chat_history, mock_settings, kernel=kernel, arguments=arguments
)
assert content is not None


@pytest.mark.asyncio
async def test_complete_chat_stream_contents(
kernel: Kernel,
mock_settings: MistralAIChatPromptExecutionSettings,
mock_mistral_ai_client_completion_stream: MistralAsyncClient
):
chat_history = MagicMock()
arguments = KernelArguments()

chat_completion_base = MistralAIChatCompletion(
ai_model_id="test_model_id",
service_id="test", api_key="",
async_client=mock_mistral_ai_client_completion_stream
)

async for content in chat_completion_base.get_streaming_chat_message_contents(
chat_history, settings, kernel=kernel, arguments=arguments
chat_history, mock_settings, kernel=kernel, arguments=arguments
):
assert content is not None


@pytest.mark.asyncio
async def test_mistral_ai_sdk_exception(kernel: Kernel):
async def test_mistral_ai_sdk_exception(kernel: Kernel, mock_settings: MistralAIChatPromptExecutionSettings):
chat_history = MagicMock()
settings = MagicMock()
settings.ai_model_id = None
arguments = KernelArguments()
client = MagicMock(spec=MistralAsyncClient)
client.chat.side_effect = Exception("Test Exception")
client.chat_stream.side_effect = Exception("Test Exception")

chat_completion_base = MistralAIChatCompletion(
ai_model_id="test_model_id", service_id="test", api_key="", async_client=client
ai_model_id="test_model_id",
service_id="test", api_key="",
async_client=client
)

with pytest.raises(ServiceResponseException):
await chat_completion_base.get_chat_message_contents(
chat_history, settings, kernel=kernel, arguments=arguments
chat_history, mock_settings, kernel=kernel, arguments=arguments
)


@pytest.mark.asyncio
async def test_mistral_ai_sdk_exception_streaming(kernel: Kernel):
async def test_mistral_ai_sdk_exception_streaming(kernel: Kernel, mock_settings: MistralAIChatPromptExecutionSettings):
chat_history = MagicMock()
settings = MagicMock()
settings.number_of_responses = 1
arguments = KernelArguments()
client = MagicMock(spec=MistralAsyncClient)
client.chat_stream.side_effect = Exception("Test Exception")
Expand All @@ -102,7 +123,7 @@ async def test_mistral_ai_sdk_exception_streaming(kernel: Kernel):

with pytest.raises(ServiceResponseException):
async for content in chat_completion_base.get_streaming_chat_message_contents(
chat_history, settings, kernel=kernel, arguments=arguments
chat_history, mock_settings, kernel=kernel, arguments=arguments
):
assert content is not None

Expand Down Expand Up @@ -138,3 +159,46 @@ def test_prompt_execution_settings_class(mistralai_unit_test_env):
mistral_ai_chat_completion = MistralAIChatCompletion()
prompt_execution_settings = mistral_ai_chat_completion.get_prompt_execution_settings_class()
assert prompt_execution_settings == MistralAIChatPromptExecutionSettings


@pytest.mark.asyncio
async def test_with_different_execution_settings(
kernel: Kernel,
mock_mistral_ai_client_completion: MagicMock
):
chat_history = MagicMock()
settings = OpenAIChatPromptExecutionSettings(temperature=0.2, seed=2)
arguments = KernelArguments()
chat_completion_base = MistralAIChatCompletion(
ai_model_id="test_model_id",
service_id="test", api_key="",
async_client=mock_mistral_ai_client_completion
)

await chat_completion_base.get_chat_message_contents(
chat_history, settings, kernel=kernel, arguments=arguments
)
assert mock_mistral_ai_client_completion.chat.call_args.kwargs["temperature"] == 0.2
assert mock_mistral_ai_client_completion.chat.call_args.kwargs["seed"] == 2


@pytest.mark.asyncio
async def test_with_different_execution_settings_stream(
kernel: Kernel,
mock_mistral_ai_client_completion_stream: MagicMock
):
chat_history = MagicMock()
settings = OpenAIChatPromptExecutionSettings(temperature=0.2, seed=2)
arguments = KernelArguments()
chat_completion_base = MistralAIChatCompletion(
ai_model_id="test_model_id",
service_id="test", api_key="",
async_client=mock_mistral_ai_client_completion_stream
)

async for chunk in chat_completion_base.get_streaming_chat_message_contents(
chat_history, settings, kernel=kernel, arguments=arguments
):
continue
assert mock_mistral_ai_client_completion_stream.chat_stream.call_args.kwargs["temperature"] == 0.2
assert mock_mistral_ai_client_completion_stream.chat_stream.call_args.kwargs["seed"] == 2