Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#region License

// The PostgreSQL License
//
// Copyright (C) 2016 The Npgsql Development Team
Expand All @@ -19,48 +20,115 @@
// AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS
// ON AN "AS IS" BASIS, AND THE NPGSQL DEVELOPMENT TEAM HAS NO OBLIGATIONS
// TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.

#endregion

using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.Expressions;
using Microsoft.EntityFrameworkCore.Query.ExpressionVisitors;
using Microsoft.EntityFrameworkCore.Query.ExpressionVisitors.Internal;
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal;
using Remotion.Linq.Clauses;
using Remotion.Linq.Clauses.Expressions;
using Remotion.Linq.Clauses.ResultOperators;

namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionVisitors
{
public class NpgsqlSqlTranslatingExpressionVisitor : SqlTranslatingExpressionVisitor
{
private readonly RelationalQueryModelVisitor _queryModelVisitor;
/// <summary>
/// The <see cref="MethodInfo"/> for <see cref="DbFunctionsExtensions.Like(DbFunctions,string,string)"/>.
/// </summary>
[NotNull] static readonly MethodInfo Like2MethodInfo =
typeof(DbFunctionsExtensions)
.GetRuntimeMethod(nameof(DbFunctionsExtensions.Like), new[] { typeof(DbFunctions), typeof(string), typeof(string) });

/// <summary>
/// The <see cref="MethodInfo"/> for <see cref="DbFunctionsExtensions.Like(DbFunctions,string,string, string)"/>.
/// </summary>
[NotNull] static readonly MethodInfo Like3MethodInfo =
typeof(DbFunctionsExtensions)
.GetRuntimeMethod(nameof(DbFunctionsExtensions.Like), new[] { typeof(DbFunctions), typeof(string), typeof(string), typeof(string) });

// ReSharper disable once InconsistentNaming
/// <summary>
/// The <see cref="MethodInfo"/> for <see cref="NpgsqlDbFunctionsExtensions.ILike(DbFunctions,string,string)"/>.
/// </summary>
[NotNull] static readonly MethodInfo ILike2MethodInfo =
typeof(NpgsqlDbFunctionsExtensions)
.GetRuntimeMethod(nameof(NpgsqlDbFunctionsExtensions.ILike), new[] { typeof(DbFunctions), typeof(string), typeof(string) });

// ReSharper disable once InconsistentNaming
/// <summary>
/// The <see cref="MethodInfo"/> for <see cref="NpgsqlDbFunctionsExtensions.ILike(DbFunctions,string,string,string)"/>.
/// </summary>
[NotNull] static readonly MethodInfo ILike3MethodInfo =
typeof(NpgsqlDbFunctionsExtensions)
.GetRuntimeMethod(nameof(NpgsqlDbFunctionsExtensions.ILike), new[] { typeof(DbFunctions), typeof(string), typeof(string), typeof(string) });

/// <summary>
/// The query model visitor.
/// </summary>
[NotNull] readonly RelationalQueryModelVisitor _queryModelVisitor;

/// <inheritdoc />
public NpgsqlSqlTranslatingExpressionVisitor(
[NotNull] SqlTranslatingExpressionVisitorDependencies dependencies,
[NotNull] RelationalQueryModelVisitor queryModelVisitor,
[CanBeNull] SelectExpression targetSelectExpression = null,
[CanBeNull] Expression topLevelPredicate = null,
bool inProjection = false)
: base(dependencies, queryModelVisitor, targetSelectExpression, topLevelPredicate, inProjection)
{
_queryModelVisitor = queryModelVisitor;
}
=> _queryModelVisitor = queryModelVisitor;

/// <inheritdoc />
protected override Expression VisitSubQuery(SubQueryExpression expression)
=> base.VisitSubQuery(expression) ?? VisitLikeAnyAll(expression) ?? VisitEqualsAny(expression);

/// <inheritdoc />
protected override Expression VisitBinary(BinaryExpression expression)
{
// Prefer the default EF Core translation if one exists
var result = base.VisitSubQuery(expression);
if (result != null)
return result;
if (expression.NodeType == ExpressionType.ArrayIndex)
{
var properties = MemberAccessBindingExpressionVisitor.GetPropertyPath(
expression.Left, _queryModelVisitor.QueryCompilationContext, out _);
if (properties.Count == 0)
return base.VisitBinary(expression);
var lastPropertyType = properties[properties.Count - 1].ClrType;
if (lastPropertyType.IsArray && lastPropertyType.GetArrayRank() == 1)
{
var left = Visit(expression.Left);
var right = Visit(expression.Right);

return left != null && right != null
? Expression.MakeBinary(ExpressionType.ArrayIndex, left, right)
: null;
}
}

return base.VisitBinary(expression);
}

/// <summary>
/// Visits a <see cref="SubQueryExpression"/> and attempts to translate a '= ANY' expression.
/// </summary>
/// <param name="expression">The expression to visit.</param>
/// <returns>
/// An '= ANY' expression or null.
/// </returns>
[CanBeNull]
protected virtual Expression VisitEqualsAny([NotNull] SubQueryExpression expression)
{
var subQueryModel = expression.QueryModel;
var fromExpression = subQueryModel.MainFromClause.FromExpression;

var properties = MemberAccessBindingExpressionVisitor.GetPropertyPath(
fromExpression, _queryModelVisitor.QueryCompilationContext, out var qsre);
fromExpression, _queryModelVisitor.QueryCompilationContext, out _);

if (properties.Count == 0)
return null;
Expand All @@ -76,33 +144,77 @@ protected override Expression VisitSubQuery(SubQueryExpression expression)
{
var containsItem = Visit(contains.Item);
if (containsItem != null)
return new ArrayAnyExpression(containsItem, Visit(fromExpression));
return new ArrayAnyAllExpression(ArrayComparisonType.ANY, "=", containsItem, Visit(fromExpression));
}
}

return null;
}

protected override Expression VisitBinary(BinaryExpression expression)
/// <summary>
/// Visits a <see cref="SubQueryExpression"/> and attempts to translate a LIKE/ILIKE ANY/ALL expression.
/// </summary>
/// <param name="expression">The expression to visit.</param>
/// <returns>
/// A 'LIKE ANY', 'LIKE ALL', 'ILIKE ANY', or 'ILIKE ALL' expression or null.
/// </returns>
[CanBeNull]
protected virtual Expression VisitLikeAnyAll([NotNull] SubQueryExpression expression)
{
if (expression.NodeType == ExpressionType.ArrayIndex)
{
var properties = MemberAccessBindingExpressionVisitor.GetPropertyPath(
expression.Left, _queryModelVisitor.QueryCompilationContext, out var qsre);
if (properties.Count == 0)
return base.VisitBinary(expression);
var lastPropertyType = properties[properties.Count - 1].ClrType;
if (lastPropertyType.IsArray && lastPropertyType.GetArrayRank() == 1)
{
var left = Visit(expression.Left);
var right = Visit(expression.Right);
var queryModel = expression.QueryModel;
var results = queryModel.ResultOperators;
var body = queryModel.BodyClauses;

return left != null && right != null
? Expression.MakeBinary(ExpressionType.ArrayIndex, left, right)
if (results.Count != 1)
return null;

ArrayComparisonType comparisonType;
MethodCallExpression call;
switch (results[0])
{
case AnyResultOperator _:
comparisonType = ArrayComparisonType.ANY;
call =
body.Count == 1 &&
body[0] is WhereClause whereClause &&
whereClause.Predicate is MethodCallExpression methocCall
? methocCall
: null;
}
break;

case AllResultOperator allResult:
comparisonType = ArrayComparisonType.ALL;
call = allResult.Predicate as MethodCallExpression;
break;

default:
return null;
}
return base.VisitBinary(expression);

if (call is null)
return null;

var source = queryModel.MainFromClause.FromExpression;

// ReSharper disable AssignNullToNotNullAttribute
switch (call.Method)
{
case MethodInfo m when m == Like2MethodInfo:
return new ArrayAnyAllExpression(comparisonType, "LIKE", Visit(call.Arguments[1]), Visit(source));

case MethodInfo m when m == Like3MethodInfo:
return new ArrayAnyAllExpression(comparisonType, "LIKE", Visit(call.Arguments[1]), Visit(source));

case MethodInfo m when m == ILike2MethodInfo:
return new ArrayAnyAllExpression(comparisonType, "ILIKE", Visit(call.Arguments[1]), Visit(source));

case MethodInfo m when m == ILike3MethodInfo:
return new ArrayAnyAllExpression(comparisonType, "ILIKE", Visit(call.Arguments[1]), Visit(source));

default:
return null;
}
// ReSharper restore AssignNullToNotNullAttribute
}
}
}
156 changes: 156 additions & 0 deletions src/EFCore.PG/Query/Expressions/Internal/ArrayAnyAllExpression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#region License

// The PostgreSQL License
//
// Copyright (C) 2016 The Npgsql Development Team
//
// Permission to use, copy, modify, and distribute this software and its
// documentation for any purpose, without fee, and without a written
// agreement is hereby granted, provided that the above copyright notice
// and this paragraph and the following two paragraphs appear in all copies.
//
// IN NO EVENT SHALL THE NPGSQL DEVELOPMENT TEAM BE LIABLE TO ANY PARTY
// FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES,
// INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS
// DOCUMENTATION, EVEN IF THE NPGSQL DEVELOPMENT TEAM HAS BEEN ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//
// THE NPGSQL DEVELOPMENT TEAM SPECIFICALLY DISCLAIMS ANY WARRANTIES,
// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY
// AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS
// ON AN "AS IS" BASIS, AND THE NPGSQL DEVELOPMENT TEAM HAS NO OBLIGATIONS
// TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.

#endregion

using System;
using System.Linq.Expressions;
using JetBrains.Annotations;
using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Sql.Internal;
using Npgsql.EntityFrameworkCore.PostgreSQL.Utilities;

namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal
{
/// <summary>
/// Represents a PostgreSQL array ANY or ALL expression.
/// </summary>
/// <example>
/// 1 = ANY ('{0,1,2}'), 'cat' LIKE ANY ('{a%,b%,c%}')
/// </example>
/// <remarks>
/// See https://www.postgresql.org/docs/current/static/functions-comparisons.html
/// </remarks>
public class ArrayAnyAllExpression : Expression, IEquatable<ArrayAnyAllExpression>
{
/// <inheritdoc />
public override ExpressionType NodeType { get; } = ExpressionType.Extension;

/// <inheritdoc />
public override Type Type { get; } = typeof(bool);

/// <summary>
/// The value to test against the <see cref="Array"/>.
/// </summary>
public virtual Expression Operand { get; }

/// <summary>
/// The array of values or patterns to test for the <see cref="Operand"/>.
/// </summary>
public virtual Expression Array { get; }

/// <summary>
/// The operator.
/// </summary>
public virtual string Operator { get; }

/// <summary>
/// The comparison type.
/// </summary>
public virtual ArrayComparisonType ArrayComparisonType { get; }

/// <summary>
/// Constructs a <see cref="ArrayAnyAllExpression"/>.
/// </summary>
/// <param name="arrayComparisonType">The comparison type.</param>
/// <param name="operatorSymbol">The operator symbol to the array expression.</param>
/// <param name="operand">The value to find.</param>
/// <param name="array">The array to search.</param>
/// <exception cref="ArgumentNullException" />
public ArrayAnyAllExpression(
ArrayComparisonType arrayComparisonType,
[NotNull] string operatorSymbol,
[NotNull] Expression operand,
[NotNull] Expression array)
{
Check.NotNull(array, nameof(operatorSymbol));
Check.NotNull(operand, nameof(operand));
Check.NotNull(array, nameof(array));

ArrayComparisonType = arrayComparisonType;
Operator = operatorSymbol;
Operand = operand;
Array = array;
}

/// <inheritdoc />
protected override Expression Accept(ExpressionVisitor visitor)
=> visitor is NpgsqlQuerySqlGenerator npsgqlGenerator
? npsgqlGenerator.VisitArrayAnyAll(this)
: base.Accept(visitor);

/// <inheritdoc />
protected override Expression VisitChildren(ExpressionVisitor visitor)
{
if (!(visitor.Visit(Operand) is Expression operand))
throw new ArgumentException($"The {nameof(operand)} of a {nameof(ArrayAnyAllExpression)} cannot be null.");

if (!(visitor.Visit(Array) is Expression collection))
throw new ArgumentException($"The {nameof(collection)} of a {nameof(ArrayAnyAllExpression)} cannot be null.");

return
operand == Operand && collection == Array
? this
: new ArrayAnyAllExpression(ArrayComparisonType, Operator, operand, collection);
}

/// <inheritdoc />
public override bool Equals(object obj)
=> obj is ArrayAnyAllExpression likeAnyExpression && Equals(likeAnyExpression);

/// <inheritdoc />
public bool Equals(ArrayAnyAllExpression other)
=> Operand.Equals(other?.Operand) &&
Operator.Equals(other?.Operator) &&
ArrayComparisonType.Equals(other?.ArrayComparisonType) &&
Array.Equals(other?.Array);

/// <inheritdoc />
public override int GetHashCode()
=> unchecked((397 * Operand.GetHashCode()) ^
(397 * Operator.GetHashCode()) ^
(397 * ArrayComparisonType.GetHashCode()) ^
(397 * Array.GetHashCode()));

/// <inheritdoc />
public override string ToString()
=> $"{Operand} {Operator} {ArrayComparisonType.ToString()} ({Array})";
}

/// <summary>
/// Represents whether an array comparison is ANY or ALL.
/// </summary>
public enum ArrayComparisonType
{
// ReSharper disable once InconsistentNaming
/// <summary>
/// Represents an ANY array comparison.
/// </summary>
ANY,

// ReSharper disable once InconsistentNaming
/// <summary>
/// Represents an ALL array comparison.
/// </summary>
ALL
}
}
Loading