Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/EFCore.Relational/Query/ISqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,13 @@ public interface ISqlExpressionFactory
/// <param name="operand">An expression to compare with <see cref="CaseWhenClause.Test" /> in <paramref name="whenClauses" />.</param>
/// <param name="whenClauses">A list of <see cref="CaseWhenClause" /> to compare or evaluate and get result from.</param>
/// <param name="elseResult">A value to return if no <paramref name="whenClauses" /> matches, if any.</param>
/// <param name="existingExpr">An optional expression that can be re-used if it matches the new expression.</param>
/// <returns>An expression representing a CASE statement in a SQL tree.</returns>
SqlExpression Case(SqlExpression? operand, IReadOnlyList<CaseWhenClause> whenClauses, SqlExpression? elseResult);
SqlExpression Case(
SqlExpression? operand,
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult,
SqlExpression? existingExpr = null);

/// <summary>
/// Creates a new <see cref="CaseExpression" /> which represent a CASE statement in a SQL tree.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ private Expression VisitCase(CaseExpression caseExpression)

var elseResult = (SqlExpression?)Visit(caseExpression.ElseResult);

return caseExpression.Update(operand, whenClauses, elseResult);
return _sqlExpressionFactory.Case(operand, whenClauses, elseResult, caseExpression);
}

private Expression VisitSelect(SelectExpression selectExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,6 @@ protected override Expression VisitExtension(Expression extensionExpression)
return shapedQueryExpression.Update(newQueryExpression, newShaperExpression);
}

// Only applies to 'CASE WHEN condition...' not 'CASE operand WHEN...'
if (extensionExpression is CaseExpression
{
Operand: null, ElseResult: CaseExpression { Operand: null } nestedCaseExpression
} caseExpression)
{
return VisitExtension(
_sqlExpressionFactory.Case(
caseExpression.WhenClauses.Union(nestedCaseExpression.WhenClauses).ToList(),
nestedCaseExpression.ElseResult));
}

if (extensionExpression is SqlBinaryExpression sqlBinaryExpression)
{
return SimplifySqlBinary(sqlBinaryExpression);
Expand Down
110 changes: 107 additions & 3 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,11 @@ public virtual SqlExpression Negate(SqlExpression operand)
=> MakeUnary(ExpressionType.Negate, operand, operand.Type, operand.TypeMapping)!;

/// <inheritdoc />
public virtual SqlExpression Case(SqlExpression? operand, IReadOnlyList<CaseWhenClause> whenClauses, SqlExpression? elseResult)
public virtual SqlExpression Case(
SqlExpression? operand,
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult,
SqlExpression? existingExpr = null)
{
RelationalTypeMapping? testTypeMapping;
if (operand == null)
Expand Down Expand Up @@ -714,13 +718,113 @@ public virtual SqlExpression Case(SqlExpression? operand, IReadOnlyList<CaseWhen
var typeMappedWhenClauses = new List<CaseWhenClause>();
foreach (var caseWhenClause in whenClauses)
{
var test = caseWhenClause.Test;

if (operand == null && test is CaseExpression { Operand: null, WhenClauses: [var nestedSingleClause] } testExpr)
{
if (nestedSingleClause.Result is SqlConstantExpression { Value: true }
&& testExpr.ElseResult is null or SqlConstantExpression { Value: false or null })
{
// WHEN CASE
// WHEN x THEN TRUE
// ELSE FALSE/NULL
// END THEN y
// simplifies to
// WHEN x THEN y
test = nestedSingleClause.Test;
}
else if (nestedSingleClause.Result is SqlConstantExpression { Value: false or null }
&& testExpr.ElseResult is SqlConstantExpression { Value: true })
{
// same for the negated results
test = Not(nestedSingleClause.Test);
}
}

typeMappedWhenClauses.Add(
new CaseWhenClause(
ApplyTypeMapping(caseWhenClause.Test, testTypeMapping),
ApplyTypeMapping(test, testTypeMapping),
ApplyTypeMapping(caseWhenClause.Result, resultTypeMapping)));
}

return new CaseExpression(operand, typeMappedWhenClauses, elseResult);
if (operand is null && elseResult is CaseExpression { Operand: null } nestedCaseExpression)
{
typeMappedWhenClauses.AddRange(nestedCaseExpression.WhenClauses);
elseResult = nestedCaseExpression.ElseResult;
}

typeMappedWhenClauses = typeMappedWhenClauses
.Where(c => !IsSkipped(c))
.TakeUpTo(IsMatched)
.DistinctBy(c => c.Test)
.ToList();

// CASE
// ...
// WHEN TRUE THEN a
// ELSE b
// END
// simplifies to
// CASE
// ...
// ELSE a
// END
if (typeMappedWhenClauses.Count > 0 && IsMatched(typeMappedWhenClauses[^1]))
{
elseResult = typeMappedWhenClauses[^1].Result;
typeMappedWhenClauses.RemoveAt(typeMappedWhenClauses.Count - 1);
}

var nullResult = Constant(null, elseResult?.Type ?? whenClauses[0].Result.Type, resultTypeMapping);

// if there are no whenClauses left (e.g. their tests evaluated to false):
// - if there is Else block, return it
// - if there is no Else block, return null
if (typeMappedWhenClauses.Count == 0)
{
return elseResult ?? nullResult;
}

// omit `ELSE NULL` (this makes it easier to compare/reuse expressions)
if (elseResult is SqlConstantExpression { Value: null })
{
elseResult = null;
}

// CASE
// ...
// WHEN x THEN CASE
// WHEN y THEN a
// ELSE b
// END
// ELSE b
// END
// simplifies to
// CASE
// ...
// WHEN x AND y THEN a
// ELSE b
// END
if (operand == null
&& typeMappedWhenClauses[^1].Result is CaseExpression { Operand: null, WhenClauses: [var lastClause] } lastCase
&& Equals(elseResult, lastCase.ElseResult))
{
typeMappedWhenClauses[^1] = new(AndAlso(typeMappedWhenClauses[^1].Test, lastClause.Test), lastClause.Result);
elseResult = lastCase.ElseResult;
}

return existingExpr is CaseExpression expr
&& operand == expr.Operand
&& typeMappedWhenClauses.SequenceEqual(expr.WhenClauses)
&& elseResult == expr.ElseResult
? expr
: new CaseExpression(operand, typeMappedWhenClauses, elseResult);

bool IsSkipped(CaseWhenClause clause)
=> operand is null && clause.Test is SqlConstantExpression { Value: false or null };

bool IsMatched(CaseWhenClause clause)
=> operand is null && clause.Test is SqlConstantExpression { Value: true };
}

/// <inheritdoc />
Expand Down
2 changes: 1 addition & 1 deletion src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
return elseResult ?? _sqlExpressionFactory.Constant(null, caseExpression.Type, caseExpression.TypeMapping);
}

return caseExpression.Update(operand, whenClauses, elseResult);
return _sqlExpressionFactory.Case(operand, whenClauses, elseResult, caseExpression);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ protected override Expression VisitCase(CaseExpression caseExpression)

_isSearchCondition = parentSearchCondition;

return ApplyConversion(caseExpression.Update(operand, whenClauses, elseResult), condition: false);
return ApplyConversion(_sqlExpressionFactory.Case(operand, whenClauses, elseResult, caseExpression), condition: false);
}

/// <summary>
Expand Down
12 changes: 12 additions & 0 deletions src/Shared/EnumerableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ public static IEnumerable<T> Distinct<T>(
where T : class
=> source.Distinct(new DynamicEqualityComparer<T>(comparer));

public static IEnumerable<T> TakeUpTo<T>(this IEnumerable<T> source, Func<T, bool> predicate)
Comment thread
maumar marked this conversation as resolved.
{
foreach (var item in source)
{
yield return item;
if (predicate(item))
{
yield break;
}
}
}

private sealed class DynamicEqualityComparer<T> : IEqualityComparer<T>
where T : class
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2012,6 +2012,62 @@ public override Task Select_over_10_nested_ternary_condition(bool async)
SELECT ((c["CustomerID"] = "1") ? "01" : ((c["CustomerID"] = "2") ? "02" : ((c["CustomerID"] = "3") ? "03" : ((c["CustomerID"] = "4") ? "04" : ((c["CustomerID"] = "5") ? "05" : ((c["CustomerID"] = "6") ? "06" : ((c["CustomerID"] = "7") ? "07" : ((c["CustomerID"] = "8") ? "08" : ((c["CustomerID"] = "9") ? "09" : ((c["CustomerID"] = "10") ? "10" : ((c["CustomerID"] = "11") ? "11" : null))))))))))) AS c
FROM root c
WHERE (c["Discriminator"] = "Customer")
""");
});

public override Task Select_conditional_drops_false(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Select_conditional_drops_false(a);

AssertSql(
"""
SELECT (((c["OrderID"] % 2) = 0) ? c["OrderID"] : -(c["OrderID"])) AS c
FROM root c
WHERE (c["Discriminator"] = "Order")
""");
});

public override Task Select_conditional_terminates_at_true(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Select_conditional_terminates_at_true(a);

AssertSql(
"""
SELECT (((c["OrderID"] % 2) = 0) ? c["OrderID"] : 0) AS c
FROM root c
WHERE (c["Discriminator"] = "Order")
""");
});

public override Task Select_conditional_flatten_nested_results(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Select_conditional_flatten_nested_results(a);

AssertSql(
"""
SELECT (((c["OrderID"] % 2) = 0) ? (((c["OrderID"] % 5) = 0) ? -(c["OrderID"]) : c["OrderID"]) : c["OrderID"]) AS c
FROM root c
WHERE (c["Discriminator"] = "Order")
""");
});

public override Task Select_conditional_flatten_nested_tests(bool async)
=> Fixture.NoSyncTest(
async, async a =>
{
await base.Select_conditional_flatten_nested_tests(a);

AssertSql(
"""
SELECT ((((c["OrderID"] % 2) = 0) ? false : true) ? c["OrderID"] : -(c["OrderID"])) AS c
FROM root c
WHERE (c["Discriminator"] = "Order")
""");
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,52 @@ public virtual Task Select_over_10_nested_ternary_condition(bool isAsync)
? "11"
: null);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_conditional_drops_false(bool async)
=> AssertQueryScalar(
async,
ss => from o in ss.Set<Order>()
select o.OrderID % 2 == 0
? o.OrderID
: false
? 0
: -o.OrderID );

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_conditional_terminates_at_true(bool async)
=> AssertQueryScalar(
async,
ss => from o in ss.Set<Order>()
select o.OrderID % 2 == 0
? o.OrderID
: true
? 0
: -o.OrderID );

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_conditional_flatten_nested_results(bool async)
=> AssertQueryScalar(
async,
ss => from o in ss.Set<Order>()
select o.OrderID % 2 == 0
? o.OrderID % 5 == 0
? -o.OrderID
: o.OrderID
: o.OrderID);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_conditional_flatten_nested_tests(bool async)
=> AssertQueryScalar(
async,
ss => from o in ss.Set<Order>()
select (o.OrderID % 2 == 0 ? false : true)
? o.OrderID
: -o.OrderID);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Projection_in_a_subquery_should_be_liftable(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,6 @@ public override async Task Conditional_expression_with_conditions_does_not_colla
"""
SELECT CASE
WHEN [c0].[Id] IS NOT NULL THEN [c0].[Processed] ^ CAST(1 AS bit)
ELSE NULL
END AS [Processing]
FROM [Carts] AS [c]
LEFT JOIN [Configuration] AS [c0] ON [c].[ConfigurationId] = [c0].[Id]
Expand Down Expand Up @@ -2398,7 +2397,6 @@ FROM [Customers] AS [c]
LEFT JOIN [Countries] AS [c1] ON [c0].[CountryId] = [c1].[Id]
WHERE CASE
WHEN [c0].[Id] IS NOT NULL THEN [c1].[CountryName]
ELSE NULL
END = N'COUNTRY'
""");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2990,15 +2990,13 @@ public override async Task Project_collection_and_nested_conditional(bool async)
WHEN [l].[Id] = 1 THEN N'01'
WHEN [l].[Id] = 2 THEN N'02'
WHEN [l].[Id] = 3 THEN N'03'
ELSE NULL
END
FROM [LevelOne] AS [l]
LEFT JOIN [LevelTwo] AS [l0] ON [l].[Id] = [l0].[OneToMany_Optional_Inverse2Id]
WHERE CASE
WHEN [l].[Id] = 1 THEN N'01'
WHEN [l].[Id] = 2 THEN N'02'
WHEN [l].[Id] = 3 THEN N'03'
ELSE NULL
END = N'02'
ORDER BY [l].[Id], [l0].[Id]
""");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3878,7 +3878,6 @@ public override async Task Project_collection_and_nested_conditional(bool async)
WHEN [l].[Id] = 1 THEN N'01'
WHEN [l].[Id] = 2 THEN N'02'
WHEN [l].[Id] = 3 THEN N'03'
ELSE NULL
END
FROM [Level1] AS [l]
LEFT JOIN (
Expand All @@ -3892,7 +3891,6 @@ WHERE CASE
WHEN [l].[Id] = 1 THEN N'01'
WHEN [l].[Id] = 2 THEN N'02'
WHEN [l].[Id] = 3 THEN N'03'
ELSE NULL
END = N'02'
ORDER BY [l].[Id], [l1].[c]
""");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4425,14 +4425,12 @@ public override async Task Project_collection_and_nested_conditional(bool async)
WHEN [l].[Id] = 1 THEN N'01'
WHEN [l].[Id] = 2 THEN N'02'
WHEN [l].[Id] = 3 THEN N'03'
ELSE NULL
END
FROM [LevelOne] AS [l]
WHERE CASE
WHEN [l].[Id] = 1 THEN N'01'
WHEN [l].[Id] = 2 THEN N'02'
WHEN [l].[Id] = 3 THEN N'03'
ELSE NULL
END = N'02'
ORDER BY [l].[Id]
""",
Expand All @@ -4445,7 +4443,6 @@ WHERE CASE
WHEN [l].[Id] = 1 THEN N'01'
WHEN [l].[Id] = 2 THEN N'02'
WHEN [l].[Id] = 3 THEN N'03'
ELSE NULL
END = N'02'
ORDER BY [l].[Id], [l2].[Id]
""");
Expand Down
Loading