Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.

from typing import Literal
from typing import Any, Literal

from pydantic import Field

Expand Down Expand Up @@ -30,6 +30,9 @@ class AzureAIInferencePromptExecutionSettings(PromptExecutionSettings):
class AzureAIInferenceChatPromptExecutionSettings(AzureAIInferencePromptExecutionSettings):
"""Azure AI Inference Chat Prompt Execution Settings."""

tools: list[dict[str, Any]] | None = Field(None, max_length=64)
tool_choice: str | None = None


@experimental_class
class AzureAIInferenceEmbeddingPromptExecutionSettings(PromptExecutionSettings):
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Microsoft. All rights reserved.

import logging
from collections.abc import Callable

from azure.ai.inference.models import (
AssistantMessage,
ChatCompletionsFunctionToolCall,
ChatRequestMessage,
FunctionCall,
ImageContentItem,
ImageDetailLevel,
ImageUrl,
SystemMessage,
TextContentItem,
ToolMessage,
UserMessage,
)

from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.image_content import ImageContent
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.contents.utils.author_role import AuthorRole

logger: logging.Logger = logging.getLogger(__name__)


def _format_system_message(message: ChatMessageContent) -> SystemMessage:
"""Format a system message to the expected object for the client.

Args:
message: The system message.

Returns:
The formatted system message.
"""
return SystemMessage(content=message.content)


def _format_user_message(message: ChatMessageContent) -> UserMessage:
"""Format a user message to the expected object for the client.

If there are any image items in the message, we need to create a list of content items,
otherwise we need to just pass in the content as a string or it will error.

Args:
message: The user message.

Returns:
The formatted user message.
"""
if not any(isinstance(item, (ImageContent)) for item in message.items):
return UserMessage(content=message.content)

contentItems = []
for item in message.items:
if isinstance(item, TextContent):
contentItems.append(TextContentItem(text=item.text))
elif isinstance(item, ImageContent) and (item.data_uri or item.uri):
contentItems.append(
ImageContentItem(image_url=ImageUrl(url=item.data_uri or str(item.uri), detail=ImageDetailLevel.Auto))
)
else:
logger.warning(
"Unsupported item type in User message while formatting chat history for Azure AI"
f" Inference: {type(item)}"
)

return UserMessage(content=contentItems)


def _format_assistant_message(message: ChatMessageContent) -> AssistantMessage:
"""Format an assistant message to the expected object for the client.

Args:
message: The assistant message.

Returns:
The formatted assistant message.
"""
contentItems = []
toolCalls = []

for item in message.items:
if isinstance(item, TextContent):
contentItems.append(TextContentItem(text=item.text))
elif isinstance(item, FunctionCallContent):
toolCalls.append(
ChatCompletionsFunctionToolCall(
id=item.id, function=FunctionCall(name=item.name, arguments=item.arguments)
)
)
else:
logger.warning(
"Unsupported item type in Assistant message while formatting chat history for Azure AI"
f" Inference: {type(item)}"
)

# tollCalls cannot be an empty list, so we need to set it to None if it is empty
return AssistantMessage(content=contentItems, tool_calls=toolCalls if toolCalls else None)


def _format_tool_message(message: ChatMessageContent) -> ToolMessage:
"""Format a tool message to the expected object for the client.

Args:
message: The tool message.

Returns:
The formatted tool message.
"""
if len(message.items) != 1:
logger.warning(
"Unsupported number of items in Tool message while formatting chat history for Azure AI"
f" Inference: {len(message.items)}"
)

if not isinstance(message.items[0], FunctionResultContent):
logger.warning(
"Unsupported item type in Tool message while formatting chat history for Azure AI"
f" Inference: {type(message.items[0])}"
)

# The API expects the result to be a string, so we need to convert it to a string
return ToolMessage(content=str(message.items[0].result), tool_call_id=message.items[0].id)


MESSAGE_CONVERTERS: dict[AuthorRole, Callable[[ChatMessageContent], ChatRequestMessage]] = {
AuthorRole.SYSTEM: _format_system_message,
AuthorRole.USER: _format_user_message,
AuthorRole.ASSISTANT: _format_assistant_message,
AuthorRole.TOOL: _format_tool_message,
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _prepare_chat_history_for_request(
chat_history: "ChatHistory",
role_key: str = "role",
content_key: str = "content",
) -> list[dict[str, str | None]]:
) -> Any:
Comment thread
TaoChenOSU marked this conversation as resolved.
"""Prepare the chat history for a request.

Allowing customization of the key names for role/author, and optionally overriding the role.
Expand All @@ -68,12 +68,14 @@ def _prepare_chat_history_for_request(
They require a "tool_call_id" and (function) "name" key, and the "metadata" key should
be removed. The "encoding" key should also be removed.

Override this method to customize the formatting of the chat history for a request.

Args:
chat_history (ChatHistory): The chat history to prepare.
role_key (str): The key name for the role/author.
content_key (str): The key name for the content/message.

Returns:
List[Dict[str, Optional[str]]]: The prepared chat history.
prepared_chat_history (Any): The prepared chat history for a request.
"""
return [message.to_dict(role_key=role_key, content_key=content_key) for message in chat_history.messages]
28 changes: 10 additions & 18 deletions python/semantic_kernel/connectors/ai/function_calling_utils.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,23 @@
# Copyright (c) Microsoft. All rights reserved.

import logging
from typing import TYPE_CHECKING, Any
from typing import Any

from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_prompt_execution_settings import (
OpenAIChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionCallChoiceConfiguration
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata

if TYPE_CHECKING:
from semantic_kernel.connectors.ai.function_choice_behavior import (
FunctionCallChoiceConfiguration,
)
from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_prompt_execution_settings import (
OpenAIChatPromptExecutionSettings,
)

logger = logging.getLogger(__name__)


def update_settings_from_function_call_configuration(
function_choice_configuration: "FunctionCallChoiceConfiguration",
settings: "OpenAIChatPromptExecutionSettings",
function_choice_configuration: FunctionCallChoiceConfiguration,
settings: PromptExecutionSettings,
type: str,
) -> None:
"""Update the settings from a FunctionChoiceConfiguration."""
if function_choice_configuration.available_functions:
if (
function_choice_configuration.available_functions
and hasattr(settings, "tool_choice")
and hasattr(settings, "tools")
):
settings.tool_choice = type
settings.tools = [
kernel_function_metadata_to_function_call_format(f)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@

from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
from semantic_kernel.connectors.ai.function_call_behavior import FunctionCallBehavior
from semantic_kernel.connectors.ai.function_calling_utils import (
update_settings_from_function_call_configuration,
)
from semantic_kernel.connectors.ai.function_choice_behavior import (
FunctionChoiceBehavior,
)
from semantic_kernel.connectors.ai.function_calling_utils import update_settings_from_function_call_configuration
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_prompt_execution_settings import (
OpenAIChatPromptExecutionSettings,
)
Expand All @@ -33,10 +29,7 @@
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.contents.utils.finish_reason import FinishReason
from semantic_kernel.exceptions import (
ServiceInvalidExecutionSettingsError,
ServiceInvalidResponseError,
)
from semantic_kernel.exceptions import ServiceInvalidExecutionSettingsError, ServiceInvalidResponseError
from semantic_kernel.filters.auto_function_invocation.auto_function_invocation_context import (
AutoFunctionInvocationContext,
)
Expand Down
52 changes: 52 additions & 0 deletions python/tests/integration/completions/test_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,58 @@ def services() -> dict[str, tuple[ChatCompletionClientBase, type[PromptExecution
["house", "germany"],
id="azure_ai_inference_image_input_file",
),
pytest.param(
"azure_ai_inference",
{
"function_choice_behavior": FunctionChoiceBehavior.Auto(
auto_invoke=True, filters={"excluded_plugins": ["chat"]}
)
},
[
ChatMessageContent(role=AuthorRole.USER, items=[TextContent(text="What is 3+345?")]),
],
["348"],
id="azure_ai_inference_tool_call_auto",
),
pytest.param(
"azure_ai_inference",
{
"function_choice_behavior": FunctionChoiceBehavior.Auto(
auto_invoke=False, filters={"excluded_plugins": ["chat"]}
)
},
[
ChatMessageContent(role=AuthorRole.USER, items=[TextContent(text="What is 3+345?")]),
],
["348"],
id="azure_ai_inference_tool_call_non_auto",
),
pytest.param(
"azure_ai_inference",
{},
[
[
ChatMessageContent(
role=AuthorRole.USER,
items=[TextContent(text="What was our 2024 revenue?")],
),
ChatMessageContent(
role=AuthorRole.ASSISTANT,
items=[
FunctionCallContent(
id="fin", name="finance-search", arguments='{"company": "contoso", "year": 2024}'
)
],
),
ChatMessageContent(
role=AuthorRole.TOOL,
items=[FunctionResultContent(id="fin", name="finance-search", result="1.2B")],
),
],
],
["1.2"],
id="azure_ai_inference_tool_call_flow",
),
pytest.param(
"mistral_ai",
{},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ async def test_process_tool_calls_with_continuation_on_malformed_arguments():
ai_model_id="test_model_id", service_id="test", client=MagicMock(spec=AsyncOpenAI)
)

with patch("semantic_kernel.connectors.ai.function_calling_utils.logger", autospec=True):
with patch("semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion_base.logger", autospec=True):
await chat_completion_base._process_function_call(
tool_call_mock,
chat_history_mock,
Expand Down