diff --git a/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs b/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs
index 4c7930bd2533..f924793951aa 100644
--- a/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs
+++ b/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs
@@ -70,11 +70,11 @@ public async Task UseKernelFunctionStrategiesWithAgentGroupChatAsync()
Determine which participant takes the next turn in a conversation based on the the most recent participant.
State only the name of the participant to take the next turn.
No participant should take more than one turn in a row.
-
+
Choose only from these participants:
- {{{ReviewerName}}}
- {{{CopyWriterName}}}
-
+
Always follow these rules when selecting the next participant:
- After {{{CopyWriterName}}}, it is {{{ReviewerName}}}'s turn.
- After {{{ReviewerName}}}, it is {{{CopyWriterName}}}'s turn.
@@ -133,9 +133,9 @@ No participant should take more than one turn in a row.
chat.AddChatMessage(message);
this.WriteAgentChatMessage(message);
- await foreach (ChatMessageContent responese in chat.InvokeAsync())
+ await foreach (ChatMessageContent response in chat.InvokeAsync())
{
- this.WriteAgentChatMessage(responese);
+ this.WriteAgentChatMessage(response);
}
Console.WriteLine($"\n[IS COMPLETED: {chat.IsComplete}]");
diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/ISqliteVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/ISqliteVectorStoreRecordCollectionFactory.cs
index 48bf1da53d2d..32f86679e1ec 100644
--- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/ISqliteVectorStoreRecordCollectionFactory.cs
+++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/ISqliteVectorStoreRecordCollectionFactory.cs
@@ -1,6 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
-using System.Data.Common;
using Microsoft.Extensions.VectorData;
namespace Microsoft.SemanticKernel.Connectors.Sqlite;
@@ -15,12 +14,12 @@ public interface ISqliteVectorStoreRecordCollectionFactory
///
/// The data type of the record key.
/// The data model to use for adding, updating and retrieving data from storage.
- /// that will be used to manage the data in SQLite.
+ /// The connection string for the SQLite database represented by this .
/// The name of the collection to connect to.
/// An optional record definition that defines the schema of the record type. If not present, attributes on will be used.
/// The new instance of .
IVectorStoreRecordCollection CreateVectorStoreRecordCollection(
- DbConnection connection,
+ string connectionString,
string name,
VectorStoreRecordDefinition? vectorStoreRecordDefinition)
where TKey : notnull;
diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteServiceCollectionExtensions.cs
index 9c962c0786d5..11c7ed589ba7 100644
--- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteServiceCollectionExtensions.cs
+++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteServiceCollectionExtensions.cs
@@ -1,6 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
-using System.Data;
+using System;
using Microsoft.Data.Sqlite;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.VectorData;
@@ -22,29 +22,12 @@ public static class SqliteServiceCollectionExtensions
/// Optional options to further configure the .
/// An optional service id to use as the service key.
/// Service collection.
+ [Obsolete("Use AddSqliteVectorStore with connectionString instead.", error: true)]
public static IServiceCollection AddSqliteVectorStore(
this IServiceCollection services,
SqliteVectorStoreOptions? options = default,
string? serviceId = default)
- {
- services.AddKeyedTransient(
- serviceId,
- (sp, obj) =>
- {
- var connection = sp.GetRequiredService();
-
- if (connection.State != ConnectionState.Open)
- {
- connection.Open();
- }
-
- var selectedOptions = options ?? sp.GetService();
-
- return new SqliteVectorStore(connection, options);
- });
-
- return services;
- }
+ => throw new InvalidOperationException("Use AddSqliteVectorStore with connectionString instead.");
///
/// Register a SQLite with the specified service ID.
@@ -60,24 +43,9 @@ public static IServiceCollection AddSqliteVectorStore(
string connectionString,
SqliteVectorStoreOptions? options = default,
string? serviceId = default)
- {
- services.AddKeyedTransient(
+ => services.AddKeyedSingleton(
serviceId,
- (sp, obj) =>
- {
- var connection = new SqliteConnection(connectionString);
- var extensionName = GetExtensionName(options?.VectorSearchExtensionName);
-
- connection.Open();
-
- connection.LoadExtension(extensionName);
-
- var selectedOptions = options ?? sp.GetService();
- return new SqliteVectorStore(connection, options);
- });
-
- return services;
- }
+ (sp, _) => new SqliteVectorStore(connectionString, options ?? sp.GetService()));
///
/// Register a SQLite and with the specified service ID
@@ -91,33 +59,14 @@ public static IServiceCollection AddSqliteVectorStore(
/// Optional options to further configure the .
/// An optional service id to use as the service key.
/// Service collection.
+ [Obsolete("Use AddSqliteVectorStoreRecordCollection with connectionString instead.", error: true)]
public static IServiceCollection AddSqliteVectorStoreRecordCollection(
this IServiceCollection services,
string collectionName,
SqliteVectorStoreRecordCollectionOptions? options = default,
string? serviceId = default)
where TKey : notnull
- {
- services.AddKeyedTransient>(
- serviceId,
- (sp, obj) =>
- {
- var connection = sp.GetRequiredService();
-
- if (connection.State != ConnectionState.Open)
- {
- connection.Open();
- }
-
- var selectedOptions = options ?? sp.GetService>();
-
- return (new SqliteVectorStoreRecordCollection(connection, collectionName, selectedOptions) as IVectorStoreRecordCollection)!;
- });
-
- AddVectorizedSearch(services, serviceId);
-
- return services;
- }
+ => throw new InvalidOperationException("Use AddSqliteVectorStore with connectionString instead.");
///
/// Register a SQLite and with the specified service ID.
@@ -139,21 +88,14 @@ public static IServiceCollection AddSqliteVectorStoreRecordCollection>(
+ services.AddKeyedSingleton>(
serviceId,
- (sp, obj) =>
- {
- var connection = new SqliteConnection(connectionString);
- var extensionName = GetExtensionName(options?.VectorSearchExtensionName);
-
- connection.Open();
-
- connection.LoadExtension(extensionName);
-
- var selectedOptions = options ?? sp.GetService>();
-
- return (new SqliteVectorStoreRecordCollection(connection, collectionName, selectedOptions) as IVectorStoreRecordCollection)!;
- });
+ (sp, _) => (
+ new SqliteVectorStoreRecordCollection(
+ connectionString,
+ collectionName,
+ options ?? sp.GetService>())
+ as IVectorStoreRecordCollection)!);
AddVectorizedSearch(services, serviceId);
@@ -169,20 +111,7 @@ public static IServiceCollection AddSqliteVectorStoreRecordCollectionThe service id that the registrations should use.
private static void AddVectorizedSearch(IServiceCollection services, string? serviceId)
where TKey : notnull
- {
- services.AddKeyedTransient>(
+ => services.AddKeyedSingleton>(
serviceId,
- (sp, obj) =>
- {
- return sp.GetRequiredKeyedService>(serviceId);
- });
- }
-
- ///
- /// Returns extension name for vector search.
- ///
- private static string GetExtensionName(string? extensionName)
- {
- return !string.IsNullOrWhiteSpace(extensionName) ? extensionName! : SqliteConstants.VectorSearchExtensionName;
- }
+ (sp, _) => sp.GetRequiredKeyedService>(serviceId));
}
diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStore.cs
index 86571536a5d5..3f246bd01e92 100644
--- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStore.cs
+++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStore.cs
@@ -18,8 +18,8 @@ namespace Microsoft.SemanticKernel.Connectors.Sqlite;
///
public sealed class SqliteVectorStore : IVectorStore
{
- /// that will be used to manage the data in SQLite.
- private readonly DbConnection _connection;
+ /// The connection string for the SQLite database represented by this .
+ private readonly string _connectionString;
/// Optional configuration options for this class.
private readonly SqliteVectorStoreOptions _options;
@@ -27,18 +27,27 @@ public sealed class SqliteVectorStore : IVectorStore
///
/// Initializes a new instance of the class.
///
- /// that will be used to manage the data in SQLite.
+ /// The connection string for the SQLite database represented by this .
/// Optional configuration options for this class.
- public SqliteVectorStore(
- DbConnection connection,
- SqliteVectorStoreOptions? options = default)
+ public SqliteVectorStore(string connectionString, SqliteVectorStoreOptions? options = default)
{
- Verify.NotNull(connection);
+ Verify.NotNull(connectionString);
- this._connection = connection;
+ this._connectionString = connectionString;
this._options = options ?? new();
}
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// that will be used to manage the data in SQLite.
+ /// Optional configuration options for this class.
+ [Obsolete("Use the constructor that accepts a connection string instead.", error: true)]
+ public SqliteVectorStore(
+ DbConnection connection,
+ SqliteVectorStoreOptions? options = default)
+ => throw new InvalidOperationException("Use the constructor that accepts a connection string instead.");
+
///
public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null)
where TKey : notnull
@@ -46,7 +55,7 @@ public IVectorStoreRecordCollection GetCollection(
if (this._options.VectorStoreCollectionFactory is not null)
{
return this._options.VectorStoreCollectionFactory.CreateVectorStoreRecordCollection(
- this._connection,
+ this._connectionString,
name,
vectorStoreRecordDefinition);
}
@@ -57,7 +66,7 @@ public IVectorStoreRecordCollection GetCollection(
}
var recordCollection = new SqliteVectorStoreRecordCollection(
- this._connection,
+ this._connectionString,
name,
new()
{
@@ -75,7 +84,9 @@ public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancel
const string TablePropertyName = "name";
const string Query = $"SELECT {TablePropertyName} FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';";
- using var command = this._connection.CreateCommand();
+ using var connection = new SqliteConnection(this._connectionString);
+ await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
+ using var command = connection.CreateCommand();
command.CommandText = Query;
diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs
index 837e3044ddc7..3f9690d146a5 100644
--- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs
+++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs
@@ -17,21 +17,7 @@ namespace Microsoft.SemanticKernel.Connectors.Sqlite;
[SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities", Justification = "User input is passed using command parameters.")]
internal sealed class SqliteVectorStoreCollectionCommandBuilder
{
- /// that will be used to manage the data in SQLite.
- private readonly DbConnection _connection;
-
- ///
- /// Initializes a new instance of the class.
- ///
- /// that will be used to manage the data in SQLite.
- public SqliteVectorStoreCollectionCommandBuilder(DbConnection connection)
- {
- Verify.NotNull(connection);
-
- this._connection = connection;
- }
-
- public DbCommand BuildTableCountCommand(string tableName)
+ public DbCommand BuildTableCountCommand(SqliteConnection connection, string tableName)
{
Verify.NotNullOrWhiteSpace(tableName);
@@ -40,7 +26,7 @@ public DbCommand BuildTableCountCommand(string tableName)
var query = $"SELECT count(*) FROM {SystemTable} WHERE type='table' AND name={ParameterName};";
- var command = this._connection.CreateCommand();
+ var command = connection.CreateCommand();
command.CommandText = query;
@@ -49,7 +35,7 @@ public DbCommand BuildTableCountCommand(string tableName)
return command;
}
- public DbCommand BuildCreateTableCommand(string tableName, IReadOnlyList columns, bool ifNotExists)
+ public DbCommand BuildCreateTableCommand(SqliteConnection connection, string tableName, IReadOnlyList columns, bool ifNotExists)
{
var builder = new StringBuilder();
@@ -58,7 +44,7 @@ public DbCommand BuildCreateTableCommand(string tableName, IReadOnlyList columns,
bool ifNotExists,
@@ -78,18 +65,18 @@ public DbCommand BuildCreateVirtualTableCommand(
builder.AppendLine(string.Join(",\n", columns.Select(GetColumnDefinition)));
builder.Append(");");
- var command = this._connection.CreateCommand();
+ var command = connection.CreateCommand();
command.CommandText = builder.ToString();
return command;
}
- public DbCommand BuildDropTableCommand(string tableName)
+ public DbCommand BuildDropTableCommand(SqliteConnection connection, string tableName)
{
string query = $"DROP TABLE [{tableName}];";
- var command = this._connection.CreateCommand();
+ var command = connection.CreateCommand();
command.CommandText = query;
@@ -97,6 +84,7 @@ public DbCommand BuildDropTableCommand(string tableName)
}
public DbCommand BuildInsertCommand(
+ SqliteConnection connection,
string tableName,
string rowIdentifier,
IReadOnlyList columnNames,
@@ -104,7 +92,7 @@ public DbCommand BuildInsertCommand(
bool replaceIfExists = false)
{
var builder = new StringBuilder();
- var command = this._connection.CreateCommand();
+ var command = connection.CreateCommand();
var replacePlaceholder = replaceIfExists ? " OR REPLACE" : string.Empty;
@@ -133,6 +121,7 @@ public DbCommand BuildInsertCommand(
}
public DbCommand BuildSelectCommand(
+ SqliteConnection connection,
string tableName,
IReadOnlyList columnNames,
List conditions,
@@ -140,7 +129,7 @@ public DbCommand BuildSelectCommand(
{
var builder = new StringBuilder();
- var (command, whereClause) = this.GetCommandWithWhereClause(conditions);
+ var (command, whereClause) = this.GetCommandWithWhereClause(connection, conditions);
builder.AppendLine($"SELECT {string.Join(", ", columnNames)}");
builder.AppendLine($"FROM {tableName}");
@@ -154,6 +143,7 @@ public DbCommand BuildSelectCommand(
}
public DbCommand BuildSelectLeftJoinCommand(
+ SqliteConnection connection,
string leftTable,
string rightTable,
string joinColumnName,
@@ -172,7 +162,7 @@ .. leftTablePropertyNames.Select(property => $"{leftTable}.{property}"),
.. rightTablePropertyNames.Select(property => $"{rightTable}.{property}"),
];
- var (command, whereClause) = this.GetCommandWithWhereClause(conditions, extraWhereFilter, extraParameters);
+ var (command, whereClause) = this.GetCommandWithWhereClause(connection, conditions, extraWhereFilter, extraParameters);
builder.AppendLine($"SELECT {string.Join(", ", propertyNames)}");
builder.AppendLine($"FROM {leftTable} ");
@@ -187,12 +177,13 @@ .. rightTablePropertyNames.Select(property => $"{rightTable}.{property}"),
}
public DbCommand BuildDeleteCommand(
+ SqliteConnection connection,
string tableName,
List conditions)
{
var builder = new StringBuilder();
- var (command, whereClause) = this.GetCommandWithWhereClause(conditions);
+ var (command, whereClause) = this.GetCommandWithWhereClause(connection, conditions);
builder.AppendLine($"DELETE FROM [{tableName}]");
@@ -242,13 +233,14 @@ private static string GetColumnDefinition(SqliteColumn column)
}
private (DbCommand Command, string WhereClause) GetCommandWithWhereClause(
+ SqliteConnection connection,
List conditions,
string? extraWhereFilter = null,
Dictionary? extraParameters = null)
{
const string WhereClauseOperator = " AND ";
- var command = this._connection.CreateCommand();
+ var command = connection.CreateCommand();
var whereClauseParts = new List();
foreach (var condition in conditions)
diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs
index 8ae095dd3bf0..091643fb8bfb 100644
--- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs
+++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs
@@ -7,6 +7,7 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.Data.Sqlite;
using Microsoft.Extensions.VectorData;
namespace Microsoft.SemanticKernel.Connectors.Sqlite;
@@ -24,8 +25,8 @@ public sealed class SqliteVectorStoreRecordCollection :
/// The name of this database for telemetry purposes.
private const string DatabaseName = "SQLite";
- /// that will be used to manage the data in SQLite.
- private readonly DbConnection _connection;
+ /// The connection string for the SQLite database represented by this .
+ private readonly string _connectionString;
/// Optional configuration options for this class.
private readonly SqliteVectorStoreRecordCollectionOptions _options;
@@ -63,30 +64,34 @@ public sealed class SqliteVectorStoreRecordCollection :
/// Table name in SQLite for vector properties.
private readonly string _vectorTableName;
+ /// The sqlite_vec extension name to use.
+ private readonly string _vectorSearchExtensionName;
+
///
public string CollectionName { get; }
///
/// Initializes a new instance of the class.
///
- /// that will be used to manage the data in SQLite.
+ /// The connection string for the SQLite database represented by this .
/// The name of the collection/table that this will access.
/// Optional configuration options for this class.
public SqliteVectorStoreRecordCollection(
- DbConnection connection,
+ string connectionString,
string collectionName,
SqliteVectorStoreRecordCollectionOptions? options = default)
{
// Verify.
- Verify.NotNull(connection);
+ Verify.NotNull(connectionString);
Verify.NotNullOrWhiteSpace(collectionName);
VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.DictionaryCustomMapper is not null, SqliteConstants.SupportedKeyTypes);
VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null);
// Assign.
- this._connection = connection;
+ this._connectionString = connectionString;
this.CollectionName = collectionName;
this._options = options ?? new();
+ this._vectorSearchExtensionName = this._options.VectorSearchExtensionName ?? SqliteConstants.VectorSearchExtensionName;
this._dataTableName = this.CollectionName;
this._vectorTableName = GetVectorTableName(this._dataTableName, this._options);
@@ -111,7 +116,7 @@ public SqliteVectorStoreRecordCollection(
this._mapper = this.InitializeMapper();
- this._commandBuilder = new SqliteVectorStoreCollectionCommandBuilder(this._connection);
+ this._commandBuilder = new SqliteVectorStoreCollectionCommandBuilder();
}
///
@@ -119,7 +124,8 @@ public async Task CollectionExistsAsync(CancellationToken cancellationToke
{
const string OperationName = "TableCount";
- using var command = this._commandBuilder.BuildTableCountCommand(this._dataTableName);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+ using var command = this._commandBuilder.BuildTableCountCommand(connection, this._dataTableName);
var result = await this
.RunOperationAsync(OperationName, () => command.ExecuteScalarAsync(cancellationToken))
@@ -131,25 +137,31 @@ public async Task CollectionExistsAsync(CancellationToken cancellationToke
}
///
- public Task CreateCollectionAsync(CancellationToken cancellationToken = default)
+ public async Task CreateCollectionAsync(CancellationToken cancellationToken = default)
{
- return this.InternalCreateCollectionAsync(ifNotExists: false, cancellationToken);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+ await this.InternalCreateCollectionAsync(connection, ifNotExists: false, cancellationToken)
+ .ConfigureAwait(false);
}
///
- public Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default)
+ public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default)
{
- return this.InternalCreateCollectionAsync(ifNotExists: true, cancellationToken);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+ await this.InternalCreateCollectionAsync(connection, ifNotExists: true, cancellationToken)
+ .ConfigureAwait(false);
}
///
public async Task DeleteCollectionAsync(CancellationToken cancellationToken = default)
{
- await this.DropTableAsync(this._dataTableName, cancellationToken).ConfigureAwait(false);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+
+ await this.DropTableAsync(connection, this._dataTableName, cancellationToken).ConfigureAwait(false);
if (this._vectorPropertiesExist)
{
- await this.DropTableAsync(this._vectorTableName, cancellationToken).ConfigureAwait(false);
+ await this.DropTableAsync(connection, this._vectorTableName, cancellationToken).ConfigureAwait(false);
}
}
@@ -227,39 +239,57 @@ public Task> VectorizedSearchAsync(TVector
#region Implementation of IVectorStoreRecordCollection
///
- public Task GetAsync(ulong key, GetRecordOptions? options = null, CancellationToken cancellationToken = default)
+ public async Task GetAsync(ulong key, GetRecordOptions? options = null, CancellationToken cancellationToken = default)
{
- return this.InternalGetAsync(key, options, cancellationToken);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+ return await this.InternalGetAsync(connection, key, options, cancellationToken).ConfigureAwait(false);
}
///
- public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable GetBatchAsync(
+ IEnumerable keys,
+ GetRecordOptions? options = null,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- return this.InternalGetBatchAsync(keys, options, cancellationToken);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+ await foreach (var record in this.InternalGetBatchAsync(connection, keys, options, cancellationToken).ConfigureAwait(false))
+ {
+ yield return record;
+ }
}
///
- public Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default)
+ public async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default)
{
- return this.InternalUpsertAsync(record, cancellationToken);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+ return await this.InternalUpsertAsync(connection, record, cancellationToken).ConfigureAwait(false);
}
///
- public IAsyncEnumerable UpsertBatchAsync(IEnumerable records, CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable UpsertBatchAsync(
+ IEnumerable records,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- return this.InternalUpsertBatchAsync(records, cancellationToken);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+ await foreach (var record in this.InternalUpsertBatchAsync(connection, records, cancellationToken)
+ .ConfigureAwait(false))
+ {
+ yield return record;
+ }
}
///
- public Task DeleteAsync(ulong key, CancellationToken cancellationToken = default)
+ public async Task DeleteAsync(ulong key, CancellationToken cancellationToken = default)
{
- return this.InternalDeleteAsync(key, cancellationToken);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+ await this.InternalDeleteAsync(connection, key, cancellationToken).ConfigureAwait(false);
}
///
- public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default)
+ public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default)
{
- return this.InternalDeleteBatchAsync(keys, cancellationToken);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+ await this.InternalDeleteBatchAsync(connection, keys, cancellationToken).ConfigureAwait(false);
}
#endregion
@@ -267,45 +297,73 @@ public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancella
#region Implementation of IVectorStoreRecordCollection
///
- public Task GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default)
+ public async Task GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default)
{
- return this.InternalGetAsync(key, options, cancellationToken);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+ return await this.InternalGetAsync(connection, key, options, cancellationToken).ConfigureAwait(false);
}
///
- public IAsyncEnumerable GetBatchAsync(IEnumerable keys, GetRecordOptions? options = null, CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable GetBatchAsync(
+ IEnumerable keys,
+ GetRecordOptions? options = null,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- return this.InternalGetBatchAsync(keys, options, cancellationToken);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+ await foreach (var record in this.InternalGetBatchAsync(connection, keys, options, cancellationToken).ConfigureAwait(false))
+ {
+ yield return record;
+ }
}
///
- Task IVectorStoreRecordCollection.UpsertAsync(TRecord record, CancellationToken cancellationToken)
+ async Task IVectorStoreRecordCollection.UpsertAsync(TRecord record, CancellationToken cancellationToken)
{
- return this.InternalUpsertAsync(record, cancellationToken);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+ return await this.InternalUpsertAsync(connection, record, cancellationToken)
+ .ConfigureAwait(false);
}
///
- IAsyncEnumerable IVectorStoreRecordCollection.UpsertBatchAsync(IEnumerable records, CancellationToken cancellationToken)
+ async IAsyncEnumerable IVectorStoreRecordCollection.UpsertBatchAsync(
+ IEnumerable records,
+ [EnumeratorCancellation] CancellationToken cancellationToken)
{
- return this.InternalUpsertBatchAsync(records, cancellationToken);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+ await foreach (var record in this.InternalUpsertBatchAsync(connection, records, cancellationToken)
+ .ConfigureAwait(false))
+ {
+ yield return record;
+ }
}
///
- public Task DeleteAsync(string key, CancellationToken cancellationToken = default)
+ public async Task DeleteAsync(string key, CancellationToken cancellationToken = default)
{
- return this.InternalDeleteAsync(key, cancellationToken);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+ await this.InternalDeleteAsync(connection, key, cancellationToken)
+ .ConfigureAwait(false);
}
///
- public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default)
+ public async Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken = default)
{
- return this.InternalDeleteBatchAsync(keys, cancellationToken);
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
+ await this.InternalDeleteBatchAsync(connection, keys, cancellationToken).ConfigureAwait(false);
}
#endregion
#region private
+ private async ValueTask GetConnectionAsync(CancellationToken cancellationToken = default)
+ {
+ var connection = new SqliteConnection(this._connectionString);
+ await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
+ connection.LoadExtension(this._vectorSearchExtensionName);
+ return connection;
+ }
+
private async IAsyncEnumerable> EnumerateAndMapSearchResultsAsync(
List conditions,
string? extraWhereFilter,
@@ -326,7 +384,9 @@ private async IAsyncEnumerable> EnumerateAndMapSearc
properties.AddRange(this._propertyReader.VectorProperties);
}
+ using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
using var command = this._commandBuilder.BuildSelectLeftJoinCommand(
+ connection,
this._vectorTableName,
this._dataTableName,
this._propertyReader.KeyPropertyStoragePropertyName,
@@ -356,13 +416,14 @@ private async IAsyncEnumerable> EnumerateAndMapSearc
}
}
- private Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken)
+ private Task InternalCreateCollectionAsync(SqliteConnection connection, bool ifNotExists, CancellationToken cancellationToken)
{
List dataTableColumns = SqliteVectorStoreRecordPropertyMapping.GetColumns(
this._dataTableProperties.Value,
this._propertyReader.StoragePropertyNamesMap);
List tasks = [this.CreateTableAsync(
+ connection,
this._dataTableName,
dataTableColumns,
ifNotExists,
@@ -379,6 +440,7 @@ private Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken c
this._propertyReader.StoragePropertyNamesMap);
tasks.Add(this.CreateVirtualTableAsync(
+ connection,
this._vectorTableName,
vectorTableColumns,
ifNotExists,
@@ -389,34 +451,35 @@ private Task InternalCreateCollectionAsync(bool ifNotExists, CancellationToken c
return Task.WhenAll(tasks);
}
- private Task CreateTableAsync(string tableName, List columns, bool ifNotExists, CancellationToken cancellationToken)
+ private Task CreateTableAsync(SqliteConnection connection, string tableName, List columns, bool ifNotExists, CancellationToken cancellationToken)
{
const string OperationName = "CreateTable";
- using var command = this._commandBuilder.BuildCreateTableCommand(tableName, columns, ifNotExists);
+ using var command = this._commandBuilder.BuildCreateTableCommand(connection, tableName, columns, ifNotExists);
return this.RunOperationAsync(OperationName, () => command.ExecuteNonQueryAsync(cancellationToken));
}
- private Task CreateVirtualTableAsync(string tableName, List columns, bool ifNotExists, string extensionName, CancellationToken cancellationToken)
+ private Task CreateVirtualTableAsync(SqliteConnection connection, string tableName, List columns, bool ifNotExists, string extensionName, CancellationToken cancellationToken)
{
const string OperationName = "CreateVirtualTable";
- using var command = this._commandBuilder.BuildCreateVirtualTableCommand(tableName, columns, ifNotExists, extensionName);
+ using var command = this._commandBuilder.BuildCreateVirtualTableCommand(connection, tableName, columns, ifNotExists, extensionName);
return this.RunOperationAsync(OperationName, () => command.ExecuteNonQueryAsync(cancellationToken));
}
- private Task DropTableAsync(string tableName, CancellationToken cancellationToken)
+ private Task DropTableAsync(SqliteConnection connection, string tableName, CancellationToken cancellationToken)
{
const string OperationName = "DropTable";
- using var command = this._commandBuilder.BuildDropTableCommand(tableName);
+ using var command = this._commandBuilder.BuildDropTableCommand(connection, tableName);
return this.RunOperationAsync(OperationName, () => command.ExecuteNonQueryAsync(cancellationToken));
}
private async Task InternalGetAsync(
+ SqliteConnection connection,
TKey key,
GetRecordOptions? options,
CancellationToken cancellationToken)
@@ -428,12 +491,13 @@ private Task DropTableAsync(string tableName, CancellationToken cancellatio
TableName = this._dataTableName
};
- return await this.InternalGetBatchAsync(condition, options, cancellationToken)
+ return await this.InternalGetBatchAsync(connection, condition, options, cancellationToken)
.FirstOrDefaultAsync(cancellationToken)
.ConfigureAwait(false);
}
private IAsyncEnumerable InternalGetBatchAsync(
+ SqliteConnection connection,
IEnumerable keys,
GetRecordOptions? options,
CancellationToken cancellationToken)
@@ -449,10 +513,11 @@ private IAsyncEnumerable InternalGetBatchAsync(
TableName = this._dataTableName
};
- return this.InternalGetBatchAsync(condition, options, cancellationToken);
+ return this.InternalGetBatchAsync(connection, condition, options, cancellationToken);
}
private async IAsyncEnumerable InternalGetBatchAsync(
+ SqliteConnection connection,
SqliteWhereCondition condition,
GetRecordOptions? options,
[EnumeratorCancellation] CancellationToken cancellationToken)
@@ -467,6 +532,7 @@ private async IAsyncEnumerable InternalGetBatchAsync(
if (includeVectors)
{
command = this._commandBuilder.BuildSelectLeftJoinCommand(
+ connection,
this._dataTableName,
this._vectorTableName,
this._propertyReader.KeyPropertyStoragePropertyName,
@@ -479,6 +545,7 @@ private async IAsyncEnumerable InternalGetBatchAsync(
else
{
command = this._commandBuilder.BuildSelectCommand(
+ connection,
this._dataTableName,
this._dataTableStoragePropertyNames.Value,
[condition]);
@@ -499,7 +566,7 @@ private async IAsyncEnumerable InternalGetBatchAsync(
}
}
- private async Task InternalUpsertAsync(TRecord record, CancellationToken cancellationToken)
+ private async Task InternalUpsertAsync(SqliteConnection connection, TRecord record, CancellationToken cancellationToken)
{
const string OperationName = "Upsert";
@@ -515,14 +582,14 @@ private async Task InternalUpsertAsync(TRecord record, CancellationT
var condition = new SqliteWhereEqualsCondition(this._propertyReader.KeyPropertyStoragePropertyName, key);
- var upsertedRecordKey = await this.InternalUpsertBatchAsync([storageModel], condition, cancellationToken)
+ var upsertedRecordKey = await this.InternalUpsertBatchAsync(connection, [storageModel], condition, cancellationToken)
.FirstOrDefaultAsync(cancellationToken)
.ConfigureAwait(false);
return upsertedRecordKey ?? throw new VectorStoreOperationException("Error occurred during upsert operation.");
}
- private IAsyncEnumerable InternalUpsertBatchAsync(IEnumerable records, CancellationToken cancellationToken)
+ private IAsyncEnumerable InternalUpsertBatchAsync(SqliteConnection connection, IEnumerable records, CancellationToken cancellationToken)
{
const string OperationName = "UpsertBatch";
@@ -536,10 +603,11 @@ private IAsyncEnumerable InternalUpsertBatchAsync(IEnumerable(storageModels, condition, cancellationToken);
+ return this.InternalUpsertBatchAsync(connection, storageModels, condition, cancellationToken);
}
private async IAsyncEnumerable InternalUpsertBatchAsync(
+ SqliteConnection connection,
List> storageModels,
SqliteWhereCondition condition,
[EnumeratorCancellation] CancellationToken cancellationToken)
@@ -552,12 +620,14 @@ private async IAsyncEnumerable InternalUpsertBatchAsync(
// Deleting vector records first since current version of vector search extension
// doesn't support Upsert operation, only Delete/Insert.
using var vectorDeleteCommand = this._commandBuilder.BuildDeleteCommand(
+ connection,
this._vectorTableName,
[condition]);
await vectorDeleteCommand.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
using var vectorInsertCommand = this._commandBuilder.BuildInsertCommand(
+ connection,
this._vectorTableName,
this._propertyReader.KeyPropertyStoragePropertyName,
this._vectorTableStoragePropertyNames.Value,
@@ -567,11 +637,12 @@ private async IAsyncEnumerable InternalUpsertBatchAsync(
}
using var dataCommand = this._commandBuilder.BuildInsertCommand(
- this._dataTableName,
- this._propertyReader.KeyPropertyStoragePropertyName,
- this._dataTableStoragePropertyNames.Value,
- storageModels,
- replaceIfExists: true);
+ connection,
+ this._dataTableName,
+ this._propertyReader.KeyPropertyStoragePropertyName,
+ this._dataTableStoragePropertyNames.Value,
+ storageModels,
+ replaceIfExists: true);
using var reader = await dataCommand.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
@@ -588,16 +659,16 @@ private async IAsyncEnumerable InternalUpsertBatchAsync(
}
}
- private Task InternalDeleteAsync(TKey key, CancellationToken cancellationToken)
+ private Task InternalDeleteAsync(SqliteConnection connection, TKey key, CancellationToken cancellationToken)
{
Verify.NotNull(key);
var condition = new SqliteWhereEqualsCondition(this._propertyReader.KeyPropertyStoragePropertyName, key);
- return this.InternalDeleteBatchAsync(condition, cancellationToken);
+ return this.InternalDeleteBatchAsync(connection, condition, cancellationToken);
}
- private Task InternalDeleteBatchAsync(IEnumerable keys, CancellationToken cancellationToken)
+ private Task InternalDeleteBatchAsync(SqliteConnection connection, IEnumerable keys, CancellationToken cancellationToken)
{
Verify.NotNull(keys);
@@ -609,10 +680,10 @@ private Task InternalDeleteBatchAsync(IEnumerable keys, Cancellation
this._propertyReader.KeyPropertyStoragePropertyName,
keysList);
- return this.InternalDeleteBatchAsync(condition, cancellationToken);
+ return this.InternalDeleteBatchAsync(connection, condition, cancellationToken);
}
- private Task InternalDeleteBatchAsync(SqliteWhereCondition condition, CancellationToken cancellationToken)
+ private Task InternalDeleteBatchAsync(SqliteConnection connection, SqliteWhereCondition condition, CancellationToken cancellationToken)
{
const string OperationName = "Delete";
@@ -621,6 +692,7 @@ private Task InternalDeleteBatchAsync(SqliteWhereCondition condition, Cancellati
if (this._vectorPropertiesExist)
{
using var vectorCommand = this._commandBuilder.BuildDeleteCommand(
+ connection,
this._vectorTableName,
[condition]);
@@ -628,6 +700,7 @@ private Task InternalDeleteBatchAsync(SqliteWhereCondition condition, Cancellati
}
using var dataCommand = this._commandBuilder.BuildDeleteCommand(
+ connection,
this._dataTableName,
[condition]);
diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDBConnection.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDBConnection.cs
deleted file mode 100644
index 7c318e1ef413..000000000000
--- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDBConnection.cs
+++ /dev/null
@@ -1,32 +0,0 @@
-// Copyright (c) Microsoft. All rights reserved.
-
-using System;
-using System.Data;
-using System.Data.Common;
-
-namespace SemanticKernel.Connectors.Sqlite.UnitTests;
-
-#pragma warning disable CS8618, CS8765
-
-internal sealed class FakeDBConnection(DbCommand command) : DbConnection
-{
- public override string ConnectionString { get; set; }
-
- public override string Database => throw new NotImplementedException();
-
- public override string DataSource => throw new NotImplementedException();
-
- public override string ServerVersion => throw new NotImplementedException();
-
- public override ConnectionState State => throw new NotImplementedException();
-
- public override void ChangeDatabase(string databaseName) => throw new NotImplementedException();
-
- public override void Close() => throw new NotImplementedException();
-
- public override void Open() => throw new NotImplementedException();
-
- protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) => throw new NotImplementedException();
-
- protected override DbCommand CreateDbCommand() => command;
-}
diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbCommand.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbCommand.cs
deleted file mode 100644
index df6062d9a4c1..000000000000
--- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbCommand.cs
+++ /dev/null
@@ -1,45 +0,0 @@
-// Copyright (c) Microsoft. All rights reserved.
-
-using System;
-using System.Data;
-using System.Data.Common;
-
-namespace SemanticKernel.Connectors.Sqlite.UnitTests;
-
-#pragma warning disable CS8618, CS8765
-
-internal sealed class FakeDbCommand(
- DbDataReader? dataReader = null,
- object? scalarResult = null) : DbCommand
-{
- public int ExecuteNonQueryCallCount { get; private set; } = 0;
-
- private readonly FakeDbParameterCollection _parameterCollection = [];
-
- public override string CommandText { get; set; }
- public override int CommandTimeout { get; set; }
- public override CommandType CommandType { get; set; }
- public override bool DesignTimeVisible { get; set; }
- public override UpdateRowSource UpdatedRowSource { get; set; }
- protected override DbConnection? DbConnection { get; set; }
-
- protected override DbParameterCollection DbParameterCollection => this._parameterCollection;
-
- protected override DbTransaction? DbTransaction { get; set; }
-
- public override void Cancel() => throw new NotImplementedException();
-
- public override int ExecuteNonQuery()
- {
- this.ExecuteNonQueryCallCount++;
- return 0;
- }
-
- public override object? ExecuteScalar() => scalarResult;
-
- public override void Prepare() => throw new NotImplementedException();
-
- protected override DbParameter CreateDbParameter() => throw new NotImplementedException();
-
- protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) => dataReader ?? throw new NotImplementedException();
-}
diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbParameterCollection.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbParameterCollection.cs
deleted file mode 100644
index 246b97a3360b..000000000000
--- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/Fakes/FakeDbParameterCollection.cs
+++ /dev/null
@@ -1,105 +0,0 @@
-// Copyright (c) Microsoft. All rights reserved.
-
-using System;
-using System.Collections;
-using System.Collections.Generic;
-using System.Data.Common;
-
-namespace SemanticKernel.Connectors.Sqlite.UnitTests;
-
-#pragma warning disable CA1812
-
-internal sealed class FakeDbParameterCollection : DbParameterCollection
-{
- private readonly List
public sealed class SqliteVectorStoreCollectionCommandBuilderTests : IDisposable
{
- private readonly FakeDbCommand _command;
- private readonly FakeDBConnection _connection;
+ private readonly SqliteCommand _command;
+ private readonly SqliteConnection _connection;
private readonly SqliteVectorStoreCollectionCommandBuilder _commandBuilder;
public SqliteVectorStoreCollectionCommandBuilderTests()
{
- this._command = new();
- this._connection = new(this._command);
- this._commandBuilder = new(this._connection);
+ this._command = new() { Connection = this._connection };
+ this._connection = new();
+ this._commandBuilder = new();
}
[Fact]
@@ -30,7 +31,7 @@ public void ItBuildsTableCountCommand()
const string TableName = "TestTable";
// Act
- var command = this._commandBuilder.BuildTableCountCommand(TableName);
+ var command = this._commandBuilder.BuildTableCountCommand(this._connection, TableName);
// Assert
Assert.Equal("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=@tableName;", command.CommandText);
@@ -53,7 +54,7 @@ public void ItBuildsCreateTableCommand(bool ifNotExists)
};
// Act
- var command = this._commandBuilder.BuildCreateTableCommand(TableName, columns, ifNotExists);
+ var command = this._commandBuilder.BuildCreateTableCommand(this._connection, TableName, columns, ifNotExists);
// Assert
Assert.Contains("CREATE TABLE", command.CommandText);
@@ -81,7 +82,7 @@ public void ItBuildsCreateVirtualTableCommand(bool ifNotExists)
};
// Act
- var command = this._commandBuilder.BuildCreateVirtualTableCommand(TableName, columns, ifNotExists, ExtensionName);
+ var command = this._commandBuilder.BuildCreateVirtualTableCommand(this._connection, TableName, columns, ifNotExists, ExtensionName);
// Assert
Assert.Contains("CREATE VIRTUAL TABLE", command.CommandText);
@@ -101,7 +102,7 @@ public void ItBuildsDropTableCommand()
const string TableName = "TestTable";
// Act
- var command = this._commandBuilder.BuildDropTableCommand(TableName);
+ var command = this._commandBuilder.BuildDropTableCommand(this._connection, TableName);
// Assert
Assert.Equal("DROP TABLE [TestTable];", command.CommandText);
@@ -125,6 +126,7 @@ public void ItBuildsInsertCommand(bool replaceIfExists)
// Act
var command = this._commandBuilder.BuildInsertCommand(
+ this._connection,
TableName,
RowIdentifier,
columnNames,
@@ -181,7 +183,7 @@ public void ItBuildsSelectCommand(string? orderByPropertyName)
};
// Act
- var command = this._commandBuilder.BuildSelectCommand(TableName, columnNames, conditions, orderByPropertyName);
+ var command = this._commandBuilder.BuildSelectCommand(this._connection, TableName, columnNames, conditions, orderByPropertyName);
// Assert
Assert.Contains("SELECT Id, Name, Age, Address", command.CommandText);
@@ -227,6 +229,7 @@ public void ItBuildsSelectLeftJoinCommand(string? orderByPropertyName)
// Act
var command = this._commandBuilder.BuildSelectLeftJoinCommand(
+ this._connection,
LeftTable,
RightTable,
JoinColumnName,
@@ -274,7 +277,7 @@ public void ItBuildsDeleteCommand()
};
// Act
- var command = this._commandBuilder.BuildDeleteCommand(TableName, conditions);
+ var command = this._commandBuilder.BuildDeleteCommand(this._connection, TableName, conditions);
// Assert
Assert.Contains("DELETE FROM [TestTable]", command.CommandText);
diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordCollectionTests.cs
index 631bf6cebf3d..59cc3c3401e4 100644
--- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordCollectionTests.cs
+++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreRecordCollectionTests.cs
@@ -1,5 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.
+// TODO: Reimplement these as integration tests, #10464
+
+#if DISABLED
+
using System;
using System.Collections.Generic;
using System.Data.Common;
@@ -400,3 +404,5 @@ private sealed class TestRecordWithoutVectorProperty
#endregion
}
+
+#endif
diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreTests.cs
index c6cf80e8b085..71656fe87a0d 100644
--- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreTests.cs
+++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreTests.cs
@@ -1,5 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.
+// TODO: Reimplement these as integration tests, #10464
+
+#if DISABLED
+
using System;
using System.Data.Common;
using System.Linq;
@@ -102,3 +106,5 @@ public async Task ListCollectionNamesReturnsCollectionNamesAsync()
Assert.Contains("collection2", collections);
}
}
+
+#endif
diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteServiceCollectionExtensionsTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteServiceCollectionExtensionsTests.cs
index 2e3e6b32fe52..bfded601d8ec 100644
--- a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteServiceCollectionExtensionsTests.cs
+++ b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteServiceCollectionExtensionsTests.cs
@@ -69,29 +69,6 @@ public void AddVectorStoreRecordCollectionWithNumericKeyAndSqliteConnectionRegis
Assert.IsType>(vectorizedSearch);
}
- [Fact(Skip = SkipReason)]
- public void ItClosesConnectionWhenDIServiceIsDisposed()
- {
- // Act
- using var connection = new SqliteConnection("Data Source=:memory:");
-
- this._serviceCollection.AddTransient(_ => connection);
-
- this._serviceCollection.AddSqliteVectorStore();
-
- var serviceProvider = this._serviceCollection.BuildServiceProvider();
-
- using (var scope = serviceProvider.CreateScope())
- {
- scope.ServiceProvider.GetRequiredService();
-
- Assert.Equal(ConnectionState.Open, connection.State);
- }
-
- // Assert
- Assert.Equal(ConnectionState.Closed, connection.State);
- }
-
#region private
#pragma warning disable CA1812 // Avoid uninstantiated internal classes
diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreFixture.cs
index 6f07f20ddf67..c3a702c5a7c0 100644
--- a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreFixture.cs
+++ b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreFixture.cs
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
+using System.IO;
using System.Threading.Tasks;
using Microsoft.Data.Sqlite;
using Microsoft.SemanticKernel.Connectors.Sqlite;
@@ -8,43 +9,22 @@
namespace SemanticKernel.IntegrationTests.Connectors.Memory.Sqlite;
-public class SqliteVectorStoreFixture : IAsyncLifetime, IDisposable
+public class SqliteVectorStoreFixture : IDisposable
{
- ///
- /// SQLite extension name for vector search.
- /// More information here: .
- ///
- private const string VectorSearchExtensionName = "vec0";
+ private readonly string _databasePath = Path.GetTempFileName();
- public SqliteConnection Connection { get; }
-
- public SqliteVectorStoreFixture()
- {
- this.Connection = new SqliteConnection("Data Source=:memory:");
- }
+ public string ConnectionString => $"Data Source={this._databasePath}";
public SqliteVectorStoreRecordCollection GetCollection(
string collectionName,
SqliteVectorStoreRecordCollectionOptions? options = default)
{
return new SqliteVectorStoreRecordCollection(
- this.Connection,
+ this.ConnectionString,
collectionName,
options);
}
- public Task DisposeAsync()
- {
- return Task.CompletedTask;
- }
-
- public async Task InitializeAsync()
- {
- await this.Connection.OpenAsync();
-
- this.Connection.LoadExtension(VectorSearchExtensionName);
- }
-
public void Dispose()
{
this.Dispose(true);
@@ -55,7 +35,7 @@ protected virtual void Dispose(bool disposing)
{
if (disposing)
{
- this.Connection.Dispose();
+ File.Delete(this._databasePath);
}
}
}
diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs
index c0dbb5fcf680..77b97a0a4428 100644
--- a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs
+++ b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs
@@ -4,6 +4,7 @@
using System.Linq;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
+using Microsoft.Data.Sqlite;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel.Connectors.Sqlite;
using Xunit;
@@ -245,7 +246,9 @@ public async Task ItCanGetExistingRecordAsync(bool includeVectors)
var record = CreateTestHotel(HotelId);
- var commandData = fixture.Connection.CreateCommand();
+ using var connection = new SqliteConnection(fixture.ConnectionString);
+ await connection.OpenAsync();
+ var commandData = connection.CreateCommand();
commandData.CommandText =
$"INSERT INTO {collectionName} " +
@@ -262,7 +265,7 @@ public async Task ItCanGetExistingRecordAsync(bool includeVectors)
if (includeVectors)
{
- var commandVector = fixture.Connection.CreateCommand();
+ var commandVector = connection.CreateCommand();
commandVector.CommandText =
$"INSERT INTO vec_{collectionName} " +
diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreTests.cs
index 8a173250f7fe..6eca22778b02 100644
--- a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreTests.cs
+++ b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreTests.cs
@@ -17,7 +17,7 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Sqlite;
[Collection("SqliteVectorStoreCollection")]
[DisableVectorStoreTests(Skip = "SQLite vector search extension is required")]
public sealed class SqliteVectorStoreTests(SqliteVectorStoreFixture fixture)
- : BaseVectorStoreTests>(new SqliteVectorStore(fixture.Connection!))
+ : BaseVectorStoreTests>(new SqliteVectorStore(fixture.ConnectionString))
{
[VectorStoreFact]
public async Task ItCanGetAListOfExistingCollectionNamesWhenRegisteredWithDIAsync()
@@ -25,7 +25,7 @@ public async Task ItCanGetAListOfExistingCollectionNamesWhenRegisteredWithDIAsyn
// Arrange
var serviceCollection = new ServiceCollection();
- serviceCollection.AddSqliteVectorStore(connectionString: "Data Source=:memory:");
+ serviceCollection.AddSqliteVectorStore(fixture.ConnectionString);
var provider = serviceCollection.BuildServiceProvider();
diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs
index 526eeac3b2d8..9b025c66610f 100644
--- a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs
+++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs
@@ -1,21 +1,16 @@
// Copyright (c) Microsoft. All rights reserved.
-using Microsoft.Data.Sqlite;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel.Connectors.Sqlite;
using VectorDataSpecificationTests.Support;
namespace SqliteIntegrationTests.Support;
-#pragma warning disable CA1001 // Type owns disposable fields (_connection) but is not disposable
-
internal sealed class SqliteTestStore : TestStore
{
- public static SqliteTestStore Instance { get; } = new();
+ private string? _databasePath;
- private SqliteConnection? _connection;
- public SqliteConnection Connection
- => this._connection ?? throw new InvalidOperationException("Call InitializeAsync() first");
+ public static SqliteTestStore Instance { get; } = new();
private SqliteVectorStore? _defaultVectorStore;
public override IVectorStore DefaultVectorStore
@@ -25,31 +20,17 @@ private SqliteTestStore()
{
}
- protected override async Task StartAsync()
+ protected override Task StartAsync()
{
- this._connection = new SqliteConnection("Data Source=:memory:");
-
- await this.Connection.OpenAsync();
-
- if (!SqliteTestEnvironment.TryLoadSqliteVec(this.Connection))
- {
- this.Connection.Dispose();
-
- // Note that we ignore sqlite_vec loading failures; the tests are decorated with [SqliteVecRequired], which causes
- // them to be skipped if sqlite_vec isn't installed (better than an exception triggering failure here)
- }
-
- this._defaultVectorStore = new SqliteVectorStore(this.Connection);
+ this._databasePath = Path.GetTempFileName();
+ this._defaultVectorStore = new SqliteVectorStore($"Data Source={this._databasePath}");
+ return Task.CompletedTask;
}
-#if NET8_0_OR_GREATER
- protected override async Task StopAsync()
- => await this.Connection.DisposeAsync();
-#else
protected override Task StopAsync()
{
- this.Connection.Dispose();
+ File.Delete(this._databasePath!);
+ this._databasePath = null;
return Task.CompletedTask;
}
-#endif
}