Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Data.Common;
using Microsoft.Extensions.VectorData;

namespace Microsoft.SemanticKernel.Connectors.Sqlite;
Expand All @@ -17,12 +16,12 @@ public interface ISqliteVectorStoreRecordCollectionFactory
/// </summary>
/// <typeparam name="TKey">The data type of the record key.</typeparam>
/// <typeparam name="TRecord">The data model to use for adding, updating and retrieving data from storage.</typeparam>
/// <param name="connection"><see cref="DbConnection"/> that will be used to manage the data in SQLite.</param>
/// <param name="connectionString">The connection string for the SQLite database represented by this <see cref="SqliteVectorStore"/>.</param>
/// <param name="name">The name of the collection to connect to.</param>
/// <param name="vectorStoreRecordDefinition">An optional record definition that defines the schema of the record type. If not present, attributes on <typeparamref name="TRecord"/> will be used.</param>
/// <returns>The new instance of <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>.</returns>
IVectorStoreRecordCollection<TKey, TRecord> CreateVectorStoreRecordCollection<TKey, TRecord>(
DbConnection connection,
string connectionString,
string name,
VectorStoreRecordDefinition? vectorStoreRecordDefinition)
where TKey : notnull;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Data;
using System;
using Microsoft.Data.Sqlite;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.VectorData;
Expand All @@ -22,29 +22,12 @@ public static class SqliteServiceCollectionExtensions
/// <param name="options">Optional options to further configure the <see cref="IVectorStore"/>.</param>
/// <param name="serviceId">An optional service id to use as the service key.</param>
/// <returns>Service collection.</returns>
[Obsolete("Use AddSqliteVectorStore with connectionString instead.", error: true)]
public static IServiceCollection AddSqliteVectorStore(
this IServiceCollection services,
SqliteVectorStoreOptions? options = default,
string? serviceId = default)
{
services.AddKeyedTransient<IVectorStore>(
serviceId,
(sp, obj) =>
{
var connection = sp.GetRequiredService<SqliteConnection>();

if (connection.State != ConnectionState.Open)
{
connection.Open();
}

var selectedOptions = options ?? sp.GetService<SqliteVectorStoreOptions>();

return new SqliteVectorStore(connection, options);
});

return services;
}
=> throw new InvalidOperationException("Use AddSqliteVectorStore with connectionString instead.");

/// <summary>
/// Register a SQLite <see cref="IVectorStore"/> with the specified service ID.
Expand All @@ -60,24 +43,9 @@ public static IServiceCollection AddSqliteVectorStore(
string connectionString,
SqliteVectorStoreOptions? options = default,
string? serviceId = default)
{
services.AddKeyedTransient<IVectorStore>(
=> services.AddKeyedSingleton<IVectorStore>(
serviceId,
(sp, obj) =>
{
var connection = new SqliteConnection(connectionString);
var extensionName = GetExtensionName(options?.VectorSearchExtensionName);

connection.Open();

connection.LoadExtension(extensionName);

var selectedOptions = options ?? sp.GetService<SqliteVectorStoreOptions>();
return new SqliteVectorStore(connection, options);
});

return services;
}
(sp, _) => new SqliteVectorStore(connectionString, options ?? sp.GetService<SqliteVectorStoreOptions>()));

/// <summary>
/// Register a SQLite <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/> and <see cref="IVectorizedSearch{TRecord}"/> with the specified service ID
Expand All @@ -91,33 +59,14 @@ public static IServiceCollection AddSqliteVectorStore(
/// <param name="options">Optional options to further configure the <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>.</param>
/// <param name="serviceId">An optional service id to use as the service key.</param>
/// <returns>Service collection.</returns>
[Obsolete("Use AddSqliteVectorStoreRecordCollection with connectionString instead.", error: true)]
public static IServiceCollection AddSqliteVectorStoreRecordCollection<TKey, TRecord>(
this IServiceCollection services,
string collectionName,
SqliteVectorStoreRecordCollectionOptions<TRecord>? options = default,
string? serviceId = default)
where TKey : notnull
{
services.AddKeyedTransient<IVectorStoreRecordCollection<TKey, TRecord>>(
serviceId,
(sp, obj) =>
{
var connection = sp.GetRequiredService<SqliteConnection>();

if (connection.State != ConnectionState.Open)
{
connection.Open();
}

var selectedOptions = options ?? sp.GetService<SqliteVectorStoreRecordCollectionOptions<TRecord>>();

return (new SqliteVectorStoreRecordCollection<TRecord>(connection, collectionName, selectedOptions) as IVectorStoreRecordCollection<TKey, TRecord>)!;
});

AddVectorizedSearch<TKey, TRecord>(services, serviceId);

return services;
}
=> throw new InvalidOperationException("Use AddSqliteVectorStore with connectionString instead.");

/// <summary>
/// Register a SQLite <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/> and <see cref="IVectorizedSearch{TRecord}"/> with the specified service ID.
Expand All @@ -139,21 +88,14 @@ public static IServiceCollection AddSqliteVectorStoreRecordCollection<TKey, TRec
string? serviceId = default)
where TKey : notnull
{
services.AddKeyedTransient<IVectorStoreRecordCollection<TKey, TRecord>>(
services.AddKeyedSingleton<IVectorStoreRecordCollection<TKey, TRecord>>(
serviceId,
(sp, obj) =>
{
var connection = new SqliteConnection(connectionString);
var extensionName = GetExtensionName(options?.VectorSearchExtensionName);

connection.Open();

connection.LoadExtension(extensionName);
Comment thread
roji marked this conversation as resolved.

var selectedOptions = options ?? sp.GetService<SqliteVectorStoreRecordCollectionOptions<TRecord>>();

return (new SqliteVectorStoreRecordCollection<TRecord>(connection, collectionName, selectedOptions) as IVectorStoreRecordCollection<TKey, TRecord>)!;
});
(sp, _) => (
new SqliteVectorStoreRecordCollection<TRecord>(
connectionString,
collectionName,
options ?? sp.GetService<SqliteVectorStoreRecordCollectionOptions<TRecord>>())
as IVectorStoreRecordCollection<TKey, TRecord>)!);

AddVectorizedSearch<TKey, TRecord>(services, serviceId);

Expand All @@ -169,20 +111,7 @@ public static IServiceCollection AddSqliteVectorStoreRecordCollection<TKey, TRec
/// <param name="serviceId">The service id that the registrations should use.</param>
private static void AddVectorizedSearch<TKey, TRecord>(IServiceCollection services, string? serviceId)
where TKey : notnull
{
services.AddKeyedTransient<IVectorizedSearch<TRecord>>(
=> services.AddKeyedSingleton<IVectorizedSearch<TRecord>>(
serviceId,
(sp, obj) =>
{
return sp.GetRequiredKeyedService<IVectorStoreRecordCollection<TKey, TRecord>>(serviceId);
});
}

/// <summary>
/// Returns extension name for vector search.
/// </summary>
private static string GetExtensionName(string? extensionName)
{
return !string.IsNullOrWhiteSpace(extensionName) ? extensionName! : SqliteConstants.VectorSearchExtensionName;
}
(sp, _) => sp.GetRequiredKeyedService<IVectorStoreRecordCollection<TKey, TRecord>>(serviceId));
}
33 changes: 22 additions & 11 deletions dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,36 @@ namespace Microsoft.SemanticKernel.Connectors.Sqlite;
/// </remarks>
public class SqliteVectorStore : IVectorStore
{
/// <summary><see cref="DbConnection"/> that will be used to manage the data in SQLite.</summary>
private readonly DbConnection _connection;
/// <summary>The connection string for the SQLite database represented by this <see cref="SqliteVectorStore"/>.</summary>
private readonly string _connectionString;

/// <summary>Optional configuration options for this class.</summary>
private readonly SqliteVectorStoreOptions _options;

/// <summary>
/// Initializes a new instance of the <see cref="SqliteVectorStore"/> class.
/// </summary>
/// <param name="connection"><see cref="SqliteConnection"/> that will be used to manage the data in SQLite.</param>
/// <param name="connectionString">The connection string for the SQLite database represented by this <see cref="SqliteVectorStore"/>.</param>
/// <param name="options">Optional configuration options for this class.</param>
public SqliteVectorStore(
DbConnection connection,
SqliteVectorStoreOptions? options = default)
public SqliteVectorStore(string connectionString, SqliteVectorStoreOptions? options = default)
{
Verify.NotNull(connection);
Verify.NotNull(connectionString);

this._connection = connection;
this._connectionString = connectionString;
this._options = options ?? new();
}

/// <summary>
/// Initializes a new instance of the <see cref="SqliteVectorStore"/> class.
/// </summary>
/// <param name="connection"><see cref="SqliteConnection"/> that will be used to manage the data in SQLite.</param>
/// <param name="options">Optional configuration options for this class.</param>
[Obsolete("Use the constructor that accepts a connection string instead.", error: true)]
public SqliteVectorStore(
DbConnection connection,
SqliteVectorStoreOptions? options = default)
=> throw new InvalidOperationException("Use the constructor that accepts a connection string instead.");

/// <inheritdoc />
public virtual IVectorStoreRecordCollection<TKey, TRecord> GetCollection<TKey, TRecord>(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null)
where TKey : notnull
Expand All @@ -47,7 +56,7 @@ public virtual IVectorStoreRecordCollection<TKey, TRecord> GetCollection<TKey, T
if (this._options.VectorStoreCollectionFactory is not null)
{
return this._options.VectorStoreCollectionFactory.CreateVectorStoreRecordCollection<TKey, TRecord>(
this._connection,
this._connectionString,
name,
vectorStoreRecordDefinition);
}
Expand All @@ -59,7 +68,7 @@ public virtual IVectorStoreRecordCollection<TKey, TRecord> GetCollection<TKey, T
}

var recordCollection = new SqliteVectorStoreRecordCollection<TRecord>(
this._connection,
this._connectionString,
name,
new()
{
Expand All @@ -77,7 +86,9 @@ public virtual async IAsyncEnumerable<string> ListCollectionNamesAsync([Enumerat
const string TablePropertyName = "name";
const string Query = $"SELECT {TablePropertyName} FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';";

using var command = this._connection.CreateCommand();
using var connection = new SqliteConnection(this._connectionString);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
using var command = connection.CreateCommand();

command.CommandText = Query;

Expand Down
Loading