diff --git a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs
index df1267503..69b160ebe 100644
--- a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs
+++ b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs
@@ -1,4 +1,5 @@
using System;
+using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
@@ -10,35 +11,47 @@
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Infrastructure;
+using Microsoft.EntityFrameworkCore.Storage;
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal;
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal;
-using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping;
using ExpressionExtensions = Microsoft.EntityFrameworkCore.Query.ExpressionExtensions;
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.Internal
{
public class NpgsqlSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExpressionVisitor
{
- [NotNull] static readonly MethodInfo Like2MethodInfo =
+ [NotNull]
+ static readonly MethodInfo Like2MethodInfo =
typeof(DbFunctionsExtensions)
.GetRuntimeMethod(nameof(DbFunctionsExtensions.Like), new[] { typeof(DbFunctions), typeof(string), typeof(string) });
// ReSharper disable once InconsistentNaming
- [NotNull] static readonly MethodInfo ILike2MethodInfo =
+ [NotNull]
+ static readonly MethodInfo ILike2MethodInfo =
typeof(NpgsqlDbFunctionsExtensions)
.GetRuntimeMethod(nameof(NpgsqlDbFunctionsExtensions.ILike), new[] { typeof(DbFunctions), typeof(string), typeof(string) });
- [NotNull] static readonly MethodInfo EnumerableAnyWithPredicate =
+ [NotNull]
+ static readonly MethodInfo EnumerableAnyWithPredicate =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(mi => mi.Name == nameof(Enumerable.Any) && mi.GetParameters().Length == 2);
- [NotNull] static readonly MethodInfo EnumerableAll =
+ [NotNull]
+ static readonly MethodInfo EnumerableAll =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(mi => mi.Name == nameof(Enumerable.All) && mi.GetParameters().Length == 2);
+ [NotNull]
+ static readonly MethodInfo Contains =
+ typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
+ .Single(m => m.Name == nameof(Enumerable.Contains) && m.GetParameters().Length == 2);
+
readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;
readonly NpgsqlJsonPocoTranslator _jsonPocoTranslator;
+ [NotNull]
+ readonly RelationalTypeMapping _boolMapping;
+
public NpgsqlSqlTranslatingExpressionVisitor(
RelationalSqlTranslatingExpressionVisitorDependencies dependencies,
IModel model,
@@ -47,6 +60,7 @@ public NpgsqlSqlTranslatingExpressionVisitor(
{
_sqlExpressionFactory = (NpgsqlSqlExpressionFactory)dependencies.SqlExpressionFactory;
_jsonPocoTranslator = ((NpgsqlMemberTranslatorProvider)Dependencies.MemberTranslatorProvider).JsonPocoTranslator;
+ _boolMapping = _sqlExpressionFactory.FindMapping(typeof(bool));
}
// PostgreSQL COUNT() always returns bigint, so we need to downcast to int
@@ -109,37 +123,52 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
protected override Expression VisitMethodCall(MethodCallExpression methodCall)
{
+ var visited = base.VisitMethodCall(methodCall);
+ if (visited != null)
+ return visited;
+
// TODO: Handle List<>
+ if (methodCall.Arguments.Count > 0 && methodCall.Arguments[0].Type.IsArray)
+ return VisitArrayMethodCall(methodCall.Method, methodCall.Arguments);
+
+ return null;
+ }
+ ///
+ /// Identifies complex array-related constructs which cannot be translated in regular method translators, since
+ /// they require accessing lambdas.
+ ///
+ protected virtual Expression VisitArrayMethodCall(MethodInfo method, ReadOnlyCollection arguments)
+ {
{
// Pattern match for .Where(e => new[] { "a", "b", "c" }.Any(p => EF.Functions.Like(e.SomeText, p))),
// which we translate to WHERE s.""SomeText"" LIKE ANY (ARRAY['a','b','c']) (see test Any_like)
// Note: NavigationExpander normalized Any(x) to Where(x).Any()
- if (methodCall.Method.IsClosedFormOf(EnumerableAnyWithPredicate) &&
- methodCall.Arguments[0].Type.IsArray &&
- methodCall.Arguments[1] is LambdaExpression wherePredicate &&
+ if (method.IsClosedFormOf(EnumerableAnyWithPredicate) &&
+ arguments[1] is LambdaExpression wherePredicate &&
wherePredicate.Body is MethodCallExpression wherePredicateMethodCall && (
wherePredicateMethodCall.Method == Like2MethodInfo ||
wherePredicateMethodCall.Method == ILike2MethodInfo))
{
- var array = (SqlExpression)Visit(methodCall.Arguments[0]);
+ var array = (SqlExpression)Visit(arguments[0]);
var match = (SqlExpression)Visit(wherePredicateMethodCall.Arguments[1]);
return _sqlExpressionFactory.ArrayAnyAll(match, array, ArrayComparisonType.Any,
wherePredicateMethodCall.Method == Like2MethodInfo ? "LIKE" : "ILIKE");
}
+
+ // Note: we also handle the above with equality instead of Like, see NpgsqlArrayMethodTranslator
}
- // Same for All (but without the normalization
{
- if (methodCall.Method.IsClosedFormOf(EnumerableAll) &&
- methodCall.Arguments[0].Type.IsArray &&
- methodCall.Arguments[1] is LambdaExpression wherePredicate &&
+ // Same for All (but without the normalization)
+ if (method.IsClosedFormOf(EnumerableAll) &&
+ arguments[1] is LambdaExpression wherePredicate &&
wherePredicate.Body is MethodCallExpression wherePredicateMethodCall && (
wherePredicateMethodCall.Method == Like2MethodInfo ||
wherePredicateMethodCall.Method == ILike2MethodInfo))
{
- var array = (SqlExpression)Visit(methodCall.Arguments[0]);
+ var array = (SqlExpression)Visit(arguments[0]);
var match = (SqlExpression)Visit(wherePredicateMethodCall.Arguments[1]);
return _sqlExpressionFactory.ArrayAnyAll(match, array, ArrayComparisonType.All,
@@ -147,8 +176,55 @@ methodCall.Arguments[1] is LambdaExpression wherePredicate &&
}
}
- // Note: we also handle the above with equality instead of Like, see NpgsqlArrayMethodTranslator
- return base.VisitMethodCall(methodCall);
+ {
+ // Translate e => new[] { 4, 5 }.Any(p => e.SomeArray.Contains(p)),
+ // using array overlap (&&)
+ if (method.IsClosedFormOf(EnumerableAnyWithPredicate) &&
+ arguments[1] is LambdaExpression wherePredicate &&
+ wherePredicate.Body is MethodCallExpression wherePredicateMethodCall &&
+ wherePredicateMethodCall.Method.IsClosedFormOf(Contains) &&
+ wherePredicateMethodCall.Arguments[0].Type.IsArray &&
+ wherePredicateMethodCall.Arguments[1] is ParameterExpression parameterExpression &&
+ parameterExpression == wherePredicate.Parameters[0])
+ {
+ var array1 = (SqlExpression)Visit(arguments[0]);
+ var array2 = (SqlExpression)Visit(wherePredicateMethodCall.Arguments[0]);
+ var inferredMapping = ExpressionExtensions.InferTypeMapping(array1, array2);
+
+ return new SqlCustomBinaryExpression(
+ _sqlExpressionFactory.ApplyTypeMapping(array1, inferredMapping),
+ _sqlExpressionFactory.ApplyTypeMapping(array2, inferredMapping),
+ "&&",
+ typeof(bool),
+ _boolMapping);
+ }
+ }
+
+ {
+ // Translate e => new[] { 4, 5 }.All(p => e.SomeArray.Contains(p)),
+ // using array containment (<@)
+ if (method.IsClosedFormOf(EnumerableAll) &&
+ arguments[1] is LambdaExpression wherePredicate &&
+ wherePredicate.Body is MethodCallExpression wherePredicateMethodCall &&
+ wherePredicateMethodCall.Method.IsClosedFormOf(Contains) &&
+ wherePredicateMethodCall.Arguments[0].Type.IsArray &&
+ wherePredicateMethodCall.Arguments[1] is ParameterExpression parameterExpression &&
+ parameterExpression == wherePredicate.Parameters[0])
+ {
+ var array1 = (SqlExpression)Visit(arguments[0]);
+ var array2 = (SqlExpression)Visit(wherePredicateMethodCall.Arguments[0]);
+ var inferredMapping = ExpressionExtensions.InferTypeMapping(array1, array2);
+
+ return new SqlCustomBinaryExpression(
+ _sqlExpressionFactory.ApplyTypeMapping(array1, inferredMapping),
+ _sqlExpressionFactory.ApplyTypeMapping(array2, inferredMapping),
+ "<@",
+ typeof(bool),
+ _boolMapping);
+ }
+ }
+
+ return null;
}
///
diff --git a/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs b/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs
index 0c29d89e0..b817a4113 100644
--- a/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs
+++ b/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs
@@ -288,6 +288,55 @@ public void Any_like_anonymous()
WHERE s.""SomeText"" LIKE ANY (@__patterns_0)");
}
+ [Fact]
+ public void Any_Contains()
+ {
+ using var ctx = CreateContext();
+
+ var results = ctx.SomeEntities
+ .Where(e => new[] { 2, 3 }.Any(p => e.SomeArray.Contains(p)))
+ .ToList();
+ Assert.Equal(1, Assert.Single(results).Id);
+
+ results = ctx.SomeEntities
+ .Where(e => new[] { 1, 2 }.Any(p => e.SomeArray.Contains(p)))
+ .ToList();
+ Assert.Empty(results);
+
+ AssertSql(
+ @"SELECT s.""Id"", s.""SomeArray"", s.""SomeBytea"", s.""SomeList"", s.""SomeMatrix"", s.""SomeText""
+FROM ""SomeEntities"" AS s
+WHERE (ARRAY[2,3]::integer[] && s.""SomeArray"")",
+ @"SELECT s.""Id"", s.""SomeArray"", s.""SomeBytea"", s.""SomeList"", s.""SomeMatrix"", s.""SomeText""
+FROM ""SomeEntities"" AS s
+WHERE (ARRAY[1,2]::integer[] && s.""SomeArray"")");
+ }
+
+ [Fact]
+ public void All_Contains()
+ {
+ using var ctx = CreateContext();
+
+ var results = ctx.SomeEntities
+ .Where(e => new[] { 5, 6 }.All(p => e.SomeArray.Contains(p)))
+ .ToList();
+ Assert.Equal(2, Assert.Single(results).Id);
+
+ results = ctx.SomeEntities
+ .Where(e => new[] { 4, 5, 6 }.All(p => e.SomeArray.Contains(p)))
+ .ToList();
+ Assert.Empty(results);
+
+ AssertSql(
+ @"SELECT s.""Id"", s.""SomeArray"", s.""SomeBytea"", s.""SomeList"", s.""SomeMatrix"", s.""SomeText""
+FROM ""SomeEntities"" AS s
+WHERE (ARRAY[5,6]::integer[] <@ s.""SomeArray"")",
+ //
+ @"SELECT s.""Id"", s.""SomeArray"", s.""SomeBytea"", s.""SomeList"", s.""SomeMatrix"", s.""SomeText""
+FROM ""SomeEntities"" AS s
+WHERE (ARRAY[4,5,6]::integer[] <@ s.""SomeArray"")");
+ }
+
#endregion
#region Support