diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlNullabilityProcessor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlNullabilityProcessor.cs
index ab8a6f48727..5ecfd0d51e5 100644
--- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlNullabilityProcessor.cs
+++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlNullabilityProcessor.cs
@@ -84,6 +84,30 @@ protected virtual SqlExpression VisitRegexp(
return regexpExpression.Update(match, pattern);
}
+ ///
+ protected override SqlExpression VisitSqlFunction(
+ SqlFunctionExpression sqlFunctionExpression,
+ bool allowOptimizedExpansion,
+ out bool nullable)
+ {
+ var result = base.VisitSqlFunction(sqlFunctionExpression, allowOptimizedExpansion, out nullable);
+
+ if (result is SqlFunctionExpression resultFunctionExpression
+ && resultFunctionExpression.IsBuiltIn
+ && string.Equals(resultFunctionExpression.Name, "ef_sum", StringComparison.OrdinalIgnoreCase))
+ {
+ nullable = false;
+
+ var sqlExpressionFactory = Dependencies.SqlExpressionFactory;
+ return sqlExpressionFactory.Coalesce(
+ result,
+ sqlExpressionFactory.Constant(0, resultFunctionExpression.TypeMapping),
+ resultFunctionExpression.TypeMapping);
+ }
+
+ return result;
+ }
+
#pragma warning disable EF1001
///
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteQueryableAggregateMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteQueryableAggregateMethodTranslator.cs
index 9619d494709..2872a2ccb24 100644
--- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteQueryableAggregateMethodTranslator.cs
+++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteQueryableAggregateMethodTranslator.cs
@@ -54,9 +54,14 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress
var averageArgumentType = GetProviderType(averageSqlExpression);
if (averageArgumentType == typeof(decimal))
{
- throw new NotSupportedException(
- SqliteStrings.AggregateOperationNotSupported(
- nameof(Queryable.Average), averageArgumentType.ShortDisplayName()));
+ averageSqlExpression = CombineTerms(source, averageSqlExpression);
+ return _sqlExpressionFactory.Function(
+ "ef_avg",
+ [averageSqlExpression],
+ nullable: true,
+ argumentsPropagateNullability: [false],
+ averageSqlExpression.Type,
+ averageSqlExpression.TypeMapping);
}
break;
@@ -100,8 +105,14 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress
var sumArgumentType = GetProviderType(sumSqlExpression);
if (sumArgumentType == typeof(decimal))
{
- throw new NotSupportedException(
- SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Sum), sumArgumentType.ShortDisplayName()));
+ sumSqlExpression = CombineTerms(source, sumSqlExpression);
+ return _sqlExpressionFactory.Function(
+ "ef_sum",
+ [sumSqlExpression],
+ nullable: true,
+ argumentsPropagateNullability: [false],
+ sumSqlExpression.Type,
+ sumSqlExpression.TypeMapping);
}
break;
@@ -115,4 +126,21 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress
=> expression.TypeMapping?.Converter?.ProviderClrType
?? expression.TypeMapping?.ClrType
?? expression.Type;
+
+ private SqlExpression CombineTerms(EnumerableExpression enumerableExpression, SqlExpression sqlExpression)
+ {
+ if (enumerableExpression.Predicate != null)
+ {
+ sqlExpression = _sqlExpressionFactory.Case(
+ new List { new(enumerableExpression.Predicate, sqlExpression) },
+ elseResult: null);
+ }
+
+ if (enumerableExpression.IsDistinct)
+ {
+ sqlExpression = new DistinctExpression(sqlExpression);
+ }
+
+ return sqlExpression;
+ }
}
diff --git a/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs b/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs
index cf83ca55632..21773c9fbec 100644
--- a/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs
+++ b/src/EFCore.Sqlite.Core/Storage/Internal/SqliteRelationalConnection.cs
@@ -142,6 +142,25 @@ private void InitializeDbConnection(DbConnection connection)
name: "ef_negate",
(decimal? m) => -m,
isDeterministic: true);
+
+ sqliteConnection.CreateAggregate(
+ "ef_avg",
+ seed: (0m, 0ul),
+ ((decimal sum, ulong count) acc, decimal? value) => value is null
+ ? acc
+ : (acc.sum + value.Value, acc.count + 1),
+ ((decimal sum, ulong count) acc) => acc.count == 0
+ ? default(decimal?)
+ : acc.sum / acc.count,
+ isDeterministic: true);
+
+ sqliteConnection.CreateAggregate(
+ "ef_sum",
+ seed: null,
+ (decimal? sum, decimal? value) => value is null
+ ? sum
+ : sum is null ? value : sum.Value + value.Value,
+ isDeterministic: true);
}
else
{
diff --git a/test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs
index 1d92a24bc37..e4b265921cb 100644
--- a/test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs
+++ b/test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs
@@ -936,7 +936,7 @@ public virtual void Cant_query_Max_of_converted_types()
}
[ConditionalFact]
- public virtual void Cant_query_Average_of_converted_types()
+ public virtual void Can_query_Average_of_converted_types()
{
using var context = CreateContext();
context.Add(
@@ -958,15 +958,14 @@ public virtual void Cant_query_Average_of_converted_types()
context.SaveChanges();
Assert.Equal(
- SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Average), typeof(decimal).ShortDisplayName()),
- Assert.Throws(
- () => context.Set()
- .Where(e => e.PartitionId == 202)
- .Average(e => e.TestNullableDecimal)).Message);
+ 1.000000000000002m,
+ context.Set()
+ .Where(e => e.PartitionId == 202)
+ .Average(e => e.TestNullableDecimal));
}
[ConditionalFact]
- public virtual void Cant_query_Sum_of_converted_types()
+ public virtual void Can_query_Sum_of_converted_types()
{
using var context = CreateContext();
context.Add(
@@ -988,11 +987,10 @@ public virtual void Cant_query_Sum_of_converted_types()
context.SaveChanges();
Assert.Equal(
- SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Sum), typeof(decimal).ShortDisplayName()),
- Assert.Throws(
- () => context.Set()
- .Where(e => e.PartitionId == 203)
- .Sum(e => e.TestDecimal)).Message);
+ 2.000000000000002m,
+ context.Set()
+ .Where(e => e.PartitionId == 203)
+ .Sum(e => e.TestDecimal));
}
[ConditionalFact]
diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/AdHocMiscellaneousQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/AdHocMiscellaneousQuerySqliteTest.cs
index a8750daf50a..7d49ffdfb98 100644
--- a/test/EFCore.Sqlite.FunctionalTests/Query/AdHocMiscellaneousQuerySqliteTest.cs
+++ b/test/EFCore.Sqlite.FunctionalTests/Query/AdHocMiscellaneousQuerySqliteTest.cs
@@ -20,7 +20,68 @@ INSERT INTO ZeroKey VALUES (NULL)
""");
public override async Task Average_with_cast()
- => Assert.Equal(
- SqliteStrings.AggregateOperationNotSupported("Average", "decimal"),
- (await Assert.ThrowsAsync(base.Average_with_cast)).Message);
+ {
+ await base.Average_with_cast();
+
+ AssertSql(
+ """
+SELECT "p"."Id", "p"."DecimalColumn", "p"."DoubleColumn", "p"."FloatColumn", "p"."IntColumn", "p"."LongColumn", "p"."NullableDecimalColumn", "p"."NullableDoubleColumn", "p"."NullableFloatColumn", "p"."NullableIntColumn", "p"."NullableLongColumn", "p"."Price"
+FROM "Prices" AS "p"
+""",
+ //
+ """
+SELECT ef_avg("p"."Price")
+FROM "Prices" AS "p"
+""",
+ //
+ """
+SELECT AVG(CAST("p"."IntColumn" AS REAL))
+FROM "Prices" AS "p"
+""",
+ //
+ """
+SELECT AVG(CAST("p"."NullableIntColumn" AS REAL))
+FROM "Prices" AS "p"
+""",
+ //
+ """
+SELECT AVG(CAST("p"."LongColumn" AS REAL))
+FROM "Prices" AS "p"
+""",
+ //
+ """
+SELECT AVG(CAST("p"."NullableLongColumn" AS REAL))
+FROM "Prices" AS "p"
+""",
+ //
+ """
+SELECT CAST(AVG("p"."FloatColumn") AS REAL)
+FROM "Prices" AS "p"
+""",
+ //
+ """
+SELECT CAST(AVG("p"."NullableFloatColumn") AS REAL)
+FROM "Prices" AS "p"
+""",
+ //
+ """
+SELECT AVG("p"."DoubleColumn")
+FROM "Prices" AS "p"
+""",
+ //
+ """
+SELECT AVG("p"."NullableDoubleColumn")
+FROM "Prices" AS "p"
+""",
+ //
+ """
+SELECT ef_avg("p"."DecimalColumn")
+FROM "Prices" AS "p"
+""",
+ //
+ """
+SELECT ef_avg("p"."NullableDecimalColumn")
+FROM "Prices" AS "p"
+""");
+ }
}
diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/Ef6GroupBySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/Ef6GroupBySqliteTest.cs
index db7b6d52dd6..01aab7d8271 100644
--- a/test/EFCore.Sqlite.FunctionalTests/Query/Ef6GroupBySqliteTest.cs
+++ b/test/EFCore.Sqlite.FunctionalTests/Query/Ef6GroupBySqliteTest.cs
@@ -17,10 +17,16 @@ public Ef6GroupBySqliteTest(Ef6GroupBySqliteFixture fixture, ITestOutputHelper t
}
public override async Task Average_Grouped_from_LINQ_101(bool async)
- => Assert.Equal(
- SqliteStrings.AggregateOperationNotSupported("Average", "decimal"),
- (await Assert.ThrowsAsync(
- () => base.Average_Grouped_from_LINQ_101(async))).Message);
+ {
+ await base.Average_Grouped_from_LINQ_101(async);
+
+ AssertSql(
+ """
+SELECT "p"."Category", ef_avg("p"."UnitPrice") AS "AveragePrice"
+FROM "ProductForLinq" AS "p"
+GROUP BY "p"."Category"
+""");
+ }
public override async Task Max_Grouped_from_LINQ_101(bool async)
=> Assert.Equal(
@@ -49,6 +55,9 @@ public override async Task Group_Join_from_LINQ_101(bool async)
(await Assert.ThrowsAsync(
() => base.Group_Join_from_LINQ_101(async))).Message);
+ private void AssertSql(params string[] expected)
+ => Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
+
public class Ef6GroupBySqliteFixture : Ef6GroupByFixtureBase, ITestSqlLoggerFactory
{
public TestSqlLoggerFactory TestSqlLoggerFactory
diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqliteTest.cs
index f2e19956f0c..2c07e6a4fec 100644
--- a/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqliteTest.cs
+++ b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqliteTest.cs
@@ -20,46 +20,98 @@ public NorthwindAggregateOperatorsQuerySqliteTest(
}
public override async Task Sum_with_division_on_decimal(bool async)
- => Assert.Equal(
- SqliteStrings.AggregateOperationNotSupported("Sum", "decimal"),
- (await Assert.ThrowsAsync(
- async () => await base.Sum_with_division_on_decimal(async)))
- .Message);
+ {
+ await base.Sum_with_division_on_decimal(async);
+
+ AssertSql(
+ """
+SELECT COALESCE(ef_sum(ef_divide(CAST("o"."Quantity" AS TEXT), '2.09')), '0.0')
+FROM "Order Details" AS "o"
+""");
+ }
public override async Task Sum_with_division_on_decimal_no_significant_digits(bool async)
- => Assert.Equal(
- SqliteStrings.AggregateOperationNotSupported("Sum", "decimal"),
- (await Assert.ThrowsAsync(
- async () => await base.Sum_with_division_on_decimal_no_significant_digits(async)))
- .Message);
+ {
+ await base.Sum_with_division_on_decimal_no_significant_digits(async);
+
+ AssertSql(
+ """
+SELECT COALESCE(ef_sum(ef_divide(CAST("o"."Quantity" AS TEXT), '2.0')), '0.0')
+FROM "Order Details" AS "o"
+""");
+ }
public override async Task Average_with_division_on_decimal(bool async)
- => Assert.Equal(
- SqliteStrings.AggregateOperationNotSupported("Average", "decimal"),
- (await Assert.ThrowsAsync(
- async () => await base.Average_with_division_on_decimal(async)))
- .Message);
+ {
+ await base.Average_with_division_on_decimal(async);
+
+ AssertSql(
+ """
+SELECT ef_avg(ef_divide(CAST("o"."Quantity" AS TEXT), '2.09'))
+FROM "Order Details" AS "o"
+""");
+ }
public override async Task Average_with_division_on_decimal_no_significant_digits(bool async)
- => Assert.Equal(
- SqliteStrings.AggregateOperationNotSupported("Average", "decimal"),
- (await Assert.ThrowsAsync(
- async () => await base.Average_with_division_on_decimal_no_significant_digits(async)))
- .Message);
+ {
+ await base.Average_with_division_on_decimal_no_significant_digits(async);
+
+ AssertSql(
+ """
+SELECT ef_avg(ef_divide(CAST("o"."Quantity" AS TEXT), '2.0'))
+FROM "Order Details" AS "o"
+""");
+ }
+
+
public override async Task Average_over_max_subquery_is_client_eval(bool async)
- => Assert.Equal(
- SqliteStrings.AggregateOperationNotSupported("Average", "decimal"),
- (await Assert.ThrowsAsync(
- async () => await base.Average_over_max_subquery_is_client_eval(async)))
- .Message);
+ {
+ await base.Average_over_max_subquery_is_client_eval(async);
+
+ AssertSql(
+ """
+@__p_0='3'
+
+SELECT ef_avg(CAST((
+ SELECT AVG(CAST(5 + (
+ SELECT MAX("o0"."ProductID")
+ FROM "Order Details" AS "o0"
+ WHERE "o"."OrderID" = "o0"."OrderID") AS REAL))
+ FROM "Orders" AS "o"
+ WHERE "c0"."CustomerID" = "o"."CustomerID") AS TEXT))
+FROM (
+ SELECT "c"."CustomerID"
+ FROM "Customers" AS "c"
+ ORDER BY "c"."CustomerID"
+ LIMIT @__p_0
+) AS "c0"
+""");
+ }
public override async Task Average_over_nested_subquery_is_client_eval(bool async)
- => Assert.Equal(
- SqliteStrings.AggregateOperationNotSupported("Average", "decimal"),
- (await Assert.ThrowsAsync(
- async () => await base.Average_over_nested_subquery_is_client_eval(async)))
- .Message);
+ {
+ await base.Average_over_nested_subquery_is_client_eval(async);
+
+ AssertSql(
+ """
+@__p_0='3'
+
+SELECT ef_avg(CAST((
+ SELECT AVG(5.0 + (
+ SELECT AVG(CAST("o0"."ProductID" AS REAL))
+ FROM "Order Details" AS "o0"
+ WHERE "o"."OrderID" = "o0"."OrderID"))
+ FROM "Orders" AS "o"
+ WHERE "c0"."CustomerID" = "o"."CustomerID") AS TEXT))
+FROM (
+ SELECT "c"."CustomerID"
+ FROM "Customers" AS "c"
+ ORDER BY "c"."CustomerID"
+ LIMIT @__p_0
+) AS "c0"
+""");
+ }
public override async Task Multiple_collection_navigation_with_FirstOrDefault_chained(bool async)
=> Assert.Equal(
@@ -203,16 +255,20 @@ ELSE 0
""");
}
+ public override async Task Type_casting_inside_sum(bool async)
+ {
+ await base.Type_casting_inside_sum(async);
+
+ AssertSql(
+ """
+SELECT COALESCE(ef_sum(CAST("o"."Discount" AS TEXT)), '0.0')
+FROM "Order Details" AS "o"
+""");
+ }
+
private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
protected override void ClearLog()
=> Fixture.TestSqlLoggerFactory.Clear();
-
- public override async Task Type_casting_inside_sum(bool async)
- => Assert.Equal(
- SqliteStrings.AggregateOperationNotSupported("Sum", "decimal"),
- (await Assert.ThrowsAsync(
- async () => await base.Type_casting_inside_sum(async)))
- .Message);
}
diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs
index a470258565c..e22de3f047f 100644
--- a/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs
+++ b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindGroupByQuerySqliteTest.cs
@@ -59,10 +59,22 @@ public override Task GroupBy_aggregate_from_multiple_query_in_same_projection_3(
() => base.GroupBy_aggregate_from_multiple_query_in_same_projection_3(async));
public override async Task Odata_groupby_empty_key(bool async)
- => Assert.Equal(
- SqliteStrings.AggregateOperationNotSupported("Sum", "decimal"),
- (await Assert.ThrowsAsync(
- () => base.Odata_groupby_empty_key(async))).Message);
+ {
+ await base.Odata_groupby_empty_key(async);
+
+ AssertSql(
+ """
+SELECT 'TotalAmount' AS "Name", COALESCE(ef_sum(CAST("o0"."OrderID" AS TEXT)), '0.0') AS "Value"
+FROM (
+ SELECT "o"."OrderID", 1 AS "Key"
+ FROM "Orders" AS "o"
+) AS "o0"
+GROUP BY "o0"."Key"
+""");
+ }
+
+ private void AssertSql(params string[] expected)
+ => Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
private static async Task AssertApplyNotSupported(Func query)
=> Assert.Equal(