From 7544a23ca7f30c58c5e240e8a564a301ba1dd249 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Thu, 11 Jul 2024 19:42:53 +0100 Subject: [PATCH 1/3] Added metadata, integration tests and some more adjustments --- dotnet/Directory.Packages.props | 2 +- .../Connectors.Ollama.UnitTests.csproj | 2 +- .../Connectors.Ollama/Core/ServiceBase.cs | 6 +- .../OllamaKernelBuilderExtensions.cs | 18 +- .../OllamaServiceCollectionExtensions.cs | 18 +- .../Connectors.Ollama/OllamaMetadata.cs | 54 ++++- .../Services/OllamaChatCompletionService.cs | 15 +- .../OllamaTextEmbeddingGenerationService.cs | 6 +- .../Services/OllamaTextGenerationService.cs | 14 +- .../Ollama/OllamaCompletionTests.cs | 221 ++++++++++++++++++ .../Ollama/OllamaTextEmbeddingTests.cs | 111 +++++++++ .../IntegrationTests/IntegrationTests.csproj | 1 + .../TestSettings/OllamaConfiguration.cs | 14 ++ 13 files changed, 439 insertions(+), 43 deletions(-) create mode 100644 dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs create mode 100644 dotnet/src/IntegrationTests/TestSettings/OllamaConfiguration.cs diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index bc2f3c81d3bc..e0bfad396dcb 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -34,7 +34,7 @@ - + diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj index 427f079b3c65..489e1b416d89 100644 --- a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj @@ -27,7 +27,7 @@ all - + diff --git a/dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs b/dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs index 192cbc238f2e..57b19adb0442 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs @@ -22,7 +22,7 @@ public abstract class ServiceBase internal readonly OllamaApiClient _client; internal ServiceBase(string model, - Uri baseUri, + Uri endpoint, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) { @@ -31,7 +31,7 @@ internal ServiceBase(string model, if (httpClient is not null) { - httpClient.BaseAddress ??= baseUri; + httpClient.BaseAddress ??= endpoint; // Try to add User-Agent header. if (!httpClient.DefaultRequestHeaders.TryGetValues("User-Agent", out _)) @@ -52,7 +52,7 @@ internal ServiceBase(string model, #pragma warning disable CA2000 // Dispose objects before losing scope // Client needs to be created to be able to inject Semantic Kernel headers var internalClient = HttpClientProvider.GetHttpClient(); - internalClient.BaseAddress = baseUri; + internalClient.BaseAddress = endpoint; internalClient.DefaultRequestHeaders.Add("User-Agent", HttpHeaderConstant.Values.UserAgent); internalClient.DefaultRequestHeaders.Add(HttpHeaderConstant.Names.SemanticKernelVersion, HttpHeaderConstant.Values.GetAssemblyVersion(typeof(Kernel))); diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs index c491d0e4397d..1e93e3d49baa 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs @@ -23,14 +23,14 @@ public static class OllamaKernelBuilderExtensions /// /// The kernel builder. /// The model for text generation. - /// The base uri to Ollama hosted service. + /// The endpoint to Ollama hosted service. /// The optional service ID. /// The optional custom HttpClient. /// The updated kernel builder. public static IKernelBuilder AddOllamaTextGeneration( this IKernelBuilder builder, string modelId, - Uri baseUri, + Uri endpoint, string? serviceId = null, HttpClient? httpClient = null) { @@ -39,7 +39,7 @@ public static IKernelBuilder AddOllamaTextGeneration( builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaTextGenerationService( model: modelId, - baseUri: baseUri, + endpoint: endpoint, httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), loggerFactory: serviceProvider.GetService())); return builder; @@ -74,14 +74,14 @@ public static IKernelBuilder AddOllamaTextGeneration( /// /// The kernel builder. /// The model for text generation. - /// The base uri to Ollama hosted service. + /// The endpoint to Ollama hosted service. /// The optional service ID. /// The optional custom HttpClient. /// The updated kernel builder. public static IKernelBuilder AddOllamaChatCompletion( this IKernelBuilder builder, string modelId, - Uri baseUri, + Uri endpoint, string? serviceId = null, HttpClient? httpClient = null) { @@ -91,7 +91,7 @@ public static IKernelBuilder AddOllamaChatCompletion( builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaChatCompletionService( model: modelId, - baseUri: baseUri, + endpoint: endpoint, httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), loggerFactory: serviceProvider.GetService())); @@ -128,14 +128,14 @@ public static IKernelBuilder AddOllamaChatCompletion( /// /// The kernel builder. /// The model for text generation. - /// The base uri to Ollama hosted service. + /// The endpoint to Ollama hosted service. /// The optional service ID. /// The optional custom HttpClient. /// The updated kernel builder. public static IKernelBuilder AddOllamaTextEmbeddingGeneration( this IKernelBuilder builder, string modelId, - Uri baseUri, + Uri endpoint, string? serviceId = null, HttpClient? httpClient = null) { @@ -144,7 +144,7 @@ public static IKernelBuilder AddOllamaTextEmbeddingGeneration( builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaTextEmbeddingGenerationService( model: modelId, - baseUri: baseUri, + endpoint: endpoint, httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), loggerFactory: serviceProvider.GetService())); diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs index 7d9e1e14f33e..d8c4658351e1 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs @@ -22,13 +22,13 @@ public static class OllamaServiceCollectionExtensions /// /// The target service collection. /// The model for text generation. - /// The base uri to Ollama hosted service. + /// The endpoint to Ollama hosted service. /// The optional service ID. /// The updated kernel builder. public static IServiceCollection AddOllamaTextGeneration( this IServiceCollection services, string modelId, - Uri baseUri, + Uri endpoint, string? serviceId = null) { Verify.NotNull(services); @@ -36,7 +36,7 @@ public static IServiceCollection AddOllamaTextGeneration( return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaTextGenerationService( model: modelId, - baseUri: baseUri, + endpoint: endpoint, httpClient: HttpClientProvider.GetHttpClient(serviceProvider), loggerFactory: serviceProvider.GetService())); } @@ -69,13 +69,13 @@ public static IServiceCollection AddOllamaTextGeneration( /// /// The target service collection. /// The model for text generation. - /// The base uri to Ollama hosted service. + /// The endpoint to Ollama hosted service. /// Optional service ID. /// The updated service collection. public static IServiceCollection AddOllamaChatCompletion( this IServiceCollection services, string modelId, - Uri baseUri, + Uri endpoint, string? serviceId = null) { Verify.NotNull(services); @@ -83,7 +83,7 @@ public static IServiceCollection AddOllamaChatCompletion( services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaChatCompletionService( model: modelId, - baseUri: baseUri, + endpoint: endpoint, httpClient: HttpClientProvider.GetHttpClient(serviceProvider), loggerFactory: serviceProvider.GetService())); @@ -118,13 +118,13 @@ public static IServiceCollection AddOllamaChatCompletion( /// /// The target service collection. /// The model for text generation. - /// The base uri to Ollama hosted service. + /// The endpoint to Ollama hosted service. /// Optional service ID. /// The updated kernel builder. public static IServiceCollection AddOllamaTextEmbeddingGeneration( this IServiceCollection services, string modelId, - Uri baseUri, + Uri endpoint, string? serviceId = null) { Verify.NotNull(services); @@ -132,7 +132,7 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration( return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaTextEmbeddingGenerationService( model: modelId, - baseUri: baseUri, + endpoint: endpoint, httpClient: HttpClientProvider.GetHttpClient(serviceProvider), loggerFactory: serviceProvider.GetService())); } diff --git a/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs b/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs index dbe16cbeafab..962826b525f0 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs @@ -4,6 +4,7 @@ using System.Collections.ObjectModel; using System.Runtime.CompilerServices; using OllamaSharp.Models; +using OllamaSharp.Models.Chat; namespace Microsoft.SemanticKernel.Connectors.Ollama; @@ -12,15 +13,45 @@ namespace Microsoft.SemanticKernel.Connectors.Ollama; /// public sealed class OllamaMetadata : ReadOnlyDictionary { - internal OllamaMetadata(GenerateCompletionDoneResponseStream ollamaResponse) : base(new Dictionary()) + internal OllamaMetadata(GenerateCompletionResponseStream? ollamaResponse) : base(new Dictionary()) { - this.TotalDuration = ollamaResponse.TotalDuration; - this.EvalCount = ollamaResponse.EvalCount; - this.EvalDuration = ollamaResponse.EvalDuration; + if (ollamaResponse is null) + { + return; + } + this.CreatedAt = ollamaResponse.CreatedAt; - this.LoadDuration = ollamaResponse.LoadDuration; - this.PromptEvalCount = ollamaResponse.PromptEvalCount; - this.PromptEvalDuration = ollamaResponse.PromptEvalDuration; + this.Done = ollamaResponse.Done; + + if (ollamaResponse is GenerateCompletionDoneResponseStream doneResponse) + { + this.TotalDuration = doneResponse.TotalDuration; + this.EvalCount = doneResponse.EvalCount; + this.EvalDuration = doneResponse.EvalDuration; + this.LoadDuration = doneResponse.LoadDuration; + this.PromptEvalCount = doneResponse.PromptEvalCount; + this.PromptEvalDuration = doneResponse.PromptEvalDuration; + } + } + + internal OllamaMetadata(ChatResponseStream? message) : base(new Dictionary()) + { + if (message is null) + { + return; + } + this.CreatedAt = message?.CreatedAt; + this.Done = message?.Done; + + if (message is ChatDoneResponseStream doneMessage) + { + this.TotalDuration = doneMessage.TotalDuration; + this.EvalCount = doneMessage.EvalCount; + this.EvalDuration = doneMessage.EvalDuration; + this.LoadDuration = doneMessage.LoadDuration; + this.PromptEvalCount = doneMessage.PromptEvalCount; + this.PromptEvalDuration = doneMessage.PromptEvalDuration; + } } /// @@ -59,6 +90,15 @@ public string? CreatedAt internal init => this.SetValueInDictionary(value); } + /// + /// The response is done + /// + public bool? Done + { + get => this.GetValueFromDictionary() as bool?; + internal init => this.SetValueInDictionary(value); + } + /// /// Time in nano seconds spent generating the response /// diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs index c6546622bc59..830b0ba82b9d 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs @@ -24,15 +24,15 @@ public sealed class OllamaChatCompletionService : ServiceBase, IChatCompletionSe /// Initializes a new instance of the class. /// /// The hosted model. - /// The base uri including the port where Ollama server is hosted + /// The endpoint including the port where Ollama server is hosted /// Optional HTTP client to be used for communication with the Ollama API. /// Optional logger factory to be used for logging. public OllamaChatCompletionService( string model, - Uri baseUri, + Uri endpoint, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) - : base(model, baseUri, httpClient, loggerFactory) + : base(model, endpoint, httpClient, loggerFactory) { } @@ -73,7 +73,7 @@ public async Task> GetChatMessageContentsAsync role: GetAuthorRole(message.Role) ?? AuthorRole.Assistant, content: message.Content, modelId: this._client.SelectedModel, - innerContent: message)]; + innerContent: message)]; // Currently the Ollama Message does not provide any metadata } /// @@ -88,7 +88,12 @@ public async IAsyncEnumerable GetStreamingChatMessa await foreach (var message in this._client.StreamChat(request, cancellationToken).ConfigureAwait(false)) { - yield return new StreamingChatMessageContent(GetAuthorRole(message?.Message.Role), message?.Message.Content, modelId: message?.Model, innerContent: message); + yield return new StreamingChatMessageContent( + GetAuthorRole(message?.Message.Role), + message?.Message.Content, + modelId: message?.Model, + innerContent: message, + metadata: new OllamaMetadata(message)); } } diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs index 121df1caf995..f5a7021d97fa 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs @@ -23,15 +23,15 @@ public sealed class OllamaTextEmbeddingGenerationService : ServiceBase, ITextEmb /// Initializes a new instance of the class. /// /// The hosted model. - /// The base uri including the port where Ollama server is hosted + /// The endpoint including the port where Ollama server is hosted /// Optional HTTP client to be used for communication with the Ollama API. /// Optional logger factory to be used for logging. public OllamaTextEmbeddingGenerationService( string model, - Uri baseUri, + Uri endpoint, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) - : base(model, baseUri, httpClient, loggerFactory) + : base(model, endpoint, httpClient, loggerFactory) { } diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs index 0294004811ff..405c345776ce 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs @@ -22,15 +22,15 @@ public sealed class OllamaTextGenerationService : ServiceBase, ITextGenerationSe /// Initializes a new instance of the class. /// /// The Ollama model for the text generation service. - /// The base uri including the port where Ollama server is hosted + /// The endpoint including the port where Ollama server is hosted /// Optional HTTP client to be used for communication with the Ollama API. /// Optional logger factory to be used for logging. public OllamaTextGenerationService( string model, - Uri baseUri, + Uri endpoint, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) - : base(model, baseUri, httpClient, loggerFactory) + : base(model, endpoint, httpClient, loggerFactory) { } @@ -60,7 +60,11 @@ public async Task> GetTextContentsAsync( { var content = await this._client.GetCompletion(prompt, null, cancellationToken).ConfigureAwait(false); - return [new(content.Response, modelId: this._client.SelectedModel, innerContent: content)]; + return [new(content.Response, modelId: this._client.SelectedModel, innerContent: content, metadata: + new Dictionary() + { + ["Context"] = content.Context + })]; } /// @@ -72,7 +76,7 @@ public async IAsyncEnumerable GetStreamingTextContentsAsyn { await foreach (var content in this._client.StreamCompletion(prompt, null, cancellationToken).ConfigureAwait(false)) { - yield return new StreamingTextContent(content?.Response, modelId: content?.Model, innerContent: content); + yield return new StreamingTextContent(content?.Response, modelId: content?.Model, innerContent: content, metadata: new OllamaMetadata(content)); } } } diff --git a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs new file mode 100644 index 000000000000..3c7fc21407a4 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Ollama; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.IntegrationTests.Connectors.Ollama; + +#pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only. + +public sealed class OllamaCompletionTests(ITestOutputHelper output) : IDisposable +{ + private const string InputParameterName = "input"; + private readonly IKernelBuilder _kernelBuilder = Kernel.CreateBuilder(); + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + [Theory(Skip = "For manual verification only")] + [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")] + public async Task ItStreamingTestAsync(string prompt, string expectedAnswerContains) + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + var builder = this._kernelBuilder; + + this.ConfigureChatOllama(this._kernelBuilder); + + Kernel target = builder.Build(); + + IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin"); + + StringBuilder fullResult = new(); + // Act + await foreach (var content in target.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt })) + { + if (content is StreamingChatMessageContent messageContent) + { + Assert.NotNull(messageContent.Role); + } + + fullResult.Append(content); + } + + // Assert + Assert.Contains(expectedAnswerContains, fullResult.ToString(), StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task ItShouldReturnMetadataAsync() + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + + this.ConfigureChatOllama(this._kernelBuilder); + + var kernel = this._kernelBuilder.Build(); + + var plugin = TestHelpers.ImportSamplePlugins(kernel, "FunPlugin"); + + // Act + StreamingKernelContent? lastUpdate = null; + await foreach (var update in kernel.InvokeStreamingAsync(plugin["FunPlugin"]["Limerick"])) + { + lastUpdate = update; + } + + // Assert + Assert.NotNull(lastUpdate); + Assert.NotNull(lastUpdate.Metadata); + + // CreatedAt + Assert.True(lastUpdate.Metadata.TryGetValue("CreatedAt", out object? createdAt)); + } + + [Theory(Skip = "For manual verification only")] + [InlineData("\n")] + [InlineData("\r\n")] + public async Task ItCompletesWithDifferentLineEndingsAsync(string lineEnding) + { + // Arrange + var prompt = + "Given a json input and a request. Apply the request on the json input and return the result. " + + $"Put the result in between tags{lineEnding}" + + $$"""Input:{{lineEnding}}{"name": "John", "age": 30}{{lineEnding}}{{lineEnding}}Request:{{lineEnding}}name"""; + + const string ExpectedAnswerContains = "result"; + + this._kernelBuilder.Services.AddSingleton(this._logger); + this.ConfigureChatOllama(this._kernelBuilder); + + Kernel target = this._kernelBuilder.Build(); + + IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin"); + + // Act + FunctionResult actual = await target.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }); + + // Assert + Assert.Contains(ExpectedAnswerContains, actual.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task ItInvokePromptTestAsync() + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + var builder = this._kernelBuilder; + this.ConfigureChatOllama(builder); + Kernel target = builder.Build(); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + FunctionResult actual = await target.InvokePromptAsync(prompt, new(new OllamaPromptExecutionSettings() { Temperature = 0.5f })); + + // Assert + Assert.Contains("Pike Place", actual.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + [Theory(Skip = "For manual verification only")] + [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")] + public async Task ItInvokeTestAsync(string prompt, string expectedAnswerContains) + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + var builder = this._kernelBuilder; + + this.ConfigureChatOllama(this._kernelBuilder); + + Kernel target = builder.Build(); + + IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin"); + + // Act + FunctionResult actual = await target.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }); + + // Assert + Assert.Contains(expectedAnswerContains, actual.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task ItShouldHaveSemanticKernelVersionHeaderAsync() + { + // Arrange + var config = this._configuration.GetSection("Ollama").Get(); + Assert.NotNull(config); + Assert.NotNull(config.ModelId); + Assert.NotNull(config.Endpoint); + + using var defaultHandler = new HttpClientHandler(); + using var httpHeaderHandler = new HttpHeaderHandler(defaultHandler); + using var httpClient = new HttpClient(httpHeaderHandler); + this._kernelBuilder.Services.AddSingleton(this._logger); + var builder = this._kernelBuilder; + builder.AddOllamaChatCompletion( + endpoint: config.Endpoint, + modelId: config.ModelId, + httpClient: httpClient); + Kernel target = builder.Build(); + + // Act + var result = await target.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?"); + + // Assert + Assert.NotNull(httpHeaderHandler.RequestHeaders); + Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var values)); + } + + #region internals + + private readonly XunitLogger _logger = new(output); + private readonly RedirectOutput _testOutputHelper = new(output); + + public void Dispose() + { + this._logger.Dispose(); + this._testOutputHelper.Dispose(); + } + + private void ConfigureChatOllama(IKernelBuilder kernelBuilder) + { + var config = this._configuration.GetSection("Ollama").Get(); + + Assert.NotNull(config); + Assert.NotNull(config.Endpoint); + Assert.NotNull(config.ModelId); + + kernelBuilder.AddOllamaChatCompletion( + modelId: config.ModelId, + endpoint: config.Endpoint); + } + + private sealed class HttpHeaderHandler(HttpMessageHandler innerHandler) : DelegatingHandler(innerHandler) + { + public System.Net.Http.Headers.HttpRequestHeaders? RequestHeaders { get; private set; } + + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + this.RequestHeaders = request.Headers; + return await base.SendAsync(request, cancellationToken); + } + } + + #endregion +} diff --git a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs new file mode 100644 index 000000000000..d26f48a79742 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Microsoft.SemanticKernel.Connectors.Ollama; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using Microsoft.SemanticKernel.Embeddings; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Ollama; + +public sealed class OllamaTextEmbeddingTests +{ + private const int AdaVectorLength = 1536; + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + [Theory(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] + [InlineData("test sentence")] + public async Task OpenAITestAsync(string testInputString) + { + // Arrange + OllamaConfiguration? config = this._configuration.GetSection("Ollama").Get(); + Assert.NotNull(config); + Assert.NotNull(config.ModelId); + Assert.NotNull(config.Endpoint); + + var embeddingGenerator = new OllamaTextEmbeddingGenerationService(config.ModelId, config.Endpoint); + + // Act + var singleResult = await embeddingGenerator.GenerateEmbeddingAsync(testInputString); + var batchResult = await embeddingGenerator.GenerateEmbeddingsAsync([testInputString, testInputString, testInputString]); + + // Assert + Assert.Equal(AdaVectorLength, singleResult.Length); + Assert.Equal(3, batchResult.Count); + } + + [Theory(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] + [InlineData(null, 3072)] + [InlineData(1024, 1024)] + public async Task OpenAIWithDimensionsAsync(int? dimensions, int expectedVectorLength) + { + // Arrange + const string TestInputString = "test sentence"; + + OpenAIConfiguration? openAIConfiguration = this._configuration.GetSection("OpenAIEmbeddings").Get(); + Assert.NotNull(openAIConfiguration); + + var embeddingGenerator = new OpenAITextEmbeddingGenerationService( + "text-embedding-3-large", + openAIConfiguration.ApiKey, + dimensions: dimensions); + + // Act + var result = await embeddingGenerator.GenerateEmbeddingAsync(TestInputString); + + // Assert + Assert.Equal(expectedVectorLength, result.Length); + } + + [Theory] + [InlineData("test sentence")] + public async Task AzureOpenAITestAsync(string testInputString) + { + // Arrange + AzureOpenAIConfiguration? azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAIEmbeddings").Get(); + Assert.NotNull(azureOpenAIConfiguration); + + var embeddingGenerator = new AzureOpenAITextEmbeddingGenerationService(azureOpenAIConfiguration.DeploymentName, + azureOpenAIConfiguration.Endpoint, + azureOpenAIConfiguration.ApiKey); + + // Act + var singleResult = await embeddingGenerator.GenerateEmbeddingAsync(testInputString); + var batchResult = await embeddingGenerator.GenerateEmbeddingsAsync([testInputString, testInputString, testInputString]); + + // Assert + Assert.Equal(AdaVectorLength, singleResult.Length); + Assert.Equal(3, batchResult.Count); + } + + [Theory] + [InlineData(null, 3072)] + [InlineData(1024, 1024)] + public async Task AzureOpenAIWithDimensionsAsync(int? dimensions, int expectedVectorLength) + { + // Arrange + const string TestInputString = "test sentence"; + + AzureOpenAIConfiguration? azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAIEmbeddings").Get(); + Assert.NotNull(azureOpenAIConfiguration); + + var embeddingGenerator = new AzureOpenAITextEmbeddingGenerationService( + "text-embedding-3-large", + azureOpenAIConfiguration.Endpoint, + azureOpenAIConfiguration.ApiKey, + dimensions: dimensions); + + // Act + var result = await embeddingGenerator.GenerateEmbeddingAsync(TestInputString); + + // Assert + Assert.Equal(expectedVectorLength, result.Length); + } +} diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index df5afa473ce7..9c14051ef665 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -63,6 +63,7 @@ + diff --git a/dotnet/src/IntegrationTests/TestSettings/OllamaConfiguration.cs b/dotnet/src/IntegrationTests/TestSettings/OllamaConfiguration.cs new file mode 100644 index 000000000000..cbf6e52351c4 --- /dev/null +++ b/dotnet/src/IntegrationTests/TestSettings/OllamaConfiguration.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace SemanticKernel.IntegrationTests.TestSettings; + +[SuppressMessage("Performance", "CA1812:Internal class that is apparently never instantiated", + Justification = "Configuration classes are instantiated through IConfiguration.")] +internal sealed class OllamaConfiguration +{ + public string? ModelId { get; set; } + public Uri? Endpoint { get; set; } +} From 50f500de99a0484ef07e8438636020900a12b8a6 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Thu, 11 Jul 2024 19:53:36 +0100 Subject: [PATCH 2/3] Fix typo --- docs/decisions/0046-kernel-content-graduation.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/decisions/0046-kernel-content-graduation.md b/docs/decisions/0046-kernel-content-graduation.md index 43518ddfa2d3..368c59bd7621 100644 --- a/docs/decisions/0046-kernel-content-graduation.md +++ b/docs/decisions/0046-kernel-content-graduation.md @@ -85,7 +85,7 @@ Pros: - With no deferred content we have simpler API and a single responsibility for contents. - Can be written and read in both `Data` or `DataUri` formats. - Can have a `Uri` reference property, which is common for specialized contexts. -- Fully serializeable. +- Fully serializable. - Data Uri parameters support (serialization included). - Data Uri and Base64 validation checks - Data Uri and Data can be dynamically generated @@ -197,7 +197,7 @@ Pros: - Can be used as a `BinaryContent` type - Can be written and read in both `Data` or `DataUri` formats. - Can have a `Uri` dedicated for referenced location. -- Fully serializeable. +- Fully serializable. - Data Uri parameters support (serialization included). - Data Uri and Base64 validation checks - Can be retrieved @@ -254,7 +254,7 @@ Pros: - Can be used as a `BinaryContent` type - Can be written and read in both `Data` or `DataUri` formats. - Can have a `Uri` dedicated for referenced location. -- Fully serializeable. +- Fully serializable. - Data Uri parameters support (serialization included). - Data Uri and Base64 validation checks - Can be retrieved From cd79a6a72dac961561527d6eb555f78f0c6fedf8 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Fri, 12 Jul 2024 09:29:06 +0100 Subject: [PATCH 3/3] Add missing integration tests --- .../OllamaKernelBuilderExtensions.cs | 12 +- .../OllamaServiceCollectionExtensions.cs | 12 +- .../OllamaPromptExecutionSettings.cs | 2 +- .../Services/OllamaChatCompletionService.cs | 12 +- .../OllamaTextEmbeddingGenerationService.cs | 12 +- .../Services/OllamaTextGenerationService.cs | 12 +- .../Ollama/OllamaCompletionTests.cs | 4 +- .../Ollama/OllamaTextEmbeddingTests.cs | 94 +++----- .../Ollama/OllamaTextGenerationTests.cs | 221 ++++++++++++++++++ 9 files changed, 279 insertions(+), 102 deletions(-) create mode 100644 dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextGenerationTests.cs diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs index 1e93e3d49baa..e442e8f9799e 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs @@ -38,7 +38,7 @@ public static IKernelBuilder AddOllamaTextGeneration( builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaTextGenerationService( - model: modelId, + modelId: modelId, endpoint: endpoint, httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), loggerFactory: serviceProvider.GetService())); @@ -63,7 +63,7 @@ public static IKernelBuilder AddOllamaTextGeneration( builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaTextGenerationService( - model: modelId, + modelId: modelId, ollamaClient: ollamaClient, loggerFactory: serviceProvider.GetService())); return builder; @@ -90,7 +90,7 @@ public static IKernelBuilder AddOllamaChatCompletion( builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaChatCompletionService( - model: modelId, + modelId: modelId, endpoint: endpoint, httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), loggerFactory: serviceProvider.GetService())); @@ -116,7 +116,7 @@ public static IKernelBuilder AddOllamaChatCompletion( builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaChatCompletionService( - model: modelId, + modelId: modelId, client: ollamaClient, loggerFactory: serviceProvider.GetService())); @@ -143,7 +143,7 @@ public static IKernelBuilder AddOllamaTextEmbeddingGeneration( builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaTextEmbeddingGenerationService( - model: modelId, + modelId: modelId, endpoint: endpoint, httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), loggerFactory: serviceProvider.GetService())); @@ -169,7 +169,7 @@ public static IKernelBuilder AddOllamaTextEmbeddingGeneration( builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaTextEmbeddingGenerationService( - model: modelId, + modelId: modelId, ollamaClient: ollamaClient, loggerFactory: serviceProvider.GetService())); diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs index d8c4658351e1..0a5497c74a73 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs @@ -35,7 +35,7 @@ public static IServiceCollection AddOllamaTextGeneration( return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaTextGenerationService( - model: modelId, + modelId: modelId, endpoint: endpoint, httpClient: HttpClientProvider.GetHttpClient(serviceProvider), loggerFactory: serviceProvider.GetService())); @@ -59,7 +59,7 @@ public static IServiceCollection AddOllamaTextGeneration( return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaTextGenerationService( - model: modelId, + modelId: modelId, ollamaClient: ollamaClient, loggerFactory: serviceProvider.GetService())); } @@ -82,7 +82,7 @@ public static IServiceCollection AddOllamaChatCompletion( services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaChatCompletionService( - model: modelId, + modelId: modelId, endpoint: endpoint, httpClient: HttpClientProvider.GetHttpClient(serviceProvider), loggerFactory: serviceProvider.GetService())); @@ -108,7 +108,7 @@ public static IServiceCollection AddOllamaChatCompletion( return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaChatCompletionService( - model: modelId, + modelId: modelId, client: ollamaClient, loggerFactory: serviceProvider.GetService())); } @@ -131,7 +131,7 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration( return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaTextEmbeddingGenerationService( - model: modelId, + modelId: modelId, endpoint: endpoint, httpClient: HttpClientProvider.GetHttpClient(serviceProvider), loggerFactory: serviceProvider.GetService())); @@ -155,7 +155,7 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration( return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => new OllamaTextEmbeddingGenerationService( - model: modelId, + modelId: modelId, ollamaClient: ollamaClient, loggerFactory: serviceProvider.GetService())); } diff --git a/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs index 9fc47bb9bb1b..283c6790c549 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs @@ -8,7 +8,7 @@ namespace Microsoft.SemanticKernel.Connectors.Ollama; /// -/// Ollama Execution Settings. +/// Ollama Prompt Execution Settings. /// public sealed class OllamaPromptExecutionSettings : PromptExecutionSettings { diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs index 830b0ba82b9d..f611b9625e88 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs @@ -23,30 +23,30 @@ public sealed class OllamaChatCompletionService : ServiceBase, IChatCompletionSe /// /// Initializes a new instance of the class. /// - /// The hosted model. + /// The hosted model. /// The endpoint including the port where Ollama server is hosted /// Optional HTTP client to be used for communication with the Ollama API. /// Optional logger factory to be used for logging. public OllamaChatCompletionService( - string model, + string modelId, Uri endpoint, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) - : base(model, endpoint, httpClient, loggerFactory) + : base(modelId, endpoint, httpClient, loggerFactory) { } /// /// Initializes a new instance of the class. /// - /// The hosted model. + /// The hosted model. /// The Ollama API client. /// Optional logger factory to be used for logging. public OllamaChatCompletionService( - string model, + string modelId, OllamaApiClient client, ILoggerFactory? loggerFactory = null) - : base(model, client, loggerFactory) + : base(modelId, client, loggerFactory) { } diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs index f5a7021d97fa..13adcd165c80 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs @@ -22,30 +22,30 @@ public sealed class OllamaTextEmbeddingGenerationService : ServiceBase, ITextEmb /// /// Initializes a new instance of the class. /// - /// The hosted model. + /// The hosted model. /// The endpoint including the port where Ollama server is hosted /// Optional HTTP client to be used for communication with the Ollama API. /// Optional logger factory to be used for logging. public OllamaTextEmbeddingGenerationService( - string model, + string modelId, Uri endpoint, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) - : base(model, endpoint, httpClient, loggerFactory) + : base(modelId, endpoint, httpClient, loggerFactory) { } /// /// Initializes a new instance of the class. /// - /// The hosted model. + /// The hosted model. /// The Ollama API client. /// Optional logger factory to be used for logging. public OllamaTextEmbeddingGenerationService( - string model, + string modelId, OllamaApiClient ollamaClient, ILoggerFactory? loggerFactory = null) - : base(model, ollamaClient, loggerFactory) + : base(modelId, ollamaClient, loggerFactory) { } diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs index 405c345776ce..29acd5f342c5 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs @@ -21,30 +21,30 @@ public sealed class OllamaTextGenerationService : ServiceBase, ITextGenerationSe /// /// Initializes a new instance of the class. /// - /// The Ollama model for the text generation service. + /// The Ollama model for the text generation service. /// The endpoint including the port where Ollama server is hosted /// Optional HTTP client to be used for communication with the Ollama API. /// Optional logger factory to be used for logging. public OllamaTextGenerationService( - string model, + string modelId, Uri endpoint, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) - : base(model, endpoint, httpClient, loggerFactory) + : base(modelId, endpoint, httpClient, loggerFactory) { } /// /// Initializes a new instance of the class. /// - /// The hosted model. + /// The hosted model. /// The Ollama API client. /// Optional logger factory to be used for logging. public OllamaTextGenerationService( - string model, + string modelId, OllamaApiClient ollamaClient, ILoggerFactory? loggerFactory = null) - : base(model, ollamaClient, loggerFactory) + : base(modelId, ollamaClient, loggerFactory) { } diff --git a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs index 3c7fc21407a4..4fabf80936ff 100644 --- a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Linq; using System.Net.Http; using System.Text; using System.Threading; @@ -11,7 +10,6 @@ using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.Ollama; -using Microsoft.SemanticKernel.Connectors.OpenAI; using SemanticKernel.IntegrationTests.TestSettings; using Xunit; using Xunit.Abstractions; @@ -33,7 +31,7 @@ public sealed class OllamaCompletionTests(ITestOutputHelper output) : IDisposabl [Theory(Skip = "For manual verification only")] [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")] - public async Task ItStreamingTestAsync(string prompt, string expectedAnswerContains) + public async Task ItInvokeStreamingWorksAsync(string prompt, string expectedAnswerContains) { // Arrange this._kernelBuilder.Services.AddSingleton(this._logger); diff --git a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs index d26f48a79742..f530098b473b 100644 --- a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs @@ -3,7 +3,6 @@ using System.Threading.Tasks; using Microsoft.Extensions.Configuration; using Microsoft.SemanticKernel.Connectors.Ollama; -using Microsoft.SemanticKernel.Connectors.OpenAI; using Microsoft.SemanticKernel.Embeddings; using SemanticKernel.IntegrationTests.TestSettings; using Xunit; @@ -12,50 +11,29 @@ namespace SemanticKernel.IntegrationTests.Connectors.Ollama; public sealed class OllamaTextEmbeddingTests { - private const int AdaVectorLength = 1536; private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() - .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true) + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) .AddEnvironmentVariables() .AddUserSecrets() .Build(); - [Theory(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] - [InlineData("test sentence")] - public async Task OpenAITestAsync(string testInputString) + [Theory(Skip = "For manual verification only")] + [InlineData("mxbai-embed-large", 1024)] + [InlineData("nomic-embed-text", 768)] + [InlineData("all-minilm", 384)] + public async Task GenerateEmbeddingHasExpectedLengthForModelAsync(string modelId, int expectedVectorLength) { // Arrange + const string TestInputString = "test sentence"; + OllamaConfiguration? config = this._configuration.GetSection("Ollama").Get(); Assert.NotNull(config); - Assert.NotNull(config.ModelId); Assert.NotNull(config.Endpoint); - var embeddingGenerator = new OllamaTextEmbeddingGenerationService(config.ModelId, config.Endpoint); - - // Act - var singleResult = await embeddingGenerator.GenerateEmbeddingAsync(testInputString); - var batchResult = await embeddingGenerator.GenerateEmbeddingsAsync([testInputString, testInputString, testInputString]); - - // Assert - Assert.Equal(AdaVectorLength, singleResult.Length); - Assert.Equal(3, batchResult.Count); - } - - [Theory(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] - [InlineData(null, 3072)] - [InlineData(1024, 1024)] - public async Task OpenAIWithDimensionsAsync(int? dimensions, int expectedVectorLength) - { - // Arrange - const string TestInputString = "test sentence"; - - OpenAIConfiguration? openAIConfiguration = this._configuration.GetSection("OpenAIEmbeddings").Get(); - Assert.NotNull(openAIConfiguration); - - var embeddingGenerator = new OpenAITextEmbeddingGenerationService( - "text-embedding-3-large", - openAIConfiguration.ApiKey, - dimensions: dimensions); + var embeddingGenerator = new OllamaTextEmbeddingGenerationService( + modelId, + config.Endpoint); // Act var result = await embeddingGenerator.GenerateEmbeddingAsync(TestInputString); @@ -64,48 +42,28 @@ public async Task OpenAIWithDimensionsAsync(int? dimensions, int expectedVectorL Assert.Equal(expectedVectorLength, result.Length); } - [Theory] - [InlineData("test sentence")] - public async Task AzureOpenAITestAsync(string testInputString) - { - // Arrange - AzureOpenAIConfiguration? azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAIEmbeddings").Get(); - Assert.NotNull(azureOpenAIConfiguration); - - var embeddingGenerator = new AzureOpenAITextEmbeddingGenerationService(azureOpenAIConfiguration.DeploymentName, - azureOpenAIConfiguration.Endpoint, - azureOpenAIConfiguration.ApiKey); - - // Act - var singleResult = await embeddingGenerator.GenerateEmbeddingAsync(testInputString); - var batchResult = await embeddingGenerator.GenerateEmbeddingsAsync([testInputString, testInputString, testInputString]); - - // Assert - Assert.Equal(AdaVectorLength, singleResult.Length); - Assert.Equal(3, batchResult.Count); - } - - [Theory] - [InlineData(null, 3072)] - [InlineData(1024, 1024)] - public async Task AzureOpenAIWithDimensionsAsync(int? dimensions, int expectedVectorLength) + [Theory(Skip = "For manual verification only")] + [InlineData("mxbai-embed-large", 1024)] + [InlineData("nomic-embed-text", 768)] + [InlineData("all-minilm", 384)] + public async Task GenerateEmbeddingsHasExpectedResultsLengthForModelAsync(string modelId, int expectedVectorLength) { // Arrange - const string TestInputString = "test sentence"; + string[] testInputStrings = ["test sentence 1", "test sentence 2", "test sentence 3"]; - AzureOpenAIConfiguration? azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAIEmbeddings").Get(); - Assert.NotNull(azureOpenAIConfiguration); + OllamaConfiguration? config = this._configuration.GetSection("Ollama").Get(); + Assert.NotNull(config); + Assert.NotNull(config.Endpoint); - var embeddingGenerator = new AzureOpenAITextEmbeddingGenerationService( - "text-embedding-3-large", - azureOpenAIConfiguration.Endpoint, - azureOpenAIConfiguration.ApiKey, - dimensions: dimensions); + var embeddingGenerator = new OllamaTextEmbeddingGenerationService( + modelId, + config.Endpoint); // Act - var result = await embeddingGenerator.GenerateEmbeddingAsync(TestInputString); + var result = await embeddingGenerator.GenerateEmbeddingsAsync(testInputStrings); // Assert - Assert.Equal(expectedVectorLength, result.Length); + Assert.Equal(testInputStrings.Length, result.Count); + Assert.All(result, r => Assert.Equal(expectedVectorLength, r.Length)); } } diff --git a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextGenerationTests.cs b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextGenerationTests.cs new file mode 100644 index 000000000000..597fdf331db2 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextGenerationTests.cs @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Ollama; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.IntegrationTests.Connectors.Ollama; + +#pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only. + +public sealed class OllamaTextGenerationTests(ITestOutputHelper output) : IDisposable +{ + private const string InputParameterName = "input"; + private readonly IKernelBuilder _kernelBuilder = Kernel.CreateBuilder(); + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + [Theory(Skip = "For manual verification only")] + [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")] + public async Task ItInvokeStreamingWorksAsync(string prompt, string expectedAnswerContains) + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + var builder = this._kernelBuilder; + + this.ConfigureTextOllama(this._kernelBuilder); + + Kernel target = builder.Build(); + + IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin"); + + StringBuilder fullResult = new(); + // Act + await foreach (var content in target.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt })) + { + fullResult.Append(content); + Assert.NotNull(content.Metadata); + } + + // Assert + Assert.Contains(expectedAnswerContains, fullResult.ToString(), StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task ItShouldReturnMetadataAsync() + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + + this.ConfigureTextOllama(this._kernelBuilder); + + var kernel = this._kernelBuilder.Build(); + + var plugin = TestHelpers.ImportSamplePlugins(kernel, "FunPlugin"); + + // Act + StreamingKernelContent? lastUpdate = null; + await foreach (var update in kernel.InvokeStreamingAsync(plugin["FunPlugin"]["Limerick"])) + { + lastUpdate = update; + } + + // Assert + Assert.NotNull(lastUpdate); + Assert.NotNull(lastUpdate.Metadata); + + // CreatedAt + Assert.True(lastUpdate.Metadata.TryGetValue("CreatedAt", out object? createdAt)); + Assert.IsType(lastUpdate.Metadata); + OllamaMetadata ollamaMetadata = (OllamaMetadata)lastUpdate.Metadata; + Assert.NotNull(ollamaMetadata.CreatedAt); + Assert.NotEqual(0, ollamaMetadata.TotalDuration); + Assert.NotEqual(0, ollamaMetadata.EvalDuration); + } + + [Theory(Skip = "For manual verification only")] + [InlineData("\n")] + [InlineData("\r\n")] + public async Task ItCompletesWithDifferentLineEndingsAsync(string lineEnding) + { + // Arrange + var prompt = + "Given a json input and a request. Apply the request on the json input and return the result. " + + $"Put the result in between tags{lineEnding}" + + $$"""Input:{{lineEnding}}{"name": "John", "age": 30}{{lineEnding}}{{lineEnding}}Request:{{lineEnding}}name"""; + + const string ExpectedAnswerContains = "result"; + + this._kernelBuilder.Services.AddSingleton(this._logger); + this.ConfigureTextOllama(this._kernelBuilder); + + Kernel target = this._kernelBuilder.Build(); + + IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin"); + + // Act + FunctionResult actual = await target.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }); + + // Assert + Assert.Contains(ExpectedAnswerContains, actual.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task ItInvokePromptTestAsync() + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + var builder = this._kernelBuilder; + this.ConfigureTextOllama(builder); + Kernel target = builder.Build(); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + FunctionResult actual = await target.InvokePromptAsync(prompt, new(new OllamaPromptExecutionSettings() { Temperature = 0.5f })); + + // Assert + Assert.Contains("Pike Place", actual.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + [Theory(Skip = "For manual verification only")] + [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")] + public async Task ItInvokeTestAsync(string prompt, string expectedAnswerContains) + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + var builder = this._kernelBuilder; + + this.ConfigureTextOllama(this._kernelBuilder); + + Kernel target = builder.Build(); + + IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin"); + + // Act + FunctionResult actual = await target.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }); + + // Assert + Assert.Contains(expectedAnswerContains, actual.GetValue(), StringComparison.OrdinalIgnoreCase); + Assert.NotNull(actual.Metadata); + } + + [Fact(Skip = "For manual verification only")] + public async Task ItShouldHaveSemanticKernelVersionHeaderAsync() + { + // Arrange + var config = this._configuration.GetSection("Ollama").Get(); + Assert.NotNull(config); + Assert.NotNull(config.ModelId); + Assert.NotNull(config.Endpoint); + + using var defaultHandler = new HttpClientHandler(); + using var httpHeaderHandler = new HttpHeaderHandler(defaultHandler); + using var httpClient = new HttpClient(httpHeaderHandler); + this._kernelBuilder.Services.AddSingleton(this._logger); + var builder = this._kernelBuilder; + builder.AddOllamaTextGeneration( + endpoint: config.Endpoint, + modelId: config.ModelId, + httpClient: httpClient); + Kernel target = builder.Build(); + + // Act + var result = await target.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?"); + + // Assert + Assert.NotNull(httpHeaderHandler.RequestHeaders); + Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var values)); + } + + #region internals + + private readonly XunitLogger _logger = new(output); + private readonly RedirectOutput _testOutputHelper = new(output); + + public void Dispose() + { + this._logger.Dispose(); + this._testOutputHelper.Dispose(); + } + + private void ConfigureTextOllama(IKernelBuilder kernelBuilder) + { + var config = this._configuration.GetSection("Ollama").Get(); + + Assert.NotNull(config); + Assert.NotNull(config.Endpoint); + Assert.NotNull(config.ModelId); + + kernelBuilder.AddOllamaTextGeneration( + modelId: config.ModelId, + endpoint: config.Endpoint); + } + + private sealed class HttpHeaderHandler(HttpMessageHandler innerHandler) : DelegatingHandler(innerHandler) + { + public System.Net.Http.Headers.HttpRequestHeaders? RequestHeaders { get; private set; } + + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + this.RequestHeaders = request.Headers; + return await base.SendAsync(request, cancellationToken); + } + } + + #endregion +}