From da4934fa97749b1c20044f71b5b6739568107698 Mon Sep 17 00:00:00 2001 From: Krzysztof Kasprowicz Date: Tue, 2 Jul 2024 19:53:28 +0200 Subject: [PATCH 01/14] Add support for system messages in GeminiRequest --- .../Clients/GeminiChatCompletionClient.cs | 41 ++++--------------- .../Core/Gemini/Models/GeminiRequest.cs | 40 +++++++++++++++++- 2 files changed, 46 insertions(+), 35 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index e52b5f4e6bd6..6d527485a576 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -297,8 +297,7 @@ private ChatCompletionState ValidateInputAndCreateChatCompletionState( Kernel? kernel, PromptExecutionSettings? executionSettings) { - var chatHistoryCopy = new ChatHistory(chatHistory); - ValidateAndPrepareChatHistory(chatHistoryCopy); + ValidateChatHistory(chatHistory); var geminiExecutionSettings = GeminiPromptExecutionSettings.FromExecutionSettings(executionSettings); ValidateMaxTokens(geminiExecutionSettings.MaxTokens); @@ -315,7 +314,7 @@ private ChatCompletionState ValidateInputAndCreateChatCompletionState( AutoInvoke = CheckAutoInvokeCondition(kernel, geminiExecutionSettings), ChatHistory = chatHistory, ExecutionSettings = geminiExecutionSettings, - GeminiRequest = CreateRequest(chatHistoryCopy, geminiExecutionSettings, kernel), + GeminiRequest = CreateRequest(chatHistory, geminiExecutionSettings, kernel), Kernel = kernel! // not null if auto-invoke is true }; } @@ -517,46 +516,20 @@ private static bool CheckAutoInvokeCondition(Kernel? kernel, GeminiPromptExecuti return autoInvoke; } - private static void ValidateAndPrepareChatHistory(ChatHistory chatHistory) + private static void ValidateChatHistory(ChatHistory chatHistory) { Verify.NotNullOrEmpty(chatHistory); - - if (chatHistory.Where(message => message.Role == AuthorRole.System).ToList() is { Count: > 0 } systemMessages) - { - if (chatHistory.Count == systemMessages.Count) - { - throw new InvalidOperationException("Chat history can't contain only system messages."); - } - - if (systemMessages.Count > 1) - { - throw new InvalidOperationException("Chat history can't contain more than one system message. " + - "Only the first system message will be processed but will be converted to the user message before sending to the Gemini api."); - } - - ConvertSystemMessageToUserMessageInChatHistory(chatHistory, systemMessages[0]); - } - ValidateChatHistoryMessagesOrder(chatHistory); } - private static void ConvertSystemMessageToUserMessageInChatHistory(ChatHistory chatHistory, ChatMessageContent systemMessage) - { - // TODO: This solution is needed due to the fact that Gemini API doesn't support system messages. Maybe in the future we will be able to remove it. - chatHistory.Remove(systemMessage); - if (!string.IsNullOrWhiteSpace(systemMessage.Content)) - { - chatHistory.Insert(0, new ChatMessageContent(AuthorRole.User, systemMessage.Content)); - chatHistory.Insert(1, new ChatMessageContent(AuthorRole.Assistant, "OK")); - } - } - private static void ValidateChatHistoryMessagesOrder(ChatHistory chatHistory) { bool incorrectOrder = false; - // Exclude tool calls from the validation + // Exclude tool calls and system messages from the validation ChatHistory chatHistoryCopy = new(chatHistory - .Where(message => message.Role != AuthorRole.Tool && (message is not GeminiChatMessageContent { ToolCalls: not null }))); + .Where(message => message.Role != AuthorRole.Tool && + message.Role != AuthorRole.System && + message is not GeminiChatMessageContent { ToolCalls: not null })); for (int i = 0; i < chatHistoryCopy.Count; i++) { if (chatHistoryCopy[i].Role != (i % 2 == 0 ? AuthorRole.User : AuthorRole.Assistant) || diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs index def81d9a7083..ac8e477d4202 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs @@ -26,6 +26,9 @@ internal sealed class GeminiRequest [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public IList? Tools { get; set; } + [JsonPropertyName("systemInstruction")] + public GeminiContent? SystemMessages { get; set; } + public void AddFunction(GeminiFunction function) { // NOTE: Currently Gemini only supports one tool i.e. function calling. @@ -95,7 +98,10 @@ private static GeminiRequest CreateGeminiRequest(ChatHistory chatHistory) { GeminiRequest obj = new() { - Contents = chatHistory.Select(CreateGeminiContentFromChatMessage).ToList() + Contents = chatHistory + .Where(message => message.Role != AuthorRole.System) + .Select(CreateGeminiContentFromChatMessage).ToList(), + SystemMessages = CreateSystemChatMessages(chatHistory) }; return obj; } @@ -109,6 +115,20 @@ private static GeminiContent CreateGeminiContentFromChatMessage(ChatMessageConte }; } + private static GeminiContent? CreateSystemChatMessages(ChatHistory chatHistory) + { + var contents = chatHistory.Where(message => message.Role == AuthorRole.System).ToList(); + if (contents.Count == 0) + { + return null; + } + + return new GeminiContent + { + Parts = CreateGeminiParts(contents) + }; + } + public void AddChatMessage(ChatMessageContent message) { Verify.NotNull(this.Contents); @@ -117,6 +137,24 @@ public void AddChatMessage(ChatMessageContent message) this.Contents.Add(CreateGeminiContentFromChatMessage(message)); } + private static List CreateGeminiParts(IEnumerable contents) + { + List? parts = null; + foreach (var content in contents) + { + if (parts == null) + { + parts = CreateGeminiParts(content); + } + else + { + parts.AddRange(CreateGeminiParts(content)); + } + } + + return parts!; + } + private static List CreateGeminiParts(ChatMessageContent content) { List parts = []; From 419f0d11651425a25b59b8208a32424aab365c59 Mon Sep 17 00:00:00 2001 From: Krzysztof Kasprowicz Date: Tue, 2 Jul 2024 20:03:47 +0200 Subject: [PATCH 02/14] Updated unit tests and code. --- .../Clients/GeminiChatGenerationTests.cs | 41 ++++++++++++------- .../Clients/GeminiChatStreamingTests.cs | 33 +++++++++++++-- .../Clients/GeminiChatCompletionClient.cs | 9 ++++ 3 files changed, 64 insertions(+), 19 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs index 6b5bda155483..d5ba021e29ac 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs @@ -259,26 +259,35 @@ await Assert.ThrowsAsync( } [Fact] - public async Task ShouldThrowInvalidOperationExceptionIfChatHistoryContainsMoreThanOneSystemMessageAsync() + public async Task ShouldPassSystemMessageToRequestAsync() { + // Arrange var client = this.CreateChatCompletionClient(); - var chatHistory = new ChatHistory("System message"); - chatHistory.AddSystemMessage("System message 2"); - chatHistory.AddSystemMessage("System message 3"); - chatHistory.AddUserMessage("hello"); + string message = "System message"; + var chatHistory = new ChatHistory(message); + chatHistory.AddUserMessage("Hello"); - // Act & Assert - await Assert.ThrowsAsync( - () => client.GenerateChatMessageAsync(chatHistory)); + // Act + await client.GenerateChatMessageAsync(chatHistory); + + // Assert + GeminiRequest? request = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(request); + Assert.NotNull(request.SystemMessages); + var systemMessage = request.SystemMessages.Parts![0].Text; + Assert.Null(request.SystemMessages.Role); + Assert.Equal(message, systemMessage); } [Fact] - public async Task ShouldPassConvertedSystemMessageToUserMessageToRequestAsync() + public async Task ShouldPassMultipleSystemMessagesToRequestAsync() { // Arrange + string[] messages = ["System message 1", "System message 2", "System message 3"]; var client = this.CreateChatCompletionClient(); - string message = "System message"; - var chatHistory = new ChatHistory(message); + var chatHistory = new ChatHistory(messages[0]); + chatHistory.AddSystemMessage(messages[1]); + chatHistory.AddSystemMessage(messages[2]); chatHistory.AddUserMessage("Hello"); // Act @@ -287,10 +296,12 @@ public async Task ShouldPassConvertedSystemMessageToUserMessageToRequestAsync() // Assert GeminiRequest? request = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); Assert.NotNull(request); - var systemMessage = request.Contents[0].Parts![0].Text; - var messageRole = request.Contents[0].Role; - Assert.Equal(AuthorRole.User, messageRole); - Assert.Equal(message, systemMessage); + Assert.NotNull(request.SystemMessages); + Assert.Null(request.SystemMessages.Role); + Assert.Collection(request.SystemMessages.Parts!, + item => Assert.Equal(messages[0], item.Text), + item => Assert.Equal(messages[1], item.Text), + item => Assert.Equal(messages[2], item.Text)); } [Fact] diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs index 73b647429297..6f81e109d2c0 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs @@ -248,7 +248,7 @@ public async Task ShouldUsePromptExecutionSettingsAsync() } [Fact] - public async Task ShouldPassConvertedSystemMessageToUserMessageToRequestAsync() + public async Task ShouldPassSystemMessageToRequestAsync() { // Arrange var client = this.CreateChatCompletionClient(); @@ -262,12 +262,37 @@ public async Task ShouldPassConvertedSystemMessageToUserMessageToRequestAsync() // Assert GeminiRequest? request = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); Assert.NotNull(request); - var systemMessage = request.Contents[0].Parts![0].Text; - var messageRole = request.Contents[0].Role; - Assert.Equal(AuthorRole.User, messageRole); + Assert.NotNull(request.SystemMessages); + var systemMessage = request.SystemMessages.Parts![0].Text; + Assert.Null(request.SystemMessages.Role); Assert.Equal(message, systemMessage); } + [Fact] + public async Task ShouldPassMultipleSystemMessagesToRequestAsync() + { + // Arrange + string[] messages = ["System message 1", "System message 2", "System message 3"]; + var client = this.CreateChatCompletionClient(); + var chatHistory = new ChatHistory(messages[0]); + chatHistory.AddSystemMessage(messages[1]); + chatHistory.AddSystemMessage(messages[2]); + chatHistory.AddUserMessage("Hello"); + + // Act + await client.StreamGenerateChatMessageAsync(chatHistory).ToListAsync(); + + // Assert + GeminiRequest? request = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(request); + Assert.NotNull(request.SystemMessages); + Assert.Null(request.SystemMessages.Role); + Assert.Collection(request.SystemMessages.Parts!, + item => Assert.Equal(messages[0], item.Text), + item => Assert.Equal(messages[1], item.Text), + item => Assert.Equal(messages[2], item.Text)); + } + [Theory] [InlineData(0)] [InlineData(-15)] diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index 6d527485a576..892f5473d001 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -519,9 +519,18 @@ private static bool CheckAutoInvokeCondition(Kernel? kernel, GeminiPromptExecuti private static void ValidateChatHistory(ChatHistory chatHistory) { Verify.NotNullOrEmpty(chatHistory); + ThrowIfChatContainsOnlySystemMessages(chatHistory); ValidateChatHistoryMessagesOrder(chatHistory); } + private static void ThrowIfChatContainsOnlySystemMessages(ChatHistory chatHistory) + { + if (chatHistory.All(message => message.Role == AuthorRole.System)) + { + throw new InvalidOperationException("Chat history can't contain only system messages."); + } + } + private static void ValidateChatHistoryMessagesOrder(ChatHistory chatHistory) { bool incorrectOrder = false; From 5fe330928c94f79c3190b1c5eb1ce87382c7835f Mon Sep 17 00:00:00 2001 From: Krzysztof Kasprowicz Date: Tue, 2 Jul 2024 20:11:25 +0200 Subject: [PATCH 03/14] Add new integration tests for chat generation and streaming with system messages --- .../Core/Gemini/Models/GeminiRequest.cs | 1 + .../Gemini/GeminiChatCompletionTests.cs | 51 +++++++++++++++++++ dotnet/src/IntegrationTests/testsettings.json | 8 +-- 3 files changed, 56 insertions(+), 4 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs index ac8e477d4202..a2515c234c41 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs @@ -27,6 +27,7 @@ internal sealed class GeminiRequest public IList? Tools { get; set; } [JsonPropertyName("systemInstruction")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public GeminiContent? SystemMessages { get; set; } public void AddFunction(GeminiFunction function) diff --git a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs index 321ede0ff115..d39b1eb14297 100644 --- a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs @@ -64,6 +64,57 @@ public async Task ChatStreamingReturnsValidResponseAsync(ServiceType serviceType this.Output.WriteLine(message); } + [RetryTheory] + [InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")] + [InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")] + public async Task ChatGenerationWithSystemMessagesAsync(ServiceType serviceType) + { + // Arrange + var chatHistory = new ChatHistory("You are helpful assistant. Your name is Roger."); + chatHistory.AddSystemMessage("You know ACDD equals 1520"); + chatHistory.AddUserMessage("Hello, I'm Brandon, how are you?"); + chatHistory.AddAssistantMessage("I'm doing well, thanks for asking."); + chatHistory.AddUserMessage("Tell me your name and the value of ACDD."); + + var sut = this.GetChatService(serviceType); + + // Act + var response = await sut.GetChatMessageContentAsync(chatHistory); + + // Assert + Assert.NotNull(response.Content); + this.Output.WriteLine(response.Content); + Assert.Contains("1520", response.Content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Roger", response.Content, StringComparison.OrdinalIgnoreCase); + } + + [RetryTheory] + [InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")] + [InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")] + public async Task ChatStreamingWithSystemMessagesAsync(ServiceType serviceType) + { + // Arrange + var chatHistory = new ChatHistory("You are helpful assistant. Your name is Roger."); + chatHistory.AddSystemMessage("You know ACDD equals 1520"); + chatHistory.AddUserMessage("Hello, I'm Brandon, how are you?"); + chatHistory.AddAssistantMessage("I'm doing well, thanks for asking."); + chatHistory.AddUserMessage("Tell me your name and the value of ACDD."); + + var sut = this.GetChatService(serviceType); + + // Act + var response = + await sut.GetStreamingChatMessageContentsAsync(chatHistory).ToListAsync(); + + // Assert + Assert.NotEmpty(response); + Assert.True(response.Count > 1); + var message = string.Concat(response.Select(c => c.Content)); + this.Output.WriteLine(message); + Assert.Contains("1520", message, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Roger", message, StringComparison.OrdinalIgnoreCase); + } + [RetryTheory] [InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")] [InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")] diff --git a/dotnet/src/IntegrationTests/testsettings.json b/dotnet/src/IntegrationTests/testsettings.json index 39ec5c4d3b1c..66df73f8b7a5 100644 --- a/dotnet/src/IntegrationTests/testsettings.json +++ b/dotnet/src/IntegrationTests/testsettings.json @@ -51,8 +51,8 @@ "EmbeddingModelId": "embedding-001", "ApiKey": "", "Gemini": { - "ModelId": "gemini-1.0-pro", - "VisionModelId": "gemini-1.0-pro-vision" + "ModelId": "gemini-1.5-flash", + "VisionModelId": "gemini-1.5-flash" } }, "VertexAI": { @@ -61,8 +61,8 @@ "Location": "us-central1", "ProjectId": "", "Gemini": { - "ModelId": "gemini-1.0-pro", - "VisionModelId": "gemini-1.0-pro-vision" + "ModelId": "gemini-1.5-flash", + "VisionModelId": "gemini-1.5-flash" } }, "Bing": { From a69b8bf225b84d0f18e09bf677736c9b70848541 Mon Sep 17 00:00:00 2001 From: Krzysztof Kasprowicz Date: Tue, 2 Jul 2024 20:34:53 +0200 Subject: [PATCH 04/14] Refactor --- .../Core/Gemini/Clients/GeminiChatCompletionClient.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index 892f5473d001..3d7814023074 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -164,11 +164,11 @@ public async Task> GenerateChatMessageAsync( for (state.Iteration = 1; ; state.Iteration++) { - GeminiResponse geminiResponse; List chatResponses; using (var activity = ModelDiagnostics.StartCompletionActivity( this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings)) { + GeminiResponse geminiResponse; try { geminiResponse = await this.SendRequestAndReturnValidGeminiResponseAsync( From 98454b9f5aa761c5af12ce85a421ae01ffd5f5e1 Mon Sep 17 00:00:00 2001 From: Krzysztof Kasprowicz Date: Tue, 2 Jul 2024 20:41:30 +0200 Subject: [PATCH 05/14] Refactor --- .../Connectors.Google/Core/Gemini/Models/GeminiRequest.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs index a2515c234c41..5afca6b8458c 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs @@ -102,7 +102,7 @@ private static GeminiRequest CreateGeminiRequest(ChatHistory chatHistory) Contents = chatHistory .Where(message => message.Role != AuthorRole.System) .Select(CreateGeminiContentFromChatMessage).ToList(), - SystemMessages = CreateSystemChatMessages(chatHistory) + SystemMessages = CreateSystemMessages(chatHistory) }; return obj; } @@ -116,7 +116,7 @@ private static GeminiContent CreateGeminiContentFromChatMessage(ChatMessageConte }; } - private static GeminiContent? CreateSystemChatMessages(ChatHistory chatHistory) + private static GeminiContent? CreateSystemMessages(ChatHistory chatHistory) { var contents = chatHistory.Where(message => message.Role == AuthorRole.System).ToList(); if (contents.Count == 0) From ab0eeefc569e760619ea5feed10652e875d34501 Mon Sep 17 00:00:00 2001 From: Krzysztof Kasprowicz Date: Tue, 2 Jul 2024 20:51:00 +0200 Subject: [PATCH 06/14] Removed messages order validation and unit tests. --- .../Clients/GeminiChatGenerationTests.cs | 30 ------------------- .../Clients/GeminiChatCompletionClient.cs | 27 ----------------- 2 files changed, 57 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs index d5ba021e29ac..f406072bc91b 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs @@ -304,36 +304,6 @@ public async Task ShouldPassMultipleSystemMessagesToRequestAsync() item => Assert.Equal(messages[2], item.Text)); } - [Fact] - public async Task ShouldThrowNotSupportedIfChatHistoryHaveIncorrectOrderAsync() - { - // Arrange - var client = this.CreateChatCompletionClient(); - var chatHistory = new ChatHistory(); - chatHistory.AddUserMessage("Hello"); - chatHistory.AddAssistantMessage("Hi"); - chatHistory.AddAssistantMessage("Hi me again"); - chatHistory.AddUserMessage("How are you?"); - - // Act & Assert - await Assert.ThrowsAsync( - () => client.GenerateChatMessageAsync(chatHistory)); - } - - [Fact] - public async Task ShouldThrowNotSupportedIfChatHistoryNotEndWithUserMessageAsync() - { - // Arrange - var client = this.CreateChatCompletionClient(); - var chatHistory = new ChatHistory(); - chatHistory.AddUserMessage("Hello"); - chatHistory.AddAssistantMessage("Hi"); - - // Act & Assert - await Assert.ThrowsAsync( - () => client.GenerateChatMessageAsync(chatHistory)); - } - [Fact] public async Task ShouldThrowArgumentExceptionIfChatHistoryIsEmptyAsync() { diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index 3d7814023074..898a15f15be3 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -520,7 +520,6 @@ private static void ValidateChatHistory(ChatHistory chatHistory) { Verify.NotNullOrEmpty(chatHistory); ThrowIfChatContainsOnlySystemMessages(chatHistory); - ValidateChatHistoryMessagesOrder(chatHistory); } private static void ThrowIfChatContainsOnlySystemMessages(ChatHistory chatHistory) @@ -531,32 +530,6 @@ private static void ThrowIfChatContainsOnlySystemMessages(ChatHistory chatHistor } } - private static void ValidateChatHistoryMessagesOrder(ChatHistory chatHistory) - { - bool incorrectOrder = false; - // Exclude tool calls and system messages from the validation - ChatHistory chatHistoryCopy = new(chatHistory - .Where(message => message.Role != AuthorRole.Tool && - message.Role != AuthorRole.System && - message is not GeminiChatMessageContent { ToolCalls: not null })); - for (int i = 0; i < chatHistoryCopy.Count; i++) - { - if (chatHistoryCopy[i].Role != (i % 2 == 0 ? AuthorRole.User : AuthorRole.Assistant) || - (i == chatHistoryCopy.Count - 1 && chatHistoryCopy[i].Role != AuthorRole.User)) - { - incorrectOrder = true; - break; - } - } - - if (incorrectOrder) - { - throw new NotSupportedException( - "Gemini API support only chat history with order of messages alternates between the user and the assistant. " + - "Last message have to be User message."); - } - } - private async IAsyncEnumerable ProcessChatResponseStreamAsync( Stream responseStream, [EnumeratorCancellation] CancellationToken ct) From 0c6c7f54cbd64264f39145cc3f1d79ca5f44b372 Mon Sep 17 00:00:00 2001 From: Krzysztof Kasprowicz Date: Tue, 2 Jul 2024 20:57:35 +0200 Subject: [PATCH 07/14] Added integration tests --- .../Gemini/GeminiChatCompletionTests.cs | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs index d39b1eb14297..5732a3e4719a 100644 --- a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs @@ -64,6 +64,53 @@ public async Task ChatStreamingReturnsValidResponseAsync(ServiceType serviceType this.Output.WriteLine(message); } + [RetryTheory] + [InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")] + [InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")] + public async Task ChatGenerationOnlyAssistantMessagesReturnsValidResponseAsync(ServiceType serviceType) + { + // Arrange + var chatHistory = new ChatHistory(); + chatHistory.AddAssistantMessage("I'm Brandon, I'm very thirsty"); + chatHistory.AddAssistantMessage("Could you help me get some..."); + + var sut = this.GetChatService(serviceType); + + // Act + var response = await sut.GetChatMessageContentAsync(chatHistory); + + // Assert + Assert.NotNull(response.Content); + this.Output.WriteLine(response.Content); + string[] resultWords = ["drink", "water", "tea", "coffee", "juice", "soda"]; + Assert.Contains(resultWords, word => response.Content.Contains(word, StringComparison.OrdinalIgnoreCase)); + } + + [RetryTheory] + [InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")] + [InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")] + public async Task ChatStreamingOnlyAssistantMessagesReturnsValidResponseAsync(ServiceType serviceType) + { + // Arrange + var chatHistory = new ChatHistory(); + chatHistory.AddAssistantMessage("I'm Brandon, I'm very thirsty"); + chatHistory.AddAssistantMessage("Could you help me get some..."); + + var sut = this.GetChatService(serviceType); + + // Act + var response = + await sut.GetStreamingChatMessageContentsAsync(chatHistory).ToListAsync(); + + // Assert + Assert.NotEmpty(response); + Assert.True(response.Count > 1); + var message = string.Concat(response.Select(c => c.Content)); + this.Output.WriteLine(message); + string[] resultWords = ["drink", "water", "tea", "coffee", "juice", "soda"]; + Assert.Contains(resultWords, word => message.Contains(word, StringComparison.OrdinalIgnoreCase)); + } + [RetryTheory] [InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")] [InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")] From ae3e02e1435d6b63273dcb4bebf6e732bdaeae52 Mon Sep 17 00:00:00 2001 From: Krzysztof Kasprowicz Date: Tue, 2 Jul 2024 22:07:02 +0200 Subject: [PATCH 08/14] Added new test to gemini request --- .../Core/Gemini/GeminiRequestTests.cs | 31 ++++++++++++++++--- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs index 4053fb8ee79f..25ddb3f100b1 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs @@ -126,7 +126,8 @@ public void FromChatHistoryItReturnsGeminiRequestWithSafetySettings() public void FromChatHistoryItReturnsGeminiRequestWithChatHistory() { // Arrange - ChatHistory chatHistory = []; + string systemMessage = "system-message"; + var chatHistory = new ChatHistory(systemMessage); chatHistory.AddUserMessage("user-message"); chatHistory.AddAssistantMessage("assist-message"); chatHistory.AddUserMessage("user-message2"); @@ -136,14 +137,34 @@ public void FromChatHistoryItReturnsGeminiRequestWithChatHistory() var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings); // Assert + Assert.NotNull(request.SystemMessages?.Parts); + Assert.Single(request.SystemMessages.Parts); + Assert.Equal(request.SystemMessages.Parts[0].Text, systemMessage); Assert.Collection(request.Contents, - c => Assert.Equal(chatHistory[0].Content, c.Parts![0].Text), c => Assert.Equal(chatHistory[1].Content, c.Parts![0].Text), - c => Assert.Equal(chatHistory[2].Content, c.Parts![0].Text)); + c => Assert.Equal(chatHistory[2].Content, c.Parts![0].Text), + c => Assert.Equal(chatHistory[3].Content, c.Parts![0].Text)); Assert.Collection(request.Contents, - c => Assert.Equal(chatHistory[0].Role, c.Role), c => Assert.Equal(chatHistory[1].Role, c.Role), - c => Assert.Equal(chatHistory[2].Role, c.Role)); + c => Assert.Equal(chatHistory[2].Role, c.Role), + c => Assert.Equal(chatHistory[3].Role, c.Role)); + } + + [Fact] + public void FromChatHistoryMultipleSystemMessagesItReturnsGeminiRequestWithSystemMessages() + { + // Arrange + string[] systemMessages = ["system-message", "system-message2", "system-message3", "system-message4"]; + var chatHistory = new ChatHistory(systemMessages[0]); + chatHistory.AddUserMessage("user-message"); + var executionSettings = new GeminiPromptExecutionSettings(); + + // Act + var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings); + + // Assert + Assert.NotNull(request.SystemMessages?.Parts); + Assert.Contains(request.SystemMessages.Parts, part => systemMessages.Contains(part.Text)); } [Fact] From a051bac3bacc3206ea53773157823df6e5636178 Mon Sep 17 00:00:00 2001 From: Krzysztof Kasprowicz Date: Tue, 2 Jul 2024 22:09:39 +0200 Subject: [PATCH 09/14] Simplify gemini request tests names --- .../Core/Gemini/GeminiRequestTests.cs | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs index 25ddb3f100b1..96f0d04a4b32 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs @@ -15,7 +15,7 @@ namespace SemanticKernel.Connectors.Google.UnitTests.Core.Gemini; public sealed class GeminiRequestTests { [Fact] - public void FromPromptItReturnsGeminiRequestWithConfiguration() + public void FromPromptItReturnsWithConfiguration() { // Arrange var prompt = "prompt-example"; @@ -37,7 +37,7 @@ public void FromPromptItReturnsGeminiRequestWithConfiguration() } [Fact] - public void FromPromptItReturnsGeminiRequestWithSafetySettings() + public void FromPromptItReturnsWithSafetySettings() { // Arrange var prompt = "prompt-example"; @@ -59,7 +59,7 @@ public void FromPromptItReturnsGeminiRequestWithSafetySettings() } [Fact] - public void FromPromptItReturnsGeminiRequestWithPrompt() + public void FromPromptItReturnsWithPrompt() { // Arrange var prompt = "prompt-example"; @@ -73,7 +73,7 @@ public void FromPromptItReturnsGeminiRequestWithPrompt() } [Fact] - public void FromChatHistoryItReturnsGeminiRequestWithConfiguration() + public void FromChatHistoryItReturnsWithConfiguration() { // Arrange ChatHistory chatHistory = []; @@ -98,7 +98,7 @@ public void FromChatHistoryItReturnsGeminiRequestWithConfiguration() } [Fact] - public void FromChatHistoryItReturnsGeminiRequestWithSafetySettings() + public void FromChatHistoryItReturnsWithSafetySettings() { // Arrange ChatHistory chatHistory = []; @@ -123,7 +123,7 @@ public void FromChatHistoryItReturnsGeminiRequestWithSafetySettings() } [Fact] - public void FromChatHistoryItReturnsGeminiRequestWithChatHistory() + public void FromChatHistoryItReturnsWithChatHistory() { // Arrange string systemMessage = "system-message"; @@ -151,7 +151,7 @@ public void FromChatHistoryItReturnsGeminiRequestWithChatHistory() } [Fact] - public void FromChatHistoryMultipleSystemMessagesItReturnsGeminiRequestWithSystemMessages() + public void FromChatHistoryMultipleSystemMessagesItReturnsWithSystemMessages() { // Arrange string[] systemMessages = ["system-message", "system-message2", "system-message3", "system-message4"]; @@ -168,7 +168,7 @@ public void FromChatHistoryMultipleSystemMessagesItReturnsGeminiRequestWithSyste } [Fact] - public void FromChatHistoryTextAsTextContentItReturnsGeminiRequestWithChatHistory() + public void FromChatHistoryTextAsTextContentItReturnsWithChatHistory() { // Arrange ChatHistory chatHistory = []; @@ -188,7 +188,7 @@ public void FromChatHistoryTextAsTextContentItReturnsGeminiRequestWithChatHistor } [Fact] - public void FromChatHistoryImageAsImageContentItReturnsGeminiRequestWithChatHistory() + public void FromChatHistoryImageAsImageContentItReturnsWithChatHistory() { // Arrange ReadOnlyMemory imageAsBytes = new byte[] { 0x00, 0x01, 0x02, 0x03 }; @@ -293,7 +293,7 @@ public void FromChatHistoryToolCallsNotNullAddsFunctionCalls() } [Fact] - public void AddFunctionItAddsFunctionToGeminiRequest() + public void AddFunctionToGeminiRequest() { // Arrange var request = new GeminiRequest(); @@ -308,7 +308,7 @@ public void AddFunctionItAddsFunctionToGeminiRequest() } [Fact] - public void AddMultipleFunctionsItAddsFunctionsToGeminiRequest() + public void AddMultipleFunctionsToGeminiRequest() { // Arrange var request = new GeminiRequest(); @@ -329,7 +329,7 @@ public void AddMultipleFunctionsItAddsFunctionsToGeminiRequest() } [Fact] - public void AddChatMessageToRequestItAddsChatMessageToGeminiRequest() + public void AddChatMessageToRequestt() { // Arrange ChatHistory chat = []; From 1b004ee165363d64439d8712a6de66f60295e8aa Mon Sep 17 00:00:00 2001 From: Krzysztof Kasprowicz Date: Tue, 2 Jul 2024 22:10:33 +0200 Subject: [PATCH 10/14] Rename SystemMessages to SystemInstruction --- .../Core/Gemini/Clients/GeminiChatGenerationTests.cs | 12 ++++++------ .../Core/Gemini/Clients/GeminiChatStreamingTests.cs | 12 ++++++------ .../Core/Gemini/GeminiRequestTests.cs | 12 ++++++------ .../Core/Gemini/Models/GeminiRequest.cs | 4 ++-- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs index f406072bc91b..5232c40b005d 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs @@ -273,9 +273,9 @@ public async Task ShouldPassSystemMessageToRequestAsync() // Assert GeminiRequest? request = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); Assert.NotNull(request); - Assert.NotNull(request.SystemMessages); - var systemMessage = request.SystemMessages.Parts![0].Text; - Assert.Null(request.SystemMessages.Role); + Assert.NotNull(request.SystemInstruction); + var systemMessage = request.SystemInstruction.Parts![0].Text; + Assert.Null(request.SystemInstruction.Role); Assert.Equal(message, systemMessage); } @@ -296,9 +296,9 @@ public async Task ShouldPassMultipleSystemMessagesToRequestAsync() // Assert GeminiRequest? request = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); Assert.NotNull(request); - Assert.NotNull(request.SystemMessages); - Assert.Null(request.SystemMessages.Role); - Assert.Collection(request.SystemMessages.Parts!, + Assert.NotNull(request.SystemInstruction); + Assert.Null(request.SystemInstruction.Role); + Assert.Collection(request.SystemInstruction.Parts!, item => Assert.Equal(messages[0], item.Text), item => Assert.Equal(messages[1], item.Text), item => Assert.Equal(messages[2], item.Text)); diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs index 6f81e109d2c0..d47115fe4ebc 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs @@ -262,9 +262,9 @@ public async Task ShouldPassSystemMessageToRequestAsync() // Assert GeminiRequest? request = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); Assert.NotNull(request); - Assert.NotNull(request.SystemMessages); - var systemMessage = request.SystemMessages.Parts![0].Text; - Assert.Null(request.SystemMessages.Role); + Assert.NotNull(request.SystemInstruction); + var systemMessage = request.SystemInstruction.Parts![0].Text; + Assert.Null(request.SystemInstruction.Role); Assert.Equal(message, systemMessage); } @@ -285,9 +285,9 @@ public async Task ShouldPassMultipleSystemMessagesToRequestAsync() // Assert GeminiRequest? request = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); Assert.NotNull(request); - Assert.NotNull(request.SystemMessages); - Assert.Null(request.SystemMessages.Role); - Assert.Collection(request.SystemMessages.Parts!, + Assert.NotNull(request.SystemInstruction); + Assert.Null(request.SystemInstruction.Role); + Assert.Collection(request.SystemInstruction.Parts!, item => Assert.Equal(messages[0], item.Text), item => Assert.Equal(messages[1], item.Text), item => Assert.Equal(messages[2], item.Text)); diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs index 96f0d04a4b32..9aff54e2cf24 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs @@ -137,9 +137,9 @@ public void FromChatHistoryItReturnsWithChatHistory() var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings); // Assert - Assert.NotNull(request.SystemMessages?.Parts); - Assert.Single(request.SystemMessages.Parts); - Assert.Equal(request.SystemMessages.Parts[0].Text, systemMessage); + Assert.NotNull(request.SystemInstruction?.Parts); + Assert.Single(request.SystemInstruction.Parts); + Assert.Equal(request.SystemInstruction.Parts[0].Text, systemMessage); Assert.Collection(request.Contents, c => Assert.Equal(chatHistory[1].Content, c.Parts![0].Text), c => Assert.Equal(chatHistory[2].Content, c.Parts![0].Text), @@ -163,8 +163,8 @@ public void FromChatHistoryMultipleSystemMessagesItReturnsWithSystemMessages() var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings); // Assert - Assert.NotNull(request.SystemMessages?.Parts); - Assert.Contains(request.SystemMessages.Parts, part => systemMessages.Contains(part.Text)); + Assert.NotNull(request.SystemInstruction?.Parts); + Assert.Contains(request.SystemInstruction.Parts, part => systemMessages.Contains(part.Text)); } [Fact] @@ -329,7 +329,7 @@ public void AddMultipleFunctionsToGeminiRequest() } [Fact] - public void AddChatMessageToRequestt() + public void AddChatMessageToRequest() { // Arrange ChatHistory chat = []; diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs index 5afca6b8458c..c50b6b33db46 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs @@ -28,7 +28,7 @@ internal sealed class GeminiRequest [JsonPropertyName("systemInstruction")] [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] - public GeminiContent? SystemMessages { get; set; } + public GeminiContent? SystemInstruction { get; set; } public void AddFunction(GeminiFunction function) { @@ -102,7 +102,7 @@ private static GeminiRequest CreateGeminiRequest(ChatHistory chatHistory) Contents = chatHistory .Where(message => message.Role != AuthorRole.System) .Select(CreateGeminiContentFromChatMessage).ToList(), - SystemMessages = CreateSystemMessages(chatHistory) + SystemInstruction = CreateSystemMessages(chatHistory) }; return obj; } From 7d618febb987c19569f4b86051f805c4006fac66 Mon Sep 17 00:00:00 2001 From: Krzysztof Kasprowicz Date: Tue, 2 Jul 2024 22:11:46 +0200 Subject: [PATCH 11/14] Minor --- .../Core/Gemini/GeminiRequestTests.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs index 9aff54e2cf24..e79009c60f9e 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs @@ -184,7 +184,7 @@ public void FromChatHistoryTextAsTextContentItReturnsWithChatHistory() Assert.Collection(request.Contents, c => Assert.Equal(chatHistory[0].Content, c.Parts![0].Text), c => Assert.Equal(chatHistory[1].Content, c.Parts![0].Text), - c => Assert.Equal(chatHistory[2].Items!.Cast().Single().Text, c.Parts![0].Text)); + c => Assert.Equal(chatHistory[2].Items.Cast().Single().Text, c.Parts![0].Text)); } [Fact] @@ -208,7 +208,7 @@ public void FromChatHistoryImageAsImageContentItReturnsWithChatHistory() Assert.Collection(request.Contents, c => Assert.Equal(chatHistory[0].Content, c.Parts![0].Text), c => Assert.Equal(chatHistory[1].Content, c.Parts![0].Text), - c => Assert.Equal(chatHistory[2].Items!.Cast().Single().Uri, + c => Assert.Equal(chatHistory[2].Items.Cast().Single().Uri, c.Parts![0].FileData!.FileUri), c => Assert.True(imageAsBytes.ToArray() .SequenceEqual(Convert.FromBase64String(c.Parts![0].InlineData!.InlineData)))); From 69e93c614430e412df7946e7edd58c6c8ef8972b Mon Sep 17 00:00:00 2001 From: Krzysztof Kasprowicz Date: Tue, 2 Jul 2024 22:18:49 +0200 Subject: [PATCH 12/14] Fixed tests --- .../Core/Gemini/GeminiRequestTests.cs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs index e79009c60f9e..e74ce51d4463 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs @@ -157,6 +157,9 @@ public void FromChatHistoryMultipleSystemMessagesItReturnsWithSystemMessages() string[] systemMessages = ["system-message", "system-message2", "system-message3", "system-message4"]; var chatHistory = new ChatHistory(systemMessages[0]); chatHistory.AddUserMessage("user-message"); + chatHistory.AddSystemMessage(systemMessages[1]); + chatHistory.AddMessage(AuthorRole.System, + [new TextContent(systemMessages[2]), new TextContent(systemMessages[3])]); var executionSettings = new GeminiPromptExecutionSettings(); // Act @@ -164,7 +167,7 @@ public void FromChatHistoryMultipleSystemMessagesItReturnsWithSystemMessages() // Assert Assert.NotNull(request.SystemInstruction?.Parts); - Assert.Contains(request.SystemInstruction.Parts, part => systemMessages.Contains(part.Text)); + Assert.All(systemMessages, msg => Assert.Contains(request.SystemInstruction.Parts, p => p.Text == msg)); } [Fact] From 11ff2bd8281421f426f3249c6d32918f617a33fd Mon Sep 17 00:00:00 2001 From: Krzysztof Kasprowicz Date: Tue, 2 Jul 2024 22:20:13 +0200 Subject: [PATCH 13/14] Refactor --- .../Core/Gemini/Clients/GeminiChatCompletionClient.cs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index 898a15f15be3..9750af44c0c7 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -519,11 +519,6 @@ private static bool CheckAutoInvokeCondition(Kernel? kernel, GeminiPromptExecuti private static void ValidateChatHistory(ChatHistory chatHistory) { Verify.NotNullOrEmpty(chatHistory); - ThrowIfChatContainsOnlySystemMessages(chatHistory); - } - - private static void ThrowIfChatContainsOnlySystemMessages(ChatHistory chatHistory) - { if (chatHistory.All(message => message.Role == AuthorRole.System)) { throw new InvalidOperationException("Chat history can't contain only system messages."); From f25181b1a6baabfb429f23f43458e088036d0225 Mon Sep 17 00:00:00 2001 From: Krzysztof Kasprowicz Date: Wed, 3 Jul 2024 18:31:40 +0200 Subject: [PATCH 14/14] Updated gemini samples to use system message. --- .../ChatCompletion/Google_GeminiChatCompletion.cs | 2 +- .../Google_GeminiChatCompletionStreaming.cs | 2 +- .../Concepts/ChatCompletion/Google_GeminiVision.cs | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletion.cs b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletion.cs index de2e996dc2fc..2e8f750e5476 100644 --- a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletion.cs +++ b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletion.cs @@ -89,7 +89,7 @@ private async Task SimpleChatAsync(Kernel kernel) { Console.WriteLine("======== Simple Chat ========"); - var chatHistory = new ChatHistory(); + var chatHistory = new ChatHistory("You are an expert in the tool shop."); var chat = kernel.GetRequiredService(); // First user message diff --git a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletionStreaming.cs b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletionStreaming.cs index 97f4873cfd52..803a6b6fafcd 100644 --- a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletionStreaming.cs +++ b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletionStreaming.cs @@ -90,7 +90,7 @@ private async Task StreamingChatAsync(Kernel kernel) { Console.WriteLine("======== Streaming Chat ========"); - var chatHistory = new ChatHistory(); + var chatHistory = new ChatHistory("You are an expert in the tool shop."); var chat = kernel.GetRequiredService(); // First user message diff --git a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiVision.cs b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiVision.cs index 1bf70ca28f5b..179b2b40937d 100644 --- a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiVision.cs +++ b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiVision.cs @@ -14,7 +14,7 @@ public async Task GoogleAIAsync() Console.WriteLine("============= Google AI - Gemini Chat Completion with vision ============="); string geminiApiKey = TestConfiguration.GoogleAI.ApiKey; - string geminiModelId = "gemini-pro-vision"; + string geminiModelId = TestConfiguration.GoogleAI.Gemini.ModelId; if (geminiApiKey is null) { @@ -28,7 +28,7 @@ public async Task GoogleAIAsync() apiKey: geminiApiKey) .Build(); - var chatHistory = new ChatHistory(); + var chatHistory = new ChatHistory("Your job is describing images."); var chatCompletionService = kernel.GetRequiredService(); // Load the image from the resources @@ -55,7 +55,7 @@ public async Task VertexAIAsync() Console.WriteLine("============= Vertex AI - Gemini Chat Completion with vision ============="); string geminiBearerKey = TestConfiguration.VertexAI.BearerKey; - string geminiModelId = "gemini-pro-vision"; + string geminiModelId = TestConfiguration.VertexAI.Gemini.ModelId; string geminiLocation = TestConfiguration.VertexAI.Location; string geminiProject = TestConfiguration.VertexAI.ProjectId; @@ -96,7 +96,7 @@ public async Task VertexAIAsync() // location: TestConfiguration.VertexAI.Location, // projectId: TestConfiguration.VertexAI.ProjectId); - var chatHistory = new ChatHistory(); + var chatHistory = new ChatHistory("Your job is describing images."); var chatCompletionService = kernel.GetRequiredService(); // Load the image from the resources