From e1d5181f72038603521a9280a5dd208bcd0a72b5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 4 Oct 2022 10:36:54 -0600 Subject: [PATCH 1/4] Fix aggregate type coercion bug --- datafusion/core/tests/sql/subqueries.rs | 6 +++--- datafusion/optimizer/src/type_coercion.rs | 4 +++- datafusion/optimizer/tests/integration-test.rs | 11 +++++++++++ 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 4b4f23e13bfa0..5bef97edaac64 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -446,7 +446,7 @@ order by value desc; Projection: #partsupp.ps_partkey, #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value Filter: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) > #__sq_1.__value CrossJoin: - Aggregate: groupBy=[[#partsupp.ps_partkey]], aggr=[[SUM(CAST(#partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(#partsupp.ps_availqty AS Decimal128(26, 2)))]] + Aggregate: groupBy=[[#partsupp.ps_partkey]], aggr=[[SUM(CAST(#partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(#partsupp.ps_availqty AS Decimal128(26, 2))) AS SUM(partsupp.ps_supplycost * partsupp.ps_availqty)]] Inner Join: #supplier.s_nationkey = #nation.n_nationkey Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] @@ -454,7 +454,7 @@ order by value desc; Filter: #nation.n_name = Utf8("GERMANY") TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("GERMANY")] Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS __value, alias=__sq_1 - Aggregate: groupBy=[[]], aggr=[[SUM(CAST(#partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(#partsupp.ps_availqty AS Decimal128(26, 2)))]] + Aggregate: groupBy=[[]], aggr=[[SUM(CAST(#partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(#partsupp.ps_availqty AS Decimal128(26, 2))) AS SUM(partsupp.ps_supplycost * partsupp.ps_availqty)]] Inner Join: #supplier.s_nationkey = #nation.n_nationkey Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] @@ -462,7 +462,7 @@ order by value desc; Filter: #nation.n_name = Utf8("GERMANY") TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("GERMANY")]"# .to_string(); - assert_eq!(actual, expected); + assert_eq!(expected, actual); // assert data let results = execute_to_batches(&ctx, sql).await; diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index c7f107b5ec886..6a34f49532aeb 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -103,7 +103,9 @@ fn optimize_internal( // ensure aggregate names don't change: // https://github.com/apache/arrow-datafusion/issues/3555 if matches!(expr, Expr::AggregateFunction { .. }) { - if let Some((alias, name)) = original_name.zip(expr.name().ok()) { + if let Some((alias, name)) = + original_name.zip(Some(expr.canonical_name().replace("#", ""))) + { if alias != name { return Ok(expr.alias(&alias)); } diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index e7245c06c1021..373c0c95e7317 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -45,6 +45,17 @@ fn case_when() -> Result<()> { Ok(()) } +#[test] +fn case_when_aggregate() -> Result<()> { + let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8"; + let plan = test_sql(sql)?; + let expected = "Projection: #test.col_utf8, #SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n\ + \n Aggregate: groupBy=[[#test.col_utf8]], aggr=[[SUM(CASE WHEN #test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]]\ + \n TableScan: test projection=[col_int32, col_utf8]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + #[test] fn unsigned_target_type() -> Result<()> { let sql = "SELECT * FROM test WHERE col_uint32 > 0"; From 693b81562701745a76e19afea173cda988e8416d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 4 Oct 2022 12:31:19 -0600 Subject: [PATCH 2/4] Clippy --- datafusion/optimizer/src/type_coercion.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 6a34f49532aeb..b8664b60ae2a0 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -103,9 +103,13 @@ fn optimize_internal( // ensure aggregate names don't change: // https://github.com/apache/arrow-datafusion/issues/3555 if matches!(expr, Expr::AggregateFunction { .. }) { - if let Some((alias, name)) = - original_name.zip(Some(expr.canonical_name().replace("#", ""))) - { + if let Some(alias) = original_name { + let name = expr + .canonical_name() + // TODO remove this hack - there is a difference in `expr.name()` + // and `expr.canonical_name()` with the use of '#' to prefix + // column names + .replace('#', ""); if alias != name { return Ok(expr.alias(&alias)); } From 2057c8d2d22ceb26fe2f027518044fb71241b131 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 4 Oct 2022 12:31:44 -0600 Subject: [PATCH 3/4] docs --- datafusion/optimizer/src/type_coercion.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index b8664b60ae2a0..be49c30d134cf 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -106,7 +106,7 @@ fn optimize_internal( if let Some(alias) = original_name { let name = expr .canonical_name() - // TODO remove this hack - there is a difference in `expr.name()` + // TODO remove this hack - there is a difference between `expr.name()` // and `expr.canonical_name()` with the use of '#' to prefix // column names .replace('#', ""); From bf832867358011cc77148cfb6df13c38a2c8d667 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 4 Oct 2022 22:32:20 -0600 Subject: [PATCH 4/4] remove hack --- datafusion/optimizer/src/type_coercion.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 5329fbd472eb1..7590803c36699 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -104,13 +104,7 @@ fn optimize_internal( // https://github.com/apache/arrow-datafusion/issues/3555 if matches!(expr, Expr::AggregateFunction { .. }) { if let Some(alias) = original_name { - let name = expr - .canonical_name() - // TODO remove this hack - there is a difference between `expr.name()` - // and `expr.canonical_name()` with the use of '#' to prefix - // column names - .replace('#', ""); - if alias != name { + if alias != expr.canonical_name() { return Ok(expr.alias(&alias)); } }