Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class CosmosQueryableMethodTranslatingExpressionVisitor : QueryableMethod
private readonly CosmosSqlTranslatingExpressionVisitor _sqlTranslator;
private readonly CosmosProjectionBindingExpressionVisitor _projectionBindingExpressionVisitor;
private readonly bool _subquery;
private ReadItemInfo? _readItemExpression;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down Expand Up @@ -91,56 +92,85 @@ protected CosmosQueryableMethodTranslatingExpressionVisitor(
[return: NotNullIfNotNull(nameof(expression))]
public override Expression? Visit(Expression? expression)
{
if (expression is MethodCallExpression
if (_queryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.TrackAll // Issue #33893
&& expression is MethodCallExpression
{
Method: { Name: nameof(Queryable.FirstOrDefault), IsGenericMethod: true },
Arguments:
[
MethodCallExpression
{
Method: { Name: nameof(Queryable.Where), IsGenericMethod: true },
Arguments:
[
EntityQueryRootExpression { EntityType: var entityType },
UnaryExpression { Operand: LambdaExpression lambdaExpression, NodeType: ExpressionType.Quote }
]
} whereMethodCall
]
} firstOrDefaultMethodCall
&& firstOrDefaultMethodCall.Method.GetGenericMethodDefinition() == QueryableMethods.FirstOrDefaultWithoutPredicate
&& whereMethodCall.Method.GetGenericMethodDefinition() == QueryableMethods.Where)
Arguments: [MethodCallExpression innerMethodCall]
})
{
var queryProperties = new List<IProperty>();
var parameterNames = new List<string>();

if (ExtractPartitionKeyFromPredicate(entityType, lambdaExpression.Body, queryProperties, parameterNames))
var clrType = innerMethodCall.Type.TryGetSequenceType() ?? typeof(object);
if (innerMethodCall is
{
Method: { Name: nameof(Queryable.Select), IsGenericMethod: true },
Arguments:
[
MethodCallExpression innerInnerMethodCall,
UnaryExpression { NodeType: ExpressionType.Quote } unaryExpression
]
})
{
var entityTypePrimaryKeyProperties = entityType.FindPrimaryKey()!.Properties;
var idProperty = entityType.GetProperties()
.First(p => p.GetJsonPropertyName() == StoreKeyConvention.IdPropertyJsonName);
var partitionKeyProperties = entityType.GetPartitionKeyProperties();

if (entityTypePrimaryKeyProperties.SequenceEqual(queryProperties)
&& (!partitionKeyProperties.Any()
|| partitionKeyProperties.All(p => entityTypePrimaryKeyProperties.Contains(p)))
&& (idProperty.GetValueGeneratorFactory() != null
|| entityTypePrimaryKeyProperties.Contains(idProperty)))
if (unaryExpression.Operand is LambdaExpression)
{
var propertyParameterList = queryProperties.Zip(
parameterNames,
(property, parameter) => (property, parameter))
.ToDictionary(tuple => tuple.property, tuple => tuple.parameter);
innerMethodCall = innerInnerMethodCall;
}
}

var readItemExpression = new ReadItemExpression(entityType, propertyParameterList);
if (innerMethodCall is
{
Method: { Name: nameof(Queryable.Where), IsGenericMethod: true },
Arguments:
[
EntityQueryRootExpression { EntityType: var entityType },
UnaryExpression { Operand: LambdaExpression lambdaExpression, NodeType: ExpressionType.Quote }
]
})
{
var queryProperties = new List<IProperty>();
var parameterNames = new List<string>();

if (ExtractPartitionKeyFromPredicate(entityType, lambdaExpression.Body, queryProperties, parameterNames))
{
var entityTypePrimaryKeyProperties = entityType.FindPrimaryKey()!.Properties;
var idProperty = entityType.GetProperties()
.First(p => p.GetJsonPropertyName() == StoreKeyConvention.IdPropertyJsonName);
var partitionKeyProperties = entityType.GetPartitionKeyProperties();

if (entityTypePrimaryKeyProperties.SequenceEqual(queryProperties)
&& (!partitionKeyProperties.Any()
|| partitionKeyProperties.All(p => entityTypePrimaryKeyProperties.Contains(p)))
// This should ideally only be looking for properties with the `IdValueGeneratorFactory` generator. since
// this is how the `id` property will be generated from other key values.
&& ((idProperty.GetValueGeneratorFactory() != null
// If we can't create an instance, then we might not be able to construct the resource id.
&& CanCreateEmptyInstance(entityType))
|| entityTypePrimaryKeyProperties.Contains(idProperty)))
{
var propertyParameterList = queryProperties.Zip(
parameterNames,
(property, parameter) => (property, parameter))
.ToDictionary(tuple => tuple.property, tuple => tuple.parameter);

return CreateShapedQueryExpression(entityType, readItemExpression)
.UpdateResultCardinality(ResultCardinality.SingleOrDefault);
_readItemExpression = new ReadItemInfo(entityType, propertyParameterList, clrType);
}
}
}
}

return base.Visit(expression);

static bool CanCreateEmptyInstance(IEntityType entityType)
{
var binding = entityType.ServiceOnlyConstructorBinding;
if (binding == null)
{
_ = entityType.ConstructorBinding;
binding = entityType.ServiceOnlyConstructorBinding;
}

return binding != null;
}

static bool ExtractPartitionKeyFromPredicate(
IEntityType entityType,
Expression joinCondition,
Expand Down Expand Up @@ -256,7 +286,11 @@ protected override QueryableMethodTranslatingExpressionVisitor CreateSubqueryVis
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType)
=> CreateShapedQueryExpression(entityType, _sqlExpressionFactory.Select(entityType));
=> CreateShapedQueryExpression(
entityType,
_readItemExpression == null
? _sqlExpressionFactory.Select(entityType)
: _sqlExpressionFactory.ReadItem(entityType, _readItemExpression));

private ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType, Expression queryExpression)
{
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ private sealed class ReadItemQueryingEnumerable<T> : IEnumerable<T>, IAsyncEnume
{
private readonly CosmosQueryContext _cosmosQueryContext;
private readonly string _cosmosContainer;
private readonly ReadItemExpression _readItemExpression;
private readonly ReadItemInfo _readItemInfo;
private readonly Func<CosmosQueryContext, JObject, T> _shaper;
private readonly Type _contextType;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _queryLogger;
Expand All @@ -33,15 +33,15 @@ private sealed class ReadItemQueryingEnumerable<T> : IEnumerable<T>, IAsyncEnume
public ReadItemQueryingEnumerable(
CosmosQueryContext cosmosQueryContext,
string cosmosContainer,
ReadItemExpression readItemExpression,
ReadItemInfo readItemInfo,
Func<CosmosQueryContext, JObject, T> shaper,
Type contextType,
bool standAloneStateManager,
bool threadSafetyChecksEnabled)
{
_cosmosQueryContext = cosmosQueryContext;
_cosmosContainer = cosmosContainer;
_readItemExpression = readItemExpression;
_readItemInfo = readItemInfo;
_shaper = shaper;
_contextType = contextType;
_queryLogger = _cosmosQueryContext.QueryLogger;
Expand All @@ -67,7 +67,7 @@ public string ToQueryString()

private bool TryGetPartitionKey(out PartitionKey partitionKeyValue)
{
var properties = _readItemExpression.EntityType.GetPartitionKeyProperties();
var properties = _readItemInfo.EntityType.GetPartitionKeyProperties();
if (!properties.Any())
{
partitionKeyValue = PartitionKey.None;
Expand Down Expand Up @@ -95,7 +95,7 @@ private bool TryGetPartitionKey(out PartitionKey partitionKeyValue)

private bool TryGetResourceId(out string resourceId)
{
var idProperty = _readItemExpression.EntityType.GetProperties()
var idProperty = _readItemInfo.EntityType.GetProperties()
.FirstOrDefault(p => p.GetJsonPropertyName() == StoreKeyConvention.IdPropertyJsonName);

if (TryGetParameterValue(idProperty, out var value))
Expand Down Expand Up @@ -124,7 +124,7 @@ private bool TryGetResourceId(out string resourceId)
private bool TryGetParameterValue(IProperty property, out object value)
{
value = null;
return _readItemExpression.PropertyParameters.TryGetValue(property, out var parameterName)
return _readItemInfo.PropertyParameters.TryGetValue(property, out var parameterName)
&& _cosmosQueryContext.ParameterValues.TryGetValue(parameterName, out value);
}

Expand All @@ -139,39 +139,36 @@ private static string GetString(IProperty property, object value)

private bool TryGenerateIdFromKeys(IProperty idProperty, out object value)
{
var entityEntry = Activator.CreateInstance(_readItemExpression.EntityType.ClrType);

#pragma warning disable EF1001 // Internal EF Core API usage.
// The idea here is that if a `IdValueGeneratorFactory` has been configured to generate an `id` value from the
// values of other properties, then we need an entity instance to use with the value generator.
var entityInstance = _readItemInfo.EntityType.GetOrCreateEmptyMaterializer(_cosmosQueryContext.EntityMaterializerSource)
(new MaterializationContext(ValueBuffer.Empty, _cosmosQueryContext.Context));

var internalEntityEntry = new InternalEntityEntry(
_cosmosQueryContext.Context.GetDependencies().StateManager, _readItemExpression.EntityType, entityEntry);
#pragma warning restore EF1001 // Internal EF Core API usage.
_cosmosQueryContext.Context.GetDependencies().StateManager, _readItemInfo.EntityType, entityInstance);

foreach (var keyProperty in _readItemExpression.EntityType.FindPrimaryKey().Properties)
foreach (var keyProperty in _readItemInfo.EntityType.FindPrimaryKey().Properties)
{
var property = _readItemExpression.EntityType.FindProperty(keyProperty.Name);
var property = _readItemInfo.EntityType.FindProperty(keyProperty.Name);

if (TryGetParameterValue(property, out var parameterValue))
{
#pragma warning disable EF1001 // Internal EF Core API usage.
internalEntityEntry[property] = parameterValue;
#pragma warning restore EF1001 // Internal EF Core API usage.
}
}

#pragma warning disable EF1001 // Internal EF Core API usage.
internalEntityEntry.SetEntityState(EntityState.Added);

value = internalEntityEntry[idProperty];

internalEntityEntry.SetEntityState(EntityState.Detached);
#pragma warning restore EF1001 // Internal EF Core API usage.

return value != null;
#pragma warning restore EF1001 // Internal EF Core API usage.
}

private sealed class Enumerator : IEnumerator<T>, IAsyncEnumerator<T>
{
private readonly CosmosQueryContext _cosmosQueryContext;
private readonly ReadItemInfo _readItemInfo;
private readonly string _cosmosContainer;
private readonly Func<CosmosQueryContext, JObject, T> _shaper;
private readonly Type _contextType;
Expand All @@ -188,6 +185,7 @@ private sealed class Enumerator : IEnumerator<T>, IAsyncEnumerator<T>
public Enumerator(ReadItemQueryingEnumerable<T> readItemEnumerable, CancellationToken cancellationToken = default)
{
_cosmosQueryContext = readItemEnumerable._cosmosQueryContext;
_readItemInfo = readItemEnumerable._readItemInfo;
_cosmosContainer = readItemEnumerable._cosmosContainer;
_shaper = readItemEnumerable._shaper;
_contextType = readItemEnumerable._contextType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,45 +59,36 @@ protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQuery
QueryCompilationContext.QueryContextParameter,
jObjectParameter);

return New(
typeof(QueryingEnumerable<>).MakeGenericType(shaperLambda.ReturnType).GetConstructors()[0],
Convert(
QueryCompilationContext.QueryContextParameter,
typeof(CosmosQueryContext)),
Constant(sqlExpressionFactory),
Constant(querySqlGeneratorFactory),
Constant(selectExpression),
Constant(shaperLambda.Compile()),
Constant(_contextType),
Constant(cosmosQueryCompilationContext.CosmosContainer),
Constant(_partitionKeyValueFromExtension, typeof(PartitionKey)),
Constant(
QueryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.NoTrackingWithIdentityResolution),
Constant(_threadSafetyChecksEnabled));
var cosmosQueryContextConstant = Convert(QueryCompilationContext.QueryContextParameter, typeof(CosmosQueryContext));
var shaperConstant = Constant(shaperLambda.Compile());
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd need to move out this Compile() to make Cosmos compatible with pre-compiled queries.

@roji De we have a separate issue tracking this work?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, opened #33909 to track Cosmos NativeAOT support.

var contextTypeConstant = Constant(_contextType);
var containerConstant = Constant(cosmosQueryCompilationContext.CosmosContainer);
var threadSafetyConstant = Constant(_threadSafetyChecksEnabled);
var standAloneStateManagerConstant = Constant(
QueryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.NoTrackingWithIdentityResolution);

case ReadItemExpression readItemExpression:
shaperBody = new CosmosProjectionBindingRemovingReadItemExpressionVisitor(
readItemExpression, jObjectParameter,
QueryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.TrackAll)
.Visit(shaperBody);

var shaperReadItemLambda = Lambda(
shaperBody,
QueryCompilationContext.QueryContextParameter,
jObjectParameter);

return New(
typeof(ReadItemQueryingEnumerable<>).MakeGenericType(shaperReadItemLambda.ReturnType).GetConstructors()[0],
Convert(
QueryCompilationContext.QueryContextParameter,
typeof(CosmosQueryContext)),
Constant(cosmosQueryCompilationContext.CosmosContainer),
Constant(readItemExpression),
Constant(shaperReadItemLambda.Compile()),
Constant(_contextType),
Constant(
QueryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.NoTrackingWithIdentityResolution),
Constant(_threadSafetyChecksEnabled));
return selectExpression.ReadItemInfo != null
? New(
typeof(ReadItemQueryingEnumerable<>).MakeGenericType(selectExpression.ReadItemInfo.Type).GetConstructors()[0],
cosmosQueryContextConstant,
containerConstant,
Constant(selectExpression.ReadItemInfo),
shaperConstant,
contextTypeConstant,
standAloneStateManagerConstant,
threadSafetyConstant)
: New(
typeof(QueryingEnumerable<>).MakeGenericType(shaperLambda.ReturnType).GetConstructors()[0],
cosmosQueryContextConstant,
Constant(sqlExpressionFactory),
Constant(querySqlGeneratorFactory),
Constant(selectExpression),
shaperConstant,
contextTypeConstant,
containerConstant,
Constant(_partitionKeyValueFromExtension, typeof(PartitionKey)),
standAloneStateManagerConstant,
threadSafetyConstant);

default:
throw new NotSupportedException(CoreStrings.UnhandledExpressionNode(shapedQueryExpression.QueryExpression));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ protected override Expression VisitExtension(Expression extensionExpression)
=> extensionExpression switch
{
ShapedQueryExpression shapedQueryExpression => VisitShapedQueryExpression(shapedQueryExpression),
ReadItemExpression readItemExpression => readItemExpression,
SelectExpression selectExpression => VisitSelect(selectExpression),
SqlConditionalExpression sqlConditionalExpression => VisitSqlConditional(sqlConditionalExpression),
_ => base.VisitExtension(extensionExpression)
Expand Down
Loading