diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index d4002997b79b4..fe737d6e81f77 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -37,7 +37,7 @@ use crate::logical_plan::{ use crate::optimizer::utils::exprlist_to_columns; use crate::prelude::JoinType; use crate::scalar::ScalarValue; -use crate::sql::utils::{make_decimal_type, normalize_ident}; +use crate::sql::utils::{make_decimal_type, normalize_ident, resolve_columns}; use crate::{ error::{DataFusionError, Result}, physical_plan::aggregates, @@ -1144,30 +1144,45 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { group_by_exprs: Vec, aggr_exprs: Vec, ) -> Result<(LogicalPlan, Vec, Option)> { + // create the aggregate plan + let plan = LogicalPlanBuilder::from(input.clone()) + .aggregate(group_by_exprs.clone(), aggr_exprs.clone())? + .build()?; + + // in this next section of code we are re-writing the projection to refer to columns + // output by the aggregate plan. For example, if the projection contains the expression + // `SUM(a)` then we replace that with a reference to a column `#SUM(a)` produced by + // the aggregate plan. + + // combine the original grouping and aggregate expressions into one list (note that + // we do not add the "having" expression since that is not part of the projection) let aggr_projection_exprs = group_by_exprs .iter() .chain(aggr_exprs.iter()) .cloned() .collect::>(); - let plan = LogicalPlanBuilder::from(input.clone()) - .aggregate(group_by_exprs, aggr_exprs)? - .build()?; + // now attempt to resolve columns and replace with fully-qualified columns + let aggr_projection_exprs = aggr_projection_exprs + .iter() + .map(|expr| resolve_columns(expr, &input)) + .collect::>>()?; - // After aggregation, these are all of the columns that will be - // available to next phases of planning. + // next we replace any expressions that are not a column with a column referencing + // an output column from the aggregate schema let column_exprs_post_aggr = aggr_projection_exprs .iter() .map(|expr| expr_as_column_expr(expr, &input)) .collect::>>()?; - // Rewrite the SELECT expression to use the columns produced by the - // aggregation. + // next we re-write the projection let select_exprs_post_aggr = select_exprs .iter() .map(|expr| rebase_expr(expr, &aggr_projection_exprs, &input)) .collect::>>()?; + // finally, we have some validation that the re-written projection can be resolved + // from the aggregate output columns check_columns_satisfy_exprs( &column_exprs_post_aggr, &select_exprs_post_aggr, diff --git a/datafusion/core/src/sql/utils.rs b/datafusion/core/src/sql/utils.rs index cd1fb316b76d5..4acaa21ef73b4 100644 --- a/datafusion/core/src/sql/utils.rs +++ b/datafusion/core/src/sql/utils.rs @@ -155,6 +155,22 @@ pub(crate) fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result Result { + clone_with_replacement(expr, &|nested_expr| { + match nested_expr { + Expr::Column(col) => { + let field = plan.schema().field_from_column(col)?; + Ok(Some(Expr::Column(field.qualified_column()))) + } + _ => { + // keep recursing + Ok(None) + } + } + }) +} + /// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s. /// /// For example, the expression `a + b < 1` would require, as input, the 2 diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs index 41f2471f6c9e7..e3da1b02195a5 100644 --- a/datafusion/core/tests/sql/group_by.rs +++ b/datafusion/core/tests/sql/group_by.rs @@ -211,6 +211,32 @@ async fn csv_query_having_without_group_by() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_group_by_substr() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + // there is an input column "c1" as well a projection expression aliased as "c1" + let sql = "SELECT substr(c1, 1, 1) c1 \ + FROM aggregate_test_100 \ + GROUP BY substr(c1, 1, 1) \ + "; + let actual = execute_to_batches(&ctx, sql).await; + #[rustfmt::skip] + let expected = vec![ + "+----+", + "| c1 |", + "+----+", + "| a |", + "| b |", + "| c |", + "| d |", + "| e |", + "+----+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_group_by_avg() -> Result<()> { let ctx = SessionContext::new();