diff --git a/doc/mapping/array.md b/doc/mapping/array.md index 11aeca56c..ed152e430 100644 --- a/doc/mapping/array.md +++ b/doc/mapping/array.md @@ -7,29 +7,64 @@ PostgreSQL has the unique feature of supporting [*array data types*](https://www # Mapping arrays -Simply define a regular .NET array or `List<>` property, and the provider +Npgsql maps PostgreSQL arrays to generic `T[]` and `List` types: ```c# public class Post { public int Id { get; set; } - public string Name { get; set; } - public string[] Tags { get; set; } - public List AlternativeTags { get; set; } + public string[] SomeArray { get; set; } + public List SomeList { get; set; } } ``` -The provider will create `text[]` columns for the above two properties, and will properly detect changes in them - if you load an array and change one of its elements, calling `SaveChanges()` will automatically update the row in the database accordingly. +The provider will create `text[]` columns for the above two properties, and will properly detect changes in them—if you load an array and change one of its elements, calling `SaveChanges()` will automatically update the row in the database accordingly. # Operation translation -The provider can also translate CLR array operations to the corresponding SQL operation; this allows you to efficiently work with arrays by evaluating operations in the database and avoids pulling all the data. The following table lists the range operations that currently get translated. If you run into a missing operation, please open an issue. +The provider translates many operations on `T[]` and `List` to corresponding SQL operations. This allows arrays to be worked with efficiently by evaluating operations in the database. -Note that operation translation on `List<>` is limited at this time, but will be improved in the future. It's recommended to use an array for now. +The following table lists the operations that are currently translated. If you run into a missing operation, please open an issue. -| C# expression | SQL generated by Npgsql | -|------------------------------------------------------------|-------------------------| -| `.Where(c => c.SomeArray[1] = "foo")` | [`WHERE "c"."SomeArray"[1] = 'foo'`](https://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-ACCESSING) -| `.Where(c => c.SomeArray.SequenceEqual(new[] { 1, 2, 3 })` | [`WHERE "c"."SomeArray" = ARRAY[1, 2, 3])`](https://www.postgresql.org/docs/current/static/arrays.html) -| `.Where(c => c.SomeArray.Contains(3))` | [`WHERE 3 = ANY("c"."SomeArray")`](https://www.postgresql.org/docs/current/static/functions-comparisons.html#AEN21104) -| `.Where(c => c.SomeArray.Length == 3)` | [`WHERE array_length("c"."SomeArray, 1) = 3`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| C# expression | SQL generated by Npgsql | +|------------------------------------------------------------------|-------------------------| +| `.Where(c => c.SomeArray[0] == "foo")` | [`WHERE "c"."SomeArray"[1] = 'foo'`](https://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-ACCESSING) +| `.Where(c => c.SomeList[0] == "foo")` | [`WHERE "c"."SomeList"[1] = 'foo'`](https://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-ACCESSING) +| `.Where(c => c.SomeArray.ElementAt(0) == "foo")` | [`WHERE "c"."SomeArray"[1] = 'foo'`](https://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-ACCESSING) +| `.Where(c => c.SomeList.ElementAt(0) == "foo")` | [`WHERE "c"."SomeList"[1] = 'foo'`](https://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-ACCESSING) +| `.Where(x => x.SomeArray.Length == 1)` | [`WHERE array_length(x."SomeArray", 1) = 1`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Where(x => x.SomeList.Count == 1)` | [`WHERE array_length(x."SomeList", 1) = 1`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Where(x => x.SomeArray.Count() == 1)` | [`WHERE array_length(x."SomeArray", 1) = 1`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Where(x => x.SomeList.Count() == 1)` | [`WHERE array_length(x."SomeList", 1) = 1`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Where(x => x.SomeArray == x.SomeList)` | [`WHERE x."SomeArray" = x."SomeList"`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Where(x => x.SomeArray.Equals(x.SomeList))` | [`WHERE x."SomeArray" = x."SomeList"`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Where(x => x.SomeArray.SequenceEquals(x.SomeList))` | [`WHERE x."SomeArray" = x."SomeList"`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Where(x => x.SomeArray.Contains("foo"))` | [`WHERE 'foo' = ANY (x."SomeArray")`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Where(x => x.SomeList.Contains("foo"))` | [`WHERE 'foo' = ANY (x."SomeList")`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Select(x => x.SomeArray.Append("foo"))` | [`SELECT x."SomeArray" \|\| 'foo'`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Select(x => x.SomeList.Append("foo"))` | [`SELECT x."SomeList" \|\| 'foo'`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Select(x => x.SomeArray.Prepend("foo"))` | [`SELECT 'foo' \|\| x."SomeArray"`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Select(x => x.SomeList.Prepend("foo"))` | [`SELECT 'foo' \|\| x."SomeList"`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Select(x => x.SomeArray.Concat(x.SomeList))` | [`SELECT x."SomeArray" \|\| x."SomeList"`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Select(x => x.SomeList.Concat(x.SomeArray))` | [`SELECT x."SomeList" \|\| x."SomeArray"`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Select(x => EF.Functions.ArrayToString(x.SomeArray, ","))` | [`SELECT array_to_string(x."SomeArray", ',')`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Select(x => EF.Functions.ArrayToString(x.SomeList, ",", "*"))` | [`SELECT array_to_string(x."SomeList", ',', '*')`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Select(x => Array.IndexOf(x.SomeArray, "foo"))` | [`SELECT COALESCE(array_position(x."SomeArray", 'foo'), -1)`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) +| `.Select(x => x.SomeList.IndexOf("foo"))` | [`SELECT COALESCE(array_position(x."SomeList", 'foo'), -1)`](https://www.postgresql.org/docs/current/static/functions-array.html#ARRAY-FUNCTIONS-TABLE) + +# Pattern translation + +The provider has special translations for certain patterns of operations. These pattern-based translations are more susceptible to client-evaluation than standard translations. + +The following sections describe the patterns that are currently translated. If you find that one of these patterns is being evaluated on the client, please open an issue. + + +## LIKE ANY + +## LIKE ALL + +## EXISTS + +## @> + +## && \ No newline at end of file diff --git a/src/EFCore.PG/Extensions/NpgsqlArrayExtensions.cs b/src/EFCore.PG/Extensions/NpgsqlArrayExtensions.cs new file mode 100644 index 000000000..674152580 --- /dev/null +++ b/src/EFCore.PG/Extensions/NpgsqlArrayExtensions.cs @@ -0,0 +1,154 @@ +#region License + +// The PostgreSQL License +// +// Copyright (C) 2016 The Npgsql Development Team +// +// Permission to use, copy, modify, and distribute this software and its +// documentation for any purpose, without fee, and without a written +// agreement is hereby granted, provided that the above copyright notice +// and this paragraph and the following two paragraphs appear in all copies. +// +// IN NO EVENT SHALL THE NPGSQL DEVELOPMENT TEAM BE LIABLE TO ANY PARTY +// FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, +// INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS +// DOCUMENTATION, EVEN IF THE NPGSQL DEVELOPMENT TEAM HAS BEEN ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +// +// THE NPGSQL DEVELOPMENT TEAM SPECIFICALLY DISCLAIMS ANY WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY +// AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS +// ON AN "AS IS" BASIS, AND THE NPGSQL DEVELOPMENT TEAM HAS NO OBLIGATIONS +// TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. + +#endregion + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using JetBrains.Annotations; + +// ReSharper disable once CheckNamespace +namespace Microsoft.EntityFrameworkCore +{ + /// + /// Provides extension methods for supporting PostgreSQL translation. + /// + public static class NpgsqlArrayExtensions + { + /// + /// Concatenates elements using the supplied delimiter. + /// + /// The DbFunctions instance. + /// The list to convert to a string in which to locate the value. + /// The value used to delimit the elements. + /// The type of the elements of . + /// + /// The string concatenation of the elements with the supplied delimiter. + /// + /// + /// This method is only intended for use via SQL translation as part of an EF Core LINQ query. + /// + public static string ArrayToString([CanBeNull] this DbFunctions _, [NotNull] T[] array, [CanBeNull] string delimiter) + => throw ClientEvaluationNotSupportedException(); + + /// + /// Concatenates elements using the supplied delimiter. + /// + /// The DbFunctions instance. + /// The list to convert to a string in which to locate the value. + /// The value used to delimit the elements. + /// The type of the elements of . + /// + /// The string concatenation of the elements with the supplied delimiter. + /// + /// + /// This method is only intended for use via SQL translation as part of an EF Core LINQ query. + /// + public static string ArrayToString([CanBeNull] this DbFunctions _, [NotNull] List list, [CanBeNull] string delimiter) + => throw ClientEvaluationNotSupportedException(); + + /// + /// Concatenates elements using the supplied delimiter and the string representation for null elements. + /// + /// The DbFunctions instance. + /// The list to convert to a string in which to locate the value. + /// The value used to delimit the elements. + /// The value used to represent a null value. + /// The type of the elements of . + /// + /// The string concatenation of the elements with the supplied delimiter and null string. + /// + /// + /// This method is only intended for use via SQL translation as part of an EF Core LINQ query. + /// + public static string ArrayToString([CanBeNull] this DbFunctions _, [NotNull] T[] array, [CanBeNull] string delimiter, [CanBeNull] string nullString) + => throw ClientEvaluationNotSupportedException(); + + /// + /// Concatenates elements using the supplied delimiter and the string representation for null elements. + /// + /// The DbFunctions instance. + /// The list to convert to a string in which to locate the value. + /// The value used to delimit the elements. + /// The value used to represent a null value. + /// The type of the elements of . + /// + /// The string concatenation of the elements with the supplied delimiter and null string. + /// + /// + /// This method is only intended for use via SQL translation as part of an EF Core LINQ query. + /// + public static string ArrayToString([CanBeNull] this DbFunctions _, [NotNull] List list, [CanBeNull] string delimiter, [CanBeNull] string nullString) + => throw ClientEvaluationNotSupportedException(); + + /// + /// Converts the input string into an array using the supplied delimiter and the string representation for null elements. + /// + /// The DbFunctions instance. + /// The input string of delimited values. + /// The value that delimits the elements. + /// The value that represents a null value. + /// The type of the elements in the resulting array. + /// + /// The array resulting from splitting the input string based on the supplied delimiter and null string. + /// + /// + /// This method is only intended for use via SQL translation as part of an EF Core LINQ query. + /// + public static T[] StringToArray([CanBeNull] this DbFunctions _, [NotNull] string input, [CanBeNull] string delimiter, [CanBeNull] string nullString) + => throw ClientEvaluationNotSupportedException(); + + /// + /// Converts the input string into a using the supplied delimiter and the string representation for null elements. + /// + /// The DbFunctions instance. + /// The input string of delimited values. + /// The value that delimits the elements. + /// The value that represents a null value. + /// The type of the elements in the resulting array. + /// + /// The list resulting from splitting the input string based on the supplied delimiter and null string. + /// + /// + /// This method is only intended for use via SQL translation as part of an EF Core LINQ query. + /// + public static List StringToList([CanBeNull] this DbFunctions _, [NotNull] string input, [CanBeNull] string delimiter, [CanBeNull] string nullString) + => throw ClientEvaluationNotSupportedException(); + + #region Utilities + + /// + /// Helper method to throw a with the name of the throwing method. + /// + /// The method that throws the exception. + /// + /// A . + /// + [NotNull] + static NotSupportedException ClientEvaluationNotSupportedException([CallerMemberName] string method = default) + => new NotSupportedException($"{method} is only intended for use via SQL translation as part of an EF Core LINQ query."); + + #endregion + } +} diff --git a/src/EFCore.PG/Extensions/NpgsqlServiceCollectionExtensions.cs b/src/EFCore.PG/Extensions/NpgsqlServiceCollectionExtensions.cs index ee76be5df..2025056dc 100644 --- a/src/EFCore.PG/Extensions/NpgsqlServiceCollectionExtensions.cs +++ b/src/EFCore.PG/Extensions/NpgsqlServiceCollectionExtensions.cs @@ -51,6 +51,7 @@ // ReSharper disable once CheckNamespace namespace Microsoft.Extensions.DependencyInjection { + // ReSharper disable once UnusedMember.Global public static class NpgsqlEntityFrameworkServicesBuilderExtensions { /// @@ -106,6 +107,8 @@ public static IServiceCollection AddEntityFrameworkNpgsql([NotNull] this IServic .TryAdd() .TryAdd() .TryAdd() + .TryAdd() + .TryAdd() .TryAdd() .TryAdd() .TryAdd(p => p.GetService()) diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArrayFragmentTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArrayFragmentTranslator.cs new file mode 100644 index 000000000..37025ac01 --- /dev/null +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArrayFragmentTranslator.cs @@ -0,0 +1,278 @@ +#region License + +// The PostgreSQL License +// +// Copyright (C) 2016 The Npgsql Development Team +// +// Permission to use, copy, modify, and distribute this software and its +// documentation for any purpose, without fee, and without a written +// agreement is hereby granted, provided that the above copyright notice +// and this paragraph and the following two paragraphs appear in all copies. +// +// IN NO EVENT SHALL THE NPGSQL DEVELOPMENT TEAM BE LIABLE TO ANY PARTY +// FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, +// INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS +// DOCUMENTATION, EVEN IF THE NPGSQL DEVELOPMENT TEAM HAS BEEN ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +// +// THE NPGSQL DEVELOPMENT TEAM SPECIFICALLY DISCLAIMS ANY WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY +// AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS +// ON AN "AS IS" BASIS, AND THE NPGSQL DEVELOPMENT TEAM HAS NO OBLIGATIONS +// TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. + +#endregion + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Query.Expressions; +using Microsoft.EntityFrameworkCore.Query.ExpressionTranslators; +using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal; +using Remotion.Linq; +using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Clauses.ResultOperators; + +namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal +{ + /// + /// Provides translation services for + /// and + /// as PostgreSQL array operations. + /// + public class NpgsqlArrayFragmentTranslator : IExpressionFragmentTranslator + { + /// + [CanBeNull] + public Expression Translate(Expression expression) + { + if (!(expression is SubQueryExpression subQuery)) + return null; + + var model = subQuery.QueryModel; + + if (ConcatResult(model) is Expression concat) + return concat; + + // TODO: catches too much. +// if (ContainsResult(model) is Expression contains) +// return contains; + + if (CountResult(model) is Expression count) + return count; + + if (model.BodyClauses.Count != 1) + return null; + + if (!(model.BodyClauses[0] is WhereClause where)) + return null; + + if (!(where.Predicate is BinaryExpression b)) + return null; + + if (!TryFindArray(b, out Expression from, out ArrayPosition position) || from is null) + return null; + + var operand = position is ArrayPosition.Left ? b.Right : b.Left; + + // In PostgreSQL, the array is on the right. Flip the sign if needed. + bool flip = position is ArrayPosition.Left; + + // ReSharper disable once SwitchStatementMissingSomeCases + switch (b.NodeType) + { + case ExpressionType.Equal: + return new ArrayAnyAllExpression(ArrayComparisonType.ANY, "=", operand, from); + + case ExpressionType.NotEqual: + return new ArrayAnyAllExpression(ArrayComparisonType.ANY, "<>", operand, from); + + case ExpressionType.LessThan: + return new ArrayAnyAllExpression(ArrayComparisonType.ANY, flip ? ">" : "<", operand, from); + + case ExpressionType.LessThanOrEqual: + return new ArrayAnyAllExpression(ArrayComparisonType.ANY, flip ? ">=" : "<=", operand, from); + + case ExpressionType.GreaterThan: + return new ArrayAnyAllExpression(ArrayComparisonType.ANY, flip ? "<" : ">", operand, from); + + case ExpressionType.GreaterThanOrEqual: + return new ArrayAnyAllExpression(ArrayComparisonType.ANY, flip ? "<=" : ">=", operand, from); + + default: + return null; + } + } + + #region SubQueries + + [CanBeNull] + static Expression ContainsResult([NotNull] QueryModel model) + { + if (model.BodyClauses.Count != 0) + return null; + + if (model.ResultOperators.Count != 1) + return null; + + if (!(model.ResultOperators[0] is ContainsResultOperator contains)) + return null; + + if (!(model.MainFromClause.FromExpression is Expression from)) + return null; + + return + IsArrayOrList(from.Type) + ? new ArrayAnyAllExpression(ArrayComparisonType.ANY, "=", contains.Item, from) + : null; + } + + /// + /// Visits an array-based count expression: {array}.Length, {list}.Count, {array|list}.Count(), {array|list}.Count({predicate}). + /// + /// The query model to visit. + /// + /// An expression or null. + /// + [CanBeNull] + static Expression CountResult([NotNull] QueryModel model) + { + // TODO: handle count operation with predicate. + if (model.BodyClauses.Count != 0) + return null; + + if (model.ResultOperators.Count != 1) + return null; + + if (!(model.ResultOperators[0] is CountResultOperator _)) + return null; + + if (!(model.MainFromClause.FromExpression is Expression from)) + return null; + + if (!IsArrayOrList(from.Type)) + return null; + + return from.Type.IsArray + ? (Expression)Expression.ArrayLength(from) + : new SqlFunctionExpression("array_length", typeof(int), new[] { from, Expression.Constant(1) }); + } + + /// + /// Visits an array-based concatenation expression: {array|value} || {array|value}. + /// + /// The query model to visit. + /// + /// An expression or null. + /// + [CanBeNull] + static Expression ConcatResult([NotNull] QueryModel model) + { + if (model.BodyClauses.Count != 0) + return null; + + if (model.ResultOperators.Count != 1) + return null; + + if (!(model.ResultOperators[0] is ConcatResultOperator concat)) + return null; + + if (!(model.MainFromClause.FromExpression is Expression from)) + return null; + + if (!IsArrayOrList(from.Type)) + return null; + + return + IsArrayOrList(concat.Source2.Type) + ? new CustomBinaryExpression(from, concat.Source2, "||", from.Type) + : null; + } + + #endregion + + #region Helpers + + /// + /// Try to return the array expression and its position in the . + /// + /// The expression to visit. + /// The array expression, if found. + /// The postion of the array. + /// + /// True if the array was found; otherwise, false. + /// + static bool TryFindArray([NotNull] BinaryExpression binaryExpression, [CanBeNull] out Expression array, out ArrayPosition position) + { + if (TryFindArray(binaryExpression.Left, out array)) + { + position = ArrayPosition.Left; + return true; + } + + if (TryFindArray(binaryExpression.Right, out array)) + { + position = ArrayPosition.Right; + return true; + } + + position = ArrayPosition.None; + return false; + } + + /// + /// Try to return the array expression. + /// + /// The expression to visit. + /// The array expression, if found. + /// + /// True if the array was found; otherwise, false. + /// + static bool TryFindArray([NotNull] Expression expression, [CanBeNull] out Expression array) + { + switch (expression) + { + // Is one side a qsre pointing to an array? + case QuerySourceReferenceExpression qsre + when qsre.ReferencedQuerySource is MainFromClause mfc && + mfc.FromExpression is Expression from && + IsArrayOrList(from.Type): + array = from; + return true; + + // Is the expression a parameter array? + case ParameterExpression param + when IsArrayOrList(param.Type): + array = param; + return true; + + default: + array = null; + return false; + } + } + + /// + /// Describes the position of an array in a . + /// + private enum ArrayPosition + { + None, + Left, + Right + } + + /// + /// Tests if the type is an array or a . + /// + /// The type to test. + /// + /// True if is an array or a ; otherwise, false. + /// + static bool IsArrayOrList([NotNull] Type type) => type.IsArray || type.IsGenericType && typeof(List<>) == type.GetGenericTypeDefinition(); + + #endregion + } +} diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArrayTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArrayTranslator.cs new file mode 100644 index 000000000..b804b5f44 --- /dev/null +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArrayTranslator.cs @@ -0,0 +1,197 @@ +#region License + +// The PostgreSQL License +// +// Copyright (C) 2016 The Npgsql Development Team +// +// Permission to use, copy, modify, and distribute this software and its +// documentation for any purpose, without fee, and without a written +// agreement is hereby granted, provided that the above copyright notice +// and this paragraph and the following two paragraphs appear in all copies. +// +// IN NO EVENT SHALL THE NPGSQL DEVELOPMENT TEAM BE LIABLE TO ANY PARTY +// FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, +// INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS +// DOCUMENTATION, EVEN IF THE NPGSQL DEVELOPMENT TEAM HAS BEEN ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +// +// THE NPGSQL DEVELOPMENT TEAM SPECIFICALLY DISCLAIMS ANY WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY +// AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS +// ON AN "AS IS" BASIS, AND THE NPGSQL DEVELOPMENT TEAM HAS NO OBLIGATIONS +// TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. + +#endregion + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Query.Expressions; +using Microsoft.EntityFrameworkCore.Query.ExpressionTranslators; +using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal; + +namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal +{ + /// + /// Provides translation services for PostgreSQL array operators mapped to methods declared on + /// , , , and . + /// + /// + /// See: https://www.postgresql.org/docs/current/static/functions-array.html + /// + public class NpgsqlArrayTranslator : IMethodCallTranslator + { + /// + [CanBeNull] + public Expression Translate(MethodCallExpression expression) + { + if (!IsTypeSupported(expression)) + return null; + + switch (expression.Method.Name) + { + #region EnumerableStaticMethods + + case nameof(Enumerable.ElementAt): + return Expression.MakeIndex(expression.Arguments[0], GetIndexer(expression.Arguments[0].Type, expression.Arguments.Skip(1)), new[] { expression.Arguments[1] }); + + case nameof(Enumerable.Append): + return new CustomBinaryExpression(expression.Arguments[0], expression.Arguments[1], "||", expression.Arguments[0].Type); + + case nameof(Enumerable.Prepend): + return new CustomBinaryExpression(expression.Arguments[1], expression.Arguments[0], "||", expression.Arguments[0].Type); + + case nameof(Enumerable.SequenceEqual): + return Expression.MakeBinary(ExpressionType.Equal, expression.Arguments[0], expression.Arguments[1]); + + #endregion + + #region NpgsqlArrayExtensions + + case nameof(NpgsqlArrayExtensions.ArrayToString): + return new SqlFunctionExpression("array_to_string", typeof(string), expression.Arguments.Skip(1)); + + case nameof(NpgsqlArrayExtensions.StringToArray): + return new SqlFunctionExpression("string_to_array", expression.Method.ReturnType, expression.Arguments.Skip(1)); + + case nameof(NpgsqlArrayExtensions.StringToList): + return new SqlFunctionExpression("string_to_array", expression.Method.ReturnType, expression.Arguments.Skip(1)); + + #endregion + + #region ArrayStaticMethods + + case nameof(Array.IndexOf) + when expression.Method.DeclaringType == typeof(Array): + return + new SqlFunctionExpression( + "COALESCE", + typeof(int), + new Expression[] + { + new SqlFunctionExpression("array_position", typeof(int), expression.Arguments), + Expression.Constant(-1) + }); + + #endregion + + #region ListInstanceMethods + + case "get_Item" when expression.Object is Expression instance: + return Expression.MakeIndex(instance, GetIndexer(instance.Type, expression.Arguments), expression.Arguments); + + case nameof(IList.IndexOf) when IsArrayOrList(expression.Method.DeclaringType): + return + new SqlFunctionExpression( + "COALESCE", + typeof(int), + new Expression[] + { + new SqlFunctionExpression("array_position", typeof(int), new[] { expression.Object, expression.Arguments[0] }), + Expression.Constant(-1) + }); + + #endregion + + #region StringInstanceMethods + + case "get_Chars" when expression.Object is Expression instance: + return Expression.MakeIndex(instance, GetIndexer(instance.Type, expression.Arguments), expression.Arguments); + + #endregion + + default: + return null; + } + } + + #region Helpers + + /// + /// Tests if the instance or argument types are supported. + /// + /// The to test. + /// + /// True if the instance or argument types are supported; otherwise, false. + /// + static bool IsTypeSupported([NotNull] MethodCallExpression expression) + { + Type declaringType = expression.Method.DeclaringType; + + // Methods declared here are always translated. + if (declaringType == typeof(NpgsqlArrayExtensions)) + return true; + + // Methods not declared here are never translated. + if (!IsArrayOrList(declaringType) && + declaringType != typeof(Array) && + declaringType != typeof(Enumerable)) + return false; + + switch (expression.Object) + { + // Instance methods are only translated for T[] and List. + case Expression instance: + return IsArrayOrList(instance.Type); + + // Extension methods may only be translated when a parameter is T[] or List + case null: + // Static method with no parameters? Return null. + // Static method on T[] or List? + return expression.Arguments.Count != 0 && IsArrayOrList(expression.Arguments[0].Type); + } + } + + /// + /// Tests if the type is an array or a . + /// + /// The type to test. + /// + /// True if is an array or a ; otherwise, false. + /// + static bool IsArrayOrList(Type type) => type.IsArray || type == typeof(string) || type.IsGenericType && typeof(List<>) == type.GetGenericTypeDefinition(); + + /// + /// Finds the for the indexer property with the given parameters, or null if none is found. + /// + /// The type to search. + /// The indexer parameters. + /// + /// The or null. + /// + [CanBeNull] + static PropertyInfo GetIndexer([NotNull] Type type, [NotNull] [ItemNotNull] IEnumerable arguments) + => type.GetRuntimeProperties() + .Where(x => x.Name == type.GetCustomAttribute()?.MemberName) + .Select(x => (Indexer: x, Parameters: x.GetGetMethod().GetParameters().Select(y => y.ParameterType))) + .SingleOrDefault(x => x.Parameters.SequenceEqual(arguments.Select(y => y.Type))) + .Indexer; + + #endregion + } +} diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArraySequenceEqualTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlCompositeExpressionFragmentTranslator.cs similarity index 52% rename from src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArraySequenceEqualTranslator.cs rename to src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlCompositeExpressionFragmentTranslator.cs index 3fb959ff5..42fc81838 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlArraySequenceEqualTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlCompositeExpressionFragmentTranslator.cs @@ -1,4 +1,5 @@ #region License + // The PostgreSQL License // // Copyright (C) 2016 The Npgsql Development Team @@ -19,43 +20,42 @@ // AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS // ON AN "AS IS" BASIS, AND THE NPGSQL DEVELOPMENT TEAM HAS NO OBLIGATIONS // TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. + #endregion -using System.Linq; -using System.Linq.Expressions; -using System.Reflection; +using System.Collections.Generic; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Query.ExpressionTranslators; namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal { /// - /// Translates Enumerable.SequenceEqual on arrays into PostgreSQL array equality operations. + /// A composite expression fragment translator that dispatches to multiple specialized translators specific to Npgsql. /// - /// - /// https://www.postgresql.org/docs/current/static/functions-array.html - /// - public class NpgsqlArraySequenceEqualTranslator : IMethodCallTranslator + public class NpgsqlCompositeExpressionFragmentTranslator : RelationalCompositeExpressionFragmentTranslator { - static readonly MethodInfo SequenceEqualMethodInfo = typeof(Enumerable).GetTypeInfo().GetDeclaredMethods(nameof(Enumerable.SequenceEqual)).Single(m => - m.IsGenericMethodDefinition && - m.GetParameters().Length == 2 - ); - - [CanBeNull] - public Expression Translate(MethodCallExpression methodCallExpression) + /// + /// The default expression fragment translators registered by the Npgsql provider. + /// + static readonly IExpressionFragmentTranslator[] ExpressionFragmentTranslators = { - var method = methodCallExpression.Method; - if (method.IsGenericMethod && - ReferenceEquals(method.GetGenericMethodDefinition(), SequenceEqualMethodInfo) && - methodCallExpression.Arguments.All(a => a.Type.IsArray)) - { - return Expression.MakeBinary(ExpressionType.Equal, - methodCallExpression.Arguments[0], - methodCallExpression.Arguments[1]); - } + new NpgsqlArrayFragmentTranslator() + }; - return null; + /// + public NpgsqlCompositeExpressionFragmentTranslator( + [NotNull] RelationalCompositeExpressionFragmentTranslatorDependencies dependencies) + : base(dependencies) + { + // ReSharper disable once DoNotCallOverridableMethodsInConstructor + AddTranslators(ExpressionFragmentTranslators); } + + /// + /// Adds additional dispatches to the translators list. + /// + /// The translators. + public new virtual void AddTranslators([NotNull] IEnumerable translators) + => base.AddTranslators(translators); } } diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlCompositeMethodCallTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlCompositeMethodCallTranslator.cs index 131ac6988..44d96dd9b 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlCompositeMethodCallTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlCompositeMethodCallTranslator.cs @@ -40,7 +40,6 @@ public class NpgsqlCompositeMethodCallTranslator : RelationalCompositeMethodCall /// [NotNull] [ItemNotNull] static readonly IMethodCallTranslator[] MethodCallTranslators = { - new NpgsqlArraySequenceEqualTranslator(), new NpgsqlConvertTranslator(), new NpgsqlStringSubstringTranslator(), new NpgsqlLikeTranslator(), @@ -60,7 +59,8 @@ public class NpgsqlCompositeMethodCallTranslator : RelationalCompositeMethodCall new NpgsqlRegexIsMatchTranslator(), new NpgsqlFullTextSearchMethodTranslator(), new NpgsqlRangeTranslator(), - new NpgsqlNetworkTranslator() + new NpgsqlNetworkTranslator(), + new NpgsqlArrayTranslator() }; /// diff --git a/src/EFCore.PG/Query/ExpressionVisitors/NpgsqlExistsToAnyRewritingExpressionVisitor.cs b/src/EFCore.PG/Query/ExpressionVisitors/NpgsqlExistsToAnyRewritingExpressionVisitor.cs new file mode 100644 index 000000000..45a46793c --- /dev/null +++ b/src/EFCore.PG/Query/ExpressionVisitors/NpgsqlExistsToAnyRewritingExpressionVisitor.cs @@ -0,0 +1,87 @@ +#region License + +// The PostgreSQL License +// +// Copyright (C) 2016 The Npgsql Development Team +// +// Permission to use, copy, modify, and distribute this software and its +// documentation for any purpose, without fee, and without a written +// agreement is hereby granted, provided that the above copyright notice +// and this paragraph and the following two paragraphs appear in all copies. +// +// IN NO EVENT SHALL THE NPGSQL DEVELOPMENT TEAM BE LIABLE TO ANY PARTY +// FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, +// INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS +// DOCUMENTATION, EVEN IF THE NPGSQL DEVELOPMENT TEAM HAS BEEN ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +// +// THE NPGSQL DEVELOPMENT TEAM SPECIFICALLY DISCLAIMS ANY WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY +// AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS +// ON AN "AS IS" BASIS, AND THE NPGSQL DEVELOPMENT TEAM HAS NO OBLIGATIONS +// TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. + +#endregion + +using System; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Extensions.Internal; +using Microsoft.EntityFrameworkCore.Query.ExpressionVisitors; +using Remotion.Linq; +using Remotion.Linq.Clauses; +using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Clauses.ResultOperators; +using Remotion.Linq.Parsing.ExpressionVisitors; + +namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionVisitors +{ + /// + /// An expression rewriter for . + /// + public class NpgsqlExistsToAnyRewritingExpressionVisitor : ExpressionVisitorBase + { + /// + /// The generic for . + /// + [NotNull] static readonly MethodInfo Exists = + typeof(Array).GetRuntimeMethods().Single(x => x.Name == nameof(Array.Exists)); + + /// + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + if (!methodCallExpression.Method.MethodIsClosedFormOf(Exists)) + return methodCallExpression; + + if (!(methodCallExpression.Arguments[0] is Expression array)) + return methodCallExpression; + + if (!(methodCallExpression.Arguments[1] is LambdaExpression predicate)) + return methodCallExpression; + + var mainFromClause = + new MainFromClause( + "", + array.Type.GetElementType(), + array); + + var qsre = new QuerySourceReferenceExpression(mainFromClause); + var queryModel = new QueryModel(mainFromClause, new SelectClause(qsre)); + + var where = + new WhereClause( + ReplacingExpressionVisitor.Replace( + predicate.Parameters[0], + qsre, + predicate.Body)); + + queryModel.BodyClauses.Add(where); + queryModel.ResultOperators.Add(new AnyResultOperator()); + queryModel.ResultTypeOverride = typeof(bool); + + return new SubQueryExpression(queryModel); + } + } +} diff --git a/src/EFCore.PG/Query/ExpressionVisitors/NpgsqlSqlTranslatingExpressionVisitor.cs b/src/EFCore.PG/Query/ExpressionVisitors/NpgsqlSqlTranslatingExpressionVisitor.cs index c50263d19..23b1f215b 100644 --- a/src/EFCore.PG/Query/ExpressionVisitors/NpgsqlSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.PG/Query/ExpressionVisitors/NpgsqlSqlTranslatingExpressionVisitor.cs @@ -23,6 +23,9 @@ #endregion +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -39,8 +42,13 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionVisitors { + /// + /// The default relational LINQ translating expression visitor for Npgsql. + /// public class NpgsqlSqlTranslatingExpressionVisitor : SqlTranslatingExpressionVisitor { + #region MethodInfoFields + /// /// The for . /// @@ -71,11 +79,19 @@ public class NpgsqlSqlTranslatingExpressionVisitor : SqlTranslatingExpressionVis typeof(NpgsqlDbFunctionsExtensions) .GetRuntimeMethod(nameof(NpgsqlDbFunctionsExtensions.ILike), new[] { typeof(DbFunctions), typeof(string), typeof(string), typeof(string) }); + #endregion + /// - /// The query model visitor. + /// The current query model visitor. /// [NotNull] readonly RelationalQueryModelVisitor _queryModelVisitor; + /// + /// The current query compilation context. + /// + [NotNull] + RelationalQueryCompilationContext Context => _queryModelVisitor.QueryCompilationContext; + /// public NpgsqlSqlTranslatingExpressionVisitor( [NotNull] SqlTranslatingExpressionVisitorDependencies dependencies, @@ -86,135 +102,240 @@ public NpgsqlSqlTranslatingExpressionVisitor( : base(dependencies, queryModelVisitor, targetSelectExpression, topLevelPredicate, inProjection) => _queryModelVisitor = queryModelVisitor; - /// - protected override Expression VisitSubQuery(SubQueryExpression expression) - => base.VisitSubQuery(expression) ?? VisitLikeAnyAll(expression) ?? VisitEqualsAny(expression); + #region Overrides /// protected override Expression VisitBinary(BinaryExpression expression) + => expression.NodeType is ExpressionType.ArrayIndex && + IsSafeToVisit(expression, Context) + ? Expression.ArrayAccess( + Visit(expression.Left) ?? expression.Left, + Visit(expression.Right) ?? expression.Right) + : base.VisitBinary(expression); + + /// + [CanBeNull] + protected override Expression VisitExtension(Expression expression) { - if (expression.NodeType == ExpressionType.ArrayIndex) + switch (expression) { - var properties = MemberAccessBindingExpressionVisitor.GetPropertyPath( - expression.Left, _queryModelVisitor.QueryCompilationContext, out _); - if (properties.Count == 0) - return base.VisitBinary(expression); - var lastPropertyType = properties[properties.Count - 1].ClrType; - if (lastPropertyType.IsArray && lastPropertyType.GetArrayRank() == 1) - { - var left = Visit(expression.Left); - var right = Visit(expression.Right); - - return left != null && right != null - ? Expression.MakeBinary(ExpressionType.ArrayIndex, left, right) - : null; - } - } + case SqlFunctionExpression e: + return + new SqlFunctionExpression( + e.FunctionName, + e.Type, + e.Schema, + e.Arguments.Select(x => Visit(x) ?? x)); - return base.VisitBinary(expression); + case PgFunctionExpression e: + return + new PgFunctionExpression( + e.Instance, + e.FunctionName, + e.Schema, + e.Type, + e.PositionalArguments.Select(x => Visit(x) ?? x), + e.NamedArguments.ToDictionary(x => x.Key, x => Visit(x.Value) ?? x.Value)); + + case CustomBinaryExpression e: + return + new CustomBinaryExpression( + Visit(e.Left) ?? e.Left, + Visit(e.Right) ?? e.Right, + e.Operator, + e.Type); + + case CustomUnaryExpression e: + return + new CustomUnaryExpression( + Visit(e.Operand) ?? e.Operand, + e.Operator, + e.Type, + e.Postfix); + + default: + return base.VisitExtension(expression); + } } - /// - /// Visits a and attempts to translate a '= ANY' expression. - /// - /// The expression to visit. - /// - /// An '= ANY' expression or null. - /// + /// [CanBeNull] - protected virtual Expression VisitEqualsAny([NotNull] SubQueryExpression expression) - { - var subQueryModel = expression.QueryModel; - var fromExpression = subQueryModel.MainFromClause.FromExpression; + protected override Expression VisitSubQuery(SubQueryExpression expression) + => base.VisitSubQuery(expression) ?? VisitArraySubQuery(expression); - var properties = MemberAccessBindingExpressionVisitor.GetPropertyPath( - fromExpression, _queryModelVisitor.QueryCompilationContext, out _); + /// + protected override Expression VisitUnary(UnaryExpression expression) + => Visit(expression.Operand) is Expression operand + ? Expression.MakeUnary(expression.NodeType, operand, expression.Type) + : base.VisitUnary(expression); - if (properties.Count == 0) - return null; - var lastPropertyType = properties[properties.Count - 1].ClrType; - if (lastPropertyType.IsArray && lastPropertyType.GetArrayRank() == 1 && subQueryModel.ResultOperators.Count > 0) - { - // Translate someArray.Length - if (subQueryModel.ResultOperators.First() is CountResultOperator) - return Expression.ArrayLength(Visit(fromExpression)); - - // Translate someArray.Contains(someValue) - if (subQueryModel.ResultOperators.First() is ContainsResultOperator contains) - { - var containsItem = Visit(contains.Item); - if (containsItem != null) - return new ArrayAnyAllExpression(ArrayComparisonType.ANY, "=", containsItem, Visit(fromExpression)); - } - } + #endregion - return null; - } + #region ArraySubQueries /// - /// Visits a and attempts to translate a LIKE/ILIKE ANY/ALL expression. + /// Visits an array-based subquery. /// - /// The expression to visit. + /// The subquery expression. /// - /// A 'LIKE ANY', 'LIKE ALL', 'ILIKE ANY', or 'ILIKE ALL' expression or null. + /// An expression or null. /// [CanBeNull] - protected virtual Expression VisitLikeAnyAll([NotNull] SubQueryExpression expression) + protected virtual Expression VisitArraySubQuery([NotNull] SubQueryExpression expression) { var queryModel = expression.QueryModel; - var results = queryModel.ResultOperators; + var from = queryModel.MainFromClause.FromExpression; var body = queryModel.BodyClauses; + var results = queryModel.ResultOperators; + + // TODO: what causes the from expression to not be visitable? + // Only handle subqueries when the from expression is visitable. + if (!(Visit(from) is Expression array)) + return null; + // Only handle types mapped to PostgreSQL arrays. + if (!IsArrayOrList(array.Type)) + return null; + + // TODO: when is there more than one result operator? + // Only handle singular result operators. if (results.Count != 1) return null; - ArrayComparisonType comparisonType; - MethodCallExpression call; switch (results[0]) { case AnyResultOperator _: - comparisonType = ArrayComparisonType.ANY; - call = - body.Count == 1 && - body[0] is WhereClause whereClause && - whereClause.Predicate is MethodCallExpression methocCall - ? methocCall - : null; - break; - - case AllResultOperator allResult: - comparisonType = ArrayComparisonType.ALL; - call = allResult.Predicate as MethodCallExpression; - break; + return VisitArrayAny(array, body); + + case AllResultOperator allResultOperator: + return VisitArrayAll(array, allResultOperator); + + case ContainsResultOperator contains: + return new ArrayAnyAllExpression(ArrayComparisonType.ANY, "=", Visit(contains.Item) ?? contains.Item, array); default: return null; } + } + + /// + /// Visits an array-based ANY comparison: {operand} {operator} ANY ({array}). + /// + /// The array expression. + /// The body clauses. + /// + /// An expression or null. + /// + [CanBeNull] + protected virtual Expression VisitArrayAny(Expression array, [NotNull] ObservableCollection body) + { + var predicate = + body.Count == 1 && + body[0] is WhereClause whereClause + ? whereClause.Predicate + : null; - if (call is null) + if (predicate is null) return null; - var source = queryModel.MainFromClause.FromExpression; + return + VisitArrayLike(array, predicate, ArrayComparisonType.ANY) ?? + VisitArrayContains(array, predicate, ArrayComparisonType.ANY); + } + + /// + /// Visits an array-based ALL comparison: {operand} {operator} ALL ({array}). + /// + /// The array expression. + /// The result operator. + /// + /// An expression or null. + /// + [CanBeNull] + protected virtual Expression VisitArrayAll([NotNull] Expression array, [NotNull] AllResultOperator allResultOperator) + => VisitArrayLike(array, allResultOperator.Predicate, ArrayComparisonType.ALL) ?? + VisitArrayContains(array, allResultOperator.Predicate, ArrayComparisonType.ALL); + + /// + /// Visits an array-based comparison for an LIKE or ILIKE expression: {operand} {LIKE|ILIKE} {ANY|ALL} ({array}). + /// + /// The array expression. + /// The method call expression. + /// The array comparison type. + /// + /// An expression or null. + /// + [CanBeNull] + protected virtual Expression VisitArrayLike([NotNull] Expression array, [NotNull] Expression predicate, ArrayComparisonType comparisonType) + { + if (!(predicate is MethodCallExpression call)) + return null; + + var operand = Visit(call.Arguments[1]) ?? call.Arguments[1]; - // ReSharper disable AssignNullToNotNullAttribute switch (call.Method) { case MethodInfo m when m == Like2MethodInfo: - return new ArrayAnyAllExpression(comparisonType, "LIKE", Visit(call.Arguments[1]), Visit(source)); + return new ArrayAnyAllExpression(comparisonType, "LIKE", operand, array); case MethodInfo m when m == Like3MethodInfo: - return new ArrayAnyAllExpression(comparisonType, "LIKE", Visit(call.Arguments[1]), Visit(source)); + return new ArrayAnyAllExpression(comparisonType, "LIKE", operand, array); case MethodInfo m when m == ILike2MethodInfo: - return new ArrayAnyAllExpression(comparisonType, "ILIKE", Visit(call.Arguments[1]), Visit(source)); + return new ArrayAnyAllExpression(comparisonType, "ILIKE", operand, array); case MethodInfo m when m == ILike3MethodInfo: - return new ArrayAnyAllExpression(comparisonType, "ILIKE", Visit(call.Arguments[1]), Visit(source)); + return new ArrayAnyAllExpression(comparisonType, "ILIKE", operand, array); default: return null; } - // ReSharper restore AssignNullToNotNullAttribute } + + /// + /// Visits an array-based comparison for a containment expression: {operand} = {ANY|ALL} ({array}). + /// + /// The array expression. + /// The method call expression. + /// The array comparison type. + /// + /// An expression or null. + /// + [CanBeNull] + protected virtual Expression VisitArrayContains([NotNull] Expression array, [NotNull] Expression predicate, ArrayComparisonType comparisonType) + { + if (!(Visit(predicate) is ArrayAnyAllExpression expression) || !expression.IsContainsExpression) + return null; + + var inner = Visit(expression.Array) ?? expression.Array; + + return new CustomBinaryExpression(array, inner, comparisonType == ArrayComparisonType.ALL ? "<@" : "&&", typeof(bool)); + } + + #endregion + + #region Helpers + + /// + /// Tests if the type is an array or a . + /// + /// The type to test. + /// + /// True if is an array or a ; otherwise, false. + /// + static bool IsArrayOrList([NotNull] Type type) => type.IsArray || type.IsGenericType && typeof(List<>) == type.GetGenericTypeDefinition(); + + /// + /// True if the expression is safe to visitat this stage. + /// + /// The expression to check + /// The context to use. + /// + /// True to visit this expression; otherwise false. + /// + static bool IsSafeToVisit(BinaryExpression expression, RelationalQueryCompilationContext context) + => MemberAccessBindingExpressionVisitor.GetPropertyPath(expression.Left, context, out _).Count != 0; + + #endregion } } diff --git a/src/EFCore.PG/Query/Expressions/Internal/ArrayAnyAllExpression.cs b/src/EFCore.PG/Query/Expressions/Internal/ArrayAnyAllExpression.cs index 04054d3a4..4cbe035d8 100644 --- a/src/EFCore.PG/Query/Expressions/Internal/ArrayAnyAllExpression.cs +++ b/src/EFCore.PG/Query/Expressions/Internal/ArrayAnyAllExpression.cs @@ -26,7 +26,7 @@ using System; using System.Linq.Expressions; using JetBrains.Annotations; -using Npgsql.EntityFrameworkCore.PostgreSQL.Query.Sql.Internal; +using Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionVisitors; using Npgsql.EntityFrameworkCore.PostgreSQL.Utilities; namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal @@ -43,24 +43,27 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.Expressions.Internal public class ArrayAnyAllExpression : Expression, IEquatable { /// - public override ExpressionType NodeType { get; } = ExpressionType.Extension; + public override ExpressionType NodeType => ExpressionType.Extension; /// - public override Type Type { get; } = typeof(bool); + public override Type Type => typeof(bool); /// /// The value to test against the . /// + [NotNull] public virtual Expression Operand { get; } /// /// The array of values or patterns to test for the . /// + [NotNull] public virtual Expression Array { get; } /// /// The operator. /// + [NotNull] public virtual string Operator { get; } /// @@ -68,6 +71,11 @@ public class ArrayAnyAllExpression : Expression, IEquatable public virtual ArrayComparisonType ArrayComparisonType { get; } + /// + /// True if this instance represents: {operand} = ANY ({array})". + /// + public bool IsContainsExpression => ArrayComparisonType is ArrayComparisonType.ANY && Operator is "="; + /// /// Constructs a . /// @@ -94,23 +102,27 @@ public ArrayAnyAllExpression( /// protected override Expression Accept(ExpressionVisitor visitor) - => visitor is NpgsqlQuerySqlGenerator npsgqlGenerator - ? npsgqlGenerator.VisitArrayAnyAll(this) - : base.Accept(visitor); + { + switch (visitor) + { + case NpgsqlSqlTranslatingExpressionVisitor npgsqlVisitor: + return VisitChildren(npgsqlVisitor); + + default: + return base.Accept(visitor) ?? this; + } + } /// protected override Expression VisitChildren(ExpressionVisitor visitor) { - if (!(visitor.Visit(Operand) is Expression operand)) - throw new ArgumentException($"The {nameof(operand)} of a {nameof(ArrayAnyAllExpression)} cannot be null."); - - if (!(visitor.Visit(Array) is Expression collection)) - throw new ArgumentException($"The {nameof(collection)} of a {nameof(ArrayAnyAllExpression)} cannot be null."); + var operand = visitor.Visit(Operand) ?? Operand; + var array = visitor.Visit(Array) ?? Array; return - operand == Operand && collection == Array - ? this - : new ArrayAnyAllExpression(ArrayComparisonType, Operator, operand, collection); + operand != Operand || array != Array + ? new ArrayAnyAllExpression(ArrayComparisonType, Operator, operand, array) + : this; } /// diff --git a/src/EFCore.PG/Query/Internal/NpgsqlQueryOptimizer.cs b/src/EFCore.PG/Query/Internal/NpgsqlQueryOptimizer.cs new file mode 100644 index 000000000..a10d1425f --- /dev/null +++ b/src/EFCore.PG/Query/Internal/NpgsqlQueryOptimizer.cs @@ -0,0 +1,58 @@ +#region License + +// The PostgreSQL License +// +// Copyright (C) 2016 The Npgsql Development Team +// +// Permission to use, copy, modify, and distribute this software and its +// documentation for any purpose, without fee, and without a written +// agreement is hereby granted, provided that the above copyright notice +// and this paragraph and the following two paragraphs appear in all copies. +// +// IN NO EVENT SHALL THE NPGSQL DEVELOPMENT TEAM BE LIABLE TO ANY PARTY +// FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, +// INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS +// DOCUMENTATION, EVEN IF THE NPGSQL DEVELOPMENT TEAM HAS BEEN ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +// +// THE NPGSQL DEVELOPMENT TEAM SPECIFICALLY DISCLAIMS ANY WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY +// AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS +// ON AN "AS IS" BASIS, AND THE NPGSQL DEVELOPMENT TEAM HAS NO OBLIGATIONS +// TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. + +#endregion + +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.Internal; +using Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionVisitors; +using Remotion.Linq; + +namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.Internal +{ + /// + /// The default relational LINQ query optimizer for Npgsql. + /// + public class NpgsqlQueryOptimizer : QueryOptimizer + { + /// + /// The default expression visitors registered by the Npgsql provider. + /// + static readonly ExpressionVisitor[] ExpressionVisitors = + { + new NpgsqlExistsToAnyRewritingExpressionVisitor() + }; + + /// + public override void Optimize(QueryCompilationContext queryCompilationContext, QueryModel queryModel) + { + base.Optimize(queryCompilationContext, queryModel); + + for (int i = 0; i < ExpressionVisitors.Length; i++) + { + queryModel.TransformExpressions(ExpressionVisitors[i].Visit); + } + } + } +} diff --git a/src/EFCore.PG/Query/Sql/Internal/NpgsqlQuerySqlGenerator.cs b/src/EFCore.PG/Query/Sql/Internal/NpgsqlQuerySqlGenerator.cs index f9728bfe2..1ff54ce37 100644 --- a/src/EFCore.PG/Query/Sql/Internal/NpgsqlQuerySqlGenerator.cs +++ b/src/EFCore.PG/Query/Sql/Internal/NpgsqlQuerySqlGenerator.cs @@ -24,7 +24,6 @@ #endregion using System; -using System.Diagnostics; using System.Linq.Expressions; using System.Text.RegularExpressions; using JetBrains.Annotations; @@ -41,24 +40,29 @@ public class NpgsqlQuerySqlGenerator : DefaultQuerySqlGenerator { readonly bool _reverseNullOrderingEnabled; + /// protected override string TypedTrueLiteral { get; } = "TRUE::bool"; + /// protected override string TypedFalseLiteral { get; } = "FALSE::bool"; + /// public NpgsqlQuerySqlGenerator( [NotNull] QuerySqlGeneratorDependencies dependencies, [NotNull] SelectExpression selectExpression, bool reverseNullOrderingEnabled) : base(dependencies, selectExpression) - { - _reverseNullOrderingEnabled = reverseNullOrderingEnabled; - } + => _reverseNullOrderingEnabled = reverseNullOrderingEnabled; + + #region Generators + /// protected override void GenerateTop(SelectExpression selectExpression) { // No TOP() in PostgreSQL, see GenerateLimitOffset } + /// protected override void GenerateLimitOffset(SelectExpression selectExpression) { Check.NotNull(selectExpression, nameof(selectExpression)); @@ -81,6 +85,82 @@ protected override void GenerateLimitOffset(SelectExpression selectExpression) } } + /// + /// PostgreSQL array indexing is 1-based. If the index happens to be a constant, + /// just increment it. Otherwise, append a +1 in the SQL. + /// + protected virtual Expression GenerateOneBasedIndexExpression(Expression expression) + => expression is ConstantExpression constantExpression + ? Expression.Constant(Convert.ToInt32(constantExpression.Value) + 1) + : (Expression)Expression.Add(expression, Expression.Constant(1)); + + /// + protected override string GenerateOperator(Expression expression) + { + switch (expression.NodeType) + { + case ExpressionType.Add: + if (expression.Type == typeof(string)) + return " || "; + goto default; + case ExpressionType.And: + if (expression.Type == typeof(bool)) + return " AND "; + goto default; + case ExpressionType.Or: + if (expression.Type == typeof(bool)) + return " OR "; + goto default; + default: + return base.GenerateOperator(expression); + } + } + + /// + protected override void GenerateOrdering(Ordering ordering) + { + base.GenerateOrdering(ordering); + if (_reverseNullOrderingEnabled) + Sql.Append( + ordering.OrderingDirection == OrderingDirection.Asc + ? " NULLS FIRST" + : " NULLS LAST"); + } + + #endregion + + #region Visitors + + /// + protected override Expression VisitExtension(Expression expression) + { + switch (expression) + { + case ArrayAnyAllExpression arrayAnyAllExpression: + return VisitArrayAnyAll(arrayAnyAllExpression); + + default: + return base.VisitExtension(expression); + } + } + + /// + /// Produces expressions like: 1 = ANY ('{0,1,2}') or 'cat' LIKE ANY ('{a%,b%,c%}'). + /// + protected virtual Expression VisitArrayAnyAll(ArrayAnyAllExpression arrayAnyAllExpression) + { + Visit(arrayAnyAllExpression.Operand); + Sql.Append(' '); + Sql.Append(arrayAnyAllExpression.Operator); + Sql.Append(' '); + Sql.Append(arrayAnyAllExpression.ArrayComparisonType.ToString()); + Sql.Append(" ("); + Visit(arrayAnyAllExpression.Array); + Sql.Append(')'); + return arrayAnyAllExpression; + } + + /// public override Expression VisitSqlFunction(SqlFunctionExpression sqlFunctionExpression) { var expr = base.VisitSqlFunction(sqlFunctionExpression); @@ -108,6 +188,7 @@ public override Expression VisitSqlFunction(SqlFunctionExpression sqlFunctionExp return expr; } + /// protected override Expression VisitBinary(BinaryExpression expression) { switch (expression.NodeType) @@ -126,17 +207,15 @@ protected override Expression VisitBinary(BinaryExpression expression) return exp; } - break; + goto default; } - case ExpressionType.ArrayIndex: - VisitArrayIndex(expression); - return expression; + default: + return base.VisitBinary(expression); } - - return base.VisitBinary(expression); } + /// protected override Expression VisitUnary(UnaryExpression expression) { if (expression.NodeType == ExpressionType.ArrayLength) @@ -148,68 +227,67 @@ protected override Expression VisitUnary(UnaryExpression expression) return base.VisitUnary(expression); } - protected virtual void VisitArrayIndex([NotNull] BinaryExpression expression) + /// + protected override Expression VisitIndex(IndexExpression expression) { - Debug.Assert(expression.NodeType == ExpressionType.ArrayIndex); - - if (expression.Left.Type == typeof(byte[])) + // bytea cannot be subscripted + if (expression.Object.Type == typeof(byte[])) + return + VisitSqlFunction( + new SqlFunctionExpression( + "get_byte", + typeof(byte), + new[] { expression.Object, expression.Arguments[0] })); + + // text cannot be subscripted + if (expression.Object.Type == typeof(string)) + return + VisitSqlFunction( + new SqlFunctionExpression( + "ascii", + typeof(int), + new[] + { + new SqlFunctionExpression( + "substr", + typeof(char), + new[] + { + expression.Object, + GenerateOneBasedIndexExpression(expression.Arguments[0]), + Expression.Constant(1) + }) + })); + +// // TODO: discussion: https://github.com/npgsql/Npgsql.EntityFrameworkCore.PostgreSQL/issues/450 +// VisitSqlFunction( +// new SqlFunctionExpression( +// "get_byte", +// typeof(int), +// new[] +// { +// new CustomUnaryExpression(expression.Object, "::bytea", typeof(int), true), +// expression.Arguments[0] +// })); + + Visit(expression.Object); + for (int i = 0; i < expression.Arguments.Count; i++) { - // bytea cannot be subscripted, but there's get_byte - VisitSqlFunction(new SqlFunctionExpression("get_byte", typeof(byte), - new[] { expression.Left, expression.Right })); - return; + Sql.Append('['); + Visit(GenerateOneBasedIndexExpression(expression.Arguments[i])); + Sql.Append(']'); } - if (expression.Left.Type == typeof(string)) - { - // text cannot be subscripted, use substr - // PostgreSQL substr() is 1-based. - - VisitSqlFunction(new SqlFunctionExpression("substr", typeof(char), - new[] { expression.Left, expression.Right, Expression.Constant(1) })); - return; - } - - // Regular array from here - Visit(expression.Left); - Sql.Append('['); - Visit(GenerateOneBasedIndexExpression(expression.Right)); - Sql.Append(']'); - } - - /// - /// Produces expressions like: 1 = ANY ('{0,1,2}') or 'cat' LIKE ANY ('{a%,b%,c%}'). - /// - public Expression VisitArrayAnyAll(ArrayAnyAllExpression arrayAnyAllExpression) - { - Visit(arrayAnyAllExpression.Operand); - Sql.Append(' '); - Sql.Append(arrayAnyAllExpression.Operator); - Sql.Append(' '); - Sql.Append(arrayAnyAllExpression.ArrayComparisonType.ToString()); - Sql.Append(" ("); - Visit(arrayAnyAllExpression.Array); - Sql.Append(')'); - return arrayAnyAllExpression; + return expression; } - /// - /// PostgreSQL array indexing is 1-based. If the index happens to be a constant, - /// just increment it. Otherwise, append a +1 in the SQL. - /// - static Expression GenerateOneBasedIndexExpression(Expression expression) - => expression is ConstantExpression constantExpression - ? Expression.Constant(Convert.ToInt32(constantExpression.Value) + 1) - : (Expression)Expression.Add(expression, Expression.Constant(1)); - /// /// See: http://www.postgresql.org/docs/current/static/functions-matching.html /// - public Expression VisitRegexMatch([NotNull] RegexMatchExpression regexMatchExpression) + public virtual Expression VisitRegexMatch([NotNull] RegexMatchExpression regexMatchExpression) { Check.NotNull(regexMatchExpression, nameof(regexMatchExpression)); var options = regexMatchExpression.Options; - Visit(regexMatchExpression.Match); Sql.Append(" ~ "); @@ -223,33 +301,27 @@ public Expression VisitRegexMatch([NotNull] RegexMatchExpression regexMatchExpre Sql.Append("('(?"); if (options.HasFlag(RegexOptions.IgnoreCase)) Sql.Append('i'); - if (options.HasFlag(RegexOptions.Multiline)) Sql.Append('n'); else if (!options.HasFlag(RegexOptions.Singleline)) + // In .NET's default mode, . doesn't match newlines but PostgreSQL it does. Sql.Append('p'); - if (options.HasFlag(RegexOptions.IgnorePatternWhitespace)) Sql.Append('x'); - Sql.Append(")' || "); Visit(regexMatchExpression.Pattern); Sql.Append(')'); - return regexMatchExpression; } - public Expression VisitAtTimeZone([NotNull] AtTimeZoneExpression atTimeZoneExpression) + public virtual Expression VisitAtTimeZone([NotNull] AtTimeZoneExpression atTimeZoneExpression) { Check.NotNull(atTimeZoneExpression, nameof(atTimeZoneExpression)); - Visit(atTimeZoneExpression.TimestampExpression); - Sql.Append(" AT TIME ZONE '"); Sql.Append(atTimeZoneExpression.TimeZone); Sql.Append('\''); - return atTimeZoneExpression; } @@ -259,13 +331,10 @@ public virtual Expression VisitILike(ILikeExpression iLikeExpression) //var parentTypeMapping = _typeMapping; //_typeMapping = InferTypeMappingFromColumn(iLikeExpression.Match) ?? parentTypeMapping; - Visit(iLikeExpression.Match); - Sql.Append(" ILIKE "); Visit(iLikeExpression.Pattern); - if (iLikeExpression.EscapeChar != null) { Sql.Append(" ESCAPE "); @@ -273,67 +342,27 @@ public virtual Expression VisitILike(ILikeExpression iLikeExpression) } //_typeMapping = parentTypeMapping; - return iLikeExpression; } - public Expression VisitExplicitStoreTypeCast([NotNull] ExplicitStoreTypeCastExpression castExpression) + public virtual Expression VisitExplicitStoreTypeCast([NotNull] ExplicitStoreTypeCastExpression castExpression) { Sql.Append("CAST("); //var parentTypeMapping = _typeMapping; //_typeMapping = InferTypeMappingFromColumn(castExpression.Operand); - Visit(castExpression.Operand); - Sql.Append(" AS ") .Append(castExpression.StoreType) .Append(")"); //_typeMapping = parentTypeMapping; - return castExpression; } - protected override string GenerateOperator(Expression expression) - { - switch (expression.NodeType) - { - case ExpressionType.Add: - if (expression.Type == typeof(string)) - return " || "; - goto default; - - case ExpressionType.And: - if (expression.Type == typeof(bool)) - return " AND "; - goto default; - - case ExpressionType.Or: - if (expression.Type == typeof(bool)) - return " OR "; - goto default; - - default: - return base.GenerateOperator(expression); - } - } - - protected override void GenerateOrdering(Ordering ordering) - { - base.GenerateOrdering(ordering); - - if (_reverseNullOrderingEnabled) - Sql.Append( - ordering.OrderingDirection == OrderingDirection.Asc - ? " NULLS FIRST" - : " NULLS LAST"); - } - public virtual Expression VisitCustomBinary(CustomBinaryExpression expression) { Check.NotNull(expression, nameof(expression)); - Sql.Append('('); Visit(expression.Left); Sql.Append(' '); @@ -341,14 +370,12 @@ public virtual Expression VisitCustomBinary(CustomBinaryExpression expression) Sql.Append(' '); Visit(expression.Right); Sql.Append(')'); - return expression; } public virtual Expression VisitCustomUnary(CustomUnaryExpression expression) { Check.NotNull(expression, nameof(expression)); - if (expression.Postfix) { Visit(expression.Operand); @@ -366,22 +393,18 @@ public virtual Expression VisitCustomUnary(CustomUnaryExpression expression) public virtual Expression VisitPgFunction(PgFunctionExpression e) { //var parentTypeMapping = _typeMapping; - //_typeMapping = null; var wroteSchema = false; - if (e.Instance != null) { Visit(e.Instance); - Sql.Append("."); } else if (!string.IsNullOrWhiteSpace(e.Schema)) { Sql.Append(SqlGenerator.DelimitIdentifier(e.Schema)) .Append("."); - wroteSchema = true; } @@ -389,32 +412,28 @@ public virtual Expression VisitPgFunction(PgFunctionExpression e) wroteSchema ? SqlGenerator.DelimitIdentifier(e.FunctionName) : e.FunctionName); - Sql.Append("("); //_typeMapping = null; - GenerateList(e.PositionalArguments); - bool hasArguments = e.PositionalArguments.Count > 0 && e.NamedArguments.Count > 0; - foreach (var kv in e.NamedArguments) { if (hasArguments) Sql.Append(", "); else hasArguments = true; - Sql.Append(kv.Key) .Append(" => "); - Visit(kv.Value); } Sql.Append(")"); - //_typeMapping = parentTypeMapping; + //_typeMapping = parentTypeMapping; return e; } + + #endregion } } diff --git a/src/EFCore.PG/Storage/Internal/Mapping/NpgsqlListTypeMapping.cs b/src/EFCore.PG/Storage/Internal/Mapping/NpgsqlListTypeMapping.cs index 1f024448f..866bdeaec 100644 --- a/src/EFCore.PG/Storage/Internal/Mapping/NpgsqlListTypeMapping.cs +++ b/src/EFCore.PG/Storage/Internal/Mapping/NpgsqlListTypeMapping.cs @@ -1,4 +1,5 @@ #region License + // The PostgreSQL License // // Copyright (C) 2016 The Npgsql Development Team @@ -19,9 +20,11 @@ // AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS // ON AN "AS IS" BASIS, AND THE NPGSQL DEVELOPMENT TEAM HAS NO OBLIGATIONS // TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. + #endregion using System; +using System.Collections; using System.Text; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Storage.ValueConversion; @@ -36,49 +39,55 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping /// public class NpgsqlListTypeMapping : RelationalTypeMapping { + /// + /// The CLR type of the list items. + /// public RelationalTypeMapping ElementMapping { get; } /// - /// Creates the default array mapping (i.e. for the single-dimensional CLR array type) + /// Creates the default list mapping. /// public NpgsqlListTypeMapping(RelationalTypeMapping elementMapping, Type listType) - : this(elementMapping.StoreType + "[]", elementMapping, listType) - {} + : this(elementMapping.StoreType + "[]", elementMapping, listType) {} + /// NpgsqlListTypeMapping(string storeType, RelationalTypeMapping elementMapping, Type listType) - : base(new RelationalTypeMappingParameters( - new CoreTypeMappingParameters(listType, null, CreateComparer(elementMapping, listType)), storeType - )) - { - ElementMapping = elementMapping; - } + : base( + new RelationalTypeMappingParameters( + new CoreTypeMappingParameters(listType, null, CreateComparer(elementMapping, listType)), storeType)) + => ElementMapping = elementMapping; + /// protected NpgsqlListTypeMapping(RelationalTypeMappingParameters parameters, RelationalTypeMapping elementMapping) - : base(parameters) {} + : base(parameters) + => ElementMapping = elementMapping; + /// public override RelationalTypeMapping Clone(string storeType, int? size) => new NpgsqlListTypeMapping(StoreType, ElementMapping, ClrType); + /// public override CoreTypeMapping Clone(ValueConverter converter) => new NpgsqlListTypeMapping(Parameters.WithComposedConverter(converter), ElementMapping); + /// protected override string GenerateNonNullSqlLiteral(object value) { - // TODO: Duplicated from NpgsqlArrayTypeMapping - var arr = (Array)value; + var list = (IList)value; - if (arr.Rank != 1) + if (list.GetType().GenericTypeArguments[0] != ElementMapping.ClrType) throw new NotSupportedException("Multidimensional array literals aren't supported"); var sb = new StringBuilder(); sb.Append("ARRAY["); - for (var i = 0; i < arr.Length; i++) + for (var i = 0; i < list.Count; i++) { - sb.Append(ElementMapping.GenerateSqlLiteral(arr.GetValue(i))); - if (i < arr.Length - 1) - sb.Append(","); + if (i > 0) + sb.Append(','); + sb.Append(ElementMapping.GenerateSqlLiteral(list[i])); } - sb.Append("]"); + + sb.Append(']'); return sb.ToString(); } @@ -148,7 +157,7 @@ static List Snapshot(List source, ValueComparer elementComp class SingleDimComparerWithIEquatable : ValueComparer> where TElem : IEquatable { - public SingleDimComparerWithIEquatable(): base( + public SingleDimComparerWithIEquatable() : base( (a, b) => Compare(a, b), o => o.GetHashCode(), // TODO: Need to get hash code of elements... source => DoSnapshot(source)) {} @@ -171,6 +180,7 @@ static bool Compare(List a, List b) continue; return false; } + if (!elem1.Equals(elem2)) return false; } @@ -215,6 +225,7 @@ static bool Compare(List a, List b) continue; return false; } + if (!elem1.Equals(elem2)) return false; } diff --git a/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs b/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs index 2412ae223..8c37eb4e5 100644 --- a/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/ArrayQueryTest.cs @@ -10,10 +10,25 @@ namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query { - public class ArrayQueryTest : IClassFixture + public class ArrayQueryTest : IClassFixture { + #region ArrayTests + + #region Roundtrip + + [Fact] + public void Array_Roundtrip() + { + using (var ctx = CreateContext()) + { + var x = ctx.SomeEntities.Single(e => e.Id == 1); + Assert.Equal(new[] { 3, 4 }, x.SomeArray); + Assert.Equal(new List { 3, 4 }, x.SomeList); + } + } + [Fact] - public void Roundtrip() + public void List_Roundtrip() { using (var ctx = CreateContext()) { @@ -23,65 +38,310 @@ public void Roundtrip() } } + #endregion + + #region Indexers + [Fact] - public void Index_with_constant() + public void Array_Index_with_constant() { using (var ctx = CreateContext()) { var actual = ctx.SomeEntities.Where(e => e.SomeArray[0] == 3).ToList(); Assert.Equal(1, actual.Count); - AssertContainsInSql(@"WHERE (e.""SomeArray""[1]) = 3"); + AssertContainsInSql(@"WHERE e.""SomeArray""[1] = 3"); + } + } + + [Fact] + public void List_Index_with_constant() + { + using (var ctx = CreateContext()) + { + var actual = ctx.SomeEntities.Where(e => e.SomeList[0] == 3).ToList(); + Assert.Equal(1, actual.Count); + AssertContainsInSql(@"WHERE e.""SomeList""[1] = 3"); + } + } + + [Fact] + public void Array_Index_bytea_with_constant() + { + using (var ctx = CreateContext()) + { + var actual = ctx.SomeEntities.Where(e => e.SomeBytea[0] == 3).ToList(); + Assert.Equal(1, actual.Count); + AssertContainsInSql(@"WHERE get_byte(e.""SomeBytea"", 0) = 3"); + } + } + + [Fact] + public void String_Index_text_with_constant_char_as_int() + { + using (var ctx = CreateContext()) + { + var actual = ctx.SomeEntities.Where(e => e.SomeString[0] == 'T').ToList(); + Assert.Equal(1, actual.Count); + AssertContainsInSql(@"WHERE ascii(substr(e.""SomeString"", 1, 1)) = 84"); + } + } + + [Fact] + public void String_Index_text_with_constant_string() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(e => e.SomeString[0].ToString() == "T").ToList(); + AssertContainsInSql(@"WHERE CAST(ascii(substr(e.""SomeString"", 1, 1)) AS text) = 'T'"); + } + } + + [Fact] + public void Array_ElementAt_with_constant() + { + using (var ctx = CreateContext()) + { + var actual = ctx.SomeEntities.Where(e => e.SomeArray.ElementAt(0) == 3).ToList(); + Assert.Equal(1, actual.Count); + AssertContainsInSql(@"WHERE e.""SomeArray""[1] = 3"); + } + } + + [Fact] + public void List_ElementAt_with_constant() + { + using (var ctx = CreateContext()) + { + var actual = ctx.SomeEntities.Where(e => e.SomeList.ElementAt(0) == 3).ToList(); + Assert.Equal(1, actual.Count); + AssertContainsInSql(@"WHERE e.""SomeList""[1] = 3"); + } + } + + [Fact] + public void Array_ElementAt_bytea_with_constant() + { + using (var ctx = CreateContext()) + { + var actual = ctx.SomeEntities.Where(e => e.SomeBytea.ElementAt(0) == 3).ToList(); + Assert.Equal(1, actual.Count); + AssertContainsInSql(@"WHERE get_byte(e.""SomeBytea"", 0) = 3"); + } + } + + [Fact] + public void String_ElementAt_text_with_constant_char_as_int() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(e => e.SomeString.ElementAt(0) == 'T').ToList(); + AssertContainsInSql(@"WHERE ascii(substr(e.""SomeString"", 1, 1)) = 84"); + } + } + + [Fact] + public void String_ElementAt_text_with_constant_string() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(e => e.SomeString.ElementAt(0).ToString() == "T").ToList(); + AssertContainsInSql(@"WHERE CAST(ascii(substr(e.""SomeString"", 1, 1)) AS text) = 'T'"); } } [Fact] - public void Index_with_non_constant() + public void Array_Index_with_non_constant() { using (var ctx = CreateContext()) { + // ReSharper disable once ConvertToConstant.Local var x = 0; var actual = ctx.SomeEntities.Where(e => e.SomeArray[x] == 3).ToList(); Assert.Equal(1, actual.Count); - AssertContainsInSql(@"WHERE (e.""SomeArray""[@__x_0 + 1]) = 3"); + AssertContainsInSql(@"WHERE e.""SomeArray""[@__x_0 + 1] = 3"); } } [Fact] - public void Index_bytea_with_constant() + public void List_Index_with_non_constant() { using (var ctx = CreateContext()) { - var actual = ctx.SomeEntities.Where(e => e.SomeBytea[0] == 3).ToList(); + // ReSharper disable once ConvertToConstant.Local + var x = 0; + var actual = ctx.SomeEntities.Where(e => e.SomeList[x] == 3).ToList(); + Assert.Equal(1, actual.Count); + AssertContainsInSql(@"WHERE e.""SomeList""[@__x_0 + 1] = 3"); + } + } + + [Fact] + public void Array_Index_bytea_with_non_constant() + { + using (var ctx = CreateContext()) + { + // ReSharper disable once ConvertToConstant.Local + var x = 0; + var actual = ctx.SomeEntities.Where(e => e.SomeBytea[x] == 3).ToList(); + Assert.Equal(1, actual.Count); + AssertContainsInSql(@"WHERE get_byte(e.""SomeBytea"", @__x_0) = 3"); + } + } + + [Fact] + public void String_Index_text_with_non_constant_char_as_int() + { + using (var ctx = CreateContext()) + { + // ReSharper disable once ConvertToConstant.Local + var x = 0; + var actual = ctx.SomeEntities.Where(e => e.SomeString[x] == 'T').ToList(); + Assert.Equal(1, actual.Count); + AssertContainsInSql(@"WHERE ascii(substr(e.""SomeString"", @__x_0 + 1, 1)) = 84"); + } + } + + [Fact] + public void String_Index_text_with_non_constant_string() + { + using (var ctx = CreateContext()) + { + // ReSharper disable once ConvertToConstant.Local + var x = 0; + var _ = ctx.SomeEntities.Where(e => e.SomeString[x].ToString() == "T").ToList(); + AssertContainsInSql(@"WHERE CAST(ascii(substr(e.""SomeString"", @__x_0 + 1, 1)) AS text) = 'T'"); + } + } + + [Fact] + public void Array_ElementAt_with_non_constant() + { + using (var ctx = CreateContext()) + { + // ReSharper disable once ConvertToConstant.Local + var x = 0; + var actual = ctx.SomeEntities.Where(e => e.SomeArray.ElementAt(x) == 3).ToList(); + Assert.Equal(1, actual.Count); + AssertContainsInSql(@"WHERE e.""SomeArray""[@__x_0 + 1] = 3"); + } + } + + [Fact] + public void List_ElementAt_with_non_constant() + { + using (var ctx = CreateContext()) + { + // ReSharper disable once ConvertToConstant.Local + var x = 0; + var actual = ctx.SomeEntities.Where(e => e.SomeList.ElementAt(x) == 3).ToList(); + Assert.Equal(1, actual.Count); + AssertContainsInSql(@"WHERE e.""SomeList""[@__x_0 + 1] = 3"); + } + } + + [Fact] + public void Array_ElementAt_bytea_with_non_constant() + { + using (var ctx = CreateContext()) + { + // ReSharper disable once ConvertToConstant.Local + var x = 0; + var actual = ctx.SomeEntities.Where(e => e.SomeBytea.ElementAt(x) == 3).ToList(); + Assert.Equal(1, actual.Count); + AssertContainsInSql(@"WHERE get_byte(e.""SomeBytea"", @__x_0) = 3"); + } + } + + [Fact] + public void String_ElementAt_text_with_non_constant_char_as_int() + { + using (var ctx = CreateContext()) + { + // ReSharper disable once ConvertToConstant.Local + var x = 0; + var actual = ctx.SomeEntities.Where(e => e.SomeString.ElementAt(x) == 'T').ToList(); Assert.Equal(1, actual.Count); - AssertContainsInSql(@"WHERE (get_byte(e.""SomeBytea"", 0)) = 3"); + AssertContainsInSql(@"WHERE ascii(substr(e.""SomeString"", @__x_0 + 1, 1)) = 84"); + } + } + + [Fact] + public void String_ElementAt_text_with_non_constant_sting() + { + using (var ctx = CreateContext()) + { + // ReSharper disable once ConvertToConstant.Local + var x = 0; + var _ = ctx.SomeEntities.Where(e => e.SomeString.ElementAt(x).ToString() == "T").ToList(); + AssertContainsInSql(@"WHERE CAST(ascii(substr(e.""SomeString"", @__x_0 + 1, 1)) AS text) = 'T'"); } } [Fact] - public void Index_multidimensional() + public void Array_Index_multidimensional() { using (var ctx = CreateContext()) { // Operations on multidimensional arrays aren't mapped to SQL yet - var actual = ctx.SomeEntities.Where(e => e.SomeMatrix[0,0] == 5).ToList(); + var actual = ctx.SomeEntities.Where(e => e.SomeMatrix[0, 0] == 5).ToList(); Assert.Equal(1, actual.Count); } } + #endregion + + #region Equality + [Fact] - public void SequenceEqual_with_parameter() + public void Array_Equal_with_parameter() { using (var ctx = CreateContext()) { - var arr = new[] { 3, 4 }; - var x = ctx.SomeEntities.Single(e => e.SomeArray.SequenceEqual(arr)); - Assert.Equal(new[] { 3, 4 }, x.SomeArray); - AssertContainsInSql(@"WHERE e.""SomeArray"" = @"); + var array = new[] { 3, 4 }; + var x = ctx.SomeEntities.Single(e => e.SomeArray.Equals(array)); + Assert.Equal(array, x.SomeArray); + AssertContainsInSql(@"WHERE e.""SomeArray"" = @__array_0"); + } + } + + [Fact] + public void List_Equal_with_parameter() + { + using (var ctx = CreateContext()) + { + var list = new List { 3, 4 }; + var x = ctx.SomeEntities.Single(e => e.SomeList.Equals(list)); + Assert.Equal(list, x.SomeList); + AssertContainsInSql(@"WHERE e.""SomeList"" = @__list_0"); + } + } + + [Fact] + public void Array_SequenceEqual_with_parameter() + { + using (var ctx = CreateContext()) + { + var array = new[] { 3, 4 }; + var x = ctx.SomeEntities.Single(e => e.SomeArray.Equals(array)); + Assert.Equal(array, x.SomeArray); + AssertContainsInSql(@"WHERE e.""SomeArray"" = @__array_0"); + } + } + + [Fact] + public void List_SequenceEqual_with_parameter() + { + using (var ctx = CreateContext()) + { + var list = new List { 3, 4 }; + var x = ctx.SomeEntities.Single(e => e.SomeList.SequenceEqual(list)); + Assert.Equal(list, x.SomeList); + AssertContainsInSql(@"WHERE e.""SomeList"" = @__list_0"); } } [Fact] - public void SequenceEqual_with_array_literal() + public void Array_SequenceEqual_with_literal() { using (var ctx = CreateContext()) { @@ -92,7 +352,22 @@ public void SequenceEqual_with_array_literal() } [Fact] - public void Contains_with_literal() + public void List_SequenceEqual_with_literal() + { + using (var ctx = CreateContext()) + { + var x = ctx.SomeEntities.Single(e => e.SomeList.SequenceEqual(new List { 3, 4 })); + Assert.Equal(new List { 3, 4 }, x.SomeList); + AssertContainsInSql(@"WHERE e.""SomeList"" = ARRAY[3,4]"); + } + } + + #endregion + + #region Containment + + [Fact] + public void Array_Contains_with_literal() { using (var ctx = CreateContext()) { @@ -103,10 +378,22 @@ public void Contains_with_literal() } [Fact] - public void Contains_with_parameter() + public void List_Contains_with_literal() + { + using (var ctx = CreateContext()) + { + var x = ctx.SomeEntities.Single(e => e.SomeList.Contains(3)); + Assert.Equal(new[] { 3, 4 }, x.SomeList); + AssertContainsInSql(@"WHERE 3 = ANY (e.""SomeList"")"); + } + } + + [Fact] + public void Array_Contains_with_parameter() { using (var ctx = CreateContext()) { + // ReSharper disable once ConvertToConstant.Local var p = 3; var x = ctx.SomeEntities.Single(e => e.SomeArray.Contains(p)); Assert.Equal(new[] { 3, 4 }, x.SomeArray); @@ -115,7 +402,20 @@ public void Contains_with_parameter() } [Fact] - public void Contains_with_column() + public void List_Contains_with_parameter() + { + using (var ctx = CreateContext()) + { + // ReSharper disable once ConvertToConstant.Local + var p = 3; + var x = ctx.SomeEntities.Single(e => e.SomeList.Contains(p)); + Assert.Equal(new[] { 3, 4 }, x.SomeList); + AssertContainsInSql(@"WHERE @__p_0 = ANY (e.""SomeList"")"); + } + } + + [Fact] + public void Array_Contains_with_column() { using (var ctx = CreateContext()) { @@ -126,121 +426,607 @@ public void Contains_with_column() } [Fact] - public void Length() + public void List_Contains_with_column() { using (var ctx = CreateContext()) { - var x = ctx.SomeEntities.Single(e => e.SomeArray.Length == 2); - Assert.Equal(new[] { 3, 4 }, x.SomeArray); - AssertContainsInSql(@"WHERE array_length(e.""SomeArray"", 1) = 2"); + var x = ctx.SomeEntities.Single(e => e.SomeList.Contains(e.Id + 2)); + Assert.Equal(new List { 3, 4 }, x.SomeList); + AssertContainsInSql(@"WHERE e.""Id"" + 2 = ANY (e.""SomeList"")"); } } - [Fact(Skip="https://github.com/aspnet/EntityFramework/issues/9242")] - public void Length_on_EF_Property() + [Fact] + public void Array_All_Contains_List() { using (var ctx = CreateContext()) { - // TODO: This fails - var x = ctx.SomeEntities.Single(e => EF.Property(e, nameof(SomeArrayEntity.SomeArray)).Length == 2); - Assert.Equal(new[] { 3, 4 }, x.SomeArray); - AssertContainsInSql(@"WHERE array_length(e.""SomeArray"", 1) = 2"); + var _ = ctx.SomeEntities.Where(x => x.SomeList.All(y => x.SomeArray.Contains(y))).ToList(); + AssertContainsInSql(@"WHERE (x.""SomeList"" <@ x.""SomeArray"") = TRUE"); } } [Fact] - public void Length_on_literal_not_translated() + public void List_All_Contains_Array() { using (var ctx = CreateContext()) { - var x = ctx.SomeEntities.Where(e => new[] { 1, 2, 3 }.Length == e.Id).ToList(); - AssertDoesNotContainInSql("array_length"); + var _ = ctx.SomeEntities.Where(x => x.SomeArray.All(y => x.SomeList.Contains(y))).ToList(); + AssertContainsInSql(@"WHERE (x.""SomeArray"" <@ x.""SomeList"") = TRUE"); } } - #region Support - - ArrayFixture Fixture { get; } - - public ArrayQueryTest(ArrayFixture fixture) + [Fact] + public void Array_Any_Contains_List() { - Fixture = fixture; - Fixture.TestSqlLoggerFactory.Clear(); + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => x.SomeArray.Any(y => x.SomeList.Contains(y))).ToList(); + AssertContainsInSql(@"WHERE (x.""SomeArray"" && x.""SomeList"") = TRUE"); + } } - ArrayContext CreateContext() => Fixture.CreateContext(); - - void AssertContainsInSql(string expected) - => Assert.Contains(expected, Fixture.TestSqlLoggerFactory.Sql); - - void AssertDoesNotContainInSql(string expected) - => Assert.DoesNotContain(expected, Fixture.TestSqlLoggerFactory.Sql); - - #endregion Support - } - - public class ArrayContext : DbContext - { - public DbSet SomeEntities { get; set; } - public ArrayContext(DbContextOptions options) : base(options) {} - protected override void OnModelCreating(ModelBuilder builder) + [Fact] + public void List_Any_Contains_Array() { - + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => x.SomeList.Any(y => x.SomeArray.Contains(y))).ToList(); + AssertContainsInSql(@"WHERE (x.""SomeList"" && x.""SomeArray"") = TRUE"); + } } - } - public class SomeArrayEntity - { - public int Id { get; set; } - public int[] SomeArray { get; set; } - public int[,] SomeMatrix { get; set; } - public List SomeList { get; set; } - public byte[] SomeBytea { get; set; } - public string SomeText { get; set; } - } + #endregion - public class ArrayFixture : IDisposable - { - readonly DbContextOptions _options; - public TestSqlLoggerFactory TestSqlLoggerFactory { get; } = new TestSqlLoggerFactory(); + #region Count - public ArrayFixture() + [Fact] + public void Array_Length() { - _testStore = NpgsqlTestStore.CreateScratch(); - _options = new DbContextOptionsBuilder() - .UseNpgsql(_testStore.ConnectionString, b => b.ApplyConfiguration()) - .UseInternalServiceProvider( - new ServiceCollection() - .AddEntityFrameworkNpgsql() - .AddSingleton(TestSqlLoggerFactory) - .BuildServiceProvider()) - .Options; - using (var ctx = CreateContext()) { - ctx.Database.EnsureCreated(); - ctx.SomeEntities.Add(new SomeArrayEntity - { - Id=1, - SomeArray = new[] { 3, 4 }, - SomeBytea = new byte[] { 3, 4 }, - SomeMatrix = new[,] { { 5, 6 }, { 7, 8 } }, - SomeList = new List { 3, 4 } - }); - ctx.SomeEntities.Add(new SomeArrayEntity - { - Id=2, - SomeArray = new[] { 5, 6, 7 }, - SomeBytea = new byte[] { 5, 6, 7 }, - SomeMatrix = new[,] { { 10, 11 }, { 12, 13 } }, - SomeList = new List { 3, 4 } - }); - ctx.SaveChanges(); + var x = ctx.SomeEntities.Single(e => e.SomeArray.Length == 2); + Assert.Equal(new[] { 3, 4 }, x.SomeArray); + AssertContainsInSql(@"WHERE array_length(e.""SomeArray"", 1) = 2"); } } - readonly NpgsqlTestStore _testStore; - public ArrayContext CreateContext() => new ArrayContext(_options); - public void Dispose() => _testStore.Dispose(); + [Fact] + public void List_Length() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(e => e.SomeList.Count == 0).ToArray(); + AssertContainsInSql(@"WHERE array_length(e.""SomeList"", 1) = 0"); + } + } + + [Fact] + public void Array_Count() + { + using (var ctx = CreateContext()) + { + // ReSharper disable once UseCollectionCountProperty + var _ = ctx.SomeEntities.Where(e => e.SomeArray.Count() == 1).ToArray(); + AssertContainsInSql(@"WHERE array_length(e.""SomeArray"", 1) = 1"); + } + } + + [Fact] + public void List_Count() + { + using (var ctx = CreateContext()) + { + // ReSharper disable once UseCollectionCountProperty + var _ = ctx.SomeEntities.Where(e => e.SomeList.Count() == 1).ToArray(); + AssertContainsInSql(@"WHERE array_length(e.""SomeList"", 1) = 1"); + } + } + + [Fact(Skip = "https://github.com/aspnet/EntityFramework/issues/9242")] + public void Array_Length_on_EF_Property() + { + using (var ctx = CreateContext()) + { + // TODO: This fails + var x = ctx.SomeEntities.Single(e => EF.Property(e, nameof(SomeArrayEntity.SomeArray)).Length == 2); + Assert.Equal(new[] { 3, 4 }, x.SomeArray); + AssertContainsInSql(@"WHERE array_length(e.""SomeArray"", 1) = 2"); + } + } + + [Fact(Skip = "https://github.com/aspnet/EntityFramework/issues/9242")] + public void List_Length_on_EF_Property() + { + using (var ctx = CreateContext()) + { + // TODO: This fails + var x = ctx.SomeEntities.Single(e => EF.Property>(e, nameof(SomeArrayEntity.SomeList)).Count == 2); + Assert.Equal(new List { 3, 4 }, x.SomeList); + AssertContainsInSql(@"WHERE array_length(e.""SomeList"", 1) = 2"); + } + } + + [Fact] + public void Array_Length_on_literal_not_translated() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(e => new[] { 1, 2, 3 }.Length == e.Id).ToList(); + AssertContainsInSql(@"WHERE 3 = e.""Id"""); + AssertDoesNotContainInSql("array_length"); + } + } + + [Fact] + public void List_Length_on_literal_not_translated() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(e => new List { 1, 2, 3 }.Count == e.Id).ToList(); + AssertContainsInSql(@"WHERE @__Count_0 = e.""Id"""); + AssertDoesNotContainInSql("array_length"); + } + } + + #endregion + + #region Concatenation + + [Fact] + public void Array_Concat_with_array_column() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Select(e => e.SomeArray.Concat(e.SomeArray)).ToList(); + AssertContainsInSql(@"SELECT (e.""SomeArray"" || e.""SomeArray"")"); + } + } + + [Fact] + public void List_Concat_with_list_column() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Select(e => e.SomeList.Concat(e.SomeList)).ToList(); + AssertContainsInSql(@"SELECT (e.""SomeList"" || e.""SomeList"")"); + } + } + + [Fact] + public void Array_Concat_with_list_column() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Select(e => e.SomeArray.Concat(e.SomeList)).ToList(); + AssertContainsInSql(@"SELECT (e.""SomeArray"" || e.""SomeList"")"); + } + } + + [Fact] + public void List_Concat_with_array_column() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Select(e => e.SomeList.Concat(e.SomeArray)).ToList(); + AssertContainsInSql(@"SELECT (e.""SomeList"" || e.""SomeArray"")"); + } + } + +// .NET 4.6.1 doesn't include the Enumerable.Append and Enumerable.Prepend functions... +#if !NET461 + [Fact] + public void Array_Append_constant() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Select(e => e.SomeArray.Append(0)).ToList(); + AssertContainsInSql(@"SELECT (e.""SomeArray"" || 0)"); + } + } + + [Fact] + public void List_Append_constant() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Select(e => e.SomeList.Append(0)).ToList(); + AssertContainsInSql(@"SELECT (e.""SomeList"" || 0)"); + } + } + + [Fact] + public void Array_Prepend_constant() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Select(e => e.SomeArray.Prepend(0)).ToList(); + AssertContainsInSql(@"SELECT (0 || e.""SomeArray"")"); + } + } + + [Fact] + public void List_Prepend_constant() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Select(e => e.SomeList.Prepend(0)).ToList(); + AssertContainsInSql(@"SELECT (0 || e.""SomeList"")"); + } + } + +#endif + + #endregion + + #region IndexOf + + [Fact] + public void Array_IndexOf_constant() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Select(e => Array.IndexOf(e.SomeArray, 0)).ToList(); + AssertContainsInSql(@"SELECT COALESCE(array_position(e.""SomeArray"", 0), -1)"); + } + } + + [Fact] + public void List_IndexOf_constant() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Select(e => e.SomeList.IndexOf(0)).ToList(); + AssertContainsInSql(@"SELECT COALESCE(array_position(e.""SomeList"", 0), -1)"); + } + } + + #endregion + + #region StringConversion + + [Fact] + public void Array_ArrayToString() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Select(e => EF.Functions.ArrayToString(e.SomeArray, ",")).ToList(); + AssertContainsInSql(@"SELECT array_to_string(e.""SomeArray"", ',')"); + } + } + + [Fact] + public void List_ArrayToString() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Select(e => EF.Functions.ArrayToString(e.SomeList, ",")).ToList(); + AssertContainsInSql(@"SELECT array_to_string(e.""SomeList"", ',')"); + } + } + + [Fact] + public void Array_ArrayToString_with_null() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Select(e => EF.Functions.ArrayToString(e.SomeArray, ",", "*")).ToList(); + AssertContainsInSql(@"SELECT array_to_string(e.""SomeArray"", ',', '*')"); + } + } + + [Fact] + public void List_ArrayToString_with_null() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Select(e => EF.Functions.ArrayToString(e.SomeList, ",", "*")).ToList(); + AssertContainsInSql(@"SELECT array_to_string(e.""SomeList"", ',', '*')"); + } + } + + [Fact] + public void Array_StringToArray() + { + using (var ctx = CreateContext()) + { + var _ = + ctx.SomeEntities + .Select(e => EF.Functions.ArrayToString(e.SomeArray, ",", "*")) + .Select(e => EF.Functions.StringToArray(e, ",", "*")).ToList(); + + AssertContainsInSql(@"SELECT string_to_array(array_to_string(e.""SomeArray"", ',', '*'), ',', '*')"); + } + } + + [Fact] + public void List_StringToList() + { + using (var ctx = CreateContext()) + { + var _ = + ctx.SomeEntities + .Select(e => EF.Functions.ArrayToString(e.SomeList, ",", "*")) + .Select(e => EF.Functions.StringToList(e, ",", "*")).ToList(); + + AssertContainsInSql(@"SELECT string_to_array(array_to_string(e.""SomeList"", ',', '*'), ',', '*')"); + } + } + + #endregion + + #region Exists + + [Fact] + public void Array_Exists_equals_with_literal_constant() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => Array.Exists(x.SomeArray, y => y == 1)).ToList(); + AssertContainsInSql(@"WHERE 1 = ANY (x.""SomeArray"") = TRUE"); + } + } + + [Fact] + public void List_Exists_equals_with_literal_constant() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => x.SomeList.Exists(y => y == 1)).ToList(); + AssertContainsInSql(@"WHERE 1 = ANY (x.""SomeList"") = TRUE"); + } + } + + [Fact] + public void Array_Exists_not_equal_with_literal_constant() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => Array.Exists(x.SomeArray, y => y != 1)).ToList(); + AssertContainsInSql(@"WHERE 1 <> ANY (x.""SomeArray"") = TRUE"); + } + } + + [Fact] + public void List_Exists_not_equal_with_literal_constant() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => x.SomeList.Exists(y => y != 1)).ToList(); + AssertContainsInSql(@"WHERE 1 <> ANY (x.""SomeList"") = TRUE"); + } + } + + [Fact] + public void Array_Exists_equals_with_parameter_array_and_column_array_element() + { + using (var ctx = CreateContext()) + { + var array = new[] { 0, 1, 2 }; + var _ = ctx.SomeEntities.Where(x => Array.Exists(array, y => y == x.SomeArray[0])).ToList(); + AssertContainsInSql(@"WHERE x.""SomeArray""[1] IN (0, 1, 2)"); + } + } + + [Fact] + public void List_Exists_equals_with_parameter_array_and_column_list_element() + { + using (var ctx = CreateContext()) + { + var list = new List { 0, 1, 2 }; + var _ = ctx.SomeEntities.Where(x => list.Exists(y => y == x.SomeList[0])).ToList(); + AssertContainsInSql(@"WHERE x.""SomeList""[1] IN (0, 1, 2)"); + } + } + + [Fact] + public void Array_Exists_equals_with_parameter_array_and_column_array_element_flipped() + { + using (var ctx = CreateContext()) + { + var array = new[] { 0, 1, 2 }; + var _ = ctx.SomeEntities.Where(x => Array.Exists(array, y => x.SomeArray[0] == y)).ToList(); + AssertContainsInSql(@"WHERE x.""SomeArray""[1] IN (0, 1, 2)"); + } + } + + [Fact] + public void List_Exists_equals_with_parameter_array_and_column_list_element_flipped() + { + using (var ctx = CreateContext()) + { + var list = new List { 0, 1, 2 }; + var _ = ctx.SomeEntities.Where(x => list.Exists(y => x.SomeList[0] == y)).ToList(); + AssertContainsInSql(@"WHERE x.""SomeList""[1] IN (0, 1, 2)"); + } + } + + [Fact] + public void Array_Exists_equals_with_literal_constant_flipped() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => Array.Exists(x.SomeArray, y => 1 == y)).ToList(); + AssertContainsInSql(@"WHERE 1 = ANY (x.""SomeArray"") = TRUE"); + } + } + + [Fact] + public void List_Exists_equals_with_literal_constant_flipped() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => x.SomeList.Exists(y => 1 == y)).ToList(); + AssertContainsInSql(@"WHERE 1 = ANY (x.""SomeList"") = TRUE"); + } + } + + [Fact] + public void Array_Exists_less_than_with_literal_constant() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => Array.Exists(x.SomeArray, y => y < 1)).ToList(); + AssertContainsInSql(@"WHERE 1 > ANY (x.""SomeArray"") = TRUE"); + } + } + + [Fact] + public void List_Exists_less_than_with_literal_constant() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => x.SomeList.Exists(y => y < 1)).ToList(); + AssertContainsInSql(@"WHERE 1 > ANY (x.""SomeList"") = TRUE"); + } + } + + [Fact] + public void Array_Exists_less_than_with_literal_constant_flipped() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => Array.Exists(x.SomeArray, y => 1 > y)).ToList(); + AssertContainsInSql(@"WHERE 1 > ANY (x.""SomeArray"") = TRUE"); + } + } + + [Fact] + public void List_Exists_less_than_with_literal_constant_flipped() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => x.SomeList.Exists(y => 1 > y)).ToList(); + AssertContainsInSql(@"WHERE 1 > ANY (x.""SomeList"") = TRUE"); + } + } + + [Fact] + public void Array_Exists_equals_with_column_list_element() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => Array.Exists(x.SomeArray, y => y == x.SomeList[0])).ToList(); + AssertContainsInSql(@"WHERE x.""SomeList""[1] = ANY (x.""SomeArray"") = TRUE"); + } + } + + [Fact] + public void List_Exists_equals_with_column_array_element() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => x.SomeList.Exists(y => y == x.SomeArray[0])).ToList(); + AssertContainsInSql(@"WHERE x.""SomeArray""[1] = ANY (x.""SomeList"") = TRUE"); + } + } + + [Fact] + public void Array_Exists_equals_with_column_list_element_flipped() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => Array.Exists(x.SomeArray, y => x.SomeList[0] == y)).ToList(); + AssertContainsInSql(@"WHERE x.""SomeList""[1] = ANY (x.""SomeArray"") = TRUE"); + } + } + + [Fact] + public void List_Exists_equals_with_column_array_element_flipped() + { + using (var ctx = CreateContext()) + { + var _ = ctx.SomeEntities.Where(x => x.SomeList.Exists(y => x.SomeArray[0] == y)).ToList(); + AssertContainsInSql(@"WHERE x.""SomeArray""[1] = ANY (x.""SomeList"") = TRUE"); + } + } + + #endregion + + #endregion + + #region Support + + ArrayFixture Fixture { get; } + + public ArrayQueryTest(ArrayFixture fixture) + { + Fixture = fixture; + Fixture.TestSqlLoggerFactory.Clear(); + } + + ArrayContext CreateContext() => Fixture.CreateContext(); + + void AssertContainsInSql(string expected) + => Assert.Contains(expected, Fixture.TestSqlLoggerFactory.Sql); + + void AssertDoesNotContainInSql(string expected) + => Assert.DoesNotContain(expected, Fixture.TestSqlLoggerFactory.Sql); + + public class ArrayContext : DbContext + { + public DbSet SomeEntities { get; set; } + public ArrayContext(DbContextOptions options) : base(options) {} + protected override void OnModelCreating(ModelBuilder builder) {} + } + + public class SomeArrayEntity + { + public int Id { get; set; } + public int[] SomeArray { get; set; } + public int[,] SomeMatrix { get; set; } + public List SomeList { get; set; } + public byte[] SomeBytea { get; set; } + + // ReSharper disable once UnusedMember.Global + public string SomeString { get; set; } + } + + public class ArrayFixture : IDisposable + { + readonly DbContextOptions _options; + public TestSqlLoggerFactory TestSqlLoggerFactory { get; } = new TestSqlLoggerFactory(); + + public ArrayFixture() + { + _testStore = NpgsqlTestStore.CreateScratch(); + _options = new DbContextOptionsBuilder() + .UseNpgsql(_testStore.ConnectionString, b => b.ApplyConfiguration()) + .UseInternalServiceProvider( + new ServiceCollection() + .AddEntityFrameworkNpgsql() + .AddSingleton(TestSqlLoggerFactory) + .BuildServiceProvider()) + .Options; + + using (var ctx = CreateContext()) + { + ctx.Database.EnsureCreated(); + ctx.SomeEntities.Add(new SomeArrayEntity + { + Id = 1, + SomeArray = new[] { 3, 4 }, + SomeBytea = new byte[] { 3, 4 }, + SomeMatrix = new[,] { { 5, 6 }, { 7, 8 } }, + SomeList = new List { 3, 4 } + }); + ctx.SomeEntities.Add(new SomeArrayEntity + { + Id = 2, + SomeArray = new[] { 5, 6, 7 }, + SomeBytea = new byte[] { 5, 6, 7 }, + SomeMatrix = new[,] { { 10, 11 }, { 12, 13 } }, + SomeList = new List { 3, 4 } + }); + ctx.SaveChanges(); + } + } + + readonly NpgsqlTestStore _testStore; + public ArrayContext CreateContext() => new ArrayContext(_options); + public void Dispose() => _testStore.Dispose(); + } + + #endregion } }