diff --git a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs index df1267503..69b160ebe 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.ObjectModel; using System.Diagnostics; using System.Linq; using System.Linq.Expressions; @@ -10,35 +11,47 @@ using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Storage; using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal; using Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal; -using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping; using ExpressionExtensions = Microsoft.EntityFrameworkCore.Query.ExpressionExtensions; namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.Internal { public class NpgsqlSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExpressionVisitor { - [NotNull] static readonly MethodInfo Like2MethodInfo = + [NotNull] + static readonly MethodInfo Like2MethodInfo = typeof(DbFunctionsExtensions) .GetRuntimeMethod(nameof(DbFunctionsExtensions.Like), new[] { typeof(DbFunctions), typeof(string), typeof(string) }); // ReSharper disable once InconsistentNaming - [NotNull] static readonly MethodInfo ILike2MethodInfo = + [NotNull] + static readonly MethodInfo ILike2MethodInfo = typeof(NpgsqlDbFunctionsExtensions) .GetRuntimeMethod(nameof(NpgsqlDbFunctionsExtensions.ILike), new[] { typeof(DbFunctions), typeof(string), typeof(string) }); - [NotNull] static readonly MethodInfo EnumerableAnyWithPredicate = + [NotNull] + static readonly MethodInfo EnumerableAnyWithPredicate = typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) .Single(mi => mi.Name == nameof(Enumerable.Any) && mi.GetParameters().Length == 2); - [NotNull] static readonly MethodInfo EnumerableAll = + [NotNull] + static readonly MethodInfo EnumerableAll = typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) .Single(mi => mi.Name == nameof(Enumerable.All) && mi.GetParameters().Length == 2); + [NotNull] + static readonly MethodInfo Contains = + typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly) + .Single(m => m.Name == nameof(Enumerable.Contains) && m.GetParameters().Length == 2); + readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; readonly NpgsqlJsonPocoTranslator _jsonPocoTranslator; + [NotNull] + readonly RelationalTypeMapping _boolMapping; + public NpgsqlSqlTranslatingExpressionVisitor( RelationalSqlTranslatingExpressionVisitorDependencies dependencies, IModel model, @@ -47,6 +60,7 @@ public NpgsqlSqlTranslatingExpressionVisitor( { _sqlExpressionFactory = (NpgsqlSqlExpressionFactory)dependencies.SqlExpressionFactory; _jsonPocoTranslator = ((NpgsqlMemberTranslatorProvider)Dependencies.MemberTranslatorProvider).JsonPocoTranslator; + _boolMapping = _sqlExpressionFactory.FindMapping(typeof(bool)); } // PostgreSQL COUNT() always returns bigint, so we need to downcast to int @@ -109,37 +123,52 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) protected override Expression VisitMethodCall(MethodCallExpression methodCall) { + var visited = base.VisitMethodCall(methodCall); + if (visited != null) + return visited; + // TODO: Handle List<> + if (methodCall.Arguments.Count > 0 && methodCall.Arguments[0].Type.IsArray) + return VisitArrayMethodCall(methodCall.Method, methodCall.Arguments); + + return null; + } + /// + /// Identifies complex array-related constructs which cannot be translated in regular method translators, since + /// they require accessing lambdas. + /// + protected virtual Expression VisitArrayMethodCall(MethodInfo method, ReadOnlyCollection arguments) + { { // Pattern match for .Where(e => new[] { "a", "b", "c" }.Any(p => EF.Functions.Like(e.SomeText, p))), // which we translate to WHERE s.""SomeText"" LIKE ANY (ARRAY['a','b','c']) (see test Any_like) // Note: NavigationExpander normalized Any(x) to Where(x).Any() - if (methodCall.Method.IsClosedFormOf(EnumerableAnyWithPredicate) && - methodCall.Arguments[0].Type.IsArray && - methodCall.Arguments[1] is LambdaExpression wherePredicate && + if (method.IsClosedFormOf(EnumerableAnyWithPredicate) && + arguments[1] is LambdaExpression wherePredicate && wherePredicate.Body is MethodCallExpression wherePredicateMethodCall && ( wherePredicateMethodCall.Method == Like2MethodInfo || wherePredicateMethodCall.Method == ILike2MethodInfo)) { - var array = (SqlExpression)Visit(methodCall.Arguments[0]); + var array = (SqlExpression)Visit(arguments[0]); var match = (SqlExpression)Visit(wherePredicateMethodCall.Arguments[1]); return _sqlExpressionFactory.ArrayAnyAll(match, array, ArrayComparisonType.Any, wherePredicateMethodCall.Method == Like2MethodInfo ? "LIKE" : "ILIKE"); } + + // Note: we also handle the above with equality instead of Like, see NpgsqlArrayMethodTranslator } - // Same for All (but without the normalization { - if (methodCall.Method.IsClosedFormOf(EnumerableAll) && - methodCall.Arguments[0].Type.IsArray && - methodCall.Arguments[1] is LambdaExpression wherePredicate && + // Same for All (but without the normalization) + if (method.IsClosedFormOf(EnumerableAll) && + arguments[1] is LambdaExpression wherePredicate && wherePredicate.Body is MethodCallExpression wherePredicateMethodCall && ( wherePredicateMethodCall.Method == Like2MethodInfo || wherePredicateMethodCall.Method == ILike2MethodInfo)) { - var array = (SqlExpression)Visit(methodCall.Arguments[0]); + var array = (SqlExpression)Visit(arguments[0]); var match = (SqlExpression)Visit(wherePredicateMethodCall.Arguments[1]); return _sqlExpressionFactory.ArrayAnyAll(match, array, ArrayComparisonType.All, @@ -147,8 +176,55 @@ methodCall.Arguments[1] is LambdaExpression wherePredicate && } } - // Note: we also handle the above with equality instead of Like, see NpgsqlArrayMethodTranslator - return base.VisitMethodCall(methodCall); + { + // Translate e => new[] { 4, 5 }.Any(p => e.SomeArray.Contains(p)), + // using array overlap (&&) + if (method.IsClosedFormOf(EnumerableAnyWithPredicate) && + arguments[1] is LambdaExpression wherePredicate && + wherePredicate.Body is MethodCallExpression wherePredicateMethodCall && + wherePredicateMethodCall.Method.IsClosedFormOf(Contains) && + wherePredicateMethodCall.Arguments[0].Type.IsArray && + wherePredicateMethodCall.Arguments[1] is ParameterExpression parameterExpression && + parameterExpression == wherePredicate.Parameters[0]) + { + var array1 = (SqlExpression)Visit(arguments[0]); + var array2 = (SqlExpression)Visit(wherePredicateMethodCall.Arguments[0]); + var inferredMapping = ExpressionExtensions.InferTypeMapping(array1, array2); + + return new SqlCustomBinaryExpression( + _sqlExpressionFactory.ApplyTypeMapping(array1, inferredMapping), + _sqlExpressionFactory.ApplyTypeMapping(array2, inferredMapping), + "&&", + typeof(bool), + _boolMapping); + } + } + + { + // Translate e => new[] { 4, 5 }.All(p => e.SomeArray.Contains(p)), + // using array containment (<@) + if (method.IsClosedFormOf(EnumerableAll) && + arguments[1] is LambdaExpression wherePredicate && + wherePredicate.Body is MethodCallExpression wherePredicateMethodCall && + wherePredicateMethodCall.Method.IsClosedFormOf(Contains) && + wherePredicateMethodCall.Arguments[0].Type.IsArray && + wherePredicateMethodCall.Arguments[1] is ParameterExpression parameterExpression && + parameterExpression == wherePredicate.Parameters[0]) + { + var array1 = (SqlExpression)Visit(arguments[0]); + var array2 = (SqlExpression)Visit(wherePredicateMethodCall.Arguments[0]); + var inferredMapping = ExpressionExtensions.InferTypeMapping(array1, array2); + + return new SqlCustomBinaryExpression( + _sqlExpressionFactory.ApplyTypeMapping(array1, inferredMapping), + _sqlExpressionFactory.ApplyTypeMapping(array2, inferredMapping), + "<@", + typeof(bool), + _boolMapping); + } + } + + return null; } /// diff --git a/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs b/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs index 0c29d89e0..b817a4113 100644 --- a/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs @@ -288,6 +288,55 @@ public void Any_like_anonymous() WHERE s.""SomeText"" LIKE ANY (@__patterns_0)"); } + [Fact] + public void Any_Contains() + { + using var ctx = CreateContext(); + + var results = ctx.SomeEntities + .Where(e => new[] { 2, 3 }.Any(p => e.SomeArray.Contains(p))) + .ToList(); + Assert.Equal(1, Assert.Single(results).Id); + + results = ctx.SomeEntities + .Where(e => new[] { 1, 2 }.Any(p => e.SomeArray.Contains(p))) + .ToList(); + Assert.Empty(results); + + AssertSql( + @"SELECT s.""Id"", s.""SomeArray"", s.""SomeBytea"", s.""SomeList"", s.""SomeMatrix"", s.""SomeText"" +FROM ""SomeEntities"" AS s +WHERE (ARRAY[2,3]::integer[] && s.""SomeArray"")", + @"SELECT s.""Id"", s.""SomeArray"", s.""SomeBytea"", s.""SomeList"", s.""SomeMatrix"", s.""SomeText"" +FROM ""SomeEntities"" AS s +WHERE (ARRAY[1,2]::integer[] && s.""SomeArray"")"); + } + + [Fact] + public void All_Contains() + { + using var ctx = CreateContext(); + + var results = ctx.SomeEntities + .Where(e => new[] { 5, 6 }.All(p => e.SomeArray.Contains(p))) + .ToList(); + Assert.Equal(2, Assert.Single(results).Id); + + results = ctx.SomeEntities + .Where(e => new[] { 4, 5, 6 }.All(p => e.SomeArray.Contains(p))) + .ToList(); + Assert.Empty(results); + + AssertSql( + @"SELECT s.""Id"", s.""SomeArray"", s.""SomeBytea"", s.""SomeList"", s.""SomeMatrix"", s.""SomeText"" +FROM ""SomeEntities"" AS s +WHERE (ARRAY[5,6]::integer[] <@ s.""SomeArray"")", + // + @"SELECT s.""Id"", s.""SomeArray"", s.""SomeBytea"", s.""SomeList"", s.""SomeMatrix"", s.""SomeText"" +FROM ""SomeEntities"" AS s +WHERE (ARRAY[4,5,6]::integer[] <@ s.""SomeArray"")"); + } + #endregion #region Support