Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@
<!-- Memory stores -->
<PackageVersion Include="Pgvector" Version="0.2.0" />
<PackageVersion Include="NRedisStack" Version="0.12.0" />
<PackageVersion Include="Milvus.Client" Version="2.2.2-preview.6" />
<PackageVersion Include="Milvus.Client" Version="2.3.0-preview.1" />
<PackageVersion Include="Testcontainers.Milvus" Version="3.8.0" />
<!-- Symbols -->
<PackageVersion Include="Microsoft.SourceLink.GitHub" Version="8.0.0" />
<!-- Toolset -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class MilvusMemoryStore : IMemoryStore, IDisposable
{
private readonly int _vectorSize;
private readonly SimilarityMetricType _metricType;
private readonly ConsistencyLevel _consistencyLevel;
private readonly bool _ownsMilvusClient;
private readonly string _indexName;

Expand All @@ -36,18 +37,10 @@ public class MilvusMemoryStore : IMemoryStore, IDisposable
private const string TimestampFieldName = "timestamp";

private const int DefaultMilvusPort = 19530;
private const ConsistencyLevel DefaultConsistencyLevel = ConsistencyLevel.Session;
private const int DefaultVarcharLength = 65_535;

private readonly QueryParameters _queryParametersWithEmbedding = new()
{
OutputFields = { IsReferenceFieldName, ExternalSourceNameFieldName, IdFieldName, DescriptionFieldName, TextFieldName, AdditionalMetadataFieldName, EmbeddingFieldName, KeyFieldName, TimestampFieldName }
};

private readonly QueryParameters _queryParametersWithoutEmbedding = new()
{
OutputFields = { IsReferenceFieldName, ExternalSourceNameFieldName, IdFieldName, DescriptionFieldName, TextFieldName, AdditionalMetadataFieldName, KeyFieldName, TimestampFieldName }
};
private readonly QueryParameters _queryParametersWithEmbedding;
private readonly QueryParameters _queryParametersWithoutEmbedding;

private readonly SearchParameters _searchParameters = new()
{
Expand All @@ -64,7 +57,7 @@ public class MilvusMemoryStore : IMemoryStore, IDisposable
/// <summary>
/// Creates a new <see cref="MilvusMemoryStore" />, connecting to the given hostname on the default Milvus port of 19530.
/// For more advanced configuration opens, construct a <see cref="MilvusClient" /> instance and pass it to
/// <see cref="MilvusMemoryStore(MilvusClient, string, int, SimilarityMetricType)" />.
/// <see cref="MilvusMemoryStore(MilvusClient, string, int, SimilarityMetricType, ConsistencyLevel)" />.
/// </summary>
/// <param name="host">The hostname or IP address to connect to.</param>
/// <param name="port">The port to connect to. Defaults to 19530.</param>
Expand All @@ -73,6 +66,7 @@ public class MilvusMemoryStore : IMemoryStore, IDisposable
/// <param name="indexName">The name of the index to use. Defaults to <see cref="DefaultIndexName" />.</param>
/// <param name="vectorSize">The size of the vectors used in Milvus. Defaults to 1536.</param>
/// <param name="metricType">The metric used to measure similarity between vectors. Defaults to <see cref="SimilarityMetricType.Ip" />.</param>
/// <param name="consistencyLevel">The consistency level to be used in the search. Defaults to <see cref="ConsistencyLevel.Session"/>.</param>
/// <param name="loggerFactory">An optional logger factory through which the Milvus client will log.</param>
public MilvusMemoryStore(
string host,
Expand All @@ -82,16 +76,19 @@ public MilvusMemoryStore(
string? indexName = null,
int vectorSize = 1536,
SimilarityMetricType metricType = SimilarityMetricType.Ip,
ConsistencyLevel consistencyLevel = ConsistencyLevel.Session,
ILoggerFactory? loggerFactory = null)
: this(new MilvusClient(host, port, ssl, database, callOptions: default, loggerFactory), indexName, vectorSize, metricType)
: this(
new MilvusClient(host, port, ssl, database, callOptions: default, loggerFactory),
indexName, vectorSize, metricType, consistencyLevel)
{
this._ownsMilvusClient = true;
}

/// <summary>
/// Creates a new <see cref="MilvusMemoryStore" />, connecting to the given hostname on the default Milvus port of 19530.
/// For more advanced configuration opens, construct a <see cref="MilvusClient" /> instance and pass it to
/// <see cref="MilvusMemoryStore(MilvusClient, string, int, SimilarityMetricType)" />.
/// <see cref="MilvusMemoryStore(MilvusClient, string, int, SimilarityMetricType, ConsistencyLevel)" />.
/// </summary>
/// <param name="host">The hostname or IP address to connect to.</param>
/// <param name="username">The username to use for authentication.</param>
Expand All @@ -102,6 +99,7 @@ public MilvusMemoryStore(
/// <param name="indexName">The name of the index to use. Defaults to <see cref="DefaultIndexName" />.</param>
/// <param name="vectorSize">The size of the vectors used in Milvus. Defaults to 1536.</param>
/// <param name="metricType">The metric used to measure similarity between vectors. Defaults to <see cref="SimilarityMetricType.Ip" />.</param>
/// <param name="consistencyLevel">The consistency level to be used in the search. Defaults to <see cref="ConsistencyLevel.Session"/>.</param>
/// <param name="loggerFactory">An optional logger factory through which the Milvus client will log.</param>
public MilvusMemoryStore(
string host,
Expand All @@ -113,16 +111,19 @@ public MilvusMemoryStore(
string? indexName = null,
int vectorSize = 1536,
SimilarityMetricType metricType = SimilarityMetricType.Ip,
ConsistencyLevel consistencyLevel = ConsistencyLevel.Session,
ILoggerFactory? loggerFactory = null)
: this(new MilvusClient(host, username, password, port, ssl, database, callOptions: default, loggerFactory), indexName, vectorSize, metricType)
: this(
new MilvusClient(host, username, password, port, ssl, database, callOptions: default, loggerFactory),
indexName, vectorSize, metricType, consistencyLevel)
{
this._ownsMilvusClient = true;
}

/// <summary>
/// Creates a new <see cref="MilvusMemoryStore" />, connecting to the given hostname on the default Milvus port of 19530.
/// For more advanced configuration opens, construct a <see cref="MilvusClient" /> instance and pass it to
/// <see cref="MilvusMemoryStore(MilvusClient, string, int, SimilarityMetricType)" />.
/// <see cref="MilvusMemoryStore(MilvusClient, string, int, SimilarityMetricType, ConsistencyLevel)" />.
/// </summary>
/// <param name="host">The hostname or IP address to connect to.</param>
/// <param name="apiKey">An API key to be used for authentication, instead of a username and password.</param>
Expand All @@ -132,6 +133,7 @@ public MilvusMemoryStore(
/// <param name="indexName">The name of the index to use. Defaults to <see cref="DefaultIndexName" />.</param>
/// <param name="vectorSize">The size of the vectors used in Milvus. Defaults to 1536.</param>
/// <param name="metricType">The metric used to measure similarity between vectors. Defaults to <see cref="SimilarityMetricType.Ip" />.</param>
/// <param name="consistencyLevel">The consistency level to be used in the search. Defaults to <see cref="ConsistencyLevel.Session"/>.</param>
/// <param name="loggerFactory">An optional logger factory through which the Milvus client will log.</param>
public MilvusMemoryStore(
string host,
Expand All @@ -142,8 +144,11 @@ public MilvusMemoryStore(
string? indexName = null,
int vectorSize = 1536,
SimilarityMetricType metricType = SimilarityMetricType.Ip,
ConsistencyLevel consistencyLevel = ConsistencyLevel.Session,
ILoggerFactory? loggerFactory = null)
: this(new MilvusClient(host, apiKey, port, ssl, database, callOptions: default, loggerFactory), indexName, vectorSize, metricType)
: this(
new MilvusClient(host, apiKey, port, ssl, database, callOptions: default, loggerFactory),
indexName, vectorSize, metricType, consistencyLevel)
{
this._ownsMilvusClient = true;
}
Expand All @@ -155,27 +160,43 @@ public MilvusMemoryStore(
/// <param name="indexName">The name of the index to use. Defaults to <see cref="DefaultIndexName" />.</param>
/// <param name="vectorSize">The size of the vectors used in Milvus. Defaults to 1536.</param>
/// <param name="metricType">The metric used to measure similarity between vectors. Defaults to <see cref="SimilarityMetricType.Ip" />.</param>
/// <param name="consistencyLevel">The consistency level to be used in the search. Defaults to <see cref="ConsistencyLevel.Session"/>.</param>
public MilvusMemoryStore(
MilvusClient client,
string? indexName = null,
int vectorSize = 1536,
SimilarityMetricType metricType = SimilarityMetricType.Ip)
: this(client, ownsMilvusClient: false, indexName, vectorSize, metricType)
SimilarityMetricType metricType = SimilarityMetricType.Ip,
ConsistencyLevel consistencyLevel = ConsistencyLevel.Session)
: this(client, ownsMilvusClient: false, indexName, vectorSize, metricType, consistencyLevel)
{
}

private MilvusMemoryStore(
MilvusClient client,
bool ownsMilvusClient,
string? indexName = null,
int vectorSize = 1536,
SimilarityMetricType metricType = SimilarityMetricType.Ip)
string? indexName,
int vectorSize,
SimilarityMetricType metricType,
ConsistencyLevel consistencyLevel)
{
this.Client = client;
this._indexName = indexName ?? DefaultIndexName;
this._vectorSize = vectorSize;
this._metricType = metricType;
this._ownsMilvusClient = ownsMilvusClient;
this._consistencyLevel = consistencyLevel;

this._queryParametersWithEmbedding = new()
{
OutputFields = { IsReferenceFieldName, ExternalSourceNameFieldName, IdFieldName, DescriptionFieldName, TextFieldName, AdditionalMetadataFieldName, EmbeddingFieldName, KeyFieldName, TimestampFieldName },
ConsistencyLevel = this._consistencyLevel
};

this._queryParametersWithoutEmbedding = new()
{
OutputFields = { IsReferenceFieldName, ExternalSourceNameFieldName, IdFieldName, DescriptionFieldName, TextFieldName, AdditionalMetadataFieldName, KeyFieldName, TimestampFieldName },
ConsistencyLevel = this._consistencyLevel
};
}

#endregion Constructors
Expand All @@ -196,7 +217,7 @@ public async Task CreateCollectionAsync(string collectionName, CancellationToken
EnableDynamicFields = true
};

MilvusCollection collection = await this.Client.CreateCollectionAsync(collectionName, schema, DefaultConsistencyLevel, cancellationToken: cancellationToken).ConfigureAwait(false);
MilvusCollection collection = await this.Client.CreateCollectionAsync(collectionName, schema, this._consistencyLevel, cancellationToken: cancellationToken).ConfigureAwait(false);

await collection.CreateIndexAsync(EmbeddingFieldName, metricType: this._metricType, indexName: this._indexName, cancellationToken: cancellationToken).ConfigureAwait(false);
await collection.WaitForIndexBuildAsync("float_vector", this._indexName, cancellationToken: cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -228,8 +249,6 @@ public async Task<string> UpsertAsync(string collectionName, MemoryRecord record
{
MilvusCollection collection = this.Client.GetCollection(collectionName);

await collection.DeleteAsync($@"{IdFieldName} in [""{record.Metadata.Id}""]", cancellationToken: cancellationToken).ConfigureAwait(false);

var metadata = record.Metadata;

List<FieldData> fieldData = new()
Expand All @@ -246,7 +265,7 @@ public async Task<string> UpsertAsync(string collectionName, MemoryRecord record
FieldData.Create(TimestampFieldName, new[] { record.Timestamp?.ToString(CultureInfo.InvariantCulture) ?? string.Empty }, isDynamic: true)
};

MutationResult result = await collection.InsertAsync(fieldData, cancellationToken: cancellationToken).ConfigureAwait(false);
MutationResult result = await collection.UpsertAsync(fieldData, cancellationToken: cancellationToken).ConfigureAwait(false);

return result.Ids.StringIds![0];
}
Expand All @@ -257,9 +276,6 @@ public async IAsyncEnumerable<string> UpsertBatchAsync(
IEnumerable<MemoryRecord> records,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
// TODO: Milvus v2.3.0 will have a 1st-class upsert API which we should use.
// In the meantime, we do delete+insert, following the Python connector's example.

StringBuilder idString = new();

List<bool> isReferenceData = new();
Expand Down Expand Up @@ -295,7 +311,6 @@ public async IAsyncEnumerable<string> UpsertBatchAsync(
}

MilvusCollection collection = this.Client.GetCollection(collectionName);
await collection.DeleteAsync($"{IdFieldName} in [{idString}]", cancellationToken: cancellationToken).ConfigureAwait(false);

FieldData[] fieldData =
{
Expand All @@ -311,7 +326,7 @@ public async IAsyncEnumerable<string> UpsertBatchAsync(
FieldData.Create(TimestampFieldName, timestampData, isDynamic: true)
};

MutationResult result = await collection.InsertAsync(fieldData, cancellationToken: cancellationToken).ConfigureAwait(false);
MutationResult result = await collection.UpsertAsync(fieldData, cancellationToken: cancellationToken).ConfigureAwait(false);

foreach (var id in result.Ids.StringIds!)
{
Expand Down Expand Up @@ -355,7 +370,10 @@ public async IAsyncEnumerable<MemoryRecord> GetBatchAsync(

IReadOnlyList<FieldData> fields = await this.Client
.GetCollection(collectionName)
.QueryAsync($"{IdFieldName} in [{idString}]", withEmbeddings ? this._queryParametersWithEmbedding : this._queryParametersWithoutEmbedding, cancellationToken: cancellationToken)
.QueryAsync(
$"{IdFieldName} in [{idString}]",
withEmbeddings ? this._queryParametersWithEmbedding : this._queryParametersWithoutEmbedding,
cancellationToken: cancellationToken)
.ConfigureAwait(false);

var rowCount = fields[0].RowCount;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Threading.Tasks;
using Milvus.Client;
using Testcontainers.Milvus;
using Xunit;

namespace SemanticKernel.IntegrationTests.Connectors.Milvus;

public sealed class MilvusFixture : IAsyncLifetime
{
private readonly MilvusContainer _container = new MilvusBuilder().Build();

public string Host => this._container.Hostname;
public int Port => this._container.GetMappedPublicPort(MilvusBuilder.MilvusGrpcPort);

public MilvusClient CreateClient()
=> new(this.Host, "root", "milvus", this.Port);

public Task InitializeAsync()
=> this._container.StartAsync();

public Task DisposeAsync()
=> this._container.DisposeAsync().AsTask();
}
Loading