diff --git a/dotnet/src/SemanticKernel.Skills/Skills.Memory.Sqlite/Database.cs b/dotnet/src/SemanticKernel.Skills/Skills.Memory.Sqlite/Database.cs index a737e21914b1..e587fd00e298 100644 --- a/dotnet/src/SemanticKernel.Skills/Skills.Memory.Sqlite/Database.cs +++ b/dotnet/src/SemanticKernel.Skills/Skills.Memory.Sqlite/Database.cs @@ -29,13 +29,46 @@ public static async Task CreateConnectionAsync(string filename public static async Task InsertAsync(this SqliteConnection conn, string collection, string key, string? value, string? timestamp, CancellationToken cancel = default) + { + await CreateTableAsync(conn, cancel); + SqliteCommand cmd = conn.CreateCommand(); + cmd.CommandText = $@" + INSERT into {TableName}(collection, key, value, timestamp) + VALUES(@collection, @key, @value, @timestamp); "; + cmd.Parameters.AddWithValue("@collection", collection); + cmd.Parameters.AddWithValue("@key", key); + cmd.Parameters.AddWithValue("@value", value ?? string.Empty); + cmd.Parameters.AddWithValue("@timestamp", timestamp ?? string.Empty); + await cmd.ExecuteNonQueryAsync(cancel); + } + + public static async Task InsertOrIgnoreAsync(this SqliteConnection conn, + string collection, string key, string? value, string? timestamp, CancellationToken cancel = default) { await CreateTableAsync(conn, cancel); SqliteCommand cmd = conn.CreateCommand(); cmd.CommandText = $@" - INSERT INTO {TableName}(collection, key, value, timestamp) + INSERT or IGNORE into {TableName}(collection, key, value, timestamp) VALUES(@collection, @key, @value, @timestamp); "; + + cmd.Parameters.AddWithValue("@collection", collection); + cmd.Parameters.AddWithValue("@key", key); + cmd.Parameters.AddWithValue("@value", value ?? string.Empty); + cmd.Parameters.AddWithValue("@timestamp", timestamp ?? string.Empty); + await cmd.ExecuteNonQueryAsync(cancel); + } + + public static async Task UpsertAsync(this SqliteConnection conn, + string collection, string key, string? value, string? timestamp, CancellationToken cancel = default) + { + await CreateTableAsync(conn, cancel); + + SqliteCommand cmd = conn.CreateCommand(); + cmd.CommandText = $@" + INSERT or REPLACE into {TableName}(collection, key, value, timestamp) + VALUES(@collection, @key, @value, @timestamp); "; + cmd.Parameters.AddWithValue("@collection", collection); cmd.Parameters.AddWithValue("@key", key); cmd.Parameters.AddWithValue("@value", value ?? string.Empty); diff --git a/dotnet/src/SemanticKernel.Skills/Skills.Memory.Sqlite/SqliteDataStore.cs b/dotnet/src/SemanticKernel.Skills/Skills.Memory.Sqlite/SqliteDataStore.cs new file mode 100644 index 000000000000..26a7ddb57339 --- /dev/null +++ b/dotnet/src/SemanticKernel.Skills/Skills.Memory.Sqlite/SqliteDataStore.cs @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.Sqlite; +using Microsoft.SemanticKernel.Memory.Storage; +using Microsoft.SemanticKernel.Skills.Memory.Sqlite; + +namespace Microsoft.SemanticKernel.Skills.Memory.Sqlite; + +/// +/// An implementation of backed by a SQLite database. +/// +/// The data is saved to a database file, specified in the constructor. +/// The data persists between subsequent instances. Only one instance may access the file at a time. +/// The caller is responsible for deleting the file. +/// The type of data to be stored in this data store. +public class SqliteDataStore : IDataStore, IDisposable +{ + /// + /// Connect a Sqlite database + /// + /// Path to the database file. If file does not exist, it will be created. + /// Cancellation token + [SuppressMessage("Design", "CA1000:Do not declare static members on generic types", + Justification = "Static factory method used to ensure successful connection.")] + public static async Task> ConnectAsync(string filename, + CancellationToken cancel = default) + { + SqliteConnection dbConnection = await Database.CreateConnectionAsync(filename, cancel); + return new SqliteDataStore(dbConnection); + } + + /// + public IAsyncEnumerable GetCollectionsAsync(CancellationToken cancel = default) + { + return this._dbConnection.GetCollectionsAsync(cancel); + } + + /// + public async IAsyncEnumerable> GetAllAsync(string collection, + [EnumeratorCancellation] CancellationToken cancel = default) + { + await foreach (DatabaseEntry dbEntry in this._dbConnection.ReadAllAsync(collection, cancel)) + { + yield return DataEntry.Create(dbEntry.Key, dbEntry.Value, ParseTimestamp(dbEntry.Timestamp)); + } + } + + /// + public async Task?> GetAsync(string collection, string key, CancellationToken cancel = default) + { + DatabaseEntry? entry = await this._dbConnection.ReadAsync(collection, key, cancel); + if (entry.HasValue) + { + DatabaseEntry dbEntry = entry.Value; + return DataEntry.Create(dbEntry.Key, dbEntry.Value, ParseTimestamp(dbEntry.Timestamp)); + } + + return null; + } + + /// + public async Task> PutAsync(string collection, DataEntry data, CancellationToken cancel = default) + { + await this._dbConnection.InsertAsync(collection, data.Key, data.ValueString, ToTimestampString(data.Timestamp), cancel); + return data; + } + + /// + public Task RemoveAsync(string collection, string key, CancellationToken cancel = default) + { + return this._dbConnection.DeleteAsync(collection, key, cancel); + } + + /// + /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. + /// + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + this.Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + #region protected ================================================================================ + + /// + /// Performs dispose specifically on the SqliteConnection + /// + protected virtual void Dispose(bool disposing) + { + if (!this._disposedValue) + { + if (disposing) + { + this._dbConnection.Dispose(); + } + + this._disposedValue = true; + } + } + + /// + /// Constructor + /// + /// DB connection + protected SqliteDataStore(SqliteConnection dbConnection) + { + this._dbConnection = dbConnection; + } + + #endregion + + #region private ================================================================================ + + private readonly SqliteConnection _dbConnection; + private bool _disposedValue; + + /// + /// Convert timestamp to string + /// + private static string? ToTimestampString(DateTimeOffset? timestamp) + { + return timestamp?.ToString("u", CultureInfo.InvariantCulture); + } + + /// + /// Convert string to timestamp + /// + private static DateTimeOffset? ParseTimestamp(string? str) + { + if (!string.IsNullOrEmpty(str) + && DateTimeOffset.TryParse(str, CultureInfo.InvariantCulture, DateTimeStyles.AssumeUniversal, out DateTimeOffset timestamp)) + { + return timestamp; + } + + return null; + } + + #endregion +} diff --git a/dotnet/src/SemanticKernel.Skills/Skills.Memory.Sqlite/SqliteMemoryStore.cs b/dotnet/src/SemanticKernel.Skills/Skills.Memory.Sqlite/SqliteMemoryStore.cs index 298029a47826..7ffd7e1bdb9e 100644 --- a/dotnet/src/SemanticKernel.Skills/Skills.Memory.Sqlite/SqliteMemoryStore.cs +++ b/dotnet/src/SemanticKernel.Skills/Skills.Memory.Sqlite/SqliteMemoryStore.cs @@ -2,80 +2,145 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; +using System.Collections.ObjectModel; +using System.Data; +using System.Data.Common; using System.Globalization; +using System.Linq; using System.Runtime.CompilerServices; +using System.Text; using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Sqlite; +using Microsoft.SemanticKernel.AI.Embeddings; +using Microsoft.SemanticKernel.AI.Embeddings.VectorOperations; +using Microsoft.SemanticKernel.Diagnostics; +using Microsoft.SemanticKernel.Memory; +using Microsoft.SemanticKernel.Memory.Collections; using Microsoft.SemanticKernel.Memory.Storage; namespace Microsoft.SemanticKernel.Skills.Memory.Sqlite; - -/// -/// An implementation of backed by a SQLite database. -/// -/// The data is saved to a database file, specified in the constructor. -/// The data persists between subsequent instances. Only one instance may access the file at a time. -/// The caller is responsible for deleting the file. -/// The type of data to be stored in this data store. -public class SqliteDataStore : IDataStore, IDisposable +public class SqliteMemoryStore : IMemoryStore, IDisposable + where TEmbedding : unmanaged { - /// - /// Connect a Sqlite database - /// - /// Path to the database file. If file does not exist, it will be created. - /// Cancellation token - [SuppressMessage("Design", "CA1000:Do not declare static members on generic types", - Justification = "Static factory method used to ensure successful connection.")] - public static async Task> ConnectAsync(string filename, - CancellationToken cancel = default) + public SqliteMemoryStore(SqliteConnection dbConnection) { - SqliteConnection dbConnection = await Database.CreateConnectionAsync(filename, cancel); - return new SqliteDataStore(dbConnection); + this._dbConnection = dbConnection; + } + + /// + public IAsyncEnumerable<(IEmbeddingWithMetadata, double)> GetNearestMatchesAsync( + string collection, + Embedding embedding, + int limit = 1, + double minRelevanceScore = 0) + { + if (limit <= 0) + { + return AsyncEnumerable.Empty<(IEmbeddingWithMetadata, double)>(); + } + + IAsyncEnumerable>> asyncEmbeddingCollection = this.TryGetCollectionAsync(collection); + IEnumerable>> embeddingCollection = asyncEmbeddingCollection.ToEnumerable(); + + if (embeddingCollection == null || !embeddingCollection!.Any()) + { + return AsyncEnumerable.Empty<(IEmbeddingWithMetadata, double)>(); + } + + EmbeddingReadOnlySpan embeddingSpan = new(embedding.AsReadOnlySpan()); + + TopNCollection> embeddings = new(limit); + + foreach (var item in embeddingCollection!) + { + if (item.Value != null) + { + EmbeddingReadOnlySpan itemSpan = new(item.Value.Embedding.AsReadOnlySpan()); + double similarity = embeddingSpan.CosineSimilarity(itemSpan); + if (similarity >= minRelevanceScore) + { + embeddings.Add(new(item.Value, similarity)); + } + } + } + + embeddings.SortByScore(); + + return embeddings.Select(x => (x.Value, x.Score.Value)).ToAsyncEnumerable(); } /// - public IAsyncEnumerable GetCollectionsAsync(CancellationToken cancel = default) + public async IAsyncEnumerable GetCollectionsAsync([EnumeratorCancellation] CancellationToken cancel = default) { - return this._dbConnection.GetCollectionsAsync(cancel); + await this._dbConnection.OpenAsync(cancel); + + await foreach (var elem in this._dbConnection.GetCollectionsAsync(cancel)) + { + yield return elem; + } + + await this._dbConnection.CloseAsync(); } /// - public async IAsyncEnumerable> GetAllAsync(string collection, + public async IAsyncEnumerable>> GetAllAsync(string collection, [EnumeratorCancellation] CancellationToken cancel = default) { + await this._dbConnection.OpenAsync(cancel); + await foreach (DatabaseEntry dbEntry in this._dbConnection.ReadAllAsync(collection, cancel)) { - yield return DataEntry.Create(dbEntry.Key, dbEntry.Value, ParseTimestamp(dbEntry.Timestamp)); + var embedding = new Embedding(); + var val = (IEmbeddingWithMetadata)MemoryRecord.FromJson(dbEntry.Value, embedding); + yield return DataEntry.Create>(dbEntry.Key, val, ParseTimestamp(dbEntry.Timestamp)); } + + await this._dbConnection.CloseAsync(); } /// - public async Task?> GetAsync(string collection, string key, CancellationToken cancel = default) + public async Task>?> GetAsync(string collection, string key, CancellationToken cancel = default) { + await this._dbConnection.OpenAsync(cancel); + DatabaseEntry? entry = await this._dbConnection.ReadAsync(collection, key, cancel); + + await this._dbConnection.CloseAsync(); + if (entry.HasValue) { DatabaseEntry dbEntry = entry.Value; - return DataEntry.Create(dbEntry.Key, dbEntry.Value, ParseTimestamp(dbEntry.Timestamp)); + var embedding = new Embedding(); + var val = (IEmbeddingWithMetadata)MemoryRecord.FromJson(dbEntry.Value, embedding); + + return DataEntry.Create>(dbEntry.Key, val, ParseTimestamp(dbEntry.Timestamp)); } return null; } /// - public async Task> PutAsync(string collection, DataEntry data, CancellationToken cancel = default) + public async Task>> PutAsync(string collection, DataEntry> data, CancellationToken cancel = default) { - await this._dbConnection.InsertAsync(collection, data.Key, data.ValueString, ToTimestampString(data.Timestamp), cancel); + await this._dbConnection.OpenAsync(cancel); + + await this._dbConnection.UpsertAsync(collection, data.Key, JsonSerializer.Serialize(data.Value), ToTimestampString(data.Timestamp), cancel); + + await this._dbConnection.CloseAsync(); return data; } /// - public Task RemoveAsync(string collection, string key, CancellationToken cancel = default) + public async Task RemoveAsync(string collection, string key, CancellationToken cancel = default) { - return this._dbConnection.DeleteAsync(collection, key, cancel); + await this._dbConnection.OpenAsync(cancel); + + await this._dbConnection.DeleteAsync(collection, key, cancel); + + await this._dbConnection.CloseAsync(); + return; } /// @@ -90,6 +155,9 @@ public void Dispose() #region protected ================================================================================ + /// + /// Performs dispose specifically on the SqliteConnection + /// protected virtual void Dispose(bool disposing) { if (!this._disposedValue) @@ -103,43 +171,42 @@ protected virtual void Dispose(bool disposing) } } - #endregion - - #region private ================================================================================ - - private readonly SqliteConnection _dbConnection; - private bool _disposedValue; - /// - /// Constructor + /// Get all entries in the database that match the collectionName; Limited to the datatable initially supplied when openning the SqliteConnection. /// - /// DB connection - private SqliteDataStore(SqliteConnection dbConnection) + protected async IAsyncEnumerable>> TryGetCollectionAsync(string collectionName, [EnumeratorCancellation] CancellationToken cancel = default) { - this._dbConnection = dbConnection; - } + await this._dbConnection.OpenAsync(cancel); - // TODO: never used - private static string? ValueToString(TValue? value) - { - if (value != null) + await foreach (DatabaseEntry dbEntry in this._dbConnection.ReadAllAsync(collectionName, cancel)) { - if (typeof(TValue) == typeof(string)) - { - return value.ToString(); - } + var embedding = new Embedding(); + var val = (IEmbeddingWithMetadata)MemoryRecord.FromJson(dbEntry.Value, new Embedding()); - return JsonSerializer.Serialize(value); + yield return DataEntry.Create>(dbEntry.Key, val, ParseTimestamp(dbEntry.Timestamp)); } - return null; + await this._dbConnection.CloseAsync(); } + #endregion + + #region private ================================================================================ + + private readonly SqliteConnection _dbConnection; + private bool _disposedValue; + + /// + /// Convert string to timestamp + /// private static string? ToTimestampString(DateTimeOffset? timestamp) { return timestamp?.ToString("u", CultureInfo.InvariantCulture); } + /// + /// Convert timestamp to string + /// private static DateTimeOffset? ParseTimestamp(string? str) { if (!string.IsNullOrEmpty(str) @@ -151,5 +218,18 @@ private SqliteDataStore(SqliteConnection dbConnection) return null; } + /// + /// Calculates the cosine similarity between an and an + /// + /// The input to be compared. + /// The input to be compared. + /// A tuple consisting of the cosine similarity result. + private (IEmbeddingWithMetadata, double) PairEmbeddingWithSimilarity(Embedding embedding, + IEmbeddingWithMetadata embeddingWithData) + { + var similarity = embedding.Vector.ToArray().CosineSimilarity(embeddingWithData.Embedding.Vector.ToArray()); + return (embeddingWithData, similarity); + } + #endregion } diff --git a/dotnet/src/SemanticKernel/Memory/MemoryRecord.cs b/dotnet/src/SemanticKernel/Memory/MemoryRecord.cs index cfe32d229802..766afe201369 100644 --- a/dotnet/src/SemanticKernel/Memory/MemoryRecord.cs +++ b/dotnet/src/SemanticKernel/Memory/MemoryRecord.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Text.Json; +using System.Text.Json.Serialization; using Microsoft.SemanticKernel.AI.Embeddings; namespace Microsoft.SemanticKernel.Memory; @@ -13,11 +14,13 @@ public class MemoryRecord : IEmbeddingWithMetadata /// /// Source content embeddings. /// + [JsonPropertyName("embedding")] public Embedding Embedding { get; } /// /// Metadata associated with a Semantic Kernel memory. /// + [JsonPropertyName("metadata")] public MemoryRecordMetadata Metadata { get; } /// diff --git a/samples/dotnet/KernelHttpServer/KernelHttpServer.csproj b/samples/dotnet/KernelHttpServer/KernelHttpServer.csproj index 8bce105e9969..3311a8c69048 100644 --- a/samples/dotnet/KernelHttpServer/KernelHttpServer.csproj +++ b/samples/dotnet/KernelHttpServer/KernelHttpServer.csproj @@ -25,6 +25,7 @@ + diff --git a/samples/dotnet/KernelHttpServer/Program.cs b/samples/dotnet/KernelHttpServer/Program.cs index ba1be18b2e89..e8a3964b614b 100644 --- a/samples/dotnet/KernelHttpServer/Program.cs +++ b/samples/dotnet/KernelHttpServer/Program.cs @@ -1,11 +1,14 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Data.Common; using System.IO; using System.Text.Json; +using Microsoft.Data.Sqlite; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.SemanticKernel.Memory; +using Microsoft.SemanticKernel.Skills.Memory.Sqlite; namespace KernelHttpServer; @@ -13,6 +16,9 @@ public static class Program { public static void Main() { + string dbName = "SKDataBase.db"; + SqliteConnection dbConnection = new SqliteConnection($@"Data Source={dbName};"); + var host = new HostBuilder() .ConfigureFunctionsWorkerDefaults() .ConfigureAppConfiguration(configuration => @@ -24,7 +30,7 @@ public static void Main() }) .ConfigureServices(services => { - services.AddSingleton>(new VolatileMemoryStore()); + services.AddSingleton>(new SqliteMemoryStore(dbConnection)); // return JSON with expected lowercase naming services.Configure(options => @@ -35,5 +41,7 @@ public static void Main() .Build(); host.Run(); + + dbConnection.Dispose(); } }