Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@
" results = await oai_text_service.get_text_contents(prompt=prompt, settings=oai_text_prompt_execution_settings)\n",
"\n",
" for i, result in enumerate(results):\n",
" print(f\"Result {i+1}: {result}\")"
" print(f\"Result {i + 1}: {result}\")"
]
},
{
Expand All @@ -276,7 +276,7 @@
" results = await aoai_text_service.get_text_contents(prompt=prompt, settings=oai_text_prompt_execution_settings)\n",
"\n",
" for i, result in enumerate(results):\n",
" print(f\"Result {i+1}: {result}\")"
" print(f\"Result {i + 1}: {result}\")"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ async def download_file(
Returns:
BufferedReader: The data of the downloaded file.
"""
auth_token = await self.auth_callback()
auth_token = await self._ensure_auth_token()
self.http_client.headers.update(
{
"Authorization": f"Bearer {auth_token}",
Expand Down
20 changes: 11 additions & 9 deletions python/semantic_kernel/functions/kernel_function_from_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def __init__(
"stream_method": (
stream_method
if stream_method is not None
else method if isasyncgenfunction(method) or isgeneratorfunction(method) else None
else method
if isasyncgenfunction(method) or isgeneratorfunction(method)
else None
),
}

Expand Down Expand Up @@ -119,9 +121,7 @@ async def _invoke_internal_stream(self, context: FunctionInvocationContext) -> N
function_arguments = self.gather_function_parameters(context)
context.result = FunctionResult(function=self.metadata, value=self.stream_method(**function_arguments))

def gather_function_parameters(
self, context: FunctionInvocationContext
) -> dict[str, Any]:
def gather_function_parameters(self, context: FunctionInvocationContext) -> dict[str, Any]:
"""Gathers the function parameters from the arguments."""
function_arguments: dict[str, Any] = {}
for param in self.parameters:
Expand All @@ -141,8 +141,12 @@ def gather_function_parameters(
continue
if param.name in context.arguments:
value: Any = context.arguments[param.name]
if (param.type_ and "," not in param.type_ and
param.type_object and param.type_object is not inspect._empty):
if (
param.type_
and "," not in param.type_
and param.type_object
and param.type_object is not inspect._empty
):
if hasattr(param.type_object, "model_validate"):
try:
value = param.type_object.model_validate(value)
Expand All @@ -167,7 +171,5 @@ def gather_function_parameters(
raise FunctionExecutionException(
f"Parameter {param.name} is required but not provided in the arguments."
)
logger.debug(
f"Parameter {param.name} is not provided, using default value {param.default_value}"
)
logger.debug(f"Parameter {param.name} is not provided, using default value {param.default_value}")
return function_arguments
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Microsoft. All rights reserved.


import pytest

from semantic_kernel.connectors.openai_plugin.openai_utils import OpenAIUtils
from semantic_kernel.exceptions import PluginInitializationError


def test_parse_openai_manifest_for_openapi_spec_url_valid():
plugin_json = {"api": {"type": "openapi", "url": "https://example.com/openapi.json"}}
result = OpenAIUtils.parse_openai_manifest_for_openapi_spec_url(plugin_json)
assert result == "https://example.com/openapi.json"


def test_parse_openai_manifest_for_openapi_spec_url_missing_api_type():
plugin_json = {"api": {}}
with pytest.raises(PluginInitializationError, match="OpenAI manifest is missing the API type."):
OpenAIUtils.parse_openai_manifest_for_openapi_spec_url(plugin_json)


def test_parse_openai_manifest_for_openapi_spec_url_invalid_api_type():
plugin_json = {"api": {"type": "other", "url": "https://example.com/openapi.json"}}
with pytest.raises(PluginInitializationError, match="OpenAI manifest is not of type OpenAPI."):
OpenAIUtils.parse_openai_manifest_for_openapi_spec_url(plugin_json)


def test_parse_openai_manifest_for_openapi_spec_url_missing_url():
plugin_json = {"api": {"type": "openapi"}}
with pytest.raises(PluginInitializationError, match="OpenAI manifest is missing the OpenAPI Spec URL."):
OpenAIUtils.parse_openai_manifest_for_openapi_spec_url(plugin_json)
11 changes: 11 additions & 0 deletions python/tests/unit/connectors/openapi/test_openapi_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from semantic_kernel.connectors.openapi_plugin.openapi_manager import (
_create_function_from_operation,
create_functions_from_openapi,
)
from semantic_kernel.exceptions import FunctionExecutionException
from semantic_kernel.functions.kernel_function_decorator import kernel_function
Expand Down Expand Up @@ -222,3 +223,13 @@ async def run_openapi_operation(kernel, **kwargs):
assert str(result) == "Operation Result"
run_operation_mock.assert_called_once()
assert runner.run_operation.call_args[0][1]["param1"] == "value1"


@pytest.mark.asyncio
@patch("semantic_kernel.connectors.openapi_plugin.openapi_parser.OpenApiParser.parse", return_value=None)
async def test_create_functions_from_openapi_raises_exception(mock_parse):
"""Test that an exception is raised when parsing fails."""
with pytest.raises(FunctionExecutionException, match="Error parsing OpenAPI document: test_openapi_document_path"):
create_functions_from_openapi(plugin_name="test_plugin", openapi_document_path="test_openapi_document_path")

mock_parse.assert_called_once_with("test_openapi_document_path")
159 changes: 158 additions & 1 deletion python/tests/unit/core_plugins/test_sessions_python_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@

import httpx
import pytest
from httpx import HTTPStatusError

from semantic_kernel.core_plugins.sessions_python_tool.sessions_python_plugin import SessionsPythonTool
from semantic_kernel.core_plugins.sessions_python_tool.sessions_python_plugin import (
SESSIONS_API_VERSION,
SessionsPythonTool,
)
from semantic_kernel.core_plugins.sessions_python_tool.sessions_remote_file_metadata import SessionsRemoteFileMetadata
from semantic_kernel.exceptions.function_exceptions import FunctionExecutionException, FunctionInitializationError
from semantic_kernel.kernel import Kernel

Expand All @@ -25,6 +30,53 @@ def test_validate_endpoint(aca_python_sessions_unit_test_env):
assert str(plugin.pool_management_endpoint) == aca_python_sessions_unit_test_env["ACA_POOL_MANAGEMENT_ENDPOINT"]


@pytest.mark.parametrize(
"base_url, endpoint, params, expected_url",
[
(
"http://example.com",
"api/resource",
{"param1": "value1", "param2": "value2"},
f"http://example.com/api/resource?param1=value1&param2=value2&api-version={SESSIONS_API_VERSION}",
),
(
"http://example.com/",
"api/resource",
{"param1": "value1"},
f"http://example.com/api/resource?param1=value1&api-version={SESSIONS_API_VERSION}",
),
(
"http://example.com",
"api/resource/",
{"param1": "value1", "param2": "value2"},
f"http://example.com/api/resource?param1=value1&param2=value2&api-version={SESSIONS_API_VERSION}",
),
(
"http://example.com/",
"api/resource/",
{"param1": "value1"},
f"http://example.com/api/resource?param1=value1&api-version={SESSIONS_API_VERSION}",
),
(
"http://example.com",
"api/resource",
{},
f"http://example.com/api/resource?api-version={SESSIONS_API_VERSION}",
),
(
"http://example.com/",
"api/resource",
{},
f"http://example.com/api/resource?api-version={SESSIONS_API_VERSION}",
),
],
)
def test_build_url_with_version(base_url, endpoint, params, expected_url, aca_python_sessions_unit_test_env):
plugin = SessionsPythonTool(auth_callback=auth_callback_test)
result = plugin._build_url_with_version(base_url, endpoint, params)
assert result == expected_url


@pytest.mark.parametrize(
"override_env_param_dict",
[
Expand Down Expand Up @@ -229,6 +281,37 @@ async def async_return(result):
mock_post.assert_awaited_once()


@pytest.mark.asyncio
@patch("httpx.AsyncClient.post")
async def test_upload_file_throws_exception(mock_post, aca_python_sessions_unit_test_env):
"""Test throwing exception during file upload."""

async def async_raise_http_error(*args, **kwargs):
mock_request = httpx.Request(method="POST", url="https://example.com/files/upload")
mock_response = httpx.Response(status_code=500, request=mock_request)
raise HTTPStatusError("Server Error", request=mock_request, response=mock_response)

with (
patch(
"semantic_kernel.core_plugins.sessions_python_tool.sessions_python_plugin.SessionsPythonTool._ensure_auth_token",
return_value="test_token",
),
patch("builtins.open", mock_open(read_data=b"file data")),
):
mock_post.side_effect = async_raise_http_error

plugin = SessionsPythonTool(
auth_callback=lambda: "sample_token",
env_file_path="test.env",
)

with pytest.raises(
FunctionExecutionException, match="Upload failed with status code 500 and error: Internal Server Error"
):
await plugin.upload_file(local_file_path="hello.py")
mock_post.assert_awaited_once()


@pytest.mark.parametrize(
"local_file_path, input_remote_file_path, expected_remote_file_path",
[
Expand Down Expand Up @@ -349,6 +432,36 @@ async def async_return(result):
mock_get.assert_awaited_once()


@pytest.mark.asyncio
@patch("httpx.AsyncClient.get")
async def test_list_files_throws_exception(mock_get, aca_python_sessions_unit_test_env):
"""Test throwing exception during list files."""

async def async_raise_http_error(*args, **kwargs):
mock_request = httpx.Request(method="GET", url="https://example.com/files?identifier=None")
mock_response = httpx.Response(status_code=500, request=mock_request)
raise HTTPStatusError("Server Error", request=mock_request, response=mock_response)

with (
patch(
"semantic_kernel.core_plugins.sessions_python_tool.sessions_python_plugin.SessionsPythonTool._ensure_auth_token",
return_value="test_token",
),
):
mock_get.side_effect = async_raise_http_error

plugin = SessionsPythonTool(
auth_callback=lambda: "sample_token",
env_file_path="test.env",
)

with pytest.raises(
FunctionExecutionException, match="List files failed with status code 500 and error: Internal Server Error"
):
await plugin.list_files()
mock_get.assert_awaited_once()


@pytest.mark.asyncio
@patch("httpx.AsyncClient.get")
async def test_download_file_to_local(mock_get, aca_python_sessions_unit_test_env):
Expand Down Expand Up @@ -417,6 +530,38 @@ async def mock_auth_callback():
mock_get.assert_awaited_once()


@pytest.mark.asyncio
@patch("httpx.AsyncClient.get")
async def test_download_file_throws_exception(mock_get, aca_python_sessions_unit_test_env):
"""Test throwing exception during download file."""

async def async_raise_http_error(*args, **kwargs):
mock_request = httpx.Request(
method="GET", url="https://example.com/files/content/remote_test.txt?identifier=None"
)
mock_response = httpx.Response(status_code=500, request=mock_request)
raise HTTPStatusError("Server Error", request=mock_request, response=mock_response)

with (
patch(
"semantic_kernel.core_plugins.sessions_python_tool.sessions_python_plugin.SessionsPythonTool._ensure_auth_token",
return_value="test_token",
),
):
mock_get.side_effect = async_raise_http_error

plugin = SessionsPythonTool(
auth_callback=lambda: "sample_token",
env_file_path="test.env",
)

with pytest.raises(
FunctionExecutionException, match="Download failed with status code 500 and error: Internal Server Error"
):
await plugin.download_file(remote_file_name="remote_test.txt")
mock_get.assert_awaited_once()


@pytest.mark.parametrize(
"input_code, expected_output",
[
Expand Down Expand Up @@ -466,3 +611,15 @@ async def token_cb():
FunctionExecutionException, match="Failed to retrieve the client auth token with messages: Could not get token."
):
await plugin._ensure_auth_token()


@pytest.mark.parametrize(
"filename, expected_full_path",
[
("/mnt/data/testfile.txt", "/mnt/data/testfile.txt"),
("testfile.txt", "/mnt/data/testfile.txt"),
],
)
def test_full_path(filename, expected_full_path):
metadata = SessionsRemoteFileMetadata(filename=filename, size_in_bytes=123)
assert metadata.full_path == expected_full_path
Loading