diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index b76695dcb047..74f5fab75658 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -38,7 +38,7 @@ - + @@ -135,8 +135,8 @@ runtime; build; native; contentfiles; analyzers; buildtransitive - - - + + + \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextEmbeddingGenerationTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextEmbeddingGenerationTests.cs index 4e9cf00754cf..53462080eb06 100644 --- a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextEmbeddingGenerationTests.cs +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextEmbeddingGenerationTests.cs @@ -68,12 +68,12 @@ public async Task ShouldSendPromptToServiceAsync() httpClient: this._httpClient); //Act - await sut.GenerateEmbeddingsAsync(new List { "fake-text" }); + await sut.GenerateEmbeddingsAsync(["fake-text"]); //Assert var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); Assert.NotNull(requestPayload); - Assert.Equal("fake-text", requestPayload.Prompt); + Assert.Equal("fake-text", requestPayload.Input[0]); } [Fact] @@ -90,9 +90,10 @@ public async Task ShouldHandleServiceResponseAsync() //Assert Assert.NotNull(contents); + Assert.Equal(2, contents.Count); - var content = contents.SingleOrDefault(); - Assert.Equal(8, content.Length); + var content = contents[0]; + Assert.Equal(5, content.Length); } public void Dispose() diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/embeddings_test_response.json b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/embeddings_test_response.json index 63218204d30d..3316addba6dd 100644 --- a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/embeddings_test_response.json +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/embeddings_test_response.json @@ -1,13 +1,19 @@ { - "embedding": [ - -0.08541165292263031, - 0.08639130741357803, - -0.12805694341659546, - -0.2877824902534485, - 0.2114177942276001, - -0.29374566674232483, - -0.10496602207422256, - 0.009402364492416382 - ], - "model": "fake-model" + "model": "fake-model", + "embeddings": [ + [ + 0.020765934, + 0.007495159, + 0.01268963, + 0.013938076, + -0.04621073 + ], + [ + 0.025005031, + 0.009804744, + -0.016960088, + -0.024823941, + -0.02756831 + ] + ] } \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs index 13adcd165c80..9e152f917f88 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs @@ -9,6 +9,7 @@ using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel.Connectors.Ollama.Core; using Microsoft.SemanticKernel.Embeddings; +using Microsoft.SemanticKernel.Services; using OllamaSharp; using OllamaSharp.Models; @@ -58,20 +59,20 @@ public async Task>> GenerateEmbeddingsAsync( Kernel? kernel = null, CancellationToken cancellationToken = default) { - var tasks = new List>(); - foreach (var prompt in data) + var request = new GenerateEmbeddingRequest { - tasks.Add(this._client.GenerateEmbeddings(prompt, cancellationToken: cancellationToken)); - } + Model = this.GetModelId()!, + Input = data.ToList() + }; + + var response = await this._client.GenerateEmbeddings(request, cancellationToken: cancellationToken).ConfigureAwait(false); - await Task.WhenAll(tasks.ToArray()).ConfigureAwait(false); + List> embeddings = []; + foreach (var embedding in response.Embeddings) + { + embeddings.Add(embedding.Select(@decimal => (float)@decimal).ToArray()); + } - return new List>( - tasks.Select( - task => new ReadOnlyMemory(task.Result.Embedding - .Select(@decimal => (float)@decimal).ToArray() - ) - ) - ); + return embeddings; } }