Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Directory.Build.props
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<Project>
<PropertyGroup>
<TargetFrameworks>net8.0;net10.0</TargetFrameworks>
<TargetFrameworks>net8.0;net9.0;net10.0</TargetFrameworks>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<LangVersion>12.0</LangVersion>
<LangVersion Condition="'$(TargetFramework)' == 'net10.0'">14.0</LangVersion>
Expand Down
3 changes: 3 additions & 0 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@
</PropertyGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'net8.0'">
<PackageVersion Include="Microsoft.EntityFrameworkCore" Version="8.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Relational" Version="8.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Sqlite" Version="8.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.SqlServer" Version="8.0.0" />
<PackageVersion Include="Basic.Reference.Assemblies.Net80" Version="1.8.3" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'net9.0'">
<PackageVersion Include="Microsoft.EntityFrameworkCore" Version="9.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Relational" Version="9.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Sqlite" Version="9.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.SqlServer" Version="9.0.0" />
<PackageVersion Include="Basic.Reference.Assemblies.Net90" Version="1.8.3" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'net10.0'">
<PackageVersion Include="Microsoft.EntityFrameworkCore" Version="10.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Relational" Version="10.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.Sqlite" Version="10.0.0" />
<PackageVersion Include="Microsoft.EntityFrameworkCore.SqlServer" Version="10.0.0" />
<PackageVersion Include="Basic.Reference.Assemblies.Net100" Version="1.8.3" />
Expand Down
38 changes: 38 additions & 0 deletions src/EntityFrameworkCore.Projectables.Abstractions/Variable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
namespace EntityFrameworkCore.Projectables;

/// <summary>
/// Utility class for marking reused local variables in projectable expression trees,
/// enabling the SQL generator to hoist shared computations into <c>CROSS APPLY</c> (SQL Server)
/// or <c>CROSS JOIN LATERAL</c> (PostgreSQL) inline subqueries.
/// </summary>
public static class Variable
{
/// <summary>
/// Identity function that marks a reused local variable in a generated expression tree.
/// <para>
/// When the same <paramref name="name"/> appears more than once in a generated
/// expression tree (because the corresponding local variable was referenced multiple times
/// in a <c>[Projectable(AllowBlockBody = true)]</c> method body), the SQL generator hoists
/// the shared computation into a single inline subquery evaluated exactly once per row:
/// <code>
/// -- SQL Server
/// CROSS APPLY (SELECT &lt;inner expression&gt; AS [name]) AS [v]
///
/// -- PostgreSQL
/// CROSS JOIN LATERAL (SELECT &lt;inner expression&gt; AS "name") AS "v"
/// </code>
/// </para>
/// <para>
/// At runtime this method is a pure identity function: it returns
/// <paramref name="value"/> unchanged and has no observable effect.
/// </para>
/// </summary>
/// <typeparam name="T">The type of the value.</typeparam>
/// <param name="name">
/// The original local variable name, used to correlate multiple uses of the same
/// computation within a single expression tree.
/// </param>
/// <param name="value">The value to pass through unchanged.</param>
/// <returns><paramref name="value"/> unchanged.</returns>
public static T Wrap<T>(string name, T value) => value;
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ internal class BlockStatementConverter
private readonly SourceProductionContext _context;
private readonly ExpressionSyntaxRewriter _expressionRewriter;
private readonly Dictionary<string, ExpressionSyntax> _localVariables = new();
private readonly Dictionary<string, int> _localVariableReferenceCount = new();

// Pre-computed reference counts for each local variable across the code statements
// (statements that are not local declarations). Variables with count > 1 are wrapped
// in Variable.Wrap so the SQL generator can hoist them into a CROSS APPLY / LATERAL
// inline subquery, computing the expression exactly once per row.
private IReadOnlyDictionary<string, int> _preComputedRefCounts = new Dictionary<string, int>();

public BlockStatementConverter(SourceProductionContext context, ExpressionSyntaxRewriter expressionRewriter)
{
Expand Down Expand Up @@ -90,7 +97,12 @@ public BlockStatementConverter(SourceProductionContext context, ExpressionSyntax
return null;
}

// Right-to-left fold: build nested expressions so that each statement wraps the
// Pre-compute how many times each local variable is referenced in the code
// statements (non-declaration statements). Variables with count > 1 will be
// wrapped in Variable.Wrap in the generated expression tree so the SQL generator
// can identify shared computations and hoist them into a CROSS APPLY / LATERAL
// inline subquery.
_preComputedRefCounts = ComputeCodeStatementRefCounts(codeStatements);
// next as its "fallthrough" branch. This naturally handles chains like:
// if (a) return 1; if (b) return 2; return 3;
// => a ? 1 : (b ? 2 : 3)
Expand Down Expand Up @@ -346,9 +358,36 @@ private bool TryProcessLocalDeclaration(LocalDeclarationStatementSyntax localDec

/// <summary>
/// Replaces references to local variables in the given expression with their initializer expressions.
/// Also tracks how many times each variable is referenced via <see cref="_localVariableReferenceCount"/>.
/// Variables referenced more than once in the final expression are wrapped in
/// <c>Variable.Wrap("name", expr)</c> so the SQL generator can hoist them into a
/// <c>CROSS APPLY</c> / <c>CROSS JOIN LATERAL</c> inline subquery.
/// </summary>
private ExpressionSyntax ReplaceLocalVariables(ExpressionSyntax expression)
=> (ExpressionSyntax)new LocalVariableReplacer(_localVariables).Visit(expression);
=> (ExpressionSyntax)new LocalVariableReplacer(_localVariables, _localVariableReferenceCount, _preComputedRefCounts).Visit(expression);

/// <summary>
/// Counts the number of standalone identifier references to each local variable
/// in the given code statements (non-declaration statements).
/// </summary>
private static IReadOnlyDictionary<string, int> ComputeCodeStatementRefCounts(
IReadOnlyList<StatementSyntax> codeStatements)
{
// We deliberately don't use a set here — we want to count every occurrence,
// not just whether the identifier appears at all.
var counts = new Dictionary<string, int>(StringComparer.Ordinal);

foreach (var stmt in codeStatements)
{
foreach (var identifier in stmt.DescendantNodes().OfType<IdentifierNameSyntax>())
{
var name = identifier.Identifier.ValueText;
counts[name] = counts.TryGetValue(name, out var existing) ? existing + 1 : 1;
}
}

return counts;
}

private static LiteralExpressionSyntax DefaultLiteral()
=> SyntaxFactory.LiteralExpression(
Expand Down Expand Up @@ -450,21 +489,62 @@ private void ReportUnsupportedStatement(StatementSyntax statement, string member
private class LocalVariableReplacer : CSharpSyntaxRewriter
{
private readonly Dictionary<string, ExpressionSyntax> _localVariables;
private readonly Dictionary<string, int> _referenceCount;
private readonly IReadOnlyDictionary<string, int> _preComputedRefCounts;

public LocalVariableReplacer(Dictionary<string, ExpressionSyntax> localVariables)
public LocalVariableReplacer(
Dictionary<string, ExpressionSyntax> localVariables,
Dictionary<string, int> referenceCount,
IReadOnlyDictionary<string, int> preComputedRefCounts)
{
_localVariables = localVariables;
_referenceCount = referenceCount;
_preComputedRefCounts = preComputedRefCounts;
}

public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
{
if (_localVariables.TryGetValue(node.Identifier.Text, out var replacement))
{
return SyntaxFactory.ParenthesizedExpression(replacement.WithoutTrivia())
.WithTriviaFrom(node);
var varName = node.Identifier.Text;
_referenceCount[varName] = _referenceCount.TryGetValue(varName, out var count)
? count + 1
: 1;

var inner = SyntaxFactory.ParenthesizedExpression(replacement.WithoutTrivia());

// When the variable is referenced more than once in the code statements,
// wrap the substituted expression in Variable.Wrap("name", expr).
// This embeds a reuse marker into the generated expression tree so the
// SQL generator can hoist shared sub-computations into a CROSS APPLY /
// CROSS JOIN LATERAL inline subquery, evaluated exactly once per row.
if (_preComputedRefCounts.TryGetValue(varName, out var preCount) && preCount > 1)
{
return BuildVariableWrapCall(varName, inner).WithTriviaFrom(node);
}

return inner.WithTriviaFrom(node);
}

return base.VisitIdentifierName(node);
}
}
}

/// <summary>
/// Builds a <c>global::EntityFrameworkCore.Projectables.Variable.Wrap("name", value)</c>
/// invocation expression.
/// </summary>
private static ExpressionSyntax BuildVariableWrapCall(string name, ExpressionSyntax value)
=> SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.ParseName("global::EntityFrameworkCore.Projectables.Variable"),
SyntaxFactory.IdentifierName("Wrap")),
SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(new[]
{
SyntaxFactory.Argument(
SyntaxFactory.LiteralExpression(
SyntaxKind.StringLiteralExpression,
SyntaxFactory.Literal(name))),
SyntaxFactory.Argument(value),
})));
}}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using EntityFrameworkCore.Projectables.Infrastructure;
using EntityFrameworkCore.Projectables.Infrastructure.Internal;
using EntityFrameworkCore.Projectables.Query;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata.Conventions;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure;
Expand Down Expand Up @@ -41,6 +42,10 @@ public void ApplyServices(IServiceCollection services)
// Register a convention that will ignore properties marked with the ProjectableAttribute
services.AddScoped<IConventionSetPlugin, ProjectablePropertiesNotMappedConventionPlugin>();

// Translate Variable.Wrap(name, expr) calls to VariableWrapSqlExpression so the
// ProjectablesQuerySqlGenerator can decide whether to inline or CROSS-APPLY them.
services.AddSingleton<IMethodCallTranslatorPlugin, VariableWrapTranslatorPlugin>();

static object CreateTargetInstance(IServiceProvider services, ServiceDescriptor descriptor)
{
if (descriptor.ImplementationInstance is not null)
Expand All @@ -57,6 +62,47 @@ static object CreateTargetInstance(IServiceProvider services, ServiceDescriptor
// Custom convention to handle global query filters, etc
services.AddScoped<IConventionSetPlugin, CustomConventionSetPlugin>();

// Register the SQL generator factory that emits CROSS APPLY / CROSS JOIN LATERAL
// subqueries for reused local variables in block-bodied projectable methods.
services.Replace(ServiceDescriptor.Scoped<IQuerySqlGeneratorFactory, ProjectablesQuerySqlGeneratorFactory>());

// Wrap the query translation postprocessor to handle VariableWrapSqlExpression before
// EF Core's SqlNullabilityProcessor encounters it.
var postprocessorDescriptor = services.FirstOrDefault(x => x.ServiceType == typeof(IQueryTranslationPostprocessorFactory));
if (postprocessorDescriptor is not null)
{
var decoratorObjectFactory = ActivatorUtilities.CreateFactory(
typeof(VariableWrapQueryTranslationPostprocessorFactory),
new[] { postprocessorDescriptor.ServiceType });

services.Replace(ServiceDescriptor.Describe(
postprocessorDescriptor.ServiceType,
serviceProvider => decoratorObjectFactory(serviceProvider, new[] { CreateTargetInstance(serviceProvider, postprocessorDescriptor) }),
postprocessorDescriptor.Lifetime
));
}

#if NET8_0 || NET9_0
// In EF Core 8/9 the execution-time SqlNullabilityProcessor (run inside
// RelationalParameterBasedSqlProcessor.Optimize) throws on unknown
// TableExpressionBase subtypes — including our InlineSubqueryExpression.
// ProjectablesParameterBasedSqlProcessorFactory decorates the provider's factory
// and temporarily hides InlineSubqueryExpression tables around the nullability pass.
var paramSqlDescriptor = services.FirstOrDefault(x => x.ServiceType == typeof(IRelationalParameterBasedSqlProcessorFactory));
if (paramSqlDescriptor is not null)
{
var paramFactory = ActivatorUtilities.CreateFactory(
typeof(ProjectablesParameterBasedSqlProcessorFactory),
new[] { paramSqlDescriptor.ServiceType });

services.Replace(ServiceDescriptor.Describe(
paramSqlDescriptor.ServiceType,
sp => paramFactory(sp, new[] { CreateTargetInstance(sp, paramSqlDescriptor) }),
paramSqlDescriptor.Lifetime
));
}
#endif

if (_compatibilityMode is CompatibilityMode.Full)
{
var targetDescriptor = services.FirstOrDefault(x => x.ServiceType == typeof(IQueryCompiler));
Expand Down
Loading