Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -820,18 +820,28 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
ExpressionType.Negate or ExpressionType.NegateChecked
=> sqlExpressionFactory.Negate(sqlOperand!),

ExpressionType.Convert or ExpressionType.ConvertChecked
when operand.Type.IsInterface
&& unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type)
// Convert nodes can be an explicit user gesture in the query, or they may get introduced by the compiler (e.g. when a Child is
// passed as an argument for a parameter of type Parent). The latter type should generally get stripped out as a pure C#/LINQ
// artifact that shouldn't affect translation, but the latter may be an indication from the user that they want to apply a
// type change.
ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs
when operand.Type.IsInterface && unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type)
// We strip out implicit conversions, e.g. float[] -> ReadOnlyMemory<float> (for vector search)
|| (unaryExpression.Method is { IsSpecialName: true, Name: "op_Implicit" }
&& IsReadOnlyMemory(unaryExpression.Type.UnwrapNullableType()))
|| unaryExpression.Type.UnwrapNullableType() == operand.Type
|| unaryExpression.Type.UnwrapNullableType() == typeof(Enum)
// Object convert needs to be converted to explicit cast when mismatching types
// But we let is pass here since we don't have explicit cast mechanism here and in some cases object convert is due to value types
// But we let it pass here since we don't have explicit cast mechanism here and in some cases object convert is due to value types
|| unaryExpression.Type == typeof(object)
=> sqlOperand!,

_ => QueryCompilationContext.NotTranslatedExpression
};

static bool IsReadOnlyMemory(Type type)
=> type is { IsGenericType: true, IsGenericTypeDefinition: false }
&& type.GetGenericTypeDefinition() == typeof(ReadOnlyMemory<>);
}

/// <inheritdoc />
Expand Down
35 changes: 34 additions & 1 deletion src/EFCore.Cosmos/Storage/Internal/CosmosVectorTypeMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Storage.Json;
using Newtonsoft.Json.Linq;

namespace Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal;

Expand Down Expand Up @@ -63,7 +64,8 @@ public CosmosVectorTypeMapping(CosmosTypeMapping mapping, CosmosVectorType vecto
: this(
new CoreTypeMappingParameters(
mapping.ClrType,
converter: mapping.Converter,
// This is a hack to allow both arrays and ROM types without different function overloads or type mappings.
converter: mapping.Converter?.GetType() == typeof(BytesToStringConverter) ? null : mapping.Converter,
mapping.Comparer,
mapping.KeyComparer,
elementMapping: mapping.ElementTypeMapping,
Expand Down Expand Up @@ -114,4 +116,35 @@ public override CoreTypeMapping WithComposedConverter(
/// </summary>
protected override CoreTypeMapping Clone(CoreTypeMappingParameters parameters)
=> new CosmosVectorTypeMapping(parameters, VectorType);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public override JToken? GenerateJToken(object? value)
{
// This is a hack to allow both arrays and ROM types without different function overloads or type mappings.
var type = value?.GetType();
if (type?.IsArray is false)
{
if (type == typeof(ReadOnlyMemory<byte>))
{
value = ((ReadOnlyMemory<byte>)value!).ToArray();
}
else if (type == typeof(ReadOnlyMemory<sbyte>))
{
value = ((ReadOnlyMemory<sbyte>)value!).ToArray();
}
else if (type == typeof(ReadOnlyMemory<float>))
{
value = ((ReadOnlyMemory<float>)value!).ToArray();
}
}

return value == null
? null
: JToken.FromObject(value, CosmosClientWrapper.Serializer);
}
}
66 changes: 43 additions & 23 deletions test/EFCore.Cosmos.FunctionalTests/VectorSearchCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,19 @@ public virtual async Task Query_for_vector_distance_bytes_array()
await using var context = CreateContext();
var inputVector = new byte[] { 2, 1, 4, 3, 5, 2, 5, 7, 3, 1 };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>().Select(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)).ToListAsync());
var booksFromStore = await context
.Set<Book>()
.Select(e => EF.Functions.VectorDistance(e.BytesArray, inputVector))
.ToListAsync();

// Assert.Equal(3, booksFromStore.Count);
// Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s));
Assert.Equal(3, booksFromStore.Count);
Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s));

AssertSql(
"""
SELECT VALUE c["BytesArray"]
@__p_1='[2,1,4,3,5,2,5,7,3,1]'

SELECT VALUE VectorDistance(c["BytesArray"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'uint8'})
FROM root c
""");
}
Expand All @@ -119,17 +122,20 @@ public virtual async Task Query_for_vector_distance_singles_array()
await using var context = CreateContext();
var inputVector = new[] { 0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f, 0.86f, -0.78f };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>()
.Select(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector, false, DistanceFunction.DotProduct)).ToListAsync());
var booksFromStore = await context
.Set<Book>()
.Select(
e => EF.Functions.VectorDistance(e.SinglesArray, inputVector, false, DistanceFunction.DotProduct))
.ToListAsync();

// Assert.Equal(3, booksFromStore.Count);
// Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s));
Assert.Equal(3, booksFromStore.Count);
Assert.All(booksFromStore, s => Assert.NotEqual(0.0, s));

AssertSql(
"""
SELECT VALUE c["SinglesArray"]
@__p_1='[0.33,-0.52,0.45,-0.67,0.89,-0.34,0.86,-0.78,0.86,-0.78]'

SELECT VALUE VectorDistance(c["SinglesArray"], @__p_1, false, {'distanceFunction':'dotproduct', 'dataType':'float32'})
FROM root c
""");
}
Expand Down Expand Up @@ -207,13 +213,20 @@ public virtual async Task Vector_distance_bytes_array_in_OrderBy()
await using var context = CreateContext();
var inputVector = new byte[] { 2, 1, 4, 6, 5, 2, 5, 7, 3, 1 };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>().OrderBy(e => EF.Functions.VectorDistance(e.BytesArray, inputVector)).ToListAsync());
var booksFromStore = await context
.Set<Book>()
.OrderBy(e => EF.Functions.VectorDistance(e.BytesArray, inputVector))
.ToListAsync();

// Assert.Equal(3, booksFromStore.Count);
Assert.Equal(3, booksFromStore.Count);
AssertSql(
"""
@__p_1='[2,1,4,6,5,2,5,7,3,1]'

AssertSql();
SELECT VALUE c
FROM root c
ORDER BY VectorDistance(c["BytesArray"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'uint8'})
""");
}

[ConditionalFact]
Expand All @@ -222,13 +235,20 @@ public virtual async Task Vector_distance_singles_array_in_OrderBy()
await using var context = CreateContext();
var inputVector = new[] { 0.33f, -0.52f, 0.45f, -0.67f, 0.89f, -0.34f, 0.86f, -0.78f };

// See Issue #34402
await Assert.ThrowsAsync<InvalidOperationException>(
() => context.Set<Book>().OrderBy(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector)).ToListAsync());
var booksFromStore = await context
.Set<Book>()
.OrderBy(e => EF.Functions.VectorDistance(e.SinglesArray, inputVector))
.ToListAsync();

// Assert.Equal(3, booksFromStore.Count);
Assert.Equal(3, booksFromStore.Count);
AssertSql(
"""
@__p_1='[0.33,-0.52,0.45,-0.67,0.89,-0.34,0.86,-0.78]'

AssertSql();
SELECT VALUE c
FROM root c
ORDER BY VectorDistance(c["SinglesArray"], @__p_1, false, {'distanceFunction':'cosine', 'dataType':'float32'})
""");
}

[ConditionalFact]
Expand Down