From 36342b40cd48c66cf7f19a3e880e09bac8bab97c Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Fri, 14 Feb 2025 18:48:51 +0200 Subject: [PATCH] Change Sqlite connector to accept connection string instead of DbConnection Closes #10454 --- .../Step04_KernelFunctionStrategies.cs | 8 +- ...qliteVectorStoreRecordCollectionFactory.cs | 5 +- .../SqliteServiceCollectionExtensions.cs | 103 ++------- .../SqliteVectorStore.cs | 33 ++- ...liteVectorStoreCollectionCommandBuilder.cs | 44 ++-- .../SqliteVectorStoreRecordCollection.cs | 195 ++++++++++++------ .../Fakes/FakeDBConnection.cs | 32 --- .../Fakes/FakeDbCommand.cs | 45 ---- .../Fakes/FakeDbParameterCollection.cs | 105 ---------- .../SqliteServiceCollectionExtensionsTests.cs | 69 +------ ...ectorStoreCollectionCommandBuilderTests.cs | 25 ++- .../SqliteVectorStoreRecordCollectionTests.cs | 6 + .../SqliteVectorStoreTests.cs | 6 + .../SqliteServiceCollectionExtensionsTests.cs | 23 --- .../Memory/Sqlite/SqliteVectorStoreFixture.cs | 32 +-- .../SqliteVectorStoreRecordCollectionTests.cs | 7 +- .../Memory/Sqlite/SqliteVectorStoreTests.cs | 4 +- .../Support/SqliteTestStore.cs | 35 +--- 18 files changed, 252 insertions(+), 525 deletions(-) delete mode 100644 dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDBConnection.cs delete mode 100644 dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbCommand.cs delete mode 100644 dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbParameterCollection.cs diff --git a/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs b/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs index 4c7930bd2533..f924793951aa 100644 --- a/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs +++ b/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs @@ -70,11 +70,11 @@ public async Task UseKernelFunctionStrategiesWithAgentGroupChatAsync() Determine which participant takes the next turn in a conversation based on the the most recent participant. State only the name of the participant to take the next turn. No participant should take more than one turn in a row. - + Choose only from these participants: - {{{ReviewerName}}} - {{{CopyWriterName}}} - + Always follow these rules when selecting the next participant: - After {{{CopyWriterName}}}, it is {{{ReviewerName}}}'s turn. - After {{{ReviewerName}}}, it is {{{CopyWriterName}}}'s turn. @@ -133,9 +133,9 @@ No participant should take more than one turn in a row. chat.AddChatMessage(message); this.WriteAgentChatMessage(message); - await foreach (ChatMessageContent responese in chat.InvokeAsync()) + await foreach (ChatMessageContent response in chat.InvokeAsync()) { - this.WriteAgentChatMessage(responese); + this.WriteAgentChatMessage(response); } Console.WriteLine($"\n[IS COMPLETED: {chat.IsComplete}]"); diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/ISqliteVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/ISqliteVectorStoreRecordCollectionFactory.cs index 48bf1da53d2d..32f86679e1ec 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/ISqliteVectorStoreRecordCollectionFactory.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/ISqliteVectorStoreRecordCollectionFactory.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Data.Common; using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.Sqlite; @@ -15,12 +14,12 @@ public interface ISqliteVectorStoreRecordCollectionFactory /// /// The data type of the record key. /// The data model to use for adding, updating and retrieving data from storage. - /// that will be used to manage the data in SQLite. + /// The connection string for the SQLite database represented by this . /// The name of the collection to connect to. /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. /// The new instance of . IVectorStoreRecordCollection CreateVectorStoreRecordCollection( - DbConnection connection, + string connectionString, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) where TKey : notnull; diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteServiceCollectionExtensions.cs index 9c962c0786d5..11c7ed589ba7 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteServiceCollectionExtensions.cs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Data; +using System; using Microsoft.Data.Sqlite; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; @@ -22,29 +22,12 @@ public static class SqliteServiceCollectionExtensions /// Optional options to further configure the . /// An optional service id to use as the service key. /// Service collection. + [Obsolete("Use AddSqliteVectorStore with connectionString instead.", error: true)] public static IServiceCollection AddSqliteVectorStore( this IServiceCollection services, SqliteVectorStoreOptions? options = default, string? serviceId = default) - { - services.AddKeyedTransient( - serviceId, - (sp, obj) => - { - var connection = sp.GetRequiredService(); - - if (connection.State != ConnectionState.Open) - { - connection.Open(); - } - - var selectedOptions = options ?? sp.GetService(); - - return new SqliteVectorStore(connection, options); - }); - - return services; - } + => throw new InvalidOperationException("Use AddSqliteVectorStore with connectionString instead."); /// /// Register a SQLite with the specified service ID. @@ -60,24 +43,9 @@ public static IServiceCollection AddSqliteVectorStore( string connectionString, SqliteVectorStoreOptions? options = default, string? serviceId = default) - { - services.AddKeyedTransient( + => services.AddKeyedSingleton( serviceId, - (sp, obj) => - { - var connection = new SqliteConnection(connectionString); - var extensionName = GetExtensionName(options?.VectorSearchExtensionName); - - connection.Open(); - - connection.LoadExtension(extensionName); - - var selectedOptions = options ?? sp.GetService(); - return new SqliteVectorStore(connection, options); - }); - - return services; - } + (sp, _) => new SqliteVectorStore(connectionString, options ?? sp.GetService())); /// /// Register a SQLite and with the specified service ID @@ -91,33 +59,14 @@ public static IServiceCollection AddSqliteVectorStore( /// Optional options to further configure the . /// An optional service id to use as the service key. /// Service collection. + [Obsolete("Use AddSqliteVectorStoreRecordCollection with connectionString instead.", error: true)] public static IServiceCollection AddSqliteVectorStoreRecordCollection( this IServiceCollection services, string collectionName, SqliteVectorStoreRecordCollectionOptions? options = default, string? serviceId = default) where TKey : notnull - { - services.AddKeyedTransient>( - serviceId, - (sp, obj) => - { - var connection = sp.GetRequiredService(); - - if (connection.State != ConnectionState.Open) - { - connection.Open(); - } - - var selectedOptions = options ?? sp.GetService>(); - - return (new SqliteVectorStoreRecordCollection(connection, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; - }); - - AddVectorizedSearch(services, serviceId); - - return services; - } + => throw new InvalidOperationException("Use AddSqliteVectorStore with connectionString instead."); /// /// Register a SQLite and with the specified service ID. @@ -139,21 +88,14 @@ public static IServiceCollection AddSqliteVectorStoreRecordCollection>( + services.AddKeyedSingleton>( serviceId, - (sp, obj) => - { - var connection = new SqliteConnection(connectionString); - var extensionName = GetExtensionName(options?.VectorSearchExtensionName); - - connection.Open(); - - connection.LoadExtension(extensionName); - - var selectedOptions = options ?? sp.GetService>(); - - return (new SqliteVectorStoreRecordCollection(connection, collectionName, selectedOptions) as IVectorStoreRecordCollection)!; - }); + (sp, _) => ( + new SqliteVectorStoreRecordCollection( + connectionString, + collectionName, + options ?? sp.GetService>()) + as IVectorStoreRecordCollection)!); AddVectorizedSearch(services, serviceId); @@ -169,20 +111,7 @@ public static IServiceCollection AddSqliteVectorStoreRecordCollectionThe service id that the registrations should use. private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) where TKey : notnull - { - services.AddKeyedTransient>( + => services.AddKeyedSingleton>( serviceId, - (sp, obj) => - { - return sp.GetRequiredKeyedService>(serviceId); - }); - } - - /// - /// Returns extension name for vector search. - /// - private static string GetExtensionName(string? extensionName) - { - return !string.IsNullOrWhiteSpace(extensionName) ? extensionName! : SqliteConstants.VectorSearchExtensionName; - } + (sp, _) => sp.GetRequiredKeyedService>(serviceId)); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStore.cs index 86571536a5d5..3f246bd01e92 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStore.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStore.cs @@ -18,8 +18,8 @@ namespace Microsoft.SemanticKernel.Connectors.Sqlite; /// public sealed class SqliteVectorStore : IVectorStore { - /// that will be used to manage the data in SQLite. - private readonly DbConnection _connection; + /// The connection string for the SQLite database represented by this . + private readonly string _connectionString; /// Optional configuration options for this class. private readonly SqliteVectorStoreOptions _options; @@ -27,18 +27,27 @@ public sealed class SqliteVectorStore : IVectorStore /// /// Initializes a new instance of the class. /// - /// that will be used to manage the data in SQLite. + /// The connection string for the SQLite database represented by this . /// Optional configuration options for this class. - public SqliteVectorStore( - DbConnection connection, - SqliteVectorStoreOptions? options = default) + public SqliteVectorStore(string connectionString, SqliteVectorStoreOptions? options = default) { - Verify.NotNull(connection); + Verify.NotNull(connectionString); - this._connection = connection; + this._connectionString = connectionString; this._options = options ?? new(); } + /// + /// Initializes a new instance of the class. + /// + /// that will be used to manage the data in SQLite. + /// Optional configuration options for this class. + [Obsolete("Use the constructor that accepts a connection string instead.", error: true)] + public SqliteVectorStore( + DbConnection connection, + SqliteVectorStoreOptions? options = default) + => throw new InvalidOperationException("Use the constructor that accepts a connection string instead."); + /// public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) where TKey : notnull @@ -46,7 +55,7 @@ public IVectorStoreRecordCollection GetCollection( if (this._options.VectorStoreCollectionFactory is not null) { return this._options.VectorStoreCollectionFactory.CreateVectorStoreRecordCollection( - this._connection, + this._connectionString, name, vectorStoreRecordDefinition); } @@ -57,7 +66,7 @@ public IVectorStoreRecordCollection GetCollection( } var recordCollection = new SqliteVectorStoreRecordCollection( - this._connection, + this._connectionString, name, new() { @@ -75,7 +84,9 @@ public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancel const string TablePropertyName = "name"; const string Query = $"SELECT {TablePropertyName} FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"; - using var command = this._connection.CreateCommand(); + using var connection = new SqliteConnection(this._connectionString); + await connection.OpenAsync(cancellationToken).ConfigureAwait(false); + using var command = connection.CreateCommand(); command.CommandText = Query; diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs index 837e3044ddc7..3f9690d146a5 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs @@ -17,21 +17,7 @@ namespace Microsoft.SemanticKernel.Connectors.Sqlite; [SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "User input is passed using command parameters.")] internal sealed class SqliteVectorStoreCollectionCommandBuilder { - /// that will be used to manage the data in SQLite. - private readonly DbConnection _connection; - - /// - /// Initializes a new instance of the class. - /// - /// that will be used to manage the data in SQLite. - public SqliteVectorStoreCollectionCommandBuilder(DbConnection connection) - { - Verify.NotNull(connection); - - this._connection = connection; - } - - public DbCommand BuildTableCountCommand(string tableName) + public DbCommand BuildTableCountCommand(SqliteConnection connection, string tableName) { Verify.NotNullOrWhiteSpace(tableName); @@ -40,7 +26,7 @@ public DbCommand BuildTableCountCommand(string tableName) var query = $"SELECT count(*) FROM {SystemTable} WHERE type='table' AND name={ParameterName};"; - var command = this._connection.CreateCommand(); + var command = connection.CreateCommand(); command.CommandText = query; @@ -49,7 +35,7 @@ public DbCommand BuildTableCountCommand(string tableName) return command; } - public DbCommand BuildCreateTableCommand(string tableName, IReadOnlyList columns, bool ifNotExists) + public DbCommand BuildCreateTableCommand(SqliteConnection connection, string tableName, IReadOnlyList columns, bool ifNotExists) { var builder = new StringBuilder(); @@ -58,7 +44,7 @@ public DbCommand BuildCreateTableCommand(string tableName, IReadOnlyList columns, bool ifNotExists, @@ -78,18 +65,18 @@ public DbCommand BuildCreateVirtualTableCommand( builder.AppendLine(string.Join(",\n", columns.Select(GetColumnDefinition))); builder.Append(");"); - var command = this._connection.CreateCommand(); + var command = connection.CreateCommand(); command.CommandText = builder.ToString(); return command; } - public DbCommand BuildDropTableCommand(string tableName) + public DbCommand BuildDropTableCommand(SqliteConnection connection, string tableName) { string query = $"DROP TABLE [{tableName}];"; - var command = this._connection.CreateCommand(); + var command = connection.CreateCommand(); command.CommandText = query; @@ -97,6 +84,7 @@ public DbCommand BuildDropTableCommand(string tableName) } public DbCommand BuildInsertCommand( + SqliteConnection connection, string tableName, string rowIdentifier, IReadOnlyList columnNames, @@ -104,7 +92,7 @@ public DbCommand BuildInsertCommand( bool replaceIfExists = false) { var builder = new StringBuilder(); - var command = this._connection.CreateCommand(); + var command = connection.CreateCommand(); var replacePlaceholder = replaceIfExists ? " OR REPLACE" : string.Empty; @@ -133,6 +121,7 @@ public DbCommand BuildInsertCommand( } public DbCommand BuildSelectCommand( + SqliteConnection connection, string tableName, IReadOnlyList columnNames, List conditions, @@ -140,7 +129,7 @@ public DbCommand BuildSelectCommand( { var builder = new StringBuilder(); - var (command, whereClause) = this.GetCommandWithWhereClause(conditions); + var (command, whereClause) = this.GetCommandWithWhereClause(connection, conditions); builder.AppendLine($"SELECT {string.Join(", ", columnNames)}"); builder.AppendLine($"FROM {tableName}"); @@ -154,6 +143,7 @@ public DbCommand BuildSelectCommand( } public DbCommand BuildSelectLeftJoinCommand( + SqliteConnection connection, string leftTable, string rightTable, string joinColumnName, @@ -172,7 +162,7 @@ .. leftTablePropertyNames.Select(property => $"{leftTable}.{property}"), .. rightTablePropertyNames.Select(property => $"{rightTable}.{property}"), ]; - var (command, whereClause) = this.GetCommandWithWhereClause(conditions, extraWhereFilter, extraParameters); + var (command, whereClause) = this.GetCommandWithWhereClause(connection, conditions, extraWhereFilter, extraParameters); builder.AppendLine($"SELECT {string.Join(", ", propertyNames)}"); builder.AppendLine($"FROM {leftTable} "); @@ -187,12 +177,13 @@ .. rightTablePropertyNames.Select(property => $"{rightTable}.{property}"), } public DbCommand BuildDeleteCommand( + SqliteConnection connection, string tableName, List conditions) { var builder = new StringBuilder(); - var (command, whereClause) = this.GetCommandWithWhereClause(conditions); + var (command, whereClause) = this.GetCommandWithWhereClause(connection, conditions); builder.AppendLine($"DELETE FROM [{tableName}]"); @@ -242,13 +233,14 @@ private static string GetColumnDefinition(SqliteColumn column) } private (DbCommand Command, string WhereClause) GetCommandWithWhereClause( + SqliteConnection connection, List conditions, string? extraWhereFilter = null, Dictionary? extraParameters = null) { const string WhereClauseOperator = " AND "; - var command = this._connection.CreateCommand(); + var command = connection.CreateCommand(); var whereClauseParts = new List(); foreach (var condition in conditions) diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs index 8ae095dd3bf0..091643fb8bfb 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs @@ -7,6 +7,7 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Data.Sqlite; using Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.Sqlite; @@ -24,8 +25,8 @@ public sealed class SqliteVectorStoreRecordCollection : /// The name of this database for telemetry purposes. private const string DatabaseName = "SQLite"; - /// that will be used to manage the data in SQLite. - private readonly DbConnection _connection; + /// The connection string for the SQLite database represented by this . + private readonly string _connectionString; /// Optional configuration options for this class. private readonly SqliteVectorStoreRecordCollectionOptions _options; @@ -63,30 +64,34 @@ public sealed class SqliteVectorStoreRecordCollection : /// Table name in SQLite for vector properties. private readonly string _vectorTableName; + /// The sqlite_vec extension name to use. + private readonly string _vectorSearchExtensionName; + /// public string CollectionName { get; } /// /// Initializes a new instance of the class. /// - /// that will be used to manage the data in SQLite. + /// The connection string for the SQLite database represented by this . /// The name of the collection/table that this will access. /// Optional configuration options for this class. public SqliteVectorStoreRecordCollection( - DbConnection connection, + string connectionString, string collectionName, SqliteVectorStoreRecordCollectionOptions? options = default) { // Verify. - Verify.NotNull(connection); + Verify.NotNull(connectionString); Verify.NotNullOrWhiteSpace(collectionName); VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.DictionaryCustomMapper is not null, SqliteConstants.SupportedKeyTypes); VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); // Assign. - this._connection = connection; + this._connectionString = connectionString; this.CollectionName = collectionName; this._options = options ?? new(); + this._vectorSearchExtensionName = this._options.VectorSearchExtensionName ?? SqliteConstants.VectorSearchExtensionName; this._dataTableName = this.CollectionName; this._vectorTableName = GetVectorTableName(this._dataTableName, this._options); @@ -111,7 +116,7 @@ public SqliteVectorStoreRecordCollection( this._mapper = this.InitializeMapper(); - this._commandBuilder = new SqliteVectorStoreCollectionCommandBuilder(this._connection); + this._commandBuilder = new SqliteVectorStoreCollectionCommandBuilder(); } /// @@ -119,7 +124,8 @@ public async Task CollectionExistsAsync(CancellationToken cancellationToke { const string OperationName = "TableCount"; - using var command = this._commandBuilder.BuildTableCountCommand(this._dataTableName); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + using var command = this._commandBuilder.BuildTableCountCommand(connection, this._dataTableName); var result = await this .RunOperationAsync(OperationName, () => command.ExecuteScalarAsync(cancellationToken)) @@ -131,25 +137,31 @@ public async Task CollectionExistsAsync(CancellationToken cancellationToke } /// - public Task CreateCollectionAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionAsync(CancellationToken cancellationToken = default) { - return this.InternalCreateCollectionAsync(ifNotExists: false, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + await this.InternalCreateCollectionAsync(connection, ifNotExists: false, cancellationToken) + .ConfigureAwait(false); } /// - public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) { - return this.InternalCreateCollectionAsync(ifNotExists: true, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + await this.InternalCreateCollectionAsync(connection, ifNotExists: true, cancellationToken) + .ConfigureAwait(false); } /// public async Task DeleteCollectionAsync(CancellationToken cancellationToken = default) { - await this.DropTableAsync(this._dataTableName, cancellationToken).ConfigureAwait(false); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + + await this.DropTableAsync(connection, this._dataTableName, cancellationToken).ConfigureAwait(false); if (this._vectorPropertiesExist) { - await this.DropTableAsync(this._vectorTableName, cancellationToken).ConfigureAwait(false); + await this.DropTableAsync(connection, this._vectorTableName, cancellationToken).ConfigureAwait(false); } } @@ -227,39 +239,57 @@ public Task> VectorizedSearchAsync(TVector #region Implementation of IVectorStoreRecordCollection /// - public Task GetAsync(ulong key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async Task GetAsync(ulong key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { - return this.InternalGetAsync(key, options, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + return await this.InternalGetAsync(connection, key, options, cancellationToken).ConfigureAwait(false); } /// - public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async IAsyncEnumerable GetBatchAsync( + IEnumerable keys, + GetRecordOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) { - return this.InternalGetBatchAsync(keys, options, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + await foreach (var record in this.InternalGetBatchAsync(connection, keys, options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } } /// - public Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) { - return this.InternalUpsertAsync(record, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + return await this.InternalUpsertAsync(connection, record, cancellationToken).ConfigureAwait(false); } /// - public IAsyncEnumerable UpsertBatchAsync(IEnumerable records, CancellationToken cancellationToken = default) + public async IAsyncEnumerable UpsertBatchAsync( + IEnumerable records, + [EnumeratorCancellation] CancellationToken cancellationToken = default) { - return this.InternalUpsertBatchAsync(records, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + await foreach (var record in this.InternalUpsertBatchAsync(connection, records, cancellationToken) + .ConfigureAwait(false)) + { + yield return record; + } } /// - public Task DeleteAsync(ulong key, CancellationToken cancellationToken = default) + public async Task DeleteAsync(ulong key, CancellationToken cancellationToken = default) { - return this.InternalDeleteAsync(key, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + await this.InternalDeleteAsync(connection, key, cancellationToken).ConfigureAwait(false); } /// - public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) { - return this.InternalDeleteBatchAsync(keys, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + await this.InternalDeleteBatchAsync(connection, keys, cancellationToken).ConfigureAwait(false); } #endregion @@ -267,45 +297,73 @@ public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancella #region Implementation of IVectorStoreRecordCollection /// - public Task GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async Task GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) { - return this.InternalGetAsync(key, options, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + return await this.InternalGetAsync(connection, key, options, cancellationToken).ConfigureAwait(false); } /// - public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + public async IAsyncEnumerable GetBatchAsync( + IEnumerable keys, + GetRecordOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) { - return this.InternalGetBatchAsync(keys, options, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + await foreach (var record in this.InternalGetBatchAsync(connection, keys, options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } } /// - Task IVectorStoreRecordCollection.UpsertAsync(TRecord record, CancellationToken cancellationToken) + async Task IVectorStoreRecordCollection.UpsertAsync(TRecord record, CancellationToken cancellationToken) { - return this.InternalUpsertAsync(record, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + return await this.InternalUpsertAsync(connection, record, cancellationToken) + .ConfigureAwait(false); } /// - IAsyncEnumerable IVectorStoreRecordCollection.UpsertBatchAsync(IEnumerable records, CancellationToken cancellationToken) + async IAsyncEnumerable IVectorStoreRecordCollection.UpsertBatchAsync( + IEnumerable records, + [EnumeratorCancellation] CancellationToken cancellationToken) { - return this.InternalUpsertBatchAsync(records, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + await foreach (var record in this.InternalUpsertBatchAsync(connection, records, cancellationToken) + .ConfigureAwait(false)) + { + yield return record; + } } /// - public Task DeleteAsync(string key, CancellationToken cancellationToken = default) + public async Task DeleteAsync(string key, CancellationToken cancellationToken = default) { - return this.InternalDeleteAsync(key, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + await this.InternalDeleteAsync(connection, key, cancellationToken) + .ConfigureAwait(false); } /// - public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) + public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default) { - return this.InternalDeleteBatchAsync(keys, cancellationToken); + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); + await this.InternalDeleteBatchAsync(connection, keys, cancellationToken).ConfigureAwait(false); } #endregion #region private + private async ValueTask GetConnectionAsync(CancellationToken cancellationToken = default) + { + var connection = new SqliteConnection(this._connectionString); + await connection.OpenAsync(cancellationToken).ConfigureAwait(false); + connection.LoadExtension(this._vectorSearchExtensionName); + return connection; + } + private async IAsyncEnumerable> EnumerateAndMapSearchResultsAsync( List conditions, string? extraWhereFilter, @@ -326,7 +384,9 @@ private async IAsyncEnumerable> EnumerateAndMapSearc properties.AddRange(this._propertyReader.VectorProperties); } + using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false); using var command = this._commandBuilder.BuildSelectLeftJoinCommand( + connection, this._vectorTableName, this._dataTableName, this._propertyReader.KeyPropertyStoragePropertyName, @@ -356,13 +416,14 @@ private async IAsyncEnumerable> EnumerateAndMapSearc } } - private Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken) + private Task InternalCreateCollectionAsync(SqliteConnection connection, bool ifNotExists, CancellationToken cancellationToken) { List dataTableColumns = SqliteVectorStoreRecordPropertyMapping.GetColumns( this._dataTableProperties.Value, this._propertyReader.StoragePropertyNamesMap); List tasks = [this.CreateTableAsync( + connection, this._dataTableName, dataTableColumns, ifNotExists, @@ -379,6 +440,7 @@ private Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken c this._propertyReader.StoragePropertyNamesMap); tasks.Add(this.CreateVirtualTableAsync( + connection, this._vectorTableName, vectorTableColumns, ifNotExists, @@ -389,34 +451,35 @@ private Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken c return Task.WhenAll(tasks); } - private Task CreateTableAsync(string tableName, List columns, bool ifNotExists, CancellationToken cancellationToken) + private Task CreateTableAsync(SqliteConnection connection, string tableName, List columns, bool ifNotExists, CancellationToken cancellationToken) { const string OperationName = "CreateTable"; - using var command = this._commandBuilder.BuildCreateTableCommand(tableName, columns, ifNotExists); + using var command = this._commandBuilder.BuildCreateTableCommand(connection, tableName, columns, ifNotExists); return this.RunOperationAsync(OperationName, () => command.ExecuteNonQueryAsync(cancellationToken)); } - private Task CreateVirtualTableAsync(string tableName, List columns, bool ifNotExists, string extensionName, CancellationToken cancellationToken) + private Task CreateVirtualTableAsync(SqliteConnection connection, string tableName, List columns, bool ifNotExists, string extensionName, CancellationToken cancellationToken) { const string OperationName = "CreateVirtualTable"; - using var command = this._commandBuilder.BuildCreateVirtualTableCommand(tableName, columns, ifNotExists, extensionName); + using var command = this._commandBuilder.BuildCreateVirtualTableCommand(connection, tableName, columns, ifNotExists, extensionName); return this.RunOperationAsync(OperationName, () => command.ExecuteNonQueryAsync(cancellationToken)); } - private Task DropTableAsync(string tableName, CancellationToken cancellationToken) + private Task DropTableAsync(SqliteConnection connection, string tableName, CancellationToken cancellationToken) { const string OperationName = "DropTable"; - using var command = this._commandBuilder.BuildDropTableCommand(tableName); + using var command = this._commandBuilder.BuildDropTableCommand(connection, tableName); return this.RunOperationAsync(OperationName, () => command.ExecuteNonQueryAsync(cancellationToken)); } private async Task InternalGetAsync( + SqliteConnection connection, TKey key, GetRecordOptions? options, CancellationToken cancellationToken) @@ -428,12 +491,13 @@ private Task DropTableAsync(string tableName, CancellationToken cancellatio TableName = this._dataTableName }; - return await this.InternalGetBatchAsync(condition, options, cancellationToken) + return await this.InternalGetBatchAsync(connection, condition, options, cancellationToken) .FirstOrDefaultAsync(cancellationToken) .ConfigureAwait(false); } private IAsyncEnumerable InternalGetBatchAsync( + SqliteConnection connection, IEnumerable keys, GetRecordOptions? options, CancellationToken cancellationToken) @@ -449,10 +513,11 @@ private IAsyncEnumerable InternalGetBatchAsync( TableName = this._dataTableName }; - return this.InternalGetBatchAsync(condition, options, cancellationToken); + return this.InternalGetBatchAsync(connection, condition, options, cancellationToken); } private async IAsyncEnumerable InternalGetBatchAsync( + SqliteConnection connection, SqliteWhereCondition condition, GetRecordOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) @@ -467,6 +532,7 @@ private async IAsyncEnumerable InternalGetBatchAsync( if (includeVectors) { command = this._commandBuilder.BuildSelectLeftJoinCommand( + connection, this._dataTableName, this._vectorTableName, this._propertyReader.KeyPropertyStoragePropertyName, @@ -479,6 +545,7 @@ private async IAsyncEnumerable InternalGetBatchAsync( else { command = this._commandBuilder.BuildSelectCommand( + connection, this._dataTableName, this._dataTableStoragePropertyNames.Value, [condition]); @@ -499,7 +566,7 @@ private async IAsyncEnumerable InternalGetBatchAsync( } } - private async Task InternalUpsertAsync(TRecord record, CancellationToken cancellationToken) + private async Task InternalUpsertAsync(SqliteConnection connection, TRecord record, CancellationToken cancellationToken) { const string OperationName = "Upsert"; @@ -515,14 +582,14 @@ private async Task InternalUpsertAsync(TRecord record, CancellationT var condition = new SqliteWhereEqualsCondition(this._propertyReader.KeyPropertyStoragePropertyName, key); - var upsertedRecordKey = await this.InternalUpsertBatchAsync([storageModel], condition, cancellationToken) + var upsertedRecordKey = await this.InternalUpsertBatchAsync(connection, [storageModel], condition, cancellationToken) .FirstOrDefaultAsync(cancellationToken) .ConfigureAwait(false); return upsertedRecordKey ?? throw new VectorStoreOperationException("Error occurred during upsert operation."); } - private IAsyncEnumerable InternalUpsertBatchAsync(IEnumerable records, CancellationToken cancellationToken) + private IAsyncEnumerable InternalUpsertBatchAsync(SqliteConnection connection, IEnumerable records, CancellationToken cancellationToken) { const string OperationName = "UpsertBatch"; @@ -536,10 +603,11 @@ private IAsyncEnumerable InternalUpsertBatchAsync(IEnumerable(storageModels, condition, cancellationToken); + return this.InternalUpsertBatchAsync(connection, storageModels, condition, cancellationToken); } private async IAsyncEnumerable InternalUpsertBatchAsync( + SqliteConnection connection, List> storageModels, SqliteWhereCondition condition, [EnumeratorCancellation] CancellationToken cancellationToken) @@ -552,12 +620,14 @@ private async IAsyncEnumerable InternalUpsertBatchAsync( // Deleting vector records first since current version of vector search extension // doesn't support Upsert operation, only Delete/Insert. using var vectorDeleteCommand = this._commandBuilder.BuildDeleteCommand( + connection, this._vectorTableName, [condition]); await vectorDeleteCommand.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); using var vectorInsertCommand = this._commandBuilder.BuildInsertCommand( + connection, this._vectorTableName, this._propertyReader.KeyPropertyStoragePropertyName, this._vectorTableStoragePropertyNames.Value, @@ -567,11 +637,12 @@ private async IAsyncEnumerable InternalUpsertBatchAsync( } using var dataCommand = this._commandBuilder.BuildInsertCommand( - this._dataTableName, - this._propertyReader.KeyPropertyStoragePropertyName, - this._dataTableStoragePropertyNames.Value, - storageModels, - replaceIfExists: true); + connection, + this._dataTableName, + this._propertyReader.KeyPropertyStoragePropertyName, + this._dataTableStoragePropertyNames.Value, + storageModels, + replaceIfExists: true); using var reader = await dataCommand.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); @@ -588,16 +659,16 @@ private async IAsyncEnumerable InternalUpsertBatchAsync( } } - private Task InternalDeleteAsync(TKey key, CancellationToken cancellationToken) + private Task InternalDeleteAsync(SqliteConnection connection, TKey key, CancellationToken cancellationToken) { Verify.NotNull(key); var condition = new SqliteWhereEqualsCondition(this._propertyReader.KeyPropertyStoragePropertyName, key); - return this.InternalDeleteBatchAsync(condition, cancellationToken); + return this.InternalDeleteBatchAsync(connection, condition, cancellationToken); } - private Task InternalDeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken) + private Task InternalDeleteBatchAsync(SqliteConnection connection, IEnumerable keys, CancellationToken cancellationToken) { Verify.NotNull(keys); @@ -609,10 +680,10 @@ private Task InternalDeleteBatchAsync(IEnumerable keys, Cancellation this._propertyReader.KeyPropertyStoragePropertyName, keysList); - return this.InternalDeleteBatchAsync(condition, cancellationToken); + return this.InternalDeleteBatchAsync(connection, condition, cancellationToken); } - private Task InternalDeleteBatchAsync(SqliteWhereCondition condition, CancellationToken cancellationToken) + private Task InternalDeleteBatchAsync(SqliteConnection connection, SqliteWhereCondition condition, CancellationToken cancellationToken) { const string OperationName = "Delete"; @@ -621,6 +692,7 @@ private Task InternalDeleteBatchAsync(SqliteWhereCondition condition, Cancellati if (this._vectorPropertiesExist) { using var vectorCommand = this._commandBuilder.BuildDeleteCommand( + connection, this._vectorTableName, [condition]); @@ -628,6 +700,7 @@ private Task InternalDeleteBatchAsync(SqliteWhereCondition condition, Cancellati } using var dataCommand = this._commandBuilder.BuildDeleteCommand( + connection, this._dataTableName, [condition]); diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDBConnection.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDBConnection.cs deleted file mode 100644 index 7c318e1ef413..000000000000 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDBConnection.cs +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Data; -using System.Data.Common; - -namespace SemanticKernel.Connectors.Sqlite.UnitTests; - -#pragma warning disable CS8618, CS8765 - -internal sealed class FakeDBConnection(DbCommand command) : DbConnection -{ - public override string ConnectionString { get; set; } - - public override string Database => throw new NotImplementedException(); - - public override string DataSource => throw new NotImplementedException(); - - public override string ServerVersion => throw new NotImplementedException(); - - public override ConnectionState State => throw new NotImplementedException(); - - public override void ChangeDatabase(string databaseName) => throw new NotImplementedException(); - - public override void Close() => throw new NotImplementedException(); - - public override void Open() => throw new NotImplementedException(); - - protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => throw new NotImplementedException(); - - protected override DbCommand CreateDbCommand() => command; -} diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbCommand.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbCommand.cs deleted file mode 100644 index df6062d9a4c1..000000000000 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbCommand.cs +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Data; -using System.Data.Common; - -namespace SemanticKernel.Connectors.Sqlite.UnitTests; - -#pragma warning disable CS8618, CS8765 - -internal sealed class FakeDbCommand( - DbDataReader? dataReader = null, - object? scalarResult = null) : DbCommand -{ - public int ExecuteNonQueryCallCount { get; private set; } = 0; - - private readonly FakeDbParameterCollection _parameterCollection = []; - - public override string CommandText { get; set; } - public override int CommandTimeout { get; set; } - public override CommandType CommandType { get; set; } - public override bool DesignTimeVisible { get; set; } - public override UpdateRowSource UpdatedRowSource { get; set; } - protected override DbConnection? DbConnection { get; set; } - - protected override DbParameterCollection DbParameterCollection => this._parameterCollection; - - protected override DbTransaction? DbTransaction { get; set; } - - public override void Cancel() => throw new NotImplementedException(); - - public override int ExecuteNonQuery() - { - this.ExecuteNonQueryCallCount++; - return 0; - } - - public override object? ExecuteScalar() => scalarResult; - - public override void Prepare() => throw new NotImplementedException(); - - protected override DbParameter CreateDbParameter() => throw new NotImplementedException(); - - protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) => dataReader ?? throw new NotImplementedException(); -} diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbParameterCollection.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbParameterCollection.cs deleted file mode 100644 index 246b97a3360b..000000000000 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbParameterCollection.cs +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections; -using System.Collections.Generic; -using System.Data.Common; - -namespace SemanticKernel.Connectors.Sqlite.UnitTests; - -#pragma warning disable CA1812 - -internal sealed class FakeDbParameterCollection : DbParameterCollection -{ - private readonly List _parameters = []; - - public override int Count => this._parameters.Count; - - public override object SyncRoot => throw new NotImplementedException(); - - public override int Add(object value) - { - this._parameters.Add(value); - return default; - } - - public override void AddRange(Array values) - { - this._parameters.AddRange([.. values]); - } - - public override void Clear() - { - this._parameters.Clear(); - } - - public override bool Contains(object value) - { - return this._parameters.Contains(value); - } - - public override bool Contains(string value) - { - return this._parameters.Contains(value); - } - - public override void CopyTo(Array array, int index) - { - this._parameters.CopyTo([.. array], index); - } - - public override IEnumerator GetEnumerator() - { - return this._parameters.GetEnumerator(); - } - - public override int IndexOf(object value) - { - return this._parameters.IndexOf(value); - } - - public override int IndexOf(string parameterName) - { - return this._parameters.IndexOf(parameterName); - } - - public override void Insert(int index, object value) - { - this._parameters.Insert(index, value); - } - - public override void Remove(object value) - { - this._parameters.Remove(value); - } - - public override void RemoveAt(int index) - { - this._parameters.RemoveAt(index); - } - - public override void RemoveAt(string parameterName) - { - throw new NotImplementedException(); - } - - protected override DbParameter GetParameter(int index) - { - return (this._parameters[index] as DbParameter)!; - } - - protected override DbParameter GetParameter(string parameterName) - { - throw new NotImplementedException(); - } - - protected override void SetParameter(int index, DbParameter value) - { - this._parameters[index] = value; - } - - protected override void SetParameter(string parameterName, DbParameter value) - { - throw new NotImplementedException(); - } -} diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteServiceCollectionExtensionsTests.cs index 69488cf4d8d4..e7f78e388c02 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteServiceCollectionExtensionsTests.cs @@ -1,12 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Data; -using Microsoft.Data.Sqlite; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.Sqlite; -using Moq; using Xunit; namespace SemanticKernel.Connectors.Sqlite.UnitTests; @@ -18,23 +15,11 @@ public sealed class SqliteServiceCollectionExtensionsTests { private readonly IServiceCollection _serviceCollection = new ServiceCollection(); - [Theory] - [InlineData(ConnectionState.Open)] - [InlineData(ConnectionState.Closed)] - public void AddVectorStoreRegistersClass(ConnectionState connectionState) + [Fact] + public void AddVectorStoreRegistersClass() { - // Arrange - var expectedOpenCalls = connectionState == ConnectionState.Closed ? 1 : 0; - - var mockConnection = new Mock(); - - mockConnection.Setup(l => l.State).Returns(connectionState); - mockConnection.Setup(l => l.Open()); - - this._serviceCollection.AddTransient((_) => mockConnection.Object); - // Act - this._serviceCollection.AddSqliteVectorStore(); + this._serviceCollection.AddSqliteVectorStore("Data Source=:memory:"); var serviceProvider = this._serviceCollection.BuildServiceProvider(); var vectorStore = serviceProvider.GetRequiredService(); @@ -42,30 +27,13 @@ public void AddVectorStoreRegistersClass(ConnectionState connectionState) // Assert Assert.NotNull(vectorStore); Assert.IsType(vectorStore); - - mockConnection.Verify(l => l.Open(), Times.Exactly(expectedOpenCalls)); } - [Theory] - [InlineData(ConnectionState.Open)] - [InlineData(ConnectionState.Closed)] - public void AddVectorStoreRecordCollectionWithStringKeyRegistersClass(ConnectionState connectionState) + [Fact] + public void AddVectorStoreRecordCollectionWithStringKeyRegistersClass() { - // Arrange - var expectedOpenCalls = connectionState == ConnectionState.Closed ? 1 : 0; - - var mockConnection = new Mock(); - - mockConnection.SetupSequence(l => l.State) - .Returns(connectionState) - .Returns(ConnectionState.Open); - - mockConnection.Setup(l => l.Open()); - - this._serviceCollection.AddTransient((_) => mockConnection.Object); - // Act - this._serviceCollection.AddSqliteVectorStoreRecordCollection("testcollection"); + this._serviceCollection.AddSqliteVectorStoreRecordCollection("testcollection", "Data Source=:memory:"); var serviceProvider = this._serviceCollection.BuildServiceProvider(); @@ -77,30 +45,13 @@ public void AddVectorStoreRecordCollectionWithStringKeyRegistersClass(Connection var vectorizedSearch = serviceProvider.GetRequiredService>(); Assert.NotNull(vectorizedSearch); Assert.IsType>(vectorizedSearch); - - mockConnection.Verify(l => l.Open(), Times.Exactly(expectedOpenCalls)); } - [Theory] - [InlineData(ConnectionState.Open)] - [InlineData(ConnectionState.Closed)] - public void AddVectorStoreRecordCollectionWithNumericKeyRegistersClass(ConnectionState connectionState) + [Fact] + public void AddVectorStoreRecordCollectionWithNumericKeyRegistersClass() { - // Arrange - var expectedOpenCalls = connectionState == ConnectionState.Closed ? 1 : 0; - - var mockConnection = new Mock(); - - mockConnection.SetupSequence(l => l.State) - .Returns(connectionState) - .Returns(ConnectionState.Open); - - mockConnection.Setup(l => l.Open()); - - this._serviceCollection.AddTransient((_) => mockConnection.Object); - // Act - this._serviceCollection.AddSqliteVectorStoreRecordCollection("testcollection"); + this._serviceCollection.AddSqliteVectorStoreRecordCollection("testcollection", "Data Source=:memory:"); var serviceProvider = this._serviceCollection.BuildServiceProvider(); @@ -112,8 +63,6 @@ public void AddVectorStoreRecordCollectionWithNumericKeyRegistersClass(Connectio var vectorizedSearch = serviceProvider.GetRequiredService>(); Assert.NotNull(vectorizedSearch); Assert.IsType>(vectorizedSearch); - - mockConnection.Verify(l => l.Open(), Times.Exactly(expectedOpenCalls)); } #region private diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs index 370756cb4344..af04eae7a26e 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using Microsoft.Data.Sqlite; using Microsoft.SemanticKernel.Connectors.Sqlite; using Xunit; @@ -12,15 +13,15 @@ namespace SemanticKernel.Connectors.Sqlite.UnitTests; /// public sealed class SqliteVectorStoreCollectionCommandBuilderTests : IDisposable { - private readonly FakeDbCommand _command; - private readonly FakeDBConnection _connection; + private readonly SqliteCommand _command; + private readonly SqliteConnection _connection; private readonly SqliteVectorStoreCollectionCommandBuilder _commandBuilder; public SqliteVectorStoreCollectionCommandBuilderTests() { - this._command = new(); - this._connection = new(this._command); - this._commandBuilder = new(this._connection); + this._command = new() { Connection = this._connection }; + this._connection = new(); + this._commandBuilder = new(); } [Fact] @@ -30,7 +31,7 @@ public void ItBuildsTableCountCommand() const string TableName = "TestTable"; // Act - var command = this._commandBuilder.BuildTableCountCommand(TableName); + var command = this._commandBuilder.BuildTableCountCommand(this._connection, TableName); // Assert Assert.Equal("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=@tableName;", command.CommandText); @@ -53,7 +54,7 @@ public void ItBuildsCreateTableCommand(bool ifNotExists) }; // Act - var command = this._commandBuilder.BuildCreateTableCommand(TableName, columns, ifNotExists); + var command = this._commandBuilder.BuildCreateTableCommand(this._connection, TableName, columns, ifNotExists); // Assert Assert.Contains("CREATE TABLE", command.CommandText); @@ -81,7 +82,7 @@ public void ItBuildsCreateVirtualTableCommand(bool ifNotExists) }; // Act - var command = this._commandBuilder.BuildCreateVirtualTableCommand(TableName, columns, ifNotExists, ExtensionName); + var command = this._commandBuilder.BuildCreateVirtualTableCommand(this._connection, TableName, columns, ifNotExists, ExtensionName); // Assert Assert.Contains("CREATE VIRTUAL TABLE", command.CommandText); @@ -101,7 +102,7 @@ public void ItBuildsDropTableCommand() const string TableName = "TestTable"; // Act - var command = this._commandBuilder.BuildDropTableCommand(TableName); + var command = this._commandBuilder.BuildDropTableCommand(this._connection, TableName); // Assert Assert.Equal("DROP TABLE [TestTable];", command.CommandText); @@ -125,6 +126,7 @@ public void ItBuildsInsertCommand(bool replaceIfExists) // Act var command = this._commandBuilder.BuildInsertCommand( + this._connection, TableName, RowIdentifier, columnNames, @@ -181,7 +183,7 @@ public void ItBuildsSelectCommand(string? orderByPropertyName) }; // Act - var command = this._commandBuilder.BuildSelectCommand(TableName, columnNames, conditions, orderByPropertyName); + var command = this._commandBuilder.BuildSelectCommand(this._connection, TableName, columnNames, conditions, orderByPropertyName); // Assert Assert.Contains("SELECT Id, Name, Age, Address", command.CommandText); @@ -227,6 +229,7 @@ public void ItBuildsSelectLeftJoinCommand(string? orderByPropertyName) // Act var command = this._commandBuilder.BuildSelectLeftJoinCommand( + this._connection, LeftTable, RightTable, JoinColumnName, @@ -274,7 +277,7 @@ public void ItBuildsDeleteCommand() }; // Act - var command = this._commandBuilder.BuildDeleteCommand(TableName, conditions); + var command = this._commandBuilder.BuildDeleteCommand(this._connection, TableName, conditions); // Assert Assert.Contains("DELETE FROM [TestTable]", command.CommandText); diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordCollectionTests.cs index 631bf6cebf3d..59cc3c3401e4 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordCollectionTests.cs @@ -1,5 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. +// TODO: Reimplement these as integration tests, #10464 + +#if DISABLED + using System; using System.Collections.Generic; using System.Data.Common; @@ -400,3 +404,5 @@ private sealed class TestRecordWithoutVectorProperty #endregion } + +#endif diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreTests.cs index c6cf80e8b085..71656fe87a0d 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreTests.cs @@ -1,5 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. +// TODO: Reimplement these as integration tests, #10464 + +#if DISABLED + using System; using System.Data.Common; using System.Linq; @@ -102,3 +106,5 @@ public async Task ListCollectionNamesReturnsCollectionNamesAsync() Assert.Contains("collection2", collections); } } + +#endif diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteServiceCollectionExtensionsTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteServiceCollectionExtensionsTests.cs index 2e3e6b32fe52..bfded601d8ec 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteServiceCollectionExtensionsTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteServiceCollectionExtensionsTests.cs @@ -69,29 +69,6 @@ public void AddVectorStoreRecordCollectionWithNumericKeyAndSqliteConnectionRegis Assert.IsType>(vectorizedSearch); } - [Fact(Skip = SkipReason)] - public void ItClosesConnectionWhenDIServiceIsDisposed() - { - // Act - using var connection = new SqliteConnection("Data Source=:memory:"); - - this._serviceCollection.AddTransient(_ => connection); - - this._serviceCollection.AddSqliteVectorStore(); - - var serviceProvider = this._serviceCollection.BuildServiceProvider(); - - using (var scope = serviceProvider.CreateScope()) - { - scope.ServiceProvider.GetRequiredService(); - - Assert.Equal(ConnectionState.Open, connection.State); - } - - // Assert - Assert.Equal(ConnectionState.Closed, connection.State); - } - #region private #pragma warning disable CA1812 // Avoid uninstantiated internal classes diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreFixture.cs index 6f07f20ddf67..c3a702c5a7c0 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreFixture.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.IO; using System.Threading.Tasks; using Microsoft.Data.Sqlite; using Microsoft.SemanticKernel.Connectors.Sqlite; @@ -8,43 +9,22 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Sqlite; -public class SqliteVectorStoreFixture : IAsyncLifetime, IDisposable +public class SqliteVectorStoreFixture : IDisposable { - /// - /// SQLite extension name for vector search. - /// More information here: . - /// - private const string VectorSearchExtensionName = "vec0"; + private readonly string _databasePath = Path.GetTempFileName(); - public SqliteConnection Connection { get; } - - public SqliteVectorStoreFixture() - { - this.Connection = new SqliteConnection("Data Source=:memory:"); - } + public string ConnectionString => $"Data Source={this._databasePath}"; public SqliteVectorStoreRecordCollection GetCollection( string collectionName, SqliteVectorStoreRecordCollectionOptions? options = default) { return new SqliteVectorStoreRecordCollection( - this.Connection, + this.ConnectionString, collectionName, options); } - public Task DisposeAsync() - { - return Task.CompletedTask; - } - - public async Task InitializeAsync() - { - await this.Connection.OpenAsync(); - - this.Connection.LoadExtension(VectorSearchExtensionName); - } - public void Dispose() { this.Dispose(true); @@ -55,7 +35,7 @@ protected virtual void Dispose(bool disposing) { if (disposing) { - this.Connection.Dispose(); + File.Delete(this._databasePath); } } } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs index c0dbb5fcf680..77b97a0a4428 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Runtime.InteropServices; using System.Threading.Tasks; +using Microsoft.Data.Sqlite; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Sqlite; using Xunit; @@ -245,7 +246,9 @@ public async Task ItCanGetExistingRecordAsync(bool includeVectors) var record = CreateTestHotel(HotelId); - var commandData = fixture.Connection.CreateCommand(); + using var connection = new SqliteConnection(fixture.ConnectionString); + await connection.OpenAsync(); + var commandData = connection.CreateCommand(); commandData.CommandText = $"INSERT INTO {collectionName} " + @@ -262,7 +265,7 @@ public async Task ItCanGetExistingRecordAsync(bool includeVectors) if (includeVectors) { - var commandVector = fixture.Connection.CreateCommand(); + var commandVector = connection.CreateCommand(); commandVector.CommandText = $"INSERT INTO vec_{collectionName} " + diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreTests.cs index 8a173250f7fe..6eca22778b02 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreTests.cs @@ -17,7 +17,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Sqlite; [Collection("SqliteVectorStoreCollection")] [DisableVectorStoreTests(Skip = "SQLite vector search extension is required")] public sealed class SqliteVectorStoreTests(SqliteVectorStoreFixture fixture) - : BaseVectorStoreTests>(new SqliteVectorStore(fixture.Connection!)) + : BaseVectorStoreTests>(new SqliteVectorStore(fixture.ConnectionString)) { [VectorStoreFact] public async Task ItCanGetAListOfExistingCollectionNamesWhenRegisteredWithDIAsync() @@ -25,7 +25,7 @@ public async Task ItCanGetAListOfExistingCollectionNamesWhenRegisteredWithDIAsyn // Arrange var serviceCollection = new ServiceCollection(); - serviceCollection.AddSqliteVectorStore(connectionString: "Data Source=:memory:"); + serviceCollection.AddSqliteVectorStore(fixture.ConnectionString); var provider = serviceCollection.BuildServiceProvider(); diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs index 526eeac3b2d8..9b025c66610f 100644 --- a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs @@ -1,21 +1,16 @@ // Copyright (c) Microsoft. All rights reserved. -using Microsoft.Data.Sqlite; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Sqlite; using VectorDataSpecificationTests.Support; namespace SqliteIntegrationTests.Support; -#pragma warning disable CA1001 // Type owns disposable fields (_connection) but is not disposable - internal sealed class SqliteTestStore : TestStore { - public static SqliteTestStore Instance { get; } = new(); + private string? _databasePath; - private SqliteConnection? _connection; - public SqliteConnection Connection - => this._connection ?? throw new InvalidOperationException("Call InitializeAsync() first"); + public static SqliteTestStore Instance { get; } = new(); private SqliteVectorStore? _defaultVectorStore; public override IVectorStore DefaultVectorStore @@ -25,31 +20,17 @@ private SqliteTestStore() { } - protected override async Task StartAsync() + protected override Task StartAsync() { - this._connection = new SqliteConnection("Data Source=:memory:"); - - await this.Connection.OpenAsync(); - - if (!SqliteTestEnvironment.TryLoadSqliteVec(this.Connection)) - { - this.Connection.Dispose(); - - // Note that we ignore sqlite_vec loading failures; the tests are decorated with [SqliteVecRequired], which causes - // them to be skipped if sqlite_vec isn't installed (better than an exception triggering failure here) - } - - this._defaultVectorStore = new SqliteVectorStore(this.Connection); + this._databasePath = Path.GetTempFileName(); + this._defaultVectorStore = new SqliteVectorStore($"Data Source={this._databasePath}"); + return Task.CompletedTask; } -#if NET8_0_OR_GREATER - protected override async Task StopAsync() - => await this.Connection.DisposeAsync(); -#else protected override Task StopAsync() { - this.Connection.Dispose(); + File.Delete(this._databasePath!); + this._databasePath = null; return Task.CompletedTask; } -#endif }