diff --git a/.github/_typos.toml b/.github/_typos.toml index 457e6bca4c2c..d9a2dcb7a2e4 100644 --- a/.github/_typos.toml +++ b/.github/_typos.toml @@ -39,6 +39,7 @@ prompty = "prompty" # prompty is a format name. ist = "ist" # German language dall = "dall" # OpenAI model name pn = "pn" # Kiota parameter +nin = "nin" # MongoDB "not in" operator [default.extend-identifiers] ags = "ags" # Azure Graph Service diff --git a/dotnet/Directory.Build.props b/dotnet/Directory.Build.props index 94d748c78057..13e279d799d6 100644 --- a/dotnet/Directory.Build.props +++ b/dotnet/Directory.Build.props @@ -25,6 +25,11 @@ True + + + $(NoWarn);CS8604;CS8602 + + $([System.IO.Path]::GetDirectoryName($([MSBuild]::GetPathOfFileAbove('.gitignore', '$(MSBuildThisFileDirectory)')))) diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index e93dc3df49a2..fcad75436cb8 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -115,11 +115,15 @@ - + - + + + + + diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 0a711f84f5f3..e1953ea0bf7e 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -117,6 +117,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Diagnostics", "Diagnostics" src\InternalUtilities\src\Diagnostics\RequiresUnreferencedCodeAttribute.cs = src\InternalUtilities\src\Diagnostics\RequiresUnreferencedCodeAttribute.cs src\InternalUtilities\src\Diagnostics\UnconditionalSuppressMessageAttribute.cs = src\InternalUtilities\src\Diagnostics\UnconditionalSuppressMessageAttribute.cs src\InternalUtilities\src\Diagnostics\Verify.cs = src\InternalUtilities\src\Diagnostics\Verify.cs + src\InternalUtilities\src\Diagnostics\UnreachableException.cs = src\InternalUtilities\src\Diagnostics\UnreachableException.cs EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Linq", "Linq", "{B00AD427-0047-4850-BEF9-BA8237EA9D8B}" @@ -140,6 +141,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "System", "System", "{3CDE10 src\InternalUtilities\src\System\InternalTypeConverter.cs = src\InternalUtilities\src\System\InternalTypeConverter.cs src\InternalUtilities\src\System\NonNullCollection.cs = src\InternalUtilities\src\System\NonNullCollection.cs src\InternalUtilities\src\System\TypeConverterFactory.cs = src\InternalUtilities\src\System\TypeConverterFactory.cs + src\InternalUtilities\src\System\IndexRange.cs = src\InternalUtilities\src\System\IndexRange.cs EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Type", "Type", "{E85EA4D0-BB7E-4DFD-882F-A76EB8C0B8FF}" @@ -439,6 +441,30 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "sk-chatgpt-azure-function", EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "kernel-functions-generator", "samples\Demos\CreateChatGptPlugin\MathPlugin\kernel-functions-generator\kernel-functions-generator.csproj", "{78785CB1-66CF-4895-D7E5-A440DD84BE86}" EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "VectorDataIntegrationTests", "VectorDataIntegrationTests", "{4F381919-F1BE-47D8-8558-3187ED04A84F}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "QdrantIntegrationTests", "src\VectorDataIntegrationTests\QdrantIntegrationTests\QdrantIntegrationTests.csproj", "{27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "VectorDataIntegrationTests", "src\VectorDataIntegrationTests\VectorDataIntegrationTests\VectorDataIntegrationTests.csproj", "{B29A972F-A774-4140-AECF-6B577C476627}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RedisIntegrationTests", "src\VectorDataIntegrationTests\RedisIntegrationTests\RedisIntegrationTests.csproj", "{F7EA82A4-A626-4316-AA47-EAC3A0E85870}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "PostgresIntegrationTests", "src\VectorDataIntegrationTests\PostgresIntegrationTests\PostgresIntegrationTests.csproj", "{3148FF01-38C7-4BEB-8CEC-9323EC7C593B}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "InMemoryIntegrationTests", "src\VectorDataIntegrationTests\InMemoryIntegrationTests\InMemoryIntegrationTests.csproj", "{F5126690-0FD1-4777-9EDF-B3F5B7B3730B}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "CosmosNoSQLIntegrationTests", "src\VectorDataIntegrationTests\CosmosNoSQLIntegrationTests\CosmosNoSQLIntegrationTests.csproj", "{E200425C-E501-430C-8A8B-BC0088BD94DB}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SqliteIntegrationTests", "src\VectorDataIntegrationTests\SqliteIntegrationTests\SqliteIntegrationTests.csproj", "{709B3933-5286-4139-8D83-8C7AA5746FAE}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "WeaviateIntegrationTests", "src\VectorDataIntegrationTests\WeaviateIntegrationTests\WeaviateIntegrationTests.csproj", "{E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MongoDBIntegrationTests", "src\VectorDataIntegrationTests\MongoDBIntegrationTests\MongoDBIntegrationTests.csproj", "{A0E65043-6B00-4836-850F-000A52238914}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "CosmosMongoDBIntegrationTests", "src\VectorDataIntegrationTests\CosmosMongoDBIntegrationTests\CosmosMongoDBIntegrationTests.csproj", "{11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AzureAISearchIntegrationTests", "src\VectorDataIntegrationTests\AzureAISearchIntegrationTests\AzureAISearchIntegrationTests.csproj", "{06181F0F-A375-43AE-B45F-73CBCFC30C14}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -1172,6 +1198,72 @@ Global {78785CB1-66CF-4895-D7E5-A440DD84BE86}.Publish|Any CPU.Build.0 = Debug|Any CPU {78785CB1-66CF-4895-D7E5-A440DD84BE86}.Release|Any CPU.ActiveCfg = Release|Any CPU {78785CB1-66CF-4895-D7E5-A440DD84BE86}.Release|Any CPU.Build.0 = Release|Any CPU + {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}.Debug|Any CPU.Build.0 = Debug|Any CPU + {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}.Publish|Any CPU.Build.0 = Debug|Any CPU + {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}.Release|Any CPU.ActiveCfg = Release|Any CPU + {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}.Release|Any CPU.Build.0 = Release|Any CPU + {B29A972F-A774-4140-AECF-6B577C476627}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B29A972F-A774-4140-AECF-6B577C476627}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B29A972F-A774-4140-AECF-6B577C476627}.Publish|Any CPU.ActiveCfg = Publish|Any CPU + {B29A972F-A774-4140-AECF-6B577C476627}.Publish|Any CPU.Build.0 = Publish|Any CPU + {B29A972F-A774-4140-AECF-6B577C476627}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B29A972F-A774-4140-AECF-6B577C476627}.Release|Any CPU.Build.0 = Release|Any CPU + {F7EA82A4-A626-4316-AA47-EAC3A0E85870}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F7EA82A4-A626-4316-AA47-EAC3A0E85870}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F7EA82A4-A626-4316-AA47-EAC3A0E85870}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {F7EA82A4-A626-4316-AA47-EAC3A0E85870}.Publish|Any CPU.Build.0 = Debug|Any CPU + {F7EA82A4-A626-4316-AA47-EAC3A0E85870}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F7EA82A4-A626-4316-AA47-EAC3A0E85870}.Release|Any CPU.Build.0 = Release|Any CPU + {3148FF01-38C7-4BEB-8CEC-9323EC7C593B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3148FF01-38C7-4BEB-8CEC-9323EC7C593B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3148FF01-38C7-4BEB-8CEC-9323EC7C593B}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {3148FF01-38C7-4BEB-8CEC-9323EC7C593B}.Publish|Any CPU.Build.0 = Debug|Any CPU + {3148FF01-38C7-4BEB-8CEC-9323EC7C593B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3148FF01-38C7-4BEB-8CEC-9323EC7C593B}.Release|Any CPU.Build.0 = Release|Any CPU + {F5126690-0FD1-4777-9EDF-B3F5B7B3730B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F5126690-0FD1-4777-9EDF-B3F5B7B3730B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F5126690-0FD1-4777-9EDF-B3F5B7B3730B}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {F5126690-0FD1-4777-9EDF-B3F5B7B3730B}.Publish|Any CPU.Build.0 = Debug|Any CPU + {F5126690-0FD1-4777-9EDF-B3F5B7B3730B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F5126690-0FD1-4777-9EDF-B3F5B7B3730B}.Release|Any CPU.Build.0 = Release|Any CPU + {E200425C-E501-430C-8A8B-BC0088BD94DB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E200425C-E501-430C-8A8B-BC0088BD94DB}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E200425C-E501-430C-8A8B-BC0088BD94DB}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {E200425C-E501-430C-8A8B-BC0088BD94DB}.Publish|Any CPU.Build.0 = Debug|Any CPU + {E200425C-E501-430C-8A8B-BC0088BD94DB}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E200425C-E501-430C-8A8B-BC0088BD94DB}.Release|Any CPU.Build.0 = Release|Any CPU + {709B3933-5286-4139-8D83-8C7AA5746FAE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {709B3933-5286-4139-8D83-8C7AA5746FAE}.Debug|Any CPU.Build.0 = Debug|Any CPU + {709B3933-5286-4139-8D83-8C7AA5746FAE}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {709B3933-5286-4139-8D83-8C7AA5746FAE}.Publish|Any CPU.Build.0 = Debug|Any CPU + {709B3933-5286-4139-8D83-8C7AA5746FAE}.Release|Any CPU.ActiveCfg = Release|Any CPU + {709B3933-5286-4139-8D83-8C7AA5746FAE}.Release|Any CPU.Build.0 = Release|Any CPU + {E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F}.Publish|Any CPU.Build.0 = Debug|Any CPU + {E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F}.Release|Any CPU.Build.0 = Release|Any CPU + {A0E65043-6B00-4836-850F-000A52238914}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A0E65043-6B00-4836-850F-000A52238914}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A0E65043-6B00-4836-850F-000A52238914}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {A0E65043-6B00-4836-850F-000A52238914}.Publish|Any CPU.Build.0 = Debug|Any CPU + {A0E65043-6B00-4836-850F-000A52238914}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A0E65043-6B00-4836-850F-000A52238914}.Release|Any CPU.Build.0 = Release|Any CPU + {11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}.Debug|Any CPU.Build.0 = Debug|Any CPU + {11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}.Publish|Any CPU.Build.0 = Debug|Any CPU + {11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}.Release|Any CPU.ActiveCfg = Release|Any CPU + {11DFBF14-6FBA-41F0-B7F3-A288952D6FDB}.Release|Any CPU.Build.0 = Release|Any CPU + {06181F0F-A375-43AE-B45F-73CBCFC30C14}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {06181F0F-A375-43AE-B45F-73CBCFC30C14}.Debug|Any CPU.Build.0 = Debug|Any CPU + {06181F0F-A375-43AE-B45F-73CBCFC30C14}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {06181F0F-A375-43AE-B45F-73CBCFC30C14}.Publish|Any CPU.Build.0 = Debug|Any CPU + {06181F0F-A375-43AE-B45F-73CBCFC30C14}.Release|Any CPU.ActiveCfg = Release|Any CPU + {06181F0F-A375-43AE-B45F-73CBCFC30C14}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1333,6 +1425,18 @@ Global {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {2EB6E4C2-606D-B638-2E08-49EA2061C428} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} {78785CB1-66CF-4895-D7E5-A440DD84BE86} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} + {4F381919-F1BE-47D8-8558-3187ED04A84F} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0} + {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {B29A972F-A774-4140-AECF-6B577C476627} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {F7EA82A4-A626-4316-AA47-EAC3A0E85870} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {3148FF01-38C7-4BEB-8CEC-9323EC7C593B} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {F5126690-0FD1-4777-9EDF-B3F5B7B3730B} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {E200425C-E501-430C-8A8B-BC0088BD94DB} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {709B3933-5286-4139-8D83-8C7AA5746FAE} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {E3CECC65-1B00-4E3A-90B6-FC7A2C64E41F} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {A0E65043-6B00-4836-850F-000A52238914} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {11DFBF14-6FBA-41F0-B7F3-A288952D6FDB} = {4F381919-F1BE-47D8-8558-3187ED04A84F} + {06181F0F-A375-43AE-B45F-73CBCFC30C14} = {4F381919-F1BE-47D8-8558-3187ED04A84F} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83} diff --git a/dotnet/SK-dotnet.sln.DotSettings b/dotnet/SK-dotnet.sln.DotSettings index d8964e230315..f5eec1700bcd 100644 --- a/dotnet/SK-dotnet.sln.DotSettings +++ b/dotnet/SK-dotnet.sln.DotSettings @@ -217,6 +217,7 @@ public void It$SOMENAME$() True True True + True True True True diff --git a/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStoreRecordCollection.cs b/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStoreRecordCollection.cs index 5c1c4b05c56f..000cb1ebba07 100644 --- a/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStoreRecordCollection.cs +++ b/dotnet/samples/Concepts/Memory/VectorStoreEmbeddingGeneration/TextEmbeddingVectorStoreRecordCollection.cs @@ -8,7 +8,7 @@ namespace Memory.VectorStoreEmbeddingGeneration; /// -/// Decorator for a that generates embeddings for records on upsert and when using . +/// Decorator for a that generates embeddings for records on upsert and when using . /// /// /// This class is part of the sample. @@ -120,13 +120,13 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record } /// - public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { return this._decoratedVectorStoreRecordCollection.VectorizedSearchAsync(vector, options, cancellationToken); } /// - public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { var embeddingValue = await this._textEmbeddingGenerationService.GenerateEmbeddingAsync(searchText, cancellationToken: cancellationToken).ConfigureAwait(false); return await this.VectorizedSearchAsync(embeddingValue, options, cancellationToken).ConfigureAwait(false); diff --git a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/MappingVectorStoreRecordCollection.cs b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/MappingVectorStoreRecordCollection.cs index 076be09c9ca5..1951f3a6dbee 100644 --- a/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/MappingVectorStoreRecordCollection.cs +++ b/dotnet/samples/Concepts/Memory/VectorStoreLangchainInterop/MappingVectorStoreRecordCollection.cs @@ -1,5 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. +// TODO: Commented out as part of implementing LINQ-based filtering, since MappingVectorStoreRecordCollection is no longer easy/feasible. +// TODO: The user provides an expression tree accepting a TPublicRecord, but we require an expression tree accepting a TInternalRecord. +// TODO: This is something that the user must provide, and is quite advanced. + +#if DISABLED + using System.Runtime.CompilerServices; using Microsoft.Extensions.VectorData; @@ -132,3 +138,5 @@ public async Task> VectorizedSearchAsync CreateVectorStoreRecordCollec return (collection as IVectorStoreRecordCollection)!; } +#if DISABLED_FOR_NOW // TODO: See note on MappingVectorStoreRecordCollection // If the user asked for a string key, we can add a decorator which converts back and forth between string and guid. // The string that the user provides will still need to contain a valid guid, since the Langchain created collection // uses guid keys. @@ -92,6 +93,7 @@ public IVectorStoreRecordCollection CreateVectorStoreRecordCollec return (stringKeyCollection as IVectorStoreRecordCollection)!; } +#endif throw new NotSupportedException("This VectorStore is only usable with Guid keys and LangchainDocument record types or string keys and LangchainDocument record types"); } diff --git a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Common.cs b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Common.cs index c5160ac8739c..ff492ca58304 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Common.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_MultiStore_Common.cs @@ -70,8 +70,7 @@ public async Task IngestDataAndSearchAsync(string collectionName, Func.Category), "External Definitions"); - searchResult = await collection.VectorizedSearchAsync(searchVector, new() { Top = 3, Filter = filter }); + searchResult = await collection.VectorizedSearchAsync(searchVector, new() { Top = 3, NewFilter = g => g.Category == "External Definitions" }); resultRecords = await searchResult.Results.ToListAsync(); output.WriteLine("Search string: " + searchString); diff --git a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Simple.cs b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Simple.cs index a7eceb4046a9..5119881c3bda 100644 --- a/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Simple.cs +++ b/dotnet/samples/Concepts/Memory/VectorStore_VectorSearch_Simple.cs @@ -70,8 +70,7 @@ public async Task ExampleAsync() // Search the collection using a vector search with pre-filtering. searchString = "What is Retrieval Augmented Generation"; searchVector = await textEmbeddingGenerationService.GenerateEmbeddingAsync(searchString); - var filter = new VectorSearchFilter().EqualTo(nameof(Glossary.Category), "External Definitions"); - searchResult = await collection.VectorizedSearchAsync(searchVector, new() { Top = 3, Filter = filter }); + searchResult = await collection.VectorizedSearchAsync(searchVector, new() { Top = 3, NewFilter = g => g.Category == "External Definitions" }); resultRecords = await searchResult.Results.ToListAsync(); Console.WriteLine("Search string: " + searchString); diff --git a/dotnet/samples/Concepts/Search/VectorStore_TextSearch.cs b/dotnet/samples/Concepts/Search/VectorStore_TextSearch.cs index df52982104b8..f6a3d4ab6356 100644 --- a/dotnet/samples/Concepts/Search/VectorStore_TextSearch.cs +++ b/dotnet/samples/Concepts/Search/VectorStore_TextSearch.cs @@ -144,7 +144,7 @@ internal static async Task> CreateCo private sealed class VectorizedSearchWrapper(IVectorizedSearch vectorizedSearch, ITextEmbeddingGenerationService textEmbeddingGeneration) : IVectorizableTextSearch { /// - public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { var vectorizedQuery = await textEmbeddingGeneration!.GenerateEmbeddingAsync(searchText, cancellationToken: cancellationToken).ConfigureAwait(false); diff --git a/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs b/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs index 19c7cee676e8..9b7e889b25dd 100644 --- a/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs +++ b/dotnet/samples/GettingStartedWithVectorStores/Step2_Vector_Search.cs @@ -71,7 +71,7 @@ public async Task SearchAnInMemoryVectorStoreWithFilteringAsync() new() { Top = 1, - Filter = new VectorSearchFilter().EqualTo(nameof(Glossary.Category), "AI") + NewFilter = g => g.Category == "AI" }); var searchResultItems = await searchResult.Results.ToListAsync(); diff --git a/dotnet/samples/GettingStartedWithVectorStores/Step4_NonStringKey_VectorStore.cs b/dotnet/samples/GettingStartedWithVectorStores/Step4_NonStringKey_VectorStore.cs index 7303ddc9801a..35ca4822a824 100644 --- a/dotnet/samples/GettingStartedWithVectorStores/Step4_NonStringKey_VectorStore.cs +++ b/dotnet/samples/GettingStartedWithVectorStores/Step4_NonStringKey_VectorStore.cs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +#if DISABLED_FOR_NOW // TODO: See note in MappingVectorStoreRecordCollection + using System.Runtime.CompilerServices; using Microsoft.Extensions.VectorData; using Microsoft.SemanticKernel.Connectors.Qdrant; @@ -7,6 +9,7 @@ namespace GettingStartedWithVectorStores; + /// /// Example that shows that you can switch between different vector stores with the same code, in this case /// with a vector store that doesn't use string keys. @@ -193,3 +196,5 @@ public async Task> VectorizedSearchAsync /// Contains tests for the class. /// @@ -21,7 +23,7 @@ public void BuildFilterStringBuildsCorrectEqualityStringForEachFilterType(string var filter = new VectorSearchFilter().EqualTo(fieldName, fieldValue!); // Act. - var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(filter, new Dictionary { { fieldName, "storage_" + fieldName } }); + var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(filter, new Dictionary { { fieldName, "storage_" + fieldName } }); // Assert. Assert.Equal(expected, actual); @@ -34,7 +36,7 @@ public void BuildFilterStringBuildsCorrectTagContainsString() var filter = new VectorSearchFilter().AnyTagEqualTo("Tags", "mytag"); // Act. - var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(filter, new Dictionary { { "Tags", "storage_tags" } }); + var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(filter, new Dictionary { { "Tags", "storage_tags" } }); // Assert. Assert.Equal("storage_tags/any(t: t eq 'mytag')", actual); @@ -47,7 +49,7 @@ public void BuildFilterStringCombinesFilterOptions() var filter = new VectorSearchFilter().EqualTo("intField", 5).AnyTagEqualTo("Tags", "mytag"); // Act. - var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(filter, new Dictionary { { "Tags", "storage_tags" }, { "intField", "storage_intField" } }); + var actual = AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(filter, new Dictionary { { "Tags", "storage_tags" }, { "intField", "storage_intField" } }); // Assert. Assert.Equal("storage_intField eq 5 and storage_tags/any(t: t eq 'mytag')", actual); @@ -57,8 +59,8 @@ public void BuildFilterStringCombinesFilterOptions() public void BuildFilterStringThrowsForUnknownPropertyName() { // Act and assert. - Assert.Throws(() => AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(new VectorSearchFilter().EqualTo("unknown", "value"), new Dictionary())); - Assert.Throws(() => AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(new VectorSearchFilter().AnyTagEqualTo("unknown", "value"), new Dictionary())); + Assert.Throws(() => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(new VectorSearchFilter().EqualTo("unknown", "value"), new Dictionary())); + Assert.Throws(() => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(new VectorSearchFilter().AnyTagEqualTo("unknown", "value"), new Dictionary())); } public static IEnumerable DataTypeMappingOptions() diff --git a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreRecordCollectionTests.cs index 467207b29ace..eb240f91d9aa 100644 --- a/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureAISearch.UnitTests/AzureAISearchVectorStoreRecordCollectionTests.cs @@ -20,6 +20,8 @@ namespace SemanticKernel.Connectors.AzureAISearch.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Contains tests for the class. /// diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreCollectionSearchMappingTests.cs index 6e061892d2b9..9dee844e61d2 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreCollectionSearchMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreCollectionSearchMappingTests.cs @@ -9,6 +9,8 @@ namespace SemanticKernel.Connectors.AzureCosmosDBMongoDB.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Unit tests for class. /// diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs index 99815a1cee63..ab2fa157b212 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs @@ -13,6 +13,7 @@ using MongoDB.Driver; using Moq; using Xunit; +using MEVD = Microsoft.Extensions.VectorData; namespace SemanticKernel.Connectors.AzureCosmosDBMongoDB.UnitTests; @@ -643,7 +644,7 @@ public async Task VectorizedSearchThrowsExceptionWithNonExistentVectorPropertyNa this._mockMongoDatabase.Object, "collection"); - var options = new VectorSearchOptions { VectorPropertyName = "non-existent-property" }; + var options = new MEVD.VectorSearchOptions { VectorPropertyName = "non-existent-property" }; // Act & Assert await Assert.ThrowsAsync(async () => await (await sut.VectorizedSearchAsync(new ReadOnlyMemory([1f, 2f, 3f]), options)).Results.FirstOrDefaultAsync()); diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs index 094028e516ab..37aa005777d5 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilderTests.cs @@ -9,6 +9,8 @@ namespace SemanticKernel.Connectors.AzureCosmosDBNoSQL.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Unit tests for class. /// @@ -35,7 +37,7 @@ public void BuildSearchQueryByDefaultReturnsValidQueryDefinition() .EqualTo("TestProperty2", "test-value-2") .AnyTagEqualTo("TestProperty3", "test-value-3"); - var searchOptions = new VectorSearchOptions { Filter = filter, Skip = 5, Top = 10 }; + var searchOptions = new VectorSearchOptions { Filter = filter, Skip = 5, Top = 10 }; // Act var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSearchQuery( @@ -84,7 +86,7 @@ public void BuildSearchQueryWithoutOffsetReturnsQueryDefinitionWithTopParameter( .EqualTo("TestProperty2", "test-value-2") .AnyTagEqualTo("TestProperty3", "test-value-3"); - var searchOptions = new VectorSearchOptions { Filter = filter, Top = 10 }; + var searchOptions = new VectorSearchOptions { Filter = filter, Top = 10 }; // Act var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSearchQuery( @@ -129,7 +131,7 @@ public void BuildSearchQueryWithInvalidFilterThrowsException() var filter = new VectorSearchFilter().EqualTo("non-existent-property", "test-value-2"); - var searchOptions = new VectorSearchOptions { Filter = filter, Skip = 5, Top = 10 }; + var searchOptions = new VectorSearchOptions { Filter = filter, Skip = 5, Top = 10 }; // Act & Assert Assert.Throws(() => @@ -150,7 +152,7 @@ public void BuildSearchQueryWithoutFilterDoesNotContainWhereClause() var vectorPropertyName = "test_property_1"; var fields = this._storagePropertyNames.Values.ToList(); - var searchOptions = new VectorSearchOptions { Skip = 5, Top = 10 }; + var searchOptions = new VectorSearchOptions { Skip = 5, Top = 10 }; // Act var queryDefinition = AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.BuildSearchQuery( @@ -181,10 +183,11 @@ public void BuildSearchQueryWithoutFilterDoesNotContainWhereClause() public void BuildSelectQueryByDefaultReturnsValidQueryDefinition() { // Arrange - const string ExpectedQueryText = "" + - "SELECT x.key,x.property_1,x.property_2 " + - "FROM x " + - "WHERE (x.key_property = @rk0 AND x.partition_key_property = @pk0) "; + const string ExpectedQueryText = """ + SELECT x.key,x.property_1,x.property_2 + FROM x + WHERE (x.key_property = @rk0 AND x.partition_key_property = @pk0) + """; const string KeyStoragePropertyName = "key_property"; const string PartitionKeyPropertyName = "partition_key_property"; @@ -211,4 +214,8 @@ public void BuildSelectQueryByDefaultReturnsValidQueryDefinition() Assert.Equal("@pk0", queryParameters[1].Name); Assert.Equal("partition_key", queryParameters[1].Value); } + +#pragma warning disable CA1812 // An internal class that is apparently never instantiated. If so, remove the code from the assembly. + private sealed class DummyType; +#pragma warning restore CA1812 } diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs index d8718eb2f2b5..24e4a2083f0b 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBNoSQL.UnitTests/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs @@ -612,7 +612,7 @@ public async Task VectorizedSearchWithNonExistentVectorPropertyNameThrowsExcepti this._mockDatabase.Object, "collection"); - var searchOptions = new VectorSearchOptions { VectorPropertyName = "non-existent-property" }; + var searchOptions = new VectorSearchOptions { VectorPropertyName = "non-existent-property" }; // Act & Assert await Assert.ThrowsAsync(async () => diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI/Connectors.AzureOpenAI.csproj b/dotnet/src/Connectors/Connectors.AzureOpenAI/Connectors.AzureOpenAI.csproj index 15d88496159b..9fcbdecf530e 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI/Connectors.AzureOpenAI.csproj +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI/Connectors.AzureOpenAI.csproj @@ -35,4 +35,5 @@ + \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryVectorStoreRecordCollectionTests.cs index 1cf974a77c84..bbf5c9611e32 100644 --- a/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.InMemory.UnitTests/InMemoryVectorStoreRecordCollectionTests.cs @@ -293,7 +293,7 @@ public async Task CanSearchWithVectorAsync(bool useDefinition, TKey testKe // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true }, + new() { IncludeVectors = true }, this._testCancellationToken); // Assert @@ -309,6 +309,7 @@ public async Task CanSearchWithVectorAsync(bool useDefinition, TKey testKe Assert.Equal(-1, actualResults[1].Score); } +#pragma warning disable CS0618 // VectorSearchFilter is obsolete [Theory] [InlineData(true, TestRecordKey1, TestRecordKey2, "Equality")] [InlineData(true, TestRecordIntKey1, TestRecordIntKey2, "Equality")] @@ -337,7 +338,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, TK var filter = filterType == "Equality" ? new VectorSearchFilter().EqualTo("Data", $"data {testKey2}") : new VectorSearchFilter().AnyTagEqualTo("Tags", $"tag {testKey2}"); var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true, Filter = filter, IncludeTotalCount = true }, + new() { IncludeVectors = true, Filter = filter, IncludeTotalCount = true }, this._testCancellationToken); // Assert @@ -349,6 +350,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, TK Assert.Equal($"data {testKey2}", actualResults[0].Record.Data); Assert.Equal(-1, actualResults[0].Score); } +#pragma warning restore CS0618 // Type or member is obsolete [Theory] [InlineData(DistanceFunction.CosineSimilarity, 1, -1)] @@ -389,7 +391,7 @@ public async Task CanSearchWithDifferentDistanceFunctionsAsync(string distanceFu // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true }, + new() { IncludeVectors = true }, this._testCancellationToken); // Assert @@ -430,7 +432,7 @@ public async Task CanSearchManyRecordsAsync(bool useDefinition) // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true, Top = 10, Skip = 10, IncludeTotalCount = true }, + new() { IncludeVectors = true, Top = 10, Skip = 10, IncludeTotalCount = true }, this._testCancellationToken); // Assert @@ -506,7 +508,7 @@ public async Task ItCanSearchUsingTheGenericDataModelAsync(TKey testKey1, // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory([1, 1, 1, 1]), - new VectorSearchOptions { IncludeVectors = true, VectorPropertyName = "Vector" }, + new() { IncludeVectors = true, VectorPropertyName = "Vector" }, this._testCancellationToken); // Assert diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs new file mode 100644 index 000000000000..16164c2a3eca --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchFilterTranslator.cs @@ -0,0 +1,366 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; + +internal class AzureAISearchFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + + private readonly StringBuilder _filter = new(); + + private static readonly char[] s_searchInDefaultDelimiter = [' ', ',']; + + internal string Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + { + Debug.Assert(this._filter.Length == 0); + + this._storagePropertyNames = storagePropertyNames; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + this.Translate(lambdaExpression.Body); + return this._filter.ToString(); + } + + private void Translate(Expression? node) + { + switch (node) + { + case BinaryExpression binary: + this.TranslateBinary(binary); + return; + + case ConstantExpression constant: + this.TranslateConstant(constant); + return; + + case MemberExpression member: + this.TranslateMember(member); + return; + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall); + return; + + case UnaryExpression unary: + this.TranslateUnary(unary); + return; + + default: + throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); + } + } + + private void TranslateBinary(BinaryExpression binary) + { + this._filter.Append('('); + this.Translate(binary.Left); + + this._filter.Append(binary.NodeType switch + { + ExpressionType.Equal => " eq ", + ExpressionType.NotEqual => " ne ", + + ExpressionType.GreaterThan => " gt ", + ExpressionType.GreaterThanOrEqual => " ge ", + ExpressionType.LessThan => " lt ", + ExpressionType.LessThanOrEqual => " le ", + + ExpressionType.AndAlso => " and ", + ExpressionType.OrElse => " or ", + + _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) + }); + + this.Translate(binary.Right); + this._filter.Append(')'); + } + + private void TranslateConstant(ConstantExpression constant) + => this.GenerateLiteral(constant.Value); + + private void GenerateLiteral(object? value) + { + // TODO: Nullable + switch (value) + { + case byte b: + this._filter.Append(b); + return; + case short s: + this._filter.Append(s); + return; + case int i: + this._filter.Append(i); + return; + case long l: + this._filter.Append(l); + return; + + case string s: + this._filter.Append('\'').Append(s.Replace("'", "''")).Append('\''); // TODO: escaping + return; + case bool b: + this._filter.Append(b ? "true" : "false"); + return; + case Guid g: + this._filter.Append('\'').Append(g.ToString()).Append('\''); + return; + + case DateTime: + case DateTimeOffset: + throw new NotImplementedException(); + + case Array: + throw new NotImplementedException(); + + case null: + this._filter.Append("null"); + return; + + default: + throw new NotSupportedException("Unsupported constant type: " + value.GetType().Name); + } + } + + private void TranslateMember(MemberExpression memberExpression) + { + switch (memberExpression) + { + case var _ when this.TryGetField(memberExpression, out var column): + this._filter.Append(column); // TODO: Escape + return; + + // Identify captured lambda variables, inline them as constants + case var _ when TryGetCapturedValue(memberExpression, out var capturedValue): + this.GenerateLiteral(capturedValue); + return; + + default: + throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + } + } + + private void TranslateMethodCall(MethodCallExpression methodCall) + { + switch (methodCall) + { + // Enumerable.Contains() + case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable): + this.TranslateContains(source, item); + return; + + // List.Contains() + case + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>): + this.TranslateContains(source, item); + return; + + default: + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + } + + private void TranslateContains(Expression source, Expression item) + { + switch (source) + { + // Contains over array field (r => r.Strings.Contains("foo")) + case var _ when this.TryGetField(source, out _): + this.Translate(source); + this._filter.Append("/any(t: t eq "); + this.Translate(item); + this._filter.Append(')'); + return; + + // Contains over inline enumerable + case NewArrayExpression newArray: + var elements = new object?[newArray.Expressions.Count]; + + for (var i = 0; i < newArray.Expressions.Count; i++) + { + if (!TryGetConstant(newArray.Expressions[i], out var elementValue)) + { + throw new NotSupportedException("Invalid element in array"); + } + + elements[i] = elementValue; + } + + ProcessInlineEnumerable(elements, item); + return; + + // Contains over captured enumerable (we inline) + case var _ when TryGetConstant(source, out var constantEnumerable) + && constantEnumerable is IEnumerable enumerable and not string: + ProcessInlineEnumerable(enumerable, item); + return; + + default: + throw new NotSupportedException("Unsupported Contains expression"); + } + + void ProcessInlineEnumerable(IEnumerable elements, Expression item) + { + if (item.Type != typeof(string)) + { + throw new NotSupportedException("Contains over non-string arrays is not supported"); + } + + this._filter.Append("search.in("); + this.Translate(item); + this._filter.Append(", '"); + + string delimiter = ", "; + var startingPosition = this._filter.Length; + +RestartLoop: + var isFirst = true; + foreach (string element in elements) + { + if (isFirst) + { + isFirst = false; + } + else + { + this._filter.Append(delimiter); + } + + // The default delimiter for search.in() is comma or space. + // If any element contains a comma or space, we switch to using pipe as the delimiter. + // If any contains a pipe, we throw (for now). + switch (delimiter) + { + case ", ": + if (element.IndexOfAny(s_searchInDefaultDelimiter) > -1) + { + delimiter = "|"; + this._filter.Length = startingPosition; + goto RestartLoop; + } + + break; + + case "|": + if (element.Contains('|')) + { + throw new NotSupportedException("Some elements contain both commas/spaces and pipes, cannot translate Contains"); + } + + break; + } + + this._filter.Append(element.Replace("'", "''")); + } + + this._filter.Append('\''); + + if (delimiter != ", ") + { + this._filter + .Append(", '") + .Append(delimiter) + .Append('\''); + } + + this._filter.Append(')'); + } + } + + private void TranslateUnary(UnaryExpression unary) + { + switch (unary.NodeType) + { + case ExpressionType.Not: + // Special handling for !(a == b) and !(a != b) + if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) + { + this.TranslateBinary( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + return; + } + + this._filter.Append("(not "); + this.Translate(unary.Operand); + this._filter.Append(')'); + return; + + default: + throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); + } + } + + private bool TryGetField(Expression expression, [NotNullWhen(true)] out string? field) + { + if (expression is MemberExpression member && member.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out field)) + { + throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + field = null; + return false; + } + + private static bool TryGetCapturedValue(Expression expression, out object? capturedValue) + { + if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + { + capturedValue = fieldInfo.GetValue(constant.Value); + return true; + } + + capturedValue = null; + return false; + } + + private static bool TryGetConstant(Expression expression, out object? constantValue) + { + switch (expression) + { + case ConstantExpression { Value: var v }: + constantValue = v; + return true; + + case var _ when TryGetCapturedValue(expression, out var capturedValue): + constantValue = capturedValue; + return true; + + default: + constantValue = null; + return false; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionSearchMapping.cs index ced35f244c5e..732b6aeae42c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreCollectionSearchMapping.cs @@ -12,6 +12,7 @@ namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; /// internal static class AzureAISearchVectorStoreCollectionSearchMapping { +#pragma warning disable CS0618 // VectorSearchFilter is obsolete /// /// Build an OData filter string from the provided . /// @@ -19,10 +20,10 @@ internal static class AzureAISearchVectorStoreCollectionSearchMapping /// A mapping of data model property names to the names under which they are stored. /// The OData filter string. /// Thrown when a provided filter value is not supported. - public static string BuildFilterString(VectorSearchFilter? basicVectorSearchFilter, IReadOnlyDictionary storagePropertyNames) + public static string BuildLegacyFilterString(VectorSearchFilter basicVectorSearchFilter, IReadOnlyDictionary storagePropertyNames) { var filterString = string.Empty; - if (basicVectorSearchFilter?.FilterClauses is not null) + if (basicVectorSearchFilter.FilterClauses is not null) { // Map Equality clauses. var filterStrings = basicVectorSearchFilter?.FilterClauses.OfType().Select(x => @@ -60,6 +61,7 @@ public static string BuildFilterString(VectorSearchFilter? basicVectorSearchFilt return filterString; } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete /// /// Gets the name of the name under which the property with the given name is stored. diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs index bdf25bd2b8a4..9e92f5bbb722 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchVectorStoreRecordCollection.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Nodes; @@ -14,7 +15,6 @@ using Azure.Search.Documents.Indexes.Models; using Azure.Search.Documents.Models; using Microsoft.Extensions.VectorData; -using VectorData = Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; @@ -66,7 +66,7 @@ public sealed class AzureAISearchVectorStoreRecordCollection : IVectorS ]; /// The default options for vector search. - private static readonly VectorData.VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// Azure AI Search client that can be used to manage the list of indices in an Azure AI Search Service. private readonly SearchIndexClient _searchIndexClient; @@ -314,7 +314,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable reco } /// - public Task> VectorizedSearchAsync(TVector vector, VectorData.VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); @@ -335,7 +335,17 @@ public Task> VectorizedSearchAsync(TVector // Configure search settings. var vectorQueries = new List(); vectorQueries.Add(new VectorizedQuery(floatVector) { KNearestNeighborsCount = internalOptions.Top, Fields = { vectorFieldName } }); - var filterString = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(internalOptions.Filter, this._propertyReader.JsonPropertyNamesMap); + +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + // Build filter object. + var filter = internalOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(legacyFilter, this._propertyReader.JsonPropertyNamesMap), + { NewFilter: Expression> newFilter } => new AzureAISearchFilterTranslator().Translate(newFilter, this._propertyReader.StoragePropertyNamesMap), + _ => null + }; +#pragma warning restore CS0618 // Build search options. var searchOptions = new SearchOptions @@ -343,9 +353,14 @@ public Task> VectorizedSearchAsync(TVector VectorSearch = new(), Size = internalOptions.Top, Skip = internalOptions.Skip, - Filter = filterString, IncludeTotalCount = internalOptions.IncludeTotalCount, }; + + if (filter is not null) + { + searchOptions.Filter = filter; + } + searchOptions.VectorSearch.Queries.AddRange(vectorQueries); // Filter out vector fields if requested. @@ -359,7 +374,7 @@ public Task> VectorizedSearchAsync(TVector } /// - public Task> VectorizableTextSearchAsync(string searchText, VectorData.VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(searchText); @@ -375,7 +390,17 @@ public Task> VectorizableTextSearchAsync(string sea // Configure search settings. var vectorQueries = new List(); vectorQueries.Add(new VectorizableTextQuery(searchText) { KNearestNeighborsCount = internalOptions.Top, Fields = { vectorFieldName } }); - var filterString = AzureAISearchVectorStoreCollectionSearchMapping.BuildFilterString(internalOptions.Filter, this._propertyReader.JsonPropertyNamesMap); + +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + // Build filter object. + var filter = internalOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => AzureAISearchVectorStoreCollectionSearchMapping.BuildLegacyFilterString(legacyFilter, this._propertyReader.JsonPropertyNamesMap), + { NewFilter: Expression> newFilter } => new AzureAISearchFilterTranslator().Translate(newFilter, this._propertyReader.StoragePropertyNamesMap), + _ => null + }; +#pragma warning restore CS0618 // Build search options. var searchOptions = new SearchOptions @@ -383,9 +408,14 @@ public Task> VectorizableTextSearchAsync(string sea VectorSearch = new(), Size = internalOptions.Top, Skip = internalOptions.Skip, - Filter = filterString, IncludeTotalCount = internalOptions.IncludeTotalCount, }; + + if (filter is not null) + { + searchOptions.Filter = filter; + } + searchOptions.VectorSearch.Queries.AddRange(vectorQueries); // Filter out vector fields if requested. diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBFilterTranslator.cs new file mode 100644 index 000000000000..6c0b4e44e23b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBFilterTranslator.cs @@ -0,0 +1,258 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using MongoDB.Bson; + +namespace Microsoft.SemanticKernel.Connectors.MongoDB; + +// MongoDB query reference: https://www.mongodb.com/docs/manual/reference/operator/query +// Information specific to vector search pre-filter: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter +internal class AzureCosmosDBMongoDBFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + + internal BsonDocument Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + { + this._storagePropertyNames = storagePropertyNames; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + return this.Translate(lambdaExpression.Body); + } + + private BsonDocument Translate(Expression? node) + => node switch + { + BinaryExpression + { + NodeType: ExpressionType.Equal or ExpressionType.NotEqual + or ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual + or ExpressionType.LessThan or ExpressionType.LessThanOrEqual + } binary + => this.TranslateEqualityComparison(binary), + + BinaryExpression { NodeType: ExpressionType.AndAlso or ExpressionType.OrElse } andOr + => this.TranslateAndOr(andOr), + UnaryExpression { NodeType: ExpressionType.Not } not + => this.TranslateNot(not), + + // MemberExpression is generally handled within e.g. TranslateEqualityComparison; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) + MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _) + => this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(true))), + + MethodCallExpression methodCall => this.TranslateMethodCall(methodCall), + + _ => throw new NotSupportedException("The following NodeType is unsupported: " + node?.NodeType) + }; + + private BsonDocument TranslateEqualityComparison(BinaryExpression binary) + { + if ((this.TryTranslateFieldAccess(binary.Left, out var storagePropertyName) && TryGetConstant(binary.Right, out var value)) + || (this.TryTranslateFieldAccess(binary.Right, out storagePropertyName) && TryGetConstant(binary.Left, out value))) + { + if (value is null) + { + throw new NotSupportedException("MongogDB does not support null checks in vector search pre-filters"); + } + + // Short form of equality (instead of $eq) + if (binary.NodeType is ExpressionType.Equal) + { + return new BsonDocument { [storagePropertyName] = BsonValue.Create(value) }; + } + + var filterOperator = binary.NodeType switch + { + ExpressionType.NotEqual => "$ne", + ExpressionType.GreaterThan => "$gt", + ExpressionType.GreaterThanOrEqual => "$gte", + ExpressionType.LessThan => "$lt", + ExpressionType.LessThanOrEqual => "$lte", + + _ => throw new UnreachableException() + }; + + return new BsonDocument { [storagePropertyName] = new BsonDocument { [filterOperator] = BsonValue.Create(value) } }; + } + + throw new NotSupportedException("Invalid equality/comparison"); + } + + private BsonDocument TranslateAndOr(BinaryExpression andOr) + { + var mongoOperator = andOr.NodeType switch + { + ExpressionType.AndAlso => "$and", + ExpressionType.OrElse => "$or", + _ => throw new UnreachableException() + }; + + var (left, right) = (this.Translate(andOr.Left), this.Translate(andOr.Right)); + + var nestedLeft = left.ElementCount == 1 && left.Elements.First() is var leftElement && leftElement.Name == mongoOperator ? (BsonArray)leftElement.Value : null; + var nestedRight = right.ElementCount == 1 && right.Elements.First() is var rightElement && rightElement.Name == mongoOperator ? (BsonArray)rightElement.Value : null; + + switch ((nestedLeft, nestedRight)) + { + case (not null, not null): + nestedLeft.AddRange(nestedRight); + return left; + case (not null, null): + nestedLeft.Add(right); + return left; + case (null, not null): + nestedRight.Insert(0, left); + return right; + case (null, null): + return new BsonDocument { [mongoOperator] = new BsonArray([left, right]) }; + } + } + + private BsonDocument TranslateNot(UnaryExpression not) + { + switch (not.Operand) + { + // Special handling for !(a == b) and !(a != b) + case BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary: + return this.TranslateEqualityComparison( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + + // Not over bool field (Filter => r => !r.Bool) + case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): + return this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(false))); + } + + var operand = this.Translate(not.Operand); + + // Identify NOT over $in, transform to $nin (https://www.mongodb.com/docs/manual/reference/operator/query/nin/#mongodb-query-op.-nin) + if (operand.ElementCount == 1 && operand.Elements.First() is { Name: var fieldName, Value: BsonDocument nested } && + nested.ElementCount == 1 && nested.Elements.First() is { Name: "$in", Value: BsonArray values }) + { + return new BsonDocument { [fieldName] = new BsonDocument { ["$nin"] = values } }; + } + + throw new NotSupportedException("MongogDB does not support the NOT operator in vector search pre-filters"); + } + + private BsonDocument TranslateMethodCall(MethodCallExpression methodCall) + => methodCall switch + { + // Enumerable.Contains() + { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable) + => this.TranslateContains(source, item), + + // List.Contains() + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>) => this.TranslateContains(source, item), + + _ => throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}") + }; + + private BsonDocument TranslateContains(Expression source, Expression item) + { + switch (source) + { + // Contains over array column (r => r.Strings.Contains("foo")) + case var _ when this.TryTranslateFieldAccess(source, out _): + throw new NotSupportedException("MongoDB does not support Contains within array fields ($elemMatch) in vector search pre-filters"); + + // Contains over inline enumerable + case NewArrayExpression newArray: + var elements = new object?[newArray.Expressions.Count]; + + for (var i = 0; i < newArray.Expressions.Count; i++) + { + if (!TryGetConstant(newArray.Expressions[i], out var elementValue)) + { + throw new NotSupportedException("Invalid element in array"); + } + + elements[i] = elementValue; + } + + return ProcessInlineEnumerable(elements, item); + + // Contains over captured enumerable (we inline) + case var _ when TryGetConstant(source, out var constantEnumerable) + && constantEnumerable is IEnumerable enumerable and not string: + return ProcessInlineEnumerable(enumerable, item); + + default: + throw new NotSupportedException("Unsupported Contains expression"); + } + + BsonDocument ProcessInlineEnumerable(IEnumerable elements, Expression item) + { + if (!this.TryTranslateFieldAccess(item, out var storagePropertyName)) + { + throw new NotSupportedException("Unsupported item type in Contains"); + } + + return new BsonDocument + { + [storagePropertyName] = new BsonDocument + { + ["$in"] = new BsonArray(from object? element in elements select BsonValue.Create(element)) + } + }; + } + } + + private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + { + if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) + { + throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + storagePropertyName = null; + return false; + } + + private static bool TryGetConstant(Expression expression, out object? constantValue) + { + switch (expression) + { + case ConstantExpression { Value: var v }: + constantValue = v; + return true; + + // This identifies compiler-generated closure types which contain captured variables. + case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): + constantValue = fieldInfo.GetValue(constant.Value); + return true; + + default: + constantValue = null; + return false; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs index 6e41eb7f3cb9..32377244112c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs @@ -20,6 +20,7 @@ internal static class AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping /// Returns distance function specified on vector property or default . public static string GetVectorPropertyDistanceFunction(string? distanceFunction) => !string.IsNullOrWhiteSpace(distanceFunction) ? distanceFunction! : MongoDBConstants.DefaultDistanceFunction; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete /// /// Build Azure CosmosDB MongoDB filter from the provided . /// @@ -86,6 +87,7 @@ internal static class AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping return filter; } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete /// Returns search part of the search query for index kind. public static BsonDocument GetSearchQueryForHnswIndex( diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs index d54a184e5771..a5d355150da3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Reflection; using System.Runtime.CompilerServices; using System.Threading; @@ -12,6 +13,7 @@ using MongoDB.Bson; using MongoDB.Bson.Serialization.Attributes; using MongoDB.Driver; +using MEVD = Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; @@ -33,7 +35,7 @@ public sealed class AzureCosmosDBMongoDBVectorStoreRecordCollection : I private const string DocumentPropertyName = "document"; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly MEVD.VectorSearchOptions s_defaultVectorSearchOptions = new(); /// that can be used to manage the collections in Azure CosmosDB MongoDB. private readonly IMongoDatabase _mongoDatabase; @@ -244,7 +246,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable reco /// public async Task> VectorizedSearchAsync( TVector vector, - VectorSearchOptions? options = null, + MEVD.VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); @@ -270,9 +272,17 @@ public async Task> VectorizedSearchAsync( var vectorPropertyName = this._storagePropertyNames[vectorProperty.DataModelPropertyName]; - var filter = AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter( - searchOptions.Filter, - this._storagePropertyNames); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + var filter = searchOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.BuildFilter( + legacyFilter, + this._storagePropertyNames), + { NewFilter: Expression> newFilter } => new AzureCosmosDBMongoDBFilterTranslator().Translate(newFilter, this._storagePropertyNames), + _ => null + }; +#pragma warning restore CS0618 // Constructing a query to fetch "skip + top" total items // to perform skip logic locally, since skip option is not part of API. @@ -371,7 +381,7 @@ private async Task> FindAsync(FilterDefinition> EnumerateAndMapSearchResultsAsync( IAsyncCursor cursor, - VectorSearchOptions searchOptions, + MEVD.VectorSearchOptions searchOptions, [EnumeratorCancellation] CancellationToken cancellationToken) { const string OperationName = "Aggregate"; diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLConstants.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLConstants.cs index 87aeee36355e..6dbb0d440b45 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLConstants.cs @@ -13,5 +13,5 @@ internal static class AzureCosmosDBNoSQLConstants /// Variable name for table in Azure CosmosDB NoSQL queries. /// Can be any string. Example: "SELECT x.Name FROM x". /// - internal const string TableQueryVariableName = "x"; + internal const char ContainerAlias = 'x'; } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLFilter.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLFilter.cs deleted file mode 100644 index 8cf6636c73e7..000000000000 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLFilter.cs +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Collections.Generic; - -namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; - -/// -/// Contains properties required to build query with filtering conditions. -/// -internal sealed class AzureCosmosDBNoSQLFilter -{ - public List? WhereClauseArguments { get; set; } - - public Dictionary? QueryParameters { get; set; } -} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.cs index a66eb5bfb719..1b0e7dcb8a7f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Text; using Microsoft.Azure.Cosmos; using Microsoft.Extensions.VectorData; @@ -21,13 +22,13 @@ internal static class AzureCosmosDBNoSQLVectorStoreCollectionQueryBuilder /// /// Builds to get items from Azure CosmosDB NoSQL using vector search. /// - public static QueryDefinition BuildSearchQuery( + public static QueryDefinition BuildSearchQuery( TVector vector, List fields, Dictionary storagePropertyNames, string vectorPropertyName, string scorePropertyName, - VectorSearchOptions searchOptions) + VectorSearchOptions searchOptions) { Verify.NotNull(vector); @@ -36,7 +37,7 @@ public static QueryDefinition BuildSearchQuery( const string LimitVariableName = "@limit"; const string TopVariableName = "@top"; - var tableVariableName = AzureCosmosDBNoSQLConstants.TableQueryVariableName; + var tableVariableName = AzureCosmosDBNoSQLConstants.ContainerAlias; var fieldsArgument = fields.Select(field => $"{tableVariableName}.{field}"); var vectorDistanceArgument = $"VectorDistance({tableVariableName}.{vectorPropertyName}, {VectorVariableName})"; @@ -44,19 +45,22 @@ public static QueryDefinition BuildSearchQuery( var selectClauseArguments = string.Join(SelectClauseDelimiter, [.. fieldsArgument, vectorDistanceArgumentWithAlias]); - var filter = BuildSearchFilter(searchOptions.Filter, storagePropertyNames); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + // Build filter object. + var (whereClause, filterParameters) = searchOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => BuildSearchFilter(legacyFilter, storagePropertyNames), + { NewFilter: Expression> newFilter } => new AzureCosmosDBNoSqlFilterTranslator().Translate(newFilter, storagePropertyNames), + _ => (null, []) + }; +#pragma warning restore CS0618 // VectorSearchFilter is obsolete - var filterQueryParameters = filter?.QueryParameters; - var filterWhereClauseArguments = filter?.WhereClauseArguments; - var queryParameters = new Dictionary + var queryParameters = new Dictionary { [VectorVariableName] = vector }; - var whereClause = filterWhereClauseArguments is { Count: > 0 } ? - $"WHERE {string.Join(AndConditionDelimiter, filterWhereClauseArguments)}" : - string.Empty; - // If Offset is not configured, use Top parameter instead of Limit/Offset // since it's more optimized. var topArgument = searchOptions.Skip == 0 ? $"TOP {TopVariableName} " : string.Empty; @@ -66,9 +70,9 @@ public static QueryDefinition BuildSearchQuery( builder.AppendLine($"SELECT {topArgument}{selectClauseArguments}"); builder.AppendLine($"FROM {tableVariableName}"); - if (filterWhereClauseArguments is { Count: > 0 }) + if (whereClause is not null) { - builder.AppendLine($"WHERE {string.Join(AndConditionDelimiter, filterWhereClauseArguments)}"); + builder.Append("WHERE ").AppendLine(whereClause); } builder.AppendLine($"ORDER BY {vectorDistanceArgument}"); @@ -86,9 +90,9 @@ public static QueryDefinition BuildSearchQuery( var queryDefinition = new QueryDefinition(builder.ToString()); - if (filterQueryParameters is { Count: > 0 }) + if (filterParameters is { Count: > 0 }) { - queryParameters = queryParameters.Union(filterQueryParameters).ToDictionary(k => k.Key, v => v.Value); + queryParameters = queryParameters.Union(filterParameters).ToDictionary(k => k.Key, v => v.Value); } foreach (var queryParameter in queryParameters) @@ -113,7 +117,7 @@ public static QueryDefinition BuildSelectQuery( const string RecordKeyVariableName = "@rk"; const string PartitionKeyVariableName = "@pk"; - var tableVariableName = AzureCosmosDBNoSQLConstants.TableQueryVariableName; + var tableVariableName = AzureCosmosDBNoSQLConstants.ContainerAlias; var selectClauseArguments = string.Join(SelectClauseDelimiter, fields.Select(field => $"{tableVariableName}.{field}")); @@ -123,10 +127,11 @@ public static QueryDefinition BuildSelectQuery( $"({tableVariableName}.{keyStoragePropertyName} = {RecordKeyVariableName}{index} {AndConditionDelimiter} " + $"{tableVariableName}.{partitionKeyStoragePropertyName} = {PartitionKeyVariableName}{index})")); - var query = - $"SELECT {selectClauseArguments} " + - $"FROM {tableVariableName} " + - $"WHERE {whereClauseArguments} "; + var query = $""" + SELECT {selectClauseArguments} + FROM {tableVariableName} + WHERE {whereClauseArguments} + """; var queryDefinition = new QueryDefinition(query); @@ -147,44 +152,43 @@ public static QueryDefinition BuildSelectQuery( #region private - private static AzureCosmosDBNoSQLFilter? BuildSearchFilter( - VectorSearchFilter? filter, +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + private static (string WhereClause, Dictionary Parameters) BuildSearchFilter( + VectorSearchFilter filter, Dictionary storagePropertyNames) { const string EqualOperator = "="; const string ArrayContainsOperator = "ARRAY_CONTAINS"; const string ConditionValueVariableName = "@cv"; - var tableVariableName = AzureCosmosDBNoSQLConstants.TableQueryVariableName; + var tableVariableName = AzureCosmosDBNoSQLConstants.ContainerAlias; - var filterClauses = filter?.FilterClauses.ToList(); - - if (filterClauses is not { Count: > 0 }) - { - return null; - } + var filterClauses = filter.FilterClauses.ToList(); - var whereClauseArguments = new List(); - var queryParameters = new Dictionary(); + var whereClauseBuilder = new StringBuilder(); + var queryParameters = new Dictionary(); for (var i = 0; i < filterClauses.Count; i++) { + if (i > 0) + { + whereClauseBuilder.Append(" AND "); + } var filterClause = filterClauses[i]; string queryParameterName = $"{ConditionValueVariableName}{i}"; object queryParameterValue; - string whereClauseArgument; if (filterClause is EqualToFilterClause equalToFilterClause) { var propertyName = GetStoragePropertyName(equalToFilterClause.FieldName, storagePropertyNames); - whereClauseArgument = $"{tableVariableName}.{propertyName} {EqualOperator} {queryParameterName}"; + whereClauseBuilder.Append($"{tableVariableName}.{propertyName} {EqualOperator} {queryParameterName}"); queryParameterValue = equalToFilterClause.Value; } else if (filterClause is AnyTagEqualToFilterClause anyTagEqualToFilterClause) { var propertyName = GetStoragePropertyName(anyTagEqualToFilterClause.FieldName, storagePropertyNames); - whereClauseArgument = $"{ArrayContainsOperator}({tableVariableName}.{propertyName}, {queryParameterName})"; + whereClauseBuilder.Append($"{ArrayContainsOperator}({tableVariableName}.{propertyName}, {queryParameterName})"); queryParameterValue = anyTagEqualToFilterClause.Value; } else @@ -196,16 +200,12 @@ public static QueryDefinition BuildSelectQuery( nameof(AnyTagEqualToFilterClause)])}"); } - whereClauseArguments.Add(whereClauseArgument); queryParameters.Add(queryParameterName, queryParameterValue); } - return new AzureCosmosDBNoSQLFilter - { - WhereClauseArguments = whereClauseArguments, - QueryParameters = queryParameters, - }; + return (whereClauseBuilder.ToString(), queryParameters); } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete private static string GetStoragePropertyName(string propertyName, Dictionary storagePropertyNames) { diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollection.cs index 6ab9222d2a14..53463cb943b4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollection.cs @@ -69,7 +69,7 @@ public sealed class AzureCosmosDBNoSQLVectorStoreRecordCollection : ]; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// that can be used to manage the collections in Azure CosmosDB NoSQL. private readonly Database _database; @@ -355,7 +355,7 @@ async IAsyncEnumerable IVectorStoreRecordCollect /// public Task> VectorizedSearchAsync( TVector vector, - VectorSearchOptions? options = null, + VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { const string OperationName = "VectorizedSearch"; @@ -679,7 +679,7 @@ private async IAsyncEnumerable> MapSearchResultsAsyn IAsyncEnumerable jsonObjects, string scorePropertyName, string operationName, - VectorSearchOptions searchOptions, + VectorSearchOptions searchOptions, [EnumeratorCancellation] CancellationToken cancellationToken) { await foreach (var jsonObject in jsonObjects.ConfigureAwait(false)) diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs new file mode 100644 index 000000000000..e18f176c2ea7 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBNoSQL/AzureCosmosDBNoSqlFilterTranslator.cs @@ -0,0 +1,282 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; + +internal class AzureCosmosDBNoSqlFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + + private readonly Dictionary _parameters = new(); + private readonly StringBuilder _sql = new(); + + internal (string WhereClause, Dictionary Parameters) Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + { + Debug.Assert(this._sql.Length == 0); + + this._storagePropertyNames = storagePropertyNames; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + this.Translate(lambdaExpression.Body); + return (this._sql.ToString(), this._parameters); + } + + private void Translate(Expression? node) + { + switch (node) + { + case BinaryExpression binary: + this.TranslateBinary(binary); + return; + + case ConstantExpression constant: + this.TranslateConstant(constant); + return; + + case MemberExpression member: + this.TranslateMember(member); + return; + + case NewArrayExpression newArray: + this.TranslateNewArray(newArray); + return; + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall); + return; + + case UnaryExpression unary: + this.TranslateUnary(unary); + return; + + default: + throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); + } + } + + private void TranslateBinary(BinaryExpression binary) + { + this._sql.Append('('); + this.Translate(binary.Left); + + this._sql.Append(binary.NodeType switch + { + ExpressionType.Equal => " = ", + ExpressionType.NotEqual => " <> ", + + ExpressionType.GreaterThan => " > ", + ExpressionType.GreaterThanOrEqual => " >= ", + ExpressionType.LessThan => " < ", + ExpressionType.LessThanOrEqual => " <= ", + + ExpressionType.AndAlso => " AND ", + ExpressionType.OrElse => " OR ", + + _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) + }); + + this.Translate(binary.Right); + this._sql.Append(')'); + } + + private void TranslateConstant(ConstantExpression constant) + { + // TODO: Nullable + switch (constant.Value) + { + case byte b: + this._sql.Append(b); + return; + case short s: + this._sql.Append(s); + return; + case int i: + this._sql.Append(i); + return; + case long l: + this._sql.Append(l); + return; + + case string s: + this._sql.Append('"').Append(s.Replace(@"\", @"\\").Replace("\"", "\\\"")).Append('"'); + return; + case bool b: + this._sql.Append(b ? "true" : "false"); + return; + case Guid g: + this._sql.Append('"').Append(g.ToString()).Append('"'); + return; + + case DateTime: + case DateTimeOffset: + throw new NotImplementedException(); + + case Array: + throw new NotImplementedException(); + + case null: + this._sql.Append("null"); + return; + + default: + throw new NotSupportedException("Unsupported constant type: " + constant.Value.GetType().Name); + } + } + + private void TranslateMember(MemberExpression memberExpression) + { + switch (memberExpression) + { + case var _ when this.TryGetPropertyAccess(memberExpression, out var column): + this._sql.Append(AzureCosmosDBNoSQLConstants.ContainerAlias).Append("[\"").Append(column).Append("\"]"); + return; + + // Identify captured lambda variables, translate to Cosmos parameters (@foo, @bar...) + case var _ when TryGetCapturedValue(memberExpression, out var name, out var value): + // Duplicate parameter name, create a new parameter with a different name + // TODO: Share the same parameter when it references the same captured value + if (this._parameters.ContainsKey(name)) + { + var baseName = name; + var i = 0; + do + { + name = baseName + (i++); + } while (this._parameters.ContainsKey(name)); + } + + name = '@' + name; + this._parameters.Add(name, value); + this._sql.Append(name); + return; + + default: + throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + } + } + + private void TranslateNewArray(NewArrayExpression newArray) + { + this._sql.Append('['); + + for (var i = 0; i < newArray.Expressions.Count; i++) + { + if (i > 0) + { + this._sql.Append(", "); + } + + this.Translate(newArray.Expressions[i]); + } + + this._sql.Append(']'); + } + + private void TranslateMethodCall(MethodCallExpression methodCall) + { + switch (methodCall) + { + // Enumerable.Contains() + case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable): + this.TranslateContains(source, item); + return; + + // List.Contains() + case + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>): + this.TranslateContains(source, item); + return; + + default: + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + } + + private void TranslateContains(Expression source, Expression item) + { + this._sql.Append("ARRAY_CONTAINS("); + this.Translate(source); + this._sql.Append(", "); + this.Translate(item); + this._sql.Append(')'); + } + + private void TranslateUnary(UnaryExpression unary) + { + switch (unary.NodeType) + { + // Special handling for !(a == b) and !(a != b) + case ExpressionType.Not: + if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) + { + this.TranslateBinary( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + return; + } + + this._sql.Append("(NOT "); + this.Translate(unary.Operand); + this._sql.Append(')'); + return; + + default: + throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); + } + } + + private bool TryGetPropertyAccess(Expression expression, [NotNullWhen(true)] out string? column) + { + if (expression is MemberExpression member && member.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) + { + throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + column = null; + return false; + } + + private static bool TryGetCapturedValue(Expression expression, [NotNullWhen(true)] out string? name, out object? value) + { + if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + { + name = fieldInfo.Name; + value = fieldInfo.GetValue(constant.Value); + return true; + } + + name = null; + value = null; + return false; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreCollectionSearchMapping.cs index 7ecea345cb85..6b33671cef9f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreCollectionSearchMapping.cs @@ -88,6 +88,7 @@ public static float ConvertScore(float score, string? distanceFunction) } } +#pragma warning disable CS0618 // VectorSearchFilter is obsolete /// /// Filter the provided records using the provided filter definition. /// @@ -95,15 +96,15 @@ public static float ConvertScore(float score, string? distanceFunction) /// The records to filter. /// The filtered records. /// Thrown when an unsupported filter clause is encountered. - public static IEnumerable FilterRecords(VectorSearchFilter? filter, IEnumerable records) + public static IEnumerable FilterRecords(VectorSearchFilter filter, IEnumerable records) { - if (filter == null) - { - return records; - } - return records.Where(record => { + if (record is null) + { + return false; + } + var result = true; // Run each filter clause against the record, and AND the results together. @@ -197,6 +198,7 @@ private static bool CheckAnyTagEqualTo(object record, AnyTagEqualToFilterClause return false; } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete /// /// Get the property info for the provided property name on the record. diff --git a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollection.cs index a2fe21e0cfc6..03fe957cca07 100644 --- a/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.InMemory/InMemoryVectorStoreRecordCollection.cs @@ -4,6 +4,7 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -29,7 +30,7 @@ public sealed class InMemoryVectorStoreRecordCollection : IVector ]; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// Internal storage for all of the record collections. private readonly ConcurrentDictionary> _internalCollections; @@ -210,7 +211,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record /// #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - Need to satisfy the interface which returns IAsyncEnumerable - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) #pragma warning restore CS1998 { Verify.NotNull(vector); @@ -234,13 +235,22 @@ public async Task> VectorizedSearchAsync(T throw new InvalidOperationException($"The collection does not have a vector field named '{internalOptions.VectorPropertyName}', so vector search is not possible."); } +#pragma warning disable CS0618 // VectorSearchFilter is obsolete // Filter records using the provided filter before doing the vector comparison. - var filteredRecords = InMemoryVectorStoreCollectionSearchMapping.FilterRecords(internalOptions.Filter, this.GetCollectionDictionary().Values); + var allValues = this.GetCollectionDictionary().Values.Cast(); + var filteredRecords = internalOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => InMemoryVectorStoreCollectionSearchMapping.FilterRecords(legacyFilter, allValues), + { NewFilter: Expression> newFilter } => allValues.AsQueryable().Where(newFilter), + _ => allValues + }; +#pragma warning restore CS0618 // VectorSearchFilter is obsolete // Compare each vector in the filtered results with the provided vector. - var results = filteredRecords.Select((record) => + var results = filteredRecords.Select(record => { - var vectorObject = this._vectorResolver(vectorPropertyName!, (TRecord)record); + var vectorObject = this._vectorResolver(vectorPropertyName!, record); if (vectorObject is not ReadOnlyMemory dbVector) { return null; diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBFilterTranslator.cs new file mode 100644 index 000000000000..202908de1c0b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBFilterTranslator.cs @@ -0,0 +1,258 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using MongoDB.Bson; + +namespace Microsoft.SemanticKernel.Connectors.MongoDB; + +// MongoDB query reference: https://www.mongodb.com/docs/manual/reference/operator/query +// Information specific to vector search pre-filter: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter +internal class MongoDBFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + + internal BsonDocument Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + { + this._storagePropertyNames = storagePropertyNames; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + return this.Translate(lambdaExpression.Body); + } + + private BsonDocument Translate(Expression? node) + => node switch + { + BinaryExpression + { + NodeType: ExpressionType.Equal or ExpressionType.NotEqual + or ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual + or ExpressionType.LessThan or ExpressionType.LessThanOrEqual + } binary + => this.TranslateEqualityComparison(binary), + + BinaryExpression { NodeType: ExpressionType.AndAlso or ExpressionType.OrElse } andOr + => this.TranslateAndOr(andOr), + UnaryExpression { NodeType: ExpressionType.Not } not + => this.TranslateNot(not), + + // MemberExpression is generally handled within e.g. TranslateEqualityComparison; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) + MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _) + => this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(true))), + + MethodCallExpression methodCall => this.TranslateMethodCall(methodCall), + + _ => throw new NotSupportedException("The following NodeType is unsupported: " + node?.NodeType) + }; + + private BsonDocument TranslateEqualityComparison(BinaryExpression binary) + { + if ((this.TryTranslateFieldAccess(binary.Left, out var storagePropertyName) && TryGetConstant(binary.Right, out var value)) + || (this.TryTranslateFieldAccess(binary.Right, out storagePropertyName) && TryGetConstant(binary.Left, out value))) + { + if (value is null) + { + throw new NotSupportedException("MongogDB does not support null checks in vector search pre-filters"); + } + + // Short form of equality (instead of $eq) + if (binary.NodeType is ExpressionType.Equal) + { + return new BsonDocument { [storagePropertyName] = BsonValue.Create(value) }; + } + + var filterOperator = binary.NodeType switch + { + ExpressionType.NotEqual => "$ne", + ExpressionType.GreaterThan => "$gt", + ExpressionType.GreaterThanOrEqual => "$gte", + ExpressionType.LessThan => "$lt", + ExpressionType.LessThanOrEqual => "$lte", + + _ => throw new UnreachableException() + }; + + return new BsonDocument { [storagePropertyName] = new BsonDocument { [filterOperator] = BsonValue.Create(value) } }; + } + + throw new NotSupportedException("Invalid equality/comparison"); + } + + private BsonDocument TranslateAndOr(BinaryExpression andOr) + { + var mongoOperator = andOr.NodeType switch + { + ExpressionType.AndAlso => "$and", + ExpressionType.OrElse => "$or", + _ => throw new UnreachableException() + }; + + var (left, right) = (this.Translate(andOr.Left), this.Translate(andOr.Right)); + + var nestedLeft = left.ElementCount == 1 && left.Elements.First() is var leftElement && leftElement.Name == mongoOperator ? (BsonArray)leftElement.Value : null; + var nestedRight = right.ElementCount == 1 && right.Elements.First() is var rightElement && rightElement.Name == mongoOperator ? (BsonArray)rightElement.Value : null; + + switch ((nestedLeft, nestedRight)) + { + case (not null, not null): + nestedLeft.AddRange(nestedRight); + return left; + case (not null, null): + nestedLeft.Add(right); + return left; + case (null, not null): + nestedRight.Insert(0, left); + return right; + case (null, null): + return new BsonDocument { [mongoOperator] = new BsonArray([left, right]) }; + } + } + + private BsonDocument TranslateNot(UnaryExpression not) + { + switch (not.Operand) + { + // Special handling for !(a == b) and !(a != b) + case BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary: + return this.TranslateEqualityComparison( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + + // Not over bool field (Filter => r => !r.Bool) + case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): + return this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(false))); + } + + var operand = this.Translate(not.Operand); + + // Identify NOT over $in, transform to $nin (https://www.mongodb.com/docs/manual/reference/operator/query/nin/#mongodb-query-op.-nin) + if (operand.ElementCount == 1 && operand.Elements.First() is { Name: var fieldName, Value: BsonDocument nested } && + nested.ElementCount == 1 && nested.Elements.First() is { Name: "$in", Value: BsonArray values }) + { + return new BsonDocument { [fieldName] = new BsonDocument { ["$nin"] = values } }; + } + + throw new NotSupportedException("MongogDB does not support the NOT operator in vector search pre-filters"); + } + + private BsonDocument TranslateMethodCall(MethodCallExpression methodCall) + => methodCall switch + { + // Enumerable.Contains() + { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable) + => this.TranslateContains(source, item), + + // List.Contains() + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>) => this.TranslateContains(source, item), + + _ => throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}") + }; + + private BsonDocument TranslateContains(Expression source, Expression item) + { + switch (source) + { + // Contains over array column (r => r.Strings.Contains("foo")) + case var _ when this.TryTranslateFieldAccess(source, out _): + throw new NotSupportedException("MongoDB does not support Contains within array fields ($elemMatch) in vector search pre-filters"); + + // Contains over inline enumerable + case NewArrayExpression newArray: + var elements = new object?[newArray.Expressions.Count]; + + for (var i = 0; i < newArray.Expressions.Count; i++) + { + if (!TryGetConstant(newArray.Expressions[i], out var elementValue)) + { + throw new NotSupportedException("Invalid element in array"); + } + + elements[i] = elementValue; + } + + return ProcessInlineEnumerable(elements, item); + + // Contains over captured enumerable (we inline) + case var _ when TryGetConstant(source, out var constantEnumerable) + && constantEnumerable is IEnumerable enumerable and not string: + return ProcessInlineEnumerable(enumerable, item); + + default: + throw new NotSupportedException("Unsupported Contains expression"); + } + + BsonDocument ProcessInlineEnumerable(IEnumerable elements, Expression item) + { + if (!this.TryTranslateFieldAccess(item, out var storagePropertyName)) + { + throw new NotSupportedException("Unsupported item type in Contains"); + } + + return new BsonDocument + { + [storagePropertyName] = new BsonDocument + { + ["$in"] = new BsonArray(from object? element in elements select BsonValue.Create(element)) + } + }; + } + } + + private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + { + if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) + { + throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + storagePropertyName = null; + return false; + } + + private static bool TryGetConstant(Expression expression, out object? constantValue) + { + switch (expression) + { + case ConstantExpression { Value: var v }: + constantValue = v; + return true; + + // This identifies compiler-generated closure types which contain captured variables. + case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): + constantValue = fieldInfo.GetValue(constant.Value); + return true; + + default: + constantValue = null; + return false; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionSearchMapping.cs index 931b668f535d..de47f6723b23 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionSearchMapping.cs @@ -16,6 +16,7 @@ internal static class MongoDBVectorStoreCollectionSearchMapping /// Returns distance function specified on vector property or default . public static string GetVectorPropertyDistanceFunction(string? distanceFunction) => !string.IsNullOrWhiteSpace(distanceFunction) ? distanceFunction! : MongoDBConstants.DefaultDistanceFunction; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete /// /// Build MongoDB filter from the provided . /// @@ -23,13 +24,13 @@ internal static class MongoDBVectorStoreCollectionSearchMapping /// A dictionary that maps from a property name to the storage name. /// Thrown when the provided filter type is unsupported. /// Thrown when property name specified in filter doesn't exist. - public static BsonDocument? BuildFilter( - VectorSearchFilter? vectorSearchFilter, + public static BsonDocument? BuildLegacyFilter( + VectorSearchFilter vectorSearchFilter, Dictionary storagePropertyNames) { const string EqualOperator = "$eq"; - var filterClauses = vectorSearchFilter?.FilterClauses.ToList(); + var filterClauses = vectorSearchFilter.FilterClauses.ToList(); if (filterClauses is not { Count: > 0 }) { @@ -82,6 +83,7 @@ internal static class MongoDBVectorStoreCollectionSearchMapping return filter; } +#pragma warning restore CS0618 /// Returns search part of the search query. public static BsonDocument GetSearchQuery( diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollection.cs index 353b3534dab9..25fc14e8196e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollection.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Reflection; using System.Runtime.CompilerServices; using System.Threading; @@ -11,6 +12,7 @@ using MongoDB.Bson; using MongoDB.Bson.Serialization.Attributes; using MongoDB.Driver; +using MEVD = Microsoft.Extensions.VectorData; namespace Microsoft.SemanticKernel.Connectors.MongoDB; @@ -32,7 +34,7 @@ public sealed class MongoDBVectorStoreRecordCollection : IVectorStoreRe private const string DocumentPropertyName = "document"; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly MEVD.VectorSearchOptions s_defaultVectorSearchOptions = new(); /// that can be used to manage the collections in MongoDB. private readonly IMongoDatabase _mongoDatabase; @@ -247,7 +249,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable reco /// public async Task> VectorizedSearchAsync( TVector vector, - VectorSearchOptions? options = null, + MEVD.VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); @@ -273,9 +275,15 @@ public async Task> VectorizedSearchAsync( var vectorPropertyName = this._storagePropertyNames[vectorProperty.DataModelPropertyName]; - var filter = MongoDBVectorStoreCollectionSearchMapping.BuildFilter( - searchOptions.Filter, - this._storagePropertyNames); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + var filter = searchOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(legacyFilter, this._storagePropertyNames), + { NewFilter: Expression> newFilter } => new MongoDBFilterTranslator().Translate(newFilter, this._storagePropertyNames), + _ => null + }; +#pragma warning restore CS0618 // Constructing a query to fetch "skip + top" total items // to perform skip logic locally, since skip option is not part of API. @@ -383,7 +391,7 @@ private async Task> FindAsync(FilterDefinition> EnumerateAndMapSearchResultsAsync( IAsyncCursor cursor, - VectorSearchOptions searchOptions, + MEVD.VectorSearchOptions searchOptions, [EnumeratorCancellation] CancellationToken cancellationToken) { const string OperationName = "Aggregate"; diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreCollectionSearchMapping.cs index e02e18807d9c..5b3d511c6b08 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreCollectionSearchMapping.cs @@ -12,6 +12,7 @@ namespace Microsoft.SemanticKernel.Connectors.Pinecone; /// internal static class PineconeVectorStoreCollectionSearchMapping { +#pragma warning disable CS0618 // FilterClause is obsolete /// /// Build a Pinecone from a set of filter clauses. /// @@ -59,4 +60,5 @@ public static MetadataMap BuildSearchFilter(IEnumerable? filterCla return metadataMap; } +#pragma warning restore CS0618 // FilterClause is obsolete } diff --git a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollection.cs index 8a956f53f635..8e1e8cf7aaf1 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Pinecone/PineconeVectorStoreRecordCollection.cs @@ -32,7 +32,7 @@ public sealed class PineconeVectorStoreRecordCollection : IVectorStoreR private const string GetOperationName = "Get"; private const string QueryOperationName = "Query"; - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); private readonly Sdk.PineconeClient _pineconeClient; private readonly PineconeVectorStoreRecordCollectionOptions _options; @@ -246,7 +246,7 @@ await this.RunOperationAsync( } /// - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); @@ -259,9 +259,12 @@ public async Task> VectorizedSearchAsync(T // Resolve options and build filter clause. var internalOptions = options ?? s_defaultVectorSearchOptions; var mapperOptions = new StorageToDataModelMapperOptions { IncludeVectors = options?.IncludeVectors ?? false }; + +#pragma warning disable CS0618 // FilterClause is obsolete var filter = PineconeVectorStoreCollectionSearchMapping.BuildSearchFilter( internalOptions.Filter?.FilterClauses, this._propertyReader.StoragePropertyNamesMap); +#pragma warning restore CS0618 // Get the current index. var indexNamespace = this.GetIndexNamespace(); diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs index d130d2f13b44..3c864cc6537f 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreCollectionSqlBuilder.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; +using System.Linq.Expressions; using Microsoft.Extensions.VectorData; using Pgvector; @@ -124,13 +126,16 @@ internal interface IPostgresVectorStoreCollectionSqlBuilder /// /// The schema of the table. /// The name of the table. - /// The properties of the table. + /// The property reader. /// The property which the vectors to compare are stored in. /// The vector to match. - /// The filter conditions for the query. + /// The filter conditions for the query. + /// The filter conditions for the query. /// The number of records to skip. /// Specifies whether to include vectors in the result. /// The maximum number of records to return. /// The built SQL command info. - PostgresSqlCommandInfo BuildGetNearestMatchCommand(string schema, string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, VectorSearchFilter? filter, int? skip, bool includeVectors, int limit); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + PostgresSqlCommandInfo BuildGetNearestMatchCommand(string schema, string tableName, VectorStoreRecordPropertyReader propertyReader, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, VectorSearchFilter? legacyFilter, Expression>? newFilter, int? skip, bool includeVectors, int limit); +#pragma warning restore CS0618 // VectorSearchFilter is obsolete } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs index 59aa9829c568..3fb62b667a92 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/IPostgresVectorStoreDbClient.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; +using System.Linq.Expressions; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; @@ -118,15 +120,18 @@ internal interface IPostgresVectorStoreDbClient /// Gets the nearest matches to the . /// /// The name assigned to a table of entries. - /// The properties to retrieve. - /// The property which the vectors to compare are stored in. + /// The property reader. + /// The vector property. /// The to compare the table's vector with. /// The maximum number of similarity results to return. - /// Optional conditions to filter the results. + /// Optional conditions to filter the results. + /// Optional conditions to filter the results. /// The number of entries to skip. /// If true, the vectors will be returned in the entries. /// The to monitor for cancellation requests. The default is . /// An asynchronous stream of objects that the nearest matches to the . - IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync(string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, - VectorSearchFilter? filter = default, int? skip = default, bool includeVectors = false, CancellationToken cancellationToken = default); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync(string tableName, VectorStoreRecordPropertyReader propertyReader, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, + VectorSearchFilter? legacyFilter = default, Expression>? newFilter = default, int? skip = default, bool includeVectors = false, CancellationToken cancellationToken = default); +#pragma warning restore CS0618 // VectorSearchFilter is obsolete } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs new file mode 100644 index 000000000000..6c68527da5c1 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresFilterTranslator.cs @@ -0,0 +1,332 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Microsoft.SemanticKernel.Connectors.Postgres; + +internal class PostgresFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + + private readonly List _parameterValues = new(); + private int _parameterIndex; + + private readonly StringBuilder _sql = new(); + + internal (string Clause, List Parameters) Translate( + IReadOnlyDictionary storagePropertyNames, + LambdaExpression lambdaExpression, + int startParamIndex) + { + Debug.Assert(this._sql.Length == 0); + + this._storagePropertyNames = storagePropertyNames; + + this._parameterIndex = startParamIndex; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + this._sql.Append("WHERE "); + this.Translate(lambdaExpression.Body); + return (this._sql.ToString(), this._parameterValues); + } + + private void Translate(Expression? node) + { + switch (node) + { + case BinaryExpression binary: + this.TranslateBinary(binary); + return; + + case ConstantExpression constant: + this.TranslateConstant(constant); + return; + + case MemberExpression member: + this.TranslateMember(member); + return; + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall); + return; + + case UnaryExpression unary: + this.TranslateUnary(unary); + return; + + default: + throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); + } + } + + private void TranslateBinary(BinaryExpression binary) + { + // Special handling for null comparisons + switch (binary.NodeType) + { + case ExpressionType.Equal when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left); + this._sql.Append(" IS NOT NULL)"); + return; + + case ExpressionType.Equal when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right); + this._sql.Append(" IS NOT NULL)"); + return; + } + + this._sql.Append('('); + this.Translate(binary.Left); + + this._sql.Append(binary.NodeType switch + { + ExpressionType.Equal => " = ", + ExpressionType.NotEqual => " <> ", + + ExpressionType.GreaterThan => " > ", + ExpressionType.GreaterThanOrEqual => " >= ", + ExpressionType.LessThan => " < ", + ExpressionType.LessThanOrEqual => " <= ", + + ExpressionType.AndAlso => " AND ", + ExpressionType.OrElse => " OR ", + + _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) + }); + + this.Translate(binary.Right); + this._sql.Append(')'); + + static bool IsNull(Expression expression) + => expression is ConstantExpression { Value: null } + || (TryGetCapturedValue(expression, out var capturedValue) && capturedValue is null); + } + + private void TranslateConstant(ConstantExpression constant) + { + // TODO: Nullable + switch (constant.Value) + { + case byte b: + this._sql.Append(b); + return; + case short s: + this._sql.Append(s); + return; + case int i: + this._sql.Append(i); + return; + case long l: + this._sql.Append(l); + return; + + case string s: + this._sql.Append('\'').Append(s.Replace("'", "''")).Append('\''); + return; + case bool b: + this._sql.Append(b ? "TRUE" : "FALSE"); + return; + case Guid g: + this._sql.Append('\'').Append(g.ToString()).Append('\''); + return; + + case DateTime: + case DateTimeOffset: + throw new NotImplementedException(); + + case Array: + throw new NotImplementedException(); + + case null: + this._sql.Append("NULL"); + return; + + default: + throw new NotSupportedException("Unsupported constant type: " + constant.Value.GetType().Name); + } + } + + private void TranslateMember(MemberExpression memberExpression) + { + switch (memberExpression) + { + case var _ when this.TryGetColumn(memberExpression, out var column): + this._sql.Append('"').Append(column).Append('"'); + return; + + // Identify captured lambda variables, translate to PostgreSQL parameters ($1, $2...) + case var _ when TryGetCapturedValue(memberExpression, out var capturedValue): + // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, + // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) + if (capturedValue is null) + { + this._sql.Append("NULL"); + } + else + { + this._parameterValues.Add(capturedValue); + this._sql.Append('$').Append(this._parameterIndex++); + } + return; + + default: + throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + } + } + + private void TranslateMethodCall(MethodCallExpression methodCall) + { + switch (methodCall) + { + // Enumerable.Contains() + case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable): + this.TranslateContains(source, item); + return; + + // List.Contains() + case + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>): + this.TranslateContains(source, item); + return; + + default: + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + } + + private void TranslateContains(Expression source, Expression item) + { + switch (source) + { + // Contains over array column (r => r.Strings.Contains("foo")) + case var _ when this.TryGetColumn(source, out _): + this.Translate(source); + this._sql.Append(" @> ARRAY["); + this.Translate(item); + this._sql.Append(']'); + return; + + // Contains over inline array (r => new[] { "foo", "bar" }.Contains(r.String)) + case NewArrayExpression newArray: + this.Translate(item); + this._sql.Append(" IN ("); + + var isFirst = true; + foreach (var element in newArray.Expressions) + { + if (isFirst) + { + isFirst = false; + } + else + { + this._sql.Append(", "); + } + + this.Translate(element); + } + + this._sql.Append(')'); + return; + + // Contains over captured array (r => arrayLocalVariable.Contains(r.String)) + case var _ when TryGetCapturedValue(source, out _): + this.Translate(item); + this._sql.Append(" = ANY ("); + this.Translate(source); + this._sql.Append(')'); + return; + + default: + throw new NotSupportedException("Unsupported Contains expression"); + } + } + + private void TranslateUnary(UnaryExpression unary) + { + switch (unary.NodeType) + { + case ExpressionType.Not: + // Special handling for !(a == b) and !(a != b) + if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) + { + this.TranslateBinary( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + return; + } + + this._sql.Append("(NOT "); + this.Translate(unary.Operand); + this._sql.Append(')'); + return; + + default: + throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); + } + } + + private bool TryGetColumn(Expression expression, [NotNullWhen(true)] out string? column) + { + if (expression is MemberExpression member && member.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) + { + throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + column = null; + return false; + } + + private static bool TryGetCapturedValue(Expression expression, out object? capturedValue) + { + if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + { + capturedValue = fieldInfo.GetValue(constant.Value); + return true; + } + + capturedValue = null; + return false; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs index d68412d31b7d..364c564703e4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreCollectionSqlBuilder.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Text; using Microsoft.Extensions.VectorData; using Npgsql; @@ -20,12 +21,13 @@ internal class PostgresVectorStoreCollectionSqlBuilder : IPostgresVectorStoreCol public PostgresSqlCommandInfo BuildDoesTableExistCommand(string schema, string tableName) { return new PostgresSqlCommandInfo( - commandText: @" - SELECT table_name - FROM information_schema.tables - WHERE table_schema = $1 - AND table_type = 'BASE TABLE' - AND table_name = $2", + commandText: """ +SELECT table_name +FROM information_schema.tables +WHERE table_schema = $1 + AND table_type = 'BASE TABLE' + AND table_name = $2 +""", parameters: [ new NpgsqlParameter() { Value = schema }, new NpgsqlParameter() { Value = tableName } @@ -37,11 +39,11 @@ FROM information_schema.tables public PostgresSqlCommandInfo BuildGetTablesCommand(string schema) { return new PostgresSqlCommandInfo( - commandText: @" - SELECT table_name - FROM information_schema.tables - WHERE table_schema = $1 - AND table_type = 'BASE TABLE'", + commandText: """ +SELECT table_name +FROM information_schema.tables +WHERE table_schema = $1 AND table_type = 'BASE TABLE' +""", parameters: [new NpgsqlParameter() { Value = schema }] ); } @@ -167,11 +169,12 @@ public PostgresSqlCommandInfo BuildUpsertCommand(string schema, string tableName var valuesParams = string.Join(", ", columns.Select((k, i) => $"${i + 1}")); var columnsWithIndex = columns.Select((k, i) => (col: k, idx: i)); var updateColumnsWithParams = string.Join(", ", columnsWithIndex.Where(c => c.col != keyColumn).Select(c => $"\"{c.col}\"=${c.idx + 1}")); - var commandText = $@" - INSERT INTO {schema}.""{tableName}"" ({columnNames}) - VALUES({valuesParams}) - ON CONFLICT (""{keyColumn}"") - DO UPDATE SET {updateColumnsWithParams};"; + var commandText = $""" +INSERT INTO {schema}."{tableName}" ({columnNames}) +VALUES ({valuesParams}) +ON CONFLICT ("{keyColumn}") +DO UPDATE SET {updateColumnsWithParams}; +"""; return new PostgresSqlCommandInfo(commandText) { @@ -204,11 +207,12 @@ public PostgresSqlCommandInfo BuildUpsertBatchCommand(string schema, string tabl var updateSetClause = string.Join(", ", columns.Where(c => c != keyColumn).Select(c => $"\"{c}\" = EXCLUDED.\"{c}\"")); // Generate the SQL command - var commandText = $@" - INSERT INTO {schema}.""{tableName}"" ({columnNames}) - VALUES {valuesRows} - ON CONFLICT (""{keyColumn}"") - DO UPDATE SET {updateSetClause}; "; + var commandText = $""" +INSERT INTO {schema}."{tableName}" ({columnNames}) +VALUES {valuesRows} +ON CONFLICT ("{keyColumn}") +DO UPDATE SET {updateSetClause}; +"""; // Generate the parameters var parameters = new List(); @@ -262,10 +266,11 @@ public PostgresSqlCommandInfo BuildGetCommand(string schema, string tableN var queryColumnList = string.Join(", ", queryColumns); return new PostgresSqlCommandInfo( - commandText: $@" - SELECT {queryColumnList} - FROM {schema}.""{tableName}"" - WHERE ""{keyColumn}"" = ${1};", + commandText: $""" +SELECT {queryColumnList} +FROM {schema}."{tableName}" +WHERE "{keyColumn}" = ${1}; +""", parameters: [new NpgsqlParameter() { Value = key }] ); } @@ -294,10 +299,11 @@ public PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string t var keyParams = string.Join(", ", keys.Select((k, i) => $"${i + 1}")); // Generate the SQL command - var commandText = $@" - SELECT {columnNames} - FROM {schema}.""{tableName}"" - WHERE ""{keyColumn}"" = ANY($1);"; + var commandText = $""" +SELECT {columnNames} +FROM {schema}."{tableName}" +WHERE "{keyColumn}" = ANY($1); +"""; return new PostgresSqlCommandInfo(commandText) { @@ -309,9 +315,10 @@ public PostgresSqlCommandInfo BuildGetBatchCommand(string schema, string t public PostgresSqlCommandInfo BuildDeleteCommand(string schema, string tableName, string keyColumn, TKey key) { return new PostgresSqlCommandInfo( - commandText: $@" - DELETE FROM {schema}.""{tableName}"" - WHERE ""{keyColumn}"" = ${1};", + commandText: $""" +DELETE FROM {schema}."{tableName}" +WHERE "{keyColumn}" = ${1}; +""", parameters: [new NpgsqlParameter() { Value = key }] ); } @@ -333,9 +340,10 @@ public PostgresSqlCommandInfo BuildDeleteBatchCommand(string schema, strin } } - var commandText = $@" - DELETE FROM {schema}.""{tableName}"" - WHERE ""{keyColumn}"" = ANY($1);"; + var commandText = $""" +DELETE FROM {schema}."{tableName}" +WHERE "{keyColumn}" = ANY($1); +"""; return new PostgresSqlCommandInfo(commandText) { @@ -343,13 +351,14 @@ DELETE FROM {schema}.""{tableName}"" }; } +#pragma warning disable CS0618 // VectorSearchFilter is obsolete /// - public PostgresSqlCommandInfo BuildGetNearestMatchCommand( - string schema, string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, - VectorSearchFilter? filter, int? skip, bool includeVectors, int limit) + public PostgresSqlCommandInfo BuildGetNearestMatchCommand( + string schema, string tableName, VectorStoreRecordPropertyReader propertyReader, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, + VectorSearchFilter? legacyFilter, Expression>? newFilter, int? skip, bool includeVectors, int limit) { var columns = string.Join(" ,", - properties + propertyReader.RecordDefinition.Properties .Select(property => property.StoragePropertyName ?? property.DataModelPropertyName) .Select(column => $"\"{column}\"") ); @@ -367,14 +376,24 @@ public PostgresSqlCommandInfo BuildGetNearestMatchCommand( }; var vectorColumn = vectorProperty.StoragePropertyName ?? vectorProperty.DataModelPropertyName; + // Start where clause params at 2, vector takes param 1. - var where = GenerateWhereClause(schema, tableName, properties, filter, startParamIndex: 2); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + var (where, parameters) = (oldFilter: legacyFilter, newFilter) switch + { + (not null, not null) => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + (not null, null) => GenerateLegacyFilterWhereClause(schema, tableName, propertyReader.RecordDefinition.Properties, legacyFilter, startParamIndex: 2), + (null, not null) => new PostgresFilterTranslator().Translate(propertyReader.StoragePropertyNamesMap, newFilter, startParamIndex: 2), + _ => (Clause: string.Empty, Parameters: []) + }; +#pragma warning restore CS0618 // VectorSearchFilter is obsolete - var commandText = $@" - SELECT {columns}, ""{vectorColumn}"" {distanceOp} $1 AS ""{PostgresConstants.DistanceColumnName}"" - FROM {schema}.""{tableName}"" {where.Clause} - ORDER BY {PostgresConstants.DistanceColumnName} - LIMIT {limit}"; + var commandText = $""" +SELECT {columns}, "{vectorColumn}" {distanceOp} $1 AS "{PostgresConstants.DistanceColumnName}" +FROM {schema}."{tableName}" {where} +ORDER BY {PostgresConstants.DistanceColumnName} +LIMIT {limit} +"""; if (skip.HasValue) { commandText += $" OFFSET {skip.Value}"; } @@ -383,9 +402,10 @@ ORDER BY {PostgresConstants.DistanceColumnName} // Instead we'll wrap the query in a subquery and modify the distance in the outer query. if (vectorProperty.DistanceFunction == DistanceFunction.CosineSimilarity) { - commandText = $@" - SELECT {columns}, 1 - ""{PostgresConstants.DistanceColumnName}"" AS ""{PostgresConstants.DistanceColumnName}"" - FROM ({commandText}) AS subquery"; + commandText = $""" +SELECT {columns}, 1 - "{PostgresConstants.DistanceColumnName}" AS "{PostgresConstants.DistanceColumnName}" +FROM ({commandText}) AS subquery +"""; } // For inner product, we need to take -1 * inner product. @@ -393,28 +413,27 @@ ORDER BY {PostgresConstants.DistanceColumnName} // Instead we'll wrap the query in a subquery and modify the distance in the outer query. if (vectorProperty.DistanceFunction == DistanceFunction.DotProductSimilarity) { - commandText = $@" - SELECT {columns}, -1 * ""{PostgresConstants.DistanceColumnName}"" AS ""{PostgresConstants.DistanceColumnName}"" - FROM ({commandText}) AS subquery"; + commandText = $""" +SELECT {columns}, -1 * "{PostgresConstants.DistanceColumnName}" AS "{PostgresConstants.DistanceColumnName}" +FROM ({commandText}) AS subquery +"""; } return new PostgresSqlCommandInfo(commandText) { - Parameters = [new NpgsqlParameter() { Value = vectorValue }, .. where.Parameters.Select(p => new NpgsqlParameter() { Value = p })] + Parameters = [new NpgsqlParameter { Value = vectorValue }, .. parameters.Select(p => new NpgsqlParameter { Value = p })] }; } - - internal static (string Clause, List Parameters) GenerateWhereClause(string schema, string tableName, IReadOnlyList properties, VectorSearchFilter? filter, int startParamIndex) +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + internal static (string Clause, List Parameters) GenerateLegacyFilterWhereClause(string schema, string tableName, IReadOnlyList properties, VectorSearchFilter legacyFilter, int startParamIndex) { - if (filter == null) { return (string.Empty, new List()); } - var whereClause = new StringBuilder("WHERE "); var filterClauses = new List(); var parameters = new List(); var paramIndex = startParamIndex; - foreach (var filterClause in filter.FilterClauses) + foreach (var filterClause in legacyFilter.FilterClauses) { if (filterClause is EqualToFilterClause equalTo) { @@ -450,4 +469,5 @@ internal static (string Clause, List Parameters) GenerateWhereClause(str whereClause.Append(string.Join(" AND ", filterClauses)); return (whereClause.ToString(), parameters); } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs index 5ef18cc88fdf..b97b24708b25 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreDbClient.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -172,21 +174,23 @@ public async Task DeleteAsync(string tableName, string keyColumn, TKey key } /// - public async IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync( - string tableName, IReadOnlyList properties, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, - VectorSearchFilter? filter = default, int? skip = default, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + public async IAsyncEnumerable<(Dictionary Row, double Distance)> GetNearestMatchesAsync( + string tableName, VectorStoreRecordPropertyReader propertyReader, VectorStoreRecordVectorProperty vectorProperty, Vector vectorValue, int limit, + VectorSearchFilter? legacyFilter = default, Expression>? newFilter = default, int? skip = default, bool includeVectors = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) +#pragma warning restore CS0618 // VectorSearchFilter is obsolete { NpgsqlConnection connection = await this.DataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); await using (connection) { - var commandInfo = this._sqlBuilder.BuildGetNearestMatchCommand(this._schema, tableName, properties, vectorProperty, vectorValue, filter, skip, includeVectors, limit); + var commandInfo = this._sqlBuilder.BuildGetNearestMatchCommand(this._schema, tableName, propertyReader, vectorProperty, vectorValue, legacyFilter, newFilter, skip, includeVectors, limit); using NpgsqlCommand cmd = commandInfo.ToNpgsqlCommand(connection); using NpgsqlDataReader dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) { var distance = dataReader.GetDouble(dataReader.GetOrdinal(PostgresConstants.DistanceColumnName)); - yield return (Row: this.GetRecord(dataReader, properties, includeVectors), Distance: distance); + yield return (Row: this.GetRecord(dataReader, propertyReader.RecordDefinition.Properties, includeVectors), Distance: distance); } } } diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs index de4a432ea48c..fd85896a46d4 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordCollection.cs @@ -37,7 +37,7 @@ public sealed class PostgresVectorStoreRecordCollection : IVector private readonly IVectorStoreRecordMapper> _mapper; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// /// Initializes a new instance of the class. @@ -250,7 +250,7 @@ public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancellat } /// - public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { const string OperationName = "VectorizedSearch"; @@ -261,7 +261,7 @@ public Task> VectorizedSearchAsync(TVector if (!PostgresConstants.SupportedVectorTypes.Contains(vectorType)) { throw new NotSupportedException( - $"The provided vector type {vectorType.FullName} is not supported by the SQLite connector. " + + $"The provided vector type {vectorType.FullName} is not supported by the PostgreSQL connector. " + $"Supported types are: {string.Join(", ", PostgresConstants.SupportedVectorTypes.Select(l => l.FullName))}"); } @@ -285,11 +285,14 @@ public Task> VectorizedSearchAsync(TVector { var results = this._client.GetNearestMatchesAsync( this.CollectionName, - this._propertyReader.RecordDefinition.Properties, + this._propertyReader, vectorProperty, pgVector, searchOptions.Top, +#pragma warning disable CS0618 // VectorSearchFilter is obsolete searchOptions.Filter, +#pragma warning restore CS0618 // VectorSearchFilter is obsolete + searchOptions.NewFilter, searchOptions.Skip, searchOptions.IncludeVectors, cancellationToken) diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs index 0b36f2003bf5..5e8509236e31 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresVectorStoreRecordPropertyMapping.cs @@ -143,7 +143,7 @@ public static (string PgType, bool IsNullable) GetPostgresTypeName(Type property // Handle enumerables if (VectorStoreRecordPropertyVerification.IsSupportedEnumerableType(propertyType)) { - Type elementType = propertyType.GetGenericArguments()[0]; + Type elementType = propertyType.IsArray ? propertyType.GetElementType()! : propertyType.GetGenericArguments()[0]; var underlyingPgType = GetPostgresTypeName(elementType); return (underlyingPgType.PgType + "[]", true); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantFilterTranslator.cs new file mode 100644 index 000000000000..a918883aa054 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantFilterTranslator.cs @@ -0,0 +1,382 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using Google.Protobuf.Collections; +using Qdrant.Client.Grpc; +using Range = Qdrant.Client.Grpc.Range; + +namespace Microsoft.SemanticKernel.Connectors.Qdrant; + +internal class QdrantFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + + internal Filter Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + { + this._storagePropertyNames = storagePropertyNames; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + return this.Translate(lambdaExpression.Body); + } + + private Filter Translate(Expression? node) + => node switch + { + BinaryExpression { NodeType: ExpressionType.Equal } equal => this.TranslateEqual(equal.Left, equal.Right), + BinaryExpression { NodeType: ExpressionType.NotEqual } notEqual => this.TranslateEqual(notEqual.Left, notEqual.Right, negated: true), + + BinaryExpression + { + NodeType: ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual or ExpressionType.LessThan or ExpressionType.LessThanOrEqual + } comparison + => this.TranslateComparison(comparison), + + BinaryExpression { NodeType: ExpressionType.AndAlso } andAlso => this.TranslateAndAlso(andAlso.Left, andAlso.Right), + BinaryExpression { NodeType: ExpressionType.OrElse } orElse => this.TranslateOrElse(orElse.Left, orElse.Right), + UnaryExpression { NodeType: ExpressionType.Not } not => this.TranslateNot(not.Operand), + + // MemberExpression is generally handled within e.g. TranslateEqual; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) + MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _) + => this.TranslateEqual(member, Expression.Constant(true)), + + MethodCallExpression methodCall => this.TranslateMethodCall(methodCall), + + _ => throw new NotSupportedException("Qdrant does not support the following NodeType in filters: " + node?.NodeType) + }; + + private Filter TranslateEqual(Expression left, Expression right, bool negated = false) + { + return TryProcessEqual(left, right, out var result) + ? result + : TryProcessEqual(right, left, out result) + ? result + : throw new NotSupportedException("Equality expression not supported by Qdrant"); + + bool TryProcessEqual(Expression first, Expression second, [NotNullWhen(true)] out Filter? result) + { + // TODO: Nullable + if (this.TryTranslateFieldAccess(first, out var storagePropertyName) + && TryGetConstant(second, out var constantValue)) + { + var condition = constantValue is null + ? new Condition { IsNull = new() { Key = storagePropertyName } } + : new Condition + { + Field = new FieldCondition + { + Key = storagePropertyName, + Match = constantValue switch + { + string stringValue => new Match { Keyword = stringValue }, + int intValue => new Match { Integer = intValue }, + long longValue => new Match { Integer = longValue }, + bool boolValue => new Match { Boolean = boolValue }, + + _ => throw new InvalidOperationException($"Unsupported filter value type '{constantValue.GetType().Name}'.") + } + } + }; + + result = new Filter(); + if (negated) + { + result.MustNot.Add(condition); + } + else + { + result.Must.Add(condition); + } + return true; + } + + result = null; + return false; + } + } + + private Filter TranslateComparison(BinaryExpression comparison) + { + return TryProcessComparison(comparison.Left, comparison.Right, out var result) + ? result + : TryProcessComparison(comparison.Right, comparison.Left, out result) + ? result + : throw new NotSupportedException("Comparison expression not supported by Qdrant"); + + bool TryProcessComparison(Expression first, Expression second, [NotNullWhen(true)] out Filter? result) + { + // TODO: Nullable + if (this.TryTranslateFieldAccess(first, out var storagePropertyName) + && TryGetConstant(second, out var constantValue)) + { + double doubleConstantValue = constantValue switch + { + double d => d, + int i => i, + long l => l, + _ => throw new NotSupportedException($"Can't perform comparison on type '{constantValue?.GetType().Name}', which isn't convertible to double") + }; + + result = new Filter(); + result.Must.Add(new Condition + { + Field = new FieldCondition + { + Key = storagePropertyName, + Range = comparison.NodeType switch + { + ExpressionType.GreaterThan => new Range { Gt = doubleConstantValue }, + ExpressionType.GreaterThanOrEqual => new Range { Gte = doubleConstantValue }, + ExpressionType.LessThan => new Range { Lt = doubleConstantValue }, + ExpressionType.LessThanOrEqual => new Range { Lte = doubleConstantValue }, + + _ => throw new InvalidOperationException("Unreachable") + } + } + }); + return true; + } + + result = null; + return false; + } + } + + #region Logical operators + + private Filter TranslateAndAlso(Expression left, Expression right) + { + var leftFilter = this.Translate(left); + var rightFilter = this.Translate(right); + + // As long as there are only AND conditions (Must or MustNot), we can simply combine both filters into a single flat one. + // The moment there's a Should, things become a bit more complicated: + // 1. If a side contains both a Should and a Must/MustNot, it must be pushed down. + // 2. Otherwise, if the left's Should is empty, and the right side is only Should, we can just copy the right Should into the left's. + // 3. Finally, if both sides have a Should, we push down the right side and put the result in the left's Must. + if (leftFilter.Should.Count > 0 && (leftFilter.Must.Count > 0 || leftFilter.MustNot.Count > 0)) + { + leftFilter = new Filter { Must = { new Condition { Filter = leftFilter } } }; + } + + if (rightFilter.Should.Count > 0 && (rightFilter.Must.Count > 0 || rightFilter.MustNot.Count > 0)) + { + rightFilter = new Filter { Must = { new Condition { Filter = rightFilter } } }; + } + + if (rightFilter.Should.Count > 0) + { + if (leftFilter.Should.Count == 0) + { + leftFilter.Should.AddRange(rightFilter.Should); + } + else + { + rightFilter = new Filter { Must = { new Condition { Filter = rightFilter } } }; + } + } + + leftFilter.Must.AddRange(rightFilter.Must); + leftFilter.MustNot.AddRange(rightFilter.MustNot); + + return leftFilter; + } + + private Filter TranslateOrElse(Expression left, Expression right) + { + var leftFilter = this.Translate(left); + var rightFilter = this.Translate(right); + + var result = new Filter(); + result.Should.AddRange(GetShouldConditions(leftFilter)); + result.Should.AddRange(GetShouldConditions(rightFilter)); + return result; + + static RepeatedField GetShouldConditions(Filter filter) + => filter switch + { + { Must.Count: 0, MustNot.Count: 0 } => filter.Should, + { Must.Count: 1, MustNot.Count: 0, Should.Count: 0 } => [filter.Must[0]], + { Must.Count: 0, MustNot.Count: 1, Should.Count: 0 } => [filter.MustNot[0]], + + _ => [new Condition { Filter = filter }] + }; + } + + private Filter TranslateNot(Expression expression) + { + // Special handling for !(a == b) and !(a != b) + if (expression is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) + { + return this.TranslateEqual(binary.Left, binary.Right, negated: binary.NodeType is ExpressionType.Equal); + } + + var filter = this.Translate(expression); + + switch (filter) + { + case { Must.Count: 1, MustNot.Count: 0, Should.Count: 0 }: + filter.MustNot.Add(filter.Must[0]); + filter.Must.RemoveAt(0); + return filter; + + case { Must.Count: 0, MustNot.Count: 1, Should.Count: 0 }: + filter.Must.Add(filter.MustNot[0]); + filter.MustNot.RemoveAt(0); + return filter; + + case { Must.Count: 0, MustNot.Count: 0, Should.Count: > 0 }: + filter.MustNot.AddRange(filter.Should); + filter.Should.Clear(); + return filter; + + default: + return new Filter { MustNot = { new Condition { Filter = filter } } }; + } + } + + #endregion Logical operators + + private Filter TranslateMethodCall(MethodCallExpression methodCall) + => methodCall switch + { + // Enumerable.Contains() + { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable) + => this.TranslateContains(source, item), + + // List.Contains() + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>) + => this.TranslateContains(source, item), + + _ => throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}") + }; + + private Filter TranslateContains(Expression source, Expression item) + { + switch (source) + { + // Contains over field enumerable + case var _ when this.TryTranslateFieldAccess(source, out _): + // Oddly, in Qdrant, tag list contains is handled using a Match condition, just like equality. + return this.TranslateEqual(source, item); + + // Contains over inline enumerable + case NewArrayExpression newArray: + var elements = new object?[newArray.Expressions.Count]; + + for (var i = 0; i < newArray.Expressions.Count; i++) + { + if (!TryGetConstant(newArray.Expressions[i], out var elementValue)) + { + throw new NotSupportedException("Invalid element in array"); + } + + elements[i] = elementValue; + } + + return ProcessInlineEnumerable(elements, item); + + // Contains over captured enumerable (we inline) + case var _ when TryGetConstant(source, out var constantEnumerable) + && constantEnumerable is IEnumerable enumerable and not string: + return ProcessInlineEnumerable(enumerable, item); + + default: + throw new NotSupportedException("Unsupported Contains"); + } + + Filter ProcessInlineEnumerable(IEnumerable elements, Expression item) + { + if (!this.TryTranslateFieldAccess(item, out var storagePropertyName)) + { + throw new NotSupportedException("Unsupported item type in Contains"); + } + + if (item.Type == typeof(string)) + { + var strings = new RepeatedStrings(); + + foreach (var value in elements) + { + strings.Strings.Add(value is string or null + ? (string?)value + : throw new ArgumentException("Non-string element in string Contains array")); + } + + return new Filter { Must = { new Condition { Field = new FieldCondition { Key = storagePropertyName, Match = new Match { Keywords = strings } } } } }; + } + + if (item.Type == typeof(int)) + { + var ints = new RepeatedIntegers(); + + foreach (var value in elements) + { + ints.Integers.Add(value is int intValue + ? intValue + : throw new ArgumentException("Non-int element in string Contains array")); + } + + return new Filter { Must = { new Condition { Field = new FieldCondition { Key = storagePropertyName, Match = new Match { Integers = ints } } } } }; + } + + throw new NotSupportedException("Contains only supported over array of ints or strings"); + } + } + + private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + { + if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) + { + throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + storagePropertyName = null; + return false; + } + + private static bool TryGetConstant(Expression expression, out object? constantValue) + { + switch (expression) + { + case ConstantExpression { Value: var v }: + constantValue = v; + return true; + + // This identifies compiler-generated closure types which contain captured variables. + case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): + constantValue = fieldInfo.GetValue(constant.Value); + return true; + + default: + constantValue = null; + return false; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionSearchMapping.cs index f2b9c91179e9..ec14ef585dfb 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreCollectionSearchMapping.cs @@ -12,6 +12,7 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; /// internal static class QdrantVectorStoreCollectionSearchMapping { +#pragma warning disable CS0618 // Type or member is obsolete /// /// Build a Qdrant from the provided . /// @@ -19,16 +20,10 @@ internal static class QdrantVectorStoreCollectionSearchMapping /// A mapping of data model property names to the names under which they are stored. /// The Qdrant . /// Thrown when the provided filter contains unsupported types, values or unknown properties. - public static Filter BuildFilter(VectorSearchFilter? basicVectorSearchFilter, IReadOnlyDictionary storagePropertyNames) + public static Filter BuildFromLegacyFilter(VectorSearchFilter basicVectorSearchFilter, IReadOnlyDictionary storagePropertyNames) { var filter = new Filter(); - // Return an empty filter if no filter clauses are provided. - if (basicVectorSearchFilter?.FilterClauses is null) - { - return filter; - } - foreach (var filterClause in basicVectorSearchFilter.FilterClauses) { string fieldName; @@ -72,6 +67,7 @@ public static Filter BuildFilter(VectorSearchFilter? basicVectorSearchFilter, IR return filter; } +#pragma warning restore CS0618 // Type or member is obsolete /// /// Map the given to a . diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollection.cs index 7dd77b76baff..e51ae549818a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantVectorStoreRecordCollection.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -29,7 +30,7 @@ public sealed class QdrantVectorStoreRecordCollection : IVectorStoreRec ]; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// The name of this database for telemetry purposes. private const string DatabaseName = "Qdrant"; @@ -457,7 +458,7 @@ private async IAsyncEnumerable GetBatchByPointIdAsync( } /// - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); @@ -473,8 +474,16 @@ public async Task> VectorizedSearchAsync(T var internalOptions = options ?? s_defaultVectorSearchOptions; +#pragma warning disable CS0618 // Type or member is obsolete // Build filter object. - var filter = QdrantVectorStoreCollectionSearchMapping.BuildFilter(internalOptions.Filter, this._propertyReader.StoragePropertyNamesMap); + var filter = internalOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(legacyFilter, this._propertyReader.StoragePropertyNamesMap), + { NewFilter: Expression> newFilter } => new QdrantFilterTranslator().Translate(newFilter, this._propertyReader.StoragePropertyNamesMap), + _ => new Filter() + }; +#pragma warning restore CS0618 // Type or member is obsolete // Specify the vector name if named vectors are used. string? vectorName = null; diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisFilterTranslator.cs new file mode 100644 index 000000000000..ec5bcd73514f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisFilterTranslator.cs @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Microsoft.SemanticKernel.Connectors.Redis; + +internal class RedisFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + private readonly StringBuilder _filter = new(); + + internal string Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + { + Debug.Assert(this._filter.Length == 0); + + this._storagePropertyNames = storagePropertyNames; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + this.Translate(lambdaExpression.Body); + return this._filter.ToString(); + } + + private void Translate(Expression? node) + { + switch (node) + { + case BinaryExpression + { + NodeType: ExpressionType.Equal or ExpressionType.NotEqual + or ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual + or ExpressionType.LessThan or ExpressionType.LessThanOrEqual + } binary: + this.TranslateEqualityComparison(binary); + return; + + case BinaryExpression { NodeType: ExpressionType.AndAlso } andAlso: + // https://redis.io/docs/latest/develop/interact/search-and-query/query/combined/#and + this._filter.Append('('); + this.Translate(andAlso.Left); + this._filter.Append(' '); + this.Translate(andAlso.Right); + this._filter.Append(')'); + return; + + case BinaryExpression { NodeType: ExpressionType.OrElse } orElse: + // https://redis.io/docs/latest/develop/interact/search-and-query/query/combined/#or + this._filter.Append('('); + this.Translate(orElse.Left); + this._filter.Append(" | "); + this.Translate(orElse.Right); + this._filter.Append(')'); + return; + + case UnaryExpression { NodeType: ExpressionType.Not } not: + this.TranslateNot(not.Operand); + return; + + // MemberExpression is generally handled within e.g. TranslateEqual; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) + case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): + { + this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(true))); + return; + } + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall); + return; + + default: + throw new NotSupportedException("Redis does not support the following NodeType in filters: " + node?.NodeType); + } + } + + private void TranslateEqualityComparison(BinaryExpression binary) + { + if (!TryProcessEqualityComparison(binary.Left, binary.Right) && !TryProcessEqualityComparison(binary.Right, binary.Left)) + { + throw new NotSupportedException("Binary expression not supported by Redis"); + } + + bool TryProcessEqualityComparison(Expression first, Expression second) + { + // TODO: Nullable + if (this.TryTranslateFieldAccess(first, out var storagePropertyName) + && TryGetConstant(second, out var constantValue)) + { + // Numeric negation has a special syntax (!=), for the rest we nest in a NOT + if (binary.NodeType is ExpressionType.NotEqual && constantValue is not int or long or float or double) + { + this.TranslateNot(Expression.Equal(first, second)); + return true; + } + + // https://redis.io/docs/latest/develop/interact/search-and-query/query/exact-match + this._filter.Append('@').Append(storagePropertyName); + + this._filter.Append( + binary.NodeType switch + { + ExpressionType.Equal when constantValue is int or long or float or double => $" == {constantValue}", + ExpressionType.Equal when constantValue is string stringValue +#if NETSTANDARD2_0 + => $$""":{"{{stringValue.Replace("\"", "\"\"")}}"}""", +#else + => $$""":{"{{stringValue.Replace("\"", "\\\"", StringComparison.Ordinal)}}"}""", +#endif + ExpressionType.Equal when constantValue is null => throw new NotSupportedException("Null value type not supported"), // TODO + + ExpressionType.NotEqual when constantValue is int or long or float or double => $" != {constantValue}", + ExpressionType.NotEqual => throw new InvalidOperationException("Unreachable"), // Handled above + + ExpressionType.GreaterThan => $" > {constantValue}", + ExpressionType.GreaterThanOrEqual => $" >= {constantValue}", + ExpressionType.LessThan => $" < {constantValue}", + ExpressionType.LessThanOrEqual => $" <= {constantValue}", + + _ => throw new InvalidOperationException("Unsupported equality/comparison") + }); + + return true; + } + + return false; + } + } + + private void TranslateNot(Expression expression) + { + // https://redis.io/docs/latest/develop/interact/search-and-query/query/combined/#not + this._filter.Append("(-"); + this.Translate(expression); + this._filter.Append(')'); + } + + private void TranslateMethodCall(MethodCallExpression methodCall) + { + switch (methodCall) + { + // Enumerable.Contains() + case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable): + this.TranslateContains(source, item); + return; + + // List.Contains() + case + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>): + this.TranslateContains(source, item); + return; + + default: + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + } + + private void TranslateContains(Expression source, Expression item) + { + // Contains over tag field + if (this.TryTranslateFieldAccess(source, out var storagePropertyName) + && TryGetConstant(item, out var itemConstant) + && itemConstant is string stringConstant) + { + this._filter + .Append('@') + .Append(storagePropertyName) + .Append(":{") + .Append(stringConstant) + .Append('}'); + return; + } + + throw new NotSupportedException("Contains supported only over tag field"); + } + + private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + { + if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) + { + throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + storagePropertyName = null; + return false; + } + + private static bool TryGetConstant(Expression expression, out object? constantValue) + { + switch (expression) + { + case ConstantExpression { Value: var v }: + constantValue = v; + return true; + + // This identifies compiler-generated closure types which contain captured variables. + case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): + constantValue = fieldInfo.GetValue(constant.Value); + return true; + + default: + constantValue = null; + return false; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs index 41971c5adb86..2a5d324e0171 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetVectorStoreRecordCollection.cs @@ -61,7 +61,7 @@ public sealed class RedisHashSetVectorStoreRecordCollection : IVectorSt ]; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// The Redis database to read/write records from. private readonly IDatabase _database; @@ -300,6 +300,7 @@ public async Task UpsertAsync(TRecord record, CancellationToken cancella // Upsert. var maybePrefixedKey = this.PrefixKeyIfNeeded(redisHashSetRecord.Key); + await this.RunOperationAsync( "HSET", () => this._database @@ -328,7 +329,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable reco } /// - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs index f8afa3ed875e..0d5f74d0821a 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonVectorStoreRecordCollection.cs @@ -44,7 +44,7 @@ public sealed class RedisJsonVectorStoreRecordCollection : IVectorStore ]; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// The Redis database to read/write records from. private readonly IDatabase _database; @@ -374,7 +374,7 @@ await this.RunOperationAsync( } /// - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(vector); diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs index d6603ca1634c..ea78a9e798c0 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisVectorStoreCollectionSearchMapping.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Runtime.InteropServices; using Microsoft.Extensions.VectorData; using NRedisStack.Search; @@ -50,14 +51,24 @@ public static byte[] ValidateVectorAndConvertToBytes(TVector vector, st /// The name of the first vector property in the data model. /// The set of fields to limit the results to. Null for all. /// The . - public static Query BuildQuery(byte[] vectorBytes, VectorSearchOptions options, IReadOnlyDictionary storagePropertyNames, string firstVectorPropertyName, string[]? selectFields) + public static Query BuildQuery(byte[] vectorBytes, VectorSearchOptions options, IReadOnlyDictionary storagePropertyNames, string firstVectorPropertyName, string[]? selectFields) { // Resolve options. var vectorPropertyName = ResolveVectorFieldName(options.VectorPropertyName, storagePropertyNames, firstVectorPropertyName); // Build search query. var redisLimit = options.Top + options.Skip; - var filter = RedisVectorStoreCollectionSearchMapping.BuildFilter(options.Filter, storagePropertyNames); + +#pragma warning disable CS0618 // Type or member is obsolete + var filter = options switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => BuildLegacyFilter(legacyFilter, storagePropertyNames), + { NewFilter: Expression> newFilter } => new RedisFilterTranslator().Translate(newFilter, storagePropertyNames), + _ => "*" + }; +#pragma warning restore CS0618 // Type or member is obsolete + var query = new Query($"{filter}=>[KNN {redisLimit} @{vectorPropertyName} $embedding AS vector_score]") .AddParam("embedding", vectorBytes) .SetSortBy("vector_score") @@ -80,13 +91,9 @@ public static Query BuildQuery(byte[] vectorBytes, VectorSearchOptions options, /// A mapping of data model property names to the names under which they are stored. /// The Redis filter string. /// Thrown when a provided filter value is not supported. - public static string BuildFilter(VectorSearchFilter? basicVectorSearchFilter, IReadOnlyDictionary storagePropertyNames) +#pragma warning disable CS0618 // Type or member is obsolete + public static string BuildLegacyFilter(VectorSearchFilter basicVectorSearchFilter, IReadOnlyDictionary storagePropertyNames) { - if (basicVectorSearchFilter == null) - { - return "*"; - } - var filterClauses = basicVectorSearchFilter.FilterClauses.Select(clause => { if (clause is EqualToFilterClause equalityFilterClause) @@ -116,6 +123,7 @@ public static string BuildFilter(VectorSearchFilter? basicVectorSearchFilter, IR return $"({string.Join(" ", filterClauses)})"; } +#pragma warning restore CS0618 // Type or member is obsolete /// /// Resolve the distance function to use for a search by checking the distance function of the vector property specified in options @@ -126,7 +134,7 @@ public static string BuildFilter(VectorSearchFilter? basicVectorSearchFilter, IR /// The first vector property in the record. /// The distance function for the vector we want to search. /// Thrown when a user asked for a vector property that doesn't exist on the record. - public static string ResolveDistanceFunction(VectorSearchOptions options, IReadOnlyList vectorProperties, VectorStoreRecordVectorProperty firstVectorProperty) + public static string ResolveDistanceFunction(VectorSearchOptions options, IReadOnlyList vectorProperties, VectorStoreRecordVectorProperty firstVectorProperty) { if (options.VectorPropertyName == null || vectorProperties.Count == 1) { diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs new file mode 100644 index 000000000000..2cb6b16fc8cd --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteFilterTranslator.cs @@ -0,0 +1,359 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; + +namespace Microsoft.SemanticKernel.Connectors.Sqlite; + +internal class SqliteFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + + private readonly Dictionary _parameters = new(); + + private readonly StringBuilder _sql = new(); + + internal (string Clause, Dictionary) Translate(IReadOnlyDictionary storagePropertyNames, LambdaExpression lambdaExpression) + { + Debug.Assert(this._sql.Length == 0); + + this._storagePropertyNames = storagePropertyNames; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + this.Translate(lambdaExpression.Body); + return (this._sql.ToString(), this._parameters); + } + + private void Translate(Expression? node) + { + switch (node) + { + case BinaryExpression binary: + this.TranslateBinary(binary); + return; + + case ConstantExpression constant: + this.TranslateConstant(constant); + return; + + case MemberExpression member: + this.TranslateMember(member); + return; + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall); + return; + + case UnaryExpression unary: + this.TranslateUnary(unary); + return; + + default: + throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); + } + } + + private void TranslateBinary(BinaryExpression binary) + { + // Special handling for null comparisons + switch (binary.NodeType) + { + case ExpressionType.Equal when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left); + this._sql.Append(" IS NOT NULL)"); + return; + + case ExpressionType.Equal when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right); + this._sql.Append(" IS NOT NULL)"); + return; + } + + this._sql.Append('('); + this.Translate(binary.Left); + + this._sql.Append(binary.NodeType switch + { + ExpressionType.Equal => " = ", + ExpressionType.NotEqual => " <> ", + + ExpressionType.GreaterThan => " > ", + ExpressionType.GreaterThanOrEqual => " >= ", + ExpressionType.LessThan => " < ", + ExpressionType.LessThanOrEqual => " <= ", + + ExpressionType.AndAlso => " AND ", + ExpressionType.OrElse => " OR ", + + _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) + }); + + this.Translate(binary.Right); + this._sql.Append(')'); + + static bool IsNull(Expression expression) + => expression is ConstantExpression { Value: null } + || (TryGetCapturedValue(expression, out _, out var capturedValue) && capturedValue is null); + } + + private void TranslateConstant(ConstantExpression constant) + => this.GenerateLiteral(constant.Value); + + private void GenerateLiteral(object? value) + { + // TODO: Nullable + switch (value) + { + case byte b: + this._sql.Append(b); + return; + case short s: + this._sql.Append(s); + return; + case int i: + this._sql.Append(i); + return; + case long l: + this._sql.Append(l); + return; + + case string s: + this._sql.Append('\'').Append(s.Replace("'", "''")).Append('\''); + return; + case bool b: + this._sql.Append(b ? "TRUE" : "FALSE"); + return; + case Guid g: + this._sql.Append('\'').Append(g.ToString()).Append('\''); + return; + + case DateTime: + case DateTimeOffset: + throw new NotImplementedException(); + + case Array: + throw new NotImplementedException(); + + case null: + this._sql.Append("NULL"); + return; + + default: + throw new NotSupportedException("Unsupported constant type: " + value.GetType().Name); + } + } + + private void TranslateMember(MemberExpression memberExpression) + { + switch (memberExpression) + { + case var _ when this.TryGetColumn(memberExpression, out var column): + this._sql.Append('"').Append(column).Append('"'); + return; + + // Identify captured lambda variables, translate to PostgreSQL parameters ($1, $2...) + case var _ when TryGetCapturedValue(memberExpression, out var name, out var value): + // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, + // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) + if (value is null) + { + this._sql.Append("NULL"); + } + else + { + // Duplicate parameter name, create a new parameter with a different name + // TODO: Share the same parameter when it references the same captured value + if (this._parameters.ContainsKey(name)) + { + var baseName = name; + var i = 0; + do + { + name = baseName + (i++); + } while (this._parameters.ContainsKey(name)); + } + + this._parameters.Add(name, value); + this._sql.Append('@').Append(name); + } + return; + + default: + throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + } + } + + private void TranslateMethodCall(MethodCallExpression methodCall) + { + switch (methodCall) + { + // Enumerable.Contains() + case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable): + this.TranslateContains(source, item); + return; + + // List.Contains() + case + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>): + this.TranslateContains(source, item); + return; + + default: + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + } + + private void TranslateContains(Expression source, Expression item) + { + switch (source) + { + // TODO: support Contains over array fields (#10343) + // Contains over array column (r => r.Strings.Contains("foo")) + case var _ when this.TryGetColumn(source, out _): + goto default; + + // Contains over inline array (r => new[] { "foo", "bar" }.Contains(r.String)) + case NewArrayExpression newArray: + { + this.Translate(item); + this._sql.Append(" IN ("); + + var isFirst = true; + foreach (var element in newArray.Expressions) + { + if (isFirst) + { + isFirst = false; + } + else + { + this._sql.Append(", "); + } + + this.Translate(element); + } + + this._sql.Append(')'); + return; + } + + // Contains over captured array (r => arrayLocalVariable.Contains(r.String)) + case var _ when TryGetCapturedValue(source, out _, out var value) && value is IEnumerable elements: + { + this.Translate(item); + this._sql.Append(" IN ("); + + var isFirst = true; + foreach (var element in elements) + { + if (isFirst) + { + isFirst = false; + } + else + { + this._sql.Append(", "); + } + + this.GenerateLiteral(element); + } + + this._sql.Append(')'); + return; + } + + default: + throw new NotSupportedException("Unsupported Contains expression"); + } + } + + private void TranslateUnary(UnaryExpression unary) + { + switch (unary.NodeType) + { + case ExpressionType.Not: + // Special handling for !(a == b) and !(a != b) + if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) + { + this.TranslateBinary( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + return; + } + + this._sql.Append("(NOT "); + this.Translate(unary.Operand); + this._sql.Append(')'); + return; + + default: + throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); + } + } + + private bool TryGetColumn(Expression expression, [NotNullWhen(true)] out string? column) + { + if (expression is MemberExpression member && member.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(member.Member.Name, out column)) + { + throw new InvalidOperationException($"Property name '{member.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + column = null; + return false; + } + + private static bool TryGetCapturedValue(Expression expression, [NotNullWhen(true)] out string? name, out object? value) + { + if (expression is MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + && constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true)) + { + name = fieldInfo.Name; + value = fieldInfo.GetValue(constant.Value); + return true; + } + + name = null; + value = null; + return false; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs index 028a838487d1..837e3044ddc7 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreCollectionCommandBuilder.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Data.Common; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Text; @@ -159,6 +160,8 @@ public DbCommand BuildSelectLeftJoinCommand( IReadOnlyList leftTablePropertyNames, IReadOnlyList rightTablePropertyNames, List conditions, + string? extraWhereFilter = null, + Dictionary? extraParameters = null, string? orderByPropertyName = null) { var builder = new StringBuilder(); @@ -169,7 +172,7 @@ .. leftTablePropertyNames.Select(property => $"{leftTable}.{property}"), .. rightTablePropertyNames.Select(property => $"{rightTable}.{property}"), ]; - var (command, whereClause) = this.GetCommandWithWhereClause(conditions); + var (command, whereClause) = this.GetCommandWithWhereClause(conditions, extraWhereFilter, extraParameters); builder.AppendLine($"SELECT {string.Join(", ", propertyNames)}"); builder.AppendLine($"FROM {leftTable} "); @@ -238,7 +241,10 @@ private static string GetColumnDefinition(SqliteColumn column) return string.Join(" ", columnDefinitionParts); } - private (DbCommand Command, string WhereClause) GetCommandWithWhereClause(List conditions) + private (DbCommand Command, string WhereClause) GetCommandWithWhereClause( + List conditions, + string? extraWhereFilter = null, + Dictionary? extraParameters = null) { const string WhereClauseOperator = " AND "; @@ -263,6 +269,22 @@ private static string GetColumnDefinition(SqliteColumn column) var whereClause = string.Join(WhereClauseOperator, whereClauseParts); + if (extraWhereFilter is not null) + { + if (conditions.Count > 0) + { + whereClause += " AND "; + } + + whereClause += extraWhereFilter; + + Debug.Assert(extraParameters is not null, "extraParameters must be provided when extraWhereFilter is provided."); + foreach (var p in extraParameters) + { + command.Parameters.Add(new SqliteParameter(p.Key, p.Value)); + } + } + return (command, whereClause); } diff --git a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs index 08c976abf43f..8ae095dd3bf0 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Sqlite/SqliteVectorStoreRecordCollection.cs @@ -34,7 +34,7 @@ public sealed class SqliteVectorStoreRecordCollection : private readonly IVectorStoreRecordMapper> _mapper; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// Command builder for queries in SQLite database. private readonly SqliteVectorStoreCollectionCommandBuilder _commandBuilder; @@ -154,7 +154,7 @@ public async Task DeleteCollectionAsync(CancellationToken cancellationToken = de } /// - public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { const string LimitPropertyName = "k"; @@ -189,15 +189,35 @@ public Task> VectorizedSearchAsync(TVector new SqliteWhereEqualsCondition(LimitPropertyName, limit) }; - var filterConditions = this.GetFilterConditions(searchOptions.Filter, this._dataTableName); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + string? extraWhereFilter = null; + Dictionary? extraParameters = null; - if (filterConditions is { Count: > 0 }) + if (searchOptions.Filter is not null) { - conditions.AddRange(filterConditions); + if (searchOptions.NewFilter is not null) + { + throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"); + } + + // Old filter, we translate it to a list of SqliteWhereCondition, and merge these into the conditions we already have + var filterConditions = this.GetFilterConditions(searchOptions.Filter, this._dataTableName); + + if (filterConditions is { Count: > 0 }) + { + conditions.AddRange(filterConditions); + } + } + else if (searchOptions.NewFilter is not null) + { + (extraWhereFilter, extraParameters) = new SqliteFilterTranslator().Translate(this._propertyReader.StoragePropertyNamesMap, searchOptions.NewFilter); } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete var vectorSearchResults = new VectorSearchResults(this.EnumerateAndMapSearchResultsAsync( conditions, + extraWhereFilter, + extraParameters, searchOptions, cancellationToken)); @@ -288,7 +308,9 @@ public Task DeleteBatchAsync(IEnumerable keys, CancellationToken cancell private async IAsyncEnumerable> EnumerateAndMapSearchResultsAsync( List conditions, - VectorSearchOptions searchOptions, + string? extraWhereFilter, + Dictionary? extraParameters, + VectorSearchOptions searchOptions, [EnumeratorCancellation] CancellationToken cancellationToken) { const string OperationName = "VectorizedSearch"; @@ -311,6 +333,8 @@ private async IAsyncEnumerable> EnumerateAndMapSearc leftTableProperties, this._dataTableStoragePropertyNames.Value, conditions, + extraWhereFilter, + extraParameters, DistancePropertyName); using var reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); @@ -670,6 +694,7 @@ private async Task RunOperationAsync(string operationName, Func> o return new SqliteVectorStoreRecordMapper(this._propertyReader); } +#pragma warning disable CS0618 // VectorSearchFilter is obsolete private List? GetFilterConditions(VectorSearchFilter? filter, string? tableName = null) { var filterClauses = filter?.FilterClauses.ToList(); @@ -706,6 +731,7 @@ private async Task RunOperationAsync(string operationName, Func> o return conditions; } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete /// /// Gets vector table name. diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateFilterTranslator.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateFilterTranslator.cs new file mode 100644 index 000000000000..2e4be5391159 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateFilterTranslator.cs @@ -0,0 +1,260 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; + +namespace Microsoft.SemanticKernel.Connectors.Weaviate; + +// https://weaviate.io/developers/weaviate/api/graphql/filters#filter-structure +internal class WeaviateFilterTranslator +{ + private IReadOnlyDictionary _storagePropertyNames = null!; + private ParameterExpression _recordParameter = null!; + private readonly StringBuilder _filter = new(); + + internal string Translate(LambdaExpression lambdaExpression, IReadOnlyDictionary storagePropertyNames) + { + Debug.Assert(this._filter.Length == 0); + + this._storagePropertyNames = storagePropertyNames; + + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._recordParameter = lambdaExpression.Parameters[0]; + + this.Translate(lambdaExpression.Body); + return this._filter.ToString(); + } + + private void Translate(Expression? node) + { + switch (node) + { + case BinaryExpression + { + NodeType: ExpressionType.Equal or ExpressionType.NotEqual + or ExpressionType.GreaterThan or ExpressionType.GreaterThanOrEqual + or ExpressionType.LessThan or ExpressionType.LessThanOrEqual + } binary: + this.TranslateEqualityComparison(binary); + return; + + case BinaryExpression { NodeType: ExpressionType.AndAlso } andAlso: + this._filter.Append("{ operator: And, operands: ["); + this.Translate(andAlso.Left); + this._filter.Append(", "); + this.Translate(andAlso.Right); + this._filter.Append("] }"); + return; + + case BinaryExpression { NodeType: ExpressionType.OrElse } orElse: + this._filter.Append("{ operator: Or, operands: ["); + this.Translate(orElse.Left); + this._filter.Append(", "); + this.Translate(orElse.Right); + this._filter.Append("] }"); + return; + + case UnaryExpression { NodeType: ExpressionType.Not } not: + { + switch (not.Operand) + { + // Special handling for !(a == b) and !(a != b) + case BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary: + this.TranslateEqualityComparison( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + return; + + // Not over bool field (Filter => r => !r.Bool) + case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): + this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(false))); + return; + + default: + throw new NotSupportedException("Weaviate does not support the NOT operator (see https://github.com/weaviate/weaviate/issues/3683)"); + } + } + + // MemberExpression is generally handled within e.g. TranslateEqual; this is used to translate direct bool inside filter (e.g. Filter => r => r.Bool) + case MemberExpression member when member.Type == typeof(bool) && this.TryTranslateFieldAccess(member, out _): + this.TranslateEqualityComparison(Expression.Equal(member, Expression.Constant(true))); + return; + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall); + return; + + default: + throw new NotSupportedException("The following NodeType is unsupported: " + node?.NodeType); + } + } + + private void TranslateEqualityComparison(BinaryExpression binary) + { + if ((this.TryTranslateFieldAccess(binary.Left, out var storagePropertyName) && TryGetConstant(binary.Right, out var value)) + || (this.TryTranslateFieldAccess(binary.Right, out storagePropertyName) && TryGetConstant(binary.Left, out value))) + { + // { path: ["intPropName"], operator: Equal, ValueInt: 8 } + this._filter + .Append("{ path: [\"") + .Append(JsonEncodedText.Encode(storagePropertyName)) + .Append("\"], operator: "); + + // Special handling for null comparisons + if (value is null) + { + if (binary.NodeType is ExpressionType.Equal or ExpressionType.NotEqual) + { + this._filter + .Append("IsNull, valueBoolean: ") + .Append(binary.NodeType is ExpressionType.Equal ? "true" : "false") + .Append(" }"); + return; + } + + throw new NotSupportedException("null value supported only with equality/inequality checks"); + } + + // Operator + this._filter.Append(binary.NodeType switch + { + ExpressionType.Equal => "Equal", + ExpressionType.NotEqual => "NotEqual", + + ExpressionType.GreaterThan => "GreaterThan", + ExpressionType.GreaterThanOrEqual => "GreaterThanEqual", + ExpressionType.LessThan => "LessThan", + ExpressionType.LessThanOrEqual => "LessThanEqual", + + _ => throw new UnreachableException() + }); + + this._filter.Append(", "); + + // FieldType + var type = value.GetType(); + if (Nullable.GetUnderlyingType(type) is Type underlying) + { + type = underlying; + } + + this._filter.Append(value.GetType() switch + { + Type t when t == typeof(int) || t == typeof(long) || t == typeof(short) || t == typeof(byte) => "valueInt", + Type t when t == typeof(bool) => "valueBoolean", + Type t when t == typeof(string) || t == typeof(Guid) => "valueText", + Type t when t == typeof(float) || t == typeof(double) || t == typeof(decimal) => "valueNumber", + Type t when t == typeof(DateTimeOffset) => "valueDate", + + _ => throw new NotSupportedException($"Unsupported value type {type.FullName} in filter.") + }); + + this._filter.Append(": "); + + // Value + this._filter.Append(JsonSerializer.Serialize(value)); + + this._filter.Append('}'); + + return; + } + + throw new NotSupportedException("Invalid equality/comparison"); + } + + private void TranslateMethodCall(MethodCallExpression methodCall) + { + switch (methodCall) + { + // Enumerable.Contains() + case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains + when contains.Method.DeclaringType == typeof(Enumerable): + this.TranslateContains(source, item); + return; + + // List.Contains() + case + { + Method: + { + Name: nameof(Enumerable.Contains), + DeclaringType: { IsGenericType: true } declaringType + }, + Object: Expression source, + Arguments: [var item] + } when declaringType.GetGenericTypeDefinition() == typeof(List<>): + this.TranslateContains(source, item); + return; + + default: + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + } + + private void TranslateContains(Expression source, Expression item) + { + // Contains over array + // { path: ["stringArrayPropName"], operator: ContainsAny, valueText: ["foo"] } + if (this.TryTranslateFieldAccess(source, out var storagePropertyName) + && TryGetConstant(item, out var itemConstant) + && itemConstant is string stringConstant) + { + this._filter + .Append("{ path: [\"") + .Append(JsonEncodedText.Encode(storagePropertyName)) + .Append("\"], operator: ContainsAny, valueText: [") + .Append(JsonEncodedText.Encode(stringConstant)) + .Append("]}"); + return; + } + + throw new NotSupportedException("Contains supported only over tag field"); + } + + private bool TryTranslateFieldAccess(Expression expression, [NotNullWhen(true)] out string? storagePropertyName) + { + if (expression is MemberExpression memberExpression && memberExpression.Expression == this._recordParameter) + { + if (!this._storagePropertyNames.TryGetValue(memberExpression.Member.Name, out storagePropertyName)) + { + throw new InvalidOperationException($"Property name '{memberExpression.Member.Name}' provided as part of the filter clause is not a valid property name."); + } + + return true; + } + + storagePropertyName = null; + return false; + } + + private static bool TryGetConstant(Expression expression, out object? constantValue) + { + switch (expression) + { + case ConstantExpression { Value: var v }: + constantValue = v; + return true; + + // This identifies compiler-generated closure types which contain captured variables. + case MemberExpression { Expression: ConstantExpression constant, Member: FieldInfo fieldInfo } + when constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) + && Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true): + constantValue = fieldInfo.GetValue(constant.Value); + return true; + + default: + constantValue = null; + return false; + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs index a4ba633535a7..fe8e965f67e3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollection.cs @@ -9,7 +9,6 @@ using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Nodes; -using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; @@ -75,7 +74,6 @@ public sealed class WeaviateVectorStoreRecordCollection : IVectorStoreR private static readonly JsonSerializerOptions s_jsonSerializerOptions = new() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase, - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, Converters = { new WeaviateDateTimeOffsetConverter(), @@ -84,7 +82,7 @@ public sealed class WeaviateVectorStoreRecordCollection : IVectorStoreR }; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// that is used to interact with Weaviate API. private readonly HttpClient _httpClient; @@ -335,7 +333,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record /// public async Task> VectorizedSearchAsync( TVector vector, - VectorSearchOptions? options = null, + VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { const string OperationName = "VectorSearch"; diff --git a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionQueryBuilder.cs b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionQueryBuilder.cs index 397af63763a6..e665e7e85e08 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionQueryBuilder.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Weaviate/WeaviateVectorStoreRecordCollectionQueryBuilder.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using System.Text.Json; using Microsoft.Extensions.VectorData; @@ -17,13 +18,13 @@ internal static class WeaviateVectorStoreRecordCollectionQueryBuilder /// Builds Weaviate search query. /// More information here: . /// - public static string BuildSearchQuery( + public static string BuildSearchQuery( TVector vector, string collectionName, string vectorPropertyName, string keyPropertyName, JsonSerializerOptions jsonSerializerOptions, - VectorSearchOptions searchOptions, + VectorSearchOptions searchOptions, IReadOnlyDictionary storagePropertyNames, IReadOnlyList vectorPropertyStorageNames, IReadOnlyList dataPropertyStorageNames) @@ -32,11 +33,19 @@ public static string BuildSearchQuery( $"vectors {{ {string.Join(" ", vectorPropertyStorageNames)} }}" : string.Empty; - var filter = BuildFilter( - searchOptions.Filter, - jsonSerializerOptions, - keyPropertyName, - storagePropertyNames); +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + var filter = searchOptions switch + { + { Filter: not null, NewFilter: not null } => throw new ArgumentException("Either Filter or NewFilter can be specified, but not both"), + { Filter: VectorSearchFilter legacyFilter } => BuildLegacyFilter( + legacyFilter, + jsonSerializerOptions, + keyPropertyName, + storagePropertyNames), + { NewFilter: Expression> newFilter } => new WeaviateFilterTranslator().Translate(newFilter, storagePropertyNames), + _ => null + }; +#pragma warning restore CS0618 var vectorArray = JsonSerializer.Serialize(vector, jsonSerializerOptions); @@ -46,7 +55,7 @@ public static string BuildSearchQuery( {{collectionName}} ( limit: {{searchOptions.Top}} offset: {{searchOptions.Skip}} - {{filter}} + {{(filter is null ? "" : "where: " + filter)}} nearVector: { targetVectors: ["{{vectorPropertyName}}"] vector: {{vectorArray}} @@ -66,11 +75,12 @@ public static string BuildSearchQuery( #region private +#pragma warning disable CS0618 // Type or member is obsolete /// /// Builds filter for Weaviate search query. /// More information here: . /// - private static string BuildFilter( + private static string BuildLegacyFilter( VectorSearchFilter? vectorSearchFilter, JsonSerializerOptions jsonSerializerOptions, string keyPropertyName, @@ -134,8 +144,9 @@ private static string BuildFilter( operands.Add(operand); } - return $$"""where: { operator: And, operands: [{{string.Join(", ", operands)}}] }"""; + return $$"""{ operator: And, operands: [{{string.Join(", ", operands)}}] }"""; } +#pragma warning restore CS0618 // Type or member is obsolete /// /// Gets filter value type. diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreCollectionSearchMappingTests.cs index 8242333ecea5..cea02dee086c 100644 --- a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreCollectionSearchMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreCollectionSearchMappingTests.cs @@ -9,6 +9,8 @@ namespace SemanticKernel.Connectors.MongoDB.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Unit tests for class. /// @@ -20,32 +22,6 @@ public sealed class MongoDBVectorStoreCollectionSearchMappingTests ["Property2"] = "property_2", }; - [Fact] - public void BuildFilterWithNullVectorSearchFilterReturnsNull() - { - // Arrange - VectorSearchFilter? vectorSearchFilter = null; - - // Act - var filter = MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames); - - // Assert - Assert.Null(filter); - } - - [Fact] - public void BuildFilterWithoutFilterClausesReturnsNull() - { - // Arrange - VectorSearchFilter vectorSearchFilter = new(); - - // Act - var filter = MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames); - - // Assert - Assert.Null(filter); - } - [Fact] public void BuildFilterThrowsExceptionWithUnsupportedFilterClause() { @@ -53,7 +29,7 @@ public void BuildFilterThrowsExceptionWithUnsupportedFilterClause() var vectorSearchFilter = new VectorSearchFilter().AnyTagEqualTo("NonExistentProperty", "TestValue"); // Act & Assert - Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames)); + Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._storagePropertyNames)); } [Fact] @@ -63,7 +39,7 @@ public void BuildFilterThrowsExceptionWithNonExistentPropertyName() var vectorSearchFilter = new VectorSearchFilter().EqualTo("NonExistentProperty", "TestValue"); // Act & Assert - Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames)); + Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._storagePropertyNames)); } [Fact] @@ -75,7 +51,7 @@ public void BuildFilterThrowsExceptionWithMultipleFilterClausesOfSameType() .EqualTo("Property1", "TestValue2"); // Act & Assert - Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames)); + Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._storagePropertyNames)); } [Fact] @@ -86,7 +62,7 @@ public void BuilderFilterByDefaultReturnsValidFilter() var vectorSearchFilter = new VectorSearchFilter().EqualTo("Property1", "TestValue1"); // Act - var filter = MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames); + var filter = MongoDBVectorStoreCollectionSearchMapping.BuildLegacyFilter(vectorSearchFilter, this._storagePropertyNames); Assert.Equal(filter.ToJson(), expectedFilter.ToJson()); } diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordCollectionTests.cs index 26a9b9fb00b7..7fa33bbd9967 100644 --- a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordCollectionTests.cs @@ -13,6 +13,7 @@ using MongoDB.Driver; using Moq; using Xunit; +using MEVD = Microsoft.Extensions.VectorData; namespace SemanticKernel.Connectors.MongoDB.UnitTests; @@ -639,7 +640,7 @@ public async Task VectorizedSearchThrowsExceptionWithNonExistentVectorPropertyNa this._mockMongoDatabase.Object, "collection"); - var options = new VectorSearchOptions { VectorPropertyName = "non-existent-property" }; + var options = new MEVD.VectorSearchOptions { VectorPropertyName = "non-existent-property" }; // Act & Assert await Assert.ThrowsAsync(async () => await (await sut.VectorizedSearchAsync(new ReadOnlyMemory([1f, 2f, 3f]), options)).Results.FirstOrDefaultAsync()); diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj b/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj index 68fbec524a28..0f884f0df59c 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj +++ b/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj @@ -39,4 +39,10 @@ + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs index 675843a78c18..e1958f934c5d 100644 --- a/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Postgres.UnitTests/PostgresVectorStoreCollectionSqlBuilderTests.cs @@ -366,57 +366,4 @@ public void TestBuildDeleteBatchCommand() // Output this._output.WriteLine(cmdInfo.CommandText); } - - [Fact] - public void TestBuildGetNearestMatchCommand() - { - // Arrange - var builder = new PostgresVectorStoreCollectionSqlBuilder(); - - var vectorProperty = new VectorStoreRecordVectorProperty("embedding1", typeof(ReadOnlyMemory)) - { - Dimensions = 10, - IndexKind = "hnsw", - }; - - var recordDefinition = new VectorStoreRecordDefinition() - { - Properties = [ - new VectorStoreRecordKeyProperty("id", typeof(long)), - new VectorStoreRecordDataProperty("name", typeof(string)), - new VectorStoreRecordDataProperty("code", typeof(int)), - new VectorStoreRecordDataProperty("rating", typeof(float?)), - new VectorStoreRecordDataProperty("description", typeof(string)), - new VectorStoreRecordDataProperty("parking_is_included", typeof(bool)), - new VectorStoreRecordDataProperty("tags", typeof(List)), - vectorProperty, - new VectorStoreRecordVectorProperty("embedding2", typeof(ReadOnlyMemory?)) - { - Dimensions = 10, - IndexKind = "hnsw", - } - ] - }; - - var vector = new Vector(s_vector); - - // Act - var cmdInfo = builder.BuildGetNearestMatchCommand("public", "testcollection", - properties: recordDefinition.Properties, - vectorProperty: vectorProperty, - vectorValue: vector, - filter: null, - skip: null, - includeVectors: true, - limit: 10); - - // Assert - Assert.Contains("SELECT", cmdInfo.CommandText); - Assert.Contains("FROM public.\"testcollection\"", cmdInfo.CommandText); - Assert.Contains("ORDER BY", cmdInfo.CommandText); - Assert.Contains("LIMIT 10", cmdInfo.CommandText); - - // Output - this._output.WriteLine(cmdInfo.CommandText); - } } diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionSearchMappingTests.cs index 623f997a4ed2..afd5e545030a 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionSearchMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreCollectionSearchMappingTests.cs @@ -10,6 +10,8 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Contains tests for the class. /// @@ -35,7 +37,7 @@ public void BuildFilterMapsEqualityClause(string type) var filter = new VectorSearchFilter().EqualTo("FieldName", expected); // Act. - var actual = QdrantVectorStoreCollectionSearchMapping.BuildFilter(filter, new Dictionary() { { "FieldName", "storage_FieldName" } }); + var actual = QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(filter, new Dictionary() { { "FieldName", "storage_FieldName" } }); // Assert. Assert.Single(actual.Must); @@ -69,7 +71,7 @@ public void BuildFilterMapsTagContainsClause() var filter = new VectorSearchFilter().AnyTagEqualTo("FieldName", "Value"); // Act. - var actual = QdrantVectorStoreCollectionSearchMapping.BuildFilter(filter, new Dictionary() { { "FieldName", "storage_FieldName" } }); + var actual = QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(filter, new Dictionary() { { "FieldName", "storage_FieldName" } }); // Assert. Assert.Single(actual.Must); @@ -84,7 +86,7 @@ public void BuildFilterThrowsForUnknownFieldName() var filter = new VectorSearchFilter().EqualTo("FieldName", "Value"); // Act and Assert. - Assert.Throws(() => QdrantVectorStoreCollectionSearchMapping.BuildFilter(filter, new Dictionary())); + Assert.Throws(() => QdrantVectorStoreCollectionSearchMapping.BuildFromLegacyFilter(filter, new Dictionary())); } [Fact] diff --git a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordCollectionTests.cs index 1bb89a91344e..666efcc4647b 100644 --- a/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Qdrant.UnitTests/QdrantVectorStoreRecordCollectionTests.cs @@ -545,6 +545,7 @@ public void CanCreateCollectionWithMismatchedDefinitionAndType() new() { VectorStoreRecordDefinition = definition, PointStructCustomMapper = Mock.Of, PointStruct>>() }); } +#pragma warning disable CS0618 // VectorSearchFilter is obsolete [Theory] [MemberData(nameof(TestOptions))] public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bool hasNamedVectors, TKey testRecordKey) @@ -593,6 +594,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bo Assert.Equal(new float[] { 1, 2, 3, 4 }, results.First().Record.Vector!.Value.ToArray()); Assert.Equal(0.5f, results.First().Score); } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete private void SetupRetrieveMock(List retrievedPoints) { diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs index 5457582661ee..fb15d0031c2b 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisHashSetVectorStoreRecordCollectionTests.cs @@ -415,6 +415,7 @@ public async Task CanUpsertRecordWithCustomMapperAsync() Times.Once); } +#pragma warning disable CS0618 // VectorSearchFilter is obsolete [Theory] [InlineData(true, true)] [InlineData(true, false)] @@ -508,6 +509,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, bool inc Assert.False(results.First().Record.Vector.HasValue); } } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete /// /// Tests that the collection can be created even if the definition and the type do not match. diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs index 20d1b0da5831..6cfe1f17960e 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisJsonVectorStoreRecordCollectionTests.cs @@ -16,6 +16,8 @@ namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Contains tests for the class. /// diff --git a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs index 8253801a8cb7..1301ee6a7eb9 100644 --- a/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs +++ b/dotnet/src/Connectors/Connectors.Redis.UnitTests/RedisVectorStoreCollectionSearchMappingTests.cs @@ -8,6 +8,8 @@ namespace Microsoft.SemanticKernel.Connectors.Redis.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Contains tests for the class. /// @@ -70,7 +72,7 @@ public void BuildQueryBuildsRedisQueryWithDefaults() var firstVectorPropertyName = "storage_Vector"; // Act. - var query = RedisVectorStoreCollectionSearchMapping.BuildQuery(byteArray, new VectorSearchOptions(), storagePropertyNames, firstVectorPropertyName, null); + var query = RedisVectorStoreCollectionSearchMapping.BuildQuery(byteArray, new VectorSearchOptions(), storagePropertyNames, firstVectorPropertyName, null); // Assert. Assert.NotNull(query); @@ -86,7 +88,7 @@ public void BuildQueryBuildsRedisQueryWithCustomVectorName() // Arrange. var floatVector = new ReadOnlyMemory(new float[] { 1.0f, 2.0f, 3.0f }); var byteArray = MemoryMarshal.AsBytes(floatVector.Span).ToArray(); - var vectorSearchOptions = new VectorSearchOptions { Top = 5, Skip = 3, VectorPropertyName = "Vector" }; + var vectorSearchOptions = new VectorSearchOptions { Top = 5, Skip = 3, VectorPropertyName = "Vector" }; var storagePropertyNames = new Dictionary() { { "Vector", "storage_Vector" }, @@ -108,7 +110,7 @@ public void BuildQueryFailsForInvalidVectorName() // Arrange. var floatVector = new ReadOnlyMemory(new float[] { 1.0f, 2.0f, 3.0f }); var byteArray = MemoryMarshal.AsBytes(floatVector.Span).ToArray(); - var vectorSearchOptions = new VectorSearchOptions { VectorPropertyName = "UnknownVector" }; + var vectorSearchOptions = new VectorSearchOptions { VectorPropertyName = "UnknownVector" }; var storagePropertyNames = new Dictionary() { { "Vector", "storage_Vector" }, @@ -149,7 +151,7 @@ public void BuildFilterBuildsEqualityFilter(string filterType) }; // Act. - var filter = RedisVectorStoreCollectionSearchMapping.BuildFilter(basicVectorSearchFilter, storagePropertyNames); + var filter = RedisVectorStoreCollectionSearchMapping.BuildLegacyFilter(basicVectorSearchFilter, storagePropertyNames); // Assert. switch (filterType) @@ -184,7 +186,7 @@ public void BuildFilterThrowsForInvalidValueType() // Act & Assert. Assert.Throws(() => { - var filter = RedisVectorStoreCollectionSearchMapping.BuildFilter(basicVectorSearchFilter, storagePropertyNames); + var filter = RedisVectorStoreCollectionSearchMapping.BuildLegacyFilter(basicVectorSearchFilter, storagePropertyNames); }); } @@ -201,7 +203,7 @@ public void BuildFilterThrowsForUnknownFieldName() // Act & Assert. Assert.Throws(() => { - var filter = RedisVectorStoreCollectionSearchMapping.BuildFilter(basicVectorSearchFilter, storagePropertyNames); + var filter = RedisVectorStoreCollectionSearchMapping.BuildLegacyFilter(basicVectorSearchFilter, storagePropertyNames); }); } @@ -211,7 +213,7 @@ public void ResolveDistanceFunctionReturnsCosineSimilarityIfNoDistanceFunctionSp var property = new VectorStoreRecordVectorProperty("Prop", typeof(ReadOnlyMemory)); // Act. - var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions(), [property], property); + var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions(), [property], property); // Assert. Assert.Equal(DistanceFunction.CosineSimilarity, resolvedDistanceFunction); @@ -223,7 +225,7 @@ public void ResolveDistanceFunctionReturnsDistanceFunctionFromFirstPropertyIfNoF var property = new VectorStoreRecordVectorProperty("Prop", typeof(ReadOnlyMemory)) { DistanceFunction = DistanceFunction.DotProductSimilarity }; // Act. - var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions(), [property], property); + var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions(), [property], property); // Assert. Assert.Equal(DistanceFunction.DotProductSimilarity, resolvedDistanceFunction); @@ -236,7 +238,7 @@ public void ResolveDistanceFunctionReturnsDistanceFunctionFromChosenPropertyIfFi var property2 = new VectorStoreRecordVectorProperty("Prop2", typeof(ReadOnlyMemory)) { DistanceFunction = DistanceFunction.DotProductSimilarity }; // Act. - var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions() { VectorPropertyName = "Prop2" }, [property1, property2], property1); + var resolvedDistanceFunction = RedisVectorStoreCollectionSearchMapping.ResolveDistanceFunction(new VectorSearchOptions { VectorPropertyName = "Prop2" }, [property1, property2], property1); // Assert. Assert.Equal(DistanceFunction.DotProductSimilarity, resolvedDistanceFunction); @@ -260,4 +262,8 @@ public void GetOutputScoreFromRedisScoreLeavesNonConsineSimilarityUntouched(stri // Act & Assert. Assert.Equal(score, RedisVectorStoreCollectionSearchMapping.GetOutputScoreFromRedisScore(score, distanceFunction)); } + +#pragma warning disable CA1812 // An internal class that is apparently never instantiated. If so, remove the code from the assembly. + private sealed class DummyType; +#pragma warning restore CA1812 } diff --git a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs index 9d79fd640a33..370756cb4344 100644 --- a/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Sqlite.UnitTests/SqliteVectorStoreCollectionCommandBuilderTests.cs @@ -233,6 +233,8 @@ public void ItBuildsSelectLeftJoinCommand(string? orderByPropertyName) leftTablePropertyNames, rightTablePropertyNames, conditions, + extraWhereFilter: null, + extraParameters: null, orderByPropertyName); // Assert diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionQueryBuilderTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionQueryBuilderTests.cs index 6c4f8336654f..a0fa8b4f0ae0 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionQueryBuilderTests.cs +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionQueryBuilderTests.cs @@ -10,6 +10,8 @@ namespace SemanticKernel.Connectors.Weaviate.UnitTests; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Unit tests for class. /// @@ -72,7 +74,7 @@ hotelName hotelCode } """; - var searchOptions = new VectorSearchOptions + var searchOptions = new VectorSearchOptions { Skip = 2, Top = 3, @@ -102,7 +104,7 @@ hotelName hotelCode public void BuildSearchQueryWithIncludedVectorsReturnsValidQuery() { // Arrange - var searchOptions = new VectorSearchOptions + var searchOptions = new VectorSearchOptions { Skip = 2, Top = 3, @@ -133,7 +135,7 @@ public void BuildSearchQueryWithFilterReturnsValidQuery() const string ExpectedFirstSubquery = """{ path: ["hotelName"], operator: Equal, valueText: "Test Name" }"""; const string ExpectedSecondSubquery = """{ path: ["tags"], operator: ContainsAny, valueText: ["t1"] }"""; - var searchOptions = new VectorSearchOptions + var searchOptions = new VectorSearchOptions { Skip = 2, Top = 3, @@ -164,7 +166,7 @@ public void BuildSearchQueryWithFilterReturnsValidQuery() public void BuildSearchQueryWithInvalidFilterValueThrowsException() { // Arrange - var searchOptions = new VectorSearchOptions + var searchOptions = new VectorSearchOptions { Skip = 2, Top = 3, @@ -189,7 +191,7 @@ public void BuildSearchQueryWithInvalidFilterValueThrowsException() public void BuildSearchQueryWithNonExistentPropertyInFilterThrowsException() { // Arrange - var searchOptions = new VectorSearchOptions + var searchOptions = new VectorSearchOptions { Skip = 2, Top = 3, @@ -212,6 +214,9 @@ public void BuildSearchQueryWithNonExistentPropertyInFilterThrowsException() #region private +#pragma warning disable CA1812 // An internal class that is apparently never instantiated. If so, remove the code from the assembly. + private sealed class DummyType; +#pragma warning restore CA1812 private sealed class TestFilterValue; #endregion diff --git a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionTests.cs index 0871c4978977..8f7ea996101d 100644 --- a/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/Connectors/Connectors.Weaviate.UnitTests/WeaviateVectorStoreRecordCollectionTests.cs @@ -530,11 +530,12 @@ public async Task VectorizedSearchWithNonExistentVectorPropertyNameThrowsExcepti // Arrange var sut = new WeaviateVectorStoreRecordCollection(this._mockHttpClient, "Collection"); - var searchOptions = new VectorSearchOptions { VectorPropertyName = "non-existent-property" }; - // Act & Assert await Assert.ThrowsAsync(async () => - await (await sut.VectorizedSearchAsync(new ReadOnlyMemory([1f, 2f, 3f]), searchOptions)).Results.ToListAsync()); + await (await sut.VectorizedSearchAsync( + new ReadOnlyMemory([1f, 2f, 3f]), + new() { VectorPropertyName = "non-existent-property" })) + .Results.ToListAsync()); } public void Dispose() diff --git a/dotnet/src/Connectors/VectorData.Abstractions/CompatibilitySuppressions.xml b/dotnet/src/Connectors/VectorData.Abstractions/CompatibilitySuppressions.xml index 0860b81e7585..cd9bfbaa3ca7 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/CompatibilitySuppressions.xml +++ b/dotnet/src/Connectors/VectorData.Abstractions/CompatibilitySuppressions.xml @@ -1,5 +1,5 @@  - + CP0001 @@ -15,6 +15,13 @@ lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0001 + T:Microsoft.Extensions.VectorData.VectorSearchOptions + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0001 T:Microsoft.Extensions.VectorData.DeleteRecordOptions @@ -29,6 +36,13 @@ lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0001 + T:Microsoft.Extensions.VectorData.VectorSearchOptions + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0001 T:Microsoft.Extensions.VectorData.DeleteRecordOptions @@ -43,6 +57,27 @@ lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0001 + T:Microsoft.Extensions.VectorData.VectorSearchOptions + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0002 M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,Microsoft.Extensions.VectorData.DeleteRecordOptions,System.Threading.CancellationToken) @@ -71,6 +106,20 @@ lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0002 M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,Microsoft.Extensions.VectorData.DeleteRecordOptions,System.Threading.CancellationToken) @@ -99,6 +148,20 @@ lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0002 + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0002 M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,Microsoft.Extensions.VectorData.DeleteRecordOptions,System.Threading.CancellationToken) @@ -127,6 +190,20 @@ lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0006 M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,System.Threading.CancellationToken) @@ -155,6 +232,20 @@ lib/net462/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0006 M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,System.Threading.CancellationToken) @@ -183,6 +274,20 @@ lib/net8.0/Microsoft.Extensions.VectorData.Abstractions.dll true + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorizableTextSearch`1.VectorizableTextSearchAsync(System.String,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + + + CP0006 + M:Microsoft.Extensions.VectorData.IVectorizedSearch`1.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions{`0},System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + lib/netstandard2.0/Microsoft.Extensions.VectorData.Abstractions.dll + true + CP0006 M:Microsoft.Extensions.VectorData.IVectorStoreRecordCollection`2.DeleteAsync(`0,System.Threading.CancellationToken) diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizableTextSearch.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizableTextSearch.cs index a0d5181b7668..5e39a541ef86 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizableTextSearch.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizableTextSearch.cs @@ -20,6 +20,6 @@ public interface IVectorizableTextSearch /// The records found by the vector search, including their result scores. Task> VectorizableTextSearchAsync( string searchText, - VectorSearchOptions? options = default, + VectorSearchOptions? options = default, CancellationToken cancellationToken = default); } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizedSearch.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizedSearch.cs index 9ac93383b18d..3286fafc15fc 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizedSearch.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/IVectorizedSearch.cs @@ -21,6 +21,6 @@ public interface IVectorizedSearch /// The records found by the vector search, including their result scores. Task> VectorizedSearchAsync( TVector vector, - VectorSearchOptions? options = default, + VectorSearchOptions? options = default, CancellationToken cancellationToken = default); } diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchFilter.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchFilter.cs index a8b941776eff..9d167fcb160b 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchFilter.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchFilter.cs @@ -14,6 +14,7 @@ namespace Microsoft.Extensions.VectorData; /// to request that the underlying service filter the search results. /// All clauses are combined with and. /// +[Obsolete("Use VectorSearchOptions.NewFilter instead of VectorSearchOptions.Filter")] public sealed class VectorSearchFilter { /// The filter clauses to and together. diff --git a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs index a5773b0cc606..65d9c6e157c2 100644 --- a/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs +++ b/dotnet/src/Connectors/VectorData.Abstractions/VectorSearch/VectorSearchOptions.cs @@ -1,17 +1,26 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Linq.Expressions; + namespace Microsoft.Extensions.VectorData; /// /// Options for vector search. /// -public class VectorSearchOptions +public class VectorSearchOptions { /// /// Gets or sets a search filter to use before doing the vector search. /// + [Obsolete("Use NewFilter instead")] public VectorSearchFilter? Filter { get; init; } + /// + /// Gets or sets a search filter to use before doing the vector search. + /// + public Expression>? NewFilter { get; init; } + /// /// Gets or sets the name of the vector property to search on. /// Use the name of the vector property from your data model or as provided in the record definition. diff --git a/dotnet/src/Functions/Functions.OpenApi/Functions.OpenApi.csproj b/dotnet/src/Functions/Functions.OpenApi/Functions.OpenApi.csproj index 2d37b88dca4a..1d72c971fcba 100644 --- a/dotnet/src/Functions/Functions.OpenApi/Functions.OpenApi.csproj +++ b/dotnet/src/Functions/Functions.OpenApi/Functions.OpenApi.csproj @@ -29,4 +29,8 @@ + + + + \ No newline at end of file diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreRecordCollectionTests.cs index e3a420a789f4..f7fb10081c76 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureAISearch/AzureAISearchVectorStoreRecordCollectionTests.cs @@ -14,6 +14,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.AzureAISearch; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Integration tests for class. /// Tests work with an Azure AI Search Instance. @@ -63,7 +65,7 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe var embedding = await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("A great hotel"); var actual = await sut.VectorizedSearchAsync( embedding, - new VectorSearchOptions + new() { IncludeVectors = true, Filter = new VectorSearchFilter().EqualTo("HotelName", "MyHotel Upsert-1") diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs index c5929e0ecaa2..7f471405b8c9 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs @@ -12,6 +12,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.AzureCosmosDBMongoDB; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + [Collection("AzureCosmosDBMongoDBVectorStoreCollection")] public class AzureCosmosDBMongoDBVectorStoreRecordCollectionTests(AzureCosmosDBMongoDBVectorStoreFixture fixture) { diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs index 6a0e249f4d7e..3864a48288ef 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBNoSQL/AzureCosmosDBNoSQLVectorStoreRecordCollectionTests.cs @@ -13,6 +13,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.AzureCosmosDBNoSQL; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Integration tests for class. /// diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs index 11da55ba3329..3f88b10eef4b 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs @@ -12,6 +12,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.MongoDB; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + [Collection("MongoDBVectorStoreCollection")] public class MongoDBVectorStoreRecordCollectionTests(MongoDBVectorStoreFixture fixture) { diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreRecordCollectionTests.cs index e30b2f35fbae..7e19c73128d0 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Pinecone/PineconeVectorStoreRecordCollectionTests.cs @@ -15,6 +15,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Pinecone; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + [Collection("PineconeVectorStoreTests")] [PineconeApiKeySetCondition] public class PineconeVectorStoreRecordCollectionTests(PineconeVectorStoreFixture fixture) : IClassFixture @@ -293,7 +295,7 @@ public async Task InsertGetModifyDeleteVectorAsync(bool collectionFromVectorStor // update await hotelRecordCollection.UpsertAsync(langriSha); - // this is not great but no vectors are added so we can't query status for number of vectors like we do for insert/delete + // this is not great but no vectors are added so we can't query status for number of vectors like we do for insert/delete await Task.Delay(2000); var updated = await hotelRecordCollection.GetAsync("langri-sha", new GetRecordOptions { IncludeVectors = true }); diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs index 7e3ae3ad9392..6a479f0b10bf 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresVectorStoreRecordCollectionTests.cs @@ -10,6 +10,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + [Collection("PostgresVectorStoreCollection")] public sealed class PostgresVectorStoreRecordCollectionTests(PostgresVectorStoreFixture fixture) { diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreRecordCollectionTests.cs index 135d09d025aa..940687525238 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Qdrant/QdrantVectorStoreRecordCollectionTests.cs @@ -15,6 +15,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Qdrant; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Contains tests for the class. /// @@ -66,7 +68,7 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool hasNamedVec var vector = await fixture.EmbeddingGenerator.GenerateEmbeddingAsync("A great hotel"); var actual = await sut.VectorizedSearchAsync( vector, - new VectorSearchOptions { Filter = new VectorSearchFilter().EqualTo("HotelCode", 30).AnyTagEqualTo("Tags", "t2") }); + new() { Filter = new VectorSearchFilter().EqualTo("HotelCode", 30).AnyTagEqualTo("Tags", "t2") }); // Assert var collectionExistResult = await sut.CollectionExistsAsync(); diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs index ef7ba087cf87..61018b2b7589 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisHashSetVectorStoreRecordCollectionTests.cs @@ -13,6 +13,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Redis; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Contains tests for the class. /// @@ -65,7 +67,7 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe var actual = await sut .VectorizedSearchAsync( new ReadOnlyMemory(new[] { 30f, 31f, 32f, 33f }), - new VectorSearchOptions { Filter = new VectorSearchFilter().EqualTo("HotelCode", 1), IncludeVectors = true }); + new() { Filter = new VectorSearchFilter().EqualTo("HotelCode", 1), IncludeVectors = true }); // Assert var collectionExistResult = await sut.CollectionExistsAsync(); @@ -316,7 +318,7 @@ public async Task ItCanSearchWithFloat32VectorAndFilterAsync(string filterType, // Act var actual = await sut.VectorizedSearchAsync( vector, - new VectorSearchOptions + new() { IncludeVectors = includeVectors, Filter = filter @@ -360,7 +362,7 @@ public async Task ItCanSearchWithFloat32VectorAndTopSkipAsync() // Act var actual = await sut.VectorizedSearchAsync( vector, - new VectorSearchOptions + new() { Top = 3, Skip = 2 @@ -390,7 +392,7 @@ public async Task ItCanSearchWithFloat64VectorAsync(bool includeVectors) // Act var actual = await sut.VectorizedSearchAsync( vector, - new VectorSearchOptions + new() { IncludeVectors = includeVectors, Top = 1 diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs index 1e6c3d9aed0e..a12d710d9446 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Redis/RedisJsonVectorStoreRecordCollectionTests.cs @@ -13,6 +13,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Redis; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Contains tests for the class. /// @@ -64,7 +66,7 @@ public async Task ItCanCreateACollectionUpsertGetAndSearchAsync(bool useRecordDe var getResult = await sut.GetAsync("Upsert-10", new GetRecordOptions { IncludeVectors = true }); var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new[] { 30f, 31f, 32f, 33f }), - new VectorSearchOptions { Filter = new VectorSearchFilter().EqualTo("HotelCode", 10) }); + new() { Filter = new VectorSearchFilter().EqualTo("HotelCode", 10) }); // Assert var collectionExistResult = await sut.CollectionExistsAsync(); @@ -346,7 +348,7 @@ public async Task ItCanSearchWithFloat32VectorAndFilterAsync(string filterType) // Act var actual = await sut.VectorizedSearchAsync( vector, - new VectorSearchOptions { IncludeVectors = true, Filter = filter }); + new() { IncludeVectors = true, Filter = filter }); // Assert var searchResults = await actual.Results.ToListAsync(); @@ -384,7 +386,7 @@ public async Task ItCanSearchWithFloat32VectorAndTopSkipAsync() // Act var actual = await sut.VectorizedSearchAsync( vector, - new VectorSearchOptions + new() { Top = 3, Skip = 2 @@ -414,7 +416,7 @@ public async Task ItCanSearchWithFloat64VectorAsync(bool includeVectors) // Act var actual = await sut.VectorizedSearchAsync( vector, - new VectorSearchOptions + new() { IncludeVectors = includeVectors, Top = 1 diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs index 214510438d59..c0dbb5fcf680 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Sqlite/SqliteVectorStoreRecordCollectionTests.cs @@ -10,6 +10,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Sqlite; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// /// Integration tests for class. /// diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateVectorStoreRecordCollectionTests.cs index 9ffaf3172eec..bd6348932937 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Weaviate/WeaviateVectorStoreRecordCollectionTests.cs @@ -10,6 +10,8 @@ namespace SemanticKernel.IntegrationTests.Connectors.Memory.Weaviate; +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + [Collection("WeaviateVectorStoreCollection")] public sealed class WeaviateVectorStoreRecordCollectionTests(WeaviateVectorStoreFixture fixture) { diff --git a/dotnet/src/IntegrationTests/Data/BaseVectorStoreTextSearchTests.cs b/dotnet/src/IntegrationTests/Data/BaseVectorStoreTextSearchTests.cs index 90ce87f14482..143c61f69e5f 100644 --- a/dotnet/src/IntegrationTests/Data/BaseVectorStoreTextSearchTests.cs +++ b/dotnet/src/IntegrationTests/Data/BaseVectorStoreTextSearchTests.cs @@ -102,7 +102,7 @@ public Task>> GenerateEmbeddingsAsync(IList protected sealed class VectorizedSearchWrapper(IVectorizedSearch vectorizedSearch, ITextEmbeddingGenerationService textEmbeddingGeneration) : IVectorizableTextSearch { /// - public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { var vectorizedQuery = await textEmbeddingGeneration!.GenerateEmbeddingAsync(searchText, cancellationToken: cancellationToken).ConfigureAwait(false); diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/UnreachableException.cs b/dotnet/src/InternalUtilities/src/Diagnostics/UnreachableException.cs new file mode 100644 index 000000000000..616073f54705 --- /dev/null +++ b/dotnet/src/InternalUtilities/src/Diagnostics/UnreachableException.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. + +#if NETSTANDARD2_0 + +// Polyfill for using UnreachableException with .NET Standard 2.0 + +namespace System.Diagnostics; + +#pragma warning disable CA1064 // Exceptions should be public +#pragma warning disable CA1812 // Internal class that is (sometimes) never instantiated. + +/// +/// Exception thrown when the program executes an instruction that was thought to be unreachable. +/// +internal sealed class UnreachableException : Exception +{ + private const string MessageText = "The program executed an instruction that was thought to be unreachable."; + + /// + /// Initializes a new instance of the class with the default error message. + /// + public UnreachableException() + : base(MessageText) + { + } + + /// + /// Initializes a new instance of the + /// class with a specified error message. + /// + /// The error message that explains the reason for the exception. + public UnreachableException(string? message) + : base(message ?? MessageText) + { + } + + /// + /// Initializes a new instance of the + /// class with a specified error message and a reference to the inner exception that is the cause of + /// this exception. + /// + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception. + public UnreachableException(string? message, Exception? innerException) + : base(message ?? MessageText, innerException) + { + } +} + +#endif diff --git a/dotnet/src/InternalUtilities/src/System/IndexRange.cs b/dotnet/src/InternalUtilities/src/System/IndexRange.cs new file mode 100644 index 000000000000..439e6e844fb6 --- /dev/null +++ b/dotnet/src/InternalUtilities/src/System/IndexRange.cs @@ -0,0 +1,288 @@ +// Copyright (c) Microsoft. All rights reserved. + +#if NETSTANDARD2_0 + +// Polyfill for using Index and Range with .NET Standard 2.0 (see https://www.meziantou.net/how-to-use-csharp-8-indices-and-ranges-in-dotnet-standard-2-0-and-dotn.htm) + +// https://github.com/dotnet/runtime/blob/419e949d258ecee4c40a460fb09c66d974229623/src/libraries/System.Private.CoreLib/src/System/Index.cs +// https://github.com/dotnet/runtime/blob/419e949d258ecee4c40a460fb09c66d974229623/src/libraries/System.Private.CoreLib/src/System/Range.cs + +#pragma warning disable RCS1168 +#pragma warning disable RCS1211 +#pragma warning disable IDE0009 +#pragma warning disable IDE0011 +#pragma warning disable IDE0090 + +using System.Runtime.CompilerServices; + +namespace System +{ + /// Represent a type can be used to index a collection either from the start or the end. + /// + /// Index is used by the C# compiler to support the new index syntax + /// + /// int[] someArray = new int[5] { 1, 2, 3, 4, 5 } ; + /// int lastElement = someArray[^1]; // lastElement = 5 + /// + /// + internal readonly struct Index : IEquatable + { + private readonly int _value; + + /// Construct an Index using a value and indicating if the index is from the start or from the end. + /// The index value. it has to be zero or positive number. + /// Indicating if the index is from the start or from the end. + /// + /// If the Index constructed from the end, index value 1 means pointing at the last element and index value 0 means pointing at beyond last element. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Index(int value, bool fromEnd = false) + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative"); + } + + if (fromEnd) + _value = ~value; + else + _value = value; + } + + // The following private constructors mainly created for perf reason to avoid the checks + private Index(int value) + { + _value = value; + } + + /// Create an Index pointing at first element. + public static Index Start => new Index(0); + + /// Create an Index pointing at beyond last element. + public static Index End => new Index(~0); + + /// Create an Index from the start at the position indicated by the value. + /// The index value from the start. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Index FromStart(int value) + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative"); + } + + return new Index(value); + } + + /// Create an Index from the end at the position indicated by the value. + /// The index value from the end. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Index FromEnd(int value) + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "value must be non-negative"); + } + + return new Index(~value); + } + + /// Returns the index value. + public int Value + { + get + { + if (_value < 0) + { + return ~_value; + } + else + { + return _value; + } + } + } + + /// Indicates whether the index is from the start or the end. + public bool IsFromEnd => _value < 0; + + /// Calculate the offset from the start using the giving collection length. + /// The length of the collection that the Index will be used with. length has to be a positive value + /// + /// For performance reason, we don't validate the input length parameter and the returned offset value against negative values. + /// we don't validate either the returned offset is greater than the input length. + /// It is expected Index will be used with collections which always have non negative length/count. If the returned offset is negative and + /// then used to index a collection will get out of range exception which will be same affect as the validation. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int GetOffset(int length) + { + var offset = _value; + if (IsFromEnd) + { + // offset = length - (~value) + // offset = length + (~(~value) + 1) + // offset = length + value + 1 + + offset += length + 1; + } + return offset; + } + + /// Indicates whether the current Index object is equal to another object of the same type. + /// An object to compare with this object + public override bool Equals(object? value) => value is Index && _value == ((Index)value)._value; + + /// Indicates whether the current Index object is equal to another Index object. + /// An object to compare with this object + public bool Equals(Index other) => _value == other._value; + + /// Returns the hash code for this instance. + public override int GetHashCode() => _value; + + /// Converts integer number to an Index. + public static implicit operator Index(int value) => FromStart(value); + + /// Converts the value of the current Index object to its equivalent string representation. + public override string ToString() + { + if (IsFromEnd) + return "^" + ((uint)Value).ToString(); + + return ((uint)Value).ToString(); + } + } + + /// Represent a range has start and end indexes. + /// + /// Range is used by the C# compiler to support the range syntax. + /// + /// int[] someArray = new int[5] { 1, 2, 3, 4, 5 }; + /// int[] subArray1 = someArray[0..2]; // { 1, 2 } + /// int[] subArray2 = someArray[1..^0]; // { 2, 3, 4, 5 } + /// + /// + internal readonly struct Range : IEquatable + { + /// Represent the inclusive start index of the Range. + public Index Start { get; } + + /// Represent the exclusive end index of the Range. + public Index End { get; } + + /// Construct a Range object using the start and end indexes. + /// Represent the inclusive start index of the range. + /// Represent the exclusive end index of the range. + public Range(Index start, Index end) + { + Start = start; + End = end; + } + + /// Indicates whether the current Range object is equal to another object of the same type. + /// An object to compare with this object + public override bool Equals(object? value) => + value is Range r && + r.Start.Equals(Start) && + r.End.Equals(End); + + /// Indicates whether the current Range object is equal to another Range object. + /// An object to compare with this object + public bool Equals(Range other) => other.Start.Equals(Start) && other.End.Equals(End); + + /// Returns the hash code for this instance. + public override int GetHashCode() + { + return Start.GetHashCode() * 31 + End.GetHashCode(); + } + + /// Converts the value of the current Range object to its equivalent string representation. + public override string ToString() + { + return Start + ".." + End; + } + + /// Create a Range object starting from start index to the end of the collection. + public static Range StartAt(Index start) => new Range(start, Index.End); + + /// Create a Range object starting from first element in the collection to the end Index. + public static Range EndAt(Index end) => new Range(Index.Start, end); + + /// Create a Range object starting from first element to the end. + public static Range All => new Range(Index.Start, Index.End); + + /// Calculate the start offset and length of range object using a collection length. + /// The length of the collection that the range will be used with. length has to be a positive value. + /// + /// For performance reason, we don't validate the input length parameter against negative values. + /// It is expected Range will be used with collections which always have non negative length/count. + /// We validate the range is inside the length scope though. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public (int Offset, int Length) GetOffsetAndLength(int length) + { + int start; + var startIndex = Start; + if (startIndex.IsFromEnd) + start = length - startIndex.Value; + else + start = startIndex.Value; + + int end; + var endIndex = End; + if (endIndex.IsFromEnd) + end = length - endIndex.Value; + else + end = endIndex.Value; + + if ((uint)end > (uint)length || (uint)start > (uint)end) + { + throw new ArgumentOutOfRangeException(nameof(length)); + } + + return (start, end - start); + } + } +} + +namespace System.Runtime.CompilerServices +{ + internal static class RuntimeHelpers + { + /// + /// Slices the specified array using the specified range. + /// + public static T[] GetSubArray(T[] array, Range range) + { + if (array == null) + { + throw new ArgumentNullException(nameof(array)); + } + + (int offset, int length) = range.GetOffsetAndLength(array.Length); + + if (default(T) != null || typeof(T[]) == array.GetType()) + { + // We know the type of the array to be exactly T[]. + + if (length == 0) + { + return Array.Empty(); + } + + var dest = new T[length]; + Array.Copy(array, offset, dest, 0, length); + return dest; + } + else + { + // The array is actually a U[] where U:T. + var dest = (T[])Array.CreateInstance(array.GetType().GetElementType(), length); + Array.Copy(array, offset, dest, 0, length); + return dest; + } + } + } +} + +#endif diff --git a/dotnet/src/Plugins/Plugins.Web/Bing/BingTextSearch.cs b/dotnet/src/Plugins/Plugins.Web/Bing/BingTextSearch.cs index 97526f388b17..556e04f148d3 100644 --- a/dotnet/src/Plugins/Plugins.Web/Bing/BingTextSearch.cs +++ b/dotnet/src/Plugins/Plugins.Web/Bing/BingTextSearch.cs @@ -241,6 +241,7 @@ public TextSearchResult MapFromResultToTextSearchResult(object result) } } +#pragma warning disable CS0618 // FilterClause is obsolete /// /// Build a query string from the /// @@ -280,5 +281,7 @@ private static string BuildQuery(string query, TextSearchOptions searchOptions) return fullQuery.ToString(); } +#pragma warning restore CS0618 // FilterClause is obsolete + #endregion } diff --git a/dotnet/src/Plugins/Plugins.Web/Google/GoogleTextSearch.cs b/dotnet/src/Plugins/Plugins.Web/Google/GoogleTextSearch.cs index a42500fa7c4e..c4165a2edadc 100644 --- a/dotnet/src/Plugins/Plugins.Web/Google/GoogleTextSearch.cs +++ b/dotnet/src/Plugins/Plugins.Web/Google/GoogleTextSearch.cs @@ -160,6 +160,7 @@ public void Dispose() return await search.ExecuteAsync(cancellationToken).ConfigureAwait(false); } +#pragma warning disable CS0618 // FilterClause is obsolete /// /// Add basic filters to the Google search metadata. /// @@ -192,6 +193,7 @@ private void AddFilters(CseResource.ListRequest search, TextSearchOptions search } } } +#pragma warning restore CS0618 // FilterClause is obsolete /// /// Return the search results as instances of . diff --git a/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj b/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj index 235c08e4d52b..47043cbe1df8 100644 --- a/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj +++ b/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj @@ -57,6 +57,9 @@ + + + diff --git a/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/MockVectorizableTextSearch.cs b/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/MockVectorizableTextSearch.cs index 454d82ace013..b39976adbebf 100644 --- a/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/MockVectorizableTextSearch.cs +++ b/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/MockVectorizableTextSearch.cs @@ -13,7 +13,7 @@ public MockVectorizableTextSearch(IEnumerable> searc this._searchResults = ToAsyncEnumerable(searchResults); } - public Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { return Task.FromResult(new VectorSearchResults(this._searchResults)); } diff --git a/dotnet/src/SemanticKernel.Core/CompatibilitySuppressions.xml b/dotnet/src/SemanticKernel.Core/CompatibilitySuppressions.xml index de2e33319a56..6c9084abb2ce 100644 --- a/dotnet/src/SemanticKernel.Core/CompatibilitySuppressions.xml +++ b/dotnet/src/SemanticKernel.Core/CompatibilitySuppressions.xml @@ -1,5 +1,5 @@  - + CP0002 @@ -29,6 +29,13 @@ lib/net8.0/Microsoft.SemanticKernel.Core.dll true + + CP0002 + M:Microsoft.SemanticKernel.Data.VolatileVectorStoreRecordCollection`2.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/net8.0/Microsoft.SemanticKernel.Core.dll + lib/net8.0/Microsoft.SemanticKernel.Core.dll + true + CP0002 M:Microsoft.SemanticKernel.Data.VolatileVectorStoreRecordCollection`2.DeleteAsync(`0,Microsoft.Extensions.VectorData.DeleteRecordOptions,System.Threading.CancellationToken) @@ -57,4 +64,11 @@ lib/netstandard2.0/Microsoft.SemanticKernel.Core.dll true + + CP0002 + M:Microsoft.SemanticKernel.Data.VolatileVectorStoreRecordCollection`2.VectorizedSearchAsync``1(``0,Microsoft.Extensions.VectorData.VectorSearchOptions,System.Threading.CancellationToken) + lib/netstandard2.0/Microsoft.SemanticKernel.Core.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Core.dll + true + \ No newline at end of file diff --git a/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs b/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs index 6970294723ef..42781b1c5483 100644 --- a/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs +++ b/dotnet/src/SemanticKernel.Core/Data/TextSearch/VectorStoreTextSearch.cs @@ -197,9 +197,11 @@ private TextSearchStringMapper CreateTextSearchStringMapper() private async Task> ExecuteVectorSearchAsync(string query, TextSearchOptions? searchOptions, CancellationToken cancellationToken) { searchOptions ??= new TextSearchOptions(); - var vectorSearchOptions = new VectorSearchOptions + var vectorSearchOptions = new VectorSearchOptions { +#pragma warning disable CS0618 // VectorSearchFilter is obsolete Filter = searchOptions.Filter?.FilterClauses is not null ? new VectorSearchFilter(searchOptions.Filter.FilterClauses) : null, +#pragma warning restore CS0618 // VectorSearchFilter is obsolete Skip = searchOptions.Skip, Top = searchOptions.Top, }; diff --git a/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStoreRecordCollection.cs b/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStoreRecordCollection.cs index da062934cfbb..e94f321eed4a 100644 --- a/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStoreRecordCollection.cs +++ b/dotnet/src/SemanticKernel.Core/Data/VolatileVectorStoreRecordCollection.cs @@ -31,7 +31,7 @@ public sealed class VolatileVectorStoreRecordCollection : IVector ]; /// The default options for vector search. - private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); /// Internal storage for all of the record collections. private readonly ConcurrentDictionary> _internalCollections; @@ -213,7 +213,7 @@ public async IAsyncEnumerable UpsertBatchAsync(IEnumerable record /// #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - Need to satisfy the interface which returns IAsyncEnumerable - public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizedSearchAsync(TVector vector, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) #pragma warning restore CS1998 { Verify.NotNull(vector); @@ -238,6 +238,11 @@ public async Task> VectorizedSearchAsync(T } // Filter records using the provided filter before doing the vector comparison. + if (internalOptions.NewFilter is not null) + { + throw new NotSupportedException("LINQ-based filtering is not supported with VolatileVectorStore, use Microsoft.SemanticKernel.Connectors.InMemory instead"); + } + var filteredRecords = VolatileVectorStoreCollectionSearchMapping.FilterRecords(internalOptions.Filter, this.GetCollectionDictionary().Values); // Compare each vector in the filtered results with the provided vector. diff --git a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj index c4c4956a3fa8..268c2e470314 100644 --- a/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj +++ b/dotnet/src/SemanticKernel.Core/SemanticKernel.Core.csproj @@ -54,5 +54,5 @@ - + \ No newline at end of file diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTestBase.cs b/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTestBase.cs index 262289c567d0..c01fe06eddf4 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTestBase.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Data/VectorStoreTextSearchTestBase.cs @@ -126,7 +126,7 @@ public Task>> GenerateEmbeddingsAsync(IList public sealed class VectorizedSearchWrapper(IVectorizedSearch vectorizedSearch, ITextEmbeddingGenerationService textEmbeddingGeneration) : IVectorizableTextSearch { /// - public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) + public async Task> VectorizableTextSearchAsync(string searchText, VectorSearchOptions? options = null, CancellationToken cancellationToken = default) { var vectorizedQuery = await textEmbeddingGeneration!.GenerateEmbeddingAsync(searchText, cancellationToken: cancellationToken).ConfigureAwait(false); return await vectorizedSearch.VectorizedSearchAsync(vectorizedQuery, options, cancellationToken); diff --git a/dotnet/src/SemanticKernel.UnitTests/Data/VolatileVectorStoreRecordCollectionTests.cs b/dotnet/src/SemanticKernel.UnitTests/Data/VolatileVectorStoreRecordCollectionTests.cs index 9530c48fe574..edd169a725ff 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Data/VolatileVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Data/VolatileVectorStoreRecordCollectionTests.cs @@ -294,7 +294,7 @@ public async Task CanSearchWithVectorAsync(bool useDefinition, TKey testKe // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true }, + new() { IncludeVectors = true }, this._testCancellationToken); // Assert @@ -338,7 +338,7 @@ public async Task CanSearchWithVectorAndFilterAsync(bool useDefinition, TK var filter = filterType == "Equality" ? new VectorSearchFilter().EqualTo("Data", $"data {testKey2}") : new VectorSearchFilter().AnyTagEqualTo("Tags", $"tag {testKey2}"); var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true, Filter = filter, IncludeTotalCount = true }, + new() { IncludeVectors = true, Filter = filter, IncludeTotalCount = true }, this._testCancellationToken); // Assert @@ -390,7 +390,7 @@ public async Task CanSearchWithDifferentDistanceFunctionsAsync(string distanceFu // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true }, + new() { IncludeVectors = true }, this._testCancellationToken); // Assert @@ -431,7 +431,7 @@ public async Task CanSearchManyRecordsAsync(bool useDefinition) // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory(new float[] { 1, 1, 1, 1 }), - new VectorSearchOptions { IncludeVectors = true, Top = 10, Skip = 10, IncludeTotalCount = true }, + new() { IncludeVectors = true, Top = 10, Skip = 10, IncludeTotalCount = true }, this._testCancellationToken); // Assert @@ -507,7 +507,7 @@ public async Task ItCanSearchUsingTheGenericDataModelAsync(TKey testKey1, // Act var actual = await sut.VectorizedSearchAsync( new ReadOnlyMemory([1, 1, 1, 1]), - new VectorSearchOptions { IncludeVectors = true, VectorPropertyName = "Vector" }, + new() { IncludeVectors = true, VectorPropertyName = "Vector" }, this._testCancellationToken); // Assert diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/AzureAISearchIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/AzureAISearchIntegrationTests.csproj new file mode 100644 index 000000000000..0fcc13f45809 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/AzureAISearchIntegrationTests.csproj @@ -0,0 +1,31 @@ + + + + net8.0;net472 + enable + enable + true + false + AzureAISearchIntegrationTests + b7762d10-e29b-4bb1-8b74-b6d69a667dd4 + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchBasicFilterTests.cs new file mode 100644 index 000000000000..9683543d3e98 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchBasicFilterTests.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using Xunit; + +namespace AzureAISearchIntegrationTests.Filter; + +public class AzureAISearchBasicFilterTests(AzureAISearchFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture +{ + // Azure AI Search only supports search.in() over strings + public override Task Contains_over_inline_int_array() + => Assert.ThrowsAsync(() => base.Contains_over_inline_int_array()); +} diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchFilterFixture.cs new file mode 100644 index 000000000000..a5ec5df341dd --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Filter/AzureAISearchFilterFixture.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. + +using AzureAISearchIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace AzureAISearchIntegrationTests.Filter; + +public class AzureAISearchFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => AzureAISearchTestStore.Instance; + + // Azure AI search only supports lowercase letters, digits or dashes. + protected override string StoreName => "filter-tests"; +} diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Properties/AssemblyAttributes.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Properties/AssemblyAttributes.cs new file mode 100644 index 000000000000..786c2742c2b3 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Properties/AssemblyAttributes.cs @@ -0,0 +1,3 @@ +// Copyright (c) Microsoft. All rights reserved. + +[assembly: AzureAISearchIntegrationTests.Support.AzureAISearchUrlRequiredAttribute] diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestEnvironment.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestEnvironment.cs new file mode 100644 index 000000000000..27e905656870 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestEnvironment.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.Configuration; + +namespace AzureAISearchIntegrationTests.Support; + +#pragma warning disable CA1810 // Initialize all static fields when those fields are declared + +internal static class AzureAISearchTestEnvironment +{ + public static readonly string? ServiceUrl, ApiKey; + + public static bool IsConnectionInfoDefined => ServiceUrl is not null && ApiKey is not null; + + static AzureAISearchTestEnvironment() + { + var configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true) + .AddJsonFile(path: "testsettings.development.json", optional: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + var azureAISearchSection = configuration.GetSection("AzureAISearch"); + ServiceUrl = azureAISearchSection?["ServiceUrl"]; + ApiKey = azureAISearchSection?["ApiKey"]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestStore.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestStore.cs new file mode 100644 index 000000000000..791005d55c9a --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchTestStore.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Azure; +using Azure.Search.Documents.Indexes; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.AzureAISearch; +using VectorDataSpecificationTests.Support; + +namespace AzureAISearchIntegrationTests.Support; + +internal sealed class AzureAISearchTestStore : TestStore +{ + public static AzureAISearchTestStore Instance { get; } = new(); + + private SearchIndexClient? _client; + private AzureAISearchVectorStore? _defaultVectorStore; + + public SearchIndexClient Client + => this._client ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + public override IVectorStore DefaultVectorStore + => this._defaultVectorStore ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + public AzureAISearchVectorStore GetVectorStore(AzureAISearchVectorStoreOptions options) + => new(this.Client, options); + + private AzureAISearchTestStore() + { + } + + protected override Task StartAsync() + { + (string? serviceUrl, string? apiKey) = (AzureAISearchTestEnvironment.ServiceUrl, AzureAISearchTestEnvironment.ApiKey); + + if (string.IsNullOrWhiteSpace(serviceUrl) || string.IsNullOrWhiteSpace(apiKey)) + { + throw new InvalidOperationException("Service URL and API key are not configured, set AzureAISearch:ServiceUrl and AzureAISearch:ApiKey"); + } + + this._client = new SearchIndexClient(new Uri(serviceUrl), new AzureKeyCredential(apiKey)); + this._defaultVectorStore = new(this._client); + + return Task.CompletedTask; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchUrlRequiredAttribute.cs b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchUrlRequiredAttribute.cs new file mode 100644 index 000000000000..1b30639bc1be --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/AzureAISearchIntegrationTests/Support/AzureAISearchUrlRequiredAttribute.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Xunit; + +namespace AzureAISearchIntegrationTests.Support; + +/// +/// Checks whether the sqlite_vec extension is properly installed, and skips the test(s) otherwise. +/// +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Assembly)] +public sealed class AzureAISearchUrlRequiredAttribute : Attribute, ITestCondition +{ + public ValueTask IsMetAsync() => new(AzureAISearchTestEnvironment.IsConnectionInfoDefined); + + public string Skip { get; set; } = "Service URL and API key are not configured, set AzureAISearch:ServiceUrl and AzureAISearch:ApiKey."; + + public string SkipReason + => this.Skip; +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoDBIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoDBIntegrationTests.csproj new file mode 100644 index 000000000000..aaf0dcf8160b --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/CosmosMongoDBIntegrationTests.csproj @@ -0,0 +1,29 @@ + + + + net8.0;net472 + enable + enable + true + false + MongoDBIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicFilterTests.cs new file mode 100644 index 000000000000..33d14908f537 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoBasicFilterTests.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace MongoDBIntegrationTests.Filter; + +public class CosmosMongoBasicFilterTests(CosmosMongoFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture +{ + // Specialized MongoDB syntax for NOT over Contains ($nin) + [ConditionalFact] + public virtual Task Not_over_Contains() + => this.TestFilterAsync(r => !new[] { 8, 10 }.Contains(r.Int)); + + #region Null checking + + // MongoDB currently doesn't support null checking ({ "Foo" : null }) in vector search pre-filters + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_reference_type() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_reference_type()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + #endregion + + #region Not + + // MongoDB currently doesn't support NOT in vector search pre-filters + // (https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter) + public override Task Not_over_And() + => Assert.ThrowsAsync(() => base.Not_over_And()); + + public override Task Not_over_Or() + => Assert.ThrowsAsync(() => base.Not_over_Or()); + + #endregion + + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + // AnyTagEqualTo not (currently) supported on SQLite + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_array() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_array()); + + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_List() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_List()); +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoFilterFixture.cs new file mode 100644 index 000000000000..129c7b0cc337 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Filter/CosmosMongoFilterFixture.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. + +using MongoDBIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace MongoDBIntegrationTests.Filter; + +public class CosmosMongoFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => CosmosMongoDBTestStore.Instance; + + protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.IvfFlat; + protected override string DistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Properties/AssemblyAttributes.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Properties/AssemblyAttributes.cs new file mode 100644 index 000000000000..4e8438d68759 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Properties/AssemblyAttributes.cs @@ -0,0 +1,3 @@ +// Copyright (c) Microsoft. All rights reserved. + +[assembly: CosmosIntegrationTests.Support.CosmosConnectionStringRequired] diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosConnectionStringRequiredAttribute.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosConnectionStringRequiredAttribute.cs new file mode 100644 index 000000000000..c944d36eb78c --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosConnectionStringRequiredAttribute.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosNoSQLIntegrationTests.Support; +using VectorDataSpecificationTests.Xunit; + +namespace CosmosIntegrationTests.Support; + +/// +/// Checks whether the sqlite_vec extension is properly installed, and skips the test(s) otherwise. +/// +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Assembly)] +public sealed class CosmosConnectionStringRequiredAttribute : Attribute, ITestCondition +{ + public ValueTask IsMetAsync() => new(CosmosMongoDBTestEnvironment.IsConnectionStringDefined); + + public string Skip { get; set; } = "The Cosmos connection string hasn't been configured (AzureCosmosDBMongoDB:ConnectionString)."; + + public string SkipReason + => this.Skip; +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestEnvironment.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestEnvironment.cs new file mode 100644 index 000000000000..1adcb225e66d --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestEnvironment.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.Configuration; + +namespace CosmosNoSQLIntegrationTests.Support; + +#pragma warning disable CA1810 // Initialize all static fields when those fields are declared + +public static class CosmosMongoDBTestEnvironment +{ + public static readonly string? ConnectionString; + + public static bool IsConnectionStringDefined => ConnectionString is not null; + + static CosmosMongoDBTestEnvironment() + { + var configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true) + .AddJsonFile(path: "testsettings.development.json", optional: true) + .AddEnvironmentVariables() + .Build(); + + ConnectionString = configuration["AzureCosmosDBMongoDB:ConnectionString"]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestStore.cs b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestStore.cs new file mode 100644 index 000000000000..b0d4c379ecf4 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosMongoDBIntegrationTests/Support/CosmosMongoDBTestStore.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosNoSQLIntegrationTests.Support; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; +using MongoDB.Driver; +using VectorDataSpecificationTests.Support; + +namespace MongoDBIntegrationTests.Support; + +public sealed class CosmosMongoDBTestStore : TestStore +{ + public static CosmosMongoDBTestStore Instance { get; } = new(); + + private MongoClient? _client; + private IMongoDatabase? _database; + private AzureCosmosDBMongoDBVectorStore? _defaultVectorStore; + + public MongoClient Client => this._client ?? throw new InvalidOperationException("Not initialized"); + public IMongoDatabase Database => this._database ?? throw new InvalidOperationException("Not initialized"); + + public override IVectorStore DefaultVectorStore + => this._defaultVectorStore ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + public AzureCosmosDBMongoDBVectorStore GetVectorStore(AzureCosmosDBMongoDBVectorStoreOptions options) + => new(this.Database, options); + + private CosmosMongoDBTestStore() + { + } + + protected override Task StartAsync() + { + if (string.IsNullOrWhiteSpace(CosmosMongoDBTestEnvironment.ConnectionString)) + { + throw new InvalidOperationException("Connection string is not configured, set the AzureCosmosDBMongoDB:ConnectionString environment variable"); + } + + this._client = new MongoClient(CosmosMongoDBTestEnvironment.ConnectionString); + this._database = this._client.GetDatabase("VectorSearchTests"); + this._defaultVectorStore = new(this._database); + + return Task.CompletedTask; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CosmosNoSQLIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CosmosNoSQLIntegrationTests.csproj new file mode 100644 index 000000000000..dd8e3f7a9ba0 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/CosmosNoSQLIntegrationTests.csproj @@ -0,0 +1,29 @@ + + + + net8.0;net472 + enable + enable + true + false + CosmosNoSQLIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLBasicFilterTests.cs new file mode 100644 index 000000000000..b67141d82e6c --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLBasicFilterTests.cs @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using Xunit; + +namespace CosmosNoSQLIntegrationTests.Filter; + +public class CosmosNoSQLBasicFilterTests(CosmosNoSQLFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture; diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLFilterFixture.cs new file mode 100644 index 000000000000..8aaf6b86d4f9 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Filter/CosmosNoSQLFilterFixture.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using CosmosNoSQLIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace CosmosNoSQLIntegrationTests.Filter; + +public class CosmosNoSQLFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => CosmosNoSqlTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Properties/AssemblyAttributes.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Properties/AssemblyAttributes.cs new file mode 100644 index 000000000000..183a8a7c926c --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Properties/AssemblyAttributes.cs @@ -0,0 +1,3 @@ +// Copyright (c) Microsoft. All rights reserved. + +[assembly: CosmosNoSQLIntegrationTests.Support.CosmosConnectionStringRequired] diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosConnectionStringRequiredAttribute.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosConnectionStringRequiredAttribute.cs new file mode 100644 index 000000000000..2183f166d3ec --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosConnectionStringRequiredAttribute.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Xunit; + +namespace CosmosNoSQLIntegrationTests.Support; + +/// +/// Checks whether the sqlite_vec extension is properly installed, and skips the test(s) otherwise. +/// +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Assembly)] +public sealed class CosmosConnectionStringRequiredAttribute : Attribute, ITestCondition +{ + public ValueTask IsMetAsync() => new(CosmosNoSQLTestEnvironment.IsConnectionStringDefined); + + public string Skip { get; set; } = "The Cosmos connection string hasn't been configured (AzureCosmosDBNoSQL:ConnectionString)."; + + public string SkipReason + => this.Skip; +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestEnvironment.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestEnvironment.cs new file mode 100644 index 000000000000..bd2848a2cb8f --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestEnvironment.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.Configuration; + +namespace CosmosNoSQLIntegrationTests.Support; + +#pragma warning disable CA1810 // Initialize all static fields when those fields are declared + +internal static class CosmosNoSQLTestEnvironment +{ + public static readonly string? ConnectionString; + + public static bool IsConnectionStringDefined => ConnectionString is not null; + + static CosmosNoSQLTestEnvironment() + { + var configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true) + .AddJsonFile(path: "testsettings.development.json", optional: true) + .AddEnvironmentVariables() + .Build(); + + ConnectionString = configuration["AzureCosmosDBNoSQL:ConnectionString"]; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs new file mode 100644 index 000000000000..7e3269ba2a27 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/CosmosNoSQLIntegrationTests/Support/CosmosNoSQLTestStore.cs @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft. All rights reserved. + +#if NET472 +using System.Net.Http; +#endif +using System.Text.Json; +using Microsoft.Azure.Cosmos; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.AzureCosmosDBNoSQL; +using VectorDataSpecificationTests.Support; + +namespace CosmosNoSQLIntegrationTests.Support; + +#pragma warning disable CA1001 // Type owns disposable fields (_connection) but is not disposable + +internal sealed class CosmosNoSqlTestStore : TestStore +{ + public static CosmosNoSqlTestStore Instance { get; } = new(); + + private CosmosClient? _client; + private Database? _database; + private AzureCosmosDBNoSQLVectorStore? _defaultVectorStore; + + public CosmosClient Client + => this._client ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + public Database Database + => this._database ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + public override IVectorStore DefaultVectorStore + => this._defaultVectorStore ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + public AzureCosmosDBNoSQLVectorStore GetVectorStore(AzureCosmosDBNoSQLVectorStoreOptions options) + => new(this.Database, options); + + private CosmosNoSqlTestStore() + { + } + +#pragma warning disable CA5400 // HttpClient may be created without enabling CheckCertificateRevocationList + protected override async Task StartAsync() + { + var connectionString = CosmosNoSQLTestEnvironment.ConnectionString; + + if (string.IsNullOrWhiteSpace(connectionString)) + { + throw new InvalidOperationException("Connection string is not configured, set the AzureCosmosDBNoSQL:ConnectionString environment variable"); + } + + var options = new CosmosClientOptions + { + UseSystemTextJsonSerializerWithOptions = JsonSerializerOptions.Default, + ConnectionMode = ConnectionMode.Gateway, + HttpClientFactory = () => new HttpClient(new HttpClientHandler { ServerCertificateCustomValidationCallback = HttpClientHandler.DangerousAcceptAnyServerCertificateValidator }) + }; + + this._client = new CosmosClient(connectionString, options); + this._database = this._client.GetDatabase("VectorDataIntegrationTests"); + await this._client.CreateDatabaseIfNotExistsAsync("VectorDataIntegrationTests"); + this._defaultVectorStore = new(this._database); + } +#pragma warning restore CA5400 +} diff --git a/dotnet/src/VectorDataIntegrationTests/Directory.Build.props b/dotnet/src/VectorDataIntegrationTests/Directory.Build.props new file mode 100644 index 000000000000..f5d133b5fd9f --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/Directory.Build.props @@ -0,0 +1,20 @@ + + + + + $(NoWarn);CA1515 + $(NoWarn);CA1707 + $(NoWarn);CA1716 + $(NoWarn);CA1720 + $(NoWarn);CA1861 + $(NoWarn);CA2007;VSTHRD111 + $(NoWarn);CS1591 + $(NoWarn);IDE1006 + + + + + $(NoWarn);CS8604;CS8602 + + + diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryBasicFilterTests.cs new file mode 100644 index 000000000000..32adf75e9017 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryBasicFilterTests.cs @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using Xunit; + +namespace PostgresIntegrationTests.Filter; + +public class InMemoryBasicFilterTests(InMemoryFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture; diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryFilterFixture.cs new file mode 100644 index 000000000000..7952d1dffad3 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Filter/InMemoryFilterFixture.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using InMemoryIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace PostgresIntegrationTests.Filter; + +public class InMemoryFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => InMemoryTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryIntegrationTests.csproj new file mode 100644 index 000000000000..f77fff8de939 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/InMemoryIntegrationTests.csproj @@ -0,0 +1,26 @@ + + + + net8.0;net472 + enable + enable + true + false + InMemoryIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemoryTestStore.cs b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemoryTestStore.cs new file mode 100644 index 000000000000..246d5166c831 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/InMemoryIntegrationTests/Support/InMemoryTestStore.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.InMemory; +using VectorDataSpecificationTests.Support; + +namespace InMemoryIntegrationTests.Support; + +internal sealed class InMemoryTestStore : TestStore +{ + public static InMemoryTestStore Instance { get; } = new(); + + private InMemoryVectorStore _vectorStore = new(); + + public override IVectorStore DefaultVectorStore => this._vectorStore; + + private InMemoryTestStore() + { + } + + protected override Task StartAsync() + { + this._vectorStore = new InMemoryVectorStore(); + + return Task.CompletedTask; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicFilterTests.cs new file mode 100644 index 000000000000..a6ad4378f7a1 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBBasicFilterTests.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace MongoDBIntegrationTests.Filter; + +public class MongoDBBasicFilterTests(MongoDBFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture +{ + // Specialized MongoDB syntax for NOT over Contains ($nin) + [ConditionalFact] + public virtual Task Not_over_Contains() + => this.TestFilterAsync(r => !new[] { 8, 10 }.Contains(r.Int)); + + #region Null checking + + // MongoDB currently doesn't support null checking ({ "Foo" : null }) in vector search pre-filters + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_reference_type() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_reference_type()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + #endregion + + #region Not + + // MongoDB currently doesn't support NOT in vector search pre-filters + // (https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter) + public override Task Not_over_And() + => Assert.ThrowsAsync(() => base.Not_over_And()); + + public override Task Not_over_Or() + => Assert.ThrowsAsync(() => base.Not_over_Or()); + + #endregion + + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + // AnyTagEqualTo not (currently) supported on SQLite + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_array() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_array()); + + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_List() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_List()); +} diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBFilterFixture.cs new file mode 100644 index 000000000000..8774018ffabf --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Filter/MongoDBFilterFixture.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using MongoDBIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace MongoDBIntegrationTests.Filter; + +public class MongoDBFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => MongoDBTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/MongoDBIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/MongoDBIntegrationTests.csproj new file mode 100644 index 000000000000..6aa9923ffaa2 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/MongoDBIntegrationTests.csproj @@ -0,0 +1,27 @@ + + + + net8.0;net472 + enable + enable + true + false + MongoDBIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Support/MongoDBTestStore.cs b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Support/MongoDBTestStore.cs new file mode 100644 index 000000000000..10ee96b890b6 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/MongoDBIntegrationTests/Support/MongoDBTestStore.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.MongoDB; +using MongoDB.Driver; +using Testcontainers.MongoDb; +using VectorDataSpecificationTests.Support; + +namespace MongoDBIntegrationTests.Support; + +internal sealed class MongoDBTestStore : TestStore +{ + public static MongoDBTestStore Instance { get; } = new(); + + private readonly MongoDbContainer _container = new MongoDbBuilder() + .WithImage("mongodb/mongodb-atlas-local:7.0.6") + .Build(); + + public MongoClient? _client { get; private set; } + public IMongoDatabase? _database { get; private set; } + private MongoDBVectorStore? _defaultVectorStore; + + public MongoClient Client => this._client ?? throw new InvalidOperationException("Not initialized"); + public IMongoDatabase Database => this._database ?? throw new InvalidOperationException("Not initialized"); + + public override IVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + + public MongoDBVectorStore GetVectorStore(MongoDBVectorStoreOptions options) + => new(this.Database, options); + + private MongoDBTestStore() + { + } + + protected override async Task StartAsync() + { + await this._container.StartAsync(); + + this._client = new MongoClient(new MongoClientSettings + { + Server = new MongoServerAddress(this._container.Hostname, this._container.GetMappedPublicPort(MongoDbBuilder.MongoDbPort)), + DirectConnection = true, + // ReadConcern = ReadConcern.Linearizable, + // WriteConcern = WriteConcern.WMajority + }); + this._database = this._client.GetDatabase("VectorSearchTests"); + this._defaultVectorStore = new(this._database); + } + + protected override Task StopAsync() + => this._container.StopAsync(); +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresBasicFilterTests.cs new file mode 100644 index 000000000000..4fad76458700 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresBasicFilterTests.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using Xunit; +using Xunit.Sdk; + +namespace PostgresIntegrationTests.Filter; + +public class PostgresBasicFilterTests(PostgresFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture +{ + public override async Task Not_over_Or() + { + // Test sends: WHERE (NOT (("Int" = 8) OR ("String" = 'foo'))) + // There's a NULL string in the database, and relational null semantics in conjunction with negation makes the default implementation fail. + await Assert.ThrowsAsync(() => base.Not_over_Or()); + + // Compensate by adding a null check: + await this.TestFilterAsync(r => r.String != null && !(r.Int == 8 || r.String == "foo")); + } + + public override async Task NotEqual_with_string() + { + // As above, null semantics + negation + await Assert.ThrowsAsync(() => base.NotEqual_with_string()); + + await this.TestFilterAsync(r => r.String != null && r.String != "foo"); + } + + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_array() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_array()); +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresFilterFixture.cs new file mode 100644 index 000000000000..c65b37177003 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Filter/PostgresFilterFixture.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PostgresIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace PostgresIntegrationTests.Filter; + +public class PostgresFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => PostgresTestStore.Instance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresIntegrationTests.csproj new file mode 100644 index 000000000000..0a039793dc49 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/PostgresIntegrationTests.csproj @@ -0,0 +1,27 @@ + + + + net8.0;net472 + enable + enable + true + false + PostgresIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresTestStore.cs b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresTestStore.cs new file mode 100644 index 000000000000..1d4c540c216a --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/PostgresIntegrationTests/Support/PostgresTestStore.cs @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Postgres; +using Npgsql; +using Testcontainers.PostgreSql; +using VectorDataSpecificationTests.Support; + +namespace PostgresIntegrationTests.Support; + +#pragma warning disable SKEXP0020 + +internal sealed class PostgresTestStore : TestStore +{ + public static PostgresTestStore Instance { get; } = new(); + + private static readonly PostgreSqlContainer s_container = new PostgreSqlBuilder() + .WithImage("pgvector/pgvector:pg16") + .Build(); + + private NpgsqlDataSource? _dataSource; + private PostgresVectorStore? _defaultVectorStore; + + public NpgsqlDataSource DataSource => this._dataSource ?? throw new InvalidOperationException("Not initialized"); + + public override IVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + + public PostgresVectorStore GetVectorStore(PostgresVectorStoreOptions options) + => new(this.DataSource, options); + + private PostgresTestStore() + { + } + + protected override async Task StartAsync() + { + await s_container.StartAsync(); + + var dataSourceBuilder = new NpgsqlDataSourceBuilder + { + ConnectionStringBuilder = + { + Host = s_container.Hostname, + Port = s_container.GetMappedPublicPort(5432), + Username = PostgreSqlBuilder.DefaultUsername, + Password = PostgreSqlBuilder.DefaultPassword, + Database = PostgreSqlBuilder.DefaultDatabase + } + }; + + dataSourceBuilder.UseVector(); + + this._dataSource = dataSourceBuilder.Build(); + + await using var connection = this._dataSource.CreateConnection(); + await connection.OpenAsync(); + using var command = new NpgsqlCommand("CREATE EXTENSION IF NOT EXISTS vector", connection); + await command.ExecuteNonQueryAsync(); + await connection.ReloadTypesAsync(); + + this._defaultVectorStore = new(this._dataSource); + } + + protected override async Task StopAsync() + { + await this._dataSource!.DisposeAsync(); + await s_container.StopAsync(); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantBasicFilterTests.cs new file mode 100644 index 000000000000..11593833dddf --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantBasicFilterTests.cs @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using Xunit; + +namespace QdrantIntegrationTests.Filter; + +public class QdrantBasicFilterTests(QdrantFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture; diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantFilterFixture.cs new file mode 100644 index 000000000000..8c8a6528b4f8 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Filter/QdrantFilterFixture.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. + +using QdrantIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace QdrantIntegrationTests.Filter; + +public class QdrantFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => QdrantTestStore.Instance; + + // Qdrant doesn't support the default Flat index kind + protected override string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Hnsw; +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/QdrantIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/QdrantIntegrationTests.csproj new file mode 100644 index 000000000000..0ea8db51c21d --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/QdrantIntegrationTests.csproj @@ -0,0 +1,27 @@ + + + + net8.0;net472 + enable + enable + true + false + QdrantIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantTestStore.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantTestStore.cs new file mode 100644 index 000000000000..3537cf8c64e9 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/QdrantTestStore.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Qdrant; +using Qdrant.Client; +using QdrantIntegrationTests.Support.TestContainer; +using VectorDataSpecificationTests.Support; + +namespace QdrantIntegrationTests.Support; + +#pragma warning disable CA1001 // Type owns disposable fields but is not disposable + +internal sealed class QdrantTestStore : TestStore +{ + public static QdrantTestStore Instance { get; } = new(); + + private readonly QdrantContainer _container = new QdrantBuilder().Build(); + private QdrantClient? _client; + private QdrantVectorStore? _defaultVectorStore; + + public QdrantClient Client => this._client ?? throw new InvalidOperationException("Not initialized"); + + public override IVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + + public QdrantVectorStore GetVectorStore(QdrantVectorStoreOptions options) + => new(this.Client, options); + + private QdrantTestStore() + { + } + + protected override async Task StartAsync() + { + await this._container.StartAsync(); + this._client = new QdrantClient(this._container.Hostname, this._container.GetMappedPublicPort(QdrantBuilder.QdrantGrpcPort)); + this._defaultVectorStore = new(this._client); + } + + protected override Task StopAsync() + => this._container.StopAsync(); +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantBuilder.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantBuilder.cs new file mode 100644 index 000000000000..a3444a9f0ee5 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantBuilder.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Docker.DotNet.Models; +using DotNet.Testcontainers.Builders; +using DotNet.Testcontainers.Configurations; +using Qdrant.Client.Grpc; + +namespace QdrantIntegrationTests.Support.TestContainer; + +public sealed class QdrantBuilder : ContainerBuilder +{ + public const string QdrantImage = "qdrant/qdrant:" + QdrantGrpcClient.QdrantVersion; + + public const ushort QdrantHttpPort = 6333; + + public const ushort QdrantGrpcPort = 6334; + + public QdrantBuilder() : this(new QdrantConfiguration()) => this.DockerResourceConfiguration = this.Init().DockerResourceConfiguration; + + private QdrantBuilder(QdrantConfiguration dockerResourceConfiguration) : base(dockerResourceConfiguration) + => this.DockerResourceConfiguration = dockerResourceConfiguration; + + public QdrantBuilder WithConfigFile(string configPath) + => this.Merge(this.DockerResourceConfiguration, new QdrantConfiguration()) + .WithBindMount(configPath, "/qdrant/config/custom_config.yaml"); + + public QdrantBuilder WithCertificate(string certPath, string keyPath) + => this.Merge(this.DockerResourceConfiguration, new QdrantConfiguration()) + .WithBindMount(certPath, "/qdrant/tls/cert.pem") + .WithBindMount(keyPath, "/qdrant/tls/key.pem"); + + public override QdrantContainer Build() + { + this.Validate(); + return new QdrantContainer(this.DockerResourceConfiguration); + } + + protected override QdrantBuilder Init() + => base.Init() + .WithImage(QdrantImage) + .WithPortBinding(QdrantHttpPort, true) + .WithPortBinding(QdrantGrpcPort, true) + .WithWaitStrategy(Wait.ForUnixContainer() + .UntilMessageIsLogged(".*Actix runtime found; starting in Actix runtime.*")); + + protected override QdrantBuilder Clone(IResourceConfiguration resourceConfiguration) + => this.Merge(this.DockerResourceConfiguration, new QdrantConfiguration(resourceConfiguration)); + + protected override QdrantBuilder Merge(QdrantConfiguration oldValue, QdrantConfiguration newValue) + => new(new QdrantConfiguration(oldValue, newValue)); + + protected override QdrantConfiguration DockerResourceConfiguration { get; } + + protected override QdrantBuilder Clone(IContainerConfiguration resourceConfiguration) + => this.Merge(this.DockerResourceConfiguration, new QdrantConfiguration(resourceConfiguration)); +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantConfiguration.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantConfiguration.cs new file mode 100644 index 000000000000..219e4030c581 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantConfiguration.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Docker.DotNet.Models; +using DotNet.Testcontainers.Configurations; + +namespace QdrantIntegrationTests.Support.TestContainer; + +public sealed class QdrantConfiguration : ContainerConfiguration +{ + /// + /// Initializes a new instance of the class. + /// + public QdrantConfiguration() + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The Docker resource configuration. + public QdrantConfiguration(IResourceConfiguration resourceConfiguration) + : base(resourceConfiguration) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The Docker resource configuration. + public QdrantConfiguration(IContainerConfiguration resourceConfiguration) + : base(resourceConfiguration) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The Docker resource configuration. + public QdrantConfiguration(QdrantConfiguration resourceConfiguration) + : this(new QdrantConfiguration(), resourceConfiguration) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The old Docker resource configuration. + /// The new Docker resource configuration. + public QdrantConfiguration(QdrantConfiguration oldValue, QdrantConfiguration newValue) + : base(oldValue, newValue) + { + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantContainer.cs b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantContainer.cs new file mode 100644 index 000000000000..f9c1ab05f1cc --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/QdrantIntegrationTests/Support/TestContainer/QdrantContainer.cs @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft. All rights reserved. + +using DotNet.Testcontainers.Containers; + +namespace QdrantIntegrationTests.Support.TestContainer; + +public class QdrantContainer(QdrantConfiguration configuration) : DockerContainer(configuration); diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicFilterTests.cs new file mode 100644 index 000000000000..d0017e3a510c --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisBasicFilterTests.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using Xunit; +using Xunit.Sdk; + +namespace RedisIntegrationTests.Filter; + +public abstract class RedisBasicFilterTests(FilterFixtureBase fixture) : BasicFilterTestsBase(fixture) +{ + #region Equality with null + + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + #endregion + + #region Bool + + public override Task Bool() + => Assert.ThrowsAsync(() => base.Bool()); + + public override Task Not_over_bool() + => Assert.ThrowsAsync(() => base.Not_over_bool()); + + #endregion + + #region Contains + + public override Task Contains_over_inline_int_array() + => Assert.ThrowsAsync(() => base.Contains_over_inline_int_array()); + + public override Task Contains_over_inline_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_inline_string_array()); + + public override Task Contains_over_inline_string_array_with_weird_chars() + => Assert.ThrowsAsync(() => base.Contains_over_inline_string_array_with_weird_chars()); + + public override Task Contains_over_captured_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_captured_string_array()); + + #endregion +} + +public class RedisJsonCollectionBasicFilterTests(RedisJsonCollectionFilterFixture fixture) : RedisBasicFilterTests(fixture), IClassFixture; + +public class RedisHashSetCollectionBasicFilterTests(RedisHashSetCollectionFilterFixture fixture) : RedisBasicFilterTests(fixture), IClassFixture +{ + // Null values are not supported in Redis HashSet + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_reference_type() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_reference_type()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + // Array fields not supported on Redis HashSet + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_array() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_array()); + + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_List() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_List()); +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisFilterFixture.cs new file mode 100644 index 000000000000..de751f36ca4e --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Filter/RedisFilterFixture.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Redis; +using RedisIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace RedisIntegrationTests.Filter; + +public class RedisJsonCollectionFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => RedisTestStore.Instance; + + protected override string StoreName => "JsonCollectionFilterTests"; + + // Override to remove the bool property, which isn't (currently) supported on Redis/JSON + protected override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = base.GetRecordDefinition().Properties.Where(p => p.PropertyType != typeof(bool)).ToList() + }; + + protected override IVectorStoreRecordCollection> CreateCollection() + => new RedisJsonVectorStoreRecordCollection>( + RedisTestStore.Instance.Database, + this.StoreName, + new() { VectorStoreRecordDefinition = this.GetRecordDefinition() }); +} + +public class RedisHashSetCollectionFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => RedisTestStore.Instance; + + protected override string StoreName => "HashSetCollectionFilterTests"; + + // Override to remove the bool property, which isn't (currently) supported on Redis + protected override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = base.GetRecordDefinition().Properties.Where(p => + p.PropertyType != typeof(bool) && + p.PropertyType != typeof(string[]) && + p.PropertyType != typeof(List)).ToList() + }; + + protected override IVectorStoreRecordCollection> CreateCollection() + => new RedisHashSetVectorStoreRecordCollection>( + RedisTestStore.Instance.Database, + this.StoreName, + new() { VectorStoreRecordDefinition = this.GetRecordDefinition() }); + + protected override List> BuildTestData() + { + var testData = base.BuildTestData(); + + foreach (var record in testData) + { + // Null values are not supported in Redis hashsets + record.String ??= string.Empty; + } + + return testData; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisIntegrationTests.csproj new file mode 100644 index 000000000000..5727b3b2650a --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/RedisIntegrationTests.csproj @@ -0,0 +1,27 @@ + + + + net8.0;net472 + enable + enable + true + false + RedisIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisTestStore.cs b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisTestStore.cs new file mode 100644 index 000000000000..a1dd2f02c0bc --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/RedisIntegrationTests/Support/RedisTestStore.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Redis; +using StackExchange.Redis; +using Testcontainers.Redis; +using VectorDataSpecificationTests.Support; + +namespace RedisIntegrationTests.Support; + +internal sealed class RedisTestStore : TestStore +{ + public static RedisTestStore Instance { get; } = new(); + + private readonly RedisContainer _container = new RedisBuilder() + .WithImage("redis/redis-stack") + .Build(); + + private IDatabase? _database; + private RedisVectorStore? _defaultVectorStore; + + public IDatabase Database => this._database ?? throw new InvalidOperationException("Not initialized"); + + public override IVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + + public RedisVectorStore GetVectorStore(RedisVectorStoreOptions options) + => new(this.Database, options); + + private RedisTestStore() + { + } + + protected override async Task StartAsync() + { + await this._container.StartAsync(); + var redis = await ConnectionMultiplexer.ConnectAsync($"{this._container.Hostname}:{this._container.GetMappedPublicPort(6379)},connectTimeout=60000,connectRetry=5"); + this._database = redis.GetDatabase(); + this._defaultVectorStore = new(this._database); + } + + protected override Task StopAsync() + => this._container.StopAsync(); +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteBasicFilterTests.cs new file mode 100644 index 000000000000..9ca7878a414e --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteBasicFilterTests.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using Xunit; +using Xunit.Sdk; + +namespace SqliteIntegrationTests.Filter; + +public class SqliteBasicFilterTests(SqliteFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture +{ + public override async Task Not_over_Or() + { + // Test sends: WHERE (NOT (("Int" = 8) OR ("String" = 'foo'))) + // There's a NULL string in the database, and relational null semantics in conjunction with negation makes the default implementation fail. + await Assert.ThrowsAsync(() => base.Not_over_Or()); + + // Compensate by adding a null check: + await this.TestFilterAsync(r => r.String != null && !(r.Int == 8 || r.String == "foo")); + } + + public override async Task NotEqual_with_string() + { + // As above, null semantics + negation + await Assert.ThrowsAsync(() => base.NotEqual_with_string()); + + await this.TestFilterAsync(r => r.String != null && r.String != "foo"); + } + + // Array fields not (currently) supported on SQLite (see #10343) + public override Task Contains_over_field_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_array()); + + // List fields not (currently) supported on SQLite (see #10343) + public override Task Contains_over_field_string_List() + => Assert.ThrowsAsync(() => base.Contains_over_field_string_List()); + + // AnyTagEqualTo not (currently) supported on SQLite + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_array() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_array()); + + [Obsolete("Legacy filter support")] + public override Task Legacy_AnyTagEqualTo_List() + => Assert.ThrowsAsync(() => base.Legacy_AnyTagEqualTo_List()); +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteFilterFixture.cs new file mode 100644 index 000000000000..3dc9a0d10dad --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Filter/SqliteFilterFixture.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using SqliteIntegrationTests.Support; +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; + +namespace SqliteIntegrationTests.Filter; + +public class SqliteFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => SqliteTestStore.Instance; + + protected override string DistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; + + // Override to remove the string array property, which isn't (currently) supported on SQLite + protected override VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = base.GetRecordDefinition().Properties.Where(p => p.PropertyType != typeof(string[]) && p.PropertyType != typeof(List)).ToList() + }; +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Properties/AssemblyAttributes.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Properties/AssemblyAttributes.cs new file mode 100644 index 000000000000..89ee1c5e6025 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Properties/AssemblyAttributes.cs @@ -0,0 +1,3 @@ +// Copyright (c) Microsoft. All rights reserved. + +[assembly: SqliteIntegrationTests.Support.SqliteVecRequired] diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/SqliteIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/SqliteIntegrationTests.csproj new file mode 100644 index 000000000000..a47480e526cd --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/SqliteIntegrationTests.csproj @@ -0,0 +1,26 @@ + + + + net8.0;net472 + enable + enable + true + false + SqliteIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestEnvironment.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestEnvironment.cs new file mode 100644 index 000000000000..e7dd76fb76fc --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestEnvironment.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Data; +using Microsoft.Data.Sqlite; + +namespace SqliteIntegrationTests.Support; + +internal static class SqliteTestEnvironment +{ + /// + /// SQLite extension name for vector search. + /// More information here: . + /// + private const string VectorSearchExtensionName = "vec0"; + + private static bool? s_isSqliteVecInstalled; + + internal static bool TryLoadSqliteVec(SqliteConnection connection) + { + if (!s_isSqliteVecInstalled.HasValue) + { + if (connection.State != ConnectionState.Open) + { + throw new ArgumentException("Connection must be open"); + } + + try + { + connection.LoadExtension(VectorSearchExtensionName); + s_isSqliteVecInstalled = true; + } + catch (SqliteException) + { + s_isSqliteVecInstalled = false; + } + } + + return s_isSqliteVecInstalled.Value; + } + + internal static bool IsSqliteVecInstalled + { + get + { + if (!s_isSqliteVecInstalled.HasValue) + { + using var connection = new SqliteConnection("Data Source=:memory:;"); + connection.Open(); + + s_isSqliteVecInstalled = TryLoadSqliteVec(connection); + } + + return s_isSqliteVecInstalled.Value; + } + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs new file mode 100644 index 000000000000..526eeac3b2d8 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteTestStore.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Data.Sqlite; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Sqlite; +using VectorDataSpecificationTests.Support; + +namespace SqliteIntegrationTests.Support; + +#pragma warning disable CA1001 // Type owns disposable fields (_connection) but is not disposable + +internal sealed class SqliteTestStore : TestStore +{ + public static SqliteTestStore Instance { get; } = new(); + + private SqliteConnection? _connection; + public SqliteConnection Connection + => this._connection ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + private SqliteVectorStore? _defaultVectorStore; + public override IVectorStore DefaultVectorStore + => this._defaultVectorStore ?? throw new InvalidOperationException("Call InitializeAsync() first"); + + private SqliteTestStore() + { + } + + protected override async Task StartAsync() + { + this._connection = new SqliteConnection("Data Source=:memory:"); + + await this.Connection.OpenAsync(); + + if (!SqliteTestEnvironment.TryLoadSqliteVec(this.Connection)) + { + this.Connection.Dispose(); + + // Note that we ignore sqlite_vec loading failures; the tests are decorated with [SqliteVecRequired], which causes + // them to be skipped if sqlite_vec isn't installed (better than an exception triggering failure here) + } + + this._defaultVectorStore = new SqliteVectorStore(this.Connection); + } + +#if NET8_0_OR_GREATER + protected override async Task StopAsync() + => await this.Connection.DisposeAsync(); +#else + protected override Task StopAsync() + { + this.Connection.Dispose(); + return Task.CompletedTask; + } +#endif +} diff --git a/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteVecRequiredAttribute.cs b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteVecRequiredAttribute.cs new file mode 100644 index 000000000000..9351fd679171 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/SqliteIntegrationTests/Support/SqliteVecRequiredAttribute.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Xunit; + +namespace SqliteIntegrationTests.Support; + +/// +/// Checks whether the sqlite_vec extension is properly installed, and skips the test(s) otherwise. +/// +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class | AttributeTargets.Assembly)] +public sealed class SqliteVecRequiredAttribute : Attribute, ITestCondition +{ + public ValueTask IsMetAsync() => new(SqliteTestEnvironment.IsSqliteVecInstalled); + + public string Skip { get; set; } = "The sqlite_vec extension is not installed."; + + public string SkipReason + => this.Skip; +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTestsBase.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTestsBase.cs new file mode 100644 index 000000000000..f2022a2e7c60 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/BasicFilterTestsBase.cs @@ -0,0 +1,283 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq.Expressions; +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Xunit; +using Xunit; + +namespace VectorDataSpecificationTests.Filter; + +public abstract class BasicFilterTestsBase(FilterFixtureBase fixture) + where TKey : notnull +{ + #region Equality + + [ConditionalFact] + public virtual Task Equal_with_int() + => this.TestFilterAsync(r => r.Int == 8); + + [ConditionalFact] + public virtual Task Equal_with_string() + => this.TestFilterAsync(r => r.String == "foo"); + + [ConditionalFact] + public virtual Task Equal_with_string_containing_special_characters() + => this.TestFilterAsync(r => r.String == """with some special"characters'and\stuff"""); + + [ConditionalFact] + public virtual Task Equal_with_string_is_not_Contains() + => this.TestFilterAsync(r => r.String == "some", expectZeroResults: true); + + [ConditionalFact] + public virtual Task Equal_reversed() + => this.TestFilterAsync(r => 8 == r.Int); + + [ConditionalFact] + public virtual Task Equal_with_null_reference_type() + => this.TestFilterAsync(r => r.String == null); + + [ConditionalFact] + public virtual Task Equal_with_null_captured() + { + string? s = null; + + return this.TestFilterAsync(r => r.String == s); + } + + [ConditionalFact] + public virtual Task NotEqual_with_int() + => this.TestFilterAsync(r => r.Int != 8); + + [ConditionalFact] + public virtual Task NotEqual_with_string() + => this.TestFilterAsync(r => r.String != "foo"); + + [ConditionalFact] + public virtual Task NotEqual_reversed() + => this.TestFilterAsync(r => r.Int != 8); + + [ConditionalFact] + public virtual Task NotEqual_with_null_reference_type() + => this.TestFilterAsync(r => r.String != null); + + [ConditionalFact] + public virtual Task NotEqual_with_null_captured() + { + string? s = null; + + return this.TestFilterAsync(r => r.String != s); + } + + [ConditionalFact] + public virtual Task Bool() + => this.TestFilterAsync(r => r.Bool); + + #endregion Equality + + #region Comparison + + [ConditionalFact] + public virtual Task GreaterThan_with_int() + => this.TestFilterAsync(r => r.Int > 9); + + [ConditionalFact] + public virtual Task GreaterThanOrEqual_with_int() + => this.TestFilterAsync(r => r.Int >= 9); + + [ConditionalFact] + public virtual Task LessThan_with_int() + => this.TestFilterAsync(r => r.Int < 10); + + [ConditionalFact] + public virtual Task LessThanOrEqual_with_int() + => this.TestFilterAsync(r => r.Int <= 10); + + #endregion Comparison + + #region Logical operators + + [ConditionalFact] + public virtual Task And() + => this.TestFilterAsync(r => r.Int == 8 && r.String == "foo"); + + [ConditionalFact] + public virtual Task Or() + => this.TestFilterAsync(r => r.Int == 8 || r.String == "foo"); + + [ConditionalFact] + public virtual Task And_within_And() + => this.TestFilterAsync(r => (r.Int == 8 && r.String == "foo") && r.Int2 == 80); + + [ConditionalFact] + public virtual Task And_within_Or() + => this.TestFilterAsync(r => (r.Int == 8 && r.String == "foo") || r.Int2 == 100); + + [ConditionalFact] + public virtual Task Or_within_And() + => this.TestFilterAsync(r => (r.Int == 8 || r.Int == 9) && r.String == "foo"); + + [ConditionalFact] + public virtual Task Not_over_Equal() + // ReSharper disable once NegativeEqualityExpression + => this.TestFilterAsync(r => !(r.Int == 8)); + + [ConditionalFact] + public virtual Task Not_over_NotEqual() + // ReSharper disable once NegativeEqualityExpression + => this.TestFilterAsync(r => !(r.Int != 8)); + + [ConditionalFact] + public virtual Task Not_over_And() + => this.TestFilterAsync(r => !(r.Int == 8 && r.String == "foo")); + + [ConditionalFact] + public virtual Task Not_over_Or() + => this.TestFilterAsync(r => !(r.Int == 8 || r.String == "foo")); + + [ConditionalFact] + public virtual Task Not_over_bool() + => this.TestFilterAsync(r => !r.Bool); + + #endregion Logical operators + + #region Contains + + [ConditionalFact] + public virtual Task Contains_over_field_string_array() + => this.TestFilterAsync(r => r.StringArray.Contains("x")); + + [ConditionalFact] + public virtual Task Contains_over_field_string_List() + => this.TestFilterAsync(r => r.StringList.Contains("x")); + + [ConditionalFact] + public virtual Task Contains_over_inline_int_array() + => this.TestFilterAsync(r => new[] { 8, 10 }.Contains(r.Int)); + + [ConditionalFact] + public virtual Task Contains_over_inline_string_array() + => this.TestFilterAsync(r => new[] { "foo", "baz", "unknown" }.Contains(r.String)); + + [ConditionalFact] + public virtual Task Contains_over_inline_string_array_with_weird_chars() + => this.TestFilterAsync(r => new[] { "foo", "baz", "un , ' \"" }.Contains(r.String)); + + [ConditionalFact] + public virtual Task Contains_over_captured_string_array() + { + var array = new[] { "foo", "baz", "unknown" }; + + return this.TestFilterAsync(r => array.Contains(r.String)); + } + + #endregion Contains + + [ConditionalFact] + public virtual Task Captured_variable() + { + // ReSharper disable once ConvertToConstant.Local + var i = 8; + + return this.TestFilterAsync(r => r.Int == i); + } + + #region Legacy filter support + + [ConditionalFact] + [Obsolete("Legacy filter support")] + public virtual Task Legacy_equality() + => this.TestLegacyFilterAsync( + new VectorSearchFilter().EqualTo("Int", 8), + r => r.Int == 8); + + [ConditionalFact] + [Obsolete("Legacy filter support")] + public virtual Task Legacy_And() + => this.TestLegacyFilterAsync( + new VectorSearchFilter().EqualTo("Int", 8).EqualTo("String", "foo"), + r => r.Int == 8); + + [ConditionalFact] + [Obsolete("Legacy filter support")] + public virtual Task Legacy_AnyTagEqualTo_array() + => this.TestLegacyFilterAsync( + new VectorSearchFilter().AnyTagEqualTo("StringArray", "x"), + r => r.StringArray.Contains("x")); + + [ConditionalFact] + [Obsolete("Legacy filter support")] + public virtual Task Legacy_AnyTagEqualTo_List() + => this.TestLegacyFilterAsync( + new VectorSearchFilter().AnyTagEqualTo("StringList", "x"), + r => r.StringArray.Contains("x")); + + #endregion Legacy filter support + + protected virtual async Task TestFilterAsync( + Expression, bool>> filter, + bool expectZeroResults = false, + bool expectAllResults = false) + { + var expected = fixture.TestData.AsQueryable().Where(filter).OrderBy(r => r.Key).ToList(); + + if (expected.Count == 0 && !expectZeroResults) + { + Assert.Fail("The test returns zero results, and so is unreliable"); + } + + if (expected.Count == fixture.TestData.Count && !expectAllResults) + { + Assert.Fail("The test returns all results, and so is unreliable"); + } + + var results = await fixture.Collection.VectorizedSearchAsync( + new ReadOnlyMemory([1, 2, 3]), + new() + { + NewFilter = filter, + Top = fixture.TestData.Count + }); + + var actual = await results.Results.Select(r => r.Record).OrderBy(r => r.Key).ToListAsync(); + + Assert.Equal(expected, actual, (e, a) => + e.Int == a.Int && + e.String == a.String && + e.Int2 == a.Int2); + } + + [Obsolete("Legacy filter support")] + protected virtual async Task TestLegacyFilterAsync( + VectorSearchFilter legacyFilter, + Expression, bool>> expectedFilter, + bool expectZeroResults = false, + bool expectAllResults = false) + { + var expected = fixture.TestData.AsQueryable().Where(expectedFilter).OrderBy(r => r.Key).ToList(); + + if (expected.Count == 0 && !expectZeroResults) + { + Assert.Fail("The test returns zero results, and so is unreliable"); + } + + if (expected.Count == fixture.TestData.Count && !expectAllResults) + { + Assert.Fail("The test returns all results, and so is unreliable"); + } + + var results = await fixture.Collection.VectorizedSearchAsync( + new ReadOnlyMemory([1, 2, 3]), + new() + { + Filter = legacyFilter, + Top = fixture.TestData.Count + }); + + var actual = await results.Results.Select(r => r.Record).OrderBy(r => r.Key).ToListAsync(); + + Assert.Equal(expected, actual, (e, a) => + e.Int == a.Int && + e.String == a.String && + e.Int2 == a.Int2); + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs new file mode 100644 index 000000000000..436d1453d552 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Filter/FilterFixtureBase.cs @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Globalization; +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Support; +using Xunit; + +namespace VectorDataSpecificationTests.Filter; + +public abstract class FilterFixtureBase : IAsyncLifetime + where TKey : notnull +{ + private int _nextKeyValue = 1; + private List>? _testData; + + protected virtual string StoreName => "FilterTests"; + + protected abstract TestStore TestStore { get; } + + protected virtual string DistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineSimilarity; + protected virtual string IndexKind => Microsoft.Extensions.VectorData.IndexKind.Flat; + + protected virtual IVectorStoreRecordCollection> CreateCollection() + => this.TestStore.DefaultVectorStore.GetCollection>(this.StoreName, this.GetRecordDefinition()); + + public virtual async Task InitializeAsync() + { + await this.TestStore.ReferenceCountingStartAsync(); + + this.Collection = this.CreateCollection(); + + if (await this.Collection.CollectionExistsAsync()) + { + await this.Collection.DeleteCollectionAsync(); + } + + await this.Collection.CreateCollectionAsync(); + await this.SeedAsync(); + + // Some databases upsert asynchronously, meaning that our seed data may not be visible immediately to tests. + // Check and loop until it is. + for (var i = 0; i < 20; i++) + { + var results = await this.Collection.VectorizedSearchAsync( + new ReadOnlyMemory([1, 2, 3]), + new() + { + Top = this.TestData.Count, + NewFilter = r => r.Int > 0 + }); + var count = await results.Results.CountAsync(); + if (count == this.TestData.Count) + { + break; + } + + await Task.Delay(TimeSpan.FromMilliseconds(100)); + } + } + + protected virtual VectorStoreRecordDefinition GetRecordDefinition() + => new() + { + Properties = + [ + new VectorStoreRecordKeyProperty(nameof(FilterRecord.Key), typeof(TKey)), + new VectorStoreRecordVectorProperty(nameof(FilterRecord.Vector), typeof(ReadOnlyMemory?)) + { + Dimensions = 3, + DistanceFunction = this.DistanceFunction, + IndexKind = this.IndexKind + }, + + new VectorStoreRecordDataProperty(nameof(FilterRecord.Int), typeof(int)) { IsFilterable = true }, + new VectorStoreRecordDataProperty(nameof(FilterRecord.String), typeof(string)) { IsFilterable = true }, + new VectorStoreRecordDataProperty(nameof(FilterRecord.Bool), typeof(bool)) { IsFilterable = true }, + new VectorStoreRecordDataProperty(nameof(FilterRecord.Int2), typeof(int)) { IsFilterable = true }, + new VectorStoreRecordDataProperty(nameof(FilterRecord.StringArray), typeof(string[])) { IsFilterable = true }, + new VectorStoreRecordDataProperty(nameof(FilterRecord.StringList), typeof(List)) { IsFilterable = true } + ] + }; + + public virtual IVectorStoreRecordCollection> Collection { get; private set; } = null!; + + public List> TestData => this._testData ??= this.BuildTestData(); + + protected virtual List> BuildTestData() + { + // All records have the same vector - this fixture is about testing criteria filtering only + var vector = new ReadOnlyMemory([1, 2, 3]); + + return + [ + new() + { + Key = this.GenerateNextKey(), + Int = 8, + String = "foo", + Bool = true, + Int2 = 80, + StringArray = ["x", "y"], + StringList = ["x", "y"], + Vector = vector + }, + new() + { + Key = this.GenerateNextKey(), + Int = 9, + String = "bar", + Bool = false, + Int2 = 90, + StringArray = ["a", "b"], + StringList = ["a", "b"], + Vector = vector + }, + new() + { + Key = this.GenerateNextKey(), + Int = 9, + String = "foo", + Bool = true, + Int2 = 9, + StringArray = ["x"], + StringList = ["x"], + Vector = vector + }, + new() + { + Key = this.GenerateNextKey(), + Int = 10, + String = null, + Bool = false, + Int2 = 100, + StringArray = ["x", "y", "z"], + StringList = ["x", "y", "z"], + Vector = vector + }, + new() + { + Key = this.GenerateNextKey(), + Int = 11, + Bool = true, + String = """with some special"characters'and\stuff""", + Int2 = 101, + StringArray = ["y", "z"], + StringList = ["y", "z"], + Vector = vector + } + ]; + } + + protected virtual async Task SeedAsync() + { + // TODO: UpsertBatchAsync returns IAsyncEnumerable (to support server-generated keys?), but this makes it quite hard to use: + await foreach (var _ in this.Collection.UpsertBatchAsync(this.TestData)) + { + } + } + + protected virtual TKey GenerateNextKey() + => typeof(TKey) switch + { + _ when typeof(TKey) == typeof(int) => (TKey)(object)this._nextKeyValue++, + _ when typeof(TKey) == typeof(long) => (TKey)(object)(long)this._nextKeyValue++, + _ when typeof(TKey) == typeof(ulong) => (TKey)(object)(ulong)this._nextKeyValue++, + _ when typeof(TKey) == typeof(string) => (TKey)(object)(this._nextKeyValue++).ToString(CultureInfo.InvariantCulture), + _ when typeof(TKey) == typeof(Guid) => (TKey)(object)new Guid($"00000000-0000-0000-0000-00{this._nextKeyValue++:0000000000}"), + + _ => throw new NotSupportedException($"Unsupported key of type '{typeof(TKey).Name}', override {nameof(this.GenerateNextKey)}") + }; + + public virtual Task DisposeAsync() + => this.TestStore.ReferenceCountingStopAsync(); +} + +#pragma warning disable CS1819 // Properties should not return arrays +#pragma warning disable CA1819 // Properties should not return arrays +public class FilterRecord +{ + public TKey Key { get; set; } = default!; + public ReadOnlyMemory? Vector { get; set; } + + public int Int { get; set; } + public string? String { get; set; } + public bool Bool { get; set; } + public int Int2 { get; set; } + public string[] StringArray { get; set; } = null!; + public List StringList { get; set; } = null!; +} +#pragma warning restore CA1819 // Properties should not return arrays +#pragma warning restore CS1819 diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/TestStore.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/TestStore.cs new file mode 100644 index 000000000000..de7c0d252062 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Support/TestStore.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; + +namespace VectorDataSpecificationTests.Support; + +#pragma warning disable CA1001 // Type owns disposable fields but is not disposable + +public abstract class TestStore +{ + private readonly SemaphoreSlim _lock = new(1, 1); + private int _referenceCount; + + protected abstract Task StartAsync(); + + protected virtual Task StopAsync() + => Task.CompletedTask; + + public virtual async Task ReferenceCountingStartAsync() + { + await this._lock.WaitAsync(); + try + { + if (this._referenceCount++ == 0) + { + await this.StartAsync(); + } + } + finally + { + this._lock.Release(); + } + } + + public virtual async Task ReferenceCountingStopAsync() + { + await this._lock.WaitAsync(); + try + { + if (--this._referenceCount == 0) + { + await this.StopAsync(); + } + } + finally + { + this._lock.Release(); + } + } + + public abstract IVectorStore DefaultVectorStore { get; } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj new file mode 100644 index 000000000000..77fc8e90dbb2 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/VectorDataIntegrationTests.csproj @@ -0,0 +1,24 @@ + + + + net8.0;net472 + enable + enable + false + VectorDataSpecificationTests + + + + + + + + + + + + + + + + diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactAttribute.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactAttribute.cs new file mode 100644 index 000000000000..d4d93c8b5035 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactAttribute.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit; +using Xunit.Sdk; + +namespace VectorDataSpecificationTests.Xunit; + +[AttributeUsage(AttributeTargets.Method)] +[XunitTestCaseDiscoverer("VectorDataSpecificationTests.Xunit.ConditionalFactDiscoverer", "VectorDataIntegrationTests")] +public sealed class ConditionalFactAttribute : FactAttribute; diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactDiscoverer.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactDiscoverer.cs new file mode 100644 index 000000000000..1fbeafd3dd1c --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactDiscoverer.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace VectorDataSpecificationTests.Xunit; + +/// +/// Used dynamically from . +/// Make sure to update that class if you move this type. +/// +public class ConditionalFactDiscoverer(IMessageSink messageSink) : FactDiscoverer(messageSink) +{ + protected override IXunitTestCase CreateTestCase( + ITestFrameworkDiscoveryOptions discoveryOptions, + ITestMethod testMethod, + IAttributeInfo factAttribute) + => new ConditionalFactTestCase( + this.DiagnosticMessageSink, + discoveryOptions.MethodDisplayOrDefault(), + discoveryOptions.MethodDisplayOptionsOrDefault(), + testMethod); +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactTestCase.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactTestCase.cs new file mode 100644 index 000000000000..3dea216a1084 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalFactTestCase.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace VectorDataSpecificationTests.Xunit; + +public sealed class ConditionalFactTestCase : XunitTestCase +{ + [Obsolete("Called by the de-serializer; should only be called by deriving classes for de-serialization purposes")] + public ConditionalFactTestCase() + { + } + + public ConditionalFactTestCase( + IMessageSink diagnosticMessageSink, + TestMethodDisplay defaultMethodDisplay, + TestMethodDisplayOptions defaultMethodDisplayOptions, + ITestMethod testMethod, + object[]? testMethodArguments = null) + : base(diagnosticMessageSink, defaultMethodDisplay, defaultMethodDisplayOptions, testMethod, testMethodArguments) + { + } + + public override async Task RunAsync( + IMessageSink diagnosticMessageSink, + IMessageBus messageBus, + object[] constructorArguments, + ExceptionAggregator aggregator, + CancellationTokenSource cancellationTokenSource) + => await XunitTestCaseExtensions.TrySkipAsync(this, messageBus) + ? new RunSummary { Total = 1, Skipped = 1 } + : await base.RunAsync( + diagnosticMessageSink, + messageBus, + constructorArguments, + aggregator, + cancellationTokenSource); +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryAttribute.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryAttribute.cs new file mode 100644 index 000000000000..529f42ef1310 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ConditionalTheoryAttribute.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit; +using Xunit.Sdk; + +namespace VectorDataSpecificationTests.Xunit; + +[AttributeUsage(AttributeTargets.Method)] +[XunitTestCaseDiscoverer("VectorDataSpecificationTests.Xunit.VectorStoreFactDiscoverer", "VectorDataIntegrationTests")] +public sealed class ConditionalTheoryAttribute : TheoryAttribute; diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ITestCondition.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ITestCondition.cs new file mode 100644 index 000000000000..deca7716fb1a --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/ITestCondition.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace VectorDataSpecificationTests.Xunit; + +public interface ITestCondition +{ + ValueTask IsMetAsync(); + + string SkipReason { get; } +} diff --git a/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/XunitTestCaseExtensions.cs b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/XunitTestCaseExtensions.cs new file mode 100644 index 000000000000..2cf37205ead4 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/VectorDataIntegrationTests/Xunit/XunitTestCaseExtensions.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Concurrent; +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace VectorDataSpecificationTests.Xunit; + +public static class XunitTestCaseExtensions +{ + private static readonly ConcurrentDictionary> s_typeAttributes = new(); + private static readonly ConcurrentDictionary> s_assemblyAttributes = new(); + + public static async ValueTask TrySkipAsync(XunitTestCase testCase, IMessageBus messageBus) + { + var method = testCase.Method; + var type = testCase.TestMethod.TestClass.Class; + var assembly = type.Assembly; + + var skipReasons = new List(); + var attributes = + s_assemblyAttributes.GetOrAdd( + assembly.Name, + a => assembly.GetCustomAttributes(typeof(ITestCondition)).ToList()) + .Concat( + s_typeAttributes.GetOrAdd( + type.Name, + t => type.GetCustomAttributes(typeof(ITestCondition)).ToList())) + .Concat(method.GetCustomAttributes(typeof(ITestCondition))) + .OfType() + .Select(attributeInfo => (ITestCondition)attributeInfo.Attribute); + + foreach (var attribute in attributes) + { + if (!await attribute.IsMetAsync()) + { + skipReasons.Add(attribute.SkipReason); + } + } + + if (skipReasons.Count > 0) + { + messageBus.QueueMessage( + new TestSkipped(new XunitTest(testCase, testCase.DisplayName), string.Join(Environment.NewLine, skipReasons))); + + return true; + } + + return false; + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicFilterTests.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicFilterTests.cs new file mode 100644 index 000000000000..2880d1b93859 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateBasicFilterTests.cs @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using VectorDataSpecificationTests.Filter; +using Xunit; +using Xunit.Sdk; + +namespace WeaviateIntegrationTests.Filter; + +public class WeaviateBasicFilterTests(WeaviateFilterFixture fixture) : BasicFilterTestsBase(fixture), IClassFixture +{ + #region Filter by null + + // Null-state indexing needs to be set up, but that's not supported yet (#10358). + // We could interact with Weaviate directly (not via the abstraction) to do this. + + public override Task Equal_with_null_reference_type() + => Assert.ThrowsAsync(() => base.Equal_with_null_reference_type()); + + public override Task Equal_with_null_captured() + => Assert.ThrowsAsync(() => base.Equal_with_null_captured()); + + public override Task NotEqual_with_null_captured() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_captured()); + + public override Task NotEqual_with_null_reference_type() + => Assert.ThrowsAsync(() => base.NotEqual_with_null_reference_type()); + + #endregion + + #region Not + + // Weaviate currently doesn't support NOT (https://github.com/weaviate/weaviate/issues/3683) + public override Task Not_over_And() + => Assert.ThrowsAsync(() => base.Not_over_And()); + + public override Task Not_over_Or() + => Assert.ThrowsAsync(() => base.Not_over_Or()); + + #endregion + + #region Unsupported Contains scenarios + + public override Task Contains_over_captured_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_captured_string_array()); + + public override Task Contains_over_inline_int_array() + => Assert.ThrowsAsync(() => base.Contains_over_inline_int_array()); + + public override Task Contains_over_inline_string_array() + => Assert.ThrowsAsync(() => base.Contains_over_inline_int_array()); + + public override Task Contains_over_inline_string_array_with_weird_chars() + => Assert.ThrowsAsync(() => base.Contains_over_inline_string_array_with_weird_chars()); + + #endregion + + // In Weaviate, string equality on multi-word textual properties depends on tokenization + // (https://weaviate.io/developers/weaviate/api/graphql/filters#multi-word-queries-in-equal-filters) + public override Task Equal_with_string_is_not_Contains() + => Assert.ThrowsAsync(() => base.Equal_with_string_is_not_Contains()); +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateFilterFixture.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateFilterFixture.cs new file mode 100644 index 000000000000..f00b884780c2 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Filter/WeaviateFilterFixture.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorDataSpecificationTests.Filter; +using VectorDataSpecificationTests.Support; +using WeaviateIntegrationTests.Support; + +namespace WeaviateIntegrationTests.Filter; + +public class WeaviateFilterFixture : FilterFixtureBase +{ + protected override TestStore TestStore => WeaviateTestStore.Instance; + + protected override string DistanceFunction => Microsoft.Extensions.VectorData.DistanceFunction.CosineDistance; +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateBuilder.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateBuilder.cs new file mode 100644 index 000000000000..1745a902a348 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateBuilder.cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Docker.DotNet.Models; +using DotNet.Testcontainers.Builders; +using DotNet.Testcontainers.Configurations; + +namespace WeaviateIntegrationTests.Support.TestContainer; + +public sealed class WeaviateBuilder : ContainerBuilder +{ + public const string WeaviateImage = "semitechnologies/weaviate:1.26.4"; + public const ushort WeaviateHttpPort = 8080; + public const ushort WeaviateGrpcPort = 50051; + + public WeaviateBuilder() : this(new WeaviateConfiguration()) => this.DockerResourceConfiguration = this.Init().DockerResourceConfiguration; + + private WeaviateBuilder(WeaviateConfiguration dockerResourceConfiguration) : base(dockerResourceConfiguration) + => this.DockerResourceConfiguration = dockerResourceConfiguration; + + public override WeaviateContainer Build() + { + this.Validate(); + return new WeaviateContainer(this.DockerResourceConfiguration); + } + + protected override WeaviateBuilder Init() + => base.Init() + .WithImage(WeaviateImage) + .WithPortBinding(WeaviateHttpPort, true) + .WithPortBinding(WeaviateGrpcPort, true) + .WithEnvironment("AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED", "true") + .WithEnvironment("PERSISTENCE_DATA_PATH", "/var/lib/weaviate") + .WithWaitStrategy(Wait.ForUnixContainer() + .UntilPortIsAvailable(WeaviateHttpPort) + .UntilPortIsAvailable(WeaviateGrpcPort) + .UntilHttpRequestIsSucceeded(r => r.ForPath("/v1/.well-known/ready").ForPort(WeaviateHttpPort))); + + protected override WeaviateBuilder Clone(IResourceConfiguration resourceConfiguration) + => this.Merge(this.DockerResourceConfiguration, new WeaviateConfiguration(resourceConfiguration)); + + protected override WeaviateBuilder Merge(WeaviateConfiguration oldValue, WeaviateConfiguration newValue) + => new(new WeaviateConfiguration(oldValue, newValue)); + + protected override WeaviateConfiguration DockerResourceConfiguration { get; } + + protected override WeaviateBuilder Clone(IContainerConfiguration resourceConfiguration) + => this.Merge(this.DockerResourceConfiguration, new WeaviateConfiguration(resourceConfiguration)); +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateConfiguration.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateConfiguration.cs new file mode 100644 index 000000000000..56ea40b242e7 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateConfiguration.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Docker.DotNet.Models; +using DotNet.Testcontainers.Configurations; + +namespace WeaviateIntegrationTests.Support.TestContainer; + +public sealed class WeaviateConfiguration : ContainerConfiguration +{ + /// + /// Initializes a new instance of the class. + /// + public WeaviateConfiguration() + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The Docker resource configuration. + public WeaviateConfiguration(IResourceConfiguration resourceConfiguration) + : base(resourceConfiguration) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The Docker resource configuration. + public WeaviateConfiguration(IContainerConfiguration resourceConfiguration) + : base(resourceConfiguration) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The Docker resource configuration. + public WeaviateConfiguration(WeaviateConfiguration resourceConfiguration) + : this(new WeaviateConfiguration(), resourceConfiguration) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The old Docker resource configuration. + /// The new Docker resource configuration. + public WeaviateConfiguration(WeaviateConfiguration oldValue, WeaviateConfiguration newValue) + : base(oldValue, newValue) + { + } +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateContainer.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateContainer.cs new file mode 100644 index 000000000000..c209d662a4d4 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/TestContainer/WeaviateContainer.cs @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft. All rights reserved. + +using DotNet.Testcontainers.Containers; + +namespace WeaviateIntegrationTests.Support.TestContainer; + +public class WeaviateContainer(WeaviateConfiguration configuration) : DockerContainer(configuration); diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs new file mode 100644 index 000000000000..d112a2abfe49 --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/Support/WeaviateTestStore.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft. All rights reserved. + +#if NET472 +using System.Net.Http; +#endif +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.Weaviate; +using VectorDataSpecificationTests.Support; +using WeaviateIntegrationTests.Support.TestContainer; + +namespace WeaviateIntegrationTests.Support; + +public sealed class WeaviateTestStore : TestStore +{ + public static WeaviateTestStore Instance { get; } = new(); + + private readonly WeaviateContainer _container = new WeaviateBuilder().Build(); + public HttpClient? _httpClient { get; private set; } + private WeaviateVectorStore? _defaultVectorStore; + + public HttpClient Client => this._httpClient ?? throw new InvalidOperationException("Not initialized"); + + public override IVectorStore DefaultVectorStore => this._defaultVectorStore ?? throw new InvalidOperationException("Not initialized"); + + public WeaviateVectorStore GetVectorStore(WeaviateVectorStoreOptions options) + => new(this.Client, options); + + private WeaviateTestStore() + { + } + + protected override async Task StartAsync() + { + await this._container.StartAsync(); + this._httpClient = new HttpClient { BaseAddress = new Uri($"http://localhost:{this._container.GetMappedPublicPort(WeaviateBuilder.WeaviateHttpPort)}/v1/") }; + this._defaultVectorStore = new(this._httpClient); + } + + protected override Task StopAsync() + => this._container.StopAsync(); +} diff --git a/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateIntegrationTests.csproj b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateIntegrationTests.csproj new file mode 100644 index 000000000000..eb98407f35ee --- /dev/null +++ b/dotnet/src/VectorDataIntegrationTests/WeaviateIntegrationTests/WeaviateIntegrationTests.csproj @@ -0,0 +1,27 @@ + + + + net8.0;net472 + enable + enable + true + false + WeaviateIntegrationTests + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + +