Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ public static bool TryConvertToArray(
{
subquery.ApplyProjection();

// TODO: Should the type be an array, or enumerable/queryable?
var arrayClrType = projection.Type.MakeArrayType();
var arrayClrType = typeof(IEnumerable<>).MakeGenericType(projection.Type);

switch (projection)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Cosmos.Internal;
using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;

namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
Expand All @@ -21,7 +23,9 @@ private static readonly MethodInfo GetParameterValueMethodInfo
= typeof(CosmosProjectionBindingExpressionVisitor)
.GetTypeInfo().GetDeclaredMethod(nameof(GetParameterValue))!;

private readonly CosmosQueryableMethodTranslatingExpressionVisitor _queryableMethodTranslatingExpressionVisitor;
private readonly CosmosSqlTranslatingExpressionVisitor _sqlTranslator;
private readonly ITypeMappingSource _typeMappingSource;
private readonly IModel _model;
private SelectExpression _selectExpression;
private bool _clientEval;
Expand All @@ -39,10 +43,14 @@ private static readonly MethodInfo GetParameterValueMethodInfo
/// </summary>
public CosmosProjectionBindingExpressionVisitor(
IModel model,
CosmosSqlTranslatingExpressionVisitor sqlTranslator)
CosmosQueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor,
CosmosSqlTranslatingExpressionVisitor sqlTranslator,
ITypeMappingSource typeMappingSource)
{
_model = model;
_queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor;
_sqlTranslator = sqlTranslator;
_typeMappingSource = typeMappingSource;
_selectExpression = null!;
}

Expand Down Expand Up @@ -570,6 +578,50 @@ UnaryExpression unaryExpression
lambda);
}
}
else if (method is { Name: nameof(Enumerable.ToList), IsGenericMethod: true }
&& method.DeclaringType == typeof(Enumerable)
&& methodCallExpression.Arguments is [var argument]
&& argument.Type.TryGetElementType(typeof(IQueryable<>)) != null)
{
if (_queryableMethodTranslatingExpressionVisitor.TranslateSubquery(argument) is not ShapedQueryExpression subquery
|| !subquery.TryConvertToArray(_typeMappingSource, out var array))
{
throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
}

// If ToList() was composed over a subquery with operators, the result here is an ArrayExpression (ARRAY(SELECT ...)), whose
// CLR Type is IEnumerable<T>. This can be directly used in the resulting ProjectingBindingExpression - the shaper will
// simply read the JSON results out successfully.
// But if ToList() is composed directly over an array property, that property could have type e.g. T[], which will be read
// in the shaper, and then the cast from T[] to List<T> will fail. As a result, wrap the array in an additional
// "reprojection" subquery, effectively to change the CLR type.
if (array is SqlExpression scalarArray
&& !(array.Type.IsGenericType && array.Type.GetGenericTypeDefinition() == typeof(IEnumerable<>)))
{
Check.DebugAssert(
array is not ScalarArrayExpression and not ObjectArrayExpression, "ArrayExpression should be IEnumerable");

if (scalarArray is not { TypeMapping.ElementTypeMapping: CosmosTypeMapping elementTypeMapping })
{
throw new UnreachableException("Scalar array with no element type mapping");
}

// TODO: Proper alias management (#33894).
var arrayReprojectionSubquery = SelectExpression.CreateForCollection(
array, "i", new ScalarReferenceExpression("i", elementTypeMapping.ClrType, elementTypeMapping));
arrayReprojectionSubquery.ApplyProjection();

array = new ScalarArrayExpression(
arrayReprojectionSubquery,
methodCallExpression.Type, // List<>
_typeMappingSource.FindMapping(methodCallExpression.Type, _model, elementTypeMapping));
}

return new ProjectionBindingExpression(
_selectExpression,
_selectExpression.AddToProjection(array),
methodCallExpression.Type);
}
}

var @object = Visit(methodCallExpression.Object);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.EntityFrameworkCore.Cosmos.Internal;
using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal;
using Microsoft.EntityFrameworkCore.Internal;

namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
Expand Down Expand Up @@ -54,7 +55,7 @@ public CosmosQueryableMethodTranslatingExpressionVisitor(
_methodCallTranslatorProvider,
this);
_projectionBindingExpressionVisitor =
new CosmosProjectionBindingExpressionVisitor(_queryCompilationContext.Model, _sqlTranslator);
new CosmosProjectionBindingExpressionVisitor(_queryCompilationContext.Model, this, _sqlTranslator, _typeMappingSource);
_subquery = false;
}

Expand All @@ -81,7 +82,7 @@ protected CosmosQueryableMethodTranslatingExpressionVisitor(
_methodCallTranslatorProvider,
parentVisitor);
_projectionBindingExpressionVisitor =
new CosmosProjectionBindingExpressionVisitor(_queryCompilationContext.Model, _sqlTranslator);
new CosmosProjectionBindingExpressionVisitor(_queryCompilationContext.Model, this, _sqlTranslator, _typeMappingSource);
_subquery = true;
}

Expand Down Expand Up @@ -1131,8 +1132,10 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s
// ElementAtOrDefault over an array of scalars
case SqlExpression scalarArray when projection is SqlExpression element:
{
var slice = _sqlExpressionFactory.Function(
"ARRAY_SLICE", [scalarArray, translatedCount], scalarArray.Type, scalarArray.TypeMapping);
var arrayType = typeof(IEnumerable<>).MakeGenericType(projection.Type);
var arrayTypeMapping = _typeMappingSource.FindMapping(arrayType, _queryCompilationContext.Model, element.TypeMapping);

var slice = _sqlExpressionFactory.Function("ARRAY_SLICE", [scalarArray, translatedCount], arrayType, arrayTypeMapping);

// TODO: Proper alias management (#33894). Ideally reach into the source of the original SelectExpression and use that alias.
var translatedSelect = SelectExpression.CreateForCollection(
Expand All @@ -1145,8 +1148,10 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s
// ElementAtOrDefault over an array os structural types
case not null when projectedStructuralTypeShaper is not null:
{
var arrayType = typeof(IEnumerable<>).MakeGenericType(projectedStructuralTypeShaper.Type);

// TODO: Proper alias management (#33894).
var slice = new ObjectFunctionExpression("ARRAY_SLICE", [array, translatedCount], projectedStructuralTypeShaper.Type);
var slice = new ObjectFunctionExpression("ARRAY_SLICE", [array, translatedCount], arrayType);
var translatedSelect = SelectExpression.CreateForCollection(
slice,
"i",
Expand Down Expand Up @@ -1591,7 +1596,7 @@ when methodCallExpression.TryGetIndexerArguments(_queryCompilationContext.Model,
// value conversion). #34026.
var elementClrType = inlineQueryRootExpression.ElementType;
var elementTypeMapping = _typeMappingSource.FindMapping(elementClrType)!;
var arrayTypeMapping = _typeMappingSource.FindMapping(elementClrType.MakeArrayType()); // TODO: IEnumerable?
var arrayTypeMapping = _typeMappingSource.FindMapping(typeof(IEnumerable<>).MakeGenericType(elementClrType));
var inlineArray = new ArrayConstantExpression(elementClrType, translatedItems, arrayTypeMapping);

// TODO: Do proper alias management: #33894
Expand Down Expand Up @@ -1620,7 +1625,7 @@ when methodCallExpression.TryGetIndexerArguments(_queryCompilationContext.Model,
// TODO: Temporary hack - need to perform proper derivation of the array type mapping from the element (e.g. for
// value conversion). #34026.
var elementClrType = parameterQueryRootExpression.ElementType;
var arrayTypeMapping = _typeMappingSource.FindMapping(elementClrType.MakeArrayType()); // TODO: IEnumerable?
var arrayTypeMapping = _typeMappingSource.FindMapping(typeof(IEnumerable<>).MakeGenericType(elementClrType));
var elementTypeMapping = _typeMappingSource.FindMapping(elementClrType)!;
var sqlParameterExpression = new SqlParameterExpression(parameterQueryRootExpression.ParameterExpression, arrayTypeMapping);

Expand Down Expand Up @@ -1689,13 +1694,17 @@ when methodCallExpression.TryGetIndexerArguments(_queryCompilationContext.Model,
&& source2.TryConvertToArray(_typeMappingSource, out var array2, out var projection2, ignoreOrderings)
&& projection1.Type == projection2.Type)
{
var arrayType = typeof(IEnumerable<>).MakeGenericType(projection1.Type);

// Set operation over arrays of scalars
if (projection1 is SqlExpression sqlProjection1
&& projection2 is SqlExpression sqlProjection2
&& (sqlProjection1.TypeMapping ?? sqlProjection2.TypeMapping) is CoreTypeMapping typeMapping)
&& (sqlProjection1.TypeMapping ?? sqlProjection2.TypeMapping) is CosmosTypeMapping typeMapping)
{
var arrayTypeMapping = _typeMappingSource.FindMapping(arrayType, _queryCompilationContext.Model, typeMapping);

// TODO: Proper alias management (#33894).
var translation = _sqlExpressionFactory.Function(functionName, [array1, array2], projection1.Type, typeMapping);
var translation = _sqlExpressionFactory.Function(functionName, [array1, array2], arrayType, arrayTypeMapping);
var select = SelectExpression.CreateForCollection(
translation, "i", new ScalarReferenceExpression("i", projection1.Type, typeMapping));
return source1.UpdateQueryExpression(select);
Expand All @@ -1707,7 +1716,7 @@ when methodCallExpression.TryGetIndexerArguments(_queryCompilationContext.Model,
&& structuralType1 == structuralType2)
{
// TODO: Proper alias management (#33894).
var translation = new ObjectFunctionExpression(functionName, [array1, array2], projection1.Type);
var translation = new ObjectFunctionExpression(functionName, [array1, array2], arrayType);
var select = SelectExpression.CreateForCollection(
translation, "i", new ObjectReferenceExpression((IEntityType)structuralType1, "i"));
return CreateShapedQueryExpression(select, structuralType1.ClrType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </remarks>
[DebuggerDisplay("{Microsoft.EntityFrameworkCore.Query.ExpressionPrinter.Print(this), nq}")]
public class ScalarAccessExpression(Expression @object, string propertyName, Type clrType, CoreTypeMapping? typeMapping)
: SqlExpression(clrType, typeMapping), IAccessExpression
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,27 +322,9 @@ public virtual void ReplaceProjectionMapping(IDictionary<ProjectionMember, Expre
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual int AddToProjection(SqlExpression sqlExpression)
public virtual int AddToProjection(Expression sqlExpression)
=> AddToProjection(sqlExpression, null);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual int AddToProjection(EntityProjectionExpression entityProjection)
=> AddToProjection(entityProjection, null);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual int AddToProjection(ObjectArrayAccessExpression objectArrayAccess)
=> AddToProjection(objectArrayAccess, null);

private int AddToProjection(Expression expression, string? alias)
{
var existingIndex = _projection.FindIndex(pe => pe.Expression.Equals(expression));
Expand Down
4 changes: 2 additions & 2 deletions src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ private SqlExpression ApplyTypeMappingOnSqlBinary(
// TODO: This infers based on the CLR type; need to properly infer based on the element type mapping
// TODO: being applied here (e.g. WHERE @p[1] = c.PropertyWithValueConverter). #34026
var arrayTypeMapping = left.TypeMapping
?? (typeMapping is null ? null : typeMappingSource.FindMapping(typeMapping.ClrType.MakeArrayType()));
?? (typeMapping is null ? null : typeMappingSource.FindMapping(typeof(IEnumerable<>).MakeGenericType(typeMapping.ClrType)));
return new SqlBinaryExpression(
ExpressionType.ArrayIndex,
ApplyTypeMapping(left, arrayTypeMapping),
Expand Down Expand Up @@ -291,7 +291,7 @@ private InExpression ApplyTypeMappingOnIn(InExpression inExpression)
var arrayClrType = arrayExpression.Type switch
{
var t when t.TryGetSequenceType() != typeof(object) => t,
{ IsArray: true } => itemExpression.Type.MakeArrayType(),
{ IsArray: true } => typeof(IEnumerable<>).MakeGenericType(itemExpression.Type),
{ IsConstructedGenericType: true, GenericTypeArguments.Length: 1 } t
=> t.GetGenericTypeDefinition().MakeGenericType(itemExpression.Type),
_ => throw new InvalidOperationException(
Expand Down
20 changes: 10 additions & 10 deletions test/EFCore.Cosmos.FunctionalTests/Query/OwnedQueryCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,8 @@ public override Task Where_owned_collection_navigation_ToList_Count(bool async)
async, async a =>
{
// TODO: #34011
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets persisted
// as null instead of [].
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets
// persisted as null instead of [] when there are no Details. So we change the Count we check to 1.
await AssertQuery(
a,
ss => ss.Set<OwnedPerson>()
Expand All @@ -388,7 +388,7 @@ await AssertQuery(
// TODO: The following should project out a["Details"], not a: #34067
AssertSql(
"""
SELECT a
SELECT a["Details"]
FROM root c
JOIN a IN c["Orders"]
WHERE (c["Discriminator"] IN ("OwnedPerson", "Branch", "LeafB", "LeafA") AND (ARRAY_LENGTH(a["Details"]) = 1))
Expand All @@ -401,8 +401,8 @@ public override Task Where_collection_navigation_ToArray_Count(bool async)
async, async a =>
{
// TODO: #34011
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets persisted
// as null instead of [].
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets
// persisted as null instead of [] when there are no Details. So we change the Count we check to 1.
await AssertQuery(
a,
ss => ss.Set<OwnedPerson>()
Expand All @@ -428,8 +428,8 @@ public override Task Where_collection_navigation_AsEnumerable_Count(bool async)
async, async a =>
{
// TODO: #34011
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets persisted
// as null instead of [].
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets
// persisted as null instead of [] when there are no Details. So we change the Count we check to 1.
await AssertQuery(
a,
ss => ss.Set<OwnedPerson>()
Expand All @@ -455,8 +455,8 @@ public override Task Where_collection_navigation_ToList_Count_member(bool async)
async, async a =>
{
// TODO: #34011
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets persisted
// as null instead of [].
// We override this test because the test data for this class gets saved incorrectly - the Order.Details collection gets
// persisted as null instead of [] when there are no Details. So we change the Count we check to 1.
await AssertQuery(
a,
ss => ss.Set<OwnedPerson>()
Expand All @@ -469,7 +469,7 @@ await AssertQuery(

AssertSql(
"""
SELECT a
SELECT a["Details"]
FROM root c
JOIN a IN c["Orders"]
WHERE (c["Discriminator"] IN ("OwnedPerson", "Branch", "LeafB", "LeafA") AND (ARRAY_LENGTH(a["Details"]) = 1))
Expand Down
Loading