From ca2d6472585eedc07a4ef6a98baedfaf399eabe3 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Sat, 13 Jul 2024 00:49:11 +0100
Subject: [PATCH 1/2] Adding metadata to chat messages
---
dotnet/Directory.Packages.props | 4 +--
.../Connectors.Ollama.UnitTests.csproj | 2 +-
.../OllamaPromptExecutionSettingsTests.cs | 5 ++--
.../Services/OllamaChatCompletionTests.cs | 15 +++++++---
.../Connectors.Ollama/OllamaMetadata.cs | 28 +++++++++++++++++++
.../OllamaPromptExecutionSettings.cs | 5 ++--
.../Services/OllamaChatCompletionService.cs | 24 ++++++++--------
7 files changed, 61 insertions(+), 22 deletions(-)
diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props
index bc2f3c81d3bc..a4005b9d7abf 100644
--- a/dotnet/Directory.Packages.props
+++ b/dotnet/Directory.Packages.props
@@ -34,10 +34,10 @@
-
+
-
+
diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj
index 427f079b3c65..a08dcc759b89 100644
--- a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj
+++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj
@@ -27,7 +27,7 @@
all
-
+
diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaPromptExecutionSettingsTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaPromptExecutionSettingsTests.cs
index 931b1f0674a8..314d05876e6f 100644
--- a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaPromptExecutionSettingsTests.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/OllamaPromptExecutionSettingsTests.cs
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
+using System.Linq;
using System.Text.Json;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.Ollama;
@@ -46,7 +47,7 @@ public void FromExecutionSettingsWhenSerializedHasPropertiesShouldPopulateSpecia
{
string jsonSettings = """
{
- "stop": "stop me",
+ "stop": ["stop me"],
"temperature": 0.5,
"top_p": 0.9,
"top_k": 100
@@ -56,7 +57,7 @@ public void FromExecutionSettingsWhenSerializedHasPropertiesShouldPopulateSpecia
var executionSettings = JsonSerializer.Deserialize(jsonSettings);
var ollamaExecutionSettings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings);
- Assert.Equal("stop me", ollamaExecutionSettings.Stop);
+ Assert.Equal("stop me", ollamaExecutionSettings.Stop?.FirstOrDefault());
Assert.Equal(0.5f, ollamaExecutionSettings.Temperature);
Assert.Equal(0.9f, ollamaExecutionSettings.TopP!.Value, 0.1f);
Assert.Equal(100, ollamaExecutionSettings.TopK);
diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs
index 622268ecd2a5..a3cf41d62706 100644
--- a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs
@@ -109,11 +109,11 @@ public async Task ShouldHandleServiceResponseAsync()
}
[Fact]
- public async Task GetChatMessageContentsShouldHaveModelIdDefinedAsync()
+ public async Task GetChatMessageContentsShouldHaveModelAndMetadataAsync()
{
//Arrange
var sut = new OllamaChatCompletionService(
- "fake-model",
+ "phi3",
new Uri("http://localhost:11434"),
httpClient: this._httpClient);
@@ -135,11 +135,11 @@ public async Task GetChatMessageContentsShouldHaveModelIdDefinedAsync()
// Assert
Assert.NotNull(message.ModelId);
- Assert.Equal("fake-model", message.ModelId);
+ Assert.Equal("phi3", message.ModelId);
}
[Fact]
- public async Task GetStreamingChatMessageContentsShouldHaveModelIdDefinedAsync()
+ public async Task GetStreamingChatMessageContentsShouldHaveModelAndMetadataAsync()
{
//Arrange
var expectedModel = "phi3";
@@ -161,11 +161,18 @@ public async Task GetStreamingChatMessageContentsShouldHaveModelIdDefinedAsync()
await foreach (var message in sut.GetStreamingChatMessageContentsAsync(chat))
{
lastMessage = message;
+ Assert.NotNull(message.Metadata);
}
// Assert
Assert.NotNull(lastMessage!.ModelId);
Assert.Equal(expectedModel, lastMessage.ModelId);
+
+ Assert.IsType(lastMessage.Metadata);
+ var metadata = lastMessage.Metadata as OllamaMetadata;
+ Assert.NotNull(metadata);
+ Assert.NotEmpty(metadata);
+ Assert.True(metadata.Done);
}
public void Dispose()
diff --git a/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs b/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs
index dbe16cbeafab..398a649e3494 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs
@@ -4,6 +4,7 @@
using System.Collections.ObjectModel;
using System.Runtime.CompilerServices;
using OllamaSharp.Models;
+using OllamaSharp.Models.Chat;
namespace Microsoft.SemanticKernel.Connectors.Ollama;
@@ -21,6 +22,24 @@ internal OllamaMetadata(GenerateCompletionDoneResponseStream ollamaResponse) : b
this.LoadDuration = ollamaResponse.LoadDuration;
this.PromptEvalCount = ollamaResponse.PromptEvalCount;
this.PromptEvalDuration = ollamaResponse.PromptEvalDuration;
+ this.Done = ollamaResponse.Done;
+ }
+
+ internal OllamaMetadata(ChatResponse response) : base(new Dictionary())
+ {
+ this.TotalDuration = response.TotalDuration;
+ this.EvalCount = response.EvalCount;
+ this.EvalDuration = response.EvalDuration;
+ this.CreatedAt = response.CreatedAt;
+ this.LoadDuration = response.LoadDuration;
+ this.PromptEvalDuration = response.PromptEvalDuration;
+ this.CreatedAt = response.CreatedAt;
+ }
+
+ internal OllamaMetadata(ChatResponseStream message) : base(new Dictionary())
+ {
+ this.CreatedAt = message.CreatedAt;
+ this.Done = message.Done;
}
///
@@ -86,6 +105,15 @@ public long TotalDuration
internal init => this.SetValueInDictionary(value);
}
+ ///
+ /// Informs when the response generation process is complete.
+ ///
+ public bool? Done
+ {
+ get => this.GetValueFromDictionary() as bool?;
+ internal init => this.SetValueInDictionary(value);
+ }
+
private void SetValueInDictionary(object? value, [CallerMemberName] string propertyName = "")
=> this.Dictionary[propertyName] = value;
diff --git a/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs
index 9fc47bb9bb1b..4757f0c13520 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/OllamaPromptExecutionSettings.cs
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
+using System.Collections.Generic;
using System.Text.Json;
using System.Text.Json.Serialization;
using Microsoft.SemanticKernel.Text;
@@ -46,7 +47,7 @@ public static OllamaPromptExecutionSettings FromExecutionSettings(PromptExecutio
///
[JsonPropertyName("stop")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
- public string? Stop
+ public List? Stop
{
get => this._stop;
@@ -112,7 +113,7 @@ public float? Temperature
#region private ================================================================================
- private string? _stop;
+ private List? _stop;
private float? _temperature;
private float? _topP;
private int? _topK;
diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs
index c6546622bc59..312c7e0f856b 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs
@@ -63,17 +63,14 @@ public async Task> GetChatMessageContentsAsync
var settings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings);
var request = CreateChatRequest(chatHistory, settings, this._client.SelectedModel);
- var answer = await this._client.SendChat(request, _ => { }, cancellationToken).ConfigureAwait(false);
-
- // Ollama Client gives back the same requested history with added message at the end
- // To be compatible with this API behavior, we only return the added message (last).
- var message = answer.Last();
+ var response = await this._client.Chat(request, cancellationToken).ConfigureAwait(false);
return [new ChatMessageContent(
- role: GetAuthorRole(message.Role) ?? AuthorRole.Assistant,
- content: message.Content,
- modelId: this._client.SelectedModel,
- innerContent: message)];
+ role: GetAuthorRole(response.Message.Role) ?? AuthorRole.Assistant,
+ content: response.Message.Content,
+ modelId: response.Model,
+ innerContent: response,
+ metadata: new OllamaMetadata(response))];
}
///
@@ -88,7 +85,12 @@ public async IAsyncEnumerable GetStreamingChatMessa
await foreach (var message in this._client.StreamChat(request, cancellationToken).ConfigureAwait(false))
{
- yield return new StreamingChatMessageContent(GetAuthorRole(message?.Message.Role), message?.Message.Content, modelId: message?.Model, innerContent: message);
+ yield return new StreamingChatMessageContent(
+ role: GetAuthorRole(message!.Message.Role),
+ content: message.Message.Content,
+ modelId: message.Model,
+ innerContent: message,
+ metadata: new OllamaMetadata(message));
}
}
@@ -125,7 +127,7 @@ private static ChatRequest CreateChatRequest(ChatHistory chatHistory, OllamaProm
Temperature = settings.Temperature,
TopP = settings.TopP,
TopK = settings.TopK,
- Stop = settings.Stop
+ Stop = settings.Stop?.ToArray()
},
Messages = messages.ToList(),
Model = selectedModel,
From a235a20dbf9d1065d8ae55f3c0c4bc57fe4a9173 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Sat, 13 Jul 2024 01:06:33 +0100
Subject: [PATCH 2/2] Conflict fix
---
dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs | 6 ------
1 file changed, 6 deletions(-)
diff --git a/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs b/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs
index 81de1af14d4f..fd7aba01819b 100644
--- a/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs
+++ b/dotnet/src/Connectors/Connectors.Ollama/OllamaMetadata.cs
@@ -65,12 +65,6 @@ internal OllamaMetadata(ChatResponse response) : base(new Dictionary())
- {
- this.CreatedAt = message.CreatedAt;
- this.Done = message.Done;
- }
-
///
/// Time spent in nanoseconds evaluating the prompt
///