diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index f91018d8bf645..97a63b3a95ab5 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -446,7 +446,7 @@ order by value desc; Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) > __sq_1.__value CrossJoin: - Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] + Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2))) AS SUM(partsupp.ps_supplycost * partsupp.ps_availqty)]] Inner Join: supplier.s_nationkey = nation.n_nationkey Inner Join: partsupp.ps_suppkey = supplier.s_suppkey TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] @@ -454,7 +454,7 @@ order by value desc; Filter: nation.n_name = Utf8("GERMANY") TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS __value, alias=__sq_1 - Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] + Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2))) AS SUM(partsupp.ps_supplycost * partsupp.ps_availqty)]] Inner Join: supplier.s_nationkey = nation.n_nationkey Inner Join: partsupp.ps_suppkey = supplier.s_suppkey TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] @@ -462,7 +462,7 @@ order by value desc; Filter: nation.n_name = Utf8("GERMANY") TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")]"# .to_string(); - assert_eq!(actual, expected); + assert_eq!(expected, actual); // assert data let results = execute_to_batches(&ctx, sql).await; diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 2073713ddd312..7590803c36699 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -103,8 +103,8 @@ fn optimize_internal( // ensure aggregate names don't change: // https://github.com/apache/arrow-datafusion/issues/3555 if matches!(expr, Expr::AggregateFunction { .. }) { - if let Some((alias, name)) = original_name.zip(expr.name().ok()) { - if alias != name { + if let Some(alias) = original_name { + if alias != expr.canonical_name() { return Ok(expr.alias(&alias)); } } diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 6dea1a243ff2d..056c79b1669fb 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -46,6 +46,17 @@ fn case_when() -> Result<()> { Ok(()) } +#[test] +fn case_when_aggregate() -> Result<()> { + let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8"; + let plan = test_sql(sql)?; + let expected = "Projection: test.col_utf8, SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n\ + \n Aggregate: groupBy=[[test.col_utf8]], aggr=[[SUM(CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]]\ + \n TableScan: test projection=[col_int32, col_utf8]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + #[test] fn unsigned_target_type() -> Result<()> { let sql = "SELECT * FROM test WHERE col_uint32 > 0";