From a84d9e37b266dd7d8d373b732e3c4909c37abe07 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Wed, 14 Aug 2024 15:35:07 +0200 Subject: [PATCH 1/4] Cosmos: strip implicit casts to allow vector search over arrays Fixes #34402 --- .../CosmosSqlTranslatingExpressionVisitor.cs | 13 ++-- .../VectorSearchCosmosTest.cs | 67 ++++++++++++------- 2 files changed, 52 insertions(+), 28 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs index e71bf98a6a1..8bd0d118b22 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs @@ -820,13 +820,18 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) ExpressionType.Negate or ExpressionType.NegateChecked => sqlExpressionFactory.Negate(sqlOperand!), - ExpressionType.Convert or ExpressionType.ConvertChecked - when operand.Type.IsInterface - && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type) + // Convert nodes can be an explicit user gesture in the query, or they may get introduced by the compiler (e.g. when a Child is + // passed as an argument for a parameter of type Parent). The latter type should generally get stripped out as a pure C#/LINQ + // artifact that shouldn't affect translation, but the latter may be an indication from the user that they want to apply a + // type change. + ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs + when operand.Type.IsInterface && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type) + // We strip out implicit conversions, e.g. float[] -> ReadOnlyMemory (for vector search) + || unaryExpression.Method is { IsSpecialName: true, Name: "op_Implicit"} || unaryExpression.Type.UnwrapNullableType() == operand.Type || unaryExpression.Type.UnwrapNullableType() == typeof(Enum) // Object convert needs to be converted to explicit cast when mismatching types - // But we let is pass here since we don't have explicit cast mechanism here and in some cases object convert is due to value types + // But we let it pass here since we don't have explicit cast mechanism here and in some cases object convert is due to value types || unaryExpression.Type == typeof(object) => sqlOperand!, diff --git a/test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs index c67ca8c3cba..843aa659238 100644 --- a/test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs @@ -99,16 +99,19 @@ public virtual async Task Query_for_vector_distance_bytes_array() await using var context = CreateContext(); var inputVector = new byte[] { 2, 1, 4, 3, 5, 2, 5, 7, 3, 1 }; - // See Issue #34402 - await Assert.ThrowsAsync( - () => context.Set().Select(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)).ToListAsync()); + var booksFromStore = await context + .Set() + .Select(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)) + .ToListAsync(); - // Assert.Equal(3, booksFromStore.Count); - // Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s)); + Assert.Equal(3, booksFromStore.Count); + Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s)); AssertSql( """ -SELECT VALUE c["BytesArray"] +@__inputVector_1='[2,1,4,3,5,2,5,7,3,1]' + +SELECT VALUE VectorDistance(c["Bytes"], @__inputVector_1, false, {'distanceFunction':'cosine', 'dataType':'uint8'}) FROM root c """); } @@ -119,17 +122,20 @@ public virtual async Task Query_for_vector_distance_singles_array() await using var context = CreateContext(); var inputVector = new[] { 0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f, 0.86f, -0.78f }; - // See Issue #34402 - await Assert.ThrowsAsync( - () => context.Set() - .Select(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector, false, DistanceFunction.DotProduct)).ToListAsync()); + var booksFromStore = await context + .Set() + .Select( + e => EF.Functions.VectorDistance(e.SinglesArray, inputVector, false, DistanceFunction.DotProduct)) + .ToListAsync(); - // Assert.Equal(3, booksFromStore.Count); - // Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s)); + Assert.Equal(3, booksFromStore.Count); + Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s)); AssertSql( """ -SELECT VALUE c["SinglesArray"] +@__inputVector_1='[0.33,-0.52,0.45,-0.67,0.89,-0.34,0.86,-0.78,0.86,-0.78]' + +SELECT VALUE VectorDistance(c["Singles"], @__inputVector_1, false, {'distanceFunction':'dotproduct', 'dataType':'float32'}) FROM root c """); } @@ -207,14 +213,20 @@ public virtual async Task Vector_distance_bytes_array_in_OrderBy() await using var context = CreateContext(); var inputVector = new byte[] { 2, 1, 4, 6, 5, 2, 5, 7, 3, 1 }; - // See Issue #34402 - await Assert.ThrowsAsync( - () => context.Set().OrderBy(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)).ToListAsync()); - - // Assert.Equal(3, booksFromStore.Count); + var booksFromStore = await context + .Set() + .OrderBy(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)) + .ToListAsync(); + Assert.Equal(3, booksFromStore.Count); AssertSql( -); + """ +@__p_1='[2,1,4,6,5,2,5,7,3,1]' + +SELECT VALUE c +FROM root c +ORDER BY VectorDistance(c["Bytes"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'uint8'}) +"""); } [ConditionalFact] @@ -223,13 +235,20 @@ public virtual async Task Vector_distance_singles_array_in_OrderBy() await using var context = CreateContext(); var inputVector = new[] { 0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f }; - // See Issue #34402 - await Assert.ThrowsAsync( - () => context.Set().OrderBy(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector)).ToListAsync()); + var booksFromStore = await context + .Set() + .OrderBy(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector)) + .ToListAsync(); - // Assert.Equal(3, booksFromStore.Count); + Assert.Equal(3, booksFromStore.Count); + AssertSql( + """ +@__p_1='[0.33,-0.52,0.45,-0.67,0.89,-0.34,0.86,-0.78]' - AssertSql(); +SELECT VALUE c +FROM root c +ORDER BY VectorDistance(c["Singles"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'float32'}) +"""); } [ConditionalFact] From c8a5471e0028a20042b8f5bc677f8290a011220b Mon Sep 17 00:00:00 2001 From: Arthur Vickers Date: Tue, 20 Aug 2024 16:59:24 +0100 Subject: [PATCH 2/4] Hack to support byte arrays and ReadOnlyMemory without overloads or appropriate type mappings. --- .../Internal/CosmosVectorTypeMapping.cs | 35 ++++++++++++++++++- .../VectorSearchCosmosTest.cs | 12 +++---- 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/EFCore.Cosmos/Storage/Internal/CosmosVectorTypeMapping.cs b/src/EFCore.Cosmos/Storage/Internal/CosmosVectorTypeMapping.cs index ab83b515244..e9279c00e62 100644 --- a/src/EFCore.Cosmos/Storage/Internal/CosmosVectorTypeMapping.cs +++ b/src/EFCore.Cosmos/Storage/Internal/CosmosVectorTypeMapping.cs @@ -4,6 +4,7 @@ using System.Diagnostics.CodeAnalysis; using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; using Microsoft.EntityFrameworkCore.Storage.Json; +using Newtonsoft.Json.Linq; namespace Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; @@ -63,7 +64,8 @@ public CosmosVectorTypeMapping(CosmosTypeMapping mapping, CosmosVectorType vecto : this( new CoreTypeMappingParameters( mapping.ClrType, - converter: mapping.Converter, + // This is a hack to allow both arrays and ROM types without different function overloads or type mappings. + converter: mapping.Converter?.GetType() == typeof(BytesToStringConverter) ? null : mapping.Converter, mapping.Comparer, mapping.KeyComparer, elementMapping: mapping.ElementTypeMapping, @@ -114,4 +116,35 @@ public override CoreTypeMapping WithComposedConverter( /// protected override CoreTypeMapping Clone(CoreTypeMappingParameters parameters) => new CosmosVectorTypeMapping(parameters, VectorType); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override JToken? GenerateJToken(object? value) + { + // This is a hack to allow both arrays and ROM types without different function overloads or type mappings. + var type = value?.GetType(); + if (type?.IsArray is false) + { + if (type == typeof(ReadOnlyMemory)) + { + value = ((ReadOnlyMemory)value!).ToArray(); + } + else if (type == typeof(ReadOnlyMemory)) + { + value = ((ReadOnlyMemory)value!).ToArray(); + } + else if (type == typeof(ReadOnlyMemory)) + { + value = ((ReadOnlyMemory)value!).ToArray(); + } + } + + return value == null + ? null + : JToken.FromObject(value, CosmosClientWrapper.Serializer); + } } diff --git a/test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs index 843aa659238..38e32f1bccc 100644 --- a/test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs @@ -109,9 +109,9 @@ public virtual async Task Query_for_vector_distance_bytes_array() AssertSql( """ -@__inputVector_1='[2,1,4,3,5,2,5,7,3,1]' +@__p_1='[2,1,4,3,5,2,5,7,3,1]' -SELECT VALUE VectorDistance(c["Bytes"], @__inputVector_1, false, {'distanceFunction':'cosine', 'dataType':'uint8'}) +SELECT VALUE VectorDistance(c["BytesArray"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'uint8'}) FROM root c """); } @@ -133,9 +133,9 @@ public virtual async Task Query_for_vector_distance_singles_array() AssertSql( """ -@__inputVector_1='[0.33,-0.52,0.45,-0.67,0.89,-0.34,0.86,-0.78,0.86,-0.78]' +@__p_1='[0.33,-0.52,0.45,-0.67,0.89,-0.34,0.86,-0.78,0.86,-0.78]' -SELECT VALUE VectorDistance(c["Singles"], @__inputVector_1, false, {'distanceFunction':'dotproduct', 'dataType':'float32'}) +SELECT VALUE VectorDistance(c["SinglesArray"], @__p_1, false, {'distanceFunction':'dotproduct', 'dataType':'float32'}) FROM root c """); } @@ -225,7 +225,7 @@ public virtual async Task Vector_distance_bytes_array_in_OrderBy() SELECT VALUE c FROM root c -ORDER BY VectorDistance(c["Bytes"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'uint8'}) +ORDER BY VectorDistance(c["BytesArray"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'uint8'}) """); } @@ -247,7 +247,7 @@ public virtual async Task Vector_distance_singles_array_in_OrderBy() SELECT VALUE c FROM root c -ORDER BY VectorDistance(c["Singles"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'float32'}) +ORDER BY VectorDistance(c["SinglesArray"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'float32'}) """); } From 3c0b38ab8997e48722afa1a5c5f5ff87a1fc1356 Mon Sep 17 00:00:00 2001 From: Arthur Vickers Date: Wed, 21 Aug 2024 12:55:17 +0100 Subject: [PATCH 3/4] Be more restrictive in the casts we allow --- .../Internal/CosmosSqlTranslatingExpressionVisitor.cs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs index 8bd0d118b22..797b4bf4cbe 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs @@ -827,7 +827,8 @@ ExpressionType.Negate or ExpressionType.NegateChecked ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs when operand.Type.IsInterface && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type) // We strip out implicit conversions, e.g. float[] -> ReadOnlyMemory (for vector search) - || unaryExpression.Method is { IsSpecialName: true, Name: "op_Implicit"} + || (unaryExpression.Method is { IsSpecialName: true, Name: "op_Implicit" } + && IsRom(unaryExpression.Type.UnwrapNullableType())) || unaryExpression.Type.UnwrapNullableType() == operand.Type || unaryExpression.Type.UnwrapNullableType() == typeof(Enum) // Object convert needs to be converted to explicit cast when mismatching types @@ -837,6 +838,10 @@ ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs _ => QueryCompilationContext.NotTranslatedExpression }; + + static bool IsRom(Type type) + => type is { IsGenericType: true, IsGenericTypeDefinition: false } + && type.GetGenericTypeDefinition() == typeof(ReadOnlyMemory<>); } /// From 7e10d07ac8b1e51a3e4a76034419e77935c56f1b Mon Sep 17 00:00:00 2001 From: Arthur Vickers Date: Wed, 21 Aug 2024 14:44:04 +0100 Subject: [PATCH 4/4] Nit --- .../Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs index 797b4bf4cbe..9d803df02c6 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs @@ -828,7 +828,7 @@ ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs when operand.Type.IsInterface && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type) // We strip out implicit conversions, e.g. float[] -> ReadOnlyMemory (for vector search) || (unaryExpression.Method is { IsSpecialName: true, Name: "op_Implicit" } - && IsRom(unaryExpression.Type.UnwrapNullableType())) + && IsReadOnlyMemory(unaryExpression.Type.UnwrapNullableType())) || unaryExpression.Type.UnwrapNullableType() == operand.Type || unaryExpression.Type.UnwrapNullableType() == typeof(Enum) // Object convert needs to be converted to explicit cast when mismatching types @@ -839,7 +839,7 @@ ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs _ => QueryCompilationContext.NotTranslatedExpression }; - static bool IsRom(Type type) + static bool IsReadOnlyMemory(Type type) => type is { IsGenericType: true, IsGenericTypeDefinition: false } && type.GetGenericTypeDefinition() == typeof(ReadOnlyMemory<>); }