From 5ffdda08c522510dd853bca3fdbf2c741f8bd577 Mon Sep 17 00:00:00 2001 From: Andrea Canciani Date: Wed, 15 May 2024 07:41:21 +0200 Subject: [PATCH 1/3] Copy `CombineTerms` from `QueryableAggregateMethodTranslator` and remove the fragment part specific to `COUNT`. --- ...qliteQueryableAggregateMethodTranslator.cs | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableAggregateMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableAggregateMethodTranslator.cs index 77a4dd7bf94..d99c8328069 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableAggregateMethodTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableAggregateMethodTranslator.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. using Microsoft.EntityFrameworkCore.Query.SqlExpressions; @@ -114,4 +114,21 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress => expression.TypeMapping?.Converter?.ProviderClrType ?? expression.TypeMapping?.ClrType ?? expression.Type; + + private SqlExpression CombineTerms(EnumerableExpression enumerableExpression, SqlExpression sqlExpression) + { + if (enumerableExpression.Predicate != null) + { + sqlExpression = _sqlExpressionFactory.Case( + new List { new(enumerableExpression.Predicate, sqlExpression) }, + elseResult: null); + } + + if (enumerableExpression.IsDistinct) + { + sqlExpression = new DistinctExpression(sqlExpression); + } + + return sqlExpression; + } } From 50b76b6f4114fbe29a74f1e2021853fd57de870f Mon Sep 17 00:00:00 2001 From: Andrea Canciani Date: Wed, 15 May 2024 07:42:58 +0200 Subject: [PATCH 2/3] Implement average aggregation for `decimal` in SQLite Contributes to #19635 --- .../SqliteQueryableAggregateMethodTranslator.cs | 11 ++++++++--- .../Storage/Internal/SqliteRelationalConnection.cs | 11 +++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableAggregateMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableAggregateMethodTranslator.cs index d99c8328069..a89ef57b28d 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableAggregateMethodTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQueryableAggregateMethodTranslator.cs @@ -53,9 +53,14 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress var averageArgumentType = GetProviderType(averageSqlExpression); if (averageArgumentType == typeof(decimal)) { - throw new NotSupportedException( - SqliteStrings.AggregateOperationNotSupported( - nameof(Queryable.Average), averageArgumentType.ShortDisplayName())); + averageSqlExpression = CombineTerms(source, averageSqlExpression); + return _sqlExpressionFactory.Function( + "ef_avg", + [averageSqlExpression], + nullable: true, + argumentsPropagateNullability: [false], + averageSqlExpression.Type, + averageSqlExpression.TypeMapping); } break; diff --git a/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs b/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs index cf83ca55632..386e0eaa4f5 100644 --- a/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs +++ b/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs @@ -142,6 +142,17 @@ private void InitializeDbConnection(DbConnection connection) name: "ef_negate", (decimal? m) => -m, isDeterministic: true); + + sqliteConnection.CreateAggregate( + "ef_avg", + seed: (0m, 0ul), + ((decimal sum, ulong count) acc, decimal? value) => value is null + ? acc + : (acc.sum + value.Value, acc.count + 1), + ((decimal sum, ulong count) acc) => acc.count == 0 + ? default(decimal?) + : acc.sum / acc.count, + isDeterministic: true); } else { From afb9ceedff291d3408c713ac6054c652e550e37b Mon Sep 17 00:00:00 2001 From: Andrea Canciani Date: Wed, 15 May 2024 07:42:58 +0200 Subject: [PATCH 3/3] Update tests --- .../BuiltInDataTypesSqliteTest.cs | 11 ++- .../AdHocMiscellaneousQuerySqliteTest.cs | 67 ++++++++++++++- .../Query/Ef6GroupBySqliteTest.cs | 17 +++- ...thwindAggregateOperatorsQuerySqliteTest.cs | 84 ++++++++++++++----- 4 files changed, 146 insertions(+), 33 deletions(-) diff --git a/test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs index 1d92a24bc37..87c6e6f8133 100644 --- a/test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs @@ -936,7 +936,7 @@ public virtual void Cant_query_Max_of_converted_types() } [ConditionalFact] - public virtual void Cant_query_Average_of_converted_types() + public virtual void Can_query_Average_of_converted_types() { using var context = CreateContext(); context.Add( @@ -958,11 +958,10 @@ public virtual void Cant_query_Average_of_converted_types() context.SaveChanges(); Assert.Equal( - SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Average), typeof(decimal).ShortDisplayName()), - Assert.Throws( - () => context.Set() - .Where(e => e.PartitionId == 202) - .Average(e => e.TestNullableDecimal)).Message); + 1.000000000000002m, + context.Set() + .Where(e => e.PartitionId == 202) + .Average(e => e.TestNullableDecimal)); } [ConditionalFact] diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/AdHocMiscellaneousQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/AdHocMiscellaneousQuerySqliteTest.cs index a8750daf50a..7d49ffdfb98 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/AdHocMiscellaneousQuerySqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/AdHocMiscellaneousQuerySqliteTest.cs @@ -20,7 +20,68 @@ INSERT INTO ZeroKey VALUES (NULL) """); public override async Task Average_with_cast() - => Assert.Equal( - SqliteStrings.AggregateOperationNotSupported("Average", "decimal"), - (await Assert.ThrowsAsync(base.Average_with_cast)).Message); + { + await base.Average_with_cast(); + + AssertSql( + """ +SELECT "p"."Id", "p"."DecimalColumn", "p"."DoubleColumn", "p"."FloatColumn", "p"."IntColumn", "p"."LongColumn", "p"."NullableDecimalColumn", "p"."NullableDoubleColumn", "p"."NullableFloatColumn", "p"."NullableIntColumn", "p"."NullableLongColumn", "p"."Price" +FROM "Prices" AS "p" +""", + // + """ +SELECT ef_avg("p"."Price") +FROM "Prices" AS "p" +""", + // + """ +SELECT AVG(CAST("p"."IntColumn" AS REAL)) +FROM "Prices" AS "p" +""", + // + """ +SELECT AVG(CAST("p"."NullableIntColumn" AS REAL)) +FROM "Prices" AS "p" +""", + // + """ +SELECT AVG(CAST("p"."LongColumn" AS REAL)) +FROM "Prices" AS "p" +""", + // + """ +SELECT AVG(CAST("p"."NullableLongColumn" AS REAL)) +FROM "Prices" AS "p" +""", + // + """ +SELECT CAST(AVG("p"."FloatColumn") AS REAL) +FROM "Prices" AS "p" +""", + // + """ +SELECT CAST(AVG("p"."NullableFloatColumn") AS REAL) +FROM "Prices" AS "p" +""", + // + """ +SELECT AVG("p"."DoubleColumn") +FROM "Prices" AS "p" +""", + // + """ +SELECT AVG("p"."NullableDoubleColumn") +FROM "Prices" AS "p" +""", + // + """ +SELECT ef_avg("p"."DecimalColumn") +FROM "Prices" AS "p" +""", + // + """ +SELECT ef_avg("p"."NullableDecimalColumn") +FROM "Prices" AS "p" +"""); + } } diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/Ef6GroupBySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/Ef6GroupBySqliteTest.cs index db7b6d52dd6..01aab7d8271 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/Ef6GroupBySqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/Ef6GroupBySqliteTest.cs @@ -17,10 +17,16 @@ public Ef6GroupBySqliteTest(Ef6GroupBySqliteFixture fixture, ITestOutputHelper t } public override async Task Average_Grouped_from_LINQ_101(bool async) - => Assert.Equal( - SqliteStrings.AggregateOperationNotSupported("Average", "decimal"), - (await Assert.ThrowsAsync( - () => base.Average_Grouped_from_LINQ_101(async))).Message); + { + await base.Average_Grouped_from_LINQ_101(async); + + AssertSql( + """ +SELECT "p"."Category", ef_avg("p"."UnitPrice") AS "AveragePrice" +FROM "ProductForLinq" AS "p" +GROUP BY "p"."Category" +"""); + } public override async Task Max_Grouped_from_LINQ_101(bool async) => Assert.Equal( @@ -49,6 +55,9 @@ public override async Task Group_Join_from_LINQ_101(bool async) (await Assert.ThrowsAsync( () => base.Group_Join_from_LINQ_101(async))).Message); + private void AssertSql(params string[] expected) + => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); + public class Ef6GroupBySqliteFixture : Ef6GroupByFixtureBase, ITestSqlLoggerFactory { public TestSqlLoggerFactory TestSqlLoggerFactory diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqliteTest.cs index f2e19956f0c..8a4931109a9 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqliteTest.cs @@ -34,32 +34,76 @@ public override async Task Sum_with_division_on_decimal_no_significant_digits(bo .Message); public override async Task Average_with_division_on_decimal(bool async) - => Assert.Equal( - SqliteStrings.AggregateOperationNotSupported("Average", "decimal"), - (await Assert.ThrowsAsync( - async () => await base.Average_with_division_on_decimal(async))) - .Message); + { + await base.Average_with_division_on_decimal(async); + + AssertSql( + """ +SELECT ef_avg(ef_divide(CAST("o"."Quantity" AS TEXT), '2.09')) +FROM "Order Details" AS "o" +"""); + } public override async Task Average_with_division_on_decimal_no_significant_digits(bool async) - => Assert.Equal( - SqliteStrings.AggregateOperationNotSupported("Average", "decimal"), - (await Assert.ThrowsAsync( - async () => await base.Average_with_division_on_decimal_no_significant_digits(async))) - .Message); + { + await base.Average_with_division_on_decimal_no_significant_digits(async); + + AssertSql( + """ +SELECT ef_avg(ef_divide(CAST("o"."Quantity" AS TEXT), '2.0')) +FROM "Order Details" AS "o" +"""); + } + + public override async Task Average_over_max_subquery_is_client_eval(bool async) - => Assert.Equal( - SqliteStrings.AggregateOperationNotSupported("Average", "decimal"), - (await Assert.ThrowsAsync( - async () => await base.Average_over_max_subquery_is_client_eval(async))) - .Message); + { + await base.Average_over_max_subquery_is_client_eval(async); + + AssertSql( + """ +@__p_0='3' + +SELECT ef_avg(CAST(( + SELECT AVG(CAST(5 + ( + SELECT MAX("o0"."ProductID") + FROM "Order Details" AS "o0" + WHERE "o"."OrderID" = "o0"."OrderID") AS REAL)) + FROM "Orders" AS "o" + WHERE "c0"."CustomerID" = "o"."CustomerID") AS TEXT)) +FROM ( + SELECT "c"."CustomerID" + FROM "Customers" AS "c" + ORDER BY "c"."CustomerID" + LIMIT @__p_0 +) AS "c0" +"""); + } public override async Task Average_over_nested_subquery_is_client_eval(bool async) - => Assert.Equal( - SqliteStrings.AggregateOperationNotSupported("Average", "decimal"), - (await Assert.ThrowsAsync( - async () => await base.Average_over_nested_subquery_is_client_eval(async))) - .Message); + { + await base.Average_over_nested_subquery_is_client_eval(async); + + AssertSql( + """ +@__p_0='3' + +SELECT ef_avg(CAST(( + SELECT AVG(5.0 + ( + SELECT AVG(CAST("o0"."ProductID" AS REAL)) + FROM "Order Details" AS "o0" + WHERE "o"."OrderID" = "o0"."OrderID")) + FROM "Orders" AS "o" + WHERE "c0"."CustomerID" = "o"."CustomerID") AS TEXT)) +FROM ( + SELECT "c"."CustomerID" + FROM "Customers" AS "c" + ORDER BY "c"."CustomerID" + LIMIT @__p_0 +) AS "c0" +"""); + } public override async Task Multiple_collection_navigation_with_FirstOrDefault_chained(bool async) => Assert.Equal(