From 7544a23ca7f30c58c5e240e8a564a301ba1dd249 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Thu, 11 Jul 2024 19:42:53 +0100
Subject: [PATCH 1/3] Added metadata, integration tests and some more
adjustments
---
dotnet/Directory.Packages.props | 2 +-
.../Connectors.Ollama.UnitTests.csproj | 2 +-
.../Connectors.Ollama/Core/ServiceBase.cs | 6 +-
.../OllamaKernelBuilderExtensions.cs | 18 +-
.../OllamaServiceCollectionExtensions.cs | 18 +-
.../Connectors.Ollama/OllamaMetadata.cs | 54 ++++-
.../Services/OllamaChatCompletionService.cs | 15 +-
.../OllamaTextEmbeddingGenerationService.cs | 6 +-
.../Services/OllamaTextGenerationService.cs | 14 +-
.../Ollama/OllamaCompletionTests.cs | 221 ++++++++++++++++++
.../Ollama/OllamaTextEmbeddingTests.cs | 111 +++++++++
.../IntegrationTests/IntegrationTests.csproj | 1 +
.../TestSettings/OllamaConfiguration.cs | 14 ++
13 files changed, 439 insertions(+), 43 deletions(-)
create mode 100644 dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs
create mode 100644 dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs
create mode 100644 dotnet/src/IntegrationTests/TestSettings/OllamaConfiguration.cs
diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props
index bc2f3c81d3bc..e0bfad396dcb 100644
--- a/dotnet/Directory.Packages.props
+++ b/dotnet/Directory.Packages.props
@@ -34,7 +34,7 @@
-
+
diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj
index 427f079b3c65..489e1b416d89 100644
--- a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj
+++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj
@@ -27,7 +27,7 @@
all
-
+
diff --git a/dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs b/dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs
index 192cbc238f2e..57b19adb0442 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs
@@ -22,7 +22,7 @@ public abstract class ServiceBase
internal readonly OllamaApiClient _client;
internal ServiceBase(string model,
- Uri baseUri,
+ Uri endpoint,
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
{
@@ -31,7 +31,7 @@ internal ServiceBase(string model,
if (httpClient is not null)
{
- httpClient.BaseAddress ??= baseUri;
+ httpClient.BaseAddress ??= endpoint;
// Try to add User-Agent header.
if (!httpClient.DefaultRequestHeaders.TryGetValues("User-Agent", out _))
@@ -52,7 +52,7 @@ internal ServiceBase(string model,
#pragma warning disable CA2000 // Dispose objects before losing scope
// Client needs to be created to be able to inject Semantic Kernel headers
var internalClient = HttpClientProvider.GetHttpClient();
- internalClient.BaseAddress = baseUri;
+ internalClient.BaseAddress = endpoint;
internalClient.DefaultRequestHeaders.Add("User-Agent", HttpHeaderConstant.Values.UserAgent);
internalClient.DefaultRequestHeaders.Add(HttpHeaderConstant.Names.SemanticKernelVersion, HttpHeaderConstant.Values.GetAssemblyVersion(typeof(Kernel)));
diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs
index c491d0e4397d..1e93e3d49baa 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs
@@ -23,14 +23,14 @@ public static class OllamaKernelBuilderExtensions
///
/// The kernel builder.
/// The model for text generation.
- /// The base uri to Ollama hosted service.
+ /// The endpoint to Ollama hosted service.
/// The optional service ID.
/// The optional custom HttpClient.
/// The updated kernel builder.
public static IKernelBuilder AddOllamaTextGeneration(
this IKernelBuilder builder,
string modelId,
- Uri baseUri,
+ Uri endpoint,
string? serviceId = null,
HttpClient? httpClient = null)
{
@@ -39,7 +39,7 @@ public static IKernelBuilder AddOllamaTextGeneration(
builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaTextGenerationService(
model: modelId,
- baseUri: baseUri,
+ endpoint: endpoint,
httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider),
loggerFactory: serviceProvider.GetService()));
return builder;
@@ -74,14 +74,14 @@ public static IKernelBuilder AddOllamaTextGeneration(
///
/// The kernel builder.
/// The model for text generation.
- /// The base uri to Ollama hosted service.
+ /// The endpoint to Ollama hosted service.
/// The optional service ID.
/// The optional custom HttpClient.
/// The updated kernel builder.
public static IKernelBuilder AddOllamaChatCompletion(
this IKernelBuilder builder,
string modelId,
- Uri baseUri,
+ Uri endpoint,
string? serviceId = null,
HttpClient? httpClient = null)
{
@@ -91,7 +91,7 @@ public static IKernelBuilder AddOllamaChatCompletion(
builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaChatCompletionService(
model: modelId,
- baseUri: baseUri,
+ endpoint: endpoint,
httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider),
loggerFactory: serviceProvider.GetService()));
@@ -128,14 +128,14 @@ public static IKernelBuilder AddOllamaChatCompletion(
///
/// The kernel builder.
/// The model for text generation.
- /// The base uri to Ollama hosted service.
+ /// The endpoint to Ollama hosted service.
/// The optional service ID.
/// The optional custom HttpClient.
/// The updated kernel builder.
public static IKernelBuilder AddOllamaTextEmbeddingGeneration(
this IKernelBuilder builder,
string modelId,
- Uri baseUri,
+ Uri endpoint,
string? serviceId = null,
HttpClient? httpClient = null)
{
@@ -144,7 +144,7 @@ public static IKernelBuilder AddOllamaTextEmbeddingGeneration(
builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaTextEmbeddingGenerationService(
model: modelId,
- baseUri: baseUri,
+ endpoint: endpoint,
httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider),
loggerFactory: serviceProvider.GetService()));
diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs
index 7d9e1e14f33e..d8c4658351e1 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs
@@ -22,13 +22,13 @@ public static class OllamaServiceCollectionExtensions
///
/// The target service collection.
/// The model for text generation.
- /// The base uri to Ollama hosted service.
+ /// The endpoint to Ollama hosted service.
/// The optional service ID.
/// The updated kernel builder.
public static IServiceCollection AddOllamaTextGeneration(
this IServiceCollection services,
string modelId,
- Uri baseUri,
+ Uri endpoint,
string? serviceId = null)
{
Verify.NotNull(services);
@@ -36,7 +36,7 @@ public static IServiceCollection AddOllamaTextGeneration(
return services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaTextGenerationService(
model: modelId,
- baseUri: baseUri,
+ endpoint: endpoint,
httpClient: HttpClientProvider.GetHttpClient(serviceProvider),
loggerFactory: serviceProvider.GetService()));
}
@@ -69,13 +69,13 @@ public static IServiceCollection AddOllamaTextGeneration(
///
/// The target service collection.
/// The model for text generation.
- /// The base uri to Ollama hosted service.
+ /// The endpoint to Ollama hosted service.
/// Optional service ID.
/// The updated service collection.
public static IServiceCollection AddOllamaChatCompletion(
this IServiceCollection services,
string modelId,
- Uri baseUri,
+ Uri endpoint,
string? serviceId = null)
{
Verify.NotNull(services);
@@ -83,7 +83,7 @@ public static IServiceCollection AddOllamaChatCompletion(
services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaChatCompletionService(
model: modelId,
- baseUri: baseUri,
+ endpoint: endpoint,
httpClient: HttpClientProvider.GetHttpClient(serviceProvider),
loggerFactory: serviceProvider.GetService()));
@@ -118,13 +118,13 @@ public static IServiceCollection AddOllamaChatCompletion(
///
/// The target service collection.
/// The model for text generation.
- /// The base uri to Ollama hosted service.
+ /// The endpoint to Ollama hosted service.
/// Optional service ID.
/// The updated kernel builder.
public static IServiceCollection AddOllamaTextEmbeddingGeneration(
this IServiceCollection services,
string modelId,
- Uri baseUri,
+ Uri endpoint,
string? serviceId = null)
{
Verify.NotNull(services);
@@ -132,7 +132,7 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration(
return services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaTextEmbeddingGenerationService(
model: modelId,
- baseUri: baseUri,
+ endpoint: endpoint,
httpClient: HttpClientProvider.GetHttpClient(serviceProvider),
loggerFactory: serviceProvider.GetService()));
}
diff --git a/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs b/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs
index dbe16cbeafab..962826b525f0 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs
@@ -4,6 +4,7 @@
using System.Collections.ObjectModel;
using System.Runtime.CompilerServices;
using OllamaSharp.Models;
+using OllamaSharp.Models.Chat;
namespace Microsoft.SemanticKernel.Connectors.Ollama;
@@ -12,15 +13,45 @@ namespace Microsoft.SemanticKernel.Connectors.Ollama;
///
public sealed class OllamaMetadata : ReadOnlyDictionary
{
- internal OllamaMetadata(GenerateCompletionDoneResponseStream ollamaResponse) : base(new Dictionary())
+ internal OllamaMetadata(GenerateCompletionResponseStream? ollamaResponse) : base(new Dictionary())
{
- this.TotalDuration = ollamaResponse.TotalDuration;
- this.EvalCount = ollamaResponse.EvalCount;
- this.EvalDuration = ollamaResponse.EvalDuration;
+ if (ollamaResponse is null)
+ {
+ return;
+ }
+
this.CreatedAt = ollamaResponse.CreatedAt;
- this.LoadDuration = ollamaResponse.LoadDuration;
- this.PromptEvalCount = ollamaResponse.PromptEvalCount;
- this.PromptEvalDuration = ollamaResponse.PromptEvalDuration;
+ this.Done = ollamaResponse.Done;
+
+ if (ollamaResponse is GenerateCompletionDoneResponseStream doneResponse)
+ {
+ this.TotalDuration = doneResponse.TotalDuration;
+ this.EvalCount = doneResponse.EvalCount;
+ this.EvalDuration = doneResponse.EvalDuration;
+ this.LoadDuration = doneResponse.LoadDuration;
+ this.PromptEvalCount = doneResponse.PromptEvalCount;
+ this.PromptEvalDuration = doneResponse.PromptEvalDuration;
+ }
+ }
+
+ internal OllamaMetadata(ChatResponseStream? message) : base(new Dictionary())
+ {
+ if (message is null)
+ {
+ return;
+ }
+ this.CreatedAt = message?.CreatedAt;
+ this.Done = message?.Done;
+
+ if (message is ChatDoneResponseStream doneMessage)
+ {
+ this.TotalDuration = doneMessage.TotalDuration;
+ this.EvalCount = doneMessage.EvalCount;
+ this.EvalDuration = doneMessage.EvalDuration;
+ this.LoadDuration = doneMessage.LoadDuration;
+ this.PromptEvalCount = doneMessage.PromptEvalCount;
+ this.PromptEvalDuration = doneMessage.PromptEvalDuration;
+ }
}
///
@@ -59,6 +90,15 @@ public string? CreatedAt
internal init => this.SetValueInDictionary(value);
}
+ ///
+ /// The response is done
+ ///
+ public bool? Done
+ {
+ get => this.GetValueFromDictionary() as bool?;
+ internal init => this.SetValueInDictionary(value);
+ }
+
///
/// Time in nano seconds spent generating the response
///
diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs
index c6546622bc59..830b0ba82b9d 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs
@@ -24,15 +24,15 @@ public sealed class OllamaChatCompletionService : ServiceBase, IChatCompletionSe
/// Initializes a new instance of the class.
///
/// The hosted model.
- /// The base uri including the port where Ollama server is hosted
+ /// The endpoint including the port where Ollama server is hosted
/// Optional HTTP client to be used for communication with the Ollama API.
/// Optional logger factory to be used for logging.
public OllamaChatCompletionService(
string model,
- Uri baseUri,
+ Uri endpoint,
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
- : base(model, baseUri, httpClient, loggerFactory)
+ : base(model, endpoint, httpClient, loggerFactory)
{
}
@@ -73,7 +73,7 @@ public async Task> GetChatMessageContentsAsync
role: GetAuthorRole(message.Role) ?? AuthorRole.Assistant,
content: message.Content,
modelId: this._client.SelectedModel,
- innerContent: message)];
+ innerContent: message)]; // Currently the Ollama Message does not provide any metadata
}
///
@@ -88,7 +88,12 @@ public async IAsyncEnumerable GetStreamingChatMessa
await foreach (var message in this._client.StreamChat(request, cancellationToken).ConfigureAwait(false))
{
- yield return new StreamingChatMessageContent(GetAuthorRole(message?.Message.Role), message?.Message.Content, modelId: message?.Model, innerContent: message);
+ yield return new StreamingChatMessageContent(
+ GetAuthorRole(message?.Message.Role),
+ message?.Message.Content,
+ modelId: message?.Model,
+ innerContent: message,
+ metadata: new OllamaMetadata(message));
}
}
diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs
index 121df1caf995..f5a7021d97fa 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs
@@ -23,15 +23,15 @@ public sealed class OllamaTextEmbeddingGenerationService : ServiceBase, ITextEmb
/// Initializes a new instance of the class.
///
/// The hosted model.
- /// The base uri including the port where Ollama server is hosted
+ /// The endpoint including the port where Ollama server is hosted
/// Optional HTTP client to be used for communication with the Ollama API.
/// Optional logger factory to be used for logging.
public OllamaTextEmbeddingGenerationService(
string model,
- Uri baseUri,
+ Uri endpoint,
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
- : base(model, baseUri, httpClient, loggerFactory)
+ : base(model, endpoint, httpClient, loggerFactory)
{
}
diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs
index 0294004811ff..405c345776ce 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs
@@ -22,15 +22,15 @@ public sealed class OllamaTextGenerationService : ServiceBase, ITextGenerationSe
/// Initializes a new instance of the class.
///
/// The Ollama model for the text generation service.
- /// The base uri including the port where Ollama server is hosted
+ /// The endpoint including the port where Ollama server is hosted
/// Optional HTTP client to be used for communication with the Ollama API.
/// Optional logger factory to be used for logging.
public OllamaTextGenerationService(
string model,
- Uri baseUri,
+ Uri endpoint,
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
- : base(model, baseUri, httpClient, loggerFactory)
+ : base(model, endpoint, httpClient, loggerFactory)
{
}
@@ -60,7 +60,11 @@ public async Task> GetTextContentsAsync(
{
var content = await this._client.GetCompletion(prompt, null, cancellationToken).ConfigureAwait(false);
- return [new(content.Response, modelId: this._client.SelectedModel, innerContent: content)];
+ return [new(content.Response, modelId: this._client.SelectedModel, innerContent: content, metadata:
+ new Dictionary()
+ {
+ ["Context"] = content.Context
+ })];
}
///
@@ -72,7 +76,7 @@ public async IAsyncEnumerable GetStreamingTextContentsAsyn
{
await foreach (var content in this._client.StreamCompletion(prompt, null, cancellationToken).ConfigureAwait(false))
{
- yield return new StreamingTextContent(content?.Response, modelId: content?.Model, innerContent: content);
+ yield return new StreamingTextContent(content?.Response, modelId: content?.Model, innerContent: content, metadata: new OllamaMetadata(content));
}
}
}
diff --git a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs
new file mode 100644
index 000000000000..3c7fc21407a4
--- /dev/null
+++ b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs
@@ -0,0 +1,221 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System;
+using System.Linq;
+using System.Net.Http;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Logging;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.Connectors.Ollama;
+using Microsoft.SemanticKernel.Connectors.OpenAI;
+using SemanticKernel.IntegrationTests.TestSettings;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace SemanticKernel.IntegrationTests.Connectors.Ollama;
+
+#pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only.
+
+public sealed class OllamaCompletionTests(ITestOutputHelper output) : IDisposable
+{
+ private const string InputParameterName = "input";
+ private readonly IKernelBuilder _kernelBuilder = Kernel.CreateBuilder();
+ private readonly IConfigurationRoot _configuration = new ConfigurationBuilder()
+ .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true)
+ .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true)
+ .AddEnvironmentVariables()
+ .AddUserSecrets()
+ .Build();
+
+ [Theory(Skip = "For manual verification only")]
+ [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")]
+ public async Task ItStreamingTestAsync(string prompt, string expectedAnswerContains)
+ {
+ // Arrange
+ this._kernelBuilder.Services.AddSingleton(this._logger);
+ var builder = this._kernelBuilder;
+
+ this.ConfigureChatOllama(this._kernelBuilder);
+
+ Kernel target = builder.Build();
+
+ IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin");
+
+ StringBuilder fullResult = new();
+ // Act
+ await foreach (var content in target.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }))
+ {
+ if (content is StreamingChatMessageContent messageContent)
+ {
+ Assert.NotNull(messageContent.Role);
+ }
+
+ fullResult.Append(content);
+ }
+
+ // Assert
+ Assert.Contains(expectedAnswerContains, fullResult.ToString(), StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Fact(Skip = "For manual verification only")]
+ public async Task ItShouldReturnMetadataAsync()
+ {
+ // Arrange
+ this._kernelBuilder.Services.AddSingleton(this._logger);
+
+ this.ConfigureChatOllama(this._kernelBuilder);
+
+ var kernel = this._kernelBuilder.Build();
+
+ var plugin = TestHelpers.ImportSamplePlugins(kernel, "FunPlugin");
+
+ // Act
+ StreamingKernelContent? lastUpdate = null;
+ await foreach (var update in kernel.InvokeStreamingAsync(plugin["FunPlugin"]["Limerick"]))
+ {
+ lastUpdate = update;
+ }
+
+ // Assert
+ Assert.NotNull(lastUpdate);
+ Assert.NotNull(lastUpdate.Metadata);
+
+ // CreatedAt
+ Assert.True(lastUpdate.Metadata.TryGetValue("CreatedAt", out object? createdAt));
+ }
+
+ [Theory(Skip = "For manual verification only")]
+ [InlineData("\n")]
+ [InlineData("\r\n")]
+ public async Task ItCompletesWithDifferentLineEndingsAsync(string lineEnding)
+ {
+ // Arrange
+ var prompt =
+ "Given a json input and a request. Apply the request on the json input and return the result. " +
+ $"Put the result in between tags{lineEnding}" +
+ $$"""Input:{{lineEnding}}{"name": "John", "age": 30}{{lineEnding}}{{lineEnding}}Request:{{lineEnding}}name""";
+
+ const string ExpectedAnswerContains = "result";
+
+ this._kernelBuilder.Services.AddSingleton(this._logger);
+ this.ConfigureChatOllama(this._kernelBuilder);
+
+ Kernel target = this._kernelBuilder.Build();
+
+ IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin");
+
+ // Act
+ FunctionResult actual = await target.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt });
+
+ // Assert
+ Assert.Contains(ExpectedAnswerContains, actual.GetValue(), StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Fact(Skip = "For manual verification only")]
+ public async Task ItInvokePromptTestAsync()
+ {
+ // Arrange
+ this._kernelBuilder.Services.AddSingleton(this._logger);
+ var builder = this._kernelBuilder;
+ this.ConfigureChatOllama(builder);
+ Kernel target = builder.Build();
+
+ var prompt = "Where is the most famous fish market in Seattle, Washington, USA?";
+
+ // Act
+ FunctionResult actual = await target.InvokePromptAsync(prompt, new(new OllamaPromptExecutionSettings() { Temperature = 0.5f }));
+
+ // Assert
+ Assert.Contains("Pike Place", actual.GetValue(), StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Theory(Skip = "For manual verification only")]
+ [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")]
+ public async Task ItInvokeTestAsync(string prompt, string expectedAnswerContains)
+ {
+ // Arrange
+ this._kernelBuilder.Services.AddSingleton(this._logger);
+ var builder = this._kernelBuilder;
+
+ this.ConfigureChatOllama(this._kernelBuilder);
+
+ Kernel target = builder.Build();
+
+ IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin");
+
+ // Act
+ FunctionResult actual = await target.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt });
+
+ // Assert
+ Assert.Contains(expectedAnswerContains, actual.GetValue(), StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Fact(Skip = "For manual verification only")]
+ public async Task ItShouldHaveSemanticKernelVersionHeaderAsync()
+ {
+ // Arrange
+ var config = this._configuration.GetSection("Ollama").Get();
+ Assert.NotNull(config);
+ Assert.NotNull(config.ModelId);
+ Assert.NotNull(config.Endpoint);
+
+ using var defaultHandler = new HttpClientHandler();
+ using var httpHeaderHandler = new HttpHeaderHandler(defaultHandler);
+ using var httpClient = new HttpClient(httpHeaderHandler);
+ this._kernelBuilder.Services.AddSingleton(this._logger);
+ var builder = this._kernelBuilder;
+ builder.AddOllamaChatCompletion(
+ endpoint: config.Endpoint,
+ modelId: config.ModelId,
+ httpClient: httpClient);
+ Kernel target = builder.Build();
+
+ // Act
+ var result = await target.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?");
+
+ // Assert
+ Assert.NotNull(httpHeaderHandler.RequestHeaders);
+ Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var values));
+ }
+
+ #region internals
+
+ private readonly XunitLogger _logger = new(output);
+ private readonly RedirectOutput _testOutputHelper = new(output);
+
+ public void Dispose()
+ {
+ this._logger.Dispose();
+ this._testOutputHelper.Dispose();
+ }
+
+ private void ConfigureChatOllama(IKernelBuilder kernelBuilder)
+ {
+ var config = this._configuration.GetSection("Ollama").Get();
+
+ Assert.NotNull(config);
+ Assert.NotNull(config.Endpoint);
+ Assert.NotNull(config.ModelId);
+
+ kernelBuilder.AddOllamaChatCompletion(
+ modelId: config.ModelId,
+ endpoint: config.Endpoint);
+ }
+
+ private sealed class HttpHeaderHandler(HttpMessageHandler innerHandler) : DelegatingHandler(innerHandler)
+ {
+ public System.Net.Http.Headers.HttpRequestHeaders? RequestHeaders { get; private set; }
+
+ protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
+ {
+ this.RequestHeaders = request.Headers;
+ return await base.SendAsync(request, cancellationToken);
+ }
+ }
+
+ #endregion
+}
diff --git a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs
new file mode 100644
index 000000000000..d26f48a79742
--- /dev/null
+++ b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs
@@ -0,0 +1,111 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System.Threading.Tasks;
+using Microsoft.Extensions.Configuration;
+using Microsoft.SemanticKernel.Connectors.Ollama;
+using Microsoft.SemanticKernel.Connectors.OpenAI;
+using Microsoft.SemanticKernel.Embeddings;
+using SemanticKernel.IntegrationTests.TestSettings;
+using Xunit;
+
+namespace SemanticKernel.IntegrationTests.Connectors.Ollama;
+
+public sealed class OllamaTextEmbeddingTests
+{
+ private const int AdaVectorLength = 1536;
+ private readonly IConfigurationRoot _configuration = new ConfigurationBuilder()
+ .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true)
+ .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true)
+ .AddEnvironmentVariables()
+ .AddUserSecrets()
+ .Build();
+
+ [Theory(Skip = "OpenAI will often throttle requests. This test is for manual verification.")]
+ [InlineData("test sentence")]
+ public async Task OpenAITestAsync(string testInputString)
+ {
+ // Arrange
+ OllamaConfiguration? config = this._configuration.GetSection("Ollama").Get();
+ Assert.NotNull(config);
+ Assert.NotNull(config.ModelId);
+ Assert.NotNull(config.Endpoint);
+
+ var embeddingGenerator = new OllamaTextEmbeddingGenerationService(config.ModelId, config.Endpoint);
+
+ // Act
+ var singleResult = await embeddingGenerator.GenerateEmbeddingAsync(testInputString);
+ var batchResult = await embeddingGenerator.GenerateEmbeddingsAsync([testInputString, testInputString, testInputString]);
+
+ // Assert
+ Assert.Equal(AdaVectorLength, singleResult.Length);
+ Assert.Equal(3, batchResult.Count);
+ }
+
+ [Theory(Skip = "OpenAI will often throttle requests. This test is for manual verification.")]
+ [InlineData(null, 3072)]
+ [InlineData(1024, 1024)]
+ public async Task OpenAIWithDimensionsAsync(int? dimensions, int expectedVectorLength)
+ {
+ // Arrange
+ const string TestInputString = "test sentence";
+
+ OpenAIConfiguration? openAIConfiguration = this._configuration.GetSection("OpenAIEmbeddings").Get();
+ Assert.NotNull(openAIConfiguration);
+
+ var embeddingGenerator = new OpenAITextEmbeddingGenerationService(
+ "text-embedding-3-large",
+ openAIConfiguration.ApiKey,
+ dimensions: dimensions);
+
+ // Act
+ var result = await embeddingGenerator.GenerateEmbeddingAsync(TestInputString);
+
+ // Assert
+ Assert.Equal(expectedVectorLength, result.Length);
+ }
+
+ [Theory]
+ [InlineData("test sentence")]
+ public async Task AzureOpenAITestAsync(string testInputString)
+ {
+ // Arrange
+ AzureOpenAIConfiguration? azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAIEmbeddings").Get();
+ Assert.NotNull(azureOpenAIConfiguration);
+
+ var embeddingGenerator = new AzureOpenAITextEmbeddingGenerationService(azureOpenAIConfiguration.DeploymentName,
+ azureOpenAIConfiguration.Endpoint,
+ azureOpenAIConfiguration.ApiKey);
+
+ // Act
+ var singleResult = await embeddingGenerator.GenerateEmbeddingAsync(testInputString);
+ var batchResult = await embeddingGenerator.GenerateEmbeddingsAsync([testInputString, testInputString, testInputString]);
+
+ // Assert
+ Assert.Equal(AdaVectorLength, singleResult.Length);
+ Assert.Equal(3, batchResult.Count);
+ }
+
+ [Theory]
+ [InlineData(null, 3072)]
+ [InlineData(1024, 1024)]
+ public async Task AzureOpenAIWithDimensionsAsync(int? dimensions, int expectedVectorLength)
+ {
+ // Arrange
+ const string TestInputString = "test sentence";
+
+ AzureOpenAIConfiguration? azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAIEmbeddings").Get();
+ Assert.NotNull(azureOpenAIConfiguration);
+
+ var embeddingGenerator = new AzureOpenAITextEmbeddingGenerationService(
+ "text-embedding-3-large",
+ azureOpenAIConfiguration.Endpoint,
+ azureOpenAIConfiguration.ApiKey,
+ dimensions: dimensions);
+
+ // Act
+ var result = await embeddingGenerator.GenerateEmbeddingAsync(TestInputString);
+
+ // Assert
+ Assert.Equal(expectedVectorLength, result.Length);
+ }
+}
diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj
index df5afa473ce7..9c14051ef665 100644
--- a/dotnet/src/IntegrationTests/IntegrationTests.csproj
+++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj
@@ -63,6 +63,7 @@
+
diff --git a/dotnet/src/IntegrationTests/TestSettings/OllamaConfiguration.cs b/dotnet/src/IntegrationTests/TestSettings/OllamaConfiguration.cs
new file mode 100644
index 000000000000..cbf6e52351c4
--- /dev/null
+++ b/dotnet/src/IntegrationTests/TestSettings/OllamaConfiguration.cs
@@ -0,0 +1,14 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System;
+using System.Diagnostics.CodeAnalysis;
+
+namespace SemanticKernel.IntegrationTests.TestSettings;
+
+[SuppressMessage("Performance", "CA1812:Internal class that is apparently never instantiated",
+ Justification = "Configuration classes are instantiated through IConfiguration.")]
+internal sealed class OllamaConfiguration
+{
+ public string? ModelId { get; set; }
+ public Uri? Endpoint { get; set; }
+}
From 50f500de99a0484ef07e8438636020900a12b8a6 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Thu, 11 Jul 2024 19:53:36 +0100
Subject: [PATCH 2/3] Fix typo
---
docs/decisions/0046-kernel-content-graduation.md | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/docs/decisions/0046-kernel-content-graduation.md b/docs/decisions/0046-kernel-content-graduation.md
index 43518ddfa2d3..368c59bd7621 100644
--- a/docs/decisions/0046-kernel-content-graduation.md
+++ b/docs/decisions/0046-kernel-content-graduation.md
@@ -85,7 +85,7 @@ Pros:
- With no deferred content we have simpler API and a single responsibility for contents.
- Can be written and read in both `Data` or `DataUri` formats.
- Can have a `Uri` reference property, which is common for specialized contexts.
-- Fully serializeable.
+- Fully serializable.
- Data Uri parameters support (serialization included).
- Data Uri and Base64 validation checks
- Data Uri and Data can be dynamically generated
@@ -197,7 +197,7 @@ Pros:
- Can be used as a `BinaryContent` type
- Can be written and read in both `Data` or `DataUri` formats.
- Can have a `Uri` dedicated for referenced location.
-- Fully serializeable.
+- Fully serializable.
- Data Uri parameters support (serialization included).
- Data Uri and Base64 validation checks
- Can be retrieved
@@ -254,7 +254,7 @@ Pros:
- Can be used as a `BinaryContent` type
- Can be written and read in both `Data` or `DataUri` formats.
- Can have a `Uri` dedicated for referenced location.
-- Fully serializeable.
+- Fully serializable.
- Data Uri parameters support (serialization included).
- Data Uri and Base64 validation checks
- Can be retrieved
From cd79a6a72dac961561527d6eb555f78f0c6fedf8 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Fri, 12 Jul 2024 09:29:06 +0100
Subject: [PATCH 3/3] Add missing integration tests
---
.../OllamaKernelBuilderExtensions.cs | 12 +-
.../OllamaServiceCollectionExtensions.cs | 12 +-
.../OllamaPromptExecutionSettings.cs | 2 +-
.../Services/OllamaChatCompletionService.cs | 12 +-
.../OllamaTextEmbeddingGenerationService.cs | 12 +-
.../Services/OllamaTextGenerationService.cs | 12 +-
.../Ollama/OllamaCompletionTests.cs | 4 +-
.../Ollama/OllamaTextEmbeddingTests.cs | 94 +++-----
.../Ollama/OllamaTextGenerationTests.cs | 221 ++++++++++++++++++
9 files changed, 279 insertions(+), 102 deletions(-)
create mode 100644 dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextGenerationTests.cs
diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs
index 1e93e3d49baa..e442e8f9799e 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs
@@ -38,7 +38,7 @@ public static IKernelBuilder AddOllamaTextGeneration(
builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaTextGenerationService(
- model: modelId,
+ modelId: modelId,
endpoint: endpoint,
httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider),
loggerFactory: serviceProvider.GetService()));
@@ -63,7 +63,7 @@ public static IKernelBuilder AddOllamaTextGeneration(
builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaTextGenerationService(
- model: modelId,
+ modelId: modelId,
ollamaClient: ollamaClient,
loggerFactory: serviceProvider.GetService()));
return builder;
@@ -90,7 +90,7 @@ public static IKernelBuilder AddOllamaChatCompletion(
builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaChatCompletionService(
- model: modelId,
+ modelId: modelId,
endpoint: endpoint,
httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider),
loggerFactory: serviceProvider.GetService()));
@@ -116,7 +116,7 @@ public static IKernelBuilder AddOllamaChatCompletion(
builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaChatCompletionService(
- model: modelId,
+ modelId: modelId,
client: ollamaClient,
loggerFactory: serviceProvider.GetService()));
@@ -143,7 +143,7 @@ public static IKernelBuilder AddOllamaTextEmbeddingGeneration(
builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaTextEmbeddingGenerationService(
- model: modelId,
+ modelId: modelId,
endpoint: endpoint,
httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider),
loggerFactory: serviceProvider.GetService()));
@@ -169,7 +169,7 @@ public static IKernelBuilder AddOllamaTextEmbeddingGeneration(
builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaTextEmbeddingGenerationService(
- model: modelId,
+ modelId: modelId,
ollamaClient: ollamaClient,
loggerFactory: serviceProvider.GetService()));
diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs
index d8c4658351e1..0a5497c74a73 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs
@@ -35,7 +35,7 @@ public static IServiceCollection AddOllamaTextGeneration(
return services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaTextGenerationService(
- model: modelId,
+ modelId: modelId,
endpoint: endpoint,
httpClient: HttpClientProvider.GetHttpClient(serviceProvider),
loggerFactory: serviceProvider.GetService()));
@@ -59,7 +59,7 @@ public static IServiceCollection AddOllamaTextGeneration(
return services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaTextGenerationService(
- model: modelId,
+ modelId: modelId,
ollamaClient: ollamaClient,
loggerFactory: serviceProvider.GetService()));
}
@@ -82,7 +82,7 @@ public static IServiceCollection AddOllamaChatCompletion(
services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaChatCompletionService(
- model: modelId,
+ modelId: modelId,
endpoint: endpoint,
httpClient: HttpClientProvider.GetHttpClient(serviceProvider),
loggerFactory: serviceProvider.GetService()));
@@ -108,7 +108,7 @@ public static IServiceCollection AddOllamaChatCompletion(
return services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaChatCompletionService(
- model: modelId,
+ modelId: modelId,
client: ollamaClient,
loggerFactory: serviceProvider.GetService()));
}
@@ -131,7 +131,7 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration(
return services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaTextEmbeddingGenerationService(
- model: modelId,
+ modelId: modelId,
endpoint: endpoint,
httpClient: HttpClientProvider.GetHttpClient(serviceProvider),
loggerFactory: serviceProvider.GetService()));
@@ -155,7 +155,7 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration(
return services.AddKeyedSingleton(serviceId, (serviceProvider, _) =>
new OllamaTextEmbeddingGenerationService(
- model: modelId,
+ modelId: modelId,
ollamaClient: ollamaClient,
loggerFactory: serviceProvider.GetService()));
}
diff --git a/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs
index 9fc47bb9bb1b..283c6790c549 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs
@@ -8,7 +8,7 @@
namespace Microsoft.SemanticKernel.Connectors.Ollama;
///
-/// Ollama Execution Settings.
+/// Ollama Prompt Execution Settings.
///
public sealed class OllamaPromptExecutionSettings : PromptExecutionSettings
{
diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs
index 830b0ba82b9d..f611b9625e88 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs
@@ -23,30 +23,30 @@ public sealed class OllamaChatCompletionService : ServiceBase, IChatCompletionSe
///
/// Initializes a new instance of the class.
///
- /// The hosted model.
+ /// The hosted model.
/// The endpoint including the port where Ollama server is hosted
/// Optional HTTP client to be used for communication with the Ollama API.
/// Optional logger factory to be used for logging.
public OllamaChatCompletionService(
- string model,
+ string modelId,
Uri endpoint,
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
- : base(model, endpoint, httpClient, loggerFactory)
+ : base(modelId, endpoint, httpClient, loggerFactory)
{
}
///
/// Initializes a new instance of the class.
///
- /// The hosted model.
+ /// The hosted model.
/// The Ollama API client.
/// Optional logger factory to be used for logging.
public OllamaChatCompletionService(
- string model,
+ string modelId,
OllamaApiClient client,
ILoggerFactory? loggerFactory = null)
- : base(model, client, loggerFactory)
+ : base(modelId, client, loggerFactory)
{
}
diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs
index f5a7021d97fa..13adcd165c80 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs
@@ -22,30 +22,30 @@ public sealed class OllamaTextEmbeddingGenerationService : ServiceBase, ITextEmb
///
/// Initializes a new instance of the class.
///
- /// The hosted model.
+ /// The hosted model.
/// The endpoint including the port where Ollama server is hosted
/// Optional HTTP client to be used for communication with the Ollama API.
/// Optional logger factory to be used for logging.
public OllamaTextEmbeddingGenerationService(
- string model,
+ string modelId,
Uri endpoint,
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
- : base(model, endpoint, httpClient, loggerFactory)
+ : base(modelId, endpoint, httpClient, loggerFactory)
{
}
///
/// Initializes a new instance of the class.
///
- /// The hosted model.
+ /// The hosted model.
/// The Ollama API client.
/// Optional logger factory to be used for logging.
public OllamaTextEmbeddingGenerationService(
- string model,
+ string modelId,
OllamaApiClient ollamaClient,
ILoggerFactory? loggerFactory = null)
- : base(model, ollamaClient, loggerFactory)
+ : base(modelId, ollamaClient, loggerFactory)
{
}
diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs
index 405c345776ce..29acd5f342c5 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs
@@ -21,30 +21,30 @@ public sealed class OllamaTextGenerationService : ServiceBase, ITextGenerationSe
///
/// Initializes a new instance of the class.
///
- /// The Ollama model for the text generation service.
+ /// The Ollama model for the text generation service.
/// The endpoint including the port where Ollama server is hosted
/// Optional HTTP client to be used for communication with the Ollama API.
/// Optional logger factory to be used for logging.
public OllamaTextGenerationService(
- string model,
+ string modelId,
Uri endpoint,
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
- : base(model, endpoint, httpClient, loggerFactory)
+ : base(modelId, endpoint, httpClient, loggerFactory)
{
}
///
/// Initializes a new instance of the class.
///
- /// The hosted model.
+ /// The hosted model.
/// The Ollama API client.
/// Optional logger factory to be used for logging.
public OllamaTextGenerationService(
- string model,
+ string modelId,
OllamaApiClient ollamaClient,
ILoggerFactory? loggerFactory = null)
- : base(model, ollamaClient, loggerFactory)
+ : base(modelId, ollamaClient, loggerFactory)
{
}
diff --git a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs
index 3c7fc21407a4..4fabf80936ff 100644
--- a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs
+++ b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
-using System.Linq;
using System.Net.Http;
using System.Text;
using System.Threading;
@@ -11,7 +10,6 @@
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.Ollama;
-using Microsoft.SemanticKernel.Connectors.OpenAI;
using SemanticKernel.IntegrationTests.TestSettings;
using Xunit;
using Xunit.Abstractions;
@@ -33,7 +31,7 @@ public sealed class OllamaCompletionTests(ITestOutputHelper output) : IDisposabl
[Theory(Skip = "For manual verification only")]
[InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")]
- public async Task ItStreamingTestAsync(string prompt, string expectedAnswerContains)
+ public async Task ItInvokeStreamingWorksAsync(string prompt, string expectedAnswerContains)
{
// Arrange
this._kernelBuilder.Services.AddSingleton(this._logger);
diff --git a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs
index d26f48a79742..f530098b473b 100644
--- a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs
+++ b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs
@@ -3,7 +3,6 @@
using System.Threading.Tasks;
using Microsoft.Extensions.Configuration;
using Microsoft.SemanticKernel.Connectors.Ollama;
-using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel.Embeddings;
using SemanticKernel.IntegrationTests.TestSettings;
using Xunit;
@@ -12,50 +11,29 @@ namespace SemanticKernel.IntegrationTests.Connectors.Ollama;
public sealed class OllamaTextEmbeddingTests
{
- private const int AdaVectorLength = 1536;
private readonly IConfigurationRoot _configuration = new ConfigurationBuilder()
- .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true)
+ .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true)
.AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true)
.AddEnvironmentVariables()
.AddUserSecrets()
.Build();
- [Theory(Skip = "OpenAI will often throttle requests. This test is for manual verification.")]
- [InlineData("test sentence")]
- public async Task OpenAITestAsync(string testInputString)
+ [Theory(Skip = "For manual verification only")]
+ [InlineData("mxbai-embed-large", 1024)]
+ [InlineData("nomic-embed-text", 768)]
+ [InlineData("all-minilm", 384)]
+ public async Task GenerateEmbeddingHasExpectedLengthForModelAsync(string modelId, int expectedVectorLength)
{
// Arrange
+ const string TestInputString = "test sentence";
+
OllamaConfiguration? config = this._configuration.GetSection("Ollama").Get();
Assert.NotNull(config);
- Assert.NotNull(config.ModelId);
Assert.NotNull(config.Endpoint);
- var embeddingGenerator = new OllamaTextEmbeddingGenerationService(config.ModelId, config.Endpoint);
-
- // Act
- var singleResult = await embeddingGenerator.GenerateEmbeddingAsync(testInputString);
- var batchResult = await embeddingGenerator.GenerateEmbeddingsAsync([testInputString, testInputString, testInputString]);
-
- // Assert
- Assert.Equal(AdaVectorLength, singleResult.Length);
- Assert.Equal(3, batchResult.Count);
- }
-
- [Theory(Skip = "OpenAI will often throttle requests. This test is for manual verification.")]
- [InlineData(null, 3072)]
- [InlineData(1024, 1024)]
- public async Task OpenAIWithDimensionsAsync(int? dimensions, int expectedVectorLength)
- {
- // Arrange
- const string TestInputString = "test sentence";
-
- OpenAIConfiguration? openAIConfiguration = this._configuration.GetSection("OpenAIEmbeddings").Get();
- Assert.NotNull(openAIConfiguration);
-
- var embeddingGenerator = new OpenAITextEmbeddingGenerationService(
- "text-embedding-3-large",
- openAIConfiguration.ApiKey,
- dimensions: dimensions);
+ var embeddingGenerator = new OllamaTextEmbeddingGenerationService(
+ modelId,
+ config.Endpoint);
// Act
var result = await embeddingGenerator.GenerateEmbeddingAsync(TestInputString);
@@ -64,48 +42,28 @@ public async Task OpenAIWithDimensionsAsync(int? dimensions, int expectedVectorL
Assert.Equal(expectedVectorLength, result.Length);
}
- [Theory]
- [InlineData("test sentence")]
- public async Task AzureOpenAITestAsync(string testInputString)
- {
- // Arrange
- AzureOpenAIConfiguration? azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAIEmbeddings").Get();
- Assert.NotNull(azureOpenAIConfiguration);
-
- var embeddingGenerator = new AzureOpenAITextEmbeddingGenerationService(azureOpenAIConfiguration.DeploymentName,
- azureOpenAIConfiguration.Endpoint,
- azureOpenAIConfiguration.ApiKey);
-
- // Act
- var singleResult = await embeddingGenerator.GenerateEmbeddingAsync(testInputString);
- var batchResult = await embeddingGenerator.GenerateEmbeddingsAsync([testInputString, testInputString, testInputString]);
-
- // Assert
- Assert.Equal(AdaVectorLength, singleResult.Length);
- Assert.Equal(3, batchResult.Count);
- }
-
- [Theory]
- [InlineData(null, 3072)]
- [InlineData(1024, 1024)]
- public async Task AzureOpenAIWithDimensionsAsync(int? dimensions, int expectedVectorLength)
+ [Theory(Skip = "For manual verification only")]
+ [InlineData("mxbai-embed-large", 1024)]
+ [InlineData("nomic-embed-text", 768)]
+ [InlineData("all-minilm", 384)]
+ public async Task GenerateEmbeddingsHasExpectedResultsLengthForModelAsync(string modelId, int expectedVectorLength)
{
// Arrange
- const string TestInputString = "test sentence";
+ string[] testInputStrings = ["test sentence 1", "test sentence 2", "test sentence 3"];
- AzureOpenAIConfiguration? azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAIEmbeddings").Get();
- Assert.NotNull(azureOpenAIConfiguration);
+ OllamaConfiguration? config = this._configuration.GetSection("Ollama").Get();
+ Assert.NotNull(config);
+ Assert.NotNull(config.Endpoint);
- var embeddingGenerator = new AzureOpenAITextEmbeddingGenerationService(
- "text-embedding-3-large",
- azureOpenAIConfiguration.Endpoint,
- azureOpenAIConfiguration.ApiKey,
- dimensions: dimensions);
+ var embeddingGenerator = new OllamaTextEmbeddingGenerationService(
+ modelId,
+ config.Endpoint);
// Act
- var result = await embeddingGenerator.GenerateEmbeddingAsync(TestInputString);
+ var result = await embeddingGenerator.GenerateEmbeddingsAsync(testInputStrings);
// Assert
- Assert.Equal(expectedVectorLength, result.Length);
+ Assert.Equal(testInputStrings.Length, result.Count);
+ Assert.All(result, r => Assert.Equal(expectedVectorLength, r.Length));
}
}
diff --git a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextGenerationTests.cs b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextGenerationTests.cs
new file mode 100644
index 000000000000..597fdf331db2
--- /dev/null
+++ b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextGenerationTests.cs
@@ -0,0 +1,221 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System;
+using System.Net.Http;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Logging;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.Connectors.Ollama;
+using SemanticKernel.IntegrationTests.TestSettings;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace SemanticKernel.IntegrationTests.Connectors.Ollama;
+
+#pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only.
+
+public sealed class OllamaTextGenerationTests(ITestOutputHelper output) : IDisposable
+{
+ private const string InputParameterName = "input";
+ private readonly IKernelBuilder _kernelBuilder = Kernel.CreateBuilder();
+ private readonly IConfigurationRoot _configuration = new ConfigurationBuilder()
+ .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true)
+ .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true)
+ .AddEnvironmentVariables()
+ .AddUserSecrets()
+ .Build();
+
+ [Theory(Skip = "For manual verification only")]
+ [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")]
+ public async Task ItInvokeStreamingWorksAsync(string prompt, string expectedAnswerContains)
+ {
+ // Arrange
+ this._kernelBuilder.Services.AddSingleton(this._logger);
+ var builder = this._kernelBuilder;
+
+ this.ConfigureTextOllama(this._kernelBuilder);
+
+ Kernel target = builder.Build();
+
+ IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin");
+
+ StringBuilder fullResult = new();
+ // Act
+ await foreach (var content in target.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }))
+ {
+ fullResult.Append(content);
+ Assert.NotNull(content.Metadata);
+ }
+
+ // Assert
+ Assert.Contains(expectedAnswerContains, fullResult.ToString(), StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Fact(Skip = "For manual verification only")]
+ public async Task ItShouldReturnMetadataAsync()
+ {
+ // Arrange
+ this._kernelBuilder.Services.AddSingleton(this._logger);
+
+ this.ConfigureTextOllama(this._kernelBuilder);
+
+ var kernel = this._kernelBuilder.Build();
+
+ var plugin = TestHelpers.ImportSamplePlugins(kernel, "FunPlugin");
+
+ // Act
+ StreamingKernelContent? lastUpdate = null;
+ await foreach (var update in kernel.InvokeStreamingAsync(plugin["FunPlugin"]["Limerick"]))
+ {
+ lastUpdate = update;
+ }
+
+ // Assert
+ Assert.NotNull(lastUpdate);
+ Assert.NotNull(lastUpdate.Metadata);
+
+ // CreatedAt
+ Assert.True(lastUpdate.Metadata.TryGetValue("CreatedAt", out object? createdAt));
+ Assert.IsType(lastUpdate.Metadata);
+ OllamaMetadata ollamaMetadata = (OllamaMetadata)lastUpdate.Metadata;
+ Assert.NotNull(ollamaMetadata.CreatedAt);
+ Assert.NotEqual(0, ollamaMetadata.TotalDuration);
+ Assert.NotEqual(0, ollamaMetadata.EvalDuration);
+ }
+
+ [Theory(Skip = "For manual verification only")]
+ [InlineData("\n")]
+ [InlineData("\r\n")]
+ public async Task ItCompletesWithDifferentLineEndingsAsync(string lineEnding)
+ {
+ // Arrange
+ var prompt =
+ "Given a json input and a request. Apply the request on the json input and return the result. " +
+ $"Put the result in between tags{lineEnding}" +
+ $$"""Input:{{lineEnding}}{"name": "John", "age": 30}{{lineEnding}}{{lineEnding}}Request:{{lineEnding}}name""";
+
+ const string ExpectedAnswerContains = "result";
+
+ this._kernelBuilder.Services.AddSingleton(this._logger);
+ this.ConfigureTextOllama(this._kernelBuilder);
+
+ Kernel target = this._kernelBuilder.Build();
+
+ IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin");
+
+ // Act
+ FunctionResult actual = await target.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt });
+
+ // Assert
+ Assert.Contains(ExpectedAnswerContains, actual.GetValue(), StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Fact(Skip = "For manual verification only")]
+ public async Task ItInvokePromptTestAsync()
+ {
+ // Arrange
+ this._kernelBuilder.Services.AddSingleton(this._logger);
+ var builder = this._kernelBuilder;
+ this.ConfigureTextOllama(builder);
+ Kernel target = builder.Build();
+
+ var prompt = "Where is the most famous fish market in Seattle, Washington, USA?";
+
+ // Act
+ FunctionResult actual = await target.InvokePromptAsync(prompt, new(new OllamaPromptExecutionSettings() { Temperature = 0.5f }));
+
+ // Assert
+ Assert.Contains("Pike Place", actual.GetValue(), StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Theory(Skip = "For manual verification only")]
+ [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")]
+ public async Task ItInvokeTestAsync(string prompt, string expectedAnswerContains)
+ {
+ // Arrange
+ this._kernelBuilder.Services.AddSingleton(this._logger);
+ var builder = this._kernelBuilder;
+
+ this.ConfigureTextOllama(this._kernelBuilder);
+
+ Kernel target = builder.Build();
+
+ IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin");
+
+ // Act
+ FunctionResult actual = await target.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt });
+
+ // Assert
+ Assert.Contains(expectedAnswerContains, actual.GetValue(), StringComparison.OrdinalIgnoreCase);
+ Assert.NotNull(actual.Metadata);
+ }
+
+ [Fact(Skip = "For manual verification only")]
+ public async Task ItShouldHaveSemanticKernelVersionHeaderAsync()
+ {
+ // Arrange
+ var config = this._configuration.GetSection("Ollama").Get();
+ Assert.NotNull(config);
+ Assert.NotNull(config.ModelId);
+ Assert.NotNull(config.Endpoint);
+
+ using var defaultHandler = new HttpClientHandler();
+ using var httpHeaderHandler = new HttpHeaderHandler(defaultHandler);
+ using var httpClient = new HttpClient(httpHeaderHandler);
+ this._kernelBuilder.Services.AddSingleton(this._logger);
+ var builder = this._kernelBuilder;
+ builder.AddOllamaTextGeneration(
+ endpoint: config.Endpoint,
+ modelId: config.ModelId,
+ httpClient: httpClient);
+ Kernel target = builder.Build();
+
+ // Act
+ var result = await target.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?");
+
+ // Assert
+ Assert.NotNull(httpHeaderHandler.RequestHeaders);
+ Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var values));
+ }
+
+ #region internals
+
+ private readonly XunitLogger _logger = new(output);
+ private readonly RedirectOutput _testOutputHelper = new(output);
+
+ public void Dispose()
+ {
+ this._logger.Dispose();
+ this._testOutputHelper.Dispose();
+ }
+
+ private void ConfigureTextOllama(IKernelBuilder kernelBuilder)
+ {
+ var config = this._configuration.GetSection("Ollama").Get();
+
+ Assert.NotNull(config);
+ Assert.NotNull(config.Endpoint);
+ Assert.NotNull(config.ModelId);
+
+ kernelBuilder.AddOllamaTextGeneration(
+ modelId: config.ModelId,
+ endpoint: config.Endpoint);
+ }
+
+ private sealed class HttpHeaderHandler(HttpMessageHandler innerHandler) : DelegatingHandler(innerHandler)
+ {
+ public System.Net.Http.Headers.HttpRequestHeaders? RequestHeaders { get; private set; }
+
+ protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
+ {
+ this.RequestHeaders = request.Headers;
+ return await base.SendAsync(request, cancellationToken);
+ }
+ }
+
+ #endregion
+}