From 96dfa8d34a419e03248497c4ca44127bbf61ed3b Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Tue, 2 Jul 2024 16:37:04 +0100 Subject: [PATCH] Ollama Updates --- dotnet/Directory.Packages.props | 1 + dotnet/SK-dotnet.sln | 18 ++ .../Connectors.Ollama.UnitTests.csproj | 49 +++++ .../HttpMessageHandlerStub.cs | 48 +++++ .../OllamaKernelBuilderExtensionsTests.cs | 59 ++++++ .../OllamaPromptExecutionSettingsTests.cs | 64 +++++++ .../OllamaServiceCollectionExtensionsTests.cs | 57 ++++++ .../OllamaTestHelper.cs | 50 +++++ .../Services/OllamaChatCompletionTests.cs | 176 +++++++++++++++++ .../OllamaTextEmbeddingGenerationTests.cs | 103 ++++++++++ .../Services/OllamaTextGenerationTests.cs | 154 +++++++++++++++ .../chat_completion_test_response.txt | 1 + .../chat_completion_test_response_stream.txt | 6 + .../TestData/embeddings_test_response.json | 13 ++ .../text_generation_test_response.txt | 1 + .../text_generation_test_response_stream.txt | 5 + .../Connectors.Ollama/AssemblyInfo.cs | 6 + .../Connectors.Ollama.csproj | 34 ++++ .../Core/OllamaChatResponseStreamer.cs | 26 +++ .../Connectors.Ollama/Core/ServiceBase.cs | 72 +++++++ .../OllamaKernelBuilderExtensions.cs | 178 ++++++++++++++++++ .../OllamaServiceCollectionExtensions.cs | 162 ++++++++++++++++ .../Connectors.Ollama/OllamaMetadata.cs | 94 +++++++++ .../OllamaPromptExecutionSettings.cs | 121 ++++++++++++ .../Services/OllamaChatCompletionService.cs | 136 +++++++++++++ .../OllamaTextEmbeddingGenerationService.cs | 77 ++++++++ .../Services/OllamaTextGenerationService.cs | 78 ++++++++ 27 files changed, 1789 insertions(+) create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/HttpMessageHandlerStub.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaKernelBuilderExtensionsTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaPromptExecutionSettingsTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaServiceCollectionExtensionsTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaTestHelper.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextEmbeddingGenerationTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextGenerationTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/chat_completion_test_response.txt create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/chat_completion_test_response_stream.txt create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/embeddings_test_response.json create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/text_generation_test_response.txt create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/text_generation_test_response_stream.txt create mode 100644 dotnet/src/Connectors/Connectors.Ollama/AssemblyInfo.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Connectors.Ollama.csproj create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Core/OllamaChatResponseStreamer.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 24495d03f3c5..bc2f3c81d3bc 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -37,6 +37,7 @@ + diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index ab2e6d7d75e8..17bcfae59d0c 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -318,6 +318,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Redis.UnitTests" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Qdrant.UnitTests", "src\Connectors\Connectors.Qdrant.UnitTests\Connectors.Qdrant.UnitTests.csproj", "{E92AE954-8F3A-4A6F-A4F9-DC12017E5AAF}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Ollama", "src\Connectors\Connectors.Ollama\Connectors.Ollama.csproj", "{E7E60E1D-1A44-4DE9-A44D-D5052E809DDD}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Ollama.UnitTests", "src\Connectors\Connectors.Ollama.UnitTests\Connectors.Ollama.UnitTests.csproj", "{924DB138-1223-4C99-B6E6-0938A3FA14EF}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -787,6 +791,18 @@ Global {E92AE954-8F3A-4A6F-A4F9-DC12017E5AAF}.Publish|Any CPU.Build.0 = Debug|Any CPU {E92AE954-8F3A-4A6F-A4F9-DC12017E5AAF}.Release|Any CPU.ActiveCfg = Release|Any CPU {E92AE954-8F3A-4A6F-A4F9-DC12017E5AAF}.Release|Any CPU.Build.0 = Release|Any CPU + {E7E60E1D-1A44-4DE9-A44D-D5052E809DDD}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E7E60E1D-1A44-4DE9-A44D-D5052E809DDD}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E7E60E1D-1A44-4DE9-A44D-D5052E809DDD}.Publish|Any CPU.ActiveCfg = Publish|Any CPU + {E7E60E1D-1A44-4DE9-A44D-D5052E809DDD}.Publish|Any CPU.Build.0 = Publish|Any CPU + {E7E60E1D-1A44-4DE9-A44D-D5052E809DDD}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E7E60E1D-1A44-4DE9-A44D-D5052E809DDD}.Release|Any CPU.Build.0 = Release|Any CPU + {924DB138-1223-4C99-B6E6-0938A3FA14EF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {924DB138-1223-4C99-B6E6-0938A3FA14EF}.Debug|Any CPU.Build.0 = Debug|Any CPU + {924DB138-1223-4C99-B6E6-0938A3FA14EF}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {924DB138-1223-4C99-B6E6-0938A3FA14EF}.Publish|Any CPU.Build.0 = Debug|Any CPU + {924DB138-1223-4C99-B6E6-0938A3FA14EF}.Release|Any CPU.ActiveCfg = Release|Any CPU + {924DB138-1223-4C99-B6E6-0938A3FA14EF}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -895,6 +911,8 @@ Global {B0B3901E-AF56-432B-8FAA-858468E5D0DF} = {24503383-A8C4-4255-9998-28D70FE8E99A} {1D4667B9-9381-4E32-895F-123B94253EE8} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} {E92AE954-8F3A-4A6F-A4F9-DC12017E5AAF} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} + {E7E60E1D-1A44-4DE9-A44D-D5052E809DDD} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1} + {924DB138-1223-4C99-B6E6-0938A3FA14EF} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83} diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj new file mode 100644 index 000000000000..427f079b3c65 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj @@ -0,0 +1,49 @@ + + + + SemanticKernel.Connectors.Ollama.UnitTests + SemanticKernel.Connectors.Ollama.UnitTests + net8.0 + 12 + LatestMajor + true + enable + disable + false + CA2007,CA1861,VSTHRD111,CS1591,SKEXP0001,SKEXP0070 + + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + + + + + + Always + + + + diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/HttpMessageHandlerStub.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/HttpMessageHandlerStub.cs new file mode 100644 index 000000000000..0da4dfa3d098 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/HttpMessageHandlerStub.cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace SemanticKernel.Connectors.Ollama.UnitTests; + +internal sealed class HttpMessageHandlerStub : DelegatingHandler +{ + public HttpRequestHeaders? RequestHeaders { get; private set; } + + public HttpContentHeaders? ContentHeaders { get; private set; } + + public byte[]? RequestContent { get; private set; } + + public Uri? RequestUri { get; private set; } + + public HttpMethod? Method { get; private set; } + + public HttpResponseMessage ResponseToReturn { get; set; } + + public HttpMessageHandlerStub() + { + this.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK); + this.ResponseToReturn.Content = new StringContent("{}", Encoding.UTF8, "application/json"); + } + + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + this.Method = request.Method; + this.RequestUri = request.RequestUri; + this.RequestHeaders = request.Headers; + if (request.Content is not null) + { +#pragma warning disable CA2016 // Forward the 'CancellationToken' parameter to methods; overload doesn't exist on .NET Framework + this.RequestContent = await request.Content.ReadAsByteArrayAsync(); +#pragma warning restore CA2016 + } + + this.ContentHeaders = request.Content?.Headers; + + return await Task.FromResult(this.ResponseToReturn); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaKernelBuilderExtensionsTests.cs new file mode 100644 index 000000000000..571f99983bbd --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaKernelBuilderExtensionsTests.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Ollama; +using Microsoft.SemanticKernel.Embeddings; +using Microsoft.SemanticKernel.TextGeneration; +using Xunit; + +namespace SemanticKernel.Connectors.Ollama.UnitTests; + +/// +/// Unit tests of . +/// +public class OllamaKernelBuilderExtensionsTests +{ + [Fact] + public void AddOllamaTextGenerationCreatesService() + { + var builder = Kernel.CreateBuilder(); + builder.AddOllamaTextGeneration("model", new Uri("http://localhost:11434")); + + var kernel = builder.Build(); + var service = kernel.GetRequiredService(); + + Assert.NotNull(kernel); + Assert.NotNull(service); + Assert.IsType(service); + } + + [Fact] + public void AddOllamaChatCompletionCreatesService() + { + var builder = Kernel.CreateBuilder(); + builder.AddOllamaChatCompletion("model", new Uri("http://localhost:11434")); + + var kernel = builder.Build(); + var service = kernel.GetRequiredService(); + + Assert.NotNull(kernel); + Assert.NotNull(service); + Assert.IsType(service); + } + + [Fact] + public void AddOllamaTextEmbeddingGenerationCreatesService() + { + var builder = Kernel.CreateBuilder(); + builder.AddOllamaTextEmbeddingGeneration("model", new Uri("http://localhost:11434")); + + var kernel = builder.Build(); + var service = kernel.GetRequiredService(); + + Assert.NotNull(kernel); + Assert.NotNull(service); + Assert.IsType(service); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaPromptExecutionSettingsTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaPromptExecutionSettingsTests.cs new file mode 100644 index 000000000000..931b1f0674a8 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaPromptExecutionSettingsTests.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Ollama; +using Xunit; + +namespace SemanticKernel.Connectors.Ollama.UnitTests; + +/// +/// Unit tests of . +/// +public class OllamaPromptExecutionSettingsTests +{ + [Fact] + public void FromExecutionSettingsWhenAlreadyOllamaShouldReturnSameAsync() + { + // Arrange + var executionSettings = new OllamaPromptExecutionSettings(); + + // Act + var ollamaExecutionSettings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + + // Assert + Assert.Same(executionSettings, ollamaExecutionSettings); + } + + [Fact] + public void FromExecutionSettingsWhenNullShouldReturnDefaultAsync() + { + // Arrange + OllamaPromptExecutionSettings? executionSettings = null; + + // Act + var ollamaExecutionSettings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + + // Assert + Assert.Null(ollamaExecutionSettings.Stop); + Assert.Null(ollamaExecutionSettings.Temperature); + Assert.Null(ollamaExecutionSettings.TopP); + Assert.Null(ollamaExecutionSettings.TopK); + } + + [Fact] + public void FromExecutionSettingsWhenSerializedHasPropertiesShouldPopulateSpecialized() + { + string jsonSettings = """ + { + "stop": "stop me", + "temperature": 0.5, + "top_p": 0.9, + "top_k": 100 + } + """; + + var executionSettings = JsonSerializer.Deserialize(jsonSettings); + var ollamaExecutionSettings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + + Assert.Equal("stop me", ollamaExecutionSettings.Stop); + Assert.Equal(0.5f, ollamaExecutionSettings.Temperature); + Assert.Equal(0.9f, ollamaExecutionSettings.TopP!.Value, 0.1f); + Assert.Equal(100, ollamaExecutionSettings.TopK); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaServiceCollectionExtensionsTests.cs new file mode 100644 index 000000000000..4762acadc65e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaServiceCollectionExtensionsTests.cs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Ollama; +using Microsoft.SemanticKernel.Embeddings; +using Microsoft.SemanticKernel.TextGeneration; +using Xunit; + +namespace SemanticKernel.Connectors.Ollama.UnitTests; + +/// +/// Unit tests of . +/// +public class OllamaServiceCollectionExtensionsTests +{ + [Fact] + public void AddOllamaTextGenerationToServiceCollection() + { + var services = new ServiceCollection(); + services.AddOllamaTextGeneration("model", new Uri("http://localhost:11434")); + + var serviceProvider = services.BuildServiceProvider(); + var service = serviceProvider.GetRequiredService(); + + Assert.NotNull(service); + Assert.IsType(service); + } + + [Fact] + public void AddOllamaChatCompletionToServiceCollection() + { + var services = new ServiceCollection(); + services.AddOllamaChatCompletion("model", new Uri("http://localhost:11434")); + + var serviceProvider = services.BuildServiceProvider(); + var service = serviceProvider.GetRequiredService(); + + Assert.NotNull(service); + Assert.IsType(service); + } + + [Fact] + public void AddOllamaTextEmbeddingsGenerationToServiceCollection() + { + var services = new ServiceCollection(); + services.AddOllamaTextEmbeddingGeneration("model", new Uri("http://localhost:11434")); + + var serviceProvider = services.BuildServiceProvider(); + var service = serviceProvider.GetRequiredService(); + + Assert.NotNull(service); + Assert.IsType(service); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaTestHelper.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaTestHelper.cs new file mode 100644 index 000000000000..33d2c24c87e3 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaTestHelper.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.IO; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Moq; +using Moq.Protected; + +namespace SemanticKernel.Connectors.Ollama.UnitTests; + +/// +/// Helper for HuggingFace test purposes. +/// +internal static class OllamaTestHelper +{ + /// + /// Reads test response from file for mocking purposes. + /// + /// Name of the file with test response. + internal static string GetTestResponse(string fileName) + { + return File.ReadAllText($"./TestData/{fileName}"); + } + + internal static ReadOnlyMemory GetTestResponseBytes(string fileName) + { + return File.ReadAllBytes($"./TestData/{fileName}"); + } + + /// + /// Returns mocked instance of . + /// + /// Message to return for mocked . + internal static HttpClientHandler GetHttpClientHandlerMock(HttpResponseMessage httpResponseMessage) + { + var httpClientHandler = new Mock(); + + httpClientHandler + .Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(httpResponseMessage); + + return httpClientHandler.Object; + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs new file mode 100644 index 000000000000..622268ecd2a5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs @@ -0,0 +1,176 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Ollama; +using OllamaSharp.Models.Chat; +using Xunit; + +namespace SemanticKernel.Connectors.Ollama.UnitTests; + +public sealed class OllamaChatCompletionTests : IDisposable +{ + private readonly HttpMessageHandlerStub _messageHandlerStub; + private readonly HttpClient _httpClient; + + public OllamaChatCompletionTests() + { + this._messageHandlerStub = new HttpMessageHandlerStub(); + this._messageHandlerStub.ResponseToReturn.Content = new StringContent(File.ReadAllText("TestData/chat_completion_test_response.txt")); + this._httpClient = new HttpClient(this._messageHandlerStub, false) { BaseAddress = new Uri("http://localhost:11434") }; + } + + [Fact] + public async Task UserAgentHeaderShouldBeUsedAsync() + { + //Arrange + var sut = new OllamaChatCompletionService( + "fake-model", + new Uri("http://localhost:11434"), + httpClient: this._httpClient); + + var chat = new ChatHistory(); + chat.AddMessage(AuthorRole.User, "fake-text"); + + //Act + await sut.GetChatMessageContentsAsync(chat); + + //Assert + Assert.True(this._messageHandlerStub.RequestHeaders?.Contains("User-Agent")); + + var values = this._messageHandlerStub.RequestHeaders!.GetValues("User-Agent"); + var value = values.SingleOrDefault(); + Assert.Equal("Semantic-Kernel", value); + } + + [Fact] + public async Task WhenHttpClientDoesNotHaveBaseAddressProvidedEndpointShouldBeUsedAsync() + { + //Arrange + this._httpClient.BaseAddress = null; + var sut = new OllamaChatCompletionService("fake-model", new Uri("https://fake-random-test-host/fake-path/"), httpClient: this._httpClient); + var chat = new ChatHistory(); + chat.AddMessage(AuthorRole.User, "fake-text"); + + //Act + await sut.GetChatMessageContentsAsync(chat); + + //Assert + Assert.StartsWith("https://fake-random-test-host/fake-path", this._messageHandlerStub.RequestUri?.AbsoluteUri, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task ShouldSendPromptToServiceAsync() + { + //Arrange + var sut = new OllamaChatCompletionService( + "fake-model", + new Uri("http://localhost:11434"), + httpClient: this._httpClient); + var chat = new ChatHistory(); + chat.AddMessage(AuthorRole.User, "fake-text"); + + //Act + await sut.GetChatMessageContentsAsync(chat); + + //Assert + var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(requestPayload); + Assert.Equal("fake-text", requestPayload.Messages!.First().Content); + } + + [Fact] + public async Task ShouldHandleServiceResponseAsync() + { + //Arrange + var sut = new OllamaChatCompletionService( + "fake-model", + new Uri("http://localhost:11434"), + httpClient: this._httpClient); + + var chat = new ChatHistory(); + chat.AddMessage(AuthorRole.User, "fake-text"); + + //Act + var messages = await sut.GetChatMessageContentsAsync(chat); + + //Assert + Assert.NotNull(messages); + + var message = messages.SingleOrDefault(); + Assert.NotNull(message); + Assert.Equal("This is test completion response", message.Content); + } + + [Fact] + public async Task GetChatMessageContentsShouldHaveModelIdDefinedAsync() + { + //Arrange + var sut = new OllamaChatCompletionService( + "fake-model", + new Uri("http://localhost:11434"), + httpClient: this._httpClient); + + this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK) + { + Content = new StringContent(File.ReadAllText("TestData/chat_completion_test_response.txt")) + }; + + var chat = new ChatHistory(); + chat.AddMessage(AuthorRole.User, "fake-text"); + + //Act + var messages = await sut.GetChatMessageContentsAsync(chat); + + //Assert + Assert.NotNull(messages); + var message = messages.SingleOrDefault(); + Assert.NotNull(message); + + // Assert + Assert.NotNull(message.ModelId); + Assert.Equal("fake-model", message.ModelId); + } + + [Fact] + public async Task GetStreamingChatMessageContentsShouldHaveModelIdDefinedAsync() + { + //Arrange + var expectedModel = "phi3"; + var sut = new OllamaChatCompletionService( + expectedModel, + new Uri("http://localhost:11434"), + httpClient: this._httpClient); + + this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK) + { + Content = new StreamContent(File.OpenRead("TestData/chat_completion_test_response_stream.txt")) + }; + + var chat = new ChatHistory(); + chat.AddMessage(AuthorRole.User, "fake-text"); + + // Act + StreamingChatMessageContent? lastMessage = null; + await foreach (var message in sut.GetStreamingChatMessageContentsAsync(chat)) + { + lastMessage = message; + } + + // Assert + Assert.NotNull(lastMessage!.ModelId); + Assert.Equal(expectedModel, lastMessage.ModelId); + } + + public void Dispose() + { + this._httpClient.Dispose(); + this._messageHandlerStub.Dispose(); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextEmbeddingGenerationTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextEmbeddingGenerationTests.cs new file mode 100644 index 000000000000..4e9cf00754cf --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextEmbeddingGenerationTests.cs @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.Ollama; +using OllamaSharp.Models; +using Xunit; + +namespace SemanticKernel.Connectors.Ollama.UnitTests; + +public sealed class OllamaTextEmbeddingGenerationTests : IDisposable +{ + private readonly HttpMessageHandlerStub _messageHandlerStub; + private readonly HttpClient _httpClient; + + public OllamaTextEmbeddingGenerationTests() + { + this._messageHandlerStub = new HttpMessageHandlerStub(); + this._messageHandlerStub.ResponseToReturn.Content = new StringContent(OllamaTestHelper.GetTestResponse("embeddings_test_response.json")); + this._httpClient = new HttpClient(this._messageHandlerStub, false); + } + + [Fact] + public async Task UserAgentHeaderShouldBeUsedAsync() + { + //Arrange + var sut = new OllamaTextEmbeddingGenerationService( + "fake-model", + new Uri("http://localhost:11434"), + httpClient: this._httpClient); + + //Act + await sut.GenerateEmbeddingsAsync(new List { "fake-text" }); + + //Assert + Assert.True(this._messageHandlerStub.RequestHeaders?.Contains("User-Agent")); + + var values = this._messageHandlerStub.RequestHeaders!.GetValues("User-Agent"); + var value = values.SingleOrDefault(); + Assert.Equal("Semantic-Kernel", value); + } + + [Fact] + public async Task WhenHttpClientDoesNotHaveBaseAddressProvidedEndpointShouldBeUsedAsync() + { + //Arrange + this._httpClient.BaseAddress = null; + var sut = new OllamaTextEmbeddingGenerationService("fake-model", new Uri("https://fake-random-test-host/fake-path/"), httpClient: this._httpClient); + + //Act + await sut.GenerateEmbeddingsAsync(new List { "fake-text" }); + + //Assert + Assert.StartsWith("https://fake-random-test-host/fake-path", this._messageHandlerStub.RequestUri?.AbsoluteUri, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task ShouldSendPromptToServiceAsync() + { + //Arrange + var sut = new OllamaTextEmbeddingGenerationService( + "fake-model", + new Uri("http://localhost:11434"), + httpClient: this._httpClient); + + //Act + await sut.GenerateEmbeddingsAsync(new List { "fake-text" }); + + //Assert + var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(requestPayload); + Assert.Equal("fake-text", requestPayload.Prompt); + } + + [Fact] + public async Task ShouldHandleServiceResponseAsync() + { + //Arrange + var sut = new OllamaTextEmbeddingGenerationService( + "fake-model", + new Uri("http://localhost:11434"), + httpClient: this._httpClient); + + //Act + var contents = await sut.GenerateEmbeddingsAsync(new List { "fake-text" }); + + //Assert + Assert.NotNull(contents); + + var content = contents.SingleOrDefault(); + Assert.Equal(8, content.Length); + } + + public void Dispose() + { + this._httpClient.Dispose(); + this._messageHandlerStub.Dispose(); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextGenerationTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextGenerationTests.cs new file mode 100644 index 000000000000..e5d9bd6d1884 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextGenerationTests.cs @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Ollama; +using Microsoft.SemanticKernel.TextGeneration; +using OllamaSharp.Models; +using Xunit; + +namespace SemanticKernel.Connectors.Ollama.UnitTests; + +public sealed class OllamaTextGenerationTests : IDisposable +{ + private readonly HttpMessageHandlerStub _messageHandlerStub; + private readonly HttpClient _httpClient; + + public OllamaTextGenerationTests() + { + this._messageHandlerStub = new HttpMessageHandlerStub(); + this._messageHandlerStub.ResponseToReturn.Content = new StringContent(OllamaTestHelper.GetTestResponse("text_generation_test_response.txt")); + this._httpClient = new HttpClient(this._messageHandlerStub, false); + } + + [Fact] + public async Task UserAgentHeaderShouldBeUsedAsync() + { + //Arrange + var sut = new OllamaTextGenerationService( + "fake-model", + new Uri("http://localhost:11434"), + httpClient: this._httpClient); + + //Act + await sut.GetTextContentsAsync("fake-text"); + + //Assert + Assert.True(this._messageHandlerStub.RequestHeaders?.Contains("User-Agent")); + + var values = this._messageHandlerStub.RequestHeaders!.GetValues("User-Agent"); + var value = values.SingleOrDefault(); + Assert.Equal("Semantic-Kernel", value); + } + + [Fact] + public async Task WhenHttpClientDoesNotHaveBaseAddressProvidedEndpointShouldBeUsedAsync() + { + //Arrange + this._httpClient.BaseAddress = null; + var sut = new OllamaTextGenerationService("fake-model", new Uri("https://fake-random-test-host/fake-path/"), httpClient: this._httpClient); + + //Act + await sut.GetTextContentsAsync("fake-text"); + + //Assert + Assert.StartsWith("https://fake-random-test-host/fake-path", this._messageHandlerStub.RequestUri?.AbsoluteUri, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task ShouldSendPromptToServiceAsync() + { + //Arrange + var sut = new OllamaTextGenerationService( + "fake-model", + new Uri("http://localhost:11434"), + httpClient: this._httpClient); + + //Act + await sut.GetTextContentsAsync("fake-text"); + + //Assert + var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(requestPayload); + Assert.Equal("fake-text", requestPayload.Prompt); + } + + [Fact] + public async Task ShouldHandleServiceResponseAsync() + { + //Arrange + var sut = new OllamaTextGenerationService( + "fake-model", + new Uri("http://localhost:11434"), + httpClient: this._httpClient); + + //Act + var contents = await sut.GetTextContentsAsync("fake-test"); + + //Assert + Assert.NotNull(contents); + + var content = contents.SingleOrDefault(); + Assert.NotNull(content); + Assert.Equal("This is test completion response", content.Text); + } + + [Fact] + public async Task GetTextContentsShouldHaveModelIdDefinedAsync() + { + //Arrange + var sut = new OllamaTextGenerationService( + "fake-model", + new Uri("http://localhost:11434"), + httpClient: this._httpClient); + + // Act + var textContent = await sut.GetTextContentAsync("Any prompt"); + + // Assert + Assert.NotNull(textContent.ModelId); + Assert.Equal("fake-model", textContent.ModelId); + } + + [Fact] + public async Task GetStreamingTextContentsShouldHaveModelIdDefinedAsync() + { + //Arrange + var expectedModel = "phi3"; + var sut = new OllamaTextGenerationService( + expectedModel, + new Uri("http://localhost:11434"), + httpClient: this._httpClient); + + this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK) + { + Content = new StreamContent(File.OpenRead("TestData/text_generation_test_response_stream.txt")) + }; + + // Act + StreamingTextContent? lastTextContent = null; + await foreach (var textContent in sut.GetStreamingTextContentsAsync("Any prompt")) + { + lastTextContent = textContent; + } + + // Assert + Assert.NotNull(lastTextContent!.ModelId); + Assert.Equal(expectedModel, lastTextContent.ModelId); + } + + /// + /// Disposes resources used by this class. + /// + public void Dispose() + { + this._messageHandlerStub.Dispose(); + + this._httpClient.Dispose(); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/chat_completion_test_response.txt b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/chat_completion_test_response.txt new file mode 100644 index 000000000000..b27faf2a1fb5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/chat_completion_test_response.txt @@ -0,0 +1 @@ +{"model":"phi3","created_at":"2024-07-02T12:30:00.295693434Z","message":{"role":"assistant","content":"This is test completion response"},"done_reason":"stop","done":true,"total_duration":69492224084,"load_duration":66637210792,"prompt_eval_count":10,"prompt_eval_duration":41164000,"eval_count":236,"eval_duration":2712234000} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/chat_completion_test_response_stream.txt b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/chat_completion_test_response_stream.txt new file mode 100644 index 000000000000..a0678c024d27 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/chat_completion_test_response_stream.txt @@ -0,0 +1,6 @@ +{"model":"phi3","created_at":"2024-07-02T11:45:16.216898458Z","message":{"role":"assistant","content":" these"},"done":false} +{"model":"phi3","created_at":"2024-07-02T11:45:16.22693076Z","message":{"role":"assistant","content":" times"},"done":false} +{"model":"phi3","created_at":"2024-07-02T11:45:16.236570847Z","message":{"role":"assistant","content":" of"},"done":false} +{"model":"phi3","created_at":"2024-07-02T11:45:16.246538945Z","message":{"role":"assistant","content":" day"},"done":false} +{"model":"phi3","created_at":"2024-07-02T11:45:16.25611096Z","message":{"role":"assistant","content":"."},"done":false} +{"model":"phi3","created_at":"2024-07-02T11:45:16.265598822Z","message":{"role":"assistant","content":""},"done_reason":"stop","done":true,"total_duration":58123571935,"load_duration":55561676662,"prompt_eval_count":10,"prompt_eval_duration":34847000,"eval_count":239,"eval_duration":2381751000} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/embeddings_test_response.json b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/embeddings_test_response.json new file mode 100644 index 000000000000..63218204d30d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/embeddings_test_response.json @@ -0,0 +1,13 @@ +{ + "embedding": [ + -0.08541165292263031, + 0.08639130741357803, + -0.12805694341659546, + -0.2877824902534485, + 0.2114177942276001, + -0.29374566674232483, + -0.10496602207422256, + 0.009402364492416382 + ], + "model": "fake-model" +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/text_generation_test_response.txt b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/text_generation_test_response.txt new file mode 100644 index 000000000000..b8d071565e02 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/text_generation_test_response.txt @@ -0,0 +1 @@ +{"model":"llama2","created_at":"2024-07-02T14:32:06.227383499Z","response":"This is test completion response","done":true,"done_reason":"stop","context":[518,25580,29962,3532,298,434,29889],"total_duration":94608355007,"load_duration":91044187969,"prompt_eval_count":26,"prompt_eval_duration":22543000,"eval_count":313,"eval_duration":3490651000} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/text_generation_test_response_stream.txt b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/text_generation_test_response_stream.txt new file mode 100644 index 000000000000..f662ae912202 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/text_generation_test_response_stream.txt @@ -0,0 +1,5 @@ +{"model":"phi3","created_at":"2024-07-02T12:22:37.03627019Z","response":" day","done":false} +{"model":"phi3","created_at":"2024-07-02T12:22:37.048915655Z","response":"light","done":false} +{"model":"phi3","created_at":"2024-07-02T12:22:37.060968719Z","response":" hours","done":false} +{"model":"phi3","created_at":"2024-07-02T12:22:37.072390403Z","response":".","done":false} +{"model":"phi3","created_at":"2024-07-02T12:22:37.091017292Z","response":"","done":true,"done_reason":"stop","context":[32010,3750,338,278,14744,7254,29973,32007,32001,450,2769,278,14744,5692,7254,304,502,373,11563,756,304,437,411,278,14801,292,310,6575,4366,491,278,25005,29889,8991,4366,29892,470,4796,3578,29892,338,1754,701,310,263,18272,310,11955,393,508,367,3595,297,263,17251,17729,313,1127,29892,24841,29892,13328,29892,7933,29892,7254,29892,1399,5973,29892,322,28008,1026,467,910,18272,310,11955,338,2998,408,4796,3578,1363,372,3743,599,278,1422,281,6447,1477,29879,12420,4208,29889,13,13,10401,6575,4366,24395,11563,29915,29879,25005,29892,21577,13206,21337,763,21767,307,1885,322,288,28596,14801,20511,29899,29893,6447,1477,3578,313,9539,322,28008,1026,29897,901,1135,5520,29899,29893,6447,1477,3578,313,1127,322,13328,467,4001,1749,5076,526,901,20502,304,7254,3578,322,278,8991,5692,901,4796,515,1749,18520,373,11563,2861,304,445,14801,292,2779,29892,591,17189,573,278,14744,408,7254,29889,13,13,2528,17658,29892,5998,1716,7254,322,28008,1026,281,6447,1477,29879,310,3578,526,29574,22829,491,4799,13206,21337,29892,1749,639,1441,338,451,28482,491,278,28008,1026,2927,1951,5199,5076,526,3109,20502,304,372,29889,12808,29892,6575,4366,20888,11563,29915,29879,7101,756,263,6133,26171,297,278,13328,29899,12692,760,310,278,18272,9401,304,2654,470,28008,1026,11955,2861,304,9596,280,1141,14801,292,29892,607,4340,26371,2925,1749,639,1441,310,278,7254,14744,29889,13,13,797,15837,29892,278,14801,292,310,20511,281,6447,1477,3578,313,9539,322,28008,1026,29897,491,11563,29915,29879,25005,9946,502,304,1074,263,758,24130,10835,7254,14744,2645,2462,4366,6199,29889,32007],"total_duration":64697743903,"load_duration":61368714283,"prompt_eval_count":10,"prompt_eval_duration":40919000,"eval_count":304,"eval_duration":3237325000} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Ollama/AssemblyInfo.cs b/dotnet/src/Connectors/Connectors.Ollama/AssemblyInfo.cs new file mode 100644 index 000000000000..fe66371dbc58 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/AssemblyInfo.cs @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +// This assembly is currently experimental. +[assembly: Experimental("SKEXP0070")] diff --git a/dotnet/src/Connectors/Connectors.Ollama/Connectors.Ollama.csproj b/dotnet/src/Connectors/Connectors.Ollama/Connectors.Ollama.csproj new file mode 100644 index 000000000000..e75d956fd50e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Connectors.Ollama.csproj @@ -0,0 +1,34 @@ + + + + + Microsoft.SemanticKernel.Connectors.Ollama + $(AssemblyName) + netstandard2.0 + alpha + + + + + + + + + Semantic Kernel - Ollama AI connectors + Semantic Kernel connector for Ollama. Contains clients for text generation. + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Ollama/Core/OllamaChatResponseStreamer.cs b/dotnet/src/Connectors/Connectors.Ollama/Core/OllamaChatResponseStreamer.cs new file mode 100644 index 000000000000..6a7818e100f5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Core/OllamaChatResponseStreamer.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Concurrent; +using OllamaSharp.Models.Chat; +using OllamaSharp.Streamer; + +namespace Microsoft.SemanticKernel.Connectors.Ollama.Core; + +internal class OllamaChatResponseStreamer : IResponseStreamer +{ + private readonly ConcurrentQueue _messages = new(); + public void Stream(ChatResponseStream stream) + { + if (stream.Message?.Content is null) + { + return; + } + + this._messages.Enqueue(stream.Message.Content); + } + + public bool TryGetMessage(out string result) + { + return this._messages.TryDequeue(out result); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs b/dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs new file mode 100644 index 000000000000..192cbc238f2e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Net.Http; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.Http; +using Microsoft.SemanticKernel.Services; +using OllamaSharp; + +namespace Microsoft.SemanticKernel.Connectors.Ollama.Core; + +/// +/// Represents the core of a service. +/// +public abstract class ServiceBase +{ + /// + /// Attributes of the service. + /// + internal Dictionary AttributesInternal { get; } = new(); + internal readonly OllamaApiClient _client; + + internal ServiceBase(string model, + Uri baseUri, + HttpClient? httpClient = null, + ILoggerFactory? loggerFactory = null) + { + Verify.NotNullOrWhiteSpace(model); + this.AttributesInternal.Add(AIServiceExtensions.ModelIdKey, model); + + if (httpClient is not null) + { + httpClient.BaseAddress ??= baseUri; + + // Try to add User-Agent header. + if (!httpClient.DefaultRequestHeaders.TryGetValues("User-Agent", out _)) + { + httpClient.DefaultRequestHeaders.Add("User-Agent", HttpHeaderConstant.Values.UserAgent); + } + + // Try to add Semantic Kernel Version header + if (!httpClient.DefaultRequestHeaders.TryGetValues(HttpHeaderConstant.Names.SemanticKernelVersion, out _)) + { + httpClient.DefaultRequestHeaders.Add(HttpHeaderConstant.Names.SemanticKernelVersion, HttpHeaderConstant.Values.GetAssemblyVersion(typeof(Kernel))); + } + + this._client = new(httpClient, model); + } + else + { +#pragma warning disable CA2000 // Dispose objects before losing scope + // Client needs to be created to be able to inject Semantic Kernel headers + var internalClient = HttpClientProvider.GetHttpClient(); + internalClient.BaseAddress = baseUri; + internalClient.DefaultRequestHeaders.Add("User-Agent", HttpHeaderConstant.Values.UserAgent); + internalClient.DefaultRequestHeaders.Add(HttpHeaderConstant.Names.SemanticKernelVersion, HttpHeaderConstant.Values.GetAssemblyVersion(typeof(Kernel))); + + this._client = new(internalClient, model); +#pragma warning restore CA2000 // Dispose objects before losing scope + } + } + + internal ServiceBase(string model, + OllamaApiClient ollamaClient, + ILoggerFactory? loggerFactory = null) + { + Verify.NotNullOrWhiteSpace(model); + this._client = ollamaClient; + this.AttributesInternal.Add(AIServiceExtensions.ModelIdKey, model); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs new file mode 100644 index 000000000000..c491d0e4397d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs @@ -0,0 +1,178 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Ollama; +using Microsoft.SemanticKernel.Embeddings; +using Microsoft.SemanticKernel.Http; +using Microsoft.SemanticKernel.TextGeneration; +using OllamaSharp; + +namespace Microsoft.SemanticKernel; + +/// +/// Extension methods for adding Ollama Text Generation service to the kernel builder. +/// +public static class OllamaKernelBuilderExtensions +{ + /// + /// Add Ollama Text Generation service to the kernel builder. + /// + /// The kernel builder. + /// The model for text generation. + /// The base uri to Ollama hosted service. + /// The optional service ID. + /// The optional custom HttpClient. + /// The updated kernel builder. + public static IKernelBuilder AddOllamaTextGeneration( + this IKernelBuilder builder, + string modelId, + Uri baseUri, + string? serviceId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(builder); + + builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextGenerationService( + model: modelId, + baseUri: baseUri, + httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), + loggerFactory: serviceProvider.GetService())); + return builder; + } + + /// + /// Add Ollama Text Generation service to the kernel builder. + /// + /// The kernel builder. + /// The model for text generation. + /// The Ollama Sharp library client. + /// The optional service ID. + /// The updated kernel builder. + public static IKernelBuilder AddOllamaTextGeneration( + this IKernelBuilder builder, + string modelId, + OllamaApiClient ollamaClient, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextGenerationService( + model: modelId, + ollamaClient: ollamaClient, + loggerFactory: serviceProvider.GetService())); + return builder; + } + + /// + /// Add Ollama Chat Completion service to the kernel builder. + /// + /// The kernel builder. + /// The model for text generation. + /// The base uri to Ollama hosted service. + /// The optional service ID. + /// The optional custom HttpClient. + /// The updated kernel builder. + public static IKernelBuilder AddOllamaChatCompletion( + this IKernelBuilder builder, + string modelId, + Uri baseUri, + string? serviceId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(builder); + Verify.NotNull(modelId); + + builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaChatCompletionService( + model: modelId, + baseUri: baseUri, + httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), + loggerFactory: serviceProvider.GetService())); + + return builder; + } + + /// + /// Add Ollama Chat Completion service to the kernel builder. + /// + /// The kernel builder. + /// The model for text generation. + /// The Ollama Sharp library client. + /// The optional service ID. + /// The updated kernel builder. + public static IKernelBuilder AddOllamaChatCompletion( + this IKernelBuilder builder, + string modelId, + OllamaApiClient ollamaClient, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaChatCompletionService( + model: modelId, + client: ollamaClient, + loggerFactory: serviceProvider.GetService())); + + return builder; + } + + /// + /// Add Ollama Text Embeddings Generation service to the kernel builder. + /// + /// The kernel builder. + /// The model for text generation. + /// The base uri to Ollama hosted service. + /// The optional service ID. + /// The optional custom HttpClient. + /// The updated kernel builder. + public static IKernelBuilder AddOllamaTextEmbeddingGeneration( + this IKernelBuilder builder, + string modelId, + Uri baseUri, + string? serviceId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(builder); + + builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextEmbeddingGenerationService( + model: modelId, + baseUri: baseUri, + httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), + loggerFactory: serviceProvider.GetService())); + + return builder; + } + + /// + /// Add Ollama Text Embeddings Generation service to the kernel builder. + /// + /// The kernel builder. + /// The model for text generation. + /// The Ollama Sharp library client. + /// The optional service ID. + /// The updated kernel builder. + public static IKernelBuilder AddOllamaTextEmbeddingGeneration( + this IKernelBuilder builder, + string modelId, + OllamaApiClient ollamaClient, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextEmbeddingGenerationService( + model: modelId, + ollamaClient: ollamaClient, + loggerFactory: serviceProvider.GetService())); + + return builder; + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs new file mode 100644 index 000000000000..7d9e1e14f33e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Ollama; +using Microsoft.SemanticKernel.Embeddings; +using Microsoft.SemanticKernel.Http; +using Microsoft.SemanticKernel.TextGeneration; +using OllamaSharp; + +namespace Microsoft.SemanticKernel; + +/// +/// Extension methods for adding Ollama Text Generation service to the kernel builder. +/// +public static class OllamaServiceCollectionExtensions +{ + /// + /// Add Ollama Text Generation service to the specified service collection. + /// + /// The target service collection. + /// The model for text generation. + /// The base uri to Ollama hosted service. + /// The optional service ID. + /// The updated kernel builder. + public static IServiceCollection AddOllamaTextGeneration( + this IServiceCollection services, + string modelId, + Uri baseUri, + string? serviceId = null) + { + Verify.NotNull(services); + + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextGenerationService( + model: modelId, + baseUri: baseUri, + httpClient: HttpClientProvider.GetHttpClient(serviceProvider), + loggerFactory: serviceProvider.GetService())); + } + + /// + /// Add Ollama Text Generation service to the kernel builder. + /// + /// The target service collection. + /// The model for text generation. + /// The Ollama Sharp library client. + /// The optional service ID. + /// The updated kernel builder. + public static IServiceCollection AddOllamaTextGeneration( + this IServiceCollection services, + string modelId, + OllamaApiClient ollamaClient, + string? serviceId = null) + { + Verify.NotNull(services); + + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextGenerationService( + model: modelId, + ollamaClient: ollamaClient, + loggerFactory: serviceProvider.GetService())); + } + + /// + /// Add Ollama Chat Completion and Text Generation services to the specified service collection. + /// + /// The target service collection. + /// The model for text generation. + /// The base uri to Ollama hosted service. + /// Optional service ID. + /// The updated service collection. + public static IServiceCollection AddOllamaChatCompletion( + this IServiceCollection services, + string modelId, + Uri baseUri, + string? serviceId = null) + { + Verify.NotNull(services); + + services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaChatCompletionService( + model: modelId, + baseUri: baseUri, + httpClient: HttpClientProvider.GetHttpClient(serviceProvider), + loggerFactory: serviceProvider.GetService())); + + return services; + } + + /// + /// Add Ollama Chat Completion service to the kernel builder. + /// + /// The target service collection. + /// The model for text generation. + /// The Ollama Sharp library client. + /// The optional service ID. + /// The updated kernel builder. + public static IServiceCollection AddOllamaChatCompletion( + this IServiceCollection services, + string modelId, + OllamaApiClient ollamaClient, + string? serviceId = null) + { + Verify.NotNull(services); + + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaChatCompletionService( + model: modelId, + client: ollamaClient, + loggerFactory: serviceProvider.GetService())); + } + + /// + /// Add Ollama Text Embedding Generation services to the kernel builder. + /// + /// The target service collection. + /// The model for text generation. + /// The base uri to Ollama hosted service. + /// Optional service ID. + /// The updated kernel builder. + public static IServiceCollection AddOllamaTextEmbeddingGeneration( + this IServiceCollection services, + string modelId, + Uri baseUri, + string? serviceId = null) + { + Verify.NotNull(services); + + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextEmbeddingGenerationService( + model: modelId, + baseUri: baseUri, + httpClient: HttpClientProvider.GetHttpClient(serviceProvider), + loggerFactory: serviceProvider.GetService())); + } + + /// + /// Add Ollama Text Embeddings Generation service to the kernel builder. + /// + /// The target service collection. + /// The model for text generation. + /// The Ollama Sharp library client. + /// The optional service ID. + /// The updated kernel builder. + public static IServiceCollection AddOllamaTextEmbeddingGeneration( + this IServiceCollection services, + string modelId, + OllamaApiClient ollamaClient, + string? serviceId = null) + { + Verify.NotNull(services); + + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextEmbeddingGenerationService( + model: modelId, + ollamaClient: ollamaClient, + loggerFactory: serviceProvider.GetService())); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs b/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs new file mode 100644 index 000000000000..dbe16cbeafab --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Runtime.CompilerServices; +using OllamaSharp.Models; + +namespace Microsoft.SemanticKernel.Connectors.Ollama; + +/// +/// Represents the metadata of the Ollama response. +/// +public sealed class OllamaMetadata : ReadOnlyDictionary +{ + internal OllamaMetadata(GenerateCompletionDoneResponseStream ollamaResponse) : base(new Dictionary()) + { + this.TotalDuration = ollamaResponse.TotalDuration; + this.EvalCount = ollamaResponse.EvalCount; + this.EvalDuration = ollamaResponse.EvalDuration; + this.CreatedAt = ollamaResponse.CreatedAt; + this.LoadDuration = ollamaResponse.LoadDuration; + this.PromptEvalCount = ollamaResponse.PromptEvalCount; + this.PromptEvalDuration = ollamaResponse.PromptEvalDuration; + } + + /// + /// Time spent in nanoseconds evaluating the prompt + /// + public long PromptEvalDuration + { + get => this.GetValueFromDictionary() as long? ?? 0; + internal init => this.SetValueInDictionary(value); + } + + /// + /// Number of tokens in the prompt + /// + public int PromptEvalCount + { + get => this.GetValueFromDictionary() as int? ?? 0; + internal init => this.SetValueInDictionary(value); + } + + /// + /// Time spent in nanoseconds loading the model + /// + public long LoadDuration + { + get => this.GetValueFromDictionary() as long? ?? 0; + internal init => this.SetValueInDictionary(value); + } + + /// + /// Returns the prompt's feedback related to the content filters. + /// + public string? CreatedAt + { + get => this.GetValueFromDictionary() as string; + internal init => this.SetValueInDictionary(value); + } + + /// + /// Time in nano seconds spent generating the response + /// + public long EvalDuration + { + get => this.GetValueFromDictionary() as long? ?? 0; + internal init => this.SetValueInDictionary(value); + } + + /// + /// Number of tokens the response + /// + public int EvalCount + { + get => this.GetValueFromDictionary() as int? ?? 0; + internal init => this.SetValueInDictionary(value); + } + + /// + /// Time spent in nanoseconds generating the response + /// + public long TotalDuration + { + get => this.GetValueFromDictionary() as long? ?? 0; + internal init => this.SetValueInDictionary(value); + } + + private void SetValueInDictionary(object? value, [CallerMemberName] string propertyName = "") + => this.Dictionary[propertyName] = value; + + private object? GetValueFromDictionary([CallerMemberName] string propertyName = "") + => this.Dictionary.TryGetValue(propertyName, out var value) ? value : null; +} diff --git a/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs new file mode 100644 index 000000000000..9fc47bb9bb1b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.SemanticKernel.Text; + +namespace Microsoft.SemanticKernel.Connectors.Ollama; + +/// +/// Ollama Execution Settings. +/// +public sealed class OllamaPromptExecutionSettings : PromptExecutionSettings +{ + /// + /// Gets the specialization for the Ollama execution settings. + /// + /// Generic prompt execution settings. + /// Specialized Ollama execution settings. + public static OllamaPromptExecutionSettings FromExecutionSettings(PromptExecutionSettings? executionSettings) + { + switch (executionSettings) + { + case null: + return new(); + case OllamaPromptExecutionSettings settings: + return settings; + } + + var json = JsonSerializer.Serialize(executionSettings); + var ollamaExecutionSettings = JsonSerializer.Deserialize(json, JsonOptionsCache.ReadPermissive); + if (ollamaExecutionSettings is not null) + { + return ollamaExecutionSettings; + } + + throw new ArgumentException( + $"Invalid execution settings, cannot convert to {nameof(OllamaPromptExecutionSettings)}", + nameof(executionSettings)); + } + + /// + /// Sets the stop sequences to use. When this pattern is encountered the + /// LLM will stop generating text and return. Multiple stop patterns may + /// be set by specifying multiple separate stop parameters in a modelfile. + /// + [JsonPropertyName("stop")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? Stop + { + get => this._stop; + + set + { + this.ThrowIfFrozen(); + this._stop = value; + } + } + + /// + /// Reduces the probability of generating nonsense. A higher value + /// (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) + /// will be more conservative. (Default: 40) + /// + [JsonPropertyName("top_k")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? TopK + { + get => this._topK; + + set + { + this.ThrowIfFrozen(); + this._topK = value; + } + } + + /// + /// Works together with top-k. A higher value (e.g., 0.95) will lead to + /// more diverse text, while a lower value (e.g., 0.5) will generate more + /// focused and conservative text. (Default: 0.9) + /// + [JsonPropertyName("top_p")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public float? TopP + { + get => this._topP; + + set + { + this.ThrowIfFrozen(); + this._topP = value; + } + } + + /// + /// The temperature of the model. Increasing the temperature will make the + /// model answer more creatively. (Default: 0.8) + /// + [JsonPropertyName("temperature")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public float? Temperature + { + get => this._temperature; + + set + { + this.ThrowIfFrozen(); + this._temperature = value; + } + } + + #region private ================================================================================ + + private string? _stop; + private float? _temperature; + private float? _topP; + private int? _topK; + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs new file mode 100644 index 000000000000..c6546622bc59 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Ollama.Core; +using OllamaSharp; +using OllamaSharp.Models.Chat; + +namespace Microsoft.SemanticKernel.Connectors.Ollama; + +/// +/// Represents a chat completion service using Ollama Original API. +/// +public sealed class OllamaChatCompletionService : ServiceBase, IChatCompletionService +{ + /// + /// Initializes a new instance of the class. + /// + /// The hosted model. + /// The base uri including the port where Ollama server is hosted + /// Optional HTTP client to be used for communication with the Ollama API. + /// Optional logger factory to be used for logging. + public OllamaChatCompletionService( + string model, + Uri baseUri, + HttpClient? httpClient = null, + ILoggerFactory? loggerFactory = null) + : base(model, baseUri, httpClient, loggerFactory) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The hosted model. + /// The Ollama API client. + /// Optional logger factory to be used for logging. + public OllamaChatCompletionService( + string model, + OllamaApiClient client, + ILoggerFactory? loggerFactory = null) + : base(model, client, loggerFactory) + { + } + + /// + public IReadOnlyDictionary Attributes => this.AttributesInternal; + + /// + public async Task> GetChatMessageContentsAsync( + ChatHistory chatHistory, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + var settings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + var request = CreateChatRequest(chatHistory, settings, this._client.SelectedModel); + + var answer = await this._client.SendChat(request, _ => { }, cancellationToken).ConfigureAwait(false); + + // Ollama Client gives back the same requested history with added message at the end + // To be compatible with this API behavior, we only return the added message (last). + var message = answer.Last(); + + return [new ChatMessageContent( + role: GetAuthorRole(message.Role) ?? AuthorRole.Assistant, + content: message.Content, + modelId: this._client.SelectedModel, + innerContent: message)]; + } + + /// + public async IAsyncEnumerable GetStreamingChatMessageContentsAsync( + ChatHistory chatHistory, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var settings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + var request = CreateChatRequest(chatHistory, settings, this._client.SelectedModel); + + await foreach (var message in this._client.StreamChat(request, cancellationToken).ConfigureAwait(false)) + { + yield return new StreamingChatMessageContent(GetAuthorRole(message?.Message.Role), message?.Message.Content, modelId: message?.Model, innerContent: message); + } + } + + private static AuthorRole? GetAuthorRole(ChatRole? role) => role.ToString().ToUpperInvariant() switch + { + "USER" => AuthorRole.User, + "ASSISTANT" => AuthorRole.Assistant, + "SYSTEM" => AuthorRole.System, + _ => null + }; + + private static ChatRequest CreateChatRequest(ChatHistory chatHistory, OllamaPromptExecutionSettings settings, string selectedModel) + { + var messages = new List(); + foreach (var chatHistoryMessage in chatHistory) + { + ChatRole role = ChatRole.User; + if (chatHistoryMessage.Role == AuthorRole.System) + { + role = ChatRole.System; + } + else if (chatHistoryMessage.Role == AuthorRole.Assistant) + { + role = ChatRole.Assistant; + } + + messages.Add(new Message(role, chatHistoryMessage.Content!)); + } + + var request = new ChatRequest + { + Options = new() + { + Temperature = settings.Temperature, + TopP = settings.TopP, + TopK = settings.TopK, + Stop = settings.Stop + }, + Messages = messages.ToList(), + Model = selectedModel, + Stream = true + }; + return request; + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs new file mode 100644 index 000000000000..121df1caf995 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.Connectors.Ollama.Core; +using Microsoft.SemanticKernel.Embeddings; +using OllamaSharp; +using OllamaSharp.Models; + +namespace Microsoft.SemanticKernel.Connectors.Ollama; + +/// +/// Represents a embedding generation service using Ollama Original API. +/// +public sealed class OllamaTextEmbeddingGenerationService : ServiceBase, ITextEmbeddingGenerationService +{ + /// + /// Initializes a new instance of the class. + /// + /// The hosted model. + /// The base uri including the port where Ollama server is hosted + /// Optional HTTP client to be used for communication with the Ollama API. + /// Optional logger factory to be used for logging. + public OllamaTextEmbeddingGenerationService( + string model, + Uri baseUri, + HttpClient? httpClient = null, + ILoggerFactory? loggerFactory = null) + : base(model, baseUri, httpClient, loggerFactory) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The hosted model. + /// The Ollama API client. + /// Optional logger factory to be used for logging. + public OllamaTextEmbeddingGenerationService( + string model, + OllamaApiClient ollamaClient, + ILoggerFactory? loggerFactory = null) + : base(model, ollamaClient, loggerFactory) + { + } + + /// + public IReadOnlyDictionary Attributes => this.AttributesInternal; + + /// + public async Task>> GenerateEmbeddingsAsync( + IList data, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + var tasks = new List>(); + foreach (var prompt in data) + { + tasks.Add(this._client.GenerateEmbeddings(prompt, cancellationToken: cancellationToken)); + } + + await Task.WhenAll(tasks.ToArray()).ConfigureAwait(false); + + return new List>( + tasks.Select( + task => new ReadOnlyMemory(task.Result.Embedding + .Select(@decimal => (float)@decimal).ToArray() + ) + ) + ); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs new file mode 100644 index 000000000000..0294004811ff --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.Connectors.Ollama.Core; +using Microsoft.SemanticKernel.TextGeneration; +using OllamaSharp; + +namespace Microsoft.SemanticKernel.Connectors.Ollama; + +/// +/// Represents a text generation service using Ollama Original API. +/// +public sealed class OllamaTextGenerationService : ServiceBase, ITextGenerationService +{ + /// + /// Initializes a new instance of the class. + /// + /// The Ollama model for the text generation service. + /// The base uri including the port where Ollama server is hosted + /// Optional HTTP client to be used for communication with the Ollama API. + /// Optional logger factory to be used for logging. + public OllamaTextGenerationService( + string model, + Uri baseUri, + HttpClient? httpClient = null, + ILoggerFactory? loggerFactory = null) + : base(model, baseUri, httpClient, loggerFactory) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The hosted model. + /// The Ollama API client. + /// Optional logger factory to be used for logging. + public OllamaTextGenerationService( + string model, + OllamaApiClient ollamaClient, + ILoggerFactory? loggerFactory = null) + : base(model, ollamaClient, loggerFactory) + { + } + + /// + public IReadOnlyDictionary Attributes => this.AttributesInternal; + + /// + public async Task> GetTextContentsAsync( + string prompt, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + var content = await this._client.GetCompletion(prompt, null, cancellationToken).ConfigureAwait(false); + + return [new(content.Response, modelId: this._client.SelectedModel, innerContent: content)]; + } + + /// + public async IAsyncEnumerable GetStreamingTextContentsAsync( + string prompt, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var content in this._client.StreamCompletion(prompt, null, cancellationToken).ConfigureAwait(false)) + { + yield return new StreamingTextContent(content?.Response, modelId: content?.Model, innerContent: content); + } + } +}