Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 92 additions & 16 deletions src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
Expand All @@ -10,35 +11,47 @@
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Storage;
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal;
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal;
using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping;
using ExpressionExtensions = Microsoft.EntityFrameworkCore.Query.ExpressionExtensions;

namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.Internal
{
public class NpgsqlSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExpressionVisitor
{
[NotNull] static readonly MethodInfo Like2MethodInfo =
[NotNull]
static readonly MethodInfo Like2MethodInfo =
typeof(DbFunctionsExtensions)
.GetRuntimeMethod(nameof(DbFunctionsExtensions.Like), new[] { typeof(DbFunctions), typeof(string), typeof(string) });

// ReSharper disable once InconsistentNaming
[NotNull] static readonly MethodInfo ILike2MethodInfo =
[NotNull]
static readonly MethodInfo ILike2MethodInfo =
typeof(NpgsqlDbFunctionsExtensions)
.GetRuntimeMethod(nameof(NpgsqlDbFunctionsExtensions.ILike), new[] { typeof(DbFunctions), typeof(string), typeof(string) });

[NotNull] static readonly MethodInfo EnumerableAnyWithPredicate =
[NotNull]
static readonly MethodInfo EnumerableAnyWithPredicate =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(mi => mi.Name == nameof(Enumerable.Any) && mi.GetParameters().Length == 2);

[NotNull] static readonly MethodInfo EnumerableAll =
[NotNull]
static readonly MethodInfo EnumerableAll =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(mi => mi.Name == nameof(Enumerable.All) && mi.GetParameters().Length == 2);

[NotNull]
static readonly MethodInfo Contains =
typeof(Enumerable).GetTypeInfo().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Single(m => m.Name == nameof(Enumerable.Contains) && m.GetParameters().Length == 2);

readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory;
readonly NpgsqlJsonPocoTranslator _jsonPocoTranslator;

[NotNull]
readonly RelationalTypeMapping _boolMapping;

public NpgsqlSqlTranslatingExpressionVisitor(
RelationalSqlTranslatingExpressionVisitorDependencies dependencies,
IModel model,
Expand All @@ -47,6 +60,7 @@ public NpgsqlSqlTranslatingExpressionVisitor(
{
_sqlExpressionFactory = (NpgsqlSqlExpressionFactory)dependencies.SqlExpressionFactory;
_jsonPocoTranslator = ((NpgsqlMemberTranslatorProvider)Dependencies.MemberTranslatorProvider).JsonPocoTranslator;
_boolMapping = _sqlExpressionFactory.FindMapping(typeof(bool));
}

// PostgreSQL COUNT() always returns bigint, so we need to downcast to int
Expand Down Expand Up @@ -109,46 +123,108 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)

protected override Expression VisitMethodCall(MethodCallExpression methodCall)
{
var visited = base.VisitMethodCall(methodCall);
if (visited != null)
return visited;

// TODO: Handle List<>
if (methodCall.Arguments.Count > 0 && methodCall.Arguments[0].Type.IsArray)
return VisitArrayMethodCall(methodCall.Method, methodCall.Arguments);

return null;
}

/// <summary>
/// Identifies complex array-related constructs which cannot be translated in regular method translators, since
/// they require accessing lambdas.
/// </summary>
protected virtual Expression VisitArrayMethodCall(MethodInfo method, ReadOnlyCollection<Expression> arguments)
{
{
// Pattern match for .Where(e => new[] { "a", "b", "c" }.Any(p => EF.Functions.Like(e.SomeText, p))),
// which we translate to WHERE s.""SomeText"" LIKE ANY (ARRAY['a','b','c']) (see test Any_like)
// Note: NavigationExpander normalized Any(x) to Where(x).Any()
if (methodCall.Method.IsClosedFormOf(EnumerableAnyWithPredicate) &&
methodCall.Arguments[0].Type.IsArray &&
methodCall.Arguments[1] is LambdaExpression wherePredicate &&
if (method.IsClosedFormOf(EnumerableAnyWithPredicate) &&
arguments[1] is LambdaExpression wherePredicate &&
wherePredicate.Body is MethodCallExpression wherePredicateMethodCall && (
wherePredicateMethodCall.Method == Like2MethodInfo ||
wherePredicateMethodCall.Method == ILike2MethodInfo))
{
var array = (SqlExpression)Visit(methodCall.Arguments[0]);
var array = (SqlExpression)Visit(arguments[0]);
var match = (SqlExpression)Visit(wherePredicateMethodCall.Arguments[1]);

return _sqlExpressionFactory.ArrayAnyAll(match, array, ArrayComparisonType.Any,
wherePredicateMethodCall.Method == Like2MethodInfo ? "LIKE" : "ILIKE");
}

// Note: we also handle the above with equality instead of Like, see NpgsqlArrayMethodTranslator
}

// Same for All (but without the normalization
{
if (methodCall.Method.IsClosedFormOf(EnumerableAll) &&
methodCall.Arguments[0].Type.IsArray &&
methodCall.Arguments[1] is LambdaExpression wherePredicate &&
// Same for All (but without the normalization)
if (method.IsClosedFormOf(EnumerableAll) &&
arguments[1] is LambdaExpression wherePredicate &&
wherePredicate.Body is MethodCallExpression wherePredicateMethodCall && (
wherePredicateMethodCall.Method == Like2MethodInfo ||
wherePredicateMethodCall.Method == ILike2MethodInfo))
{
var array = (SqlExpression)Visit(methodCall.Arguments[0]);
var array = (SqlExpression)Visit(arguments[0]);
var match = (SqlExpression)Visit(wherePredicateMethodCall.Arguments[1]);

return _sqlExpressionFactory.ArrayAnyAll(match, array, ArrayComparisonType.All,
wherePredicateMethodCall.Method == Like2MethodInfo ? "LIKE" : "ILIKE");
}
}

// Note: we also handle the above with equality instead of Like, see NpgsqlArrayMethodTranslator
return base.VisitMethodCall(methodCall);
{
// Translate e => new[] { 4, 5 }.Any(p => e.SomeArray.Contains(p)),
// using array overlap (&&)
if (method.IsClosedFormOf(EnumerableAnyWithPredicate) &&
arguments[1] is LambdaExpression wherePredicate &&
wherePredicate.Body is MethodCallExpression wherePredicateMethodCall &&
wherePredicateMethodCall.Method.IsClosedFormOf(Contains) &&
wherePredicateMethodCall.Arguments[0].Type.IsArray &&
wherePredicateMethodCall.Arguments[1] is ParameterExpression parameterExpression &&
parameterExpression == wherePredicate.Parameters[0])
{
var array1 = (SqlExpression)Visit(arguments[0]);
var array2 = (SqlExpression)Visit(wherePredicateMethodCall.Arguments[0]);
var inferredMapping = ExpressionExtensions.InferTypeMapping(array1, array2);

return new SqlCustomBinaryExpression(
_sqlExpressionFactory.ApplyTypeMapping(array1, inferredMapping),
_sqlExpressionFactory.ApplyTypeMapping(array2, inferredMapping),
"&&",
typeof(bool),
_boolMapping);
}
}

{
// Translate e => new[] { 4, 5 }.All(p => e.SomeArray.Contains(p)),
// using array containment (<@)
if (method.IsClosedFormOf(EnumerableAll) &&
arguments[1] is LambdaExpression wherePredicate &&
wherePredicate.Body is MethodCallExpression wherePredicateMethodCall &&
wherePredicateMethodCall.Method.IsClosedFormOf(Contains) &&
wherePredicateMethodCall.Arguments[0].Type.IsArray &&
wherePredicateMethodCall.Arguments[1] is ParameterExpression parameterExpression &&
parameterExpression == wherePredicate.Parameters[0])
{
var array1 = (SqlExpression)Visit(arguments[0]);
var array2 = (SqlExpression)Visit(wherePredicateMethodCall.Arguments[0]);
var inferredMapping = ExpressionExtensions.InferTypeMapping(array1, array2);

return new SqlCustomBinaryExpression(
_sqlExpressionFactory.ApplyTypeMapping(array1, inferredMapping),
_sqlExpressionFactory.ApplyTypeMapping(array2, inferredMapping),
"<@",
typeof(bool),
_boolMapping);
}
}

return null;
}

/// <inheritdoc />
Expand Down
49 changes: 49 additions & 0 deletions test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,55 @@ public void Any_like_anonymous()
WHERE s.""SomeText"" LIKE ANY (@__patterns_0)");
}

[Fact]
public void Any_Contains()
{
using var ctx = CreateContext();

var results = ctx.SomeEntities
.Where(e => new[] { 2, 3 }.Any(p => e.SomeArray.Contains(p)))
.ToList();
Assert.Equal(1, Assert.Single(results).Id);

results = ctx.SomeEntities
.Where(e => new[] { 1, 2 }.Any(p => e.SomeArray.Contains(p)))
.ToList();
Assert.Empty(results);

AssertSql(
@"SELECT s.""Id"", s.""SomeArray"", s.""SomeBytea"", s.""SomeList"", s.""SomeMatrix"", s.""SomeText""
FROM ""SomeEntities"" AS s
WHERE (ARRAY[2,3]::integer[] && s.""SomeArray"")",
@"SELECT s.""Id"", s.""SomeArray"", s.""SomeBytea"", s.""SomeList"", s.""SomeMatrix"", s.""SomeText""
FROM ""SomeEntities"" AS s
WHERE (ARRAY[1,2]::integer[] && s.""SomeArray"")");
}

[Fact]
public void All_Contains()
{
using var ctx = CreateContext();

var results = ctx.SomeEntities
.Where(e => new[] { 5, 6 }.All(p => e.SomeArray.Contains(p)))
.ToList();
Assert.Equal(2, Assert.Single(results).Id);

results = ctx.SomeEntities
.Where(e => new[] { 4, 5, 6 }.All(p => e.SomeArray.Contains(p)))
.ToList();
Assert.Empty(results);

AssertSql(
@"SELECT s.""Id"", s.""SomeArray"", s.""SomeBytea"", s.""SomeList"", s.""SomeMatrix"", s.""SomeText""
FROM ""SomeEntities"" AS s
WHERE (ARRAY[5,6]::integer[] <@ s.""SomeArray"")",
//
@"SELECT s.""Id"", s.""SomeArray"", s.""SomeBytea"", s.""SomeList"", s.""SomeMatrix"", s.""SomeText""
FROM ""SomeEntities"" AS s
WHERE (ARRAY[4,5,6]::integer[] <@ s.""SomeArray"")");
}

#endregion

#region Support
Expand Down