From cedab3d2766c6a881477106efca306e7a646e114 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Thu, 27 Apr 2023 17:21:23 +0200 Subject: [PATCH 1/2] Support any expression type inside inline primitive collections Closes #30732 Closes #30734 --- .../Properties/CosmosStrings.Designer.cs | 8 + .../Properties/CosmosStrings.resx | 3 + .../CosmosSqlTranslatingExpressionVisitor.cs | 63 +++-- .../Properties/RelationalStrings.Designer.cs | 22 +- .../Properties/RelationalStrings.resx | 11 +- .../Query/ISqlExpressionFactory.cs | 14 +- .../Query/Internal/ContainsTranslator.cs | 31 ++- ...mSqlParameterExpandingExpressionVisitor.cs | 58 +++-- ...lExpressionSimplifyingExpressionVisitor.cs | 213 ++++------------ .../Query/QuerySqlGenerator.cs | 26 +- .../Query/RelationalQueryRootProcessor.cs | 2 +- ...yableMethodTranslatingExpressionVisitor.cs | 67 +++-- ...lationalSqlTranslatingExpressionVisitor.cs | 69 ++--- .../Query/SqlExpressionFactory.cs | 113 ++++++--- .../Query/SqlExpressions/InExpression.cs | 205 +++++++++++---- .../Query/SqlExpressions/SelectExpression.cs | 30 ++- .../SqlExpressions/SqlUnaryExpression.cs | 39 +-- .../Query/SqlExpressions/ValuesExpression.cs | 20 +- .../Query/SqlNullabilityProcessor.cs | 235 ++++++++++++------ ...rchConditionConvertingExpressionVisitor.cs | 30 ++- ...yableMethodTranslatingExpressionVisitor.cs | 2 +- ...yableMethodTranslatingExpressionVisitor.cs | 5 +- src/EFCore/Query/InlineQueryRootExpression.cs | 13 +- .../ParameterExtractingExpressionVisitor.cs | 25 +- src/EFCore/Query/QueryRootProcessor.cs | 77 +++--- .../Query/NorthwindWhereQueryCosmosTest.cs | 11 +- .../Query/NullSemanticsQueryTestBase.cs | 128 ++++++++++ .../Query/UdfDbFunctionTestBase.cs | 20 +- ...orthwindAggregateOperatorsQueryTestBase.cs | 10 +- .../PrimitiveCollectionsQueryTestBase.cs | 40 ++- .../FiltersInheritanceQuerySqlServerTest.cs | 2 +- ...eteMappingInheritanceQuerySqlServerTest.cs | 24 +- .../Query/InheritanceQuerySqlServerTest.cs | 4 +- ...indAggregateOperatorsQuerySqlServerTest.cs | 10 +- .../Query/NorthwindWhereQuerySqlServerTest.cs | 12 +- .../Query/NullSemanticsQuerySqlServerTest.cs | 233 ++++++++++++++++- ...imitiveCollectionsQueryOldSqlServerTest.cs | 43 +++- .../PrimitiveCollectionsQuerySqlServerTest.cs | 44 +++- ...TPCFiltersInheritanceQuerySqlServerTest.cs | 2 +- ...TPTFiltersInheritanceQuerySqlServerTest.cs | 2 +- ...ralFiltersInheritanceQuerySqlServerTest.cs | 2 +- .../Query/UdfDbFunctionSqlServerTests.cs | 2 +- .../PrimitiveCollectionsQuerySqliteTest.cs | 44 +++- 43 files changed, 1373 insertions(+), 641 deletions(-) diff --git a/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs b/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs index f238646821c..e4cd0e2ace6 100644 --- a/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs +++ b/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs @@ -37,6 +37,14 @@ public static string AnalyticalTTLMismatch(object? ttl1, object? entityType1, ob public static string CanConnectNotSupported => GetString("CanConnectNotSupported"); + /// + /// The query contained a new array expression containing non-constant elements, which could not be translated: '{newArrayExpression}'. + /// + public static string CannotTranslateNonConstantNewArrayExpression(object? newArrayExpression) + => string.Format( + GetString("CannotTranslateNonConstantNewArrayExpression", nameof(newArrayExpression)), + newArrayExpression); + /// /// Both the connection string and CredentialToken, account key or account endpoint were specified. Specify only one set of connection details. /// diff --git a/src/EFCore.Cosmos/Properties/CosmosStrings.resx b/src/EFCore.Cosmos/Properties/CosmosStrings.resx index 4f55bd92c71..6f0de57243e 100644 --- a/src/EFCore.Cosmos/Properties/CosmosStrings.resx +++ b/src/EFCore.Cosmos/Properties/CosmosStrings.resx @@ -123,6 +123,9 @@ The Cosmos database does not support 'CanConnect' or 'CanConnectAsync'. + + The query contained a new array expression containing non-constant elements, which could not be translated: '{newArrayExpression}'. + Both the connection string and CredentialToken, account key or account endpoint were specified. Specify only one set of connection details. diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs index e52749044e0..e67d989751e 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs @@ -4,6 +4,7 @@ #nullable disable using System.Collections; +using System.Diagnostics.CodeAnalysis; using Microsoft.EntityFrameworkCore.Cosmos.Internal; namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; @@ -429,7 +430,9 @@ protected override Expression VisitMember(MemberExpression memberExpression) /// doing so can result in application failures when updating to a new Entity Framework Core release. /// protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) - => GetConstantOrNull(memberInitExpression); + => TryEvaluateToConstant(memberInitExpression, out var sqlConstantExpression) + ? sqlConstantExpression + : QueryCompilationContext.NotTranslatedExpression; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -603,7 +606,9 @@ static Expression RemoveObjectConvert(Expression expression) /// doing so can result in application failures when updating to a new Entity Framework Core release. /// protected override Expression VisitNew(NewExpression newExpression) - => GetConstantOrNull(newExpression); + => TryEvaluateToConstant(newExpression, out var sqlConstantExpression) + ? sqlConstantExpression + : QueryCompilationContext.NotTranslatedExpression; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -612,7 +617,15 @@ protected override Expression VisitNew(NewExpression newExpression) /// doing so can result in application failures when updating to a new Entity Framework Core release. /// protected override Expression VisitNewArray(NewArrayExpression newArrayExpression) - => null; + { + if (TryEvaluateToConstant(newArrayExpression, out var sqlConstantExpression)) + { + return sqlConstantExpression; + } + + AddTranslationErrorDetails(CosmosStrings.CannotTranslateNonConstantNewArrayExpression(newArrayExpression.Print())); + return QueryCompilationContext.NotTranslatedExpression; + } /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -990,38 +1003,34 @@ private static List ParameterListValueExtractor( private static bool IsNullSqlConstantExpression(Expression expression) => expression is SqlConstantExpression sqlConstant && sqlConstant.Value == null; - private static SqlConstantExpression GetConstantOrNull(Expression expression) - => CanEvaluate(expression) - ? new SqlConstantExpression( + private static bool TryEvaluateToConstant(Expression expression, out SqlConstantExpression sqlConstantExpression) + { + if (CanEvaluate(expression)) + { + sqlConstantExpression = new SqlConstantExpression( Expression.Constant( Expression.Lambda>(Expression.Convert(expression, typeof(object))) .Compile(preferInterpretation: true) .Invoke(), expression.Type), - null) - : null; + null); + return true; + } + + sqlConstantExpression = null; + return false; + } private static bool CanEvaluate(Expression expression) - { -#pragma warning disable IDE0066 // Convert switch statement to expression - switch (expression) -#pragma warning restore IDE0066 // Convert switch statement to expression + => expression switch { - case ConstantExpression: - return true; - - case NewExpression newExpression: - return newExpression.Arguments.All(e => CanEvaluate(e)); - - case MemberInitExpression memberInitExpression: - return CanEvaluate(memberInitExpression.NewExpression) - && memberInitExpression.Bindings.All( - mb => mb is MemberAssignment memberAssignment && CanEvaluate(memberAssignment.Expression)); - - default: - return false; - } - } + ConstantExpression => true, + NewExpression e => e.Arguments.All(CanEvaluate), + NewArrayExpression e => e.Expressions.All(CanEvaluate), + MemberInitExpression e => CanEvaluate(e.NewExpression) + && e.Bindings.All(mb => mb is MemberAssignment memberAssignment && CanEvaluate(memberAssignment.Expression)), + _ => false + }; [DebuggerStepThrough] private static bool TranslationFailed(Expression original, Expression translation, out SqlExpression castTranslation) diff --git a/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs b/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs index 1ed5c6c8d0f..7b7f43e2888 100644 --- a/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs +++ b/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs @@ -53,6 +53,14 @@ public static string BadSequenceType public static string CannotChangeWhenOpen => GetString("CannotChangeWhenOpen"); + /// + /// The query contained a new array expression containing non-constant elements, which could not be translated: '{newArrayExpression}'. + /// + public static string CannotTranslateNonConstantNewArrayExpression(object? newArrayExpression) + => string.Format( + GetString("CannotTranslateNonConstantNewArrayExpression", nameof(newArrayExpression)), + newArrayExpression); + /// /// Can't configure a trigger on entity type '{entityType}', which is in a TPH hierarchy and isn't the root. Configure the trigger on the TPH root entity type '{rootEntityType}' instead. /// @@ -622,12 +630,12 @@ public static string DuplicateSeedDataSensitive(object? entityType, object? keyV entityType, keyValue, table); /// - /// Either {param1} or {param2} must be null. + /// Exactly one of '{param1}', '{param2}' or '{param3}' must be set. /// - public static string EitherOfTwoValuesMustBeNull(object? param1, object? param2) + public static string OneOfThreeValuesMustBeSet(object? param1, object? param2, object? param3) => string.Format( - GetString("EitherOfTwoValuesMustBeNull", nameof(param1), nameof(param2)), - param1, param2); + GetString("OneOfThreeValuesMustBeSet", nameof(param1), nameof(param2), nameof(param3)), + param1, param2, param3); /// /// Empty collections are not supported as constant query roots. @@ -1310,11 +1318,11 @@ public static string NoDbCommand => GetString("NoDbCommand"); /// - /// Expression of type '{type}' isn't supported as the Values of an InExpression; only constants and parameters are supported. + /// Expression of type '{type}' isn't supported in the values of an InExpression; only constants and parameters are supported. /// - public static string NonConstantOrParameterAsInExpressionValues(object? type) + public static string NonConstantOrParameterAsInExpressionValue(object? type) => string.Format( - GetString("NonConstantOrParameterAsInExpressionValues", nameof(type)), + GetString("NonConstantOrParameterAsInExpressionValue", nameof(type)), type); /// diff --git a/src/EFCore.Relational/Properties/RelationalStrings.resx b/src/EFCore.Relational/Properties/RelationalStrings.resx index 97b5fc2bc0d..d212c1beb9e 100644 --- a/src/EFCore.Relational/Properties/RelationalStrings.resx +++ b/src/EFCore.Relational/Properties/RelationalStrings.resx @@ -130,6 +130,9 @@ The instance of DbConnection is currently in use. The connection can only be changed when the existing connection is not being used. + + The query contained a new array expression containing non-constant elements, which could not be translated: '{newArrayExpression}'. + Can't configure a trigger on entity type '{entityType}', which is in a TPH hierarchy and isn't the root. Configure the trigger on the TPH root entity type '{rootEntityType}' instead. @@ -346,8 +349,8 @@ A seed entity for entity type '{entityType}' has the same key value {keyValue} as another seed entity mapped to the same table '{table}'. Key values should be unique across seed entities. - - Either {param1} or {param2} must be null. + + Exactly one of '{param1}', '{param2}' or '{param3}' must be set. Empty collections are not supported as inline query roots. @@ -911,8 +914,8 @@ Cannot create a DbCommand for a non-relational query. - - Expression of type '{type}' isn't supported as the Values of an InExpression; only constants and parameters are supported. + + Expression of type '{type}' isn't supported in the values of an InExpression; only constants and parameters are supported. 'FindMapping' was called on a 'RelationalTypeMappingSource' with a non-relational 'TypeMappingInfo'. diff --git a/src/EFCore.Relational/Query/ISqlExpressionFactory.cs b/src/EFCore.Relational/Query/ISqlExpressionFactory.cs index 3fa2be11a49..6462d240fb6 100644 --- a/src/EFCore.Relational/Query/ISqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/ISqlExpressionFactory.cs @@ -402,21 +402,29 @@ SqlFunctionExpression NiladicFunction( /// An expression representing an EXISTS operation in a SQL tree. ExistsExpression Exists(SelectExpression subquery); + /// + /// Creates a new which represents an IN operation in a SQL tree. + /// + /// An item to look into values. + /// A subquery in which item is searched. + /// An expression representing an IN operation in a SQL tree. + InExpression In(SqlExpression item, SelectExpression subquery); + /// /// Creates a new which represents an IN operation in a SQL tree. /// /// An item to look into values. /// A list of values in which item is searched. /// An expression representing an IN operation in a SQL tree. - InExpression In(SqlExpression item, SqlExpression values); + InExpression In(SqlExpression item, IReadOnlyList values); /// /// Creates a new which represents an IN operation in a SQL tree. /// /// An item to look into values. - /// A subquery in which item is searched. + /// A parameterized list of values in which the item is searched. /// An expression representing an IN operation in a SQL tree. - InExpression In(SqlExpression item, SelectExpression subquery); + InExpression In(SqlExpression item, SqlParameterExpression valuesParameter); /// /// Creates a new which represents a LIKE in a SQL tree. diff --git a/src/EFCore.Relational/Query/Internal/ContainsTranslator.cs b/src/EFCore.Relational/Query/Internal/ContainsTranslator.cs index e27e6c9e536..43d562142cd 100644 --- a/src/EFCore.Relational/Query/Internal/ContainsTranslator.cs +++ b/src/EFCore.Relational/Query/Internal/ContainsTranslator.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; namespace Microsoft.EntityFrameworkCore.Query.Internal; @@ -38,11 +39,14 @@ public ContainsTranslator(ISqlExpressionFactory sqlExpressionFactory) IReadOnlyList arguments, IDiagnosticsLogger logger) { + SqlExpression? itemExpression = null, valuesExpression = null; + + // Identify static Enumerable.Contains and instance List.Contains if (method.IsGenericMethod - && method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains) + && method.GetGenericMethodDefinition() == EnumerableMethods.Contains && ValidateValues(arguments[0])) { - return _sqlExpressionFactory.In(RemoveObjectConvert(arguments[1]), arguments[0]); + (itemExpression, valuesExpression) = (RemoveObjectConvert(arguments[1]), arguments[0]); } if (arguments.Count == 1 @@ -50,14 +54,33 @@ public ContainsTranslator(ISqlExpressionFactory sqlExpressionFactory) && instance != null && ValidateValues(instance)) { - return _sqlExpressionFactory.In(RemoveObjectConvert(arguments[0]), instance); + (itemExpression, valuesExpression) = (RemoveObjectConvert(arguments[0]), instance); + } + + if (itemExpression is not null && valuesExpression is not null) + { + switch (valuesExpression) + { + case SqlParameterExpression parameter: + return _sqlExpressionFactory.In(itemExpression, parameter); + + case SqlConstantExpression { Value: IEnumerable values }: + var valuesExpressions = new List(); + + foreach (var value in values) + { + valuesExpressions.Add(_sqlExpressionFactory.Constant(value)); + } + + return _sqlExpressionFactory.In(itemExpression, valuesExpressions); + } } return null; } private static bool ValidateValues(SqlExpression values) - => values is SqlConstantExpression || values is SqlParameterExpression; + => values is SqlConstantExpression or SqlParameterExpression; private static SqlExpression RemoveObjectConvert(SqlExpression expression) => expression is SqlUnaryExpression sqlUnaryExpression diff --git a/src/EFCore.Relational/Query/Internal/FromSqlParameterExpandingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/FromSqlParameterExpandingExpressionVisitor.cs index 6cf339b80a1..f95c11ea530 100644 --- a/src/EFCore.Relational/Query/Internal/FromSqlParameterExpandingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Internal/FromSqlParameterExpandingExpressionVisitor.cs @@ -130,38 +130,58 @@ public virtual Expression Expand( return _visitedFromSqlExpressions[fromSql] = fromSql.Update( Expression.Constant(new CompositeRelationalParameter(parameterExpression.Name!, subParameters))); - case ConstantExpression constantExpression: - var existingValues = constantExpression.GetConstantValue(); + case ConstantExpression { Value: object?[] existingValues }: + { var constantValues = new object?[existingValues.Length]; for (var i = 0; i < existingValues.Length; i++) { - var value = existingValues[i]; - if (value is DbParameter dbParameter) - { - var parameterName = _parameterNameGenerator.GenerateNext(); - if (string.IsNullOrEmpty(dbParameter.ParameterName)) - { - dbParameter.ParameterName = parameterName; - } - else - { - parameterName = dbParameter.ParameterName; - } + constantValues[i] = ProcessConstantValue(existingValues[i]); + } - constantValues[i] = new RawRelationalParameter(parameterName, dbParameter); - } - else + return _visitedFromSqlExpressions[fromSql] = fromSql.Update(Expression.Constant(constantValues, typeof(object[]))); + } + + case NewArrayExpression { Expressions: var expressions }: + { + var constantValues = new object?[expressions.Count]; + for (var i = 0; i < constantValues.Length; i++) + { + if (expressions[i] is not SqlConstantExpression { Value: var existingValue }) { - constantValues[i] = _sqlExpressionFactory.Constant( - value, _typeMappingSource.GetMappingForValue(value)); + Check.DebugFail("FromSql.Arguments must be Constant/ParameterExpression"); + throw new InvalidOperationException(); } + + constantValues[i] = ProcessConstantValue(existingValue); } return _visitedFromSqlExpressions[fromSql] = fromSql.Update(Expression.Constant(constantValues, typeof(object[]))); + } default: Check.DebugFail("FromSql.Arguments must be Constant/ParameterExpression"); return null; } + + object ProcessConstantValue(object? existingConstantValue) + { + if (existingConstantValue is DbParameter dbParameter) + { + var parameterName = _parameterNameGenerator.GenerateNext(); + if (string.IsNullOrEmpty(dbParameter.ParameterName)) + { + dbParameter.ParameterName = parameterName; + } + else + { + parameterName = dbParameter.ParameterName; + } + + return new RawRelationalParameter(parameterName, dbParameter); + } + + return _sqlExpressionFactory.Constant( + existingConstantValue, _typeMappingSource.GetMappingForValue(existingConstantValue)); + } } } diff --git a/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs index 62585199c34..619a42985f2 100644 --- a/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Internal/SqlExpressionSimplifyingExpressionVisitor.cs @@ -245,91 +245,39 @@ or ExpressionType.LessThan && leftCandidateInfo.ColumnExpression == rightCandidateInfo.ColumnExpression && leftCandidateInfo.OperationType == rightCandidateInfo.OperationType) { - var leftConstantIsEnumerable = leftCandidateInfo.ConstantValue is IEnumerable - && !(leftCandidateInfo.ConstantValue is string) - && !(leftCandidateInfo.ConstantValue is byte[]); - - var rightConstantIsEnumerable = rightCandidateInfo.ConstantValue is IEnumerable - && !(rightCandidateInfo.ConstantValue is string) - && !(rightCandidateInfo.ConstantValue is byte[]); - - if ((leftCandidateInfo.OperationType == ExpressionType.Equal - && sqlBinaryExpression.OperatorType == ExpressionType.OrElse) - || (leftCandidateInfo.OperationType == ExpressionType.NotEqual - && sqlBinaryExpression.OperatorType == ExpressionType.AndAlso)) + // for relational nulls we can't combine comparisons that contain null + // a != 1 && a != null would be converted to a NOT IN (1, null), which never returns any results + // we need to keep it in the original form so that a != null gets converted to a IS NOT NULL instead + // for c# null semantics it's fine because null semantics visitor extracts null back into proper null checks + var leftValues = leftCandidateInfo.ValueOrValues switch { - object leftValue; - object rightValue; - List resultArray; + IReadOnlyList v => v, + SqlConstantExpression c when !_useRelationalNulls || c.Value is not null => new[] { c }, + _ => null + }; - switch ((leftConstantIsEnumerable, rightConstantIsEnumerable)) - { - case (false, false): - // comparison + comparison - leftValue = leftCandidateInfo.ConstantValue; - rightValue = rightCandidateInfo.ConstantValue; - - // for relational nulls we can't combine comparisons that contain null - // a != 1 && a != null would be converted to a NOT IN (1, null), which never returns any results - // we need to keep it in the original form so that a != null gets converted to a IS NOT NULL instead - // for c# null semantics it's fine because null semantics visitor extracts null back into proper null checks - if (_useRelationalNulls && (leftValue == null || rightValue == null)) - { - return sqlBinaryExpression.Update(left, right); - } - - resultArray = ConstructCollection(leftValue, rightValue); - break; - - case (true, true): - // in + in - leftValue = leftCandidateInfo.ConstantValue; - rightValue = rightCandidateInfo.ConstantValue; - resultArray = UnionCollections((IEnumerable)leftValue, (IEnumerable)rightValue); - break; - - default: - // in + comparison - leftValue = leftConstantIsEnumerable - ? leftCandidateInfo.ConstantValue - : rightCandidateInfo.ConstantValue; - - rightValue = leftConstantIsEnumerable - ? rightCandidateInfo.ConstantValue - : leftCandidateInfo.ConstantValue; - - if (_useRelationalNulls && rightValue == null) - { - return sqlBinaryExpression.Update(left, right); - } - - resultArray = AddToCollection((IEnumerable)leftValue, rightValue); - break; - } - - var inExpression = _sqlExpressionFactory.In( - leftCandidateInfo.ColumnExpression, - _sqlExpressionFactory.Constant(resultArray, leftCandidateInfo.TypeMapping)); - - return leftCandidateInfo.OperationType switch - { - ExpressionType.Equal => inExpression, - ExpressionType.NotEqual => _sqlExpressionFactory.Not(inExpression), - _ => throw new InvalidOperationException("IMPOSSIBLE") - }; - } + var rightValues = rightCandidateInfo.ValueOrValues switch + { + IReadOnlyList v => v, + SqlConstantExpression c when !_useRelationalNulls || c.Value is not null => new[] { c }, + _ => null + }; - if (leftConstantIsEnumerable && rightConstantIsEnumerable) + if (leftValues is not null && rightValues is not null) { - // a IN (1, 2, 3) && a IN (2, 3, 4) -> a IN (2, 3) - // a NOT IN (1, 2, 3) || a NOT IN (2, 3, 4) -> a NOT IN (2, 3) - var resultArray = IntersectCollections( - (IEnumerable)leftCandidateInfo.ConstantValue, - (IEnumerable)rightCandidateInfo.ConstantValue); + // Union: + // a IN (1, 2) || a IN (2, 3) -> a IN (1, 2, 3) + // a IN (1, 2) || a = 3 -> a IN (1, 2, 3) + // a NOT IN (1, 2) && a <> 3 -> a NOT IN (1, 2, 3) + // Intersection: + // a IN (1, 2, 3) && a IN (2, 3, 4) -> a IN (2, 3) var inExpression = _sqlExpressionFactory.In( leftCandidateInfo.ColumnExpression, - _sqlExpressionFactory.Constant(resultArray, leftCandidateInfo.TypeMapping)); + (leftCandidateInfo.OperationType, sqlBinaryExpression.OperatorType) is + (ExpressionType.Equal, ExpressionType.OrElse) or (ExpressionType.NotEqual, ExpressionType.AndAlso) + ? leftValues.Union(rightValues).ToArray() + : leftValues.Intersect(rightValues).ToArray()); return leftCandidateInfo.OperationType switch { @@ -344,106 +292,45 @@ or ExpressionType.LessThan return sqlBinaryExpression.Update(left, right); } - private static List ConstructCollection(object left, object right) - => new() { left, right }; - - private static List AddToCollection(IEnumerable collection, object newElement) - { - var result = BuildListFromEnumerable(collection); - if (!result.Contains(newElement)) - { - result.Add(newElement); - } - - return result; - } - - private static List UnionCollections(IEnumerable first, IEnumerable second) + private static bool TryGetInExpressionCandidateInfo( + SqlExpression sqlExpression, + out (ColumnExpression ColumnExpression, object ValueOrValues, ExpressionType OperationType) candidateInfo) { - var result = BuildListFromEnumerable(first); - foreach (var collectionElement in second) + switch (sqlExpression) { - if (!result.Contains(collectionElement)) + case SqlUnaryExpression { OperatorType: ExpressionType.Not } sqlUnaryExpression + when TryGetInExpressionCandidateInfo(sqlUnaryExpression.Operand, out var inner): { - result.Add(collectionElement); - } - } - - return result; - } - - private static List IntersectCollections(IEnumerable first, IEnumerable second) - { - var firstList = BuildListFromEnumerable(first); - var result = new List(); + candidateInfo = (inner.ColumnExpression, inner.ValueOrValues, + inner.OperationType == ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal); - foreach (var collectionElement in second) - { - if (firstList.Contains(collectionElement)) - { - result.Add(collectionElement); + return true; } - } - return result; - } - - private static List BuildListFromEnumerable(IEnumerable collection) - { - List result; - if (collection is List list) - { - result = list; - } - else - { - result = new List(); - foreach (var collectionElement in collection) + case SqlBinaryExpression { OperatorType: ExpressionType.Equal or ExpressionType.NotEqual } sqlBinaryExpression: { - result.Add(collectionElement); - } - } + var column = (sqlBinaryExpression.Left as ColumnExpression ?? sqlBinaryExpression.Right as ColumnExpression); + var constant = (sqlBinaryExpression.Left as SqlConstantExpression ?? sqlBinaryExpression.Right as SqlConstantExpression); - return result; - } - - private static bool TryGetInExpressionCandidateInfo( - SqlExpression sqlExpression, - out (ColumnExpression ColumnExpression, object ConstantValue, RelationalTypeMapping TypeMapping, ExpressionType OperationType) - candidateInfo) - { - if (sqlExpression is SqlUnaryExpression { OperatorType: ExpressionType.Not } sqlUnaryExpression) - { - if (TryGetInExpressionCandidateInfo(sqlUnaryExpression.Operand, out var inner)) - { - candidateInfo = (inner.ColumnExpression, inner.ConstantValue, inner.TypeMapping, - inner.OperationType == ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal); + if (column != null && constant != null) + { + candidateInfo = (column, constant, sqlBinaryExpression.OperatorType); + return true; + } - return true; + goto default; } - } - else if (sqlExpression is SqlBinaryExpression { OperatorType: ExpressionType.Equal or ExpressionType.NotEqual } sqlBinaryExpression) - { - var column = (sqlBinaryExpression.Left as ColumnExpression ?? sqlBinaryExpression.Right as ColumnExpression); - var constant = (sqlBinaryExpression.Left as SqlConstantExpression ?? sqlBinaryExpression.Right as SqlConstantExpression); - if (column != null && constant != null) + case InExpression { Item: ColumnExpression column, Subquery: null, Values: { } values }: { - candidateInfo = (column, constant.Value!, constant.TypeMapping!, sqlBinaryExpression.OperatorType); + candidateInfo = (column, values, ExpressionType.Equal); + return true; } - } - else if (sqlExpression is InExpression - { - Item: ColumnExpression column, Subquery: null, Values: SqlConstantExpression valuesConstant - }) - { - candidateInfo = (column, valuesConstant.Value!, valuesConstant.TypeMapping!, ExpressionType.Equal); - return true; + default: + candidateInfo = default; + return false; } - - candidateInfo = default; - return false; } } diff --git a/src/EFCore.Relational/Query/QuerySqlGenerator.cs b/src/EFCore.Relational/Query/QuerySqlGenerator.cs index 4de5bc63df7..80fdc22a2b5 100644 --- a/src/EFCore.Relational/Query/QuerySqlGenerator.cs +++ b/src/EFCore.Relational/Query/QuerySqlGenerator.cs @@ -982,30 +982,30 @@ protected sealed override Expression VisitIn(InExpression inExpression) /// Whether the given is negated. protected virtual void GenerateIn(InExpression inExpression, bool negated) { - if (inExpression.Values != null) + Check.DebugAssert( + inExpression.ValuesParameter is null, + "InExpression.ValuesParameter must have been expanded to constants before SQL generation (i.e. in SqlNullabilityProcessor)"); + + Visit(inExpression.Item); + _relationalCommandBuilder.Append(negated ? " NOT IN (" : " IN ("); + + if (inExpression.Values is not null) { - Visit(inExpression.Item); - _relationalCommandBuilder.Append(negated ? " NOT IN " : " IN "); - _relationalCommandBuilder.Append("("); - var valuesConstant = (SqlConstantExpression)inExpression.Values; - var valuesList = ((IEnumerable)valuesConstant.Value!) - .Select(v => new SqlConstantExpression(Expression.Constant(v), valuesConstant.TypeMapping)).ToList(); - GenerateList(valuesList, e => Visit(e)); - _relationalCommandBuilder.Append(")"); + GenerateList(inExpression.Values, e => Visit(e)); } else { - Visit(inExpression.Item); - _relationalCommandBuilder.Append(negated ? " NOT IN " : " IN "); - _relationalCommandBuilder.AppendLine("("); + _relationalCommandBuilder.AppendLine(); using (_relationalCommandBuilder.Indent()) { Visit(inExpression.Subquery); } - _relationalCommandBuilder.AppendLine().Append(")"); + _relationalCommandBuilder.AppendLine(); } + + _relationalCommandBuilder.Append(")"); } /// diff --git a/src/EFCore.Relational/Query/RelationalQueryRootProcessor.cs b/src/EFCore.Relational/Query/RelationalQueryRootProcessor.cs index d2c218bb5c9..10ba014ce94 100644 --- a/src/EFCore.Relational/Query/RelationalQueryRootProcessor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryRootProcessor.cs @@ -30,7 +30,7 @@ public RelationalQueryRootProcessor( /// Indicates that a can be converted to a ; /// this will later be translated to a SQL . /// - protected override bool ShouldConvertToInlineQueryRoot(ConstantExpression constantExpression) + protected override bool ShouldConvertToInlineQueryRoot(NewArrayExpression newArrayExpression) => true; /// diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index b65ce1c77c5..93793864ab6 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics.CodeAnalysis; +using System.Reflection.Metadata; using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; @@ -207,8 +208,8 @@ when entityQueryRootExpression.GetType() == typeof(EntityQueryRootExpression) return new ShapedQueryExpression(selectExpression, shaperExpression); } - case InlineQueryRootExpression constantQueryRootExpression: - return VisitInlineQueryRoot(constantQueryRootExpression) ?? base.VisitExtension(extensionExpression); + case InlineQueryRootExpression inlineQueryRootExpression: + return VisitInlineQueryRoot(inlineQueryRootExpression) ?? base.VisitExtension(extensionExpression); case ParameterQueryRootExpression parameterQueryRootExpression: var sqlParameterExpression = @@ -319,26 +320,27 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp for (var i = 0; i < inlineQueryRootExpression.Values.Count; i++) { - var value = inlineQueryRootExpression.Values[i]; - - // We currently support constants only; supporting non-constant values in VALUES is tracked by #30734. - if (value is not ConstantExpression constantExpression) + // Note that we specifically don't apply the default type mapping to the translation, to allow it to get inferred later based + // on usage. + if (TranslateExpression(inlineQueryRootExpression.Values[i], applyDefaultTypeMapping: false) + is not SqlExpression translatedValue) { - AddTranslationErrorDetails(RelationalStrings.OnlyConstantsSupportedInInlineCollectionQueryRoots); return null; } - if (constantExpression.Value is null) - { - encounteredNull = true; - } + // TODO: Poor man's null semantics: in SqlNullabilityProcessor we don't fully handle the nullability of SelectExpression + // projections. Whether the SelectExpression's projection is nullable or not is determined here in translation, but at this + // point we don't know how to properly calculate nullability (and can't look at parameters). + // So for now, we assume the projected column is nullable if we see anything but non-null constants and non-nullable columns. + encounteredNull |= + translatedValue is not SqlConstantExpression { Value: not null } and not ColumnExpression { IsNullable: false }; rowExpressions.Add(new RowValueExpression(new[] { // Since VALUES may not guarantee row ordering, we add an _ord value by which we'll order. _sqlExpressionFactory.Constant(i, intTypeMapping), // Note that for the actual value, we must leave the type mapping null to allow it to get inferred later based on usage - _sqlExpressionFactory.Constant(constantExpression.Value, elementType, typeMapping: null) + translatedValue })); } @@ -529,27 +531,15 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent && projection is ColumnExpression projectedColumn && projectedColumn.Table == valuesExpression) { - var values = new object?[valuesExpression.RowValues.Count]; + var values = new SqlExpression[valuesExpression.RowValues.Count]; for (var i = 0; i < values.Length; i++) { // Skip the first value (_ord), which is irrelevant for Contains - if (valuesExpression.RowValues[i].Values[1] is SqlConstantExpression { Value: var constantValue }) - { - values[i] = constantValue; - } - else - { - // We only support constants for now - values = null; - break; - } + values[i] = valuesExpression.RowValues[i].Values[1]; } - if (values is not null) - { - var inExpression = _sqlExpressionFactory.In(translatedItem, _sqlExpressionFactory.Constant(values)); - return source.Update(_sqlExpressionFactory.Select(inExpression), source.ShaperExpression); - } + var inExpression = _sqlExpressionFactory.In(translatedItem, values); + return source.Update(_sqlExpressionFactory.Select(inExpression), source.ShaperExpression); } // Translate to IN with a subquery. @@ -1834,7 +1824,7 @@ protected virtual bool IsValidSelectExpressionForExecuteUpdate( protected virtual Expression ApplyInferredTypeMappings( Expression expression, IReadOnlyDictionary<(TableExpressionBase, string), RelationalTypeMapping?> inferredTypeMappings) - => new RelationalInferredTypeMappingApplier(inferredTypeMappings).Visit(expression); + => new RelationalInferredTypeMappingApplier(_sqlExpressionFactory, inferredTypeMappings).Visit(expression); /// /// Determines whether the given is ordered, typically because orderings have been added to it. @@ -2746,6 +2736,7 @@ private void RegisterInferredTypeMapping(ColumnExpression columnExpression, Rela /// protected class RelationalInferredTypeMappingApplier : ExpressionVisitor { + private readonly ISqlExpressionFactory _sqlExpressionFactory; private SelectExpression? _currentSelectExpression; /// @@ -2756,10 +2747,15 @@ protected class RelationalInferredTypeMappingApplier : ExpressionVisitor /// /// Creates a new instance of the class. /// + /// The SQL expression factory. /// The inferred type mappings to be applied back on their query roots. public RelationalInferredTypeMappingApplier( + ISqlExpressionFactory sqlExpressionFactory, IReadOnlyDictionary<(TableExpressionBase, string), RelationalTypeMapping?> inferredTypeMappings) - => _inferredTypeMappings = inferredTypeMappings; + { + _sqlExpressionFactory = sqlExpressionFactory; + _inferredTypeMappings = inferredTypeMappings; + } /// /// Attempts to find an inferred type mapping for the given table column. @@ -2855,30 +2851,27 @@ protected virtual ValuesExpression ApplyTypeMappingsOnValuesExpression(ValuesExp var newValues = new SqlExpression[newColumnNames.Count]; for (var j = 0; j < valuesExpression.ColumnNames.Count; j++) { - Check.DebugAssert(rowValue.Values[j] is SqlConstantExpression, "Non-constant SqlExpression in ValuesExpression"); - if (j == 0 && stripOrdering) { continue; } - var value = (SqlConstantExpression)rowValue.Values[j]; - SqlExpression newValue = value; + var value = rowValue.Values[j]; var inferredTypeMapping = inferredTypeMappings[j]; if (inferredTypeMapping is not null && value.TypeMapping is null) { - newValue = new SqlConstantExpression(Expression.Constant(value.Value, value.Type), inferredTypeMapping); + value = _sqlExpressionFactory.ApplyTypeMapping(value, inferredTypeMapping); // We currently add explicit conversions on the first row, to ensure that the inferred types are properly typed. // See #30605 for removing that when not needed. if (i == 0) { - newValue = new SqlUnaryExpression(ExpressionType.Convert, newValue, newValue.Type, newValue.TypeMapping); + value = new SqlUnaryExpression(ExpressionType.Convert, value, value.Type, value.TypeMapping); } } - newValues[j - (stripOrdering ? 1 : 0)] = newValue; + newValues[j - (stripOrdering ? 1 : 0)] = value; } newRowValues[i] = new RowValueExpression(newValues); diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index 8749b8d3585..eb1cea178f8 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -743,7 +743,9 @@ protected override Expression VisitMember(MemberExpression memberExpression) /// protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression) - => GetConstantOrNotTranslated(memberInitExpression); + => TryEvaluateToConstant(memberInitExpression, out var sqlConstantExpression) + ? sqlConstantExpression + : QueryCompilationContext.NotTranslatedExpression; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -1002,11 +1004,21 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp /// protected override Expression VisitNew(NewExpression newExpression) - => GetConstantOrNotTranslated(newExpression); + => TryEvaluateToConstant(newExpression, out var sqlConstantExpression) + ? sqlConstantExpression + : QueryCompilationContext.NotTranslatedExpression; /// protected override Expression VisitNewArray(NewArrayExpression newArrayExpression) - => QueryCompilationContext.NotTranslatedExpression; + { + if (TryEvaluateToConstant(newArrayExpression, out var sqlConstantExpression)) + { + return sqlConstantExpression; + } + + AddTranslationErrorDetails(RelationalStrings.CannotTranslateNonConstantNewArrayExpression(newArrayExpression.Print())); + return QueryCompilationContext.NotTranslatedExpression; + } /// protected override Expression VisitParameter(ParameterExpression parameterExpression) @@ -1095,7 +1107,7 @@ SqlExpression GeneratePredicateTpt(EntityProjectionExpression entityProjectionEx _sqlExpressionFactory.Constant(discriminatorValues[0])) : _sqlExpressionFactory.In( entityProjectionExpression.DiscriminatorExpression!, - _sqlExpressionFactory.Constant(discriminatorValues)); + discriminatorValues.Select(d => _sqlExpressionFactory.Constant(d)).ToArray()); } } else @@ -1114,8 +1126,7 @@ SqlExpression GeneratePredicateTpt(EntityProjectionExpression entityProjectionEx _sqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue())) : _sqlExpressionFactory.In( discriminatorColumn, - _sqlExpressionFactory.Constant( - concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList())); + concreteEntityTypes.Select(et => _sqlExpressionFactory.Constant(et.GetDiscriminatorValue())).ToArray()); } } else @@ -1569,16 +1580,23 @@ private static Expression ConvertObjectArrayEqualityComparison(Expression left, .Aggregate((a, b) => Expression.AndAlso(a, b)); } - private static Expression GetConstantOrNotTranslated(Expression expression) - => CanEvaluate(expression) - ? new SqlConstantExpression( + private static bool TryEvaluateToConstant(Expression expression, [NotNullWhen(true)] out SqlConstantExpression? sqlConstantExpression) + { + if (CanEvaluate(expression)) + { + sqlConstantExpression = new SqlConstantExpression( Expression.Constant( Expression.Lambda>(Expression.Convert(expression, typeof(object))) .Compile(preferInterpretation: true) .Invoke(), expression.Type), - null) - : QueryCompilationContext.NotTranslatedExpression; + null); + return true; + } + + sqlConstantExpression = null; + return false; + } private bool TryRewriteContainsEntity(Expression source, Expression item, [NotNullWhen(true)] out Expression? result) { @@ -1862,26 +1880,15 @@ when memberInitExpression.Bindings.SingleOrDefault( } private static bool CanEvaluate(Expression expression) - { -#pragma warning disable IDE0066 // Convert switch statement to expression - switch (expression) -#pragma warning restore IDE0066 // Convert switch statement to expression - { - case ConstantExpression: - return true; - - case NewExpression newExpression: - return newExpression.Arguments.All(e => CanEvaluate(e)); - - case MemberInitExpression memberInitExpression: - return CanEvaluate(memberInitExpression.NewExpression) - && memberInitExpression.Bindings.All( - mb => mb is MemberAssignment memberAssignment && CanEvaluate(memberAssignment.Expression)); - - default: - return false; - } - } + => expression switch + { + ConstantExpression => true, + NewExpression e => e.Arguments.All(CanEvaluate), + NewArrayExpression e => e.Expressions.All(CanEvaluate), + MemberInitExpression e => CanEvaluate(e.NewExpression) + && e.Bindings.All(mb => mb is MemberAssignment memberAssignment && CanEvaluate(memberAssignment.Expression)), + _ => false + }; private static bool IsNullSqlConstantExpression(Expression expression) => expression is SqlConstantExpression sqlConstant && sqlConstant.Value == null; diff --git a/src/EFCore.Relational/Query/SqlExpressionFactory.cs b/src/EFCore.Relational/Query/SqlExpressionFactory.cs index 1f81e7d826b..552e17a99f6 100644 --- a/src/EFCore.Relational/Query/SqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/SqlExpressionFactory.cs @@ -239,28 +239,78 @@ private SqlExpression ApplyTypeMappingOnSqlBinary( resultTypeMapping); } - private SqlExpression ApplyTypeMappingOnIn(InExpression inExpression) + private InExpression ApplyTypeMappingOnIn(InExpression inExpression) { - var itemTypeMapping = (inExpression.Values != null - ? ExpressionExtensions.InferTypeMapping(inExpression.Item, inExpression.Values) - : inExpression.Subquery != null - ? ExpressionExtensions.InferTypeMapping(inExpression.Item, inExpression.Subquery.Projection[0].Expression) - : inExpression.Item.TypeMapping) - ?? _typeMappingSource.FindMapping(inExpression.Item.Type, Dependencies.Model); - - var item = ApplyTypeMapping(inExpression.Item, itemTypeMapping); - if (inExpression.Values != null) + var missingTypeMappingInValues = false; + + RelationalTypeMapping? valuesTypeMapping = null; + switch (inExpression) { - var values = ApplyTypeMapping(inExpression.Values, itemTypeMapping); + case { Subquery: SelectExpression subquery }: + valuesTypeMapping = subquery.Projection[0].Expression.TypeMapping; + break; + + case { ValuesParameter: SqlParameterExpression parameter }: + valuesTypeMapping = parameter.TypeMapping; + break; + + case { Values: IReadOnlyList values }: + // Note: there could be conflicting type mappings inside the values; we take the first. + foreach (var value in values) + { + if (value.TypeMapping is null) + { + missingTypeMappingInValues = true; + } + else + { + valuesTypeMapping = value.TypeMapping; + } + } + + break; + + default: + throw new ArgumentOutOfRangeException(); + } + + var item = ApplyTypeMapping( + inExpression.Item, + valuesTypeMapping ?? Dependencies.TypeMappingSource.FindMapping(inExpression.Item.Type, Dependencies.Model)); + + switch (inExpression) + { + case { Subquery: SelectExpression subquery }: + inExpression = inExpression.Update(item, subquery); + break; + + case { ValuesParameter: SqlParameterExpression parameter }: + inExpression = inExpression.Update(item, (SqlParameterExpression)ApplyTypeMapping(parameter, item.TypeMapping)); + break; + + case { Values: IReadOnlyList values }: + SqlExpression[]? newValues = null; - return item != inExpression.Item || values != inExpression.Values || inExpression.TypeMapping != _boolTypeMapping - ? new InExpression(item, values, _boolTypeMapping) - : inExpression; + if (missingTypeMappingInValues) + { + newValues = new SqlExpression[values.Count]; + + for (var i = 0; i < newValues.Length; i++) + { + newValues[i] = ApplyTypeMapping(values[i], item.TypeMapping); + } + } + + inExpression = inExpression.Update(item, newValues ?? values); + break; + + default: + throw new ArgumentOutOfRangeException(); } - return item != inExpression.Item || inExpression.TypeMapping != _boolTypeMapping - ? new InExpression(item, inExpression.Subquery!, _boolTypeMapping) - : inExpression; + return inExpression.TypeMapping == _boolTypeMapping + ? inExpression + : inExpression.ApplyTypeMapping(_boolTypeMapping); } private SqlExpression ApplyTypeMappingOnJsonScalar( @@ -585,30 +635,17 @@ public virtual SqlFunctionExpression NiladicFunction( public virtual ExistsExpression Exists(SelectExpression subquery) => new(subquery, _boolTypeMapping); - /// - public virtual InExpression In(SqlExpression item, SqlExpression values) - { - var typeMapping = item.TypeMapping ?? _typeMappingSource.FindMapping(item.Type, Dependencies.Model); - - item = ApplyTypeMapping(item, typeMapping); - values = ApplyTypeMapping(values, typeMapping); - - return new InExpression(item, values, _boolTypeMapping); - } - /// public virtual InExpression In(SqlExpression item, SelectExpression subquery) - { - var sqlExpression = subquery.Projection.Single().Expression; - var subqueryTypeMapping = sqlExpression.TypeMapping; + => ApplyTypeMappingOnIn(new InExpression(item, subquery, _boolTypeMapping)); - if (item.TypeMapping is null) - { - item = subqueryTypeMapping is null ? ApplyDefaultTypeMapping(item) : ApplyTypeMapping(item, subqueryTypeMapping); - } + /// + public virtual InExpression In(SqlExpression item, IReadOnlyList values) + => ApplyTypeMappingOnIn(new InExpression(item, values, _boolTypeMapping)); - return new InExpression(item, subquery, _boolTypeMapping); - } + /// + public virtual InExpression In(SqlExpression item, SqlParameterExpression valuesParameter) + => ApplyTypeMappingOnIn(new InExpression(item, valuesParameter, _boolTypeMapping)); /// public virtual LikeExpression Like(SqlExpression match, SqlExpression pattern, SqlExpression? escapeChar = null) @@ -671,7 +708,7 @@ private void AddConditions(SelectExpression selectExpression, IEntityType entity var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList(); var predicate = concreteEntityTypes.Count == 1 ? (SqlExpression)Equal(discriminatorColumn, Constant(concreteEntityTypes[0].GetDiscriminatorValue())) - : In(discriminatorColumn, Constant(concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList())); + : In(discriminatorColumn, concreteEntityTypes.Select(et => Constant(et.GetDiscriminatorValue())).ToArray()); selectExpression.ApplyPredicate(predicate); diff --git a/src/EFCore.Relational/Query/SqlExpressions/InExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/InExpression.cs index 525599a81ad..c4ec59efdd4 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/InExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/InExpression.cs @@ -1,8 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Collections; - namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; /// @@ -17,37 +15,54 @@ namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; public class InExpression : SqlExpression { /// - /// Creates a new instance of the class which represents a IN subquery expression. + /// Creates a new instance of the class, representing a SQL IN expression with a subquery. /// /// An item to look into values. - /// A subquery in which item is searched. + /// A subquery in which the item is searched. /// The associated with the expression. public InExpression( SqlExpression item, SelectExpression subquery, RelationalTypeMapping typeMapping) - : this(item, null, subquery, typeMapping) + : this(item, subquery, values: null, valuesParameter: null, typeMapping) + { + } + + /// + /// Creates a new instance of the class, representing a SQL IN expression with a given list + /// of values. + /// + /// An item to look into values. + /// A list of values in which the item is searched. + /// The associated with the expression. + public InExpression( + SqlExpression item, + IReadOnlyList values, + RelationalTypeMapping typeMapping) + : this(item, subquery: null, values, valuesParameter: null, typeMapping) { } /// - /// Creates a new instance of the class which represents a IN values expression. + /// Creates a new instance of the class, representing a SQL IN expression with a given + /// parameterized list of values. /// /// An item to look into values. - /// A list of values in which item is searched. + /// A parameterized list of values in which the item is searched. /// The associated with the expression. public InExpression( SqlExpression item, - SqlExpression values, + SqlParameterExpression valuesParameter, RelationalTypeMapping typeMapping) - : this(item, values, null, typeMapping) + : this(item, subquery: null, values: null, valuesParameter, typeMapping) { } private InExpression( SqlExpression item, - SqlExpression? values, SelectExpression? subquery, + IReadOnlyList? values, + SqlParameterExpression? valuesParameter, RelationalTypeMapping? typeMapping) : base(typeof(bool), typeMapping) { @@ -60,6 +75,7 @@ private InExpression( Item = item; Subquery = subquery; Values = values; + ValuesParameter = valuesParameter; } /// @@ -68,47 +84,118 @@ private InExpression( public virtual SqlExpression Item { get; } /// - /// The list of values to search item in. + /// The subquery to search the item in. /// - public virtual SqlExpression? Values { get; } + public virtual SelectExpression? Subquery { get; } /// - /// The subquery to search item in. + /// The list of values to search the item in. /// - public virtual SelectExpression? Subquery { get; } + public virtual IReadOnlyList? Values { get; } + + /// + /// A parameter containing the list of values to search the item in. The parameterized list get expanded to the actual value + /// before the query SQL is generated. + /// + public virtual SqlParameterExpression? ValuesParameter { get; } /// protected override Expression VisitChildren(ExpressionVisitor visitor) { var item = (SqlExpression)visitor.Visit(Item); var subquery = (SelectExpression?)visitor.Visit(Subquery); - var values = (SqlExpression?)visitor.Visit(Values); - return Update(item, values, subquery); + SqlExpression[]? values = null; + if (Values is not null) + { + for (var i = 0; i < Values.Count; i++) + { + var value = Values[i]; + var newValue = (SqlExpression)visitor.Visit(value); + + if (newValue != value && values is null) + { + values = new SqlExpression[Values.Count]; + for (var j = 0; j < i; j++) + { + values[j] = Values[j]; + } + } + + if (values is not null) + { + values[i] = newValue; + } + } + } + + var valuesParameter = (SqlParameterExpression?)visitor.Visit(ValuesParameter); + + return Update(item, subquery, values ?? Values, valuesParameter); } + /// + /// Applies supplied type mapping to this expression. + /// + /// A relational type mapping to apply. + /// A new expression which has supplied type mapping. + public virtual InExpression ApplyTypeMapping(RelationalTypeMapping? typeMapping) + => new(Item, Subquery, Values, ValuesParameter, typeMapping); + + /// + /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will + /// return this expression. + /// + /// The property of the result. + /// The property of the result. + /// This expression if no children changed, or an expression with the updated children. + public virtual InExpression Update(SqlExpression item, SelectExpression subquery) + => Update(item, subquery, values: null, valuesParameter: null); + /// /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will /// return this expression. /// /// The property of the result. /// The property of the result. + /// This expression if no children changed, or an expression with the updated children. + public virtual InExpression Update(SqlExpression item, IReadOnlyList values) + => Update(item, subquery: null, values, valuesParameter: null); + + /// + /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will + /// return this expression. + /// + /// The property of the result. + /// The property of the result. + /// This expression if no children changed, or an expression with the updated children. + public virtual InExpression Update(SqlExpression item, SqlParameterExpression valuesParameter) + => Update(item, subquery: null, values: null, valuesParameter); + + /// + /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will + /// return this expression. + /// + /// The property of the result. /// The property of the result. + /// The property of the result. + /// The property of the result. /// This expression if no children changed, or an expression with the updated children. public virtual InExpression Update( SqlExpression item, - SqlExpression? values, - SelectExpression? subquery) + SelectExpression? subquery, + IReadOnlyList? values, + SqlParameterExpression? valuesParameter) { - if (values != null - && subquery != null) + if ((subquery is null ? 0 : 1) + (values is null ? 0 : 1) + (valuesParameter is null ? 0 : 1) != 1) { - throw new ArgumentException(RelationalStrings.EitherOfTwoValuesMustBeNull(nameof(values), nameof(subquery))); + throw new ArgumentException( + RelationalStrings.OneOfThreeValuesMustBeSet(nameof(subquery), nameof(values), nameof(valuesParameter))); } - return item != Item || subquery != Subquery || values != Values - ? new InExpression(item, values, subquery, TypeMapping) - : this; + return item == Item && subquery == Subquery && values == Values && valuesParameter == ValuesParameter + ? this + : new InExpression(item, subquery, values, valuesParameter, TypeMapping); } /// @@ -118,31 +205,35 @@ protected override void Print(ExpressionPrinter expressionPrinter) expressionPrinter.Append(" IN "); expressionPrinter.Append("("); - if (Subquery != null) - { - using (expressionPrinter.Indent()) - { - expressionPrinter.Visit(Subquery); - } - } - else if (Values is SqlConstantExpression constantValuesExpression - && constantValuesExpression.Value is IEnumerable constantValues) + switch (this) { - var first = true; - foreach (var item in constantValues) - { - if (!first) + case { Subquery: not null }: + using (expressionPrinter.Indent()) { - expressionPrinter.Append(", "); + expressionPrinter.Visit(Subquery); } - first = false; - expressionPrinter.Append(constantValuesExpression.TypeMapping?.GenerateSqlLiteral(item) ?? item?.ToString() ?? "NULL"); - } - } - else - { - expressionPrinter.Visit(Values); + break; + + case { Values: not null }: + for (var i = 0; i < Values.Count; i++) + { + if (i > 0) + { + expressionPrinter.Append(", "); + } + + expressionPrinter.Visit(Values[i]); + } + + break; + + case { ValuesParameter: not null}: + expressionPrinter.Visit(ValuesParameter); + break; + + default: + throw new ArgumentOutOfRangeException(); } expressionPrinter.Append(")"); @@ -158,10 +249,28 @@ public override bool Equals(object? obj) private bool Equals(InExpression inExpression) => base.Equals(inExpression) && Item.Equals(inExpression.Item) - && (Values?.Equals(inExpression.Values) ?? inExpression.Values == null) - && (Subquery?.Equals(inExpression.Subquery) ?? inExpression.Subquery == null); + && (Subquery?.Equals(inExpression.Subquery) ?? inExpression.Subquery == null) + && (ValuesParameter?.Equals(inExpression.ValuesParameter) ?? inExpression.ValuesParameter == null) + && (ReferenceEquals(Values, inExpression.Values) + || (Values is not null && inExpression.Values is not null && Values.SequenceEqual(inExpression.Values))); /// public override int GetHashCode() - => HashCode.Combine(base.GetHashCode(), Item, Values, Subquery); + { + var hash = new HashCode(); + hash.Add(base.GetHashCode()); + hash.Add(Item); + hash.Add(Subquery); + hash.Add(ValuesParameter); + + if (Values is not null) + { + for (var i = 0; i < Values.Count; i++) + { + hash.Add(Values[i]); + } + } + + return hash.ToHashCode(); + } } diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs index 2e360d253f1..6d8fa3e72ad 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs @@ -1903,15 +1903,30 @@ public void ApplyPredicate(SqlExpression sqlExpression) } } } - else if (sqlExpression is InExpression inExpression - && inExpression.Item is ColumnExpression itemColumn - && itemColumn.Table is TpcTablesExpression itemTpc + // Identify application of a predicate which narrows the discriminator (e.g. OfType) for TPC, apply it to + // _tpcDiscriminatorValues (which will be handled later) instead of as a WHERE predicate. + else if (sqlExpression is InExpression + { + Item: ColumnExpression { Table: TpcTablesExpression itemTpc } itemColumn, + Values: IReadOnlyList valueExpressions + } && _tpcDiscriminatorValues.TryGetValue(itemTpc, out var itemTuple) - && itemTuple.Item1.Equals(itemColumn) - && inExpression.Values is SqlConstantExpression itemConstant - && itemConstant.Value is List values) + && itemTuple.Item1.Equals(itemColumn)) { - var newList = itemTuple.Item2.Intersect(values).ToList(); + var constantValues = new string[valueExpressions.Count]; + for (var i = 0; i < constantValues.Length; i++) + { + if (valueExpressions[i] is SqlConstantExpression { Value: string value }) + { + constantValues[i] = value; + } + else + { + goto ApplyPredicate; + } + } + + var newList = itemTuple.Item2.Intersect(constantValues).ToList(); if (newList.Count > 0) { _tpcDiscriminatorValues[itemTpc] = (itemColumn, newList); @@ -1920,6 +1935,7 @@ public void ApplyPredicate(SqlExpression sqlExpression) } } + ApplyPredicate: sqlExpression = AssignUniqueAliases(sqlExpression); if (_groupBy.Count > 0) diff --git a/src/EFCore.Relational/Query/SqlExpressions/SqlUnaryExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SqlUnaryExpression.cs index ca465c71a6a..8b2b34f5d69 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SqlUnaryExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SqlUnaryExpression.cs @@ -79,21 +79,32 @@ public virtual SqlUnaryExpression Update(SqlExpression operand) /// protected override void Print(ExpressionPrinter expressionPrinter) { - if (OperatorType == ExpressionType.Convert - && TypeMapping != null) + switch (this) { - expressionPrinter.Append("CAST("); - expressionPrinter.Visit(Operand); - expressionPrinter.Append(" AS "); - expressionPrinter.Append(TypeMapping.StoreType); - expressionPrinter.Append(")"); - } - else - { - expressionPrinter.Append(OperatorType.ToString()); - expressionPrinter.Append("("); - expressionPrinter.Visit(Operand); - expressionPrinter.Append(")"); + case { OperatorType: ExpressionType.Convert, TypeMapping: not null }: + expressionPrinter.Append("CAST("); + expressionPrinter.Visit(Operand); + expressionPrinter.Append(" AS "); + expressionPrinter.Append(TypeMapping.StoreType); + expressionPrinter.Append(")"); + break; + + case { OperatorType: ExpressionType.Equal }: + expressionPrinter.Visit(Operand); + expressionPrinter.Append(" IS NULL"); + break; + + case { OperatorType: ExpressionType.NotEqual }: + expressionPrinter.Visit(Operand); + expressionPrinter.Append(" IS NOT NULL"); + break; + + default: + expressionPrinter.Append(OperatorType.ToString()); + expressionPrinter.Append("("); + expressionPrinter.Visit(Operand); + expressionPrinter.Append(")"); + break; } } diff --git a/src/EFCore.Relational/Query/SqlExpressions/ValuesExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/ValuesExpression.cs index 7da95b916bd..cb361af8c91 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/ValuesExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/ValuesExpression.cs @@ -42,23 +42,9 @@ public ValuesExpression( { Check.NotEmpty(rowValues, nameof(rowValues)); -#if DEBUG - if (rowValues.Any(rv => rv.Values.Count != columnNames.Count)) - { - throw new ArgumentException("All number of all row values doesn't match the number of column names"); - } - - if (rowValues.SelectMany(rv => rv.Values).Any( - v => v is not SqlConstantExpression and not SqlUnaryExpression - { - Operand: SqlConstantExpression, - OperatorType: ExpressionType.Convert - })) - { - // See #30734 for non-constants - throw new ArgumentException("Only constant expressions are supported in ValuesExpression"); - } -#endif + Check.DebugAssert( + rowValues.All(rv => rv.Values.Count == columnNames.Count), + "All row values must have a value count matching the number of column names"); RowValues = rowValues; ColumnNames = columnNames; diff --git a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs index 42c4a4f9363..9c73b434e18 100644 --- a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs +++ b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs @@ -3,6 +3,7 @@ using System.Collections; using System.Diagnostics.CodeAnalysis; +using System.Net.Http.Headers; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; namespace Microsoft.EntityFrameworkCore.Query; @@ -649,6 +650,8 @@ protected virtual SqlExpression VisitExists( /// An optimized sql expression. protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOptimizedExpansion, out bool nullable) { + // SQL IN returns null when the item is null, and when the values (or subquery projection) contains NULL and no match was made. + var item = Visit(inExpression.Item, out var itemNullable); if (inExpression.Subquery != null) @@ -669,7 +672,7 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt // we want to avoid double visitation var subqueryProjection = subquery.Projection.Single().Expression; - inExpression = inExpression.Update(item, values: null, subquery); + inExpression = inExpression.Update(item, subquery); var unwrappedSubqueryProjection = subqueryProjection; while (unwrappedSubqueryProjection is SqlUnaryExpression @@ -692,8 +695,6 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt nullable = false; - // SQL IN returns null when the item is null, and when the values (subquery projection) contains NULL and no match was made. - switch ((itemNullable, projectionNullable)) { // If both sides are non-nullable, IN never returns null, so is safe to use as-is. @@ -770,112 +771,188 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt // Non-subquery case - // for relational null semantics we don't need to extract null values from the array - if (UseRelationalNulls - || !(inExpression.Values is SqlConstantExpression || inExpression.Values is SqlParameterExpression)) + nullable = false; + + // For relational null semantics we don't need to extract null values from the array + if (UseRelationalNulls) { - var (valuesExpression, valuesList, _) = ProcessInExpressionValues(inExpression.Values!, extractNullValues: false); - nullable = false; + inExpression = ProcessInExpressionValues(inExpression, removeNulls: false, removeNullables: false, out _, out _); - return valuesList.Count == 0 - ? _sqlExpressionFactory.Constant(false, inExpression.TypeMapping) - : SimplifyInExpression( - inExpression.Update(item, valuesExpression, subquery: null), - valuesExpression, - valuesList); + return inExpression.Values! switch + { + [] => _sqlExpressionFactory.Constant(false, inExpression.TypeMapping), + [var v] => _sqlExpressionFactory.Equal(inExpression.Item, v), + [..] => inExpression + }; } - // for c# null semantics we need to remove nulls from Values and add IsNull/IsNotNull when necessary - var (inValuesExpression, inValuesList, hasNullValue) = ProcessInExpressionValues(inExpression.Values, extractNullValues: true); + // For all other scenarios, we need to compensate for the presence of nulls (constants/parameters) and nullables + // (columns/arbitrary expressions) in the value list. The following visits all the values, removing nulls (but not nullables) + // and returns the visited values with some information on what was found. + inExpression = ProcessInExpressionValues( + inExpression, removeNulls: true, removeNullables: false, out var valuesHasNull, out var nullableValues); - // either values array is empty or only contains null - if (inValuesList.Count == 0) + // Do some simplifications for when the value list contains only zero or one values + switch (inExpression.Values!) { - nullable = false; + // nullable IN (NULL) -> nullable IS NULL + case [] when valuesHasNull && itemNullable: + return _sqlExpressionFactory.IsNull(item); // a IN () -> false // non_nullable IN (NULL) -> false - // nullable IN (NULL) -> nullable IS NULL - return !hasNullValue || !itemNullable - ? _sqlExpressionFactory.Constant(false, inExpression.TypeMapping) - : _sqlExpressionFactory.IsNull(item); + case []: + return _sqlExpressionFactory.Constant(false, inExpression.TypeMapping); + + // a IN (1) -> a = 1 + // nullable IN (1, NULL) -> nullable IS NULL OR nullable = 1 + case [var singleValue]: + return Visit( + itemNullable && valuesHasNull + ? _sqlExpressionFactory.OrElse(_sqlExpressionFactory.IsNull(item), _sqlExpressionFactory.Equal(item, singleValue)) + : _sqlExpressionFactory.Equal(item, singleValue), + allowOptimizedExpansion, + out _); } - var simplifiedInExpression = SimplifyInExpression( - inExpression.Update(item, inValuesExpression, subquery: null), - inValuesExpression, - inValuesList); + // If the item is non-nullable and there are no nullables, return the expression without compensation; null has already been removed + // as it will never match, and the expression doesn't return NULL in any case: + // non_nullable IN (1, 2) -> non_nullable IN (1, 2) + // non_nullable IN (1, 2, NULL) -> non_nullable IN (1, 2) + if (!itemNullable && nullableValues.Count == 0) + { + return inExpression; + } - if (!itemNullable - || (allowOptimizedExpansion && !hasNullValue)) + // If we're in optimized mode and the item isn't nullable (no matter what the values have), or there are no nulls/nullable values, + // also return without compensation; null will only be returned if the item isn't found, and that will evaluate to false in + // optimized mode: + // non_nullable IN (1, 2, NULL, nullable) -> non_nullable IN (1, 2, nullable) (optimized) + // nullable IN (1, 2) -> nullable IN (1, 2) (optimized) + if (allowOptimizedExpansion && (!itemNullable || !valuesHasNull && nullableValues.Count == 0)) { - nullable = false; + return inExpression; + } - // non_nullable IN (1, 2) -> non_nullable IN (1, 2) - // non_nullable IN (1, 2, NULL) -> non_nullable IN (1, 2) - // nullable IN (1, 2) -> nullable IN (1, 2) (optimized) - return simplifiedInExpression; + // At this point, if there are any nullable values, we need to extract them out to create a pure, non-nullable/non-null list of + // values. We'll add them back via separate equality checks. + if (nullableValues.Count > 0) + { + inExpression = ProcessInExpressionValues(inExpression, removeNulls: true, removeNullables: true, out _, out nullableValues); } - nullable = false; + SqlExpression result = inExpression; + // If the item is nullable, we need to add compensation based on whether null was found in the values or not: // nullable IN (1, 2) -> nullable IN (1, 2) AND nullable IS NOT NULL (full) // nullable IN (1, 2, NULL) -> nullable IN (1, 2) OR nullable IS NULL (full) - return hasNullValue - ? _sqlExpressionFactory.OrElse(simplifiedInExpression, _sqlExpressionFactory.IsNull(item)) - : _sqlExpressionFactory.AndAlso(simplifiedInExpression, _sqlExpressionFactory.IsNotNull(item)); + if (itemNullable) + { + result = valuesHasNull + ? _sqlExpressionFactory.OrElse(inExpression, _sqlExpressionFactory.IsNull(item)) + : _sqlExpressionFactory.AndAlso(inExpression, _sqlExpressionFactory.IsNotNull(item)); + } - (SqlConstantExpression ProcessedValuesExpression, List ProcessedValuesList, bool HasNullValue) - ProcessInExpressionValues(SqlExpression valuesExpression, bool extractNullValues) + // If there are no nullables, we're done. + if (nullableValues.Count == 0) { - var inValues = new List(); - var hasNullValue = false; - RelationalTypeMapping? typeMapping; + return result; + } - IEnumerable values; - switch (valuesExpression) - { - case SqlConstantExpression sqlConstant: - typeMapping = sqlConstant.TypeMapping; - values = (IEnumerable)sqlConstant.Value!; - break; - - case SqlParameterExpression sqlParameter: - DoNotCache(); - typeMapping = sqlParameter.TypeMapping; - values = (IEnumerable?)ParameterValues[sqlParameter.Name] ?? Array.Empty(); - break; - - default: - throw new InvalidOperationException( - RelationalStrings.NonConstantOrParameterAsInExpressionValues(valuesExpression.GetType().Name)); - } + // At this point we know that there are nullable values; we need to extract these out and add regular individual equality checks + // for each one. + // non_nullable IN (1, 2, nullable) -> non_nullable IN (1, 2) OR (non_nullable = nullable AND nullable IS NOT NULL) (full) + // non_nullable IN (1, 2, NULL, nullable) -> non_nullable IN (1, 2) OR (non_nullable = nullable AND nullable IS NOT NULL) (full) + return nullableValues.Aggregate( + result, + (expr, nullableValue) => _sqlExpressionFactory.OrElse( + expr, + VisitSqlBinary(_sqlExpressionFactory.Equal(item, nullableValue), allowOptimizedExpansion, out _))); + + InExpression ProcessInExpressionValues( + InExpression inExpression, bool removeNulls, bool removeNullables, out bool hasNull, out List nullables) + { + List? processedValues = null; + (hasNull, nullables) = (false, new List()); - foreach (var value in values) + if (inExpression.ValuesParameter is SqlParameterExpression valuesParameter) { - if (value == null && extractNullValues) + // The InExpression has a values parameter. Expand it out, embedding its values as constants into the SQL; disable SQL + // caching. + DoNotCache(); + var typeMapping = inExpression.ValuesParameter.TypeMapping; + var values = (IEnumerable?)ParameterValues[valuesParameter.Name] ?? Array.Empty(); + + processedValues = new List(); + + foreach (var value in values) { - hasNullValue = true; - continue; - } + if (value == null && removeNulls) + { + hasNull = true; + continue; + } - inValues.Add(value); + processedValues.Add(_sqlExpressionFactory.Constant(value, typeMapping)); + } } + else + { + Check.DebugAssert(inExpression.Values is not null, "inExpression.Values is not null"); - var processedValuesExpression = _sqlExpressionFactory.Constant(inValues, typeMapping); + for (var i = 0; i < inExpression.Values.Count; i++) + { + var value = inExpression.Values[i]; - return (processedValuesExpression, inValues, hasNullValue); - } + if (IsNull(value)) + { + hasNull = true; + + if (removeNulls && processedValues is null) + { + CreateProcessedValues(); + } - SqlExpression SimplifyInExpression( - InExpression inExpression, - SqlConstantExpression inValuesExpression, - List inValuesList) - => inValuesList.Count == 1 - ? _sqlExpressionFactory.Equal( - inExpression.Item, - _sqlExpressionFactory.Constant(inValuesList[0], inExpression.Values!.TypeMapping)) - : inExpression; + continue; + } + + var visitedValue = Visit(value, out var valueNullable); + + if (valueNullable) + { + nullables.Add(visitedValue); + + if (removeNullables) + { + if (processedValues is null) + { + CreateProcessedValues(); + } + + continue; + } + } + + if (value != visitedValue && processedValues is null) + { + CreateProcessedValues(); + } + + processedValues?.Add(visitedValue); + + void CreateProcessedValues() + { + processedValues = new List(inExpression.Values!.Count - 1); + for (var j = 0; j < i; j++) + { + processedValues.Add(inExpression.Values[j]); + } + } + } + } + + return inExpression.Update(inExpression.Item, processedValues ?? inExpression.Values!); + } } /// diff --git a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs index 7ea65b2626a..150640dbe0c 100644 --- a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs @@ -223,10 +223,36 @@ protected override Expression VisitIn(InExpression inExpression) _isSearchCondition = false; var item = (SqlExpression)Visit(inExpression.Item); var subquery = (SelectExpression?)Visit(inExpression.Subquery); - var values = (SqlExpression?)Visit(inExpression.Values); + + var values = inExpression.Values; + SqlExpression[]? newValues = null; + if (values is not null) + { + for (var i = 0; i < values.Count; i++) + { + var value = values[i]; + var newValue = (SqlExpression)Visit(value); + + if (newValue != value && newValues is null) + { + newValues = new SqlExpression[values.Count]; + for (var j = 0; j < i; j++) + { + newValues[j] = values[j]; + } + } + + if (newValues is not null) + { + newValues[i] = newValue; + } + } + } + + var valuesParameter = (SqlParameterExpression?)Visit(inExpression.ValuesParameter); _isSearchCondition = parentSearchCondition; - return ApplyConversion(inExpression.Update(item, values, subquery), condition: true); + return ApplyConversion(inExpression.Update(item, subquery, newValues ?? values, valuesParameter), condition: true); } /// diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitor.cs index 7760b751214..45db664494c 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitor.cs @@ -393,7 +393,7 @@ public SqlServerInferredTypeMappingApplier( IRelationalTypeMappingSource typeMappingSource, ISqlExpressionFactory sqlExpressionFactory, IReadOnlyDictionary<(TableExpressionBase, string), RelationalTypeMapping?> inferredTypeMappings) - : base(inferredTypeMappings) + : base(sqlExpressionFactory, inferredTypeMappings) => (_typeMappingSource, _sqlExpressionFactory) = (typeMappingSource, sqlExpressionFactory); /// diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableMethodTranslatingExpressionVisitor.cs index b2796c81292..045bce96c57 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableMethodTranslatingExpressionVisitor.cs @@ -332,7 +332,6 @@ protected override Expression ApplyInferredTypeMappings( protected class SqliteInferredTypeMappingApplier : RelationalInferredTypeMappingApplier { private readonly IRelationalTypeMappingSource _typeMappingSource; - private readonly ISqlExpressionFactory _sqlExpressionFactory; private Dictionary? _currentSelectInferredTypeMappings; /// @@ -345,8 +344,8 @@ public SqliteInferredTypeMappingApplier( IRelationalTypeMappingSource typeMappingSource, ISqlExpressionFactory sqlExpressionFactory, IReadOnlyDictionary<(TableExpressionBase, string), RelationalTypeMapping?> inferredTypeMappings) - : base(inferredTypeMappings) - => (_typeMappingSource, _sqlExpressionFactory) = (typeMappingSource, sqlExpressionFactory); + : base(sqlExpressionFactory, inferredTypeMappings) + => _typeMappingSource = typeMappingSource; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to diff --git a/src/EFCore/Query/InlineQueryRootExpression.cs b/src/EFCore/Query/InlineQueryRootExpression.cs index 85eb769c139..79dd5db9113 100644 --- a/src/EFCore/Query/InlineQueryRootExpression.cs +++ b/src/EFCore/Query/InlineQueryRootExpression.cs @@ -45,12 +45,23 @@ public InlineQueryRootExpression(IReadOnlyList values, Type elementT public override Expression DetachQueryProvider() => new InlineQueryRootExpression(Values, ElementType); + /// + /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will + /// return this expression. + /// + /// The property of the result. + /// This expression if no children changed, or an expression with the updated children. + public virtual InlineQueryRootExpression Update(IReadOnlyList values) + => ReferenceEquals(values, Values) || values.SequenceEqual(Values) + ? this + : new InlineQueryRootExpression(values, ElementType); + /// protected override Expression VisitChildren(ExpressionVisitor visitor) => visitor.Visit(Values) is var visitedValues && ReferenceEquals(visitedValues, Values) ? this - : new InlineQueryRootExpression(visitedValues, Type); + : Update(visitedValues); /// protected override void Print(ExpressionPrinter expressionPrinter) diff --git a/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs b/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs index 43ce2d91a3a..8a9fbe4eb7c 100644 --- a/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs @@ -515,7 +515,7 @@ public IDictionary Find(Expression expression) var parentEvaluatable = _evaluatable; var parentContainsClosure = _containsClosure; - _evaluatable = IsEvaluatableNodeType(expression) + _evaluatable = IsEvaluatableNodeType(expression, out var preferNoEvaluation) // Extension point to disable funcletization && _evaluatableExpressionFilter.IsEvaluatableExpression(expression, _model) // Don't evaluate QueryableMethods if in compiled query @@ -524,7 +524,7 @@ public IDictionary Find(Expression expression) base.Visit(expression); - if (_evaluatable) + if (_evaluatable && !preferNoEvaluation) { // Force parameterization when not in lambda _evaluatableExpressions[expression] = _containsClosure || !_inLambda; @@ -643,10 +643,23 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio return base.VisitConstant(constantExpression); } - private static bool IsEvaluatableNodeType(Expression expression) - => expression.NodeType != ExpressionType.Extension - || expression.CanReduce - && IsEvaluatableNodeType(expression.ReduceAndCheck()); + private static bool IsEvaluatableNodeType(Expression expression, out bool preferNoEvaluation) + { + switch (expression.NodeType) + { + case ExpressionType.NewArrayInit: + preferNoEvaluation = true; + return true; + + case ExpressionType.Extension: + preferNoEvaluation = false; + return expression.CanReduce && IsEvaluatableNodeType(expression.ReduceAndCheck(), out preferNoEvaluation); + + default: + preferNoEvaluation = false; + return true; + } + } private static bool IsQueryableMethod(Expression expression) => expression is MethodCallExpression methodCallExpression diff --git a/src/EFCore/Query/QueryRootProcessor.cs b/src/EFCore/Query/QueryRootProcessor.cs index 5bc21839385..4c208858e87 100644 --- a/src/EFCore/Query/QueryRootProcessor.cs +++ b/src/EFCore/Query/QueryRootProcessor.cs @@ -50,48 +50,18 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp var argument = methodCallExpression.Arguments[i]; var parameterType = parameters[i].ParameterType; - Expression? visitedArgument = null; - // This converts collections over constants and parameters to query roots, for later translation of LINQ operators over them. // The element type doesn't have to be directly mappable; we allow unknown CLR types in order to support value convertors // (the precise type mapping - with the value converter - will be inferred later based on LINQ operators composed on the root). // However, we do exclude element CLR types which are associated to entity types in our model, since Contains over entity // collections isn't yet supported (#30712). - if (parameterType.IsGenericType + var visitedArgument = parameterType.IsGenericType && (parameterType.GetGenericTypeDefinition() == typeof(IEnumerable<>) || parameterType.GetGenericTypeDefinition() == typeof(IQueryable<>)) && parameterType.GetGenericArguments()[0] is var elementClrType - && !_model.FindEntityTypes(elementClrType).Any()) - { - switch (argument) - { - case ConstantExpression { Value: IEnumerable values } constantExpression - when ShouldConvertToInlineQueryRoot(constantExpression): - - var valueExpressions = new List(); - foreach (var value in values) - { - valueExpressions.Add(Expression.Constant(value, elementClrType)); - } - visitedArgument = new InlineQueryRootExpression(valueExpressions, elementClrType); - break; - - // TODO: Support NewArrayExpression, see #30734. - - case ParameterExpression parameterExpression - when parameterExpression.Name?.StartsWith(QueryCompilationContext.QueryParameterPrefix, StringComparison.Ordinal) - == true - && ShouldConvertToParameterQueryRoot(parameterExpression): - visitedArgument = new ParameterQueryRootExpression(parameterExpression.Type.GetSequenceType(), parameterExpression); - break; - - default: - visitedArgument = null; - break; - } - } - - visitedArgument ??= Visit(argument); + && !_model.FindEntityTypes(elementClrType).Any() + ? VisitQueryRootCandidate(argument, elementClrType) + : Visit(argument); if (newArguments is not null) { @@ -115,12 +85,47 @@ when ShouldConvertToInlineQueryRoot(constantExpression): : methodCallExpression.Update(methodCallExpression.Object, newArguments); } + private Expression VisitQueryRootCandidate(Expression expression, Type elementClrType) + { + switch (expression) + { + // An array containing only constants is represented as a ConstantExpression with the array as the value. + // Convert that into a NewArrayExpression for use with InlineQueryRootExpression + case ConstantExpression { Value: IEnumerable values }: + var valueExpressions = new List(); + foreach (var value in values) + { + valueExpressions.Add(Expression.Constant(value, elementClrType)); + } + + if (ShouldConvertToInlineQueryRoot(Expression.NewArrayInit(elementClrType, valueExpressions))) + { + return new InlineQueryRootExpression(valueExpressions, elementClrType); + } + + goto default; + + case NewArrayExpression newArrayExpression + when ShouldConvertToInlineQueryRoot(newArrayExpression): + return new InlineQueryRootExpression(newArrayExpression.Expressions, elementClrType); + + case ParameterExpression parameterExpression + when parameterExpression.Name?.StartsWith(QueryCompilationContext.QueryParameterPrefix, StringComparison.Ordinal) + == true + && ShouldConvertToParameterQueryRoot(parameterExpression): + return new ParameterQueryRootExpression(parameterExpression.Type.GetSequenceType(), parameterExpression); + + default: + return Visit(expression); + } + } + /// /// Determines whether a should be converted to a . /// This handles cases inline expressions whose elements are all constants. /// - /// The constant expression that's a candidate for conversion to a query root. - protected virtual bool ShouldConvertToInlineQueryRoot(ConstantExpression constantExpression) + /// The new array expression that's a candidate for conversion to a query root. + protected virtual bool ShouldConvertToInlineQueryRoot(NewArrayExpression newArrayExpression) => false; /// diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs index 349c3abd31a..4ead70a6ba1 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs @@ -2745,14 +2745,11 @@ FROM root c public override async Task Array_of_parameters_Contains_OrElse_comparison_with_constant_gets_combined_to_one_in(bool async) { - await base.Array_of_parameters_Contains_OrElse_comparison_with_constant_gets_combined_to_one_in(async); + // #31051 + await AssertTranslationFailed( + () => base.Array_of_parameters_Contains_OrElse_comparison_with_constant_gets_combined_to_one_in(async)); - AssertSql( -""" -SELECT c -FROM root c -WHERE ((c["Discriminator"] = "Customer") AND (c["CustomerID"] IN ("ALFKI", "ANATR") OR (c["CustomerID"] = "ANTON"))) -"""); + AssertSql(); } public override async Task Multiple_OrElse_on_same_column_with_null_parameter_comparison_converted_to_in(bool async) diff --git a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs index 6b698250e53..0cdf792d0fc 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/NullSemanticsQueryTestBase.cs @@ -1152,6 +1152,8 @@ await AssertQueryScalar( ss => ss.Set().Where(e => !new List { null }.Contains(e.NullableIntA)).Select(e => e.Id)); } + #region Contains with subquery + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual async Task Null_semantics_contains_non_nullable_item_with_non_nullable_subquery(bool async) @@ -1220,6 +1222,132 @@ await AssertQueryScalar( .Select(e => e.Id)); } + #endregion Contains with subquery + + #region Contains with inline collection + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Null_semantics_contains_with_non_nullable_item_and_inline_non_nullable_values(bool async) + { + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => new[] { 1, 2 }.Contains(e.IntA)).Select(e => e.Id)); + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !new[] { 1, 2 }.Contains(e.IntA)).Select(e => e.Id)); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Null_semantics_contains_with_non_nullable_item_and_inline_non_nullable_values_with_null(bool async) + { + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => new int?[] { 1, 2, null }.Contains(e.IntA)).Select(e => e.Id)); + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !new int?[] { 1, 2, null }.Contains(e.IntA)).Select(e => e.Id)); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Null_semantics_contains_with_non_nullable_item_and_inline_nullable_values(bool async) + { + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => new[] { 1, 2, e.NullableIntB }.Contains(e.IntA)).Select(e => e.Id)); + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !new[] { 1, 2, e.NullableIntB }.Contains(e.IntA)).Select(e => e.Id)); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Null_semantics_contains_with_non_nullable_item_and_inline_nullable_values_with_null(bool async) + { + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => new[] { 1, 2, e.NullableIntB, null }.Contains(e.IntA)).Select(e => e.Id)); + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !new[] { 1, 2, e.NullableIntB, null }.Contains(e.IntA)).Select(e => e.Id)); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Null_semantics_contains_with_nullable_item_and_inline_non_nullable_values(bool async) + { + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => new int?[] { 1, 2 }.Contains(e.NullableIntA)).Select(e => e.Id)); + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !new int?[] { 1, 2 }.Contains(e.NullableIntA)).Select(e => e.Id)); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Null_semantics_contains_with_nullable_item_and_inline_non_nullable_values_with_null(bool async) + { + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => new int?[] { 1, 2, null }.Contains(e.NullableIntA)).Select(e => e.Id)); + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !new int?[] { 1, 2, null }.Contains(e.NullableIntA)).Select(e => e.Id)); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Null_semantics_contains_with_nullable_item_and_inline_nullable_values(bool async) + { + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => new[] { 1, 2, e.NullableIntB }.Contains(e.NullableIntA)).Select(e => e.Id)); + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !new[] { 1, 2, e.NullableIntB }.Contains(e.NullableIntA)).Select(e => e.Id)); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Null_semantics_contains_with_nullable_item_and_inline_nullable_values_with_null(bool async) + { + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => new[] { 1, 2, e.NullableIntB, null }.Contains(e.NullableIntA)).Select(e => e.Id)); + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !new[] { 1, 2, e.NullableIntB, null }.Contains(e.NullableIntA)).Select(e => e.Id)); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Null_semantics_contains_with_non_nullable_item_and_one_value(bool async) + { + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => new[] { 1 }.Contains(e.IntA)).Select(e => e.Id)); + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !new[] { 1 }.Contains(e.IntA)).Select(e => e.Id)); + + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => new int?[] { null }.Contains(e.IntA)).Select(e => e.Id)); + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !new int?[] { null }.Contains(e.IntA)).Select(e => e.Id)); + + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => new[] { e.NullableIntB }.Contains(e.IntA)).Select(e => e.Id)); + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !new[] { e.NullableIntB }.Contains(e.IntA)).Select(e => e.Id)); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Null_semantics_contains_with_nullable_item_and_one_value(bool async) + { + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => new int?[] { 1 }.Contains(e.NullableIntA)).Select(e => e.Id)); + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !new int?[] { 1 }.Contains(e.NullableIntA)).Select(e => e.Id)); + + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => new int?[] { null }.Contains(e.NullableIntA)).Select(e => e.Id)); + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !new int?[] { null }.Contains(e.NullableIntA)).Select(e => e.Id)); + + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => new[] { e.NullableIntB }.Contains(e.NullableIntA)).Select(e => e.Id)); + await AssertQueryScalar(async, ss => ss.Set() + .Where(e => !new[] { e.NullableIntB }.Contains(e.NullableIntA)).Select(e => e.Id)); + } + + #endregion Contains with inline collection + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual async Task Null_semantics_contains_non_nullable_item_with_values(bool async) diff --git a/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs index 88e9f2aa50c..3f76b23164f 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/UdfDbFunctionTestBase.cs @@ -296,7 +296,12 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) .HasTranslation( args => new InExpression( args.First(), - new SqlConstantExpression(Expression.Constant(abc), typeMapping: null), // args.First().TypeMapping), + new[] + { + new SqlConstantExpression(Expression.Constant(abc[0]), typeMapping: null), + new SqlConstantExpression(Expression.Constant(abc[1]), typeMapping: null), + new SqlConstantExpression(Expression.Constant(abc[2]), typeMapping: null) + }, // args.First().TypeMapping) typeMapping: null)); var trueFalse = new[] { true, false }; @@ -305,9 +310,18 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) args => new InExpression( new InExpression( args.First(), - new SqlConstantExpression(Expression.Constant(abc), args.First().TypeMapping), + new[] + { + new SqlConstantExpression(Expression.Constant(abc[0]), args.First().TypeMapping), + new SqlConstantExpression(Expression.Constant(abc[1]), args.First().TypeMapping), + new SqlConstantExpression(Expression.Constant(abc[2]), args.First().TypeMapping) + }, typeMapping: null), - new SqlConstantExpression(Expression.Constant(trueFalse), typeMapping: null), + new[] + { + new SqlConstantExpression(Expression.Constant(trueFalse[0]), typeMapping: null), + new SqlConstantExpression(Expression.Constant(trueFalse[1]), typeMapping: null) + }, typeMapping: null)); modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(NullableValueReturnType), Array.Empty())) diff --git a/test/EFCore.Specification.Tests/Query/NorthwindAggregateOperatorsQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindAggregateOperatorsQueryTestBase.cs index c9870841924..b4304117daa 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindAggregateOperatorsQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindAggregateOperatorsQueryTestBase.cs @@ -1396,7 +1396,7 @@ public virtual Task Contains_over_scalar_with_null_should_rewrite_to_identity_eq => AssertQuery( async, ss => ss.Set().Where( - o => ss.Set().Where(o => o.CustomerID == "VINET").Select(o => o.CustomerID).Contains(null))); + o => ss.Set().Where(o => o.CustomerID == "VINET").Select(o => o.EmployeeID).Contains(null))); [ConditionalTheory] [MemberData(nameof(IsAsyncData))] @@ -1404,7 +1404,7 @@ public virtual Task Contains_over_entityType_with_null_should_rewrite_to_identit => AssertQuery( async, ss => ss.Set().Where( - o => !ss.Set().Where(o => o.CustomerID == "VINET").Select(o => o.CustomerID).Contains(null)), + o => !ss.Set().Where(o => o.CustomerID == "VINET").Select(o => o.EmployeeID).Contains(null)), entryCount: 830); [ConditionalTheory] @@ -1413,9 +1413,9 @@ public virtual Task Contains_over_entityType_with_null_should_rewrite_to_identit => AssertQuery( async, ss => ss.Set().Where( - o => ss.Set().Where(o => o.CustomerID == "VINET").Select(o => o.CustomerID) + o => ss.Set().Where(o => o.CustomerID == "VINET").Select(o => o.EmployeeID) .Contains(null) - == ss.Set().Where(o => o.CustomerID != "VINET").Select(o => o.CustomerID) + == ss.Set().Where(o => o.CustomerID != "VINET").Select(o => o.EmployeeID) .Contains(null)), entryCount: 830); @@ -1425,7 +1425,7 @@ public virtual Task Contains_over_nullable_scalar_with_null_in_subquery_translat => AssertQueryScalar( async, ss => ss.Set().Select( - o => ss.Set().Where(o => o.CustomerID == "VINET").Select(o => o.CustomerID).Contains(null))); + o => ss.Set().Where(o => o.CustomerID == "VINET").Select(o => o.EmployeeID).Contains(null))); [ConditionalTheory] [MemberData(nameof(IsAsyncData))] diff --git a/test/EFCore.Specification.Tests/Query/PrimitiveCollectionsQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/PrimitiveCollectionsQueryTestBase.cs index f505a33852c..3cf9f2bab6e 100644 --- a/test/EFCore.Specification.Tests/Query/PrimitiveCollectionsQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/PrimitiveCollectionsQueryTestBase.cs @@ -115,15 +115,28 @@ public virtual Task Inline_collection_Contains_with_all_parameters(bool async) [ConditionalTheory] [MemberData(nameof(IsAsyncData))] - public virtual async Task Inline_collection_Contains_with_parameter_and_column_based_expression(bool async) + public virtual Task Inline_collection_Contains_with_constant_and_parameter(bool async) { - var i = 2; + var j = 999; - await AssertTranslationFailed( - () => AssertQuery( - async, - ss => ss.Set().Where(c => new[] { i, c.Int }.Contains(c.Id)), - entryCount: 1)); + return AssertQuery( + async, + ss => ss.Set().Where(c => new[] { 2, j }.Contains(c.Id)), + entryCount: 1); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Inline_collection_Contains_with_mixed_value_types(bool async) + { + // Note: see many nullability-related variations on this in NullSemanticsQueryTestBase + + var i = 11; + + await AssertQuery( + async, + ss => ss.Set().Where(c => new[] { 999, i, c.Id, c.Id + c.Int }.Contains(c.Int)), + entryCount: 1); } [ConditionalTheory] @@ -548,6 +561,19 @@ public virtual Task Column_collection_equality_inline_collection(bool async) ss => ss.Set().Where(c => c.Ints.SequenceEqual(new[] { 1, 10 })), entryCount: 1); + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Column_collection_equality_inline_collection_with_parameters(bool async) + { + var (i, j) = (1, 10); + + return AssertTranslationFailed( + () => AssertQuery( + async, + ss => ss.Set().Where(c => c.Ints == new[] { i, j }), + entryCount: 1)); + } + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual async Task Parameter_collection_in_subquery_Count_as_compiled_query(bool async) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/FiltersInheritanceQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/FiltersInheritanceQuerySqlServerTest.cs index a60cf4ff48a..dfcef05646e 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/FiltersInheritanceQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/FiltersInheritanceQuerySqlServerTest.cs @@ -89,7 +89,7 @@ public override async Task Can_use_of_type_bird_predicate(bool async) """ SELECT [a].[Id], [a].[CountryId], [a].[Discriminator], [a].[Name], [a].[Species], [a].[EagleId], [a].[IsFlightless], [a].[Group], [a].[FoundOn] FROM [Animals] AS [a] -WHERE [a].[CountryId] = 1 AND [a].[CountryId] = 1 +WHERE [a].[CountryId] = 1 ORDER BY [a].[Species] """); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/IncompleteMappingInheritanceQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/IncompleteMappingInheritanceQuerySqlServerTest.cs index d1114119ee0..6191bd1fd84 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/IncompleteMappingInheritanceQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/IncompleteMappingInheritanceQuerySqlServerTest.cs @@ -128,7 +128,7 @@ public override async Task Can_use_is_kiwi(bool async) """ SELECT [a].[Id], [a].[CountryId], [a].[Discriminator], [a].[Name], [a].[Species], [a].[EagleId], [a].[IsFlightless], [a].[Group], [a].[FoundOn] FROM [Animals] AS [a] -WHERE [a].[Discriminator] IN (N'Eagle', N'Kiwi') AND [a].[Discriminator] = N'Kiwi' +WHERE [a].[Discriminator] = N'Kiwi' """); } @@ -233,7 +233,7 @@ public override async Task Can_use_of_type_kiwi(bool async) """ SELECT [a].[Id], [a].[CountryId], [a].[Discriminator], [a].[Name], [a].[Species], [a].[EagleId], [a].[IsFlightless], [a].[FoundOn] FROM [Animals] AS [a] -WHERE [a].[Discriminator] IN (N'Eagle', N'Kiwi') AND [a].[Discriminator] = N'Kiwi' +WHERE [a].[Discriminator] = N'Kiwi' """); } @@ -245,7 +245,7 @@ public override async Task Can_use_of_type_rose(bool async) """ SELECT [p].[Species], [p].[CountryId], [p].[Genus], [p].[Name], [p].[HasThorns] FROM [Plants] AS [p] -WHERE [p].[Genus] IN (1, 0) AND [p].[Genus] = 0 +WHERE [p].[Genus] = 0 """); } @@ -385,7 +385,7 @@ public override async Task Can_use_of_type_kiwi_where_north_on_derived_property( """ SELECT [a].[Id], [a].[CountryId], [a].[Discriminator], [a].[Name], [a].[Species], [a].[EagleId], [a].[IsFlightless], [a].[FoundOn] FROM [Animals] AS [a] -WHERE [a].[Discriminator] IN (N'Eagle', N'Kiwi') AND [a].[Discriminator] = N'Kiwi' AND [a].[FoundOn] = CAST(0 AS tinyint) +WHERE [a].[Discriminator] = N'Kiwi' AND [a].[FoundOn] = CAST(0 AS tinyint) """); } @@ -397,7 +397,7 @@ public override async Task Can_use_of_type_kiwi_where_south_on_derived_property( """ SELECT [a].[Id], [a].[CountryId], [a].[Discriminator], [a].[Name], [a].[Species], [a].[EagleId], [a].[IsFlightless], [a].[FoundOn] FROM [Animals] AS [a] -WHERE [a].[Discriminator] IN (N'Eagle', N'Kiwi') AND [a].[Discriminator] = N'Kiwi' AND [a].[FoundOn] = CAST(1 AS tinyint) +WHERE [a].[Discriminator] = N'Kiwi' AND [a].[FoundOn] = CAST(1 AS tinyint) """); } @@ -433,7 +433,7 @@ public override async Task Discriminator_used_when_projection_over_of_type(bool """ SELECT [a].[FoundOn] FROM [Animals] AS [a] -WHERE [a].[Discriminator] IN (N'Eagle', N'Kiwi') AND [a].[Discriminator] = N'Kiwi' +WHERE [a].[Discriminator] = N'Kiwi' """); } @@ -641,7 +641,7 @@ public override async Task Discriminator_with_cast_in_shadow_property(bool async """ SELECT [a].[Name] AS [Predator] FROM [Animals] AS [a] -WHERE [a].[Discriminator] IN (N'Eagle', N'Kiwi') AND N'Kiwi' = [a].[Discriminator] +WHERE [a].[Discriminator] = N'Kiwi' """); } @@ -681,7 +681,7 @@ public override async Task Using_is_operator_on_multiple_type_with_no_result(boo """ SELECT [a].[Id], [a].[CountryId], [a].[Discriminator], [a].[Name], [a].[Species], [a].[EagleId], [a].[IsFlightless], [a].[Group], [a].[FoundOn] FROM [Animals] AS [a] -WHERE [a].[Discriminator] IN (N'Eagle', N'Kiwi') AND [a].[Discriminator] = N'Kiwi' AND [a].[Discriminator] = N'Eagle' +WHERE 0 = 1 """); } @@ -693,7 +693,7 @@ public override async Task Using_is_operator_with_of_type_on_multiple_type_with_ """ SELECT [a].[Id], [a].[CountryId], [a].[Discriminator], [a].[Name], [a].[Species], [a].[EagleId], [a].[IsFlightless], [a].[Group] FROM [Animals] AS [a] -WHERE [a].[Discriminator] IN (N'Eagle', N'Kiwi') AND [a].[Discriminator] = N'Kiwi' AND [a].[Discriminator] = N'Eagle' +WHERE 0 = 1 """); } @@ -736,7 +736,7 @@ public override async Task GetType_in_hierarchy_in_leaf_type_with_sibling(bool a """ SELECT [a].[Id], [a].[CountryId], [a].[Discriminator], [a].[Name], [a].[Species], [a].[EagleId], [a].[IsFlightless], [a].[Group], [a].[FoundOn] FROM [Animals] AS [a] -WHERE [a].[Discriminator] IN (N'Eagle', N'Kiwi') AND [a].[Discriminator] = N'Eagle' +WHERE [a].[Discriminator] = N'Eagle' """); } @@ -748,7 +748,7 @@ public override async Task GetType_in_hierarchy_in_leaf_type_with_sibling2(bool """ SELECT [a].[Id], [a].[CountryId], [a].[Discriminator], [a].[Name], [a].[Species], [a].[EagleId], [a].[IsFlightless], [a].[Group], [a].[FoundOn] FROM [Animals] AS [a] -WHERE [a].[Discriminator] IN (N'Eagle', N'Kiwi') AND [a].[Discriminator] = N'Kiwi' +WHERE [a].[Discriminator] = N'Kiwi' """); } @@ -760,7 +760,7 @@ public override async Task GetType_in_hierarchy_in_leaf_type_with_sibling2_rever """ SELECT [a].[Id], [a].[CountryId], [a].[Discriminator], [a].[Name], [a].[Species], [a].[EagleId], [a].[IsFlightless], [a].[Group], [a].[FoundOn] FROM [Animals] AS [a] -WHERE [a].[Discriminator] IN (N'Eagle', N'Kiwi') AND [a].[Discriminator] = N'Kiwi' +WHERE [a].[Discriminator] = N'Kiwi' """); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/InheritanceQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/InheritanceQuerySqlServerTest.cs index 7ef197c91da..68125e9cb73 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/InheritanceQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/InheritanceQuerySqlServerTest.cs @@ -652,7 +652,7 @@ public override async Task Using_is_operator_on_multiple_type_with_no_result(boo """ SELECT [a].[Id], [a].[CountryId], [a].[Discriminator], [a].[Name], [a].[Species], [a].[EagleId], [a].[IsFlightless], [a].[Group], [a].[FoundOn] FROM [Animals] AS [a] -WHERE [a].[Discriminator] = N'Kiwi' AND [a].[Discriminator] = N'Eagle' +WHERE 0 = 1 """); } @@ -664,7 +664,7 @@ public override async Task Using_is_operator_with_of_type_on_multiple_type_with_ """ SELECT [a].[Id], [a].[CountryId], [a].[Discriminator], [a].[Name], [a].[Species], [a].[EagleId], [a].[IsFlightless], [a].[Group] FROM [Animals] AS [a] -WHERE [a].[Discriminator] = N'Kiwi' AND [a].[Discriminator] = N'Eagle' +WHERE 0 = 1 """); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqlServerTest.cs index e50cfcaf7a2..863b77077ce 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqlServerTest.cs @@ -2220,7 +2220,7 @@ FROM [Orders] AS [o] WHERE EXISTS ( SELECT 1 FROM [Orders] AS [o0] - WHERE [o0].[CustomerID] = N'VINET' AND [o0].[CustomerID] IS NULL) + WHERE [o0].[CustomerID] = N'VINET' AND [o0].[EmployeeID] IS NULL) """); } @@ -2235,7 +2235,7 @@ FROM [Orders] AS [o] WHERE NOT EXISTS ( SELECT 1 FROM [Orders] AS [o0] - WHERE [o0].[CustomerID] = N'VINET' AND [o0].[CustomerID] IS NULL) + WHERE [o0].[CustomerID] = N'VINET' AND [o0].[EmployeeID] IS NULL) """); } @@ -2251,13 +2251,13 @@ WHERE CASE WHEN EXISTS ( SELECT 1 FROM [Orders] AS [o0] - WHERE [o0].[CustomerID] = N'VINET' AND [o0].[CustomerID] IS NULL) THEN CAST(1 AS bit) + WHERE [o0].[CustomerID] = N'VINET' AND [o0].[EmployeeID] IS NULL) THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END = CASE WHEN EXISTS ( SELECT 1 FROM [Orders] AS [o1] - WHERE ([o1].[CustomerID] <> N'VINET' OR [o1].[CustomerID] IS NULL) AND [o1].[CustomerID] IS NULL) THEN CAST(1 AS bit) + WHERE ([o1].[CustomerID] <> N'VINET' OR [o1].[CustomerID] IS NULL) AND [o1].[EmployeeID] IS NULL) THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END """); @@ -2273,7 +2273,7 @@ SELECT CASE WHEN EXISTS ( SELECT 1 FROM [Orders] AS [o0] - WHERE [o0].[CustomerID] = N'VINET' AND [o0].[CustomerID] IS NULL) THEN CAST(1 AS bit) + WHERE [o0].[CustomerID] = N'VINET' AND [o0].[EmployeeID] IS NULL) THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END FROM [Orders] AS [o] diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs index b6183aed44b..2bb1762a584 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindWhereQuerySqlServerTest.cs @@ -2509,7 +2509,7 @@ public override async Task Constant_array_Contains_OrElse_comparison_with_consta """ SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] FROM [Customers] AS [c] -WHERE [c].[CustomerID] IN (N'ALFKI', N'ANATR', N'ANTON') +WHERE [c].[CustomerID] IN (N'ANTON', N'ALFKI', N'ANATR') """); } @@ -2561,14 +2561,12 @@ public override async Task Array_of_parameters_Contains_OrElse_comparison_with_c // issue #21462 AssertSql( """ -@__p_0='["ALFKI","ANATR"]' (Size = 4000) +@__prm1_0='ALFKI' (Size = 5) (DbType = StringFixedLength) +@__prm2_1='ANATR' (Size = 5) (DbType = StringFixedLength) SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] FROM [Customers] AS [c] -WHERE [c].[CustomerID] IN ( - SELECT [p].[value] - FROM OPENJSON(@__p_0) WITH ([value] nchar(5) '$') AS [p] -) OR [c].[CustomerID] = N'ANTON' +WHERE [c].[CustomerID] IN (@__prm1_0, @__prm2_1, N'ANTON') """); } @@ -2641,7 +2639,7 @@ public override async Task Two_sets_of_comparison_combine_correctly2(bool async) """ SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] FROM [Customers] AS [c] -WHERE [c].[Region] <> N'WA' AND [c].[Region] IS NOT NULL +WHERE [c].[Region] IS NOT NULL AND [c].[Region] <> N'WA' """); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs index 29b53d17c7b..5fc242e50f2 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NullSemanticsQuerySqlServerTest.cs @@ -2017,6 +2017,237 @@ FROM [Entities2] AS [e0] """); } + #region Contains with inline collection + + public override async Task Null_semantics_contains_with_non_nullable_item_and_inline_non_nullable_values(bool async) + { + await base.Null_semantics_contains_with_non_nullable_item_and_inline_non_nullable_values(async); + + AssertSql( +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] IN (1, 2) +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] NOT IN (1, 2) +"""); + } + + public override async Task Null_semantics_contains_with_non_nullable_item_and_inline_non_nullable_values_with_null(bool async) + { + await base.Null_semantics_contains_with_non_nullable_item_and_inline_non_nullable_values_with_null(async); + + AssertSql( +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] IN (1, 2) +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] NOT IN (1, 2) +"""); + } + + public override async Task Null_semantics_contains_with_non_nullable_item_and_inline_nullable_values(bool async) + { + await base.Null_semantics_contains_with_non_nullable_item_and_inline_nullable_values(async); + + AssertSql( +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] IN (1, 2, [e].[NullableIntB]) +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] NOT IN (1, 2) AND ([e].[IntA] <> [e].[NullableIntB] OR [e].[NullableIntB] IS NULL) +"""); + } + + public override async Task Null_semantics_contains_with_non_nullable_item_and_inline_nullable_values_with_null(bool async) + { + await base.Null_semantics_contains_with_non_nullable_item_and_inline_nullable_values_with_null(async); + + AssertSql( +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] IN (1, 2, [e].[NullableIntB]) +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] NOT IN (1, 2) AND ([e].[IntA] <> [e].[NullableIntB] OR [e].[NullableIntB] IS NULL) +"""); + } + + public override async Task Null_semantics_contains_with_nullable_item_and_inline_non_nullable_values(bool async) + { + await base.Null_semantics_contains_with_nullable_item_and_inline_non_nullable_values(async); + + AssertSql( +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] IN (1, 2) +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] NOT IN (1, 2) OR [e].[NullableIntA] IS NULL +"""); + } + + public override async Task Null_semantics_contains_with_nullable_item_and_inline_non_nullable_values_with_null(bool async) + { + await base.Null_semantics_contains_with_nullable_item_and_inline_non_nullable_values_with_null(async); + + AssertSql( +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] IN (1, 2) OR [e].[NullableIntA] IS NULL +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] NOT IN (1, 2) AND [e].[NullableIntA] IS NOT NULL +"""); + } + + public override async Task Null_semantics_contains_with_nullable_item_and_inline_nullable_values(bool async) + { + await base.Null_semantics_contains_with_nullable_item_and_inline_nullable_values(async); + + AssertSql( +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE ([e].[NullableIntA] IN (1, 2) AND [e].[NullableIntA] IS NOT NULL) OR [e].[NullableIntA] = [e].[NullableIntB] OR ([e].[NullableIntA] IS NULL AND [e].[NullableIntB] IS NULL) +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE ([e].[NullableIntA] NOT IN (1, 2) OR [e].[NullableIntA] IS NULL) AND ([e].[NullableIntA] <> [e].[NullableIntB] OR [e].[NullableIntA] IS NULL OR [e].[NullableIntB] IS NULL) AND ([e].[NullableIntA] IS NOT NULL OR [e].[NullableIntB] IS NOT NULL) +"""); + } + + public override async Task Null_semantics_contains_with_nullable_item_and_inline_nullable_values_with_null(bool async) + { + await base.Null_semantics_contains_with_nullable_item_and_inline_nullable_values_with_null(async); + + AssertSql( +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] IN (1, 2) OR [e].[NullableIntA] IS NULL OR [e].[NullableIntA] = [e].[NullableIntB] OR ([e].[NullableIntA] IS NULL AND [e].[NullableIntB] IS NULL) +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] NOT IN (1, 2) AND [e].[NullableIntA] IS NOT NULL AND ([e].[NullableIntA] <> [e].[NullableIntB] OR [e].[NullableIntA] IS NULL OR [e].[NullableIntB] IS NULL) AND ([e].[NullableIntA] IS NOT NULL OR [e].[NullableIntB] IS NOT NULL) +"""); + } + + public override async Task Null_semantics_contains_with_non_nullable_item_and_one_value(bool async) + { + await base.Null_semantics_contains_with_non_nullable_item_and_one_value(async); + + AssertSql( +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] = 1 +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] <> 1 +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE 0 = 1 +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] = [e].[NullableIntB] +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[IntA] <> [e].[NullableIntB] OR [e].[NullableIntB] IS NULL +"""); + } + + public override async Task Null_semantics_contains_with_nullable_item_and_one_value(bool async) + { + await base.Null_semantics_contains_with_nullable_item_and_one_value(async); + + AssertSql( +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] = 1 +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] <> 1 OR [e].[NullableIntA] IS NULL +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] IS NULL +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] IS NOT NULL +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE [e].[NullableIntA] = [e].[NullableIntB] OR ([e].[NullableIntA] IS NULL AND [e].[NullableIntB] IS NULL) +""", + // +""" +SELECT [e].[Id] +FROM [Entities1] AS [e] +WHERE ([e].[NullableIntA] <> [e].[NullableIntB] OR [e].[NullableIntA] IS NULL OR [e].[NullableIntB] IS NULL) AND ([e].[NullableIntA] IS NOT NULL OR [e].[NullableIntB] IS NOT NULL) +"""); + } + + #endregion Contains with inline collection + public override async Task Null_semantics_contains_non_nullable_item_with_values(bool async) { await base.Null_semantics_contains_non_nullable_item_with_values(async); @@ -2590,7 +2821,7 @@ public override async Task Negated_contains_with_comparison_without_null_get_com """ SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC] FROM [Entities1] AS [e] -WHERE [e].[NullableIntA] NOT IN (1, 2, 3) +WHERE [e].[NullableIntA] NOT IN (3, 1, 2) """); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQueryOldSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQueryOldSqlServerTest.cs index 190d0155f33..5d18a463c8c 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQueryOldSqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQueryOldSqlServerTest.cs @@ -50,7 +50,7 @@ public override async Task Inline_collection_of_nullable_ints_Contains_null(bool """ SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[String], [p].[Strings] FROM [PrimitiveCollectionsEntity] AS [p] -WHERE [p].[NullableInt] = 999 OR [p].[NullableInt] IS NULL +WHERE [p].[NullableInt] IS NULL OR [p].[NullableInt] = 999 """); } @@ -151,17 +151,41 @@ public override async Task Inline_collection_Contains_with_all_parameters(bool a AssertSql( """ +@__i_0='2' +@__j_1='999' + SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[String], [p].[Strings] FROM [PrimitiveCollectionsEntity] AS [p] -WHERE [p].[Id] IN (2, 999) +WHERE [p].[Id] IN (@__i_0, @__j_1) """); } - public override async Task Inline_collection_Contains_with_parameter_and_column_based_expression(bool async) + public override async Task Inline_collection_Contains_with_constant_and_parameter(bool async) { - await base.Inline_collection_Contains_with_parameter_and_column_based_expression(async); + await base.Inline_collection_Contains_with_constant_and_parameter(async); - AssertSql(); + AssertSql( +""" +@__j_0='999' + +SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[String], [p].[Strings] +FROM [PrimitiveCollectionsEntity] AS [p] +WHERE [p].[Id] IN (2, @__j_0) +"""); + } + + public override async Task Inline_collection_Contains_with_mixed_value_types(bool async) + { + await base.Inline_collection_Contains_with_mixed_value_types(async); + + AssertSql( +""" +@__i_0='11' + +SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[String], [p].[Strings] +FROM [PrimitiveCollectionsEntity] AS [p] +WHERE [p].[Int] IN (999, @__i_0, [p].[Id], [p].[Id] + [p].[Int]) +"""); } public override async Task Inline_collection_Contains_as_Any_with_predicate(bool async) @@ -223,7 +247,7 @@ public override async Task Parameter_collection_of_nullable_ints_Contains_nullab """ SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[String], [p].[Strings] FROM [PrimitiveCollectionsEntity] AS [p] -WHERE [p].[NullableInt] = 999 OR [p].[NullableInt] IS NULL +WHERE [p].[NullableInt] IS NULL OR [p].[NullableInt] = 999 """); } @@ -433,6 +457,13 @@ FROM [PrimitiveCollectionsEntity] AS [p] """); } + public override async Task Column_collection_equality_inline_collection_with_parameters(bool async) + { + await base.Column_collection_equality_inline_collection_with_parameters(async); + + AssertSql(); + } + public override Task Parameter_collection_in_subquery_Union_column_collection_as_compiled_query(bool async) => AssertTranslationFailed(() => base.Parameter_collection_in_subquery_Union_column_collection_as_compiled_query(async)); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs index 3f742f3ab5c..5dd1e483126 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs @@ -45,7 +45,7 @@ public override async Task Inline_collection_of_nullable_ints_Contains_null(bool """ SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[String], [p].[Strings] FROM [PrimitiveCollectionsEntity] AS [p] -WHERE [p].[NullableInt] = 999 OR [p].[NullableInt] IS NULL +WHERE [p].[NullableInt] IS NULL OR [p].[NullableInt] = 999 """); } @@ -144,26 +144,43 @@ public override async Task Inline_collection_Contains_with_all_parameters(bool a { await base.Inline_collection_Contains_with_all_parameters(async); - // See #30732 for making this better + AssertSql( +""" +@__i_0='2' +@__j_1='999' + +SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[String], [p].[Strings] +FROM [PrimitiveCollectionsEntity] AS [p] +WHERE [p].[Id] IN (@__i_0, @__j_1) +"""); + } + + public override async Task Inline_collection_Contains_with_constant_and_parameter(bool async) + { + await base.Inline_collection_Contains_with_constant_and_parameter(async); AssertSql( """ -@__p_0='[2,999]' (Size = 4000) +@__j_0='999' SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[String], [p].[Strings] FROM [PrimitiveCollectionsEntity] AS [p] -WHERE [p].[Id] IN ( - SELECT [p0].[value] - FROM OPENJSON(@__p_0) WITH ([value] int '$') AS [p0] -) +WHERE [p].[Id] IN (2, @__j_0) """); } - public override async Task Inline_collection_Contains_with_parameter_and_column_based_expression(bool async) + public override async Task Inline_collection_Contains_with_mixed_value_types(bool async) { - await base.Inline_collection_Contains_with_parameter_and_column_based_expression(async); + await base.Inline_collection_Contains_with_mixed_value_types(async); - AssertSql(); + AssertSql( +""" +@__i_0='11' + +SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[String], [p].[Strings] +FROM [PrimitiveCollectionsEntity] AS [p] +WHERE [p].[Int] IN (999, @__i_0, [p].[Id], [p].[Id] + [p].[Int]) +"""); } public override async Task Inline_collection_Contains_as_Any_with_predicate(bool async) @@ -804,6 +821,13 @@ FROM [PrimitiveCollectionsEntity] AS [p] """); } + public override async Task Column_collection_equality_inline_collection_with_parameters(bool async) + { + await base.Column_collection_equality_inline_collection_with_parameters(async); + + AssertSql(); + } + public override async Task Parameter_collection_in_subquery_Union_column_collection_as_compiled_query(bool async) { await base.Parameter_collection_in_subquery_Union_column_collection_as_compiled_query(async); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/TPCFiltersInheritanceQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/TPCFiltersInheritanceQuerySqlServerTest.cs index 472dbd42967..fbfde5ecaeb 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/TPCFiltersInheritanceQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/TPCFiltersInheritanceQuerySqlServerTest.cs @@ -124,7 +124,7 @@ UNION ALL SELECT [k].[Id], [k].[CountryId], [k].[Name], [k].[Species], [k].[EagleId], [k].[IsFlightless], NULL AS [Group], [k].[FoundOn], N'Kiwi' AS [Discriminator] FROM [Kiwi] AS [k] ) AS [t] -WHERE [t].[CountryId] = 1 AND [t].[CountryId] = 1 +WHERE [t].[CountryId] = 1 ORDER BY [t].[Species] """); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/TPTFiltersInheritanceQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/TPTFiltersInheritanceQuerySqlServerTest.cs index 4af39f70b41..ffab3a8a065 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/TPTFiltersInheritanceQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/TPTFiltersInheritanceQuerySqlServerTest.cs @@ -122,7 +122,7 @@ FROM [Animals] AS [a] LEFT JOIN [Birds] AS [b] ON [a].[Id] = [b].[Id] LEFT JOIN [Eagle] AS [e] ON [a].[Id] = [e].[Id] LEFT JOIN [Kiwi] AS [k] ON [a].[Id] = [k].[Id] -WHERE [a].[CountryId] = 1 AND [a].[CountryId] = 1 AND ([k].[Id] IS NOT NULL OR [e].[Id] IS NOT NULL) +WHERE [a].[CountryId] = 1 AND ([k].[Id] IS NOT NULL OR [e].[Id] IS NOT NULL) ORDER BY [a].[Species] """); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/TemporalFiltersInheritanceQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/TemporalFiltersInheritanceQuerySqlServerTest.cs index 3c922201d74..539fc3a5195 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/TemporalFiltersInheritanceQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/TemporalFiltersInheritanceQuerySqlServerTest.cs @@ -112,7 +112,7 @@ public override async Task Can_use_of_type_bird_predicate(bool async) """ SELECT [a].[Id], [a].[CountryId], [a].[Discriminator], [a].[Name], [a].[PeriodEnd], [a].[PeriodStart], [a].[Species], [a].[EagleId], [a].[IsFlightless], [a].[Group], [a].[FoundOn] FROM [Animals] FOR SYSTEM_TIME AS OF '2010-01-01T00:00:00.0000000' AS [a] -WHERE [a].[CountryId] = 1 AND [a].[CountryId] = 1 +WHERE [a].[CountryId] = 1 ORDER BY [a].[Species] """); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs b/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs index a61cfc6485b..626363ec376 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs @@ -336,7 +336,7 @@ public override void Scalar_Function_with_nested_InExpression_translation() SELECT [c].[Id], [c].[FirstName], [c].[LastName] FROM [Customers] AS [c] WHERE CASE - WHEN SUBSTRING([c].[FirstName], 0 + 1, 1) IN (N'A', N'B', N'C') AND SUBSTRING([c].[FirstName], 0 + 1, 1) IS NOT NULL THEN CAST(1 AS bit) + WHEN SUBSTRING([c].[FirstName], 0 + 1, 1) IN (N'A', N'B', N'C') THEN CAST(1 AS bit) ELSE CAST(0 AS bit) END IN (CAST(1 AS bit), CAST(0 AS bit)) """); diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/PrimitiveCollectionsQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/PrimitiveCollectionsQuerySqliteTest.cs index 6b6f446a3f9..f4cdfa407ea 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/PrimitiveCollectionsQuerySqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/PrimitiveCollectionsQuerySqliteTest.cs @@ -47,7 +47,7 @@ public override async Task Inline_collection_of_nullable_ints_Contains_null(bool """ SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."String", "p"."Strings" FROM "PrimitiveCollectionsEntity" AS "p" -WHERE "p"."NullableInt" = 999 OR "p"."NullableInt" IS NULL +WHERE "p"."NullableInt" IS NULL OR "p"."NullableInt" = 999 """); } @@ -146,26 +146,43 @@ public override async Task Inline_collection_Contains_with_all_parameters(bool a { await base.Inline_collection_Contains_with_all_parameters(async); - // See #30732 for making this better + AssertSql( +""" +@__i_0='2' +@__j_1='999' + +SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."String", "p"."Strings" +FROM "PrimitiveCollectionsEntity" AS "p" +WHERE "p"."Id" IN (@__i_0, @__j_1) +"""); + } + + public override async Task Inline_collection_Contains_with_constant_and_parameter(bool async) + { + await base.Inline_collection_Contains_with_constant_and_parameter(async); AssertSql( """ -@__p_0='[2,999]' (Size = 7) +@__j_0='999' SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."String", "p"."Strings" FROM "PrimitiveCollectionsEntity" AS "p" -WHERE "p"."Id" IN ( - SELECT "p0"."value" - FROM json_each(@__p_0) AS "p0" -) +WHERE "p"."Id" IN (2, @__j_0) """); } - public override async Task Inline_collection_Contains_with_parameter_and_column_based_expression(bool async) + public override async Task Inline_collection_Contains_with_mixed_value_types(bool async) { - await base.Inline_collection_Contains_with_parameter_and_column_based_expression(async); + await base.Inline_collection_Contains_with_mixed_value_types(async); - AssertSql(); + AssertSql( +""" +@__i_0='11' + +SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."String", "p"."Strings" +FROM "PrimitiveCollectionsEntity" AS "p" +WHERE "p"."Int" IN (999, @__i_0, "p"."Id", "p"."Id" + "p"."Int") +"""); } public override async Task Inline_collection_Contains_as_Any_with_predicate(bool async) @@ -788,6 +805,13 @@ public override async Task Column_collection_equality_inline_collection(bool asy """); } + public override async Task Column_collection_equality_inline_collection_with_parameters(bool async) + { + await base.Column_collection_equality_inline_collection_with_parameters(async); + + AssertSql(); + } + public override async Task Parameter_collection_in_subquery_Count_as_compiled_query(bool async) { await base.Parameter_collection_in_subquery_Count_as_compiled_query(async); From 641975d5bd2dfbb6259d7c713ab7492bc5003549 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Sat, 10 Jun 2023 19:35:53 +0200 Subject: [PATCH 2/2] Make Contains over array of parameters work again on Cosmos --- .../Properties/CosmosStrings.Designer.cs | 14 ++ .../Properties/CosmosStrings.resx | 6 + .../Internal/CosmosContainsTranslator.cs | 59 ------- .../CosmosMethodCallTranslatorProvider.cs | 1 - ...ressionValuesExpandingExpressionVisitor.cs | 46 +++--- .../CosmosSqlTranslatingExpressionVisitor.cs | 97 ++++++----- .../Query/Internal/ISqlExpressionFactory.cs | 10 +- .../Query/Internal/InExpression.cs | 153 +++++++++++++++++- .../Query/Internal/QuerySqlGenerator.cs | 14 +- .../Query/Internal/SqlExpressionFactory.cs | 88 ++++++++-- .../Query/SqlExpressions/InExpression.cs | 46 +++--- .../Query/SqlNullabilityProcessor.cs | 2 +- .../Query/NorthwindWhereQueryCosmosTest.cs | 14 +- 13 files changed, 383 insertions(+), 167 deletions(-) delete mode 100644 src/EFCore.Cosmos/Query/Internal/CosmosContainsTranslator.cs diff --git a/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs b/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs index e4cd0e2ace6..42b366c6e14 100644 --- a/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs +++ b/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs @@ -233,6 +233,20 @@ public static string NullTypeMappingInSqlTree(object? sqlExpression) public static string OffsetRequiresLimit => GetString("OffsetRequiresLimit"); + /// + /// Exactly one of '{param1}' or '{param2}' must be set. + /// + public static string OneOfTwoValuesMustBeSet(object? param1, object? param2) + => string.Format( + GetString("OneOfTwoValuesMustBeSet", nameof(param1), nameof(param2)), + param1, param2); + + /// + /// Only constants or parameters are currently allowed in Contains. + /// + public static string OnlyConstantsAndParametersAllowedInContains + => GetString("OnlyConstantsAndParametersAllowedInContains"); + /// /// The entity of type '{entityType}' is mapped as a part of the document mapped to '{missingEntityType}', but there is no tracked entity of this type with the corresponding key value. Consider using 'DbContextOptionsBuilder.EnableSensitiveDataLogging' to see the key values. /// diff --git a/src/EFCore.Cosmos/Properties/CosmosStrings.resx b/src/EFCore.Cosmos/Properties/CosmosStrings.resx index 6f0de57243e..8b9de6ac560 100644 --- a/src/EFCore.Cosmos/Properties/CosmosStrings.resx +++ b/src/EFCore.Cosmos/Properties/CosmosStrings.resx @@ -229,6 +229,12 @@ Cosmos SQL does not allow Offset without Limit. Consider specifying a 'Take' operation on the query. + + Exactly one of '{param1}' or '{param2}' must be set. + + + Only constants or parameters are currently allowed in Contains. + The entity of type '{entityType}' is mapped as a part of the document mapped to '{missingEntityType}', but there is no tracked entity of this type with the corresponding key value. Consider using 'DbContextOptionsBuilder.EnableSensitiveDataLogging' to see the key values. diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosContainsTranslator.cs b/src/EFCore.Cosmos/Query/Internal/CosmosContainsTranslator.cs deleted file mode 100644 index 83269423772..00000000000 --- a/src/EFCore.Cosmos/Query/Internal/CosmosContainsTranslator.cs +++ /dev/null @@ -1,59 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; - -/// -/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to -/// the same compatibility standards as public APIs. It may be changed or removed without notice in -/// any release. You should only use it directly in your code with extreme caution and knowing that -/// doing so can result in application failures when updating to a new Entity Framework Core release. -/// -public class CosmosContainsTranslator : IMethodCallTranslator -{ - private readonly ISqlExpressionFactory _sqlExpressionFactory; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public CosmosContainsTranslator(ISqlExpressionFactory sqlExpressionFactory) - { - _sqlExpressionFactory = sqlExpressionFactory; - } - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public virtual SqlExpression? Translate( - SqlExpression? instance, - MethodInfo method, - IReadOnlyList arguments, - IDiagnosticsLogger logger) - { - if (method.IsGenericMethod - && method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains) - && ValidateValues(arguments[0])) - { - return _sqlExpressionFactory.In(arguments[1], arguments[0]); - } - - if (arguments.Count == 1 - && method.IsContainsMethod() - && instance != null - && ValidateValues(instance)) - { - return _sqlExpressionFactory.In(arguments[0], instance); - } - - return null; - } - - private static bool ValidateValues(SqlExpression values) - => values is SqlConstantExpression || values is SqlParameterExpression; -} diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosMethodCallTranslatorProvider.cs b/src/EFCore.Cosmos/Query/Internal/CosmosMethodCallTranslatorProvider.cs index a3ced0b0da5..c5bb70e0394 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosMethodCallTranslatorProvider.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosMethodCallTranslatorProvider.cs @@ -31,7 +31,6 @@ public CosmosMethodCallTranslatorProvider( { new CosmosEqualsTranslator(sqlExpressionFactory), new CosmosStringMethodTranslator(sqlExpressionFactory), - new CosmosContainsTranslator(sqlExpressionFactory), new CosmosRandomTranslator(sqlExpressionFactory), new CosmosMathTranslator(sqlExpressionFactory), new CosmosRegexTranslator(sqlExpressionFactory) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.InExpressionValuesExpandingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.InExpressionValuesExpandingExpressionVisitor.cs index 2e74cc6317b..3202766415d 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.InExpressionValuesExpandingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.InExpressionValuesExpandingExpressionVisitor.cs @@ -4,6 +4,7 @@ #nullable disable using System.Collections; +using Microsoft.EntityFrameworkCore.Cosmos.Internal; namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; @@ -26,36 +27,39 @@ public override Expression Visit(Expression expression) { if (expression is InExpression inExpression) { - var inValues = new List(); + var inValues = new List(); var hasNullValue = false; - CoreTypeMapping typeMapping = null; - switch (inExpression.Values) + switch (inExpression) { - case SqlConstantExpression sqlConstant: + case { ValuesParameter: SqlParameterExpression valuesParameter }: { - typeMapping = sqlConstant.TypeMapping; - var values = (IEnumerable)sqlConstant.Value; - foreach (var value in values) + var typeMapping = valuesParameter.TypeMapping; + + foreach (var value in (IEnumerable)_parametersValues[valuesParameter.Name]) { - if (value == null) + if (value is null) { hasNullValue = true; continue; } - inValues.Add(value); + inValues.Add(_sqlExpressionFactory.Constant(value, typeMapping)); } - } + break; + } - case SqlParameterExpression sqlParameter: + case { Values: IReadOnlyList values }: { - typeMapping = sqlParameter.TypeMapping; - var values = (IEnumerable)_parametersValues[sqlParameter.Name]; foreach (var value in values) { - if (value == null) + if (value is not (SqlConstantExpression or SqlParameterExpression)) + { + throw new InvalidOperationException(CosmosStrings.OnlyConstantsAndParametersAllowedInContains); + } + + if (IsNull(value)) { hasNullValue = true; continue; @@ -63,14 +67,16 @@ public override Expression Visit(Expression expression) inValues.Add(value); } - } + break; + } + + default: + throw new InvalidOperationException("IMPOSSIBLE"); } var updatedInExpression = inValues.Count > 0 - ? _sqlExpressionFactory.In( - (SqlExpression)Visit(inExpression.Item), - _sqlExpressionFactory.Constant(inValues, typeMapping)) + ? _sqlExpressionFactory.In((SqlExpression)Visit(inExpression.Item), inValues) : null; var nullCheckExpression = hasNullValue @@ -94,5 +100,9 @@ public override Expression Visit(Expression expression) return base.Visit(expression); } + + private bool IsNull(SqlExpression expression) + => expression is SqlConstantExpression { Value: null } + || expression is SqlParameterExpression { Name: string parameterName } && _parametersValues[parameterName] is null; } } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs index e67d989751e..54b75a57ab3 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs @@ -510,45 +510,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp else if (method.IsGenericMethod && method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)) { - var enumerable = Visit(methodCallExpression.Arguments[0]); - var item = Visit(methodCallExpression.Arguments[1]); - - if (TryRewriteContainsEntity(enumerable, item ?? methodCallExpression.Arguments[1], out var result)) - { - return result; - } - - if (enumerable is SqlExpression sqlEnumerable - && item is SqlExpression sqlItem) - { - arguments = new[] { sqlEnumerable, sqlItem }; - } - else - { - return null; - } + return TranslateContains(methodCallExpression.Arguments[1], methodCallExpression.Arguments[0]); } else if (methodCallExpression.Arguments.Count == 1 && method.IsContainsMethod()) { - var enumerable = Visit(methodCallExpression.Object); - var item = Visit(methodCallExpression.Arguments[0]); - - if (TryRewriteContainsEntity(enumerable, item ?? methodCallExpression.Arguments[0], out var result)) - { - return result; - } - - if (enumerable is SqlExpression sqlEnumerable - && item is SqlExpression sqlItem) - { - sqlObject = sqlEnumerable; - arguments = new[] { sqlItem }; - } - else - { - return null; - } + return TranslateContains(methodCallExpression.Arguments[0], methodCallExpression.Object); } else { @@ -591,6 +558,64 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp return translation; + Expression TranslateContains(Expression untranslatedItem, Expression untranslatedCollection) + { + var collection = Visit(untranslatedCollection); + var itemUnchecked = Visit(untranslatedItem); + + if (TryRewriteContainsEntity(collection, itemUnchecked ?? untranslatedItem, out var result)) + { + return result; + } + + if (itemUnchecked is not SqlExpression translatedItem) + { + return null; + } + + switch (collection) + { + // If the collection was an inline NewArrayExpression with constants only, we get a single constant for that array. + case SqlConstantExpression { Value: IEnumerable values, TypeMapping: var typeMapping }: + { + var translatedValues = values is IList iList + ? new List(iList.Count) + : new List(); + foreach (var value in values) + { + translatedValues.Add(_sqlExpressionFactory.Constant(value, typeMapping)); + } + return _sqlExpressionFactory.In(translatedItem, translatedValues); + } + + // If the collection was an inline NewArrayExpression with at least one non-constant, the NewArrayExpression makes it + // as-is to translation, where it (currently) cannot be translated. Identify this case and translate the elements. + case not SqlExpression when untranslatedCollection is NewArrayExpression { Expressions: var values }: + { + var translatedValues = new SqlExpression[values.Count]; + for (var i = 0; i < values.Count; i++) + { + if (Visit(values[i]) is not SqlExpression value) + { + return null; + } + + translatedValues[i] = value; + } + + return _sqlExpressionFactory.In(translatedItem, translatedValues); + } + + // If the collection was a captured variable (parameter), construct an InExpression over that; + // InExpressionValuesExpandingExpressionVisitor will expand the values as constants later. + case SqlParameterExpression sqlParameterExpression: + return _sqlExpressionFactory.In(translatedItem, sqlParameterExpression); + + default: + return null; + } + } + static Expression RemoveObjectConvert(Expression expression) => expression is UnaryExpression unaryExpression && (unaryExpression.NodeType == ExpressionType.Convert || unaryExpression.NodeType == ExpressionType.ConvertChecked) @@ -717,7 +742,7 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp _sqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue())) : _sqlExpressionFactory.In( discriminatorColumn, - _sqlExpressionFactory.Constant(concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList())); + concreteEntityTypes.Select(et => _sqlExpressionFactory.Constant(et.GetDiscriminatorValue())).ToArray()); } } diff --git a/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs b/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs index 01c619fa612..6981ebbc73e 100644 --- a/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs +++ b/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs @@ -256,7 +256,15 @@ SqlConditionalExpression Condition( /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - InExpression In(SqlExpression item, SqlExpression values); + InExpression In(SqlExpression item, SqlParameterExpression valuesParameter); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + InExpression In(SqlExpression item, IReadOnlyList values); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to diff --git a/src/EFCore.Cosmos/Query/Internal/InExpression.cs b/src/EFCore.Cosmos/Query/Internal/InExpression.cs index 4a3339c950b..7b9bfc575bd 100644 --- a/src/EFCore.Cosmos/Query/Internal/InExpression.cs +++ b/src/EFCore.Cosmos/Query/Internal/InExpression.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using Microsoft.EntityFrameworkCore.Cosmos.Internal; + namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; /// @@ -19,12 +21,36 @@ public class InExpression : SqlExpression /// public InExpression( SqlExpression item, - SqlExpression values, + IReadOnlyList values, + CoreTypeMapping typeMapping) + : this(item, values, valuesParameter: null, typeMapping) + { + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public InExpression( + SqlExpression item, + SqlParameterExpression valuesParameter, CoreTypeMapping typeMapping) + : this(item, values: null, valuesParameter, typeMapping) + { + } + + private InExpression( + SqlExpression item, + IReadOnlyList? values, + SqlParameterExpression? valuesParameter, + CoreTypeMapping? typeMapping) : base(typeof(bool), typeMapping) { Item = item; Values = values; + ValuesParameter = valuesParameter; } /// @@ -35,13 +61,18 @@ public InExpression( /// public virtual SqlExpression Item { get; } + /// + /// The list of values to search the item in. + /// + public virtual IReadOnlyList? Values { get; } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public virtual SqlExpression Values { get; } + public virtual SqlParameterExpression? ValuesParameter { get; } /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -52,22 +83,88 @@ public InExpression( protected override Expression VisitChildren(ExpressionVisitor visitor) { var newItem = (SqlExpression)visitor.Visit(Item); - var values = (SqlExpression)visitor.Visit(Values); - return Update(newItem, values); + SqlExpression[]? values = null; + if (Values is not null) + { + for (var i = 0; i < Values.Count; i++) + { + var value = Values[i]; + var newValue = (SqlExpression)visitor.Visit(value); + + if (newValue != value && values is null) + { + values = new SqlExpression[Values.Count]; + for (var j = 0; j < i; j++) + { + values[j] = Values[j]; + } + } + + if (values is not null) + { + values[i] = newValue; + } + } + } + + var valuesParameter = (SqlParameterExpression?)visitor.Visit(ValuesParameter); + + return Update(newItem, values ?? Values, valuesParameter); } + /// + /// Applies supplied type mapping to this expression. + /// + /// A relational type mapping to apply. + /// A new expression which has supplied type mapping. + public virtual InExpression ApplyTypeMapping(CoreTypeMapping? typeMapping) + => new(Item, Values, ValuesParameter, typeMapping); + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public virtual InExpression Update(SqlExpression item, SqlExpression values) + public virtual InExpression Update(SqlExpression item, IReadOnlyList values) => item != Item || values != Values ? new InExpression(item, values, TypeMapping!) : this; + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual InExpression Update(SqlExpression item, SqlParameterExpression valuesParameter) + => item != Item || ValuesParameter != valuesParameter + ? new InExpression(item, valuesParameter, TypeMapping!) + : this; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual InExpression Update( + SqlExpression item, + IReadOnlyList? values, + SqlParameterExpression? valuesParameter) + { + if (!(values is null ^ valuesParameter is null)) + { + throw new ArgumentException( + CosmosStrings.OneOfTwoValuesMustBeSet(nameof(values), nameof(valuesParameter))); + } + + return item == Item && values == Values && valuesParameter == ValuesParameter + ? this + : new InExpression(item, values, valuesParameter, TypeMapping); + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -79,7 +176,30 @@ protected override void Print(ExpressionPrinter expressionPrinter) expressionPrinter.Visit(Item); expressionPrinter.Append(" IN "); expressionPrinter.Append("("); - expressionPrinter.Visit(Values); + + switch (this) + { + case { Values: not null }: + for (var i = 0; i < Values.Count; i++) + { + if (i > 0) + { + expressionPrinter.Append(", "); + } + + expressionPrinter.Visit(Values[i]); + } + + break; + + case { ValuesParameter: not null}: + expressionPrinter.Visit(ValuesParameter); + break; + + default: + throw new ArgumentOutOfRangeException(); + } + expressionPrinter.Append(")"); } @@ -98,7 +218,9 @@ public override bool Equals(object? obj) private bool Equals(InExpression inExpression) => base.Equals(inExpression) && Item.Equals(inExpression.Item) - && Values.Equals(inExpression.Values); + && (ValuesParameter?.Equals(inExpression.ValuesParameter) ?? inExpression.ValuesParameter == null) + && (ReferenceEquals(Values, inExpression.Values) + || (Values is not null && inExpression.Values is not null && Values.SequenceEqual(inExpression.Values))); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -107,5 +229,20 @@ private bool Equals(InExpression inExpression) /// doing so can result in application failures when updating to a new Entity Framework Core release. /// public override int GetHashCode() - => HashCode.Combine(base.GetHashCode(), Item, Values); + { + var hash = new HashCode(); + hash.Add(base.GetHashCode()); + hash.Add(Item); + hash.Add(ValuesParameter); + + if (Values is not null) + { + for (var i = 0; i < Values.Count; i++) + { + hash.Add(Values[i]); + } + } + + return hash.ToHashCode(); + } } diff --git a/src/EFCore.Cosmos/Query/Internal/QuerySqlGenerator.cs b/src/EFCore.Cosmos/Query/Internal/QuerySqlGenerator.cs index 4d1bc4207cb..a432de2e417 100644 --- a/src/EFCore.Cosmos/Query/Internal/QuerySqlGenerator.cs +++ b/src/EFCore.Cosmos/Query/Internal/QuerySqlGenerator.cs @@ -528,13 +528,15 @@ protected sealed override Expression VisitIn(InExpression inExpression) /// protected virtual void GenerateIn(InExpression inExpression, bool negated) { + Check.DebugAssert( + inExpression.ValuesParameter is null, + "InExpression.ValuesParameter must have been expanded to constants before SQL generation (in " + + "InExpressionValuesExpandingExpressionVisitor)"); + Check.DebugAssert(inExpression.Values is not null, "Missing Values on InExpression"); + Visit(inExpression.Item); - _sqlBuilder.Append(negated ? " NOT IN " : " IN "); - _sqlBuilder.Append('('); - var valuesConstant = (SqlConstantExpression)inExpression.Values; - var valuesList = ((IEnumerable)valuesConstant.Value) - .Select(v => new SqlConstantExpression(Expression.Constant(v), valuesConstant.TypeMapping)).ToList(); - GenerateList(valuesList, e => Visit(e)); + _sqlBuilder.Append(negated ? " NOT IN (" : " IN ("); + GenerateList(inExpression.Values, e => Visit(e)); _sqlBuilder.Append(')'); } diff --git a/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs b/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs index a571a7914b7..d1ad7893f98 100644 --- a/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs +++ b/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs @@ -3,6 +3,7 @@ using System.Diagnostics.CodeAnalysis; using Microsoft.EntityFrameworkCore.Cosmos.Internal; +using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; using Microsoft.EntityFrameworkCore.Internal; namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; @@ -207,6 +208,72 @@ private SqlExpression ApplyTypeMappingOnSqlBinary( resultTypeMapping); } + private InExpression ApplyTypeMappingOnIn(InExpression inExpression) + { + var missingTypeMappingInValues = false; + + CoreTypeMapping? valuesTypeMapping = null; + switch (inExpression) + { + case { ValuesParameter: SqlParameterExpression parameter }: + valuesTypeMapping = parameter.TypeMapping; + break; + + case { Values: IReadOnlyList values }: + // Note: there could be conflicting type mappings inside the values; we take the first. + foreach (var value in values) + { + if (value.TypeMapping is null) + { + missingTypeMappingInValues = true; + } + else + { + valuesTypeMapping = value.TypeMapping; + } + } + + break; + + default: + throw new ArgumentOutOfRangeException(); + } + + var item = ApplyTypeMapping( + inExpression.Item, + valuesTypeMapping ?? _typeMappingSource.FindMapping(inExpression.Item.Type, _model)); + + switch (inExpression) + { + case { ValuesParameter: SqlParameterExpression parameter }: + inExpression = inExpression.Update(item, (SqlParameterExpression)ApplyTypeMapping(parameter, item.TypeMapping)); + break; + + case { Values: IReadOnlyList values }: + SqlExpression[]? newValues = null; + + if (missingTypeMappingInValues) + { + newValues = new SqlExpression[values.Count]; + + for (var i = 0; i < newValues.Length; i++) + { + newValues[i] = ApplyTypeMapping(values[i], item.TypeMapping); + } + } + + inExpression = inExpression.Update(item, newValues ?? values); + break; + + default: + throw new ArgumentOutOfRangeException(); + } + + return inExpression.TypeMapping == _boolTypeMapping + ? inExpression + : inExpression.ApplyTypeMapping(_boolTypeMapping); + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -473,15 +540,17 @@ public virtual SqlConditionalExpression Condition(SqlExpression test, SqlExpress /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public virtual InExpression In(SqlExpression item, SqlExpression values) - { - var typeMapping = item.TypeMapping ?? _typeMappingSource.FindMapping(item.Type, _model); + public virtual InExpression In(SqlExpression item, IReadOnlyList values) + => ApplyTypeMappingOnIn(new InExpression(item, values, _boolTypeMapping)); - item = ApplyTypeMapping(item, typeMapping); - values = ApplyTypeMapping(values, typeMapping); - - return new InExpression(item, values, _boolTypeMapping); - } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual InExpression In(SqlExpression item, SqlParameterExpression valuesParameter) + => ApplyTypeMappingOnIn(new InExpression(item, valuesParameter, _boolTypeMapping)); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -538,8 +607,7 @@ private void AddDiscriminator(SelectExpression selectExpression, IEntityType ent .BindProperty(concreteEntityTypes[0].FindDiscriminatorProperty(), clientEval: false); selectExpression.ApplyPredicate( - In( - (SqlExpression)discriminatorColumn, Constant(concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()))); + In((SqlExpression)discriminatorColumn, concreteEntityTypes.Select(et => Constant(et.GetDiscriminatorValue())).ToArray())); } } } diff --git a/src/EFCore.Relational/Query/SqlExpressions/InExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/InExpression.cs index c4ec59efdd4..335a3e0e2fb 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/InExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/InExpression.cs @@ -205,36 +205,36 @@ protected override void Print(ExpressionPrinter expressionPrinter) expressionPrinter.Append(" IN "); expressionPrinter.Append("("); - switch (this) - { - case { Subquery: not null }: - using (expressionPrinter.Indent()) - { - expressionPrinter.Visit(Subquery); - } + switch (this) + { + case { Subquery: not null }: + using (expressionPrinter.Indent()) + { + expressionPrinter.Visit(Subquery); + } - break; + break; - case { Values: not null }: - for (var i = 0; i < Values.Count; i++) - { - if (i > 0) + case { Values: not null }: + for (var i = 0; i < Values.Count; i++) { - expressionPrinter.Append(", "); - } + if (i > 0) + { + expressionPrinter.Append(", "); + } - expressionPrinter.Visit(Values[i]); - } + expressionPrinter.Visit(Values[i]); + } - break; + break; - case { ValuesParameter: not null}: - expressionPrinter.Visit(ValuesParameter); - break; + case { ValuesParameter: not null}: + expressionPrinter.Visit(ValuesParameter); + break; - default: - throw new ArgumentOutOfRangeException(); - } + default: + throw new ArgumentOutOfRangeException(); + } expressionPrinter.Append(")"); } diff --git a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs index 9c73b434e18..3e14e44e7e8 100644 --- a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs +++ b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs @@ -887,7 +887,7 @@ InExpression ProcessInExpressionValues( foreach (var value in values) { - if (value == null && removeNulls) + if (value is null && removeNulls) { hasNull = true; continue; diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs index 4ead70a6ba1..4ca2ab54d38 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindWhereQueryCosmosTest.cs @@ -2745,11 +2745,17 @@ FROM root c public override async Task Array_of_parameters_Contains_OrElse_comparison_with_constant_gets_combined_to_one_in(bool async) { - // #31051 - await AssertTranslationFailed( - () => base.Array_of_parameters_Contains_OrElse_comparison_with_constant_gets_combined_to_one_in(async)); + await base.Array_of_parameters_Contains_OrElse_comparison_with_constant_gets_combined_to_one_in(async); - AssertSql(); + AssertSql( +""" +@__prm1_0='ALFKI' +@__prm2_1='ANATR' + +SELECT c +FROM root c +WHERE ((c["Discriminator"] = "Customer") AND (c["CustomerID"] IN (@__prm1_0, @__prm2_1) OR (c["CustomerID"] = "ANTON"))) +"""); } public override async Task Multiple_OrElse_on_same_column_with_null_parameter_comparison_converted_to_in(bool async)