From ca2d6472585eedc07a4ef6a98baedfaf399eabe3 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Sat, 13 Jul 2024 00:49:11 +0100 Subject: [PATCH 1/2] Adding metadata to chat messages --- dotnet/Directory.Packages.props | 4 +-- .../Connectors.Ollama.UnitTests.csproj | 2 +- .../OllamaPromptExecutionSettingsTests.cs | 5 ++-- .../Services/OllamaChatCompletionTests.cs | 15 +++++++--- .../Connectors.Ollama/OllamaMetadata.cs | 28 +++++++++++++++++++ .../OllamaPromptExecutionSettings.cs | 5 ++-- .../Services/OllamaChatCompletionService.cs | 24 ++++++++-------- 7 files changed, 61 insertions(+), 22 deletions(-) diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index bc2f3c81d3bc..a4005b9d7abf 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -34,10 +34,10 @@ - + - + diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj index 427f079b3c65..a08dcc759b89 100644 --- a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj @@ -27,7 +27,7 @@ all - + diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaPromptExecutionSettingsTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaPromptExecutionSettingsTests.cs index 931b1f0674a8..314d05876e6f 100644 --- a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaPromptExecutionSettingsTests.cs +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaPromptExecutionSettingsTests.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Linq; using System.Text.Json; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.Ollama; @@ -46,7 +47,7 @@ public void FromExecutionSettingsWhenSerializedHasPropertiesShouldPopulateSpecia { string jsonSettings = """ { - "stop": "stop me", + "stop": ["stop me"], "temperature": 0.5, "top_p": 0.9, "top_k": 100 @@ -56,7 +57,7 @@ public void FromExecutionSettingsWhenSerializedHasPropertiesShouldPopulateSpecia var executionSettings = JsonSerializer.Deserialize(jsonSettings); var ollamaExecutionSettings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); - Assert.Equal("stop me", ollamaExecutionSettings.Stop); + Assert.Equal("stop me", ollamaExecutionSettings.Stop?.FirstOrDefault()); Assert.Equal(0.5f, ollamaExecutionSettings.Temperature); Assert.Equal(0.9f, ollamaExecutionSettings.TopP!.Value, 0.1f); Assert.Equal(100, ollamaExecutionSettings.TopK); diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs index 622268ecd2a5..a3cf41d62706 100644 --- a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs @@ -109,11 +109,11 @@ public async Task ShouldHandleServiceResponseAsync() } [Fact] - public async Task GetChatMessageContentsShouldHaveModelIdDefinedAsync() + public async Task GetChatMessageContentsShouldHaveModelAndMetadataAsync() { //Arrange var sut = new OllamaChatCompletionService( - "fake-model", + "phi3", new Uri("http://localhost:11434"), httpClient: this._httpClient); @@ -135,11 +135,11 @@ public async Task GetChatMessageContentsShouldHaveModelIdDefinedAsync() // Assert Assert.NotNull(message.ModelId); - Assert.Equal("fake-model", message.ModelId); + Assert.Equal("phi3", message.ModelId); } [Fact] - public async Task GetStreamingChatMessageContentsShouldHaveModelIdDefinedAsync() + public async Task GetStreamingChatMessageContentsShouldHaveModelAndMetadataAsync() { //Arrange var expectedModel = "phi3"; @@ -161,11 +161,18 @@ public async Task GetStreamingChatMessageContentsShouldHaveModelIdDefinedAsync() await foreach (var message in sut.GetStreamingChatMessageContentsAsync(chat)) { lastMessage = message; + Assert.NotNull(message.Metadata); } // Assert Assert.NotNull(lastMessage!.ModelId); Assert.Equal(expectedModel, lastMessage.ModelId); + + Assert.IsType(lastMessage.Metadata); + var metadata = lastMessage.Metadata as OllamaMetadata; + Assert.NotNull(metadata); + Assert.NotEmpty(metadata); + Assert.True(metadata.Done); } public void Dispose() diff --git a/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs b/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs index dbe16cbeafab..398a649e3494 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs @@ -4,6 +4,7 @@ using System.Collections.ObjectModel; using System.Runtime.CompilerServices; using OllamaSharp.Models; +using OllamaSharp.Models.Chat; namespace Microsoft.SemanticKernel.Connectors.Ollama; @@ -21,6 +22,24 @@ internal OllamaMetadata(GenerateCompletionDoneResponseStream ollamaResponse) : b this.LoadDuration = ollamaResponse.LoadDuration; this.PromptEvalCount = ollamaResponse.PromptEvalCount; this.PromptEvalDuration = ollamaResponse.PromptEvalDuration; + this.Done = ollamaResponse.Done; + } + + internal OllamaMetadata(ChatResponse response) : base(new Dictionary()) + { + this.TotalDuration = response.TotalDuration; + this.EvalCount = response.EvalCount; + this.EvalDuration = response.EvalDuration; + this.CreatedAt = response.CreatedAt; + this.LoadDuration = response.LoadDuration; + this.PromptEvalDuration = response.PromptEvalDuration; + this.CreatedAt = response.CreatedAt; + } + + internal OllamaMetadata(ChatResponseStream message) : base(new Dictionary()) + { + this.CreatedAt = message.CreatedAt; + this.Done = message.Done; } /// @@ -86,6 +105,15 @@ public long TotalDuration internal init => this.SetValueInDictionary(value); } + /// + /// Informs when the response generation process is complete. + /// + public bool? Done + { + get => this.GetValueFromDictionary() as bool?; + internal init => this.SetValueInDictionary(value); + } + private void SetValueInDictionary(object? value, [CallerMemberName] string propertyName = "") => this.Dictionary[propertyName] = value; diff --git a/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs index 9fc47bb9bb1b..4757f0c13520 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.Text.Json; using System.Text.Json.Serialization; using Microsoft.SemanticKernel.Text; @@ -46,7 +47,7 @@ public static OllamaPromptExecutionSettings FromExecutionSettings(PromptExecutio /// [JsonPropertyName("stop")] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] - public string? Stop + public List? Stop { get => this._stop; @@ -112,7 +113,7 @@ public float? Temperature #region private ================================================================================ - private string? _stop; + private List? _stop; private float? _temperature; private float? _topP; private int? _topK; diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs index c6546622bc59..312c7e0f856b 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs @@ -63,17 +63,14 @@ public async Task> GetChatMessageContentsAsync var settings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); var request = CreateChatRequest(chatHistory, settings, this._client.SelectedModel); - var answer = await this._client.SendChat(request, _ => { }, cancellationToken).ConfigureAwait(false); - - // Ollama Client gives back the same requested history with added message at the end - // To be compatible with this API behavior, we only return the added message (last). - var message = answer.Last(); + var response = await this._client.Chat(request, cancellationToken).ConfigureAwait(false); return [new ChatMessageContent( - role: GetAuthorRole(message.Role) ?? AuthorRole.Assistant, - content: message.Content, - modelId: this._client.SelectedModel, - innerContent: message)]; + role: GetAuthorRole(response.Message.Role) ?? AuthorRole.Assistant, + content: response.Message.Content, + modelId: response.Model, + innerContent: response, + metadata: new OllamaMetadata(response))]; } /// @@ -88,7 +85,12 @@ public async IAsyncEnumerable GetStreamingChatMessa await foreach (var message in this._client.StreamChat(request, cancellationToken).ConfigureAwait(false)) { - yield return new StreamingChatMessageContent(GetAuthorRole(message?.Message.Role), message?.Message.Content, modelId: message?.Model, innerContent: message); + yield return new StreamingChatMessageContent( + role: GetAuthorRole(message!.Message.Role), + content: message.Message.Content, + modelId: message.Model, + innerContent: message, + metadata: new OllamaMetadata(message)); } } @@ -125,7 +127,7 @@ private static ChatRequest CreateChatRequest(ChatHistory chatHistory, OllamaProm Temperature = settings.Temperature, TopP = settings.TopP, TopK = settings.TopK, - Stop = settings.Stop + Stop = settings.Stop?.ToArray() }, Messages = messages.ToList(), Model = selectedModel, From a235a20dbf9d1065d8ae55f3c0c4bc57fe4a9173 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Sat, 13 Jul 2024 01:06:33 +0100 Subject: [PATCH 2/2] Conflict fix --- dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs b/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs index 81de1af14d4f..fd7aba01819b 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs @@ -65,12 +65,6 @@ internal OllamaMetadata(ChatResponse response) : base(new Dictionary()) - { - this.CreatedAt = message.CreatedAt; - this.Done = message.Done; - } - /// /// Time spent in nanoseconds evaluating the prompt ///