diff --git a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletion.cs b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletion.cs index de2e996dc2fc..2e8f750e5476 100644 --- a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletion.cs +++ b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletion.cs @@ -89,7 +89,7 @@ private async Task SimpleChatAsync(Kernel kernel) { Console.WriteLine("======== Simple Chat ========"); - var chatHistory = new ChatHistory(); + var chatHistory = new ChatHistory("You are an expert in the tool shop."); var chat = kernel.GetRequiredService(); // First user message diff --git a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletionStreaming.cs b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletionStreaming.cs index 97f4873cfd52..803a6b6fafcd 100644 --- a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletionStreaming.cs +++ b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiChatCompletionStreaming.cs @@ -90,7 +90,7 @@ private async Task StreamingChatAsync(Kernel kernel) { Console.WriteLine("======== Streaming Chat ========"); - var chatHistory = new ChatHistory(); + var chatHistory = new ChatHistory("You are an expert in the tool shop."); var chat = kernel.GetRequiredService(); // First user message diff --git a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiVision.cs b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiVision.cs index 1bf70ca28f5b..179b2b40937d 100644 --- a/dotnet/samples/Concepts/ChatCompletion/Google_GeminiVision.cs +++ b/dotnet/samples/Concepts/ChatCompletion/Google_GeminiVision.cs @@ -14,7 +14,7 @@ public async Task GoogleAIAsync() Console.WriteLine("============= Google AI - Gemini Chat Completion with vision ============="); string geminiApiKey = TestConfiguration.GoogleAI.ApiKey; - string geminiModelId = "gemini-pro-vision"; + string geminiModelId = TestConfiguration.GoogleAI.Gemini.ModelId; if (geminiApiKey is null) { @@ -28,7 +28,7 @@ public async Task GoogleAIAsync() apiKey: geminiApiKey) .Build(); - var chatHistory = new ChatHistory(); + var chatHistory = new ChatHistory("Your job is describing images."); var chatCompletionService = kernel.GetRequiredService(); // Load the image from the resources @@ -55,7 +55,7 @@ public async Task VertexAIAsync() Console.WriteLine("============= Vertex AI - Gemini Chat Completion with vision ============="); string geminiBearerKey = TestConfiguration.VertexAI.BearerKey; - string geminiModelId = "gemini-pro-vision"; + string geminiModelId = TestConfiguration.VertexAI.Gemini.ModelId; string geminiLocation = TestConfiguration.VertexAI.Location; string geminiProject = TestConfiguration.VertexAI.ProjectId; @@ -96,7 +96,7 @@ public async Task VertexAIAsync() // location: TestConfiguration.VertexAI.Location, // projectId: TestConfiguration.VertexAI.ProjectId); - var chatHistory = new ChatHistory(); + var chatHistory = new ChatHistory("Your job is describing images."); var chatCompletionService = kernel.GetRequiredService(); // Load the image from the resources diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs index 6b5bda155483..5232c40b005d 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatGenerationTests.cs @@ -259,21 +259,7 @@ await Assert.ThrowsAsync( } [Fact] - public async Task ShouldThrowInvalidOperationExceptionIfChatHistoryContainsMoreThanOneSystemMessageAsync() - { - var client = this.CreateChatCompletionClient(); - var chatHistory = new ChatHistory("System message"); - chatHistory.AddSystemMessage("System message 2"); - chatHistory.AddSystemMessage("System message 3"); - chatHistory.AddUserMessage("hello"); - - // Act & Assert - await Assert.ThrowsAsync( - () => client.GenerateChatMessageAsync(chatHistory)); - } - - [Fact] - public async Task ShouldPassConvertedSystemMessageToUserMessageToRequestAsync() + public async Task ShouldPassSystemMessageToRequestAsync() { // Arrange var client = this.CreateChatCompletionClient(); @@ -287,40 +273,35 @@ public async Task ShouldPassConvertedSystemMessageToUserMessageToRequestAsync() // Assert GeminiRequest? request = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); Assert.NotNull(request); - var systemMessage = request.Contents[0].Parts![0].Text; - var messageRole = request.Contents[0].Role; - Assert.Equal(AuthorRole.User, messageRole); + Assert.NotNull(request.SystemInstruction); + var systemMessage = request.SystemInstruction.Parts![0].Text; + Assert.Null(request.SystemInstruction.Role); Assert.Equal(message, systemMessage); } [Fact] - public async Task ShouldThrowNotSupportedIfChatHistoryHaveIncorrectOrderAsync() + public async Task ShouldPassMultipleSystemMessagesToRequestAsync() { // Arrange + string[] messages = ["System message 1", "System message 2", "System message 3"]; var client = this.CreateChatCompletionClient(); - var chatHistory = new ChatHistory(); + var chatHistory = new ChatHistory(messages[0]); + chatHistory.AddSystemMessage(messages[1]); + chatHistory.AddSystemMessage(messages[2]); chatHistory.AddUserMessage("Hello"); - chatHistory.AddAssistantMessage("Hi"); - chatHistory.AddAssistantMessage("Hi me again"); - chatHistory.AddUserMessage("How are you?"); - // Act & Assert - await Assert.ThrowsAsync( - () => client.GenerateChatMessageAsync(chatHistory)); - } - - [Fact] - public async Task ShouldThrowNotSupportedIfChatHistoryNotEndWithUserMessageAsync() - { - // Arrange - var client = this.CreateChatCompletionClient(); - var chatHistory = new ChatHistory(); - chatHistory.AddUserMessage("Hello"); - chatHistory.AddAssistantMessage("Hi"); + // Act + await client.GenerateChatMessageAsync(chatHistory); - // Act & Assert - await Assert.ThrowsAsync( - () => client.GenerateChatMessageAsync(chatHistory)); + // Assert + GeminiRequest? request = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(request); + Assert.NotNull(request.SystemInstruction); + Assert.Null(request.SystemInstruction.Role); + Assert.Collection(request.SystemInstruction.Parts!, + item => Assert.Equal(messages[0], item.Text), + item => Assert.Equal(messages[1], item.Text), + item => Assert.Equal(messages[2], item.Text)); } [Fact] diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs index 73b647429297..d47115fe4ebc 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/Clients/GeminiChatStreamingTests.cs @@ -248,7 +248,7 @@ public async Task ShouldUsePromptExecutionSettingsAsync() } [Fact] - public async Task ShouldPassConvertedSystemMessageToUserMessageToRequestAsync() + public async Task ShouldPassSystemMessageToRequestAsync() { // Arrange var client = this.CreateChatCompletionClient(); @@ -262,12 +262,37 @@ public async Task ShouldPassConvertedSystemMessageToUserMessageToRequestAsync() // Assert GeminiRequest? request = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); Assert.NotNull(request); - var systemMessage = request.Contents[0].Parts![0].Text; - var messageRole = request.Contents[0].Role; - Assert.Equal(AuthorRole.User, messageRole); + Assert.NotNull(request.SystemInstruction); + var systemMessage = request.SystemInstruction.Parts![0].Text; + Assert.Null(request.SystemInstruction.Role); Assert.Equal(message, systemMessage); } + [Fact] + public async Task ShouldPassMultipleSystemMessagesToRequestAsync() + { + // Arrange + string[] messages = ["System message 1", "System message 2", "System message 3"]; + var client = this.CreateChatCompletionClient(); + var chatHistory = new ChatHistory(messages[0]); + chatHistory.AddSystemMessage(messages[1]); + chatHistory.AddSystemMessage(messages[2]); + chatHistory.AddUserMessage("Hello"); + + // Act + await client.StreamGenerateChatMessageAsync(chatHistory).ToListAsync(); + + // Assert + GeminiRequest? request = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(request); + Assert.NotNull(request.SystemInstruction); + Assert.Null(request.SystemInstruction.Role); + Assert.Collection(request.SystemInstruction.Parts!, + item => Assert.Equal(messages[0], item.Text), + item => Assert.Equal(messages[1], item.Text), + item => Assert.Equal(messages[2], item.Text)); + } + [Theory] [InlineData(0)] [InlineData(-15)] diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs index 4053fb8ee79f..e74ce51d4463 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Core/Gemini/GeminiRequestTests.cs @@ -15,7 +15,7 @@ namespace SemanticKernel.Connectors.Google.UnitTests.Core.Gemini; public sealed class GeminiRequestTests { [Fact] - public void FromPromptItReturnsGeminiRequestWithConfiguration() + public void FromPromptItReturnsWithConfiguration() { // Arrange var prompt = "prompt-example"; @@ -37,7 +37,7 @@ public void FromPromptItReturnsGeminiRequestWithConfiguration() } [Fact] - public void FromPromptItReturnsGeminiRequestWithSafetySettings() + public void FromPromptItReturnsWithSafetySettings() { // Arrange var prompt = "prompt-example"; @@ -59,7 +59,7 @@ public void FromPromptItReturnsGeminiRequestWithSafetySettings() } [Fact] - public void FromPromptItReturnsGeminiRequestWithPrompt() + public void FromPromptItReturnsWithPrompt() { // Arrange var prompt = "prompt-example"; @@ -73,7 +73,7 @@ public void FromPromptItReturnsGeminiRequestWithPrompt() } [Fact] - public void FromChatHistoryItReturnsGeminiRequestWithConfiguration() + public void FromChatHistoryItReturnsWithConfiguration() { // Arrange ChatHistory chatHistory = []; @@ -98,7 +98,7 @@ public void FromChatHistoryItReturnsGeminiRequestWithConfiguration() } [Fact] - public void FromChatHistoryItReturnsGeminiRequestWithSafetySettings() + public void FromChatHistoryItReturnsWithSafetySettings() { // Arrange ChatHistory chatHistory = []; @@ -123,10 +123,11 @@ public void FromChatHistoryItReturnsGeminiRequestWithSafetySettings() } [Fact] - public void FromChatHistoryItReturnsGeminiRequestWithChatHistory() + public void FromChatHistoryItReturnsWithChatHistory() { // Arrange - ChatHistory chatHistory = []; + string systemMessage = "system-message"; + var chatHistory = new ChatHistory(systemMessage); chatHistory.AddUserMessage("user-message"); chatHistory.AddAssistantMessage("assist-message"); chatHistory.AddUserMessage("user-message2"); @@ -136,18 +137,41 @@ public void FromChatHistoryItReturnsGeminiRequestWithChatHistory() var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings); // Assert + Assert.NotNull(request.SystemInstruction?.Parts); + Assert.Single(request.SystemInstruction.Parts); + Assert.Equal(request.SystemInstruction.Parts[0].Text, systemMessage); Assert.Collection(request.Contents, - c => Assert.Equal(chatHistory[0].Content, c.Parts![0].Text), c => Assert.Equal(chatHistory[1].Content, c.Parts![0].Text), - c => Assert.Equal(chatHistory[2].Content, c.Parts![0].Text)); + c => Assert.Equal(chatHistory[2].Content, c.Parts![0].Text), + c => Assert.Equal(chatHistory[3].Content, c.Parts![0].Text)); Assert.Collection(request.Contents, - c => Assert.Equal(chatHistory[0].Role, c.Role), c => Assert.Equal(chatHistory[1].Role, c.Role), - c => Assert.Equal(chatHistory[2].Role, c.Role)); + c => Assert.Equal(chatHistory[2].Role, c.Role), + c => Assert.Equal(chatHistory[3].Role, c.Role)); + } + + [Fact] + public void FromChatHistoryMultipleSystemMessagesItReturnsWithSystemMessages() + { + // Arrange + string[] systemMessages = ["system-message", "system-message2", "system-message3", "system-message4"]; + var chatHistory = new ChatHistory(systemMessages[0]); + chatHistory.AddUserMessage("user-message"); + chatHistory.AddSystemMessage(systemMessages[1]); + chatHistory.AddMessage(AuthorRole.System, + [new TextContent(systemMessages[2]), new TextContent(systemMessages[3])]); + var executionSettings = new GeminiPromptExecutionSettings(); + + // Act + var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings); + + // Assert + Assert.NotNull(request.SystemInstruction?.Parts); + Assert.All(systemMessages, msg => Assert.Contains(request.SystemInstruction.Parts, p => p.Text == msg)); } [Fact] - public void FromChatHistoryTextAsTextContentItReturnsGeminiRequestWithChatHistory() + public void FromChatHistoryTextAsTextContentItReturnsWithChatHistory() { // Arrange ChatHistory chatHistory = []; @@ -163,11 +187,11 @@ public void FromChatHistoryTextAsTextContentItReturnsGeminiRequestWithChatHistor Assert.Collection(request.Contents, c => Assert.Equal(chatHistory[0].Content, c.Parts![0].Text), c => Assert.Equal(chatHistory[1].Content, c.Parts![0].Text), - c => Assert.Equal(chatHistory[2].Items!.Cast().Single().Text, c.Parts![0].Text)); + c => Assert.Equal(chatHistory[2].Items.Cast().Single().Text, c.Parts![0].Text)); } [Fact] - public void FromChatHistoryImageAsImageContentItReturnsGeminiRequestWithChatHistory() + public void FromChatHistoryImageAsImageContentItReturnsWithChatHistory() { // Arrange ReadOnlyMemory imageAsBytes = new byte[] { 0x00, 0x01, 0x02, 0x03 }; @@ -187,7 +211,7 @@ public void FromChatHistoryImageAsImageContentItReturnsGeminiRequestWithChatHist Assert.Collection(request.Contents, c => Assert.Equal(chatHistory[0].Content, c.Parts![0].Text), c => Assert.Equal(chatHistory[1].Content, c.Parts![0].Text), - c => Assert.Equal(chatHistory[2].Items!.Cast().Single().Uri, + c => Assert.Equal(chatHistory[2].Items.Cast().Single().Uri, c.Parts![0].FileData!.FileUri), c => Assert.True(imageAsBytes.ToArray() .SequenceEqual(Convert.FromBase64String(c.Parts![0].InlineData!.InlineData)))); @@ -272,7 +296,7 @@ public void FromChatHistoryToolCallsNotNullAddsFunctionCalls() } [Fact] - public void AddFunctionItAddsFunctionToGeminiRequest() + public void AddFunctionToGeminiRequest() { // Arrange var request = new GeminiRequest(); @@ -287,7 +311,7 @@ public void AddFunctionItAddsFunctionToGeminiRequest() } [Fact] - public void AddMultipleFunctionsItAddsFunctionsToGeminiRequest() + public void AddMultipleFunctionsToGeminiRequest() { // Arrange var request = new GeminiRequest(); @@ -308,7 +332,7 @@ public void AddMultipleFunctionsItAddsFunctionsToGeminiRequest() } [Fact] - public void AddChatMessageToRequestItAddsChatMessageToGeminiRequest() + public void AddChatMessageToRequest() { // Arrange ChatHistory chat = []; diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs index e52b5f4e6bd6..9750af44c0c7 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Clients/GeminiChatCompletionClient.cs @@ -164,11 +164,11 @@ public async Task> GenerateChatMessageAsync( for (state.Iteration = 1; ; state.Iteration++) { - GeminiResponse geminiResponse; List chatResponses; using (var activity = ModelDiagnostics.StartCompletionActivity( this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings)) { + GeminiResponse geminiResponse; try { geminiResponse = await this.SendRequestAndReturnValidGeminiResponseAsync( @@ -297,8 +297,7 @@ private ChatCompletionState ValidateInputAndCreateChatCompletionState( Kernel? kernel, PromptExecutionSettings? executionSettings) { - var chatHistoryCopy = new ChatHistory(chatHistory); - ValidateAndPrepareChatHistory(chatHistoryCopy); + ValidateChatHistory(chatHistory); var geminiExecutionSettings = GeminiPromptExecutionSettings.FromExecutionSettings(executionSettings); ValidateMaxTokens(geminiExecutionSettings.MaxTokens); @@ -315,7 +314,7 @@ private ChatCompletionState ValidateInputAndCreateChatCompletionState( AutoInvoke = CheckAutoInvokeCondition(kernel, geminiExecutionSettings), ChatHistory = chatHistory, ExecutionSettings = geminiExecutionSettings, - GeminiRequest = CreateRequest(chatHistoryCopy, geminiExecutionSettings, kernel), + GeminiRequest = CreateRequest(chatHistory, geminiExecutionSettings, kernel), Kernel = kernel! // not null if auto-invoke is true }; } @@ -517,61 +516,12 @@ private static bool CheckAutoInvokeCondition(Kernel? kernel, GeminiPromptExecuti return autoInvoke; } - private static void ValidateAndPrepareChatHistory(ChatHistory chatHistory) + private static void ValidateChatHistory(ChatHistory chatHistory) { Verify.NotNullOrEmpty(chatHistory); - - if (chatHistory.Where(message => message.Role == AuthorRole.System).ToList() is { Count: > 0 } systemMessages) - { - if (chatHistory.Count == systemMessages.Count) - { - throw new InvalidOperationException("Chat history can't contain only system messages."); - } - - if (systemMessages.Count > 1) - { - throw new InvalidOperationException("Chat history can't contain more than one system message. " + - "Only the first system message will be processed but will be converted to the user message before sending to the Gemini api."); - } - - ConvertSystemMessageToUserMessageInChatHistory(chatHistory, systemMessages[0]); - } - - ValidateChatHistoryMessagesOrder(chatHistory); - } - - private static void ConvertSystemMessageToUserMessageInChatHistory(ChatHistory chatHistory, ChatMessageContent systemMessage) - { - // TODO: This solution is needed due to the fact that Gemini API doesn't support system messages. Maybe in the future we will be able to remove it. - chatHistory.Remove(systemMessage); - if (!string.IsNullOrWhiteSpace(systemMessage.Content)) - { - chatHistory.Insert(0, new ChatMessageContent(AuthorRole.User, systemMessage.Content)); - chatHistory.Insert(1, new ChatMessageContent(AuthorRole.Assistant, "OK")); - } - } - - private static void ValidateChatHistoryMessagesOrder(ChatHistory chatHistory) - { - bool incorrectOrder = false; - // Exclude tool calls from the validation - ChatHistory chatHistoryCopy = new(chatHistory - .Where(message => message.Role != AuthorRole.Tool && (message is not GeminiChatMessageContent { ToolCalls: not null }))); - for (int i = 0; i < chatHistoryCopy.Count; i++) - { - if (chatHistoryCopy[i].Role != (i % 2 == 0 ? AuthorRole.User : AuthorRole.Assistant) || - (i == chatHistoryCopy.Count - 1 && chatHistoryCopy[i].Role != AuthorRole.User)) - { - incorrectOrder = true; - break; - } - } - - if (incorrectOrder) + if (chatHistory.All(message => message.Role == AuthorRole.System)) { - throw new NotSupportedException( - "Gemini API support only chat history with order of messages alternates between the user and the assistant. " + - "Last message have to be User message."); + throw new InvalidOperationException("Chat history can't contain only system messages."); } } diff --git a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs index def81d9a7083..c50b6b33db46 100644 --- a/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs +++ b/dotnet/src/Connectors/Connectors.Google/Core/Gemini/Models/GeminiRequest.cs @@ -26,6 +26,10 @@ internal sealed class GeminiRequest [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public IList? Tools { get; set; } + [JsonPropertyName("systemInstruction")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public GeminiContent? SystemInstruction { get; set; } + public void AddFunction(GeminiFunction function) { // NOTE: Currently Gemini only supports one tool i.e. function calling. @@ -95,7 +99,10 @@ private static GeminiRequest CreateGeminiRequest(ChatHistory chatHistory) { GeminiRequest obj = new() { - Contents = chatHistory.Select(CreateGeminiContentFromChatMessage).ToList() + Contents = chatHistory + .Where(message => message.Role != AuthorRole.System) + .Select(CreateGeminiContentFromChatMessage).ToList(), + SystemInstruction = CreateSystemMessages(chatHistory) }; return obj; } @@ -109,6 +116,20 @@ private static GeminiContent CreateGeminiContentFromChatMessage(ChatMessageConte }; } + private static GeminiContent? CreateSystemMessages(ChatHistory chatHistory) + { + var contents = chatHistory.Where(message => message.Role == AuthorRole.System).ToList(); + if (contents.Count == 0) + { + return null; + } + + return new GeminiContent + { + Parts = CreateGeminiParts(contents) + }; + } + public void AddChatMessage(ChatMessageContent message) { Verify.NotNull(this.Contents); @@ -117,6 +138,24 @@ public void AddChatMessage(ChatMessageContent message) this.Contents.Add(CreateGeminiContentFromChatMessage(message)); } + private static List CreateGeminiParts(IEnumerable contents) + { + List? parts = null; + foreach (var content in contents) + { + if (parts == null) + { + parts = CreateGeminiParts(content); + } + else + { + parts.AddRange(CreateGeminiParts(content)); + } + } + + return parts!; + } + private static List CreateGeminiParts(ChatMessageContent content) { List parts = []; diff --git a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs index 321ede0ff115..5732a3e4719a 100644 --- a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatCompletionTests.cs @@ -64,6 +64,104 @@ public async Task ChatStreamingReturnsValidResponseAsync(ServiceType serviceType this.Output.WriteLine(message); } + [RetryTheory] + [InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")] + [InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")] + public async Task ChatGenerationOnlyAssistantMessagesReturnsValidResponseAsync(ServiceType serviceType) + { + // Arrange + var chatHistory = new ChatHistory(); + chatHistory.AddAssistantMessage("I'm Brandon, I'm very thirsty"); + chatHistory.AddAssistantMessage("Could you help me get some..."); + + var sut = this.GetChatService(serviceType); + + // Act + var response = await sut.GetChatMessageContentAsync(chatHistory); + + // Assert + Assert.NotNull(response.Content); + this.Output.WriteLine(response.Content); + string[] resultWords = ["drink", "water", "tea", "coffee", "juice", "soda"]; + Assert.Contains(resultWords, word => response.Content.Contains(word, StringComparison.OrdinalIgnoreCase)); + } + + [RetryTheory] + [InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")] + [InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")] + public async Task ChatStreamingOnlyAssistantMessagesReturnsValidResponseAsync(ServiceType serviceType) + { + // Arrange + var chatHistory = new ChatHistory(); + chatHistory.AddAssistantMessage("I'm Brandon, I'm very thirsty"); + chatHistory.AddAssistantMessage("Could you help me get some..."); + + var sut = this.GetChatService(serviceType); + + // Act + var response = + await sut.GetStreamingChatMessageContentsAsync(chatHistory).ToListAsync(); + + // Assert + Assert.NotEmpty(response); + Assert.True(response.Count > 1); + var message = string.Concat(response.Select(c => c.Content)); + this.Output.WriteLine(message); + string[] resultWords = ["drink", "water", "tea", "coffee", "juice", "soda"]; + Assert.Contains(resultWords, word => message.Contains(word, StringComparison.OrdinalIgnoreCase)); + } + + [RetryTheory] + [InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")] + [InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")] + public async Task ChatGenerationWithSystemMessagesAsync(ServiceType serviceType) + { + // Arrange + var chatHistory = new ChatHistory("You are helpful assistant. Your name is Roger."); + chatHistory.AddSystemMessage("You know ACDD equals 1520"); + chatHistory.AddUserMessage("Hello, I'm Brandon, how are you?"); + chatHistory.AddAssistantMessage("I'm doing well, thanks for asking."); + chatHistory.AddUserMessage("Tell me your name and the value of ACDD."); + + var sut = this.GetChatService(serviceType); + + // Act + var response = await sut.GetChatMessageContentAsync(chatHistory); + + // Assert + Assert.NotNull(response.Content); + this.Output.WriteLine(response.Content); + Assert.Contains("1520", response.Content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Roger", response.Content, StringComparison.OrdinalIgnoreCase); + } + + [RetryTheory] + [InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")] + [InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")] + public async Task ChatStreamingWithSystemMessagesAsync(ServiceType serviceType) + { + // Arrange + var chatHistory = new ChatHistory("You are helpful assistant. Your name is Roger."); + chatHistory.AddSystemMessage("You know ACDD equals 1520"); + chatHistory.AddUserMessage("Hello, I'm Brandon, how are you?"); + chatHistory.AddAssistantMessage("I'm doing well, thanks for asking."); + chatHistory.AddUserMessage("Tell me your name and the value of ACDD."); + + var sut = this.GetChatService(serviceType); + + // Act + var response = + await sut.GetStreamingChatMessageContentsAsync(chatHistory).ToListAsync(); + + // Assert + Assert.NotEmpty(response); + Assert.True(response.Count > 1); + var message = string.Concat(response.Select(c => c.Content)); + this.Output.WriteLine(message); + Assert.Contains("1520", message, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Roger", message, StringComparison.OrdinalIgnoreCase); + } + [RetryTheory] [InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")] [InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")] diff --git a/dotnet/src/IntegrationTests/testsettings.json b/dotnet/src/IntegrationTests/testsettings.json index 39ec5c4d3b1c..66df73f8b7a5 100644 --- a/dotnet/src/IntegrationTests/testsettings.json +++ b/dotnet/src/IntegrationTests/testsettings.json @@ -51,8 +51,8 @@ "EmbeddingModelId": "embedding-001", "ApiKey": "", "Gemini": { - "ModelId": "gemini-1.0-pro", - "VisionModelId": "gemini-1.0-pro-vision" + "ModelId": "gemini-1.5-flash", + "VisionModelId": "gemini-1.5-flash" } }, "VertexAI": { @@ -61,8 +61,8 @@ "Location": "us-central1", "ProjectId": "", "Gemini": { - "ModelId": "gemini-1.0-pro", - "VisionModelId": "gemini-1.0-pro-vision" + "ModelId": "gemini-1.5-flash", + "VisionModelId": "gemini-1.5-flash" } }, "Bing": {