Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/EFCore.Design/Query/Internal/CSharpToLinqTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1245,10 +1245,10 @@ private sealed class FakeFieldInfo(
public bool IsNonNullableReferenceType { get; } = isNonNullableReferenceType;

public override object[] GetCustomAttributes(bool inherit)
=> Array.Empty<object>();
=> [];

public override object[] GetCustomAttributes(Type attributeType, bool inherit)
=> Array.Empty<object>();
=> [];

public override bool IsDefined(Type attributeType, bool inherit)
=> false;
Expand Down Expand Up @@ -1289,10 +1289,10 @@ public override RuntimeFieldHandle FieldHandle
private sealed class FakeConstructorInfo(Type type, ParameterInfo[] parameters) : ConstructorInfo
{
public override object[] GetCustomAttributes(bool inherit)
=> Array.Empty<object>();
=> [];

public override object[] GetCustomAttributes(Type attributeType, bool inherit)
=> Array.Empty<object>();
=> [];

public override bool IsDefined(Type attributeType, bool inherit)
=> false;
Expand Down
127 changes: 110 additions & 17 deletions src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -734,18 +734,52 @@ void ProcessCapturedVariables()

for (var i = 1; i < parameters.Length; i++)
{
var parameter = parameters[i];
var (parameterName, parameterType) = (parameters[i].Name!, parameters[i].ParameterType);

if (parameter.ParameterType == typeof(CancellationToken))
if (parameterType == typeof(CancellationToken))
{
continue;
}

if (_funcletizer.CalculatePathsToEvaluatableRoots(operatorMethodCall, i) is not ExpressionTreeFuncletizer.PathNode
evaluatableRootPaths)
ExpressionTreeFuncletizer.PathNode? evaluatableRootPaths;

// ExecuteUpdate requires really special handling: the function accepts a Func<SetPropertyCalls...> argument, but
// we need to run funcletization on the setter lambdas added via that Func<>.
if (operatorMethodCall.Method is
{
Name: nameof(EntityFrameworkQueryableExtensions.ExecuteUpdate)
or nameof(EntityFrameworkQueryableExtensions.ExecuteUpdateAsync),
IsGenericMethod: true
}
&& operatorMethodCall.Method.DeclaringType == typeof(EntityFrameworkQueryableExtensions))
{
// There are no captured variables in this lambda argument - skip the argument
continue;
// First, statically convert the Func<SetPropertyCalls...> to a NewArrayExpression which represents all the
// setters; since that's an expression, we can run the funcletizer on it.
var settersExpression = ProcessExecuteUpdate(operatorMethodCall);
evaluatableRootPaths = _funcletizer.CalculatePathsToEvaluatableRoots(settersExpression);

if (evaluatableRootPaths is null)
{
// There are no captured variables in this lambda argument - skip the argument
continue;
}

// If there were captured variables, generate code to evaluate and build the same NewArrayExpression at runtime,
// and then fall through to the normal logic, generating variable extractors against that NewArrayExpression
// (local var) instead of against the method argument.
code.AppendLine(
$"var setters = {parameterName}(new SetPropertyCalls<{sourceElementTypeName}>()).BuildSettersExpression();");
parameterName = "setters";
parameterType = typeof(NewArrayExpression);
}
else
{
evaluatableRootPaths = _funcletizer.CalculatePathsToEvaluatableRoots(operatorMethodCall, i);
if (evaluatableRootPaths is null)
{
// There are no captured variables in this lambda argument - skip the argument
continue;
}
}

// We have a lambda argument with captured variables. Use the information returned by the funcletizer to generate code
Expand All @@ -756,11 +790,11 @@ void ProcessCapturedVariables()
declaredQueryContextVariable = true;
}

if (!parameter.ParameterType.IsSubclassOf(typeof(Expression)))
if (!parameterType.IsSubclassOf(typeof(Expression)))
{
// Special case: this is a non-lambda argument (Skip/Take/FromSql).
// Simply add the argument directly as a parameter
code.AppendLine($"""queryContext.AddParameter("{evaluatableRootPaths.ParameterName}", {parameter.Name});""");
code.AppendLine($"""queryContext.AddParameter("{evaluatableRootPaths.ParameterName}", {parameterName});""");
continue;
}

Expand All @@ -769,7 +803,7 @@ void ProcessCapturedVariables()
// Lambda argument. Recurse through evaluatable path trees.
foreach (var child in evaluatableRootPaths.Children!)
{
GenerateCapturedVariableExtractors(parameter.Name!, parameter.ParameterType, child);
GenerateCapturedVariableExtractors(parameterName, parameterType, child);

void GenerateCapturedVariableExtractors(
string currentIdentifier,
Expand All @@ -786,12 +820,13 @@ void GenerateCapturedVariableExtractors(

var variableName = capturedVariablesPathTree.ExpressionType.Name;
variableName = char.ToLower(variableName[0]) + variableName[1..^"Expression".Length] + ++variableCounter;
code.AppendLine(
$"var {variableName} = ({capturedVariablesPathTree.ExpressionType.Name}){roslynPathSegment};");

if (capturedVariablesPathTree.Children?.Count > 0)
{
// This is an intermediate node which has captured variables in the children. Continue recursing down.
code.AppendLine(
$"var {variableName} = ({capturedVariablesPathTree.ExpressionType.Name}){roslynPathSegment};");

foreach (var child in capturedVariablesPathTree.Children)
{
GenerateCapturedVariableExtractors(variableName, capturedVariablesPathTree.ExpressionType, child);
Expand All @@ -816,7 +851,7 @@ void GenerateCapturedVariableExtractors(
{
code
.Append('"').Append(capturedVariablesPathTree.ParameterName!).AppendLine("\",")
.AppendLine($"Expression.Lambda<Func<object?>>(Expression.Convert({variableName}, typeof(object)))")
.AppendLine($"Expression.Lambda<Func<object?>>(Expression.Convert({roslynPathSegment}, typeof(object)))")
.AppendLine(".Compile(preferInterpretation: true)")
.AppendLine(".Invoke());");
}
Expand Down Expand Up @@ -1073,15 +1108,23 @@ or nameof(EntityFrameworkQueryableExtensions.ToListAsync)
QueryableMethods.GetSumWithSelector(
method.GetParameters()[1].ParameterType.GenericTypeArguments[0].GenericTypeArguments[1])),

// ExecuteDelete/Update behave just like other scalar-returning operators
// ExecuteDelete behaves just like other scalar-returning operators
nameof(EntityFrameworkQueryableExtensions.ExecuteDeleteAsync) when method.DeclaringType
== typeof(EntityFrameworkQueryableExtensions)
=> RewriteToSync(
typeof(EntityFrameworkQueryableExtensions).GetMethod(nameof(EntityFrameworkQueryableExtensions.ExecuteDelete))),
nameof(EntityFrameworkQueryableExtensions.ExecuteUpdateAsync) when method.DeclaringType
== typeof(EntityFrameworkQueryableExtensions)
=> RewriteToSync(
typeof(EntityFrameworkQueryableExtensions).GetMethod(nameof(EntityFrameworkQueryableExtensions.ExecuteUpdate))),

// ExecuteUpdate is special; it accepts a non-expression-tree argument (Func<SetPropertyCalls, SetPropertyCalls>),
// evaluates it immediately, and injects a different MethodCall node into the expression tree with the resulting setter
// expressions.
// When statically analyzing ExecuteUpdate, we have to manually perform the same thing.
nameof(EntityFrameworkQueryableExtensions.ExecuteUpdate) or nameof(EntityFrameworkQueryableExtensions.ExecuteUpdateAsync)
when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)
=> Expression.Call(
EntityFrameworkQueryableExtensions.ExecuteUpdateMethodInfo.MakeGenericMethod(
terminatingOperator.Arguments[0].Type.GetSequenceType()),
penultimateOperator,
ProcessExecuteUpdate(terminatingOperator)),

// In the regular case (sync terminating operator which needs to stay in the query tree), simply compose the terminating
// operator over the penultimate and return that.
Expand Down Expand Up @@ -1116,6 +1159,56 @@ MethodCallExpression RewriteToSync(MethodInfo? syncMethod)
}
}

// Accepts an expression tree representing a series of SetProperty() calls, parses them and passes them through the SetPropertyCalls
// builder; returns the resulting NewArrayExpression representing all the setters.
private static NewArrayExpression ProcessExecuteUpdate(MethodCallExpression executeUpdateCall)
{
var setPropertyCalls = Activator.CreateInstance<SetPropertyCalls>();
var settersLambda = (LambdaExpression)executeUpdateCall.Arguments[1];
var settersParameter = settersLambda.Parameters.Single();
var expression = settersLambda.Body;

while (expression != settersParameter)
{
if (expression is MethodCallExpression
{
Method:
{
IsGenericMethod: true,
Name: nameof(SetPropertyCalls<int>.SetProperty),
DeclaringType.IsGenericType: true,
},
Arguments:
[
UnaryExpression { NodeType: ExpressionType.Quote, Operand: LambdaExpression propertySelector },
Expression valueSelector
]
} methodCallExpression
&& methodCallExpression.Method.DeclaringType.GetGenericTypeDefinition() == typeof(SetPropertyCalls<>))
{
if (valueSelector is UnaryExpression
{
NodeType: ExpressionType.Quote,
Operand: LambdaExpression unwrappedValueSelector
})
{
setPropertyCalls.SetProperty(propertySelector, unwrappedValueSelector);
}
else
{
setPropertyCalls.SetProperty(propertySelector, valueSelector);
}

expression = methodCallExpression.Object;
continue;
}

throw new InvalidOperationException(RelationalStrings.InvalidArgumentToExecuteUpdate);
}

return setPropertyCalls.BuildSettersExpression();
}

/// <summary>
/// Contains information on a failure to precompile a specific query in the user's source code.
/// Includes information about the query, its location, and the exception that occured.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,22 @@ public partial class RelationalQueryableMethodTranslatingExpressionVisitor
typeof(RelationalSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ParameterValueExtractor))!;

/// <inheritdoc />
protected override UpdateExpression? TranslateExecuteUpdate(ShapedQueryExpression source, LambdaExpression setPropertyCalls)
protected override UpdateExpression? TranslateExecuteUpdate(ShapedQueryExpression source, IReadOnlyList<ExecuteUpdateSetter> setters)
{
if (setters.Count == 0)
{
throw new UnreachableException("Empty setters list");
}

// Our source may have IncludeExpressions because of owned entities or auto-include; unwrap these, as they're meaningless for
// ExecuteUpdate's lambdas. Note that we don't currently support updates across tables.
source = source.UpdateShaperExpression(new IncludePruner().Visit(source.ShaperExpression));

var setters = new List<(LambdaExpression PropertySelector, Expression ValueExpression)>();
PopulateSetPropertyCalls(setPropertyCalls.Body, setters, setPropertyCalls.Parameters[0]);
if (TranslationErrorDetails != null)
{
return null;
}

if (setters.Count == 0)
{
AddTranslationErrorDetails(RelationalStrings.NoSetPropertyInvocation);
return null;
}

// Translate the setters: the left (property) selectors get translated to ColumnExpressions, the right (value) selectors to
// arbitrary SqlExpressions.
// Note that if the query isn't natively supported, we'll do a pushdown (see PushdownWithPkInnerJoinPredicate below); if that
Expand Down Expand Up @@ -67,42 +64,9 @@ public partial class RelationalQueryableMethodTranslatingExpressionVisitor

return PushdownWithPkInnerJoinPredicate();

void PopulateSetPropertyCalls(
Expression expression,
List<(LambdaExpression, Expression)> list,
ParameterExpression parameter)
{
switch (expression)
{
case ParameterExpression p
when parameter == p:
break;

case MethodCallExpression
{
Method:
{
IsGenericMethod: true,
Name: nameof(SetPropertyCalls<int>.SetProperty),
DeclaringType.IsGenericType: true
}
} methodCallExpression
when methodCallExpression.Method.DeclaringType.GetGenericTypeDefinition() == typeof(SetPropertyCalls<>):
list.Add(((LambdaExpression)methodCallExpression.Arguments[0], methodCallExpression.Arguments[1]));

PopulateSetPropertyCalls(methodCallExpression.Object!, list, parameter);

break;

default:
AddTranslationErrorDetails(RelationalStrings.InvalidArgumentToExecuteUpdate);
break;
}
}

bool TranslateSetters(
ShapedQueryExpression source,
List<(LambdaExpression PropertySelector, Expression ValueExpression)> setters,
IReadOnlyList<ExecuteUpdateSetter> setters,
[NotNullWhen(true)] out List<ColumnValueSetter>? translatedSetters,
[NotNullWhen(true)] out TableExpressionBase? targetTable)
{
Expand Down Expand Up @@ -464,7 +428,7 @@ SqlParameterExpression parameter
var inner = source;
var outerParameter = Expression.Parameter(entityType.ClrType);
var outerKeySelector = Expression.Lambda(outerParameter.CreateKeyValuesExpression(pk.Properties), outerParameter);
var firstPropertyLambdaExpression = setters[0].Item1;
var firstPropertyLambdaExpression = setters[0].PropertySelector;
var entitySource = GetEntitySource(RelationalDependencies.Model, firstPropertyLambdaExpression.Body);
var innerKeySelector = Expression.Lambda(
entitySource.CreateKeyValuesExpression(pk.Properties), firstPropertyLambdaExpression.Parameters);
Expand All @@ -481,6 +445,7 @@ SqlParameterExpression parameter

var propertyReplacement = AccessField(transparentIdentifierType, transparentIdentifierParameter, "Outer");
var valueReplacement = AccessField(transparentIdentifierType, transparentIdentifierParameter, "Inner");
var rewrittenSetters = new ExecuteUpdateSetter[setters.Count];
for (var i = 0; i < setters.Count; i++)
{
var (propertyExpression, valueExpression) = setters[i];
Expand All @@ -499,14 +464,14 @@ SqlParameterExpression parameter
transparentIdentifierParameter)
: valueExpression;

setters[i] = (propertyExpression, valueExpression);
rewrittenSetters[i] = new(propertyExpression, valueExpression);
}

tableExpression = (TableExpression)outerSelectExpression.Tables[0];

// Re-translate the property selectors to get column expressions pointing to the new outer select expression (the original one
// has been pushed down into a subquery).
if (!TranslateSetters(outer, setters, out var translatedSetters, out _))
if (!TranslateSetters(outer, rewrittenSetters, out var translatedSetters, out _))
{
return null;
}
Expand Down
7 changes: 6 additions & 1 deletion src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,12 @@ protected virtual SqlExpression VisitSqlParameter(
bool allowOptimizedExpansion,
out bool nullable)
{
var parameterValue = ParameterValues[sqlParameterExpression.Name];
if (!ParameterValues.TryGetValue(sqlParameterExpression.Name, out var parameterValue))
{
throw new UnreachableException(
$"Encountered SqlParameter with name '{sqlParameterExpression.Name}', but such a parameter does not exist.");
}

nullable = parameterValue == null;

if (nullable)
Expand Down
Loading