From 8a59f0a8ea9d18cd8dd2210748bd3c3a84f70b36 Mon Sep 17 00:00:00 2001 From: SergeyMenshykh Date: Fri, 28 Jun 2024 17:27:10 +0100 Subject: [PATCH 1/6] feature(azure-open-ai): 1. Add serices collection and kernle builder extension methods to create and register chat completion service. 2. Add integration tests for the azure chat completion service. --- ...eOpenAIServiceCollectionExtensionsTests.cs | 63 ++ ...enAIServiceKernelBuilderExtensionsTests.cs | 63 ++ .../ChatHistoryExtensionsTests.cs | 2 +- .../Connectors.AzureOpenAI/Core/ClientCore.cs | 2 +- .../AzureOpenAIServiceCollectionExtensions.cs | 249 ++++++ .../AzureOpenAIChatCompletionTests.cs | 273 ++++++ ...enAIChatCompletion_FunctionCallingTests.cs | 781 ++++++++++++++++++ ...eOpenAIChatCompletion_NonStreamingTests.cs | 180 ++++ ...zureOpenAIChatCompletion_StreamingTests.cs | 174 ++++ .../IntegrationTestsV2.csproj | 1 + dotnet/src/IntegrationTestsV2/TestHelpers.cs | 55 ++ .../TestSettings/AzureOpenAIConfiguration.cs | 19 + 12 files changed, 1860 insertions(+), 2 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceCollectionExtensionsTests.cs create mode 100644 dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceKernelBuilderExtensionsTests.cs rename dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/{ => Extensions}/ChatHistoryExtensionsTests.cs (95%) create mode 100644 dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIServiceCollectionExtensions.cs create mode 100644 dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletionTests.cs create mode 100644 dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_FunctionCallingTests.cs create mode 100644 dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_NonStreamingTests.cs create mode 100644 dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_StreamingTests.cs create mode 100644 dotnet/src/IntegrationTestsV2/TestHelpers.cs create mode 100644 dotnet/src/IntegrationTestsV2/TestSettings/AzureOpenAIConfiguration.cs diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceCollectionExtensionsTests.cs new file mode 100644 index 000000000000..15aab3daf8b0 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceCollectionExtensionsTests.cs @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Azure.AI.OpenAI; +using Azure.Core; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using Microsoft.SemanticKernel.TextGeneration; + +namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests.Extensions; + +/// +/// Unit tests for class. +/// +public sealed class AzureOpenAIServiceCollectionExtensionsTests +{ + #region Chat completion + + [Theory] + [InlineData(InitializationType.ApiKey)] + [InlineData(InitializationType.TokenCredential)] + [InlineData(InitializationType.OpenAIClientInline)] + [InlineData(InitializationType.OpenAIClientInServiceProvider)] + public void ServiceCollectionAddAzureOpenAIChatCompletionAddsValidService(InitializationType type) + { + // Arrange + var credentials = DelegatedTokenCredential.Create((_, _) => new AccessToken()); + var client = new AzureOpenAIClient(new Uri("http://localhost"), "key"); + var builder = Kernel.CreateBuilder(); + + builder.Services.AddSingleton(client); + + // Act + IServiceCollection collection = type switch + { + InitializationType.ApiKey => builder.Services.AddAzureOpenAIChatCompletion("deployment-name", "https://endpoint", "api-key"), + InitializationType.TokenCredential => builder.Services.AddAzureOpenAIChatCompletion("deployment-name", "https://endpoint", credentials), + InitializationType.OpenAIClientInline => builder.Services.AddAzureOpenAIChatCompletion("deployment-name", client), + InitializationType.OpenAIClientInServiceProvider => builder.Services.AddAzureOpenAIChatCompletion("deployment-name"), + _ => builder.Services + }; + + // Assert + var chatCompletionService = builder.Build().GetRequiredService(); + Assert.True(chatCompletionService is AzureOpenAIChatCompletionService); + + var textGenerationService = builder.Build().GetRequiredService(); + Assert.True(textGenerationService is AzureOpenAIChatCompletionService); + } + + #endregion + + public enum InitializationType + { + ApiKey, + TokenCredential, + OpenAIClientInline, + OpenAIClientInServiceProvider, + OpenAIClientEndpoint, + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceKernelBuilderExtensionsTests.cs new file mode 100644 index 000000000000..2ecd92151871 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceKernelBuilderExtensionsTests.cs @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Azure.AI.OpenAI; +using Azure.Core; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using Microsoft.SemanticKernel.TextGeneration; + +namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests.Extensions; + +/// +/// Unit tests for class. +/// +public sealed class AzureOpenAIServiceKernelBuilderExtensionsTests +{ + #region Chat completion + + [Theory] + [InlineData(InitializationType.ApiKey)] + [InlineData(InitializationType.TokenCredential)] + [InlineData(InitializationType.OpenAIClientInline)] + [InlineData(InitializationType.OpenAIClientInServiceProvider)] + public void KernelBuilderAddAzureOpenAIChatCompletionAddsValidService(InitializationType type) + { + // Arrange + var credentials = DelegatedTokenCredential.Create((_, _) => new AccessToken()); + var client = new AzureOpenAIClient(new Uri("http://localhost"), "key"); + var builder = Kernel.CreateBuilder(); + + builder.Services.AddSingleton(client); + + // Act + builder = type switch + { + InitializationType.ApiKey => builder.AddAzureOpenAIChatCompletion("deployment-name", "https://endpoint", "api-key"), + InitializationType.TokenCredential => builder.AddAzureOpenAIChatCompletion("deployment-name", "https://endpoint", credentials), + InitializationType.OpenAIClientInline => builder.AddAzureOpenAIChatCompletion("deployment-name", client), + InitializationType.OpenAIClientInServiceProvider => builder.AddAzureOpenAIChatCompletion("deployment-name"), + _ => builder + }; + + // Assert + var chatCompletionService = builder.Build().GetRequiredService(); + Assert.True(chatCompletionService is AzureOpenAIChatCompletionService); + + var textGenerationService = builder.Build().GetRequiredService(); + Assert.True(textGenerationService is AzureOpenAIChatCompletionService); + } + + #endregion + + public enum InitializationType + { + ApiKey, + TokenCredential, + OpenAIClientInline, + OpenAIClientInServiceProvider, + OpenAIClientEndpoint, + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/ChatHistoryExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/ChatHistoryExtensionsTests.cs similarity index 95% rename from dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/ChatHistoryExtensionsTests.cs rename to dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/ChatHistoryExtensionsTests.cs index a0579f6d6c72..94fc1e5d1a5c 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/ChatHistoryExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/ChatHistoryExtensionsTests.cs @@ -7,7 +7,7 @@ using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.AzureOpenAI; -namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests; +namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests.Extensions; public class ChatHistoryExtensionsTests { [Fact] diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI/Core/ClientCore.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI/Core/ClientCore.cs index 6486d7348144..4152f2137409 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI/Core/ClientCore.cs +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI/Core/ClientCore.cs @@ -126,7 +126,7 @@ internal ClientCore(ILogger? logger = null) private static Dictionary GetChatChoiceMetadata(OpenAIChatCompletion completions) { #pragma warning disable AOAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - return new Dictionary(12) + return new Dictionary(8) { { nameof(completions.Id), completions.Id }, { nameof(completions.CreatedAt), completions.CreatedAt }, diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIServiceCollectionExtensions.cs new file mode 100644 index 000000000000..782889c4542c --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIServiceCollectionExtensions.cs @@ -0,0 +1,249 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using Azure; +using Azure.AI.OpenAI; +using Azure.Core; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using Microsoft.SemanticKernel.Http; +using Microsoft.SemanticKernel.TextGeneration; + +#pragma warning disable IDE0039 // Use local function + +namespace Microsoft.SemanticKernel; + +/// +/// Provides extension methods for and related classes to configure Azure OpenAI connectors. +/// +public static class AzureOpenAIServiceCollectionExtensions +{ + #region Chat Completion + + /// + /// Adds the Azure OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource + /// Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Azure OpenAI API key, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// A local identifier for the given AI service + /// Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// The HttpClient to use with this service. + /// The same instance as . + public static IKernelBuilder AddAzureOpenAIChatCompletion( + this IKernelBuilder builder, + string deploymentName, + string endpoint, + string apiKey, + string? serviceId = null, + string? modelId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(builder); + Verify.NotNullOrWhiteSpace(deploymentName); + Verify.NotNullOrWhiteSpace(endpoint); + Verify.NotNullOrWhiteSpace(apiKey); + + Func factory = (serviceProvider, _) => + { + AzureOpenAIClient client = CreateAzureOpenAIClient( + endpoint, + new AzureKeyCredential(apiKey), + HttpClientProvider.GetHttpClient(httpClient, serviceProvider)); + + return new(deploymentName, client, modelId, serviceProvider.GetService()); + }; + + builder.Services.AddKeyedSingleton(serviceId, factory); + builder.Services.AddKeyedSingleton(serviceId, factory); + + return builder; + } + + /// + /// Adds the Azure OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource + /// Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Azure OpenAI API key, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// A local identifier for the given AI service + /// Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// The same instance as . + public static IServiceCollection AddAzureOpenAIChatCompletion( + this IServiceCollection services, + string deploymentName, + string endpoint, + string apiKey, + string? serviceId = null, + string? modelId = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(deploymentName); + Verify.NotNullOrWhiteSpace(endpoint); + Verify.NotNullOrWhiteSpace(apiKey); + + Func factory = (serviceProvider, _) => + { + AzureOpenAIClient client = CreateAzureOpenAIClient( + endpoint, + new AzureKeyCredential(apiKey), + HttpClientProvider.GetHttpClient(serviceProvider)); + + return new(deploymentName, client, modelId, serviceProvider.GetService()); + }; + + services.AddKeyedSingleton(serviceId, factory); + services.AddKeyedSingleton(serviceId, factory); + + return services; + } + + /// + /// Adds the Azure OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource + /// Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Token credentials, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc. + /// A local identifier for the given AI service + /// Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// The HttpClient to use with this service. + /// The same instance as . + public static IKernelBuilder AddAzureOpenAIChatCompletion( + this IKernelBuilder builder, + string deploymentName, + string endpoint, + TokenCredential credentials, + string? serviceId = null, + string? modelId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(builder); + Verify.NotNullOrWhiteSpace(deploymentName); + Verify.NotNullOrWhiteSpace(endpoint); + Verify.NotNull(credentials); + + Func factory = (serviceProvider, _) => + { + AzureOpenAIClient client = CreateAzureOpenAIClient( + endpoint, + credentials, + HttpClientProvider.GetHttpClient(httpClient, serviceProvider)); + + return new(deploymentName, client, modelId, serviceProvider.GetService()); + }; + + builder.Services.AddKeyedSingleton(serviceId, factory); + builder.Services.AddKeyedSingleton(serviceId, factory); + + return builder; + } + + /// + /// Adds the Azure OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource + /// Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Token credentials, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc. + /// A local identifier for the given AI service + /// Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// The same instance as . + public static IServiceCollection AddAzureOpenAIChatCompletion( + this IServiceCollection services, + string deploymentName, + string endpoint, + TokenCredential credentials, + string? serviceId = null, + string? modelId = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(deploymentName); + Verify.NotNullOrWhiteSpace(endpoint); + Verify.NotNull(credentials); + + Func factory = (serviceProvider, _) => + { + AzureOpenAIClient client = CreateAzureOpenAIClient( + endpoint, + credentials, + HttpClientProvider.GetHttpClient(serviceProvider)); + + return new(deploymentName, client, modelId, serviceProvider.GetService()); + }; + + services.AddKeyedSingleton(serviceId, factory); + services.AddKeyedSingleton(serviceId, factory); + + return services; + } + + /// + /// Adds the Azure OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource + /// to use for the service. If null, one must be available in the service provider when this service is resolved. + /// A local identifier for the given AI service + /// Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// The same instance as . + public static IKernelBuilder AddAzureOpenAIChatCompletion( + this IKernelBuilder builder, + string deploymentName, + AzureOpenAIClient? azureOpenAIClient = null, + string? serviceId = null, + string? modelId = null) + { + Verify.NotNull(builder); + Verify.NotNullOrWhiteSpace(deploymentName); + + Func factory = (serviceProvider, _) => + new(deploymentName, azureOpenAIClient ?? serviceProvider.GetRequiredService(), modelId, serviceProvider.GetService()); + + builder.Services.AddKeyedSingleton(serviceId, factory); + builder.Services.AddKeyedSingleton(serviceId, factory); + + return builder; + } + + /// + /// Adds the Azure OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource + /// to use for the service. If null, one must be available in the service provider when this service is resolved. + /// A local identifier for the given AI service + /// Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// The same instance as . + public static IServiceCollection AddAzureOpenAIChatCompletion( + this IServiceCollection services, + string deploymentName, + AzureOpenAIClient? azureOpenAIClient = null, + string? serviceId = null, + string? modelId = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(deploymentName); + + Func factory = (serviceProvider, _) => + new(deploymentName, azureOpenAIClient ?? serviceProvider.GetRequiredService(), modelId, serviceProvider.GetService()); + + services.AddKeyedSingleton(serviceId, factory); + services.AddKeyedSingleton(serviceId, factory); + + return services; + } + + #endregion + + private static AzureOpenAIClient CreateAzureOpenAIClient(string endpoint, AzureKeyCredential credentials, HttpClient? httpClient) => + new(new Uri(endpoint), credentials, ClientCore.GetOpenAIClientOptions(httpClient)); + + private static AzureOpenAIClient CreateAzureOpenAIClient(string endpoint, TokenCredential credentials, HttpClient? httpClient) => + new(new Uri(endpoint), credentials, ClientCore.GetOpenAIClientOptions(httpClient)); +} diff --git a/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletionTests.cs b/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletionTests.cs new file mode 100644 index 000000000000..ce4641f380cf --- /dev/null +++ b/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletionTests.cs @@ -0,0 +1,273 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Http.Resilience; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using OpenAI.Chat; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; + +namespace SemanticKernel.IntegrationTestsV2.Connectors.AzureOpenAI; + +#pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only. + +public sealed class AzureOpenAIChatCompletionTests() +{ + [Fact] + //[Fact(Skip = "Skipping while we investigate issue with GitHub actions.")] + public async Task ItCanUseAzureOpenAiChatForTextGenerationAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var func = kernel.CreateFunctionFromPrompt( + "List the two planets after '{{$input}}', excluding moons, using bullet points.", + new AzureOpenAIPromptExecutionSettings()); + + // Act + var result = await func.InvokeAsync(kernel, new() { [InputParameterName] = "Jupiter" }); + + // Assert + Assert.NotNull(result); + Assert.Contains("Saturn", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + Assert.Contains("Uranus", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + } + + [Fact] + public async Task AzureOpenAIStreamingTestAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin"); + + StringBuilder fullResult = new(); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + await foreach (var content in kernel.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt })) + { + fullResult.Append(content); + } + + // Assert + Assert.Contains("Pike Place", fullResult.ToString(), StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task AzureOpenAIHttpRetryPolicyTestAsync() + { + // Arrange + List statusCodes = []; + + var azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get(); + + this._kernelBuilder.AddAzureOpenAIChatCompletion( + deploymentName: azureOpenAIConfiguration!.ChatDeploymentName!, + modelId: azureOpenAIConfiguration.ChatModelId, + endpoint: azureOpenAIConfiguration.Endpoint, + apiKey: "INVALID_KEY"); + + this._kernelBuilder.Services.ConfigureHttpClientDefaults(c => + { + // Use a standard resiliency policy, augmented to retry on 401 Unauthorized for this example + c.AddStandardResilienceHandler().Configure(o => + { + o.Retry.ShouldHandle = args => ValueTask.FromResult(args.Outcome.Result?.StatusCode is HttpStatusCode.Unauthorized); + o.Retry.OnRetry = args => + { + statusCodes.Add(args.Outcome.Result?.StatusCode); + return ValueTask.CompletedTask; + }; + }); + }); + + var target = this._kernelBuilder.Build(); + + var plugins = TestHelpers.ImportSamplePlugins(target, "SummarizePlugin"); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + var exception = await Assert.ThrowsAsync(() => target.InvokeAsync(plugins["SummarizePlugin"]["Summarize"], new() { [InputParameterName] = prompt })); + + // Assert + Assert.All(statusCodes, s => Assert.Equal(HttpStatusCode.Unauthorized, s)); + Assert.Equal(HttpStatusCode.Unauthorized, ((HttpOperationException)exception).StatusCode); + } + + [Fact] + public async Task AzureOpenAIShouldReturnMetadataAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "FunPlugin"); + + // Act + var result = await kernel.InvokeAsync(plugins["FunPlugin"]["Limerick"]); + + // Assert + Assert.NotNull(result.Metadata); + + // Usage + Assert.True(result.Metadata.TryGetValue("Usage", out object? usageObject)); + Assert.NotNull(usageObject); + + var jsonObject = JsonSerializer.SerializeToElement(usageObject); + Assert.True(jsonObject.TryGetProperty("InputTokens", out JsonElement promptTokensJson)); + Assert.True(promptTokensJson.TryGetInt32(out int promptTokens)); + Assert.NotEqual(0, promptTokens); + + Assert.True(jsonObject.TryGetProperty("OutputTokens", out JsonElement completionTokensJson)); + Assert.True(completionTokensJson.TryGetInt32(out int completionTokens)); + Assert.NotEqual(0, completionTokens); + + // ContentFilterResults + Assert.True(result.Metadata.ContainsKey("ContentFilterResults")); + } + + [Theory(Skip = "This test is for manual verification.")] + [InlineData("\n")] + [InlineData("\r\n")] + public async Task CompletionWithDifferentLineEndingsAsync(string lineEnding) + { + // Arrange + var prompt = + "Given a json input and a request. Apply the request on the json input and return the result. " + + $"Put the result in between tags{lineEnding}" + + $$"""Input:{{lineEnding}}{"name": "John", "age": 30}{{lineEnding}}{{lineEnding}}Request:{{lineEnding}}name"""; + + var kernel = this.CreateAndInitializeKernel(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin"); + + // Act + FunctionResult actual = await kernel.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }); + + // Assert + Assert.Contains("John", actual.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task ChatSystemPromptIsNotIgnoredAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var settings = new AzureOpenAIPromptExecutionSettings { ChatSystemPrompt = "Reply \"I don't know\" to every question." }; + + // Act + var result = await kernel.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?", new(settings)); + + // Assert + Assert.Contains("I don't know", result.ToString(), StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task SemanticKernelVersionHeaderIsSentAsync() + { + // Arrange + using var defaultHandler = new HttpClientHandler(); + using var httpHeaderHandler = new HttpHeaderHandler(defaultHandler); + using var httpClient = new HttpClient(httpHeaderHandler); + + var kernel = this.CreateAndInitializeKernel(httpClient); + + // Act + var result = await kernel.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?"); + + // Assert + Assert.NotNull(httpHeaderHandler.RequestHeaders); + Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var values)); + } + + //[Theory(Skip = "This test is for manual verification.")] + [Theory] + [InlineData(null, null)] + [InlineData(false, null)] + [InlineData(true, 2)] + [InlineData(true, 5)] + public async Task LogProbsDataIsReturnedWhenRequestedAsync(bool? logprobs, int? topLogprobs) + { + // Arrange + var settings = new AzureOpenAIPromptExecutionSettings { Logprobs = logprobs, TopLogprobs = topLogprobs }; + + var kernel = this.CreateAndInitializeKernel(); + + // Act + var result = await kernel.InvokePromptAsync("Hi, can you help me today?", new(settings)); + + var logProbabilityInfo = result.Metadata?["LogProbabilityInfo"] as IReadOnlyList; + + // Assert + Assert.NotNull(logProbabilityInfo); + + if (logprobs is true) + { + Assert.NotNull(logProbabilityInfo); + Assert.Equal(topLogprobs, logProbabilityInfo[0].TopLogProbabilities.Count); + } + else + { + Assert.Empty(logProbabilityInfo); + } + } + + #region internals + + private Kernel CreateAndInitializeKernel(HttpClient? httpClient = null) + { + var azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get(); + Assert.NotNull(azureOpenAIConfiguration); + Assert.NotNull(azureOpenAIConfiguration.ChatDeploymentName); + Assert.NotNull(azureOpenAIConfiguration.ApiKey); + Assert.NotNull(azureOpenAIConfiguration.Endpoint); + Assert.NotNull(azureOpenAIConfiguration.ServiceId); + + this._kernelBuilder.AddAzureOpenAIChatCompletion( + deploymentName: azureOpenAIConfiguration.ChatDeploymentName, + modelId: azureOpenAIConfiguration.ChatModelId, + endpoint: azureOpenAIConfiguration.Endpoint, + apiKey: azureOpenAIConfiguration.ApiKey, + serviceId: azureOpenAIConfiguration.ServiceId, + httpClient: httpClient); + + return this._kernelBuilder.Build(); + } + + private const string InputParameterName = "input"; + private readonly IKernelBuilder _kernelBuilder = Kernel.CreateBuilder(); + + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + private sealed class HttpHeaderHandler(HttpMessageHandler innerHandler) : DelegatingHandler(innerHandler) + { + public System.Net.Http.Headers.HttpRequestHeaders? RequestHeaders { get; private set; } + + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + this.RequestHeaders = request.Headers; + return await base.SendAsync(request, cancellationToken); + } + } + + #endregion +} diff --git a/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_FunctionCallingTests.cs b/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_FunctionCallingTests.cs new file mode 100644 index 000000000000..5bbbd60c9005 --- /dev/null +++ b/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_FunctionCallingTests.cs @@ -0,0 +1,781 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using OpenAI.Chat; +using SemanticKernel.IntegrationTests.TestSettings; +using SemanticKernel.IntegrationTestsV2.Connectors.AzureOpenAI; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.AzureOpenAI; + +public sealed class AzureOpenAIChatCompletionFunctionCallingTests +{ + [Fact] + public async Task CanAutoInvokeKernelFunctionsAsync() + { + // Arrange + var invokedFunctions = new List(); + + var filter = new FakeFunctionFilter(async (context, next) => + { + invokedFunctions.Add($"{context.Function.Name}({string.Join(", ", context.Arguments)})"); + await next(context); + }); + + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + kernel.FunctionInvocationFilters.Add(filter); + + AzureOpenAIPromptExecutionSettings settings = new() { ToolCallBehavior = AzureOpenAIToolCallBehavior.AutoInvokeKernelFunctions }; + + // Act + var result = await kernel.InvokePromptAsync("Given the current time of day and weather, what is the likely color of the sky in Boston?", new(settings)); + + // Assert + Assert.Contains("rain", result.GetValue(), StringComparison.InvariantCulture); + Assert.Contains("GetCurrentUtcTime()", invokedFunctions); + Assert.Contains("Get_Weather_For_City([cityName, Boston])", invokedFunctions); + } + + [Fact] + public async Task CanAutoInvokeKernelFunctionsStreamingAsync() + { + // Arrange + var invokedFunctions = new List(); + + var filter = new FakeFunctionFilter(async (context, next) => + { + invokedFunctions.Add($"{context.Function.Name}({string.Join(", ", context.Arguments)})"); + await next(context); + }); + + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + kernel.FunctionInvocationFilters.Add(filter); + + AzureOpenAIPromptExecutionSettings settings = new() { ToolCallBehavior = AzureOpenAIToolCallBehavior.AutoInvokeKernelFunctions }; + + var stringBuilder = new StringBuilder(); + + // Act + await foreach (var update in kernel.InvokePromptStreamingAsync("Given the current time of day and weather, what is the likely color of the sky in Boston?", new(settings))) + { + stringBuilder.Append(update); + } + + // Assert + Assert.Contains("rain", stringBuilder.ToString(), StringComparison.InvariantCulture); + Assert.Contains("GetCurrentUtcTime()", invokedFunctions); + Assert.Contains("Get_Weather_For_City([cityName, Boston])", invokedFunctions); + } + + [Fact] + public async Task CanAutoInvokeKernelFunctionsWithComplexTypeParametersAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + AzureOpenAIPromptExecutionSettings settings = new() { ToolCallBehavior = AzureOpenAIToolCallBehavior.AutoInvokeKernelFunctions }; + + // Act + var result = await kernel.InvokePromptAsync("What is the current temperature in Dublin, Ireland, in Fahrenheit?", new(settings)); + + // Assert + Assert.NotNull(result); + Assert.Contains("42.8", result.GetValue(), StringComparison.InvariantCulture); // The WeatherPlugin always returns 42.8 for Dublin, Ireland. + } + + [Fact] + public async Task CanAutoInvokeKernelFunctionsWithPrimitiveTypeParametersAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + AzureOpenAIPromptExecutionSettings settings = new() { ToolCallBehavior = AzureOpenAIToolCallBehavior.AutoInvokeKernelFunctions }; + + // Act + var result = await kernel.InvokePromptAsync("Convert 50 degrees Fahrenheit to Celsius.", new(settings)); + + // Assert + Assert.NotNull(result); + Assert.Contains("10", result.GetValue(), StringComparison.InvariantCulture); + } + + [Fact] + public async Task CanAutoInvokeKernelFunctionsWithEnumTypeParametersAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + AzureOpenAIPromptExecutionSettings settings = new() { ToolCallBehavior = AzureOpenAIToolCallBehavior.AutoInvokeKernelFunctions }; + + // Act + var result = await kernel.InvokePromptAsync("Given the current time of day and weather, what is the likely color of the sky in Boston?", new(settings)); + + // Assert + Assert.NotNull(result); + Assert.Contains("rain", result.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task CanAutoInvokeKernelFunctionFromPromptAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var promptFunction = KernelFunctionFactory.CreateFromPrompt( + "Your role is always to return this text - 'A Game-Changer for the Transportation Industry'. Don't ask for more details or context.", + functionName: "FindLatestNews", + description: "Searches for the latest news."); + + kernel.Plugins.Add(KernelPluginFactory.CreateFromFunctions( + "NewsProvider", + "Delivers up-to-date news content.", + [promptFunction])); + + AzureOpenAIPromptExecutionSettings settings = new() { ToolCallBehavior = AzureOpenAIToolCallBehavior.AutoInvokeKernelFunctions }; + + // Act + var result = await kernel.InvokePromptAsync("Show me the latest news as they are.", new(settings)); + + // Assert + Assert.NotNull(result); + Assert.Contains("Transportation", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + } + + [Fact] + public async Task CanAutoInvokeKernelFunctionFromPromptStreamingAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var promptFunction = KernelFunctionFactory.CreateFromPrompt( + "Your role is always to return this text - 'A Game-Changer for the Transportation Industry'. Don't ask for more details or context.", + functionName: "FindLatestNews", + description: "Searches for the latest news."); + + kernel.Plugins.Add(KernelPluginFactory.CreateFromFunctions( + "NewsProvider", + "Delivers up-to-date news content.", + [promptFunction])); + + AzureOpenAIPromptExecutionSettings settings = new() { ToolCallBehavior = AzureOpenAIToolCallBehavior.AutoInvokeKernelFunctions }; + + // Act + var streamingResult = kernel.InvokePromptStreamingAsync("Show me the latest news as they are.", new(settings)); + + var builder = new StringBuilder(); + + await foreach (var update in streamingResult) + { + builder.Append(update.ToString()); + } + + var result = builder.ToString(); + + // Assert + Assert.NotNull(result); + Assert.Contains("Transportation", result, StringComparison.InvariantCultureIgnoreCase); + } + + [Fact] + public async Task ConnectorSpecificChatMessageContentClassesCanBeUsedForManualFunctionCallingAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?"); + + var settings = new AzureOpenAIPromptExecutionSettings() { ToolCallBehavior = AzureOpenAIToolCallBehavior.EnableKernelFunctions }; + + var sut = kernel.GetRequiredService(); + + // Act + var result = await sut.GetChatMessageContentAsync(chatHistory, settings, kernel); + + // Current way of handling function calls manually using connector specific chat message content class. + var toolCalls = ((AzureOpenAIChatMessageContent)result).ToolCalls.OfType().ToList(); + + while (toolCalls.Count > 0) + { + // Adding LLM function call request to chat history + chatHistory.Add(result); + + // Iterating over the requested function calls and invoking them + foreach (var toolCall in toolCalls) + { + string content = kernel.Plugins.TryGetFunctionAndArguments(toolCall, out KernelFunction? function, out KernelArguments? arguments) ? + JsonSerializer.Serialize((await function.InvokeAsync(kernel, arguments)).GetValue()) : + "Unable to find function. Please try again!"; + + // Adding the result of the function call to the chat history + chatHistory.Add(new ChatMessageContent( + AuthorRole.Tool, + content, + metadata: new Dictionary(1) { { AzureOpenAIChatMessageContent.ToolIdProperty, toolCall.Id } })); + } + + // Sending the functions invocation results back to the LLM to get the final response + result = await sut.GetChatMessageContentAsync(chatHistory, settings, kernel); + toolCalls = ((AzureOpenAIChatMessageContent)result).ToolCalls.OfType().ToList(); + } + + // Assert + Assert.Contains("rain", result.Content, StringComparison.InvariantCultureIgnoreCase); + } + + [Fact] + public async Task ConnectorAgnosticFunctionCallingModelClassesCanBeUsedForManualFunctionCallingAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?"); + + var settings = new AzureOpenAIPromptExecutionSettings() { ToolCallBehavior = AzureOpenAIToolCallBehavior.EnableKernelFunctions }; + + var sut = kernel.GetRequiredService(); + + // Act + var messageContent = await sut.GetChatMessageContentAsync(chatHistory, settings, kernel); + + var functionCalls = FunctionCallContent.GetFunctionCalls(messageContent).ToArray(); + + while (functionCalls.Length != 0) + { + // Adding function call from LLM to chat history + chatHistory.Add(messageContent); + + // Iterating over the requested function calls and invoking them + foreach (var functionCall in functionCalls) + { + var result = await functionCall.InvokeAsync(kernel); + + chatHistory.Add(result.ToChatMessage()); + } + + // Sending the functions invocation results to the LLM to get the final response + messageContent = await sut.GetChatMessageContentAsync(chatHistory, settings, kernel); + functionCalls = FunctionCallContent.GetFunctionCalls(messageContent).ToArray(); + } + + // Assert + Assert.Contains("rain", messageContent.Content, StringComparison.InvariantCultureIgnoreCase); + } + + [Fact] + public async Task ConnectorAgnosticFunctionCallingModelClassesCanPassFunctionExceptionToConnectorAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + var chatHistory = new ChatHistory(); + chatHistory.AddSystemMessage("Add the \"Error\" keyword to the response, if you are unable to answer a question or an error has happen."); + chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?"); + + var settings = new AzureOpenAIPromptExecutionSettings() { ToolCallBehavior = AzureOpenAIToolCallBehavior.EnableKernelFunctions }; + + var completionService = kernel.GetRequiredService(); + + // Act + var messageContent = await completionService.GetChatMessageContentAsync(chatHistory, settings, kernel); + + var functionCalls = FunctionCallContent.GetFunctionCalls(messageContent).ToArray(); + + while (functionCalls.Length != 0) + { + // Adding function call from LLM to chat history + chatHistory.Add(messageContent); + + // Iterating over the requested function calls and invoking them + foreach (var functionCall in functionCalls) + { + // Simulating an exception + var exception = new OperationCanceledException("The operation was canceled due to timeout."); + + chatHistory.Add(new FunctionResultContent(functionCall, exception).ToChatMessage()); + } + + // Sending the functions execution results back to the LLM to get the final response + messageContent = await completionService.GetChatMessageContentAsync(chatHistory, settings, kernel); + functionCalls = FunctionCallContent.GetFunctionCalls(messageContent).ToArray(); + } + + // Assert + Assert.NotNull(messageContent.Content); + + Assert.Contains("error", messageContent.Content, StringComparison.InvariantCultureIgnoreCase); + } + + [Fact] + public async Task ConnectorAgnosticFunctionCallingModelClassesSupportSimulatedFunctionCallsAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + var chatHistory = new ChatHistory(); + chatHistory.AddSystemMessage("if there's a tornado warning, please add the 'tornado' keyword to the response."); + chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?"); + + var settings = new AzureOpenAIPromptExecutionSettings() { ToolCallBehavior = AzureOpenAIToolCallBehavior.EnableKernelFunctions }; + + var completionService = kernel.GetRequiredService(); + + // Act + var messageContent = await completionService.GetChatMessageContentAsync(chatHistory, settings, kernel); + + var functionCalls = FunctionCallContent.GetFunctionCalls(messageContent).ToArray(); + + while (functionCalls.Length > 0) + { + // Adding function call from LLM to chat history + chatHistory.Add(messageContent); + + // Iterating over the requested function calls and invoking them + foreach (var functionCall in functionCalls) + { + var result = await functionCall.InvokeAsync(kernel); + + chatHistory.AddMessage(AuthorRole.Tool, [result]); + } + + // Adding a simulated function call to the connector response message + var simulatedFunctionCall = new FunctionCallContent("weather-alert", id: "call_123"); + messageContent.Items.Add(simulatedFunctionCall); + + // Adding a simulated function result to chat history + var simulatedFunctionResult = "A Tornado Watch has been issued, with potential for severe thunderstorms causing unusual sky colors like green, yellow, or dark gray. Stay informed and follow safety instructions from authorities."; + chatHistory.Add(new FunctionResultContent(simulatedFunctionCall, simulatedFunctionResult).ToChatMessage()); + + // Sending the functions invocation results back to the LLM to get the final response + messageContent = await completionService.GetChatMessageContentAsync(chatHistory, settings, kernel); + functionCalls = FunctionCallContent.GetFunctionCalls(messageContent).ToArray(); + } + + // Assert + Assert.Contains("tornado", messageContent.Content, StringComparison.InvariantCultureIgnoreCase); + } + + [Fact] + public async Task ItFailsIfNoFunctionResultProvidedAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?"); + + var settings = new AzureOpenAIPromptExecutionSettings() { ToolCallBehavior = AzureOpenAIToolCallBehavior.EnableKernelFunctions }; + + var completionService = kernel.GetRequiredService(); + + // Act + var result = await completionService.GetChatMessageContentAsync(chatHistory, settings, kernel); + + chatHistory.Add(result); + + var exception = await Assert.ThrowsAsync(() => completionService.GetChatMessageContentAsync(chatHistory, settings, kernel)); + + // Assert + Assert.Contains("'tool_calls' must be followed by tool", exception.Message, StringComparison.InvariantCulture); + } + + [Fact] + public async Task ConnectorAgnosticFunctionCallingModelClassesCanBeUsedForAutoFunctionCallingAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?"); + + var settings = new AzureOpenAIPromptExecutionSettings() { ToolCallBehavior = AzureOpenAIToolCallBehavior.AutoInvokeKernelFunctions }; + + var sut = kernel.GetRequiredService(); + + // Act + await sut.GetChatMessageContentAsync(chatHistory, settings, kernel); + + // Assert + Assert.Equal(5, chatHistory.Count); + + var userMessage = chatHistory[0]; + Assert.Equal(AuthorRole.User, userMessage.Role); + + // LLM requested the current time. + var getCurrentTimeFunctionCallRequestMessage = chatHistory[1]; + Assert.Equal(AuthorRole.Assistant, getCurrentTimeFunctionCallRequestMessage.Role); + + var getCurrentTimeFunctionCallRequest = getCurrentTimeFunctionCallRequestMessage.Items.OfType().Single(); + Assert.Equal("GetCurrentUtcTime", getCurrentTimeFunctionCallRequest.FunctionName); + Assert.Equal("HelperFunctions", getCurrentTimeFunctionCallRequest.PluginName); + Assert.NotNull(getCurrentTimeFunctionCallRequest.Id); + + // Connector invoked the GetCurrentUtcTime function and added result to chat history. + var getCurrentTimeFunctionCallResultMessage = chatHistory[2]; + Assert.Equal(AuthorRole.Tool, getCurrentTimeFunctionCallResultMessage.Role); + Assert.Single(getCurrentTimeFunctionCallResultMessage.Items.OfType()); // Current function calling model adds TextContent item representing the result of the function call. + + var getCurrentTimeFunctionCallResult = getCurrentTimeFunctionCallResultMessage.Items.OfType().Single(); + Assert.Equal("GetCurrentUtcTime", getCurrentTimeFunctionCallResult.FunctionName); + Assert.Equal("HelperFunctions", getCurrentTimeFunctionCallResult.PluginName); + Assert.Equal(getCurrentTimeFunctionCallRequest.Id, getCurrentTimeFunctionCallResult.CallId); + Assert.NotNull(getCurrentTimeFunctionCallResult.Result); + + // LLM requested the weather for Boston. + var getWeatherForCityFunctionCallRequestMessage = chatHistory[3]; + Assert.Equal(AuthorRole.Assistant, getWeatherForCityFunctionCallRequestMessage.Role); + + var getWeatherForCityFunctionCallRequest = getWeatherForCityFunctionCallRequestMessage.Items.OfType().Single(); + Assert.Equal("Get_Weather_For_City", getWeatherForCityFunctionCallRequest.FunctionName); + Assert.Equal("HelperFunctions", getWeatherForCityFunctionCallRequest.PluginName); + Assert.NotNull(getWeatherForCityFunctionCallRequest.Id); + + // Connector invoked the Get_Weather_For_City function and added result to chat history. + var getWeatherForCityFunctionCallResultMessage = chatHistory[4]; + Assert.Equal(AuthorRole.Tool, getWeatherForCityFunctionCallResultMessage.Role); + Assert.Single(getWeatherForCityFunctionCallResultMessage.Items.OfType()); // Current function calling model adds TextContent item representing the result of the function call. + + var getWeatherForCityFunctionCallResult = getWeatherForCityFunctionCallResultMessage.Items.OfType().Single(); + Assert.Equal("Get_Weather_For_City", getWeatherForCityFunctionCallResult.FunctionName); + Assert.Equal("HelperFunctions", getWeatherForCityFunctionCallResult.PluginName); + Assert.Equal(getWeatherForCityFunctionCallRequest.Id, getWeatherForCityFunctionCallResult.CallId); + Assert.NotNull(getWeatherForCityFunctionCallResult.Result); + } + + [Fact] + public async Task ConnectorAgnosticFunctionCallingModelClassesCanBeUsedForManualFunctionCallingForStreamingAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + var settings = new AzureOpenAIPromptExecutionSettings() { ToolCallBehavior = AzureOpenAIToolCallBehavior.EnableKernelFunctions }; + + var sut = kernel.GetRequiredService(); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?"); + + string? result = null; + + // Act + while (true) + { + AuthorRole? authorRole = null; + var fccBuilder = new FunctionCallContentBuilder(); + var textContent = new StringBuilder(); + + await foreach (var streamingContent in sut.GetStreamingChatMessageContentsAsync(chatHistory, settings, kernel)) + { + textContent.Append(streamingContent.Content); + authorRole ??= streamingContent.Role; + fccBuilder.Append(streamingContent); + } + + var functionCalls = fccBuilder.Build(); + if (functionCalls.Any()) + { + var fcContent = new ChatMessageContent(role: authorRole ?? default, content: null); + chatHistory.Add(fcContent); + + // Iterating over the requested function calls and invoking them + foreach (var functionCall in functionCalls) + { + fcContent.Items.Add(functionCall); + + var functionResult = await functionCall.InvokeAsync(kernel); + + chatHistory.Add(functionResult.ToChatMessage()); + } + + continue; + } + + result = textContent.ToString(); + break; + } + + // Assert + Assert.Contains("rain", result, StringComparison.InvariantCultureIgnoreCase); + } + + [Fact] + public async Task ConnectorAgnosticFunctionCallingModelClassesCanBeUsedForAutoFunctionCallingForStreamingAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + var chatHistory = new ChatHistory(); + chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?"); + + var settings = new AzureOpenAIPromptExecutionSettings() { ToolCallBehavior = AzureOpenAIToolCallBehavior.AutoInvokeKernelFunctions }; + + var sut = kernel.GetRequiredService(); + + var result = new StringBuilder(); + + // Act + await foreach (var contentUpdate in sut.GetStreamingChatMessageContentsAsync(chatHistory, settings, kernel)) + { + result.Append(contentUpdate.Content); + } + + // Assert + Assert.Equal(5, chatHistory.Count); + + var userMessage = chatHistory[0]; + Assert.Equal(AuthorRole.User, userMessage.Role); + + // LLM requested the current time. + var getCurrentTimeFunctionCallRequestMessage = chatHistory[1]; + Assert.Equal(AuthorRole.Assistant, getCurrentTimeFunctionCallRequestMessage.Role); + + var getCurrentTimeFunctionCallRequest = getCurrentTimeFunctionCallRequestMessage.Items.OfType().Single(); + Assert.Equal("GetCurrentUtcTime", getCurrentTimeFunctionCallRequest.FunctionName); + Assert.Equal("HelperFunctions", getCurrentTimeFunctionCallRequest.PluginName); + Assert.NotNull(getCurrentTimeFunctionCallRequest.Id); + + // Connector invoked the GetCurrentUtcTime function and added result to chat history. + var getCurrentTimeFunctionCallResultMessage = chatHistory[2]; + Assert.Equal(AuthorRole.Tool, getCurrentTimeFunctionCallResultMessage.Role); + Assert.Single(getCurrentTimeFunctionCallResultMessage.Items.OfType()); // Current function calling model adds TextContent item representing the result of the function call. + + var getCurrentTimeFunctionCallResult = getCurrentTimeFunctionCallResultMessage.Items.OfType().Single(); + Assert.Equal("GetCurrentUtcTime", getCurrentTimeFunctionCallResult.FunctionName); + Assert.Equal("HelperFunctions", getCurrentTimeFunctionCallResult.PluginName); + Assert.Equal(getCurrentTimeFunctionCallRequest.Id, getCurrentTimeFunctionCallResult.CallId); + Assert.NotNull(getCurrentTimeFunctionCallResult.Result); + + // LLM requested the weather for Boston. + var getWeatherForCityFunctionCallRequestMessage = chatHistory[3]; + Assert.Equal(AuthorRole.Assistant, getWeatherForCityFunctionCallRequestMessage.Role); + + var getWeatherForCityFunctionCallRequest = getWeatherForCityFunctionCallRequestMessage.Items.OfType().Single(); + Assert.Equal("Get_Weather_For_City", getWeatherForCityFunctionCallRequest.FunctionName); + Assert.Equal("HelperFunctions", getWeatherForCityFunctionCallRequest.PluginName); + Assert.NotNull(getWeatherForCityFunctionCallRequest.Id); + + // Connector invoked the Get_Weather_For_City function and added result to chat history. + var getWeatherForCityFunctionCallResultMessage = chatHistory[4]; + Assert.Equal(AuthorRole.Tool, getWeatherForCityFunctionCallResultMessage.Role); + Assert.Single(getWeatherForCityFunctionCallResultMessage.Items.OfType()); // Current function calling model adds TextContent item representing the result of the function call. + + var getWeatherForCityFunctionCallResult = getWeatherForCityFunctionCallResultMessage.Items.OfType().Single(); + Assert.Equal("Get_Weather_For_City", getWeatherForCityFunctionCallResult.FunctionName); + Assert.Equal("HelperFunctions", getWeatherForCityFunctionCallResult.PluginName); + Assert.Equal(getWeatherForCityFunctionCallRequest.Id, getWeatherForCityFunctionCallResult.CallId); + Assert.NotNull(getWeatherForCityFunctionCallResult.Result); + } + + [Fact] + public async Task ConnectorAgnosticFunctionCallingModelClassesCanPassFunctionExceptionToConnectorForStreamingAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + var settings = new AzureOpenAIPromptExecutionSettings() { ToolCallBehavior = AzureOpenAIToolCallBehavior.EnableKernelFunctions }; + + var sut = kernel.GetRequiredService(); + + var chatHistory = new ChatHistory(); + chatHistory.AddSystemMessage("Add the \"Error\" keyword to the response, if you are unable to answer a question or an error has happen."); + chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?"); + + string? result = null; + + // Act + while (true) + { + AuthorRole? authorRole = null; + var fccBuilder = new FunctionCallContentBuilder(); + var textContent = new StringBuilder(); + + await foreach (var streamingContent in sut.GetStreamingChatMessageContentsAsync(chatHistory, settings, kernel)) + { + textContent.Append(streamingContent.Content); + authorRole ??= streamingContent.Role; + fccBuilder.Append(streamingContent); + } + + var functionCalls = fccBuilder.Build(); + if (functionCalls.Any()) + { + var fcContent = new ChatMessageContent(role: authorRole ?? default, content: null); + chatHistory.Add(fcContent); + + // Iterating over the requested function calls and invoking them + foreach (var functionCall in functionCalls) + { + fcContent.Items.Add(functionCall); + + // Simulating an exception + var exception = new OperationCanceledException("The operation was canceled due to timeout."); + + chatHistory.Add(new FunctionResultContent(functionCall, exception).ToChatMessage()); + } + + continue; + } + + result = textContent.ToString(); + break; + } + + // Assert + Assert.Contains("error", result, StringComparison.InvariantCultureIgnoreCase); + } + + [Fact] + public async Task ConnectorAgnosticFunctionCallingModelClassesSupportSimulatedFunctionCallsForStreamingAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(importHelperPlugin: true); + + var settings = new AzureOpenAIPromptExecutionSettings() { ToolCallBehavior = AzureOpenAIToolCallBehavior.EnableKernelFunctions }; + + var sut = kernel.GetRequiredService(); + + var chatHistory = new ChatHistory(); + chatHistory.AddSystemMessage("if there's a tornado warning, please add the 'tornado' keyword to the response."); + chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?"); + + string? result = null; + + // Act + while (true) + { + AuthorRole? authorRole = null; + var fccBuilder = new FunctionCallContentBuilder(); + var textContent = new StringBuilder(); + + await foreach (var streamingContent in sut.GetStreamingChatMessageContentsAsync(chatHistory, settings, kernel)) + { + textContent.Append(streamingContent.Content); + authorRole ??= streamingContent.Role; + fccBuilder.Append(streamingContent); + } + + var functionCalls = fccBuilder.Build(); + if (functionCalls.Any()) + { + var fcContent = new ChatMessageContent(role: authorRole ?? default, content: null); + chatHistory.Add(fcContent); + + // Iterating over the requested function calls and invoking them + foreach (var functionCall in functionCalls) + { + fcContent.Items.Add(functionCall); + + var functionResult = await functionCall.InvokeAsync(kernel); + + chatHistory.Add(functionResult.ToChatMessage()); + } + + // Adding a simulated function call to the connector response message + var simulatedFunctionCall = new FunctionCallContent("weather-alert", id: "call_123"); + fcContent.Items.Add(simulatedFunctionCall); + + // Adding a simulated function result to chat history + var simulatedFunctionResult = "A Tornado Watch has been issued, with potential for severe thunderstorms causing unusual sky colors like green, yellow, or dark gray. Stay informed and follow safety instructions from authorities."; + chatHistory.Add(new FunctionResultContent(simulatedFunctionCall, simulatedFunctionResult).ToChatMessage()); + + continue; + } + + result = textContent.ToString(); + break; + } + + // Assert + Assert.Contains("tornado", result, StringComparison.InvariantCultureIgnoreCase); + } + + private Kernel CreateAndInitializeKernel(bool importHelperPlugin = false) + { + var azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get(); + Assert.NotNull(azureOpenAIConfiguration); + Assert.NotNull(azureOpenAIConfiguration.ChatDeploymentName); + Assert.NotNull(azureOpenAIConfiguration.ApiKey); + Assert.NotNull(azureOpenAIConfiguration.Endpoint); + + var kernelBuilder = Kernel.CreateBuilder(); + + kernelBuilder.AddAzureOpenAIChatCompletion( + deploymentName: azureOpenAIConfiguration.ChatDeploymentName, + modelId: azureOpenAIConfiguration.ChatModelId, + endpoint: azureOpenAIConfiguration.Endpoint, + apiKey: azureOpenAIConfiguration.ApiKey); + + var kernel = kernelBuilder.Build(); + + if (importHelperPlugin) + { + kernel.ImportPluginFromFunctions("HelperFunctions", + [ + kernel.CreateFunctionFromMethod(() => DateTime.UtcNow.ToString("R"), "GetCurrentUtcTime", "Retrieves the current time in UTC."), + kernel.CreateFunctionFromMethod((string cityName) => + { + return cityName switch + { + "Boston" => "61 and rainy", + _ => "31 and snowing", + }; + }, "Get_Weather_For_City", "Gets the current weather for the specified city"), + kernel.CreateFunctionFromMethod((WeatherParameters parameters) => + { + if (parameters.City.Name == "Dublin" && (parameters.City.Country == "Ireland" || parameters.City.Country == "IE")) + { + return Task.FromResult(42.8); // 42.8 Fahrenheit. + } + + throw new NotSupportedException($"Weather in {parameters.City.Name} ({parameters.City.Country}) is not supported."); + }, "Get_Current_Temperature", "Get current temperature."), + kernel.CreateFunctionFromMethod((double temperatureInFahrenheit) => + { + double temperatureInCelsius = (temperatureInFahrenheit - 32) * 5 / 9; + return Task.FromResult(temperatureInCelsius); + }, "Convert_Temperature_From_Fahrenheit_To_Celsius", "Convert temperature from Fahrenheit to Celsius.") + ]); + } + + return kernel; + } + + public record WeatherParameters(City City); + + public class City + { + public string Name { get; set; } = string.Empty; + public string Country { get; set; } = string.Empty; + } + + private sealed class FakeFunctionFilter : IFunctionInvocationFilter + { + private readonly Func, Task>? _onFunctionInvocation; + + public FakeFunctionFilter( + Func, Task>? onFunctionInvocation = null) + { + this._onFunctionInvocation = onFunctionInvocation; + } + + public Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func next) => + this._onFunctionInvocation?.Invoke(context, next) ?? Task.CompletedTask; + } + + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); +} diff --git a/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_NonStreamingTests.cs b/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_NonStreamingTests.cs new file mode 100644 index 000000000000..cf87a2192a62 --- /dev/null +++ b/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_NonStreamingTests.cs @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using Microsoft.SemanticKernel.TextGeneration; +using OpenAI.Chat; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; + +namespace SemanticKernel.IntegrationTestsV2.Connectors.AzureOpenAI; + +#pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only. + +public sealed class AzureOpenAIChatCompletionNonStreamingTests() +{ + [Fact] + public async Task ChatCompletionShouldUseChatSystemPromptAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var chatCompletion = kernel.Services.GetRequiredService(); + + var settings = new AzureOpenAIPromptExecutionSettings { ChatSystemPrompt = "Reply \"I don't know\" to every question." }; + + // Act + var result = await chatCompletion.GetChatMessageContentAsync("What is the capital of France?", settings, kernel); + + // Assert + Assert.Contains("I don't know", result.Content); + } + + [Fact] + public async Task ChatCompletionShouldUseChatHistoryAndReturnMetadataAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var chatCompletion = kernel.Services.GetRequiredService(); + + var chatHistory = new ChatHistory("Reply \"I don't know\" to every question."); + chatHistory.AddUserMessage("What is the capital of France?"); + + // Act + var result = await chatCompletion.GetChatMessageContentAsync(chatHistory, null, kernel); + + // Assert + Assert.Contains("I don't know", result.Content); + Assert.NotNull(result.Metadata); + + Assert.True(result.Metadata.TryGetValue("Id", out object? id)); + Assert.NotNull(id); + + Assert.True(result.Metadata.TryGetValue("CreatedAt", out object? createdAt)); + Assert.NotNull(createdAt); + + Assert.True(result.Metadata.ContainsKey("PromptFilterResults")); + + Assert.True(result.Metadata.ContainsKey("SystemFingerprint")); + + Assert.True(result.Metadata.TryGetValue("Usage", out object? usageObject)); + Assert.NotNull(usageObject); + + var jsonObject = JsonSerializer.SerializeToElement(usageObject); + Assert.True(jsonObject.TryGetProperty("InputTokens", out JsonElement promptTokensJson)); + Assert.True(promptTokensJson.TryGetInt32(out int promptTokens)); + Assert.NotEqual(0, promptTokens); + + Assert.True(jsonObject.TryGetProperty("OutputTokens", out JsonElement completionTokensJson)); + Assert.True(completionTokensJson.TryGetInt32(out int completionTokens)); + Assert.NotEqual(0, completionTokens); + + Assert.True(result.Metadata.ContainsKey("ContentFilterResults")); + + Assert.True(result.Metadata.TryGetValue("FinishReason", out object? finishReason)); + Assert.Equal("Stop", finishReason); + + Assert.True(result.Metadata.TryGetValue("LogProbabilityInfo", out object? logProbabilityInfo)); + Assert.Empty((logProbabilityInfo as IReadOnlyList)!); + } + + [Fact] + public async Task TextGenerationShouldUseChatSystemPromptAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var textGeneration = kernel.Services.GetRequiredService(); + + var settings = new AzureOpenAIPromptExecutionSettings { ChatSystemPrompt = "Reply \"I don't know\" to every question." }; + + // Act + var result = await textGeneration.GetTextContentAsync("What is the capital of France?", settings, kernel); + + // Assert + Assert.Contains("I don't know", result.Text); + } + + [Fact] + public async Task TextGenerationShouldReturnMetadataAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var textGeneration = kernel.Services.GetRequiredService(); + + // Act + var result = await textGeneration.GetTextContentAsync("Reply \"I don't know\" to every question. What is the capital of France?", null, kernel); + + // Assert + Assert.Contains("I don't know", result.Text); + Assert.NotNull(result.Metadata); + + Assert.True(result.Metadata.TryGetValue("Id", out object? id)); + Assert.NotNull(id); + + Assert.True(result.Metadata.TryGetValue("CreatedAt", out object? createdAt)); + Assert.NotNull(createdAt); + + Assert.True(result.Metadata.ContainsKey("PromptFilterResults")); + + Assert.True(result.Metadata.ContainsKey("SystemFingerprint")); + + Assert.True(result.Metadata.TryGetValue("Usage", out object? usageObject)); + Assert.NotNull(usageObject); + + var jsonObject = JsonSerializer.SerializeToElement(usageObject); + Assert.True(jsonObject.TryGetProperty("InputTokens", out JsonElement promptTokensJson)); + Assert.True(promptTokensJson.TryGetInt32(out int promptTokens)); + Assert.NotEqual(0, promptTokens); + + Assert.True(jsonObject.TryGetProperty("OutputTokens", out JsonElement completionTokensJson)); + Assert.True(completionTokensJson.TryGetInt32(out int completionTokens)); + Assert.NotEqual(0, completionTokens); + + Assert.True(result.Metadata.ContainsKey("ContentFilterResults")); + + Assert.True(result.Metadata.TryGetValue("FinishReason", out object? finishReason)); + Assert.Equal("Stop", finishReason); + + Assert.True(result.Metadata.TryGetValue("LogProbabilityInfo", out object? logProbabilityInfo)); + Assert.Empty((logProbabilityInfo as IReadOnlyList)!); + } + + #region internals + + private Kernel CreateAndInitializeKernel() + { + var azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get(); + Assert.NotNull(azureOpenAIConfiguration); + Assert.NotNull(azureOpenAIConfiguration.ChatDeploymentName); + Assert.NotNull(azureOpenAIConfiguration.ApiKey); + Assert.NotNull(azureOpenAIConfiguration.Endpoint); + + var kernelBuilder = Kernel.CreateBuilder(); + + kernelBuilder.AddAzureOpenAIChatCompletion( + deploymentName: azureOpenAIConfiguration.ChatDeploymentName, + modelId: azureOpenAIConfiguration.ChatModelId, + endpoint: azureOpenAIConfiguration.Endpoint, + apiKey: azureOpenAIConfiguration.ApiKey); + + return kernelBuilder.Build(); + } + + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + #endregion +} diff --git a/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_StreamingTests.cs b/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_StreamingTests.cs new file mode 100644 index 000000000000..3e841d04f775 --- /dev/null +++ b/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_StreamingTests.cs @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using Microsoft.SemanticKernel.TextGeneration; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; + +namespace SemanticKernel.IntegrationTestsV2.Connectors.AzureOpenAI; + +#pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only. + +public sealed class AzureOpenAIChatCompletionStreamingTests() +{ + [Fact] + public async Task ChatCompletionShouldUseChatSystemPromptAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var chatCompletion = kernel.Services.GetRequiredService(); + + var settings = new AzureOpenAIPromptExecutionSettings { ChatSystemPrompt = "Reply \"I don't know\" to every question." }; + + var stringBuilder = new StringBuilder(); + + // Act + await foreach (var update in chatCompletion.GetStreamingChatMessageContentsAsync("What is the capital of France?", settings, kernel)) + { + stringBuilder.Append(update.Content); + } + + // Assert + Assert.Contains("I don't know", stringBuilder.ToString()); + } + + [Fact] + public async Task ChatCompletionShouldUseChatHistoryAndReturnMetadataAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var chatCompletion = kernel.Services.GetRequiredService(); + + var chatHistory = new ChatHistory("Reply \"I don't know\" to every question."); + chatHistory.AddUserMessage("What is the capital of France?"); + + var stringBuilder = new StringBuilder(); + var metadata = new Dictionary(); + + // Act + await foreach (var update in chatCompletion.GetStreamingChatMessageContentsAsync(chatHistory, null, kernel)) + { + stringBuilder.Append(update.Content); + + foreach (var key in update.Metadata!.Keys) + { + metadata[key] = update.Metadata[key]; + } + } + + // Assert + Assert.Contains("I don't know", stringBuilder.ToString()); + Assert.NotNull(metadata); + + Assert.True(metadata.TryGetValue("Id", out object? id)); + Assert.NotNull(id); + + Assert.True(metadata.TryGetValue("CreatedAt", out object? createdAt)); + Assert.NotNull(createdAt); + + Assert.True(metadata.ContainsKey("SystemFingerprint")); + + Assert.True(metadata.TryGetValue("FinishReason", out object? finishReason)); + Assert.Equal("Stop", finishReason); + } + + [Fact] + public async Task TextGenerationShouldUseChatSystemPromptAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var textGeneration = kernel.Services.GetRequiredService(); + + var settings = new AzureOpenAIPromptExecutionSettings { ChatSystemPrompt = "Reply \"I don't know\" to every question." }; + + var stringBuilder = new StringBuilder(); + + // Act + await foreach (var update in textGeneration.GetStreamingTextContentsAsync("What is the capital of France?", settings, kernel)) + { + stringBuilder.Append(update); + } + + // Assert + Assert.Contains("I don't know", stringBuilder.ToString()); + } + + [Fact] + public async Task TextGenerationShouldReturnMetadataAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var textGeneration = kernel.Services.GetRequiredService(); + + // Act + var stringBuilder = new StringBuilder(); + var metadata = new Dictionary(); + + // Act + await foreach (var update in textGeneration.GetStreamingTextContentsAsync("Reply \"I don't know\" to every question. What is the capital of France?", null, kernel)) + { + stringBuilder.Append(update); + + foreach (var key in update.Metadata!.Keys) + { + metadata[key] = update.Metadata[key]; + } + } + + // Assert + Assert.Contains("I don't know", stringBuilder.ToString()); + Assert.NotNull(metadata); + + Assert.True(metadata.TryGetValue("Id", out object? id)); + Assert.NotNull(id); + + Assert.True(metadata.TryGetValue("CreatedAt", out object? createdAt)); + Assert.NotNull(createdAt); + + Assert.True(metadata.ContainsKey("SystemFingerprint")); + + Assert.True(metadata.TryGetValue("FinishReason", out object? finishReason)); + Assert.Equal("Stop", finishReason); + } + + #region internals + + private Kernel CreateAndInitializeKernel() + { + var azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get(); + Assert.NotNull(azureOpenAIConfiguration); + Assert.NotNull(azureOpenAIConfiguration.ChatDeploymentName); + Assert.NotNull(azureOpenAIConfiguration.ApiKey); + Assert.NotNull(azureOpenAIConfiguration.Endpoint); + + var kernelBuilder = Kernel.CreateBuilder(); + + kernelBuilder.AddAzureOpenAIChatCompletion( + deploymentName: azureOpenAIConfiguration.ChatDeploymentName, + modelId: azureOpenAIConfiguration.ChatModelId, + endpoint: azureOpenAIConfiguration.Endpoint, + apiKey: azureOpenAIConfiguration.ApiKey); + + return kernelBuilder.Build(); + } + + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + #endregion +} diff --git a/dotnet/src/IntegrationTestsV2/IntegrationTestsV2.csproj b/dotnet/src/IntegrationTestsV2/IntegrationTestsV2.csproj index f3c704a27307..13bcc5ba0f44 100644 --- a/dotnet/src/IntegrationTestsV2/IntegrationTestsV2.csproj +++ b/dotnet/src/IntegrationTestsV2/IntegrationTestsV2.csproj @@ -42,6 +42,7 @@ + diff --git a/dotnet/src/IntegrationTestsV2/TestHelpers.cs b/dotnet/src/IntegrationTestsV2/TestHelpers.cs new file mode 100644 index 000000000000..350370d6c056 --- /dev/null +++ b/dotnet/src/IntegrationTestsV2/TestHelpers.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.IO; +using System.Linq; +using System.Reflection; +using Microsoft.SemanticKernel; + +namespace SemanticKernel.IntegrationTestsV2; + +internal static class TestHelpers +{ + private const string PluginsFolder = "../../../../../../prompt_template_samples"; + + internal static void ImportAllSamplePlugins(Kernel kernel) + { + ImportSamplePromptFunctions(kernel, PluginsFolder, + "ChatPlugin", + "SummarizePlugin", + "WriterPlugin", + "CalendarPlugin", + "ChildrensBookPlugin", + "ClassificationPlugin", + "CodingPlugin", + "FunPlugin", + "IntentDetectionPlugin", + "MiscPlugin", + "QAPlugin"); + } + + internal static void ImportAllSampleSkills(Kernel kernel) + { + ImportSamplePromptFunctions(kernel, "./skills", "FunSkill"); + } + + internal static IReadOnlyKernelPluginCollection ImportSamplePlugins(Kernel kernel, params string[] pluginNames) + { + return ImportSamplePromptFunctions(kernel, PluginsFolder, pluginNames); + } + + internal static IReadOnlyKernelPluginCollection ImportSamplePromptFunctions(Kernel kernel, string path, params string[] pluginNames) + { + string? currentAssemblyDirectory = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location); + if (string.IsNullOrWhiteSpace(currentAssemblyDirectory)) + { + throw new InvalidOperationException("Unable to determine current assembly directory."); + } + + string parentDirectory = Path.GetFullPath(Path.Combine(currentAssemblyDirectory, path)); + + return new KernelPluginCollection( + from pluginName in pluginNames + select kernel.ImportPluginFromPromptDirectory(Path.Combine(parentDirectory, pluginName))); + } +} diff --git a/dotnet/src/IntegrationTestsV2/TestSettings/AzureOpenAIConfiguration.cs b/dotnet/src/IntegrationTestsV2/TestSettings/AzureOpenAIConfiguration.cs new file mode 100644 index 000000000000..6a15a4c89dd7 --- /dev/null +++ b/dotnet/src/IntegrationTestsV2/TestSettings/AzureOpenAIConfiguration.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace SemanticKernel.IntegrationTests.TestSettings; + +[SuppressMessage("Performance", "CA1812:Internal class that is apparently never instantiated", + Justification = "Configuration classes are instantiated through IConfiguration.")] +internal sealed class AzureOpenAIConfiguration(string serviceId, string deploymentName, string endpoint, string apiKey, string? chatDeploymentName = null, string? modelId = null, string? chatModelId = null, string? embeddingModelId = null) +{ + public string ServiceId { get; set; } = serviceId; + public string DeploymentName { get; set; } = deploymentName; + public string ApiKey { get; set; } = apiKey; + public string? ChatDeploymentName { get; set; } = chatDeploymentName ?? deploymentName; + public string ModelId { get; set; } = modelId ?? deploymentName; + public string ChatModelId { get; set; } = chatModelId ?? deploymentName; + public string EmbeddingModelId { get; set; } = embeddingModelId ?? "text-embedding-ada-002"; + public string Endpoint { get; set; } = endpoint; +} From f9162e155d59787fafb8c22d847b58dc9433c049 Mon Sep 17 00:00:00 2001 From: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> Date: Fri, 28 Jun 2024 19:30:37 +0100 Subject: [PATCH 2/6] Update dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceCollectionExtensionsTests.cs Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> --- .../Extensions/AzureOpenAIServiceCollectionExtensionsTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceCollectionExtensionsTests.cs index 15aab3daf8b0..041cee3f3cc9 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceCollectionExtensionsTests.cs @@ -12,7 +12,7 @@ namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests.Extensions; /// -/// Unit tests for class. +/// Unit tests for class. /// public sealed class AzureOpenAIServiceCollectionExtensionsTests { From a3c38e79717eb5ef3b82510efd5c8a5b4d0408a4 Mon Sep 17 00:00:00 2001 From: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> Date: Fri, 28 Jun 2024 19:30:54 +0100 Subject: [PATCH 3/6] Update dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceKernelBuilderExtensionsTests.cs Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> --- .../AzureOpenAIServiceKernelBuilderExtensionsTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceKernelBuilderExtensionsTests.cs index 2ecd92151871..6025eb1d447f 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceKernelBuilderExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceKernelBuilderExtensionsTests.cs @@ -12,7 +12,7 @@ namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests.Extensions; /// -/// Unit tests for class. +/// Unit tests for class. /// public sealed class AzureOpenAIServiceKernelBuilderExtensionsTests { From 7a2b0d7207726e7c0c78d9fd02ec3852b385b91a Mon Sep 17 00:00:00 2001 From: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> Date: Fri, 28 Jun 2024 19:31:03 +0100 Subject: [PATCH 4/6] Update dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_NonStreamingTests.cs Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> --- .../AzureOpenAI/AzureOpenAIChatCompletion_NonStreamingTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_NonStreamingTests.cs b/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_NonStreamingTests.cs index cf87a2192a62..72d5ff34dec4 100644 --- a/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_NonStreamingTests.cs +++ b/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_NonStreamingTests.cs @@ -17,7 +17,7 @@ namespace SemanticKernel.IntegrationTestsV2.Connectors.AzureOpenAI; #pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only. -public sealed class AzureOpenAIChatCompletionNonStreamingTests() +public sealed class AzureOpenAIChatCompletionNonStreamingTests { [Fact] public async Task ChatCompletionShouldUseChatSystemPromptAsync() From 9829947323bf39a98d1fec81e1a5b8841b50f02f Mon Sep 17 00:00:00 2001 From: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> Date: Fri, 28 Jun 2024 19:31:10 +0100 Subject: [PATCH 5/6] Update dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_StreamingTests.cs Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> --- .../AzureOpenAI/AzureOpenAIChatCompletion_StreamingTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_StreamingTests.cs b/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_StreamingTests.cs index 3e841d04f775..57fb1c73fb72 100644 --- a/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_StreamingTests.cs +++ b/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletion_StreamingTests.cs @@ -16,7 +16,7 @@ namespace SemanticKernel.IntegrationTestsV2.Connectors.AzureOpenAI; #pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only. -public sealed class AzureOpenAIChatCompletionStreamingTests() +public sealed class AzureOpenAIChatCompletionStreamingTests { [Fact] public async Task ChatCompletionShouldUseChatSystemPromptAsync() From c3b08d6d2955678683776ebd82c5e277b2dd89ba Mon Sep 17 00:00:00 2001 From: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> Date: Fri, 28 Jun 2024 19:31:25 +0100 Subject: [PATCH 6/6] Update dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletionTests.cs Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> --- .../Connectors/AzureOpenAI/AzureOpenAIChatCompletionTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletionTests.cs b/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletionTests.cs index ce4641f380cf..04f1be7e45c7 100644 --- a/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletionTests.cs +++ b/dotnet/src/IntegrationTestsV2/Connectors/AzureOpenAI/AzureOpenAIChatCompletionTests.cs @@ -22,7 +22,7 @@ namespace SemanticKernel.IntegrationTestsV2.Connectors.AzureOpenAI; #pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only. -public sealed class AzureOpenAIChatCompletionTests() +public sealed class AzureOpenAIChatCompletionTests { [Fact] //[Fact(Skip = "Skipping while we investigate issue with GitHub actions.")]