Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -660,20 +660,10 @@ private SqlExpression Not(SqlExpression operand, SqlExpression? existingExpressi
SqlBinaryExpression { OperatorType: ExpressionType.OrElse } binary
=> AndAlso(Not(binary.Left), Not(binary.Right)),

// use equality where possible
// !(a == true) -> a == false
// !(a == false) -> a == true
SqlBinaryExpression { OperatorType: ExpressionType.Equal, Right: SqlConstantExpression { Value: bool } } binary
=> Equal(binary.Left, Not(binary.Right)),

// !(true == a) -> false == a
// !(false == a) -> true == a
SqlBinaryExpression { OperatorType: ExpressionType.Equal, Left: SqlConstantExpression { Value: bool } } binary
=> Equal(Not(binary.Left), binary.Right),

// !(a == b) -> a != b
SqlBinaryExpression { OperatorType: ExpressionType.Equal } sqlBinaryOperand => NotEqual(
sqlBinaryOperand.Left, sqlBinaryOperand.Right),

// !(a != b) -> a == b
SqlBinaryExpression { OperatorType: ExpressionType.NotEqual } sqlBinaryOperand => Equal(
sqlBinaryOperand.Left, sqlBinaryOperand.Right),
Expand Down
143 changes: 87 additions & 56 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ SqlExpression ProcessJoinPredicate(SqlExpression predicate)
right,
leftNullable,
rightNullable,
optimize: true,
out _);

return result;
Expand Down Expand Up @@ -1107,6 +1108,7 @@ protected virtual SqlExpression VisitSqlBinary(
right,
leftNullable,
rightNullable,
optimize,
out nullable);

if (optimized is SqlUnaryExpression { Operand: ColumnExpression optimizedUnaryColumnOperand } optimizedUnary)
Expand All @@ -1124,7 +1126,7 @@ protected virtual SqlExpression VisitSqlBinary(
// we assume that NullSemantics rewrite is only needed (on the current level)
// if the optimization didn't make any changes.
// Reason is that optimization can/will change the nullability of the resulting expression
// and that inforation is not tracked/stored anywhere
// and that information is not tracked/stored anywhere
// so we can no longer rely on nullabilities that we computed earlier (leftNullable, rightNullable)
// when performing null semantics rewrite.
// It should be fine because current optimizations *radically* change the expression
Expand Down Expand Up @@ -1451,6 +1453,7 @@ private SqlExpression OptimizeComparison(
SqlExpression right,
bool leftNullable,
bool rightNullable,
bool optimize,
out bool nullable)
{
var leftNullValue = leftNullable && left is SqlConstantExpression or SqlParameterExpression;
Expand Down Expand Up @@ -1531,47 +1534,20 @@ private SqlExpression OptimizeComparison(
&& !rightNullable
&& sqlBinaryExpression.OperatorType is ExpressionType.Equal or ExpressionType.NotEqual)
{
var leftUnary = left as SqlUnaryExpression;
var rightUnary = right as SqlUnaryExpression;

var leftNegated = IsLogicalNot(leftUnary);
var rightNegated = IsLogicalNot(rightUnary);

if (leftNegated)
{
left = leftUnary!.Operand;
}

if (rightNegated)
{
right = rightUnary!.Operand;
}

// a == b <=> !a == !b -> a == b
// !a == b <=> a == !b -> a != b
// a != b <=> !a != !b -> a != b
// !a != b <=> a != !b -> a == b

nullable = false;

return sqlBinaryExpression.OperatorType == ExpressionType.Equal ^ leftNegated == rightNegated
? _sqlExpressionFactory.NotEqual(left, right)
: _sqlExpressionFactory.Equal(left, right);
return OptimizeBooleanComparison(sqlBinaryExpression, left, right, optimize);
}

nullable = false;

return sqlBinaryExpression.Update(left, right);
}

private SqlExpression RewriteNullSemantics(
private SqlExpression OptimizeBooleanComparison(
SqlBinaryExpression sqlBinaryExpression,
SqlExpression left,
SqlExpression right,
bool leftNullable,
bool rightNullable,
bool optimize,
out bool nullable)
bool optimize)
{
var leftUnary = left as SqlUnaryExpression;
var rightUnary = right as SqlUnaryExpression;
Expand All @@ -1589,22 +1565,49 @@ private SqlExpression RewriteNullSemantics(
right = rightUnary!.Operand;
}

var notEqual = sqlBinaryExpression.OperatorType == ExpressionType.Equal ^ leftNegated == rightNegated;

// prefer equality in predicates
if (optimize && notEqual && left.Type == typeof(bool))
{
if (right is ColumnExpression && (left is not ColumnExpression || leftNegated))
{
left = _sqlExpressionFactory.Not(left);
}
else
{
right = _sqlExpressionFactory.Not(right);
}

return _sqlExpressionFactory.Equal(left, right);
}

// a == b <=> !a == !b -> a == b
// !a == b <=> a == !b -> a != b
// a != b <=> !a != !b -> a != b
// !a != b <=> a != !b -> a == b

return notEqual
? _sqlExpressionFactory.NotEqual(left, right)
: _sqlExpressionFactory.Equal(left, right);
}

private SqlExpression RewriteNullSemantics(
SqlBinaryExpression sqlBinaryExpression,
SqlExpression left,
SqlExpression right,
bool leftNullable,
bool rightNullable,
bool optimize,
out bool nullable)
{
var leftIsNull = ProcessNullNotNull(_sqlExpressionFactory.IsNull(left), leftNullable);
var leftIsNotNull = _sqlExpressionFactory.Not(leftIsNull);
var leftIsNotNull = OptimizeNotExpression(_sqlExpressionFactory.Not(leftIsNull));

var rightIsNull = ProcessNullNotNull(_sqlExpressionFactory.IsNull(right), rightNullable);
var rightIsNotNull = _sqlExpressionFactory.Not(rightIsNull);
var rightIsNotNull = OptimizeNotExpression(_sqlExpressionFactory.Not(rightIsNull));

SqlExpression body;
if (leftNegated == rightNegated)
{
body = _sqlExpressionFactory.Equal(left, right);
}
else
{
// a == !b and !a == b in SQL evaluate the same as a != b
body = _sqlExpressionFactory.NotEqual(left, right);
}
var body = OptimizeBooleanComparison(sqlBinaryExpression, left, right, optimize);

// optimized expansion which doesn't distinguish between null and false
if (optimize && sqlBinaryExpression.OperatorType == ExpressionType.Equal)
Expand All @@ -1617,6 +1620,12 @@ private SqlExpression RewriteNullSemantics(
// doing a full null semantics rewrite - removing all nulls from truth table
nullable = false;

if (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual)
{
// the factory takes care of simplifying equal <-> not-equal
body = _sqlExpressionFactory.Not(body);
}

// (a == b && (a != null && b != null)) || (a == null && b == null)
body = _sqlExpressionFactory.OrElse(
_sqlExpressionFactory.AndAlso(body, _sqlExpressionFactory.AndAlso(leftIsNotNull, rightIsNotNull)),
Expand All @@ -1625,7 +1634,7 @@ private SqlExpression RewriteNullSemantics(
if (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual)
{
// the factory takes care of simplifying using DeMorgan
body = _sqlExpressionFactory.Not(body);
body = OptimizeNotExpression(_sqlExpressionFactory.Not(body));
}

return body;
Expand All @@ -1643,18 +1652,40 @@ protected virtual SqlExpression OptimizeNotExpression(SqlExpression expression)
return expression;
}

// !(a > b) -> a <= b
// !(a >= b) -> a < b
// !(a < b) -> a >= b
// !(a <= b) -> a > b
if (sqlUnaryExpression.Operand is SqlBinaryExpression sqlBinaryOperand
&& TryNegate(sqlBinaryOperand.OperatorType, out var negated))
if (sqlUnaryExpression.Operand is SqlBinaryExpression sqlBinaryOperand)
{
return _sqlExpressionFactory.MakeBinary(
negated,
sqlBinaryOperand.Left,
sqlBinaryOperand.Right,
sqlBinaryOperand.TypeMapping)!;
// !(a > b) -> a <= b
// !(a >= b) -> a < b
// !(a < b) -> a >= b
// !(a <= b) -> a > b
if (TryNegate(sqlBinaryOperand.OperatorType, out var negated))
{
return _sqlExpressionFactory.MakeBinary(
negated,
sqlBinaryOperand.Left,
sqlBinaryOperand.Right,
sqlBinaryOperand.TypeMapping)!;
}

// use equality where possible - at this point (true == null) and (false == null) have been converted to
// IS NULL / IS NOT NULL (i.e. false), so this optimization is safe to do. See #35393
// !(a == true) -> a == false
// !(a == false) -> a == true
if (sqlBinaryOperand is { OperatorType: ExpressionType.Equal, Right: SqlConstantExpression { Value: bool } })
{
return _sqlExpressionFactory.Equal(
sqlBinaryOperand.Left,
OptimizeNotExpression(_sqlExpressionFactory.Not(sqlBinaryOperand.Right)));
}

// !(true == a) -> false == a
// !(false == a) -> true == a
if (sqlBinaryOperand is { OperatorType: ExpressionType.Equal, Left: SqlConstantExpression { Value: bool } })
{
return _sqlExpressionFactory.Equal(
OptimizeNotExpression(_sqlExpressionFactory.Not(sqlBinaryOperand.Left)),
sqlBinaryOperand.Right);
}
}

// the factory can optimize most `NOT` expressions
Expand Down Expand Up @@ -2039,7 +2070,7 @@ private SqlExpression ProcessNullNotNull(SqlExpression sqlExpression, bool opera
return result;
}
}
break;
break;
}

return sqlUnaryExpression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2459,6 +2459,25 @@ await AssertQueryScalar(
ss => ss.Set<NullSemanticsEntity1>().Where(e => true).Select(e => e.Id));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Compare_constant_true_to_nullable_column_negated(bool async)
=> await AssertQueryScalar(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => !(true == x.NullableBoolA)).Select(x => x.Id));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Compare_constant_true_to_expression_which_evaluates_to_null(bool async)
{
var prm = default(bool?);

await AssertQueryScalar(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => x.NullableBoolA != null
&& !object.Equals(true, x.NullableBoolA == null ? null : prm)).Select(x => x.Id));
}

// We can't client-evaluate Like (for the expected results).
// However, since the test data has no LIKE wildcards, it effectively functions like equality - except that 'null like null' returns
// false instead of true. So we have this "lite" implementation which doesn't support wildcards.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,24 @@ public virtual Task GroupJoin_aggregate_nested_anonymous_key_selectors(bool asyn
(c, g) => new { c.CustomerID, Sum = g.Sum(x => x.CustomerID.Length) }),
elementSorter: e => e.CustomerID));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_on_true_equal_true(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().GroupJoin(
ss.Set<Order>(),
x => true,
x => true,
(c, g) => new { c, g })
.Select(x => new { x.c.CustomerID, Orders = x.g }),
elementSorter: e => e.CustomerID,
elementAsserter: (e, a) =>
{
Assert.Equal(e.CustomerID, a.CustomerID);
AssertCollection(e.Orders, a.Orders, elementSorter: ee => ee.OrderID);
});

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Inner_join_with_tautology_predicate_converts_to_cross_join(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1369,10 +1369,10 @@ public override async Task Where_navigation_property_to_collection(bool async)
SELECT [l].[Id], [l].[Date], [l].[Name], [l].[OneToMany_Optional_Self_Inverse1Id], [l].[OneToMany_Required_Self_Inverse1Id], [l].[OneToOne_Optional_Self1Id]
FROM [LevelOne] AS [l]
LEFT JOIN [LevelTwo] AS [l0] ON [l].[Id] = [l0].[Level1_Required_Id]
WHERE (
SELECT COUNT(*)
WHERE EXISTS (
SELECT 1
FROM [LevelThree] AS [l1]
WHERE [l0].[Id] IS NOT NULL AND [l0].[Id] = [l1].[OneToMany_Optional_Inverse3Id]) > 0
WHERE [l0].[Id] IS NOT NULL AND [l0].[Id] = [l1].[OneToMany_Optional_Inverse3Id])
""");
}

Expand All @@ -1385,10 +1385,10 @@ public override async Task Where_navigation_property_to_collection2(bool async)
SELECT [l].[Id], [l].[Level2_Optional_Id], [l].[Level2_Required_Id], [l].[Name], [l].[OneToMany_Optional_Inverse3Id], [l].[OneToMany_Optional_Self_Inverse3Id], [l].[OneToMany_Required_Inverse3Id], [l].[OneToMany_Required_Self_Inverse3Id], [l].[OneToOne_Optional_PK_Inverse3Id], [l].[OneToOne_Optional_Self3Id]
FROM [LevelThree] AS [l]
INNER JOIN [LevelTwo] AS [l0] ON [l].[Level2_Required_Id] = [l0].[Id]
WHERE (
SELECT COUNT(*)
WHERE EXISTS (
SELECT 1
FROM [LevelThree] AS [l1]
WHERE [l0].[Id] = [l1].[OneToMany_Optional_Inverse3Id]) > 0
WHERE [l0].[Id] = [l1].[OneToMany_Optional_Inverse3Id])
""");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5055,16 +5055,16 @@ WHERE [l4].[OneToOne_Required_PK_Date] IS NOT NULL AND [l4].[Level1_Required_Id]
) AS [l5] ON [l3].[Level2_Required_Id] = CASE
WHEN [l5].[OneToOne_Required_PK_Date] IS NOT NULL AND [l5].[Level1_Required_Id] IS NOT NULL AND [l5].[OneToMany_Required_Inverse2Id] IS NOT NULL THEN [l5].[Id]
END
WHERE [l1].[OneToOne_Required_PK_Date] IS NOT NULL AND [l1].[Level1_Required_Id] IS NOT NULL AND [l1].[OneToMany_Required_Inverse2Id] IS NOT NULL AND [l3].[Level2_Required_Id] IS NOT NULL AND [l3].[OneToMany_Required_Inverse3Id] IS NOT NULL AND (
SELECT COUNT(*)
WHERE [l1].[OneToOne_Required_PK_Date] IS NOT NULL AND [l1].[Level1_Required_Id] IS NOT NULL AND [l1].[OneToMany_Required_Inverse2Id] IS NOT NULL AND [l3].[Level2_Required_Id] IS NOT NULL AND [l3].[OneToMany_Required_Inverse3Id] IS NOT NULL AND EXISTS (
SELECT 1
FROM [Level1] AS [l6]
WHERE [l6].[Level2_Required_Id] IS NOT NULL AND [l6].[OneToMany_Required_Inverse3Id] IS NOT NULL AND CASE
WHEN [l5].[OneToOne_Required_PK_Date] IS NOT NULL AND [l5].[Level1_Required_Id] IS NOT NULL AND [l5].[OneToMany_Required_Inverse2Id] IS NOT NULL THEN [l5].[Id]
END IS NOT NULL AND (CASE
WHEN [l5].[OneToOne_Required_PK_Date] IS NOT NULL AND [l5].[Level1_Required_Id] IS NOT NULL AND [l5].[OneToMany_Required_Inverse2Id] IS NOT NULL THEN [l5].[Id]
END = [l6].[OneToMany_Optional_Inverse3Id] OR (CASE
WHEN [l5].[OneToOne_Required_PK_Date] IS NOT NULL AND [l5].[Level1_Required_Id] IS NOT NULL AND [l5].[OneToMany_Required_Inverse2Id] IS NOT NULL THEN [l5].[Id]
END IS NULL AND [l6].[OneToMany_Optional_Inverse3Id] IS NULL))) > 0
END IS NULL AND [l6].[OneToMany_Optional_Inverse3Id] IS NULL)))
""");
}

Expand Down Expand Up @@ -6255,16 +6255,16 @@ LEFT JOIN (
FROM [Level1] AS [l0]
WHERE [l0].[OneToOne_Required_PK_Date] IS NOT NULL AND [l0].[Level1_Required_Id] IS NOT NULL AND [l0].[OneToMany_Required_Inverse2Id] IS NOT NULL
) AS [l1] ON [l].[Id] = [l1].[Level1_Required_Id]
WHERE (
SELECT COUNT(*)
WHERE EXISTS (
SELECT 1
FROM [Level1] AS [l2]
WHERE [l2].[Level2_Required_Id] IS NOT NULL AND [l2].[OneToMany_Required_Inverse3Id] IS NOT NULL AND CASE
WHEN [l1].[OneToOne_Required_PK_Date] IS NOT NULL AND [l1].[Level1_Required_Id] IS NOT NULL AND [l1].[OneToMany_Required_Inverse2Id] IS NOT NULL THEN [l1].[Id]
END IS NOT NULL AND (CASE
WHEN [l1].[OneToOne_Required_PK_Date] IS NOT NULL AND [l1].[Level1_Required_Id] IS NOT NULL AND [l1].[OneToMany_Required_Inverse2Id] IS NOT NULL THEN [l1].[Id]
END = [l2].[OneToMany_Optional_Inverse3Id] OR (CASE
WHEN [l1].[OneToOne_Required_PK_Date] IS NOT NULL AND [l1].[Level1_Required_Id] IS NOT NULL AND [l1].[OneToMany_Required_Inverse2Id] IS NOT NULL THEN [l1].[Id]
END IS NULL AND [l2].[OneToMany_Optional_Inverse3Id] IS NULL))) > 0
END IS NULL AND [l2].[OneToMany_Optional_Inverse3Id] IS NULL)))
""");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,9 +556,9 @@ public override async Task String_ends_with_not_equals_nullable_column(bool asyn
FROM [FunkyCustomers] AS [f]
CROSS JOIN [FunkyCustomers] AS [f0]
WHERE CASE
WHEN [f].[FirstName] IS NOT NULL AND [f0].[LastName] IS NOT NULL AND RIGHT([f].[FirstName], LEN([f0].[LastName])) = [f0].[LastName] THEN CAST(1 AS bit)
WHEN [f].[FirstName] IS NULL OR [f0].[LastName] IS NULL OR RIGHT([f].[FirstName], LEN([f0].[LastName])) <> [f0].[LastName] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END <> [f].[NullableBool] OR [f].[NullableBool] IS NULL
END = [f].[NullableBool] OR [f].[NullableBool] IS NULL
""");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,9 +556,9 @@ public override async Task String_ends_with_not_equals_nullable_column(bool asyn
FROM [FunkyCustomers] AS [f]
CROSS JOIN [FunkyCustomers] AS [f0]
WHERE CASE
WHEN [f].[FirstName] IS NOT NULL AND [f0].[LastName] IS NOT NULL AND RIGHT([f].[FirstName], LEN([f0].[LastName])) = [f0].[LastName] THEN CAST(1 AS bit)
WHEN [f].[FirstName] IS NULL OR [f0].[LastName] IS NULL OR RIGHT([f].[FirstName], LEN([f0].[LastName])) <> [f0].[LastName] THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END <> [f].[NullableBool] OR [f].[NullableBool] IS NULL
END = [f].[NullableBool] OR [f].[NullableBool] IS NULL
""");
}

Expand Down
Loading