From 8e87787d451a818f5c6f3a61f0c5ac56e4d72a69 Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Tue, 21 Mar 2023 15:50:26 +0800 Subject: [PATCH 1/6] add analyzer rule replace_grouping_func --- datafusion/core/tests/sql/group_by.rs | 22 ++ datafusion/expr/src/aggregate_function.rs | 10 + datafusion/expr/src/signature.rs | 2 + .../expr/src/type_coercion/aggregates.rs | 4 + .../expr/src/type_coercion/functions.rs | 3 + datafusion/expr/src/utils.rs | 4 +- .../src/{analyzer.rs => analyzer/mod.rs} | 7 +- .../src/analyzer/replace_grouping_func.rs | 213 ++++++++++++++++++ .../physical-expr/src/aggregate/build_in.rs | 15 +- .../physical-expr/src/aggregate/grouping.rs | 93 -------- datafusion/physical-expr/src/aggregate/mod.rs | 1 - .../physical-expr/src/expressions/mod.rs | 1 - datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 1 + datafusion/proto/src/logical_plan/to_proto.rs | 2 + 15 files changed, 278 insertions(+), 103 deletions(-) rename datafusion/optimizer/src/{analyzer.rs => analyzer/mod.rs} (97%) create mode 100644 datafusion/optimizer/src/analyzer/replace_grouping_func.rs delete mode 100644 datafusion/physical-expr/src/aggregate/grouping.rs diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs index a92eaf0f4d311..a7e882a89e906 100644 --- a/datafusion/core/tests/sql/group_by.rs +++ b/datafusion/core/tests/sql/group_by.rs @@ -905,3 +905,25 @@ async fn csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> { assert_batches_sorted_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn csv_query_group_by_with_grouping_functions() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT GROUPING(c1), GROUPING_ID(c1), c1, avg(c12) FROM aggregate_test_100 GROUP BY GROUPING SETS((c1),(c1))"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+-----------------------------+", + "| c1 | AVG(aggregate_test_100.c12) |", + "+----+-----------------------------+", + "| a | 0.48754517466109415 |", + "| b | 0.41040709263815384 |", + "| c | 0.6600456536439784 |", + "| d | 0.48855379387549824 |", + "| e | 0.48600669271341534 |", + "+----+-----------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index b7fb7d47d2970..3bc6677a90e7b 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -63,6 +63,8 @@ pub enum AggregateFunction { ApproxMedian, /// Grouping Grouping, + /// GroupingID + GroupingID, } impl fmt::Display for AggregateFunction { @@ -101,6 +103,7 @@ impl FromStr for AggregateFunction { } "approx_median" => AggregateFunction::ApproxMedian, "grouping" => AggregateFunction::Grouping, + "grouping_id" => AggregateFunction::GroupingID, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {name}" @@ -158,6 +161,7 @@ pub fn return_type( Ok(coerced_data_types[0].clone()) } AggregateFunction::Grouping => Ok(DataType::Int32), + AggregateFunction::GroupingID => Ok(DataType::Int32), } } @@ -220,5 +224,11 @@ pub fn signature(fun: &AggregateFunction) -> Signature { .collect(), Volatility::Immutable, ), + AggregateFunction::GroupingID => { + Signature { + type_signature: TypeSignature::Arbitrary, + volatility: Volatility::Immutable, + } + } } } diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 19909cf2fbf44..7e565949e7384 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -56,6 +56,8 @@ pub enum TypeSignature { Any(usize), /// One of a list of signatures OneOf(Vec), + /// Arbitrary number of arguments of arbitrary types + Arbitrary } ///The Signature of a function defines its supported input types as well as its volatility. diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 3ad197afb64a0..4b6ac91f089a1 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -217,6 +217,7 @@ pub fn coerce_types( } AggregateFunction::Median => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), + AggregateFunction::GroupingID => Ok(input_types.to_vec()), } } @@ -263,6 +264,9 @@ fn check_arg_count( ))); } } + TypeSignature::Arbitrary => { + return Ok(()) + } _ => { return Err(DataFusionError::Internal(format!( "Aggregate functions do not support this {signature:?}" diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index a038fdcc92d0d..1bdec8a4891ce 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -94,6 +94,9 @@ fn get_valid_types( .filter_map(|t| get_valid_types(t, current_types).ok()) .flatten() .collect::>(), + TypeSignature::Arbitrary => { + vec![current_types.to_vec()] + } }; Ok(valid_types) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index ea8607feeb0c3..53c02f2d2a1dc 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -385,7 +385,7 @@ pub fn find_window_exprs(exprs: &[Expr]) -> Vec { /// Search the provided `Expr`'s, and all of their nested `Expr`, for any that /// pass the provided test. The returned `Expr`'s are deduplicated and returned /// in order of appearance (depth first). -fn find_exprs_in_exprs(exprs: &[Expr], test_fn: &F) -> Vec +pub fn find_exprs_in_exprs(exprs: &[Expr], test_fn: &F) -> Vec where F: Fn(&Expr) -> bool, { @@ -403,7 +403,7 @@ where /// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the /// provided test. The returned `Expr`'s are deduplicated and returned in order /// of appearance (depth first). -fn find_exprs_in_expr(expr: &Expr, test_fn: &F) -> Vec +pub fn find_exprs_in_expr(expr: &Expr, test_fn: &F) -> Vec where F: Fn(&Expr) -> bool, { diff --git a/datafusion/optimizer/src/analyzer.rs b/datafusion/optimizer/src/analyzer/mod.rs similarity index 97% rename from datafusion/optimizer/src/analyzer.rs rename to datafusion/optimizer/src/analyzer/mod.rs index f2a1ba9d64bba..bb7369dfd1df2 100644 --- a/datafusion/optimizer/src/analyzer.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +mod replace_grouping_func; + use crate::rewrite::TreeNodeRewritable; use datafusion_common::config::ConfigOptions; use datafusion_common::{DataFusionError, Result}; @@ -23,6 +25,7 @@ use datafusion_expr::{Expr, LogicalPlan}; use log::{debug, trace}; use std::sync::Arc; use std::time::Instant; +use crate::analyzer::replace_grouping_func::ReplaceGroupingFunc; /// `AnalyzerRule` transforms the unresolved ['LogicalPlan']s and unresolved ['Expr']s into /// the resolved form. @@ -49,7 +52,9 @@ impl Default for Analyzer { impl Analyzer { /// Create a new analyzer using the recommended list of rules pub fn new() -> Self { - let rules = vec![]; + let rules: Vec> = vec![ + Arc::new(ReplaceGroupingFunc::new()), + ]; Self::with_rules(rules) } diff --git a/datafusion/optimizer/src/analyzer/replace_grouping_func.rs b/datafusion/optimizer/src/analyzer/replace_grouping_func.rs new file mode 100644 index 0000000000000..d7e3ebe248ba8 --- /dev/null +++ b/datafusion/optimizer/src/analyzer/replace_grouping_func.rs @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::analyzer::AnalyzerRule; +use crate::rewrite::TreeNodeRewritable; +use arrow::datatypes::DataType; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::expr_rewriter::rewrite_expr; +use datafusion_expr::utils::find_exprs_in_expr; +use datafusion_expr::{ + aggregate_function, bitwise_and, bitwise_shift_right, cast, col, lit, Filter, + GroupingSet, Sort, +}; +use datafusion_expr::{Aggregate, Expr, LogicalPlan}; + +use datafusion_common::{Column, DataFusionError, Result}; + +use hashbrown::HashSet; + +pub struct ReplaceGroupingFunc; + +impl ReplaceGroupingFunc { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +const INTERNAL_GROUPING_COLUMN: &str = "_grouping_id"; + +impl AnalyzerRule for ReplaceGroupingFunc { + fn analyze( + &self, + plan: &LogicalPlan, + _config: &ConfigOptions, + ) -> datafusion_common::Result { + plan.clone().transform_up(&|plan| match plan { + LogicalPlan::Aggregate(Aggregate { + input, + aggr_expr, + group_expr, + .. + }) if contains_grouping_funcs_in_exprs(&aggr_expr) => { + let gid_column = Column { + relation: None, + name: INTERNAL_GROUPING_COLUMN.to_owned(), + }; + let distinct_group_by = distinct_group_exprs(&group_expr); + let new_agg_expr = aggr_expr + .into_iter() + .map(|expr| { + replace_grouping_func(expr, &distinct_group_by, gid_column.clone()) + }) + .collect::>>()?; + Ok(Some(LogicalPlan::Aggregate(Aggregate::try_new( + input, + group_expr, + new_agg_expr, + )?))) + } + LogicalPlan::Filter(Filter { predicate, .. }) + if contains_grouping_funcs(&predicate) => + { + Ok(None) + } + LogicalPlan::Sort(Sort { expr, .. }) + if contains_grouping_funcs_in_exprs(&expr) => + { + Ok(None) + } + _ => Ok(None), + }) + } + fn name(&self) -> &str { + "replace_grouping_func" + } +} + +pub fn distinct_group_exprs(group_expr: &[Expr]) -> Vec { + let mut dedup_expr = Vec::new(); + let mut dedup_set = HashSet::new(); + group_expr.iter().for_each(|expr| match expr { + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(exprs) => exprs.iter().for_each(|e| { + if !dedup_set.contains(e) { + dedup_expr.push(e.clone()); + dedup_set.insert(e.clone()); + } + }), + GroupingSet::Cube(exprs) => exprs.iter().for_each(|e| { + if !dedup_set.contains(e) { + dedup_expr.push(e.clone()); + dedup_set.insert(e.clone()); + } + }), + GroupingSet::GroupingSets(groups) => groups.iter().flatten().for_each(|e| { + if !dedup_set.contains(e) { + dedup_expr.push(e.clone()); + dedup_set.insert(e.clone()); + } + }), + }, + _ => { + if !dedup_set.contains(expr) { + dedup_expr.push(expr.clone()); + dedup_set.insert(expr.clone()); + } + } + }); + dedup_expr +} + +fn contains_grouping_funcs_in_exprs(aggr_expr: &[Expr]) -> bool { + aggr_expr.iter().any(|expr| { + !find_exprs_in_expr(expr, &|nested_expr| { + matches!( + nested_expr, + Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::Grouping, + .. + }) | Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::GroupingID, + .. + }) + ) + }) + .is_empty() + }) +} + +fn contains_grouping_funcs(expr: &Expr) -> bool { + !find_exprs_in_expr(expr, &|nested_expr| { + matches!( + nested_expr, + Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::Grouping, + .. + }) | Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::GroupingID, + .. + }) + ) + }) + .is_empty() +} + +fn replace_grouping_func( + expr: Expr, + group_by_exprs: &[Expr], + gid_column: Column, +) -> Result { + rewrite_expr(expr, |expr| { + let display_name = expr.display_name()?; + match expr { + Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::Grouping, + args, + .. + }) => { + let grouping_col = &args[0]; + match group_by_exprs.iter().position(|e| e == grouping_col) { + Some(idx) => Ok(cast( + bitwise_and( + bitwise_shift_right( + col(gid_column.clone()), + lit((group_by_exprs.len() - 1 - idx) as u32), + ), + lit(1), + ), + DataType::Binary, + ).alias(display_name)), + None => Err(DataFusionError::Plan(format!( + "Column of GROUPING({:?}) can't be found in GROUP BY columns {:?}", + grouping_col, group_by_exprs + ))), + } + } + Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::GroupingID, + args, + .. + }) => { + if group_by_exprs.is_empty() + || (group_by_exprs.len() == args.len() + && group_by_exprs.iter().zip(args.iter()).all(|(g, a)| g == a)) + { + Ok(col(gid_column.clone()).alias(display_name)) + } else { + Err(DataFusionError::Plan(format!( + "Columns of GROUPING_ID({:?}) does not match GROUP BY columns {:?}", + args, group_by_exprs + ))) + } + } + _ => Ok(expr), + } + }) +} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index b3dbef7dfdf53..774208b6b6cba 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -62,11 +62,6 @@ pub fn create_aggregate_expr( input_phy_exprs[0].clone(), name, )), - (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( - input_phy_exprs[0].clone(), - name, - return_type, - )), (AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new( input_phy_exprs[0].clone(), name, @@ -250,6 +245,16 @@ pub fn create_aggregate_expr( "MEDIAN(DISTINCT) aggregations are not available".to_string(), )); } + (AggregateFunction::Grouping, _) => { + return Err(DataFusionError::Plan( + "GROUPING() aggregations are not evaluable".to_string(), + )); + } + (AggregateFunction::GroupingID, _) => { + return Err(DataFusionError::Plan( + "GROUPING_ID() aggregations are not evaluable".to_string(), + )); + } }) } diff --git a/datafusion/physical-expr/src/aggregate/grouping.rs b/datafusion/physical-expr/src/aggregate/grouping.rs deleted file mode 100644 index 9ddd17c035e84..0000000000000 --- a/datafusion/physical-expr/src/aggregate/grouping.rs +++ /dev/null @@ -1,93 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use std::any::Any; -use std::sync::Arc; - -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::datatypes::DataType; -use arrow::datatypes::Field; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; - -use crate::expressions::format_state_name; - -/// GROUPING aggregate expression -/// Returns the amount of non-null values of the given expression. -#[derive(Debug)] -pub struct Grouping { - name: String, - data_type: DataType, - nullable: bool, - expr: Arc, -} - -impl Grouping { - /// Create a new GROUPING aggregate function. - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for Grouping { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "grouping"), - self.data_type.clone(), - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn create_accumulator(&self) -> Result> { - Err(DataFusionError::NotImplemented( - "physical plan is not yet implemented for GROUPING aggregate function" - .to_owned(), - )) - } - - fn name(&self) -> &str { - &self.name - } -} diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index c42a5c03b3060..9a25bc3d7a97a 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -35,7 +35,6 @@ pub(crate) mod correlation; pub(crate) mod count; pub(crate) mod count_distinct; pub(crate) mod covariance; -pub(crate) mod grouping; pub(crate) mod median; #[macro_use] pub(crate) mod min_max; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 63fb7b7d37ad5..5efda1c6c1342 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -52,7 +52,6 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; -pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 6209989deff0e..86054ec23108f 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2309,6 +2309,7 @@ pub enum AggregateFunction { ApproxPercentileContWithWeight = 16, Grouping = 17, Median = 18, + GroupingID = 19, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2337,6 +2338,7 @@ impl AggregateFunction { "APPROX_PERCENTILE_CONT_WITH_WEIGHT" } AggregateFunction::Grouping => "GROUPING", + AggregateFunction::GroupingID => "GROUPING_ID", AggregateFunction::Median => "MEDIAN", } } @@ -2363,6 +2365,7 @@ impl AggregateFunction { Some(Self::ApproxPercentileContWithWeight) } "GROUPING" => Some(Self::Grouping), + "GROUPING_ID" => Some(Self::GroupingID), "MEDIAN" => Some(Self::Median), _ => None, } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index aa416e63b8a63..119648e3559e4 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -502,6 +502,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian, protobuf::AggregateFunction::Grouping => Self::Grouping, protobuf::AggregateFunction::Median => Self::Median, + protobuf::AggregateFunction::GroupingID => Self::GroupingID, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index a794c9bd06051..ab5cc529565c1 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -371,6 +371,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::ApproxMedian => Self::ApproxMedian, AggregateFunction::Grouping => Self::Grouping, AggregateFunction::Median => Self::Median, + AggregateFunction::GroupingID => Self::GroupingID, } } } @@ -630,6 +631,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, AggregateFunction::Median => protobuf::AggregateFunction::Median, + AggregateFunction::GroupingID => protobuf::AggregateFunction::GroupingID, }; let aggregate_expr = protobuf::AggregateExprNode { From 6a9fbf3e8b64255bbefc924a0112fcc9b36595d4 Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Tue, 21 Mar 2023 16:44:35 +0800 Subject: [PATCH 2/6] add VirtualColumn --- .../core/src/datasource/listing/helpers.rs | 1 + datafusion/core/src/physical_plan/planner.rs | 3 +++ datafusion/core/tests/sql/group_by.rs | 17 +++++++++++++++++ datafusion/expr/src/expr.rs | 5 +++++ datafusion/expr/src/expr_rewriter.rs | 1 + datafusion/expr/src/expr_schema.rs | 2 ++ datafusion/expr/src/expr_visitor.rs | 1 + datafusion/expr/src/utils.rs | 3 ++- .../src/analyzer/replace_grouping_func.rs | 13 +++++-------- .../src/simplify_expressions/expr_simplifier.rs | 1 + datafusion/proto/src/logical_plan/to_proto.rs | 4 ++-- datafusion/sql/src/utils.rs | 1 + 12 files changed, 41 insertions(+), 11 deletions(-) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 7ed6326a907aa..77527fee5722c 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -82,6 +82,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> { Expr::Literal(_) | Expr::Alias(_, _) | Expr::OuterReferenceColumn(_, _) + | Expr::VirtualColumn(_, _) | Expr::ScalarVariable(_, _) | Expr::Not(_) | Expr::IsNotNull(_) diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 51653450a6996..cbd682a1261b8 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -349,6 +349,9 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Expr::OuterReferenceColumn(_, _) => Err(DataFusionError::Internal( "Create physical name does not support OuterReferenceColumn".to_string(), )), + Expr::VirtualColumn(_dt, c) => { + Ok(c.to_string()) + } } } diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs index ed66ce7bd6357..ae4b58fbd0abb 100644 --- a/datafusion/core/tests/sql/group_by.rs +++ b/datafusion/core/tests/sql/group_by.rs @@ -911,6 +911,23 @@ async fn csv_query_group_by_with_grouping_functions() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_csv(&ctx).await?; let sql = "SELECT GROUPING(c1), GROUPING_ID(c1), c1, avg(c12) FROM aggregate_test_100 GROUP BY GROUPING SETS((c1),(c1))"; + + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + let expected = vec![ + "Projection: GROUPING(aggregate_test_100.c1), GROUPINGID(aggregate_test_100.c1), aggregate_test_100.c1, AVG(aggregate_test_100.c12) [GROUPING(aggregate_test_100.c1):Int32;N, GROUPINGID(aggregate_test_100.c1):Int32;N, c1:Utf8, AVG(aggregate_test_100.c12):Float64;N]", + " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1), (aggregate_test_100.c1))]], aggr=[[CAST(CAST(_virtual_grouping_id AS _virtual_grouping_id AS Int64) & Int64(1) AS Binary) AS GROUPING(aggregate_test_100.c1), _virtual_grouping_id AS _virtual_grouping_id AS GROUPINGID(aggregate_test_100.c1), AVG(aggregate_test_100.c12)]] [c1:Utf8, GROUPING(aggregate_test_100.c1):Binary, GROUPINGID(aggregate_test_100.c1):Int32, AVG(aggregate_test_100.c12):Float64;N]", + " Projection: _virtual_grouping_id, aggregate_test_100.c1, aggregate_test_100.c12 [_virtual_grouping_id:Int32, c1:Utf8, c12:Float64]", + " TableScan: aggregate_test_100 projection=[c1, c12] [c1:Utf8, c12:Float64]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-----------------------------+", diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 01860ae7d411d..414ee8583d287 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -223,6 +223,8 @@ pub enum Expr { /// A place holder which hold a reference to a qualified field /// in the outer query, used for correlated sub queries. OuterReferenceColumn(DataType, Column), + /// A virtual column used by the system internally + VirtualColumn(DataType, String), } /// Binary expression @@ -600,6 +602,7 @@ impl Expr { Expr::TryCast { .. } => "TryCast", Expr::WindowFunction { .. } => "WindowFunction", Expr::Wildcard => "Wildcard", + Expr::VirtualColumn(..) => "VirtualColumn", } } @@ -1081,6 +1084,7 @@ impl fmt::Debug for Expr { } }, Expr::Placeholder { id, .. } => write!(f, "{id}"), + Expr::VirtualColumn(_, c) => write!(f, "_virtual_{}", c), } } } @@ -1364,6 +1368,7 @@ fn create_name(e: &Expr) -> Result { "Create name does not support qualified wildcard".to_string(), )), Expr::Placeholder { id, .. } => Ok((*id).to_string()), + Expr::VirtualColumn(_, c) => Ok(format!("_virtual_{}", c)), } } diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index b4e82be5781fd..56e2bbfb040ef 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -122,6 +122,7 @@ impl ExprRewritable for Expr { Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name), Expr::Column(_) => self.clone(), Expr::OuterReferenceColumn(_, _) => self.clone(), + Expr::VirtualColumn(_, _) => self.clone(), Expr::Exists { .. } => self.clone(), Expr::InSubquery { expr, diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index fafda79a6f61d..2f671dc638a2d 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -66,6 +66,7 @@ impl ExprSchemable for Expr { Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => expr.get_type(schema), Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), + Expr::VirtualColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), Expr::Literal(l) => Ok(l.get_datatype()), Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), @@ -212,6 +213,7 @@ impl ExprSchemable for Expr { | Expr::IsNotUnknown(_) | Expr::Exists { .. } | Expr::Placeholder { .. } => Ok(true), + | Expr::VirtualColumn(_, _) => Ok(false), Expr::InSubquery { expr, .. } => expr.nullable(input_schema), Expr::ScalarSubquery(subquery) => { Ok(subquery.subquery.schema().field(0).is_nullable()) diff --git a/datafusion/expr/src/expr_visitor.rs b/datafusion/expr/src/expr_visitor.rs index 84ca6f7ed9dfb..faf228513f6e2 100644 --- a/datafusion/expr/src/expr_visitor.rs +++ b/datafusion/expr/src/expr_visitor.rs @@ -136,6 +136,7 @@ impl ExprVisitable for Expr { Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression | Expr::OuterReferenceColumn(_, _) + | Expr::VirtualColumn(_, _) | Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Exists { .. } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 142939b64e35d..4fd8376cac331 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -135,7 +135,8 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::QualifiedWildcard { .. } | Expr::GetIndexedField { .. } | Expr::Placeholder { .. } - | Expr::OuterReferenceColumn { .. } => {} + | Expr::OuterReferenceColumn { .. } + | Expr::VirtualColumn {..} => {} } Ok(()) }) diff --git a/datafusion/optimizer/src/analyzer/replace_grouping_func.rs b/datafusion/optimizer/src/analyzer/replace_grouping_func.rs index dfaf6308961f5..8d64d13c6ddf0 100644 --- a/datafusion/optimizer/src/analyzer/replace_grouping_func.rs +++ b/datafusion/optimizer/src/analyzer/replace_grouping_func.rs @@ -41,7 +41,7 @@ impl ReplaceGroupingFunc { } } -const INTERNAL_GROUPING_COLUMN: &str = "_grouping_id"; +const INTERNAL_GROUPING_COLUMN: &str = "grouping_id"; impl AnalyzerRule for ReplaceGroupingFunc { fn analyze( @@ -56,10 +56,7 @@ impl AnalyzerRule for ReplaceGroupingFunc { group_expr, .. }) if contains_grouping_funcs_in_exprs(&aggr_expr) => { - let gid_column = Column { - relation: None, - name: INTERNAL_GROUPING_COLUMN.to_owned(), - }; + let gid_column = Expr::VirtualColumn(DataType::Int32, INTERNAL_GROUPING_COLUMN.to_string()); let distinct_group_by = distinct_group_exprs(&group_expr); let new_agg_expr = aggr_expr .into_iter() @@ -166,7 +163,7 @@ fn contains_grouping_funcs(expr: &Expr) -> bool { fn replace_grouping_func( expr: Expr, group_by_exprs: &[Expr], - gid_column: Column, + gid_column: Expr, ) -> Result { rewrite_expr(expr, |expr| { let display_name = expr.display_name()?; @@ -181,7 +178,7 @@ fn replace_grouping_func( Some(idx) => Ok(cast( bitwise_and( bitwise_shift_right( - col(gid_column.clone()), + gid_column.clone(), lit((group_by_exprs.len() - 1 - idx) as u32), ), lit(1), @@ -203,7 +200,7 @@ fn replace_grouping_func( || (group_by_exprs.len() == args.len() && group_by_exprs.iter().zip(args.iter()).all(|(g, a)| g == a)) { - Ok(col(gid_column.clone()).alias(display_name)) + Ok(gid_column.clone().alias(display_name)) } else { Err(DataFusionError::Plan(format!( "Columns of GROUPING_ID({:?}) does not match GROUP BY columns {:?}", diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index bfac8da643dba..66a5a365162bf 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -255,6 +255,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::ScalarVariable(_, _) | Expr::Column(_) | Expr::OuterReferenceColumn(_, _) + | Expr::VirtualColumn(_, _) | Expr::Exists { .. } | Expr::InSubquery { .. } | Expr::ScalarSubquery(_) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 1048b1ed127f5..980662aa07157 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -856,10 +856,10 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { Expr::Wildcard => Self { expr_type: Some(ExprType::Wildcard(true)), }, - Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::OuterReferenceColumn{..} => { + Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::OuterReferenceColumn{..} | Expr::VirtualColumn{..} => { // we would need to add logical plan operators to datafusion.proto to support this // see discussion in https://github.com/apache/arrow-datafusion/issues/2565 - return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); + return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported | Exp:VirtualColumn not supported".to_string())); } Expr::GetIndexedField(GetIndexedField { key, expr }) => Self { diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 91cef6d4712e7..25b9ce680a234 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -362,6 +362,7 @@ where ))), Expr::Column { .. } | Expr::OuterReferenceColumn(_, _) + | Expr::VirtualColumn(_, _) | Expr::Literal(_) | Expr::ScalarVariable(_, _) | Expr::Exists { .. } From 1cee4f452e389fe7867f8ca7a00b59bb3fa7407a Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Tue, 21 Mar 2023 18:57:28 +0800 Subject: [PATCH 3/6] fix common_subexpr_eliminate --- datafusion/core/tests/sql/group_by.rs | 5 ++--- datafusion/optimizer/src/analyzer/replace_grouping_func.rs | 4 ++-- datafusion/optimizer/src/common_subexpr_eliminate.rs | 7 +++++++ 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs index ae4b58fbd0abb..40135d74c29ac 100644 --- a/datafusion/core/tests/sql/group_by.rs +++ b/datafusion/core/tests/sql/group_by.rs @@ -917,9 +917,8 @@ async fn csv_query_group_by_with_grouping_functions() -> Result<()> { let plan = dataframe.into_optimized_plan()?; let expected = vec![ "Projection: GROUPING(aggregate_test_100.c1), GROUPINGID(aggregate_test_100.c1), aggregate_test_100.c1, AVG(aggregate_test_100.c12) [GROUPING(aggregate_test_100.c1):Int32;N, GROUPINGID(aggregate_test_100.c1):Int32;N, c1:Utf8, AVG(aggregate_test_100.c12):Float64;N]", - " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1), (aggregate_test_100.c1))]], aggr=[[CAST(CAST(_virtual_grouping_id AS _virtual_grouping_id AS Int64) & Int64(1) AS Binary) AS GROUPING(aggregate_test_100.c1), _virtual_grouping_id AS _virtual_grouping_id AS GROUPINGID(aggregate_test_100.c1), AVG(aggregate_test_100.c12)]] [c1:Utf8, GROUPING(aggregate_test_100.c1):Binary, GROUPINGID(aggregate_test_100.c1):Int32, AVG(aggregate_test_100.c12):Float64;N]", - " Projection: _virtual_grouping_id, aggregate_test_100.c1, aggregate_test_100.c12 [_virtual_grouping_id:Int32, c1:Utf8, c12:Float64]", - " TableScan: aggregate_test_100 projection=[c1, c12] [c1:Utf8, c12:Float64]", + " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1), (aggregate_test_100.c1))]], aggr=[[CAST(_virtual_grouping_id & UInt32(1) AS Binary) AS GROUPING(aggregate_test_100.c1), _virtual_grouping_id AS GROUPINGID(aggregate_test_100.c1), AVG(aggregate_test_100.c12)]] [c1:Utf8, GROUPING(aggregate_test_100.c1):Binary, GROUPINGID(aggregate_test_100.c1):UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " TableScan: aggregate_test_100 projection=[c1, c12] [c1:Utf8, c12:Float64]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); diff --git a/datafusion/optimizer/src/analyzer/replace_grouping_func.rs b/datafusion/optimizer/src/analyzer/replace_grouping_func.rs index 8d64d13c6ddf0..823886652e6b2 100644 --- a/datafusion/optimizer/src/analyzer/replace_grouping_func.rs +++ b/datafusion/optimizer/src/analyzer/replace_grouping_func.rs @@ -56,7 +56,7 @@ impl AnalyzerRule for ReplaceGroupingFunc { group_expr, .. }) if contains_grouping_funcs_in_exprs(&aggr_expr) => { - let gid_column = Expr::VirtualColumn(DataType::Int32, INTERNAL_GROUPING_COLUMN.to_string()); + let gid_column = Expr::VirtualColumn(DataType::UInt32, INTERNAL_GROUPING_COLUMN.to_string()); let distinct_group_by = distinct_group_exprs(&group_expr); let new_agg_expr = aggr_expr .into_iter() @@ -181,7 +181,7 @@ fn replace_grouping_func( gid_column.clone(), lit((group_by_exprs.len() - 1 - idx) as u32), ), - lit(1), + lit(1u32), ), DataType::Binary, ).alias(display_name)), diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 33bf676db1281..fef5d15f78567 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -18,6 +18,7 @@ //! Eliminate common sub-expression. use std::collections::{BTreeSet, HashMap}; +use std::convert::identity; use std::sync::Arc; use arrow::datatypes::DataType; @@ -537,8 +538,14 @@ impl ExprRewriter for CommonSubexprRewriter<'_> { if self.curr_index >= self.id_array.len() { return Ok(expr); } + if matches!(expr, Expr::VirtualColumn(_, _)) { + return Ok(expr); + } let (series_number, id) = &self.id_array[self.curr_index]; + if id.eq("_virtual_grouping_id") { + return Ok(expr); + } self.curr_index += 1; // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. let expr_set_item = self.expr_set.get(id).ok_or_else(|| { From b3c41a1ead01a9a195c8b3193e3c8fabab738fd5 Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Mon, 27 Mar 2023 12:47:40 +0800 Subject: [PATCH 4/6] Add ResolveGroupingAnalytics analyzer rule --- .../core/src/datasource/listing/helpers.rs | 3 +- .../core/src/physical_plan/aggregates/mod.rs | 142 +++-- .../src/physical_plan/aggregates/row_hash.rs | 5 +- datafusion/core/src/physical_plan/planner.rs | 242 ++------ datafusion/core/tests/sql/group_by.rs | 533 +++++++++++++++++- datafusion/expr/src/aggregate_function.rs | 17 +- datafusion/expr/src/expr.rs | 72 ++- datafusion/expr/src/expr_rewriter.rs | 3 +- datafusion/expr/src/expr_schema.rs | 5 +- datafusion/expr/src/expr_visitor.rs | 5 +- datafusion/expr/src/logical_plan/plan.rs | 34 +- .../expr/src/type_coercion/aggregates.rs | 2 +- datafusion/expr/src/utils.rs | 201 +++++-- datafusion/optimizer/src/analyzer/mod.rs | 7 +- .../src/analyzer/replace_grouping_func.rs | 214 ------- .../analyzer/resolve_grouping_analytics.rs | 213 +++++++ .../optimizer/src/common_subexpr_eliminate.rs | 7 - datafusion/optimizer/src/push_down_filter.rs | 4 +- .../optimizer/src/push_down_projection.rs | 2 + .../simplify_expressions/expr_simplifier.rs | 3 +- .../src/single_distinct_to_groupby.rs | 6 +- .../physical-expr/src/aggregate/build_in.rs | 6 +- .../physical-expr/src/expressions/column.rs | 71 +++ .../physical-expr/src/expressions/mod.rs | 2 +- datafusion/physical-expr/src/planner.rs | 20 +- .../proto/src/logical_plan/from_proto.rs | 2 +- datafusion/proto/src/logical_plan/to_proto.rs | 8 +- datafusion/proto/src/physical_plan/mod.rs | 8 +- datafusion/sql/src/utils.rs | 3 +- 29 files changed, 1301 insertions(+), 539 deletions(-) delete mode 100644 datafusion/optimizer/src/analyzer/replace_grouping_func.rs create mode 100644 datafusion/optimizer/src/analyzer/resolve_grouping_analytics.rs diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 77527fee5722c..a1857ca171146 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -82,7 +82,8 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> { Expr::Literal(_) | Expr::Alias(_, _) | Expr::OuterReferenceColumn(_, _) - | Expr::VirtualColumn(_, _) + | Expr::HiddenColumn(_, _) + | Expr::HiddenExpr(_, _) | Expr::ScalarVariable(_, _) | Expr::Not(_) | Expr::IsNotNull(_) diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index c41cc438c8987..6b132dc0d79cf 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -86,7 +86,11 @@ pub enum AggregateMode { #[derive(Clone, Debug, Default)] pub struct PhysicalGroupBy { /// Distinct (Physical Expr, Alias) in the grouping set - expr: Vec<(Arc, String)>, + grouping_set_expr: Vec<(Arc, String)>, + /// Hidden grouping set expr in the grouping set + hidden_grouping_set_expr: Vec<(Arc, String)>, + /// Distinct result expr for the grouping set, used to generate output schema + result_expr: Vec<(Arc, String)>, /// Corresponding NULL expressions for expr null_expr: Vec<(Arc, String)>, /// Null mask for each group in this grouping set. Each group is @@ -99,12 +103,16 @@ pub struct PhysicalGroupBy { impl PhysicalGroupBy { /// Create a new `PhysicalGroupBy` pub fn new( - expr: Vec<(Arc, String)>, + grouping_set_expr: Vec<(Arc, String)>, + hidden_grouping_set_expr: Vec<(Arc, String)>, + result_expr: Vec<(Arc, String)>, null_expr: Vec<(Arc, String)>, groups: Vec>, ) -> Self { Self { - expr, + grouping_set_expr, + hidden_grouping_set_expr, + result_expr, null_expr, groups, } @@ -115,7 +123,9 @@ impl PhysicalGroupBy { pub fn new_single(expr: Vec<(Arc, String)>) -> Self { let num_exprs = expr.len(); Self { - expr, + grouping_set_expr: expr.clone(), + hidden_grouping_set_expr: vec![], + result_expr: expr, null_expr: vec![], groups: vec![vec![false; num_exprs]], } @@ -128,7 +138,12 @@ impl PhysicalGroupBy { /// Returns the group expressions pub fn expr(&self) -> &[(Arc, String)] { - &self.expr + &self.grouping_set_expr + } + + /// Returns the group result expressions + pub fn result_expr(&self) -> &[(Arc, String)] { + &self.result_expr } /// Returns the null expressions @@ -136,6 +151,11 @@ impl PhysicalGroupBy { &self.null_expr } + /// Returns the hidden grouping set expressions + pub fn hidden_grouping_set_expr(&self) -> &[(Arc, String)] { + &self.hidden_grouping_set_expr + } + /// Returns the group null masks pub fn groups(&self) -> &[Vec] { &self.groups @@ -143,7 +163,7 @@ impl PhysicalGroupBy { /// Returns true if this `PhysicalGroupBy` has no group expressions pub fn is_empty(&self) -> bool { - self.expr.is_empty() + self.grouping_set_expr.is_empty() } } @@ -196,7 +216,7 @@ impl AggregateExec { ) -> Result { let schema = create_schema( &input.schema(), - &group_by.expr, + group_by.result_expr(), &aggr_expr, group_by.contains_null(), mode, @@ -205,7 +225,7 @@ impl AggregateExec { let schema = Arc::new(schema); let mut alias_map: HashMap> = HashMap::new(); - for (expression, name) in group_by.expr.iter() { + for (expression, name) in group_by.result_expr().iter() { if let Some(column) = expression.as_any().downcast_ref::() { let new_col_idx = schema.index_of(name)?; // When the column name is the same, but index does not equal, treat it as Alias @@ -243,7 +263,7 @@ impl AggregateExec { // Update column indices. Since the group by columns come first in the output schema, their // indices are simply 0..self.group_expr(len). self.group_by - .expr() + .result_expr() .iter() .enumerate() .map(|(index, (_col, name))| { @@ -275,7 +295,7 @@ impl AggregateExec { let batch_size = context.session_config().batch_size(); let input = self.input.execute(partition, Arc::clone(&context))?; let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - if self.group_by.expr.is_empty() { + if self.group_by.result_expr().is_empty() { Ok(StreamType::AggregateStream(AggregateStream::new( self.mode, self.schema.clone(), @@ -418,7 +438,7 @@ impl ExecutionPlan for AggregateExec { write!(f, "AggregateExec: mode={:?}", self.mode)?; let g: Vec = if self.group_by.groups.len() == 1 { self.group_by - .expr + .grouping_set_expr .iter() .map(|(e, alias)| { let e = e.to_string(); @@ -447,7 +467,8 @@ impl ExecutionPlan for AggregateExec { e } } else { - let (e, alias) = &self.group_by.expr[idx]; + let (e, alias) = + &self.group_by.grouping_set_expr[idx]; let e = e.to_string(); if &e != alias { format!("{e} as {alias}") @@ -484,7 +505,7 @@ impl ExecutionPlan for AggregateExec { // - aggregations somtimes also preserve invariants such as min, max... match self.mode { AggregateMode::Final | AggregateMode::FinalPartitioned - if self.group_by.expr.is_empty() => + if self.group_by.result_expr().is_empty() => { Statistics { num_rows: Some(1), @@ -671,8 +692,8 @@ fn evaluate_group_by( group_by: &PhysicalGroupBy, batch: &RecordBatch, ) -> Result>> { - let exprs: Vec = group_by - .expr + let exprs_value: Vec = group_by + .grouping_set_expr .iter() .map(|(expr, _)| { let value = expr.evaluate(batch)?; @@ -680,7 +701,7 @@ fn evaluate_group_by( }) .collect::>>()?; - let null_exprs: Vec = group_by + let null_exprs_value: Vec = group_by .null_expr .iter() .map(|(expr, _)| { @@ -689,23 +710,61 @@ fn evaluate_group_by( }) .collect::>>()?; - Ok(group_by - .groups - .iter() - .map(|group| { - group - .iter() - .enumerate() - .map(|(idx, is_null)| { - if *is_null { - null_exprs[idx].clone() - } else { - exprs[idx].clone() - } - }) - .collect() - }) - .collect()) + if !group_by.hidden_grouping_set_expr().is_empty() { + let hidden_exprs_value: Vec = group_by + .hidden_grouping_set_expr + .iter() + .map(|(expr, _)| { + let value = expr.evaluate(batch)?; + Ok(value.into_array(batch.num_rows())) + }) + .collect::>>()?; + + let chunk_size = hidden_exprs_value.len() / group_by.groups.len(); + let hidden_expr_value_chunks = + hidden_exprs_value.chunks(chunk_size).collect::>(); + + Ok(group_by + .groups + .iter() + .enumerate() + .map(|(groud_id, group)| { + let mut group_data = group + .iter() + .enumerate() + .map(|(idx, is_null)| { + if *is_null { + null_exprs_value[idx].clone() + } else { + exprs_value[idx].clone() + } + }) + .collect::>(); + for data in hidden_expr_value_chunks[groud_id] { + group_data.push(data.clone()); + } + group_data + }) + .collect()) + } else { + Ok(group_by + .groups + .iter() + .map(|group| { + group + .iter() + .enumerate() + .map(|(idx, is_null)| { + if *is_null { + null_exprs_value[idx].clone() + } else { + exprs_value[idx].clone() + } + }) + .collect::>() + }) + .collect()) + } } #[cfg(test)] @@ -775,7 +834,12 @@ mod tests { let input_schema = input.schema(); let grouping_set = PhysicalGroupBy { - expr: vec![ + grouping_set_expr: vec![ + (col("a", &input_schema)?, "a".to_string()), + (col("b", &input_schema)?, "b".to_string()), + ], + hidden_grouping_set_expr: vec![], + result_expr: vec![ (col("a", &input_schema)?, "a".to_string()), (col("b", &input_schema)?, "b".to_string()), ], @@ -890,9 +954,11 @@ mod tests { let input_schema = input.schema(); let grouping_set = PhysicalGroupBy { - expr: vec![(col("a", &input_schema)?, "a".to_string())], + grouping_set_expr: vec![(col("a", &input_schema)?, "a".to_string())], + hidden_grouping_set_expr: vec![], null_expr: vec![], groups: vec![vec![false]], + result_expr: vec![(col("a", &input_schema)?, "a".to_string())], }; let aggregates: Vec> = vec![Arc::new(Avg::new( @@ -929,7 +995,7 @@ mod tests { let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); let final_group: Vec<(Arc, String)> = grouping_set - .expr + .result_expr() .iter() .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone()))) .collect::>()?; @@ -1119,9 +1185,11 @@ mod tests { let groups_none = PhysicalGroupBy::default(); let groups_some = PhysicalGroupBy { - expr: vec![(col("a", &input_schema)?, "a".to_string())], + grouping_set_expr: vec![(col("a", &input_schema)?, "a".to_string())], + hidden_grouping_set_expr: vec![], null_expr: vec![], groups: vec![vec![false]], + result_expr: vec![(col("a", &input_schema)?, "a".to_string())], }; // something that allocates within the aggregator diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index 612b707cc19e5..ea3c97481d8be 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -134,7 +134,7 @@ impl GroupedHashAggregateStream { ) -> Result { let timer = baseline_metrics.elapsed_compute().timer(); - let mut start_idx = group_by.expr.len(); + let mut start_idx = group_by.result_expr().len(); let mut row_aggr_expr = vec![]; let mut row_agg_indices = vec![]; let mut row_aggregate_expressions = vec![]; @@ -175,7 +175,8 @@ impl GroupedHashAggregateStream { let row_aggr_schema = aggr_state_schema(&row_aggr_expr)?; - let group_schema = group_schema(&schema, group_by.expr.len()); + let group_schema = group_schema(&schema, group_by.result_expr().len()); + let row_converter = RowConverter::new( group_schema .fields() diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index cbd682a1261b8..0bce37947df28 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -68,11 +68,10 @@ use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessar use datafusion_expr::{logical_plan, StringifiedPlan}; use datafusion_expr::{WindowFrame, WindowFrameBound}; use datafusion_optimizer::utils::unalias; -use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::expressions::{HiddenColumn, Literal}; use datafusion_sql::utils::window_expr_common_partition_keys; use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryStreamExt}; -use itertools::Itertools; use log::{debug, trace}; use std::collections::HashMap; use std::fmt::Write; @@ -334,6 +333,8 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Ok(format!("{expr} SIMILAR TO {pattern}{escape}")) } } + Expr::HiddenColumn(_dt, c) => Ok(format!("#{}", c)), + Expr::HiddenExpr(expr, _) => Ok(create_physical_name(expr, false)?), Expr::Sort { .. } => Err(DataFusionError::Internal( "Create physical name does not support sort expression".to_string(), )), @@ -349,9 +350,6 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Expr::OuterReferenceColumn(_, _) => Err(DataFusionError::Internal( "Create physical name does not support OuterReferenceColumn".to_string(), )), - Expr::VirtualColumn(_dt, c) => { - Ok(c.to_string()) - } } } @@ -703,7 +701,7 @@ impl DefaultPhysicalPlanner { final_group .iter() .enumerate() - .map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone())) + .map(|(i, expr)| (expr.clone(), groups.result_expr()[i].1.clone())) .collect() ); @@ -1274,20 +1272,12 @@ impl DefaultPhysicalPlanner { session_state, ) } - Expr::GroupingSet(GroupingSet::Cube(exprs)) => create_cube_physical_expr( - exprs, - input_dfschema, - input_schema, - session_state, - ), - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { - create_rollup_physical_expr( - exprs, - input_dfschema, - input_schema, - session_state, - ) - } + Expr::GroupingSet(GroupingSet::Cube(_exprs)) => Err(DataFusionError::Internal( + "Unsupported logical plan: GroupingSet::Cube should be replaced to GroupingSet::GroupingSets".to_string(), + )), + Expr::GroupingSet(GroupingSet::Rollup(_exprs)) => Err(DataFusionError::Internal( + "Unsupported logical plan: GroupingSet::Rollup should be replaced to GroupingSet::GroupingSets".to_string(), + )), expr => Ok(PhysicalGroupBy::new_single(vec![tuple_err(( self.create_physical_expr( expr, @@ -1336,145 +1326,81 @@ fn merge_grouping_set_physical_expr( session_state: &SessionState, ) -> Result { let num_groups = grouping_sets.len(); - let mut all_exprs: Vec = vec![]; - let mut grouping_set_expr: Vec<(Arc, String)> = vec![]; - let mut null_exprs: Vec<(Arc, String)> = vec![]; + let mut all_normal_exprs: Vec = vec![]; + let mut all_hidden_result_exprs: Vec = vec![]; + + let mut grouping_set_phy_expr: Vec<(Arc, String)> = vec![]; + let mut hidden_grouping_set_phy_expr: Vec<(Arc, String)> = vec![]; + let mut null_phy_exprs: Vec<(Arc, String)> = vec![]; + + let mut hidden_grouping_set_result_phy_expr: Vec<(Arc, String)> = + vec![]; + let mut grouping_set_result_phy_expr: Vec<(Arc, String)> = vec![]; for expr in grouping_sets.iter().flatten() { - if !all_exprs.contains(expr) { - all_exprs.push(expr.clone()); + if let Expr::HiddenExpr(first, second) = expr { + if let Expr::HiddenColumn(dt, _) = second.as_ref() { + hidden_grouping_set_phy_expr.push(get_physical_expr_pair( + first, + input_dfschema, + input_schema, + session_state, + )?); + + if !all_hidden_result_exprs.contains(second) { + all_hidden_result_exprs.push(*second.clone()); + let hidden_column_name = second.display_name()?; + // The second element in the hidden expr should be converted to a physic HiddenColumn + hidden_grouping_set_result_phy_expr.push(( + Arc::new(HiddenColumn::new(&hidden_column_name, dt)), + hidden_column_name, + )); + } + } else { + return Err(DataFusionError::Internal( + "The second part of the Expr::HiddenExpr should be a Expr::HiddenColumn" + .to_string(), + )); + } + } else if !all_normal_exprs.contains(expr) { + all_normal_exprs.push(expr.clone()); - grouping_set_expr.push(get_physical_expr_pair( + let phy_expr = get_physical_expr_pair( expr, input_dfschema, input_schema, session_state, - )?); - - null_exprs.push(get_null_physical_expr_pair( + )?; + grouping_set_phy_expr.push(phy_expr.clone()); + null_phy_exprs.push(get_null_physical_expr_pair( expr, input_dfschema, input_schema, session_state, )?); + grouping_set_result_phy_expr.push(phy_expr); } } + grouping_set_result_phy_expr.append(&mut hidden_grouping_set_result_phy_expr); let mut merged_sets: Vec> = Vec::with_capacity(num_groups); - for expr_group in grouping_sets.iter() { - let group: Vec = all_exprs + let group: Vec = all_normal_exprs .iter() .map(|expr| !expr_group.contains(expr)) .collect(); - merged_sets.push(group) } Ok(PhysicalGroupBy::new( - grouping_set_expr, - null_exprs, + grouping_set_phy_expr, + hidden_grouping_set_phy_expr, + grouping_set_result_phy_expr, + null_phy_exprs, merged_sets, )) } -/// Expand and align a CUBE expression. This is a special case of GROUPING SETS -/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS) -fn create_cube_physical_expr( - exprs: &[Expr], - input_dfschema: &DFSchema, - input_schema: &Schema, - session_state: &SessionState, -) -> Result { - let num_of_exprs = exprs.len(); - let num_groups = num_of_exprs * num_of_exprs; - - let mut null_exprs: Vec<(Arc, String)> = - Vec::with_capacity(num_of_exprs); - let mut all_exprs: Vec<(Arc, String)> = - Vec::with_capacity(num_of_exprs); - - for expr in exprs { - null_exprs.push(get_null_physical_expr_pair( - expr, - input_dfschema, - input_schema, - session_state, - )?); - - all_exprs.push(get_physical_expr_pair( - expr, - input_dfschema, - input_schema, - session_state, - )?) - } - - let mut groups: Vec> = Vec::with_capacity(num_groups); - - groups.push(vec![false; num_of_exprs]); - - for null_count in 1..=num_of_exprs { - for null_idx in (0..num_of_exprs).combinations(null_count) { - let mut next_group: Vec = vec![false; num_of_exprs]; - null_idx.into_iter().for_each(|i| next_group[i] = true); - groups.push(next_group); - } - } - - Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups)) -} - -/// Expand and align a ROLLUP expression. This is a special case of GROUPING SETS -/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS) -fn create_rollup_physical_expr( - exprs: &[Expr], - input_dfschema: &DFSchema, - input_schema: &Schema, - session_state: &SessionState, -) -> Result { - let num_of_exprs = exprs.len(); - - let mut null_exprs: Vec<(Arc, String)> = - Vec::with_capacity(num_of_exprs); - let mut all_exprs: Vec<(Arc, String)> = - Vec::with_capacity(num_of_exprs); - - let mut groups: Vec> = Vec::with_capacity(num_of_exprs + 1); - - for expr in exprs { - null_exprs.push(get_null_physical_expr_pair( - expr, - input_dfschema, - input_schema, - session_state, - )?); - - all_exprs.push(get_physical_expr_pair( - expr, - input_dfschema, - input_schema, - session_state, - )?) - } - - for total in 0..=num_of_exprs { - let mut group: Vec = Vec::with_capacity(num_of_exprs); - - for index in 0..num_of_exprs { - if index < total { - group.push(false); - } else { - group.push(true); - } - } - - groups.push(group) - } - - Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups)) -} - /// For a given logical expr, get a properly typed NULL ScalarValue physical expression fn get_null_physical_expr_pair( expr: &Expr, @@ -1932,60 +1858,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_create_cube_expr() -> Result<()> { - let logical_plan = test_csv_scan().await?.build()?; - - let plan = plan(&logical_plan).await?; - - let exprs = vec![col("c1"), col("c2"), col("c3")]; - - let physical_input_schema = plan.schema(); - let physical_input_schema = physical_input_schema.as_ref(); - let logical_input_schema = logical_plan.schema(); - let session_state = make_session_state(); - - let cube = create_cube_physical_expr( - &exprs, - logical_input_schema, - physical_input_schema, - &session_state, - ); - - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; - - assert_eq!(format!("{cube:?}"), expected); - - Ok(()) - } - - #[tokio::test] - async fn test_create_rollup_expr() -> Result<()> { - let logical_plan = test_csv_scan().await?.build()?; - - let plan = plan(&logical_plan).await?; - - let exprs = vec![col("c1"), col("c2"), col("c3")]; - - let physical_input_schema = plan.schema(); - let physical_input_schema = physical_input_schema.as_ref(); - let logical_input_schema = logical_plan.schema(); - let session_state = make_session_state(); - - let rollup = create_rollup_physical_expr( - &exprs, - logical_input_schema, - physical_input_schema, - &session_state, - ); - - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; - - assert_eq!(format!("{rollup:?}"), expected); - - Ok(()) - } - #[tokio::test] async fn test_create_not() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs index 40135d74c29ac..02d53b0c9bafe 100644 --- a/datafusion/core/tests/sql/group_by.rs +++ b/datafusion/core/tests/sql/group_by.rs @@ -110,6 +110,28 @@ async fn csv_query_group_by_boolean() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_group_by_boolean2() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; + + let sql = + "SELECT COUNT(*), c3 FROM aggregate_simple GROUP BY c3 ORDER BY COUNT(*) DESC"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+-----------------+-------+", + "| COUNT(UInt8(1)) | c3 |", + "+-----------------+-------+", + "| 9 | true |", + "| 6 | false |", + "+-----------------+-------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + #[tokio::test] async fn csv_query_group_by_two_columns() -> Result<()> { let ctx = SessionContext::new(); @@ -907,17 +929,17 @@ async fn csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> { } #[tokio::test] -async fn csv_query_group_by_with_grouping_functions() -> Result<()> { +async fn group_by_with_dup_group_set() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_csv(&ctx).await?; - let sql = "SELECT GROUPING(c1), GROUPING_ID(c1), c1, avg(c12) FROM aggregate_test_100 GROUP BY GROUPING SETS((c1),(c1))"; + let sql = "SELECT c1, avg(c12) FROM aggregate_test_100 GROUP BY GROUPING SETS((c1),(c1),())"; let msg = format!("Creating logical plan for '{sql}'"); let dataframe = ctx.sql(sql).await.expect(&msg); - let plan = dataframe.into_optimized_plan()?; + let plan = dataframe.clone().into_optimized_plan()?; let expected = vec![ - "Projection: GROUPING(aggregate_test_100.c1), GROUPINGID(aggregate_test_100.c1), aggregate_test_100.c1, AVG(aggregate_test_100.c12) [GROUPING(aggregate_test_100.c1):Int32;N, GROUPINGID(aggregate_test_100.c1):Int32;N, c1:Utf8, AVG(aggregate_test_100.c12):Float64;N]", - " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1), (aggregate_test_100.c1))]], aggr=[[CAST(_virtual_grouping_id & UInt32(1) AS Binary) AS GROUPING(aggregate_test_100.c1), _virtual_grouping_id AS GROUPINGID(aggregate_test_100.c1), AVG(aggregate_test_100.c12)]] [c1:Utf8, GROUPING(aggregate_test_100.c1):Binary, GROUPINGID(aggregate_test_100.c1):UInt32, AVG(aggregate_test_100.c12):Float64;N]", + "Projection: aggregate_test_100.c1, AVG(aggregate_test_100.c12) [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N]", + " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1, UInt32(1) AS #grouping_set_id), (aggregate_test_100.c1, UInt32(2) AS #grouping_set_id), (UInt32(3) AS #grouping_set_id))]], aggr=[[AVG(aggregate_test_100.c12)]] [c1:Utf8, #grouping_set_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", " TableScan: aggregate_test_100 projection=[c1, c12] [c1:Utf8, c12:Float64]", ]; let formatted = plan.display_indent_schema().to_string(); @@ -932,6 +954,12 @@ async fn csv_query_group_by_with_grouping_functions() -> Result<()> { "+----+-----------------------------+", "| c1 | AVG(aggregate_test_100.c12) |", "+----+-----------------------------+", + "| | 0.5089725099127211 |", + "| a | 0.48754517466109415 |", + "| b | 0.41040709263815384 |", + "| c | 0.6600456536439784 |", + "| d | 0.48855379387549824 |", + "| e | 0.48600669271341534 |", "| a | 0.48754517466109415 |", "| b | 0.41040709263815384 |", "| c | 0.6600456536439784 |", @@ -942,3 +970,498 @@ async fn csv_query_group_by_with_grouping_functions() -> Result<()> { assert_batches_sorted_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn group_by_with_grouping_id_func() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, c2, c3, GROUPING_ID(c1, c2, c3), avg(c12) FROM \ + (select c1, c2, '0' as c3, c12 from aggregate_test_100) + GROUP BY GROUPING SETS((c1, c2), (c1, c3), (c3),())"; + + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.clone().into_optimized_plan()?; + let expected = vec![ + "Projection: aggregate_test_100.c1, aggregate_test_100.c2, c3, #grouping_id AS GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2,c3), AVG(aggregate_test_100.c12) [c1:Utf8, c2:UInt32, c3:Utf8, GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2,c3):UInt32;N, AVG(aggregate_test_100.c12):Float64;N]", + " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1, aggregate_test_100.c2, UInt32(1) AS #grouping_id), (aggregate_test_100.c1, c3, UInt32(2) AS #grouping_id), (c3, UInt32(6) AS #grouping_id), (UInt32(7) AS #grouping_id))]], aggr=[[AVG(aggregate_test_100.c12)]] [c1:Utf8, c2:UInt32, c3:Utf8, #grouping_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " Projection: aggregate_test_100.c1, aggregate_test_100.c2, Utf8(\"0\") AS c3, aggregate_test_100.c12 [c1:Utf8, c2:UInt32, c3:Utf8, c12:Float64]", + " TableScan: aggregate_test_100 projection=[c1, c2, c12] [c1:Utf8, c2:UInt32, c12:Float64]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+----+----+-------------------------------------------------------------+-----------------------------+", + "| c1 | c2 | c3 | GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2,c3) | AVG(aggregate_test_100.c12) |", + "+----+----+----+-------------------------------------------------------------+-----------------------------+", + "| | | | 7 | 0.5089725099127211 |", + "| | | 0 | 6 | 0.5089725099127211 |", + "| a | | 0 | 2 | 0.48754517466109415 |", + "| a | 1 | | 1 | 0.4693685626367209 |", + "| a | 2 | | 1 | 0.5945188963859894 |", + "| a | 3 | | 1 | 0.5996111195922015 |", + "| a | 4 | | 1 | 0.3653038379118398 |", + "| a | 5 | | 1 | 0.3497223654469457 |", + "| b | | 0 | 2 | 0.41040709263815384 |", + "| b | 1 | | 1 | 0.16148594845154118 |", + "| b | 2 | | 1 | 0.5857678873564655 |", + "| b | 3 | | 1 | 0.42804338065410286 |", + "| b | 4 | | 1 | 0.33400957036260354 |", + "| b | 5 | | 1 | 0.4888141504446429 |", + "| c | | 0 | 2 | 0.6600456536439784 |", + "| c | 1 | | 1 | 0.6430620563927849 |", + "| c | 2 | | 1 | 0.7736013221256991 |", + "| c | 3 | | 1 | 0.421733279717472 |", + "| c | 4 | | 1 | 0.6827805579021969 |", + "| c | 5 | | 1 | 0.7277229477969185 |", + "| d | | 0 | 2 | 0.48855379387549824 |", + "| d | 1 | | 1 | 0.49931809179640024 |", + "| d | 2 | | 1 | 0.5181987328311988 |", + "| d | 3 | | 1 | 0.586369575965718 |", + "| d | 4 | | 1 | 0.49575895804943215 |", + "| d | 5 | | 1 | 0.2488799233225611 |", + "| e | | 0 | 2 | 0.48600669271341534 |", + "| e | 1 | | 1 | 0.780297346359783 |", + "| e | 2 | | 1 | 0.660795726704708 |", + "| e | 3 | | 1 | 0.5165824734324667 |", + "| e | 4 | | 1 | 0.2720288398836001 |", + "| e | 5 | | 1 | 0.29536905073188496 |", + "+----+----+----+-------------------------------------------------------------+-----------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn group_by_with_multi_grouping_funcs() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, c2, GROUPING(C1), GROUPING(C2), GROUPING_ID(c1, c2), avg(c12) FROM aggregate_test_100 GROUP BY CUBE(c1, c2)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.clone().into_optimized_plan()?; + let expected = vec![ + "Projection: aggregate_test_100.c1, aggregate_test_100.c2, CAST((#grouping_id AS #grouping_id AS #grouping_id AS #grouping_id >> UInt32(1)) & UInt32(1) AS UInt8) AS GROUPING(aggregate_test_100.c1), CAST(#grouping_id AS #grouping_id AS #grouping_id AS #grouping_id & UInt32(1) AS UInt8) AS GROUPING(aggregate_test_100.c2), #grouping_id AS #grouping_id AS #grouping_id AS #grouping_id AS GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2), AVG(aggregate_test_100.c12) [c1:Utf8, c2:UInt32, GROUPING(aggregate_test_100.c1):UInt8;N, GROUPING(aggregate_test_100.c2):UInt8;N, GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2):UInt32;N, AVG(aggregate_test_100.c12):Float64;N]", + " Aggregate: groupBy=[[GROUPING SETS ((UInt32(3) AS #grouping_id), (aggregate_test_100.c1, UInt32(1) AS #grouping_id), (aggregate_test_100.c2, UInt32(2) AS #grouping_id), (aggregate_test_100.c1, aggregate_test_100.c2, UInt32(0) AS #grouping_id))]], aggr=[[AVG(aggregate_test_100.c12)]] [c1:Utf8, c2:UInt32, #grouping_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " TableScan: aggregate_test_100 projection=[c1, c2, c12] [c1:Utf8, c2:UInt32, c12:Float64]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+----+---------------------------------+---------------------------------+----------------------------------------------------------+-----------------------------+", + "| c1 | c2 | GROUPING(aggregate_test_100.c1) | GROUPING(aggregate_test_100.c2) | GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2) | AVG(aggregate_test_100.c12) |", + "+----+----+---------------------------------+---------------------------------+----------------------------------------------------------+-----------------------------+", + "| | | 1 | 1 | 3 | 0.5089725099127211 |", + "| | 1 | 1 | 0 | 2 | 0.5108939802619781 |", + "| | 2 | 1 | 0 | 2 | 0.6545641966127662 |", + "| | 3 | 1 | 0 | 2 | 0.5245329062820169 |", + "| | 4 | 1 | 0 | 2 | 0.40234192123489837 |", + "| | 5 | 1 | 0 | 2 | 0.4312272637333415 |", + "| a | | 0 | 1 | 1 | 0.48754517466109415 |", + "| a | 1 | 0 | 0 | 0 | 0.4693685626367209 |", + "| a | 2 | 0 | 0 | 0 | 0.5945188963859894 |", + "| a | 3 | 0 | 0 | 0 | 0.5996111195922015 |", + "| a | 4 | 0 | 0 | 0 | 0.3653038379118398 |", + "| a | 5 | 0 | 0 | 0 | 0.3497223654469457 |", + "| b | | 0 | 1 | 1 | 0.41040709263815384 |", + "| b | 1 | 0 | 0 | 0 | 0.16148594845154118 |", + "| b | 2 | 0 | 0 | 0 | 0.5857678873564655 |", + "| b | 3 | 0 | 0 | 0 | 0.42804338065410286 |", + "| b | 4 | 0 | 0 | 0 | 0.33400957036260354 |", + "| b | 5 | 0 | 0 | 0 | 0.4888141504446429 |", + "| c | | 0 | 1 | 1 | 0.6600456536439784 |", + "| c | 1 | 0 | 0 | 0 | 0.6430620563927849 |", + "| c | 2 | 0 | 0 | 0 | 0.7736013221256991 |", + "| c | 3 | 0 | 0 | 0 | 0.421733279717472 |", + "| c | 4 | 0 | 0 | 0 | 0.6827805579021969 |", + "| c | 5 | 0 | 0 | 0 | 0.7277229477969185 |", + "| d | | 0 | 1 | 1 | 0.48855379387549824 |", + "| d | 1 | 0 | 0 | 0 | 0.49931809179640024 |", + "| d | 2 | 0 | 0 | 0 | 0.5181987328311988 |", + "| d | 3 | 0 | 0 | 0 | 0.586369575965718 |", + "| d | 4 | 0 | 0 | 0 | 0.49575895804943215 |", + "| d | 5 | 0 | 0 | 0 | 0.2488799233225611 |", + "| e | | 0 | 1 | 1 | 0.48600669271341534 |", + "| e | 1 | 0 | 0 | 0 | 0.780297346359783 |", + "| e | 2 | 0 | 0 | 0 | 0.660795726704708 |", + "| e | 3 | 0 | 0 | 0 | 0.5165824734324667 |", + "| e | 4 | 0 | 0 | 0 | 0.2720288398836001 |", + "| e | 5 | 0 | 0 | 0 | 0.29536905073188496 |", + "+----+----+---------------------------------+---------------------------------+----------------------------------------------------------+-----------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn group_by_with_dup_group_set_and_grouping_func() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, avg(c12), GROUPING(C1) FROM aggregate_test_100 GROUP BY GROUPING SETS((c1),(c1),())"; + + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.clone().into_optimized_plan()?; + let expected = vec![ + "Projection: aggregate_test_100.c1, AVG(aggregate_test_100.c12), CAST(#grouping_id & UInt32(1) AS UInt8) AS GROUPING(aggregate_test_100.c1) [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N, GROUPING(aggregate_test_100.c1):UInt8;N]", + " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1, UInt32(0) AS #grouping_id, UInt32(1) AS #grouping_set_id), (aggregate_test_100.c1, UInt32(0) AS #grouping_id, UInt32(2) AS #grouping_set_id), (UInt32(1) AS #grouping_id, UInt32(3) AS #grouping_set_id))]], aggr=[[AVG(aggregate_test_100.c12)]] [c1:Utf8, #grouping_id:UInt32, #grouping_set_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " TableScan: aggregate_test_100 projection=[c1, c12] [c1:Utf8, c12:Float64]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+-----------------------------+---------------------------------+", + "| c1 | AVG(aggregate_test_100.c12) | GROUPING(aggregate_test_100.c1) |", + "+----+-----------------------------+---------------------------------+", + "| | 0.5089725099127211 | 1 |", + "| a | 0.48754517466109415 | 0 |", + "| a | 0.48754517466109415 | 0 |", + "| b | 0.41040709263815384 | 0 |", + "| b | 0.41040709263815384 | 0 |", + "| c | 0.6600456536439784 | 0 |", + "| c | 0.6600456536439784 | 0 |", + "| d | 0.48855379387549824 | 0 |", + "| d | 0.48855379387549824 | 0 |", + "| e | 0.48600669271341534 | 0 |", + "| e | 0.48600669271341534 | 0 |", + "+----+-----------------------------+---------------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn group_by_with_grouping_func_and_having() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, avg(c12), GROUPING(C1) FROM aggregate_test_100 \ + GROUP BY GROUPING SETS((c1),(c1),()) HAVING GROUPING(C1) = 1 and avg(c12) > 0"; + + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.clone().into_optimized_plan()?; + let expected = vec![ + "Projection: aggregate_test_100.c1, AVG(aggregate_test_100.c12), CAST(#grouping_id & UInt32(1) AS UInt8) AS GROUPING(aggregate_test_100.c1) [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N, GROUPING(aggregate_test_100.c1):UInt8;N]", + " Filter: (#grouping_id & UInt32(1)) = UInt32(1) AND AVG(aggregate_test_100.c12) > Float64(0) [c1:Utf8, #grouping_id:UInt32, #grouping_set_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1, UInt32(0) AS #grouping_id, UInt32(1) AS #grouping_set_id), (aggregate_test_100.c1, UInt32(0) AS #grouping_id, UInt32(2) AS #grouping_set_id), (UInt32(1) AS #grouping_id, UInt32(3) AS #grouping_set_id))]], aggr=[[AVG(aggregate_test_100.c12)]] [c1:Utf8, #grouping_id:UInt32, #grouping_set_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " TableScan: aggregate_test_100 projection=[c1, c12] [c1:Utf8, c12:Float64]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+-----------------------------+---------------------------------+", + "| c1 | AVG(aggregate_test_100.c12) | GROUPING(aggregate_test_100.c1) |", + "+----+-----------------------------+---------------------------------+", + "| | 0.5089725099127211 | 1 |", + "+----+-----------------------------+---------------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn group_by_with_grouping_func_as_expr() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, avg(c12), GROUPING(C1) + GROUPING(C2) as grouping_lvl FROM aggregate_test_100 \ + GROUP BY GROUPING SETS((c1),(c1),(c1, c2))\ + ORDER BY CASE WHEN grouping_lvl = 0 THEN 0 ELSE 1 END"; + + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.clone().into_optimized_plan()?; + let expected = vec![ + "Sort: CASE WHEN grouping_lvl = UInt8(0) THEN Int64(0) ELSE Int64(1) END AS CASE WHEN grouping_lvl = Int64(0) THEN Int64(0) ELSE Int64(1) END ASC NULLS LAST [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N, grouping_lvl:UInt8;N]", + " Projection: aggregate_test_100.c1, AVG(aggregate_test_100.c12), CAST((#grouping_id >> UInt32(1)) & UInt32(1) AS UInt8) + CAST(#grouping_id & UInt32(1) AS UInt8) AS grouping_lvl [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N, grouping_lvl:UInt8;N]", + " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1, UInt32(1) AS #grouping_id, UInt32(1) AS #grouping_set_id), (aggregate_test_100.c1, UInt32(1) AS #grouping_id, UInt32(2) AS #grouping_set_id), (aggregate_test_100.c1, aggregate_test_100.c2, UInt32(0) AS #grouping_id, UInt32(3) AS #grouping_set_id))]], aggr=[[AVG(aggregate_test_100.c12)]] [c1:Utf8, c2:UInt32, #grouping_id:UInt32, #grouping_set_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " TableScan: aggregate_test_100 projection=[c1, c2, c12] [c1:Utf8, c2:UInt32, c12:Float64]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+-----------------------------+--------------+", + "| c1 | AVG(aggregate_test_100.c12) | grouping_lvl |", + "+----+-----------------------------+--------------+", + "| a | 0.3497223654469457 | 0 |", + "| a | 0.3653038379118398 | 0 |", + "| a | 0.4693685626367209 | 0 |", + "| a | 0.48754517466109415 | 1 |", + "| a | 0.48754517466109415 | 1 |", + "| a | 0.5945188963859894 | 0 |", + "| a | 0.5996111195922015 | 0 |", + "| b | 0.16148594845154118 | 0 |", + "| b | 0.33400957036260354 | 0 |", + "| b | 0.41040709263815384 | 1 |", + "| b | 0.41040709263815384 | 1 |", + "| b | 0.42804338065410286 | 0 |", + "| b | 0.4888141504446429 | 0 |", + "| b | 0.5857678873564655 | 0 |", + "| c | 0.421733279717472 | 0 |", + "| c | 0.6430620563927849 | 0 |", + "| c | 0.6600456536439784 | 1 |", + "| c | 0.6600456536439784 | 1 |", + "| c | 0.6827805579021969 | 0 |", + "| c | 0.7277229477969185 | 0 |", + "| c | 0.7736013221256991 | 0 |", + "| d | 0.2488799233225611 | 0 |", + "| d | 0.48855379387549824 | 1 |", + "| d | 0.48855379387549824 | 1 |", + "| d | 0.49575895804943215 | 0 |", + "| d | 0.49931809179640024 | 0 |", + "| d | 0.5181987328311988 | 0 |", + "| d | 0.586369575965718 | 0 |", + "| e | 0.2720288398836001 | 0 |", + "| e | 0.29536905073188496 | 0 |", + "| e | 0.48600669271341534 | 1 |", + "| e | 0.48600669271341534 | 1 |", + "| e | 0.5165824734324667 | 0 |", + "| e | 0.660795726704708 | 0 |", + "| e | 0.780297346359783 | 0 |", + "+----+-----------------------------+--------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn group_by_with_grouping_func_and_order_by() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, avg(c12), GROUPING_ID(C1, c2) FROM aggregate_test_100 \ + GROUP BY CUBE(c1,c2) ORDER BY GROUPING_ID(C1, c2) DESC"; + + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.clone().into_optimized_plan()?; + let expected = vec![ + "Sort: GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2) DESC NULLS FIRST [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N, GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2):UInt32;N]", + " Projection: aggregate_test_100.c1, AVG(aggregate_test_100.c12), #grouping_id AS GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2) [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N, GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2):UInt32;N]", + " Aggregate: groupBy=[[GROUPING SETS ((UInt32(3) AS #grouping_id), (aggregate_test_100.c1, UInt32(1) AS #grouping_id), (aggregate_test_100.c2, UInt32(2) AS #grouping_id), (aggregate_test_100.c1, aggregate_test_100.c2, UInt32(0) AS #grouping_id))]], aggr=[[AVG(aggregate_test_100.c12)]] [c1:Utf8, c2:UInt32, #grouping_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " TableScan: aggregate_test_100 projection=[c1, c2, c12] [c1:Utf8, c2:UInt32, c12:Float64]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+-----------------------------+----------------------------------------------------------+", + "| c1 | AVG(aggregate_test_100.c12) | GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2) |", + "+----+-----------------------------+----------------------------------------------------------+", + "| | 0.5089725099127211 | 3 |", + "| | 0.6545641966127662 | 2 |", + "| | 0.5245329062820169 | 2 |", + "| | 0.4312272637333415 | 2 |", + "| | 0.5108939802619781 | 2 |", + "| | 0.40234192123489837 | 2 |", + "| b | 0.41040709263815384 | 1 |", + "| a | 0.48754517466109415 | 1 |", + "| d | 0.48855379387549824 | 1 |", + "| c | 0.6600456536439784 | 1 |", + "| e | 0.48600669271341534 | 1 |", + "| e | 0.5165824734324667 | 0 |", + "| c | 0.6430620563927849 | 0 |", + "| c | 0.6827805579021969 | 0 |", + "| b | 0.5857678873564655 | 0 |", + "| c | 0.7277229477969185 | 0 |", + "| e | 0.2720288398836001 | 0 |", + "| e | 0.780297346359783 | 0 |", + "| e | 0.29536905073188496 | 0 |", + "| b | 0.42804338065410286 | 0 |", + "| b | 0.4888141504446429 | 0 |", + "| b | 0.33400957036260354 | 0 |", + "| a | 0.3653038379118398 | 0 |", + "| d | 0.5181987328311988 | 0 |", + "| a | 0.5945188963859894 | 0 |", + "| c | 0.7736013221256991 | 0 |", + "| b | 0.16148594845154118 | 0 |", + "| a | 0.5996111195922015 | 0 |", + "| d | 0.586369575965718 | 0 |", + "| d | 0.2488799233225611 | 0 |", + "| a | 0.4693685626367209 | 0 |", + "| d | 0.49931809179640024 | 0 |", + "| e | 0.660795726704708 | 0 |", + "| a | 0.3497223654469457 | 0 |", + "| d | 0.49575895804943215 | 0 |", + "| c | 0.421733279717472 | 0 |", + "+----+-----------------------------+----------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn group_by_rollup_with_count_wildcard_and_order_by() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, c2, c3, COUNT(*) \ + FROM aggregate_test_100 \ + WHERE c1 IN ('a', 'b', NULL) \ + GROUP BY c1, ROLLUP (c2, c3) \ + ORDER BY c1, c2, c3"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.clone().into_optimized_plan()?; + let expected = vec![ + "Sort: aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST [c1:Utf8, c2:UInt32, c3:Int8, COUNT(UInt8(1)):Int64;N]", + " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1), (aggregate_test_100.c1, aggregate_test_100.c2), (aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c3))]], aggr=[[COUNT(UInt8(1))]] [c1:Utf8, c2:UInt32, c3:Int8, COUNT(UInt8(1)):Int64;N]", + " Filter: aggregate_test_100.c1 = Utf8(NULL) OR aggregate_test_100.c1 = Utf8(\"b\") OR aggregate_test_100.c1 = Utf8(\"a\") [c1:Utf8, c2:UInt32, c3:Int8]", + " TableScan: aggregate_test_100 projection=[c1, c2, c3], partial_filters=[aggregate_test_100.c1 = Utf8(NULL) OR aggregate_test_100.c1 = Utf8(\"b\") OR aggregate_test_100.c1 = Utf8(\"a\")] [c1:Utf8, c2:UInt32, c3:Int8]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+----+------+-----------------+", + "| c1 | c2 | c3 | COUNT(UInt8(1)) |", + "+----+----+------+-----------------+", + "| a | 1 | -85 | 1 |", + "| a | 1 | -56 | 1 |", + "| a | 1 | -25 | 1 |", + "| a | 1 | -5 | 1 |", + "| a | 1 | 83 | 1 |", + "| a | 1 | | 5 |", + "| a | 2 | -48 | 1 |", + "| a | 2 | -43 | 1 |", + "| a | 2 | 45 | 1 |", + "| a | 2 | | 3 |", + "| a | 3 | -72 | 1 |", + "| a | 3 | -12 | 1 |", + "| a | 3 | 13 | 2 |", + "| a | 3 | 14 | 1 |", + "| a | 3 | 17 | 1 |", + "| a | 3 | | 6 |", + "| a | 4 | -101 | 1 |", + "| a | 4 | -54 | 1 |", + "| a | 4 | -38 | 1 |", + "| a | 4 | 65 | 1 |", + "| a | 4 | | 4 |", + "| a | 5 | -101 | 1 |", + "| a | 5 | -31 | 1 |", + "| a | 5 | 36 | 1 |", + "| a | 5 | | 3 |", + "| a | | | 21 |", + "| b | 1 | 12 | 1 |", + "| b | 1 | 29 | 1 |", + "| b | 1 | 54 | 1 |", + "| b | 1 | | 3 |", + "| b | 2 | -60 | 1 |", + "| b | 2 | 31 | 1 |", + "| b | 2 | 63 | 1 |", + "| b | 2 | 68 | 1 |", + "| b | 2 | | 4 |", + "| b | 3 | -101 | 1 |", + "| b | 3 | 17 | 1 |", + "| b | 3 | | 2 |", + "| b | 4 | -117 | 1 |", + "| b | 4 | -111 | 1 |", + "| b | 4 | -59 | 1 |", + "| b | 4 | 17 | 1 |", + "| b | 4 | 47 | 1 |", + "| b | 4 | | 5 |", + "| b | 5 | -82 | 1 |", + "| b | 5 | -44 | 1 |", + "| b | 5 | -5 | 1 |", + "| b | 5 | 62 | 1 |", + "| b | 5 | 68 | 1 |", + "| b | 5 | | 5 |", + "| b | | | 19 |", + "+----+----+------+-----------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn invalid_grouping_func() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, avg(c12), GROUPING(c3) FROM aggregate_test_100 GROUP BY GROUPING SETS((c1),(c2),())"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let err = dataframe.into_optimized_plan().err().unwrap(); + assert_eq!( + "Plan(\"Column of GROUPING(aggregate_test_100.c3) can't be found in GROUP BY columns [aggregate_test_100.c1, aggregate_test_100.c2]\")", + &format!("{err:?}") + ); + + Ok(()) +} + +#[tokio::test] +async fn invalid_grouping_id_func() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, avg(c12), GROUPING_ID(c1) FROM aggregate_test_100 GROUP BY GROUPING SETS((c1),(c2),())"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let err = dataframe.into_optimized_plan().err().unwrap(); + assert_eq!( + "Plan(\"Columns of GROUPING_ID([aggregate_test_100.c1]) does not match GROUP BY columns [aggregate_test_100.c1, aggregate_test_100.c2]\")", + &format!("{err:?}") + ); + + Ok(()) +} + +#[tokio::test] +async fn invalid_grouping_id_func2() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + + // The column ordering of the GROUPING_ID() matters + let sql = "SELECT c1, avg(c12), GROUPING_ID(c2, c1) FROM aggregate_test_100 GROUP BY CUBE(c1,c2)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let err = dataframe.into_optimized_plan().err().unwrap(); + assert_eq!( + "Plan(\"Columns of GROUPING_ID([aggregate_test_100.c2, aggregate_test_100.c1]) does not match GROUP BY columns [aggregate_test_100.c1, aggregate_test_100.c2]\")", + &format!("{err:?}") + ); + + Ok(()) +} diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 7b41010ee0867..ad5db1c0a14da 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -64,13 +64,18 @@ pub enum AggregateFunction { /// Grouping Grouping, /// GroupingID - GroupingID, + GroupingId, } impl fmt::Display for AggregateFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // uppercase of the debug. - write!(f, "{}", format!("{self:?}").to_uppercase()) + match self { + AggregateFunction::GroupingId => { + write!(f, "GROUPING_ID") + } + _ => write!(f, "{}", format!("{self:?}").to_uppercase()), + } } } @@ -103,7 +108,7 @@ impl FromStr for AggregateFunction { } "approx_median" => AggregateFunction::ApproxMedian, "grouping" => AggregateFunction::Grouping, - "grouping_id" => AggregateFunction::GroupingID, + "grouping_id" => AggregateFunction::GroupingId, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {name}" @@ -160,8 +165,8 @@ pub fn return_type( AggregateFunction::ApproxMedian | AggregateFunction::Median => { Ok(coerced_data_types[0].clone()) } - AggregateFunction::Grouping => Ok(DataType::Int32), - AggregateFunction::GroupingID => Ok(DataType::Int32), + AggregateFunction::Grouping => Ok(DataType::UInt8), + AggregateFunction::GroupingId => Ok(DataType::UInt32), } } @@ -224,7 +229,7 @@ pub fn signature(fun: &AggregateFunction) -> Signature { .collect(), Volatility::Immutable, ), - AggregateFunction::GroupingID => Signature { + AggregateFunction::GroupingId => Signature { type_signature: TypeSignature::Arbitrary, volatility: Volatility::Immutable, }, diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index a67432f9ba295..4e36b4d67c115 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -21,7 +21,7 @@ use crate::aggregate_function; use crate::built_in_function; use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; -use crate::utils::{expr_to_columns, find_out_reference_exprs}; +use crate::utils::{expr_to_columns, find_hidden_columns, find_out_reference_exprs}; use crate::window_frame; use crate::window_function; use crate::AggregateUDF; @@ -223,8 +223,10 @@ pub enum Expr { /// A place holder which hold a reference to a qualified field /// in the outer query, used for correlated sub queries. OuterReferenceColumn(DataType, Column), - /// A virtual column used by the system internally - VirtualColumn(DataType, String), + /// A hidden column used by the system internally + HiddenColumn(DataType, String), + /// A hidden expr pair used by the system internally, evaluated to a HiddenColumn + HiddenExpr(Box, Box), } /// Binary expression @@ -507,14 +509,18 @@ impl GroupingSet { /// Return all distinct exprs in the grouping set. For `CUBE` and `ROLLUP` this /// is just the underlying list of exprs. For `GROUPING SET` we need to deduplicate /// the exprs in the underlying sets. - pub fn distinct_expr(&self) -> Vec { + pub fn distinct_expr(&self, include_hidden: bool) -> Vec { match self { GroupingSet::Rollup(exprs) => exprs.clone(), GroupingSet::Cube(exprs) => exprs.clone(), GroupingSet::GroupingSets(groups) => { let mut exprs: Vec = vec![]; for exp in groups.iter().flatten() { - if !exprs.contains(exp) { + if let Expr::HiddenExpr(_, second) = exp { + if include_hidden && !exprs.contains(second) { + exprs.push(*second.clone()); + } + } else if !exprs.contains(exp) { exprs.push(exp.clone()); } } @@ -522,6 +528,48 @@ impl GroupingSet { } } } + + pub fn contains_duplicate_grouping(&self) -> bool { + match self { + GroupingSet::Rollup(_) => false, + GroupingSet::Cube(_) => false, + GroupingSet::GroupingSets(groups) => { + let exclude_hidden = groups + .clone() + .into_iter() + .map(|group| { + group + .into_iter() + .filter(|e| !matches!(e, Expr::HiddenExpr(_, _))) + .collect::>() + }) + .collect::>(); + let exclude_hidden_len = exclude_hidden.len(); + let distinct_set = exclude_hidden.into_iter().collect::>(); + exclude_hidden_len != distinct_set.len() + } + } + } + + pub fn contains_hidden_expr(&self) -> bool { + match self { + GroupingSet::Rollup(_) => false, + GroupingSet::Cube(_) => false, + GroupingSet::GroupingSets(groups) => groups + .iter() + .flatten() + .any(|e| matches!(e, Expr::HiddenExpr(_, _))), + } + } + + /// Return the input exprs len in the grouping set + pub fn input_expr_len(&self) -> usize { + match self { + GroupingSet::Rollup(exprs) => exprs.len(), + GroupingSet::Cube(exprs) => exprs.len(), + GroupingSet::GroupingSets(groups) => groups.len(), + } + } } /// Fixed seed for the hashing so that Ords are consistent across runs @@ -602,7 +650,8 @@ impl Expr { Expr::TryCast { .. } => "TryCast", Expr::WindowFunction { .. } => "WindowFunction", Expr::Wildcard => "Wildcard", - Expr::VirtualColumn(..) => "VirtualColumn", + Expr::HiddenColumn(..) => "HiddenColumn", + Expr::HiddenExpr(..) => "HiddenExpr", } } @@ -797,6 +846,11 @@ impl Expr { pub fn contains_outer(&self) -> bool { !find_out_reference_exprs(self).is_empty() } + + /// Return true when the expression contains hidden columns. + pub fn contains_hidden_columns(&self) -> bool { + !find_hidden_columns(self).is_empty() + } } impl Not for Expr { @@ -1084,7 +1138,8 @@ impl fmt::Debug for Expr { } }, Expr::Placeholder { id, .. } => write!(f, "{id}"), - Expr::VirtualColumn(_, c) => write!(f, "_virtual_{}", c), + Expr::HiddenColumn(_, c) => write!(f, "#{}", c), + Expr::HiddenExpr(first, _) => write!(f, "{}", first), } } } @@ -1368,7 +1423,8 @@ fn create_name(e: &Expr) -> Result { "Create name does not support qualified wildcard".to_string(), )), Expr::Placeholder { id, .. } => Ok((*id).to_string()), - Expr::VirtualColumn(_, c) => Ok(format!("_virtual_{}", c)), + Expr::HiddenColumn(_, c) => Ok(format!("#{}", c)), + Expr::HiddenExpr(first, _) => Ok(format!("#{}", first)), } } diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index 56e2bbfb040ef..5853890d54786 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -122,7 +122,8 @@ impl ExprRewritable for Expr { Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name), Expr::Column(_) => self.clone(), Expr::OuterReferenceColumn(_, _) => self.clone(), - Expr::VirtualColumn(_, _) => self.clone(), + Expr::HiddenColumn(_, _) => self.clone(), + Expr::HiddenExpr(_, _) => self.clone(), Expr::Exists { .. } => self.clone(), Expr::InSubquery { expr, diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 2f671dc638a2d..d9b5e25155cc5 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -66,7 +66,8 @@ impl ExprSchemable for Expr { Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => expr.get_type(schema), Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), - Expr::VirtualColumn(ty, _) => Ok(ty.clone()), + Expr::HiddenColumn(ty, _) => Ok(ty.clone()), + Expr::HiddenExpr(_, second) => second.get_type(schema), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), Expr::Literal(l) => Ok(l.get_datatype()), Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), @@ -213,7 +214,7 @@ impl ExprSchemable for Expr { | Expr::IsNotUnknown(_) | Expr::Exists { .. } | Expr::Placeholder { .. } => Ok(true), - | Expr::VirtualColumn(_, _) => Ok(false), + Expr::HiddenColumn(_, _) | Expr::HiddenExpr(_, _) => Ok(false), Expr::InSubquery { expr, .. } => expr.nullable(input_schema), Expr::ScalarSubquery(subquery) => { Ok(subquery.subquery.schema().field(0).is_nullable()) diff --git a/datafusion/expr/src/expr_visitor.rs b/datafusion/expr/src/expr_visitor.rs index faf228513f6e2..3426b4decedf6 100644 --- a/datafusion/expr/src/expr_visitor.rs +++ b/datafusion/expr/src/expr_visitor.rs @@ -118,7 +118,8 @@ impl ExprVisitable for Expr { | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) | Expr::Sort(Sort { expr, .. }) - | Expr::InSubquery { expr, .. } => expr.accept(visitor), + | Expr::InSubquery { expr, .. } + | Expr::HiddenExpr(expr, _) => expr.accept(visitor), Expr::GetIndexedField(GetIndexedField { expr, .. }) => expr.accept(visitor), Expr::GroupingSet(GroupingSet::Rollup(exprs)) => exprs .iter() @@ -136,7 +137,7 @@ impl ExprVisitable for Expr { Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression | Expr::OuterReferenceColumn(_, _) - | Expr::VirtualColumn(_, _) + | Expr::HiddenColumn(_, _) | Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Exists { .. } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index da490db39538f..7a31790e018d8 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -24,8 +24,8 @@ use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::plan; use crate::utils::{ - enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, from_plan, - grouping_set_expr_count, grouping_set_to_exprlist, + distinct_group_exprs, exprlist_to_fields, find_out_reference_exprs, from_plan, + grouping_set_expr_count, }; use crate::{ build_join_schema, Expr, ExprSchemable, TableProviderFilterPushDown, TableSource, @@ -1804,25 +1804,39 @@ pub struct Aggregate { } impl Aggregate { - /// Create a new aggregate operator. + /// Create a new aggregate operator, the group_expr might contain multiple [Expr::GroupingSet] expressions pub fn try_new( input: Arc, group_expr: Vec, aggr_expr: Vec, ) -> Result { - let group_expr = enumerate_grouping_sets(group_expr)?; - let grouping_expr: Vec = grouping_set_to_exprlist(group_expr.as_slice())?; - let all_expr = grouping_expr.iter().chain(aggr_expr.iter()); + if group_expr.is_empty() && aggr_expr.is_empty() { + return Err(DataFusionError::Plan( + "Aggregate requires at least one grouping or aggregate expression" + .to_string(), + )); + } + let distinct_grouping_expr: Vec = + distinct_group_exprs(group_expr.as_slice(), true); + + let all_expr = distinct_grouping_expr.iter().chain(aggr_expr.iter()); validate_unique_names("Aggregations", all_expr.clone())?; - let schema = DFSchema::new_with_metadata( + let schema = Arc::new(DFSchema::new_with_metadata( exprlist_to_fields(all_expr, &input)?, input.schema().metadata().clone(), - )?; - Self::try_new_with_schema(input, group_expr, aggr_expr, Arc::new(schema)) + )?); + + Ok(Self { + input, + group_expr, + aggr_expr, + schema, + }) } /// Create a new aggregate operator using the provided schema to avoid the overhead of - /// building the schema again when the schema is already known. + /// building the schema again when the schema is already known, + /// The group_expr can not contain multiple [Expr::GroupingSet] expressions. /// /// This method should only be called when you are absolutely sure that the schema being /// provided is correct for the aggregate. If in doubt, call [try_new](Self::try_new) instead. diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 1afce05ea6424..97b9eb87affa3 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -217,7 +217,7 @@ pub fn coerce_types( } AggregateFunction::Median => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), - AggregateFunction::GroupingID => Ok(input_types.to_vec()), + AggregateFunction::GroupingId => Ok(input_types.to_vec()), } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 70e78def83d47..1209aec9f96ae 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -54,6 +54,11 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result Ok(()) } +/// Check whether the group_expr contains [Expr::GroupingSet]. +pub fn contains_grouping_set(group_expr: &[Expr]) -> bool { + group_expr.iter().any(|e| matches!(e, Expr::GroupingSet(_))) +} + /// Count the number of distinct exprs in a list of group by expressions. If the /// first element is a `GroupingSet` expression then it must be the only expr. pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { @@ -64,7 +69,7 @@ pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { .to_string(), )); } - Ok(grouping_set.distinct_expr().len()) + Ok(grouping_set.distinct_expr(true).len()) } else { Ok(group_expr.len()) } @@ -119,6 +124,17 @@ fn check_grouping_set_size_limit(size: usize) -> Result<()> { Ok(()) } +/// check the number of distinct expressions contained in the grouping_set when using group id +fn check_grouping_set_distinct_expression_size_limit(size: usize) -> Result<()> { + // we use u32 to represent the grouping id + let max_expression_set_size = 32; + if size > max_expression_set_size { + return Err(DataFusionError::Plan(format!("The number of distinct group_expression in grouping_set exceeds the maximum limit {} when using group id, found {}", max_expression_set_size, size))); + } + + Ok(()) +} + /// check the number of grouping_set contained in the grouping sets fn check_grouping_sets_size_limit(size: usize) -> Result<()> { let max_grouping_sets_size = 4096; @@ -192,14 +208,10 @@ fn cross_join_grouping_sets( /// (person.id, person.age, person.state),\ /// (person.id, person.age, person.state, person.birth_date)\ /// ) -pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { - let has_grouping_set = group_expr - .iter() - .any(|expr| matches!(expr, Expr::GroupingSet(_))); - if !has_grouping_set || group_expr.len() == 1 { - return Ok(group_expr); +pub fn enumerate_grouping_sets(group_expr: &[Expr]) -> Result> { + if !contains_grouping_set(group_expr) { + return Ok(group_expr.to_vec()); } - // only process mix grouping sets let partial_sets = group_expr .iter() .map(|expr| { @@ -246,22 +258,94 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { ))]) } -/// Find all distinct exprs in a list of group by expressions. If the -/// first element is a `GroupingSet` expression then it must be the only expr. -pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { - if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { - if group_expr.len() > 1 { - return Err(DataFusionError::Plan( - "Invalid group by expressions, GroupingSet must be the only expression" - .to_string(), - )); +/// Generate the grouping ids for each group in this grouping set. +/// Each group id represents the level of grouping which combines the GROUPING() function +/// for several columns into one by assigning each column a bit. +/// +/// For example, we have the Group By columns (person.id, person.age, person.salary), +/// the the Grouping Set (person.id, person.age) will be represented as '001', the selected +/// column is set to '0' and the unselected is set to '1' +pub fn generate_grouping_ids(grouping_set: &GroupingSet) -> Result> { + match grouping_set { + GroupingSet::Rollup(_) => Ok(vec![]), + GroupingSet::Cube(_) => Ok(vec![]), + GroupingSet::GroupingSets(groups) => { + let distinct_exprs = grouping_set.distinct_expr(false); + check_grouping_set_distinct_expression_size_limit(distinct_exprs.len())?; + Ok(groups + .iter() + .map(|group| { + let mut mask = 0u32; + distinct_exprs.iter().for_each(|expr| { + mask = (mask << 1) + (if !group.contains(expr) { 1 } else { 0 }) + }); + mask + }) + .collect::>()) } - Ok(grouping_set.distinct_expr()) - } else { - Ok(group_expr.to_vec()) } } +/// Add hidden grouping set expression to each group in the grouping_set +pub fn add_hidden_grouping_set_expr( + grouping_set: &mut GroupingSet, + hidden_grouping_expr: F, +) -> Result<()> +where + F: Fn(usize) -> Expr, +{ + if let GroupingSet::GroupingSets(groups) = grouping_set { + groups + .iter_mut() + .enumerate() + .for_each(|(idx, expr)| expr.push(hidden_grouping_expr(idx))); + } + Ok(()) +} + +/// Find all distinct exprs in a list of group by expressions. +pub fn distinct_group_exprs(group_expr: &[Expr], include_hidden: bool) -> Vec { + let mut dedup_expr = Vec::new(); + let mut dedup_set = HashSet::new(); + let mut dedup_hidden_expr = Vec::new(); + let mut dedup_hidden_set = HashSet::new(); + group_expr.iter().for_each(|expr| match expr { + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(exprs) => exprs.iter().for_each(|e| { + if !dedup_set.contains(e) { + dedup_expr.push(e.clone()); + dedup_set.insert(e.clone()); + } + }), + GroupingSet::Cube(exprs) => exprs.iter().for_each(|e| { + if !dedup_set.contains(e) { + dedup_expr.push(e.clone()); + dedup_set.insert(e.clone()); + } + }), + GroupingSet::GroupingSets(groups) => groups.iter().flatten().for_each(|e| { + if let Expr::HiddenExpr(_, second) = e { + if include_hidden && !dedup_hidden_set.contains(second.as_ref()) { + dedup_hidden_expr.push(*second.clone()); + dedup_hidden_set.insert(*second.clone()); + } + } else if !dedup_set.contains(e) { + dedup_expr.push(e.clone()); + dedup_set.insert(e.clone()); + } + }), + }, + _ => { + if !dedup_set.contains(expr) { + dedup_expr.push(expr.clone()); + dedup_set.insert(expr.clone()); + } + } + }); + dedup_expr.append(&mut dedup_hidden_expr); + dedup_expr +} + /// Recursively walk an expression tree, collecting the unique set of columns /// referenced in the expression pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { @@ -312,7 +396,8 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::GetIndexedField { .. } | Expr::Placeholder { .. } | Expr::OuterReferenceColumn { .. } - | Expr::VirtualColumn {..} => {} + | Expr::HiddenColumn { .. } + | Expr::HiddenExpr { .. } => {} } Ok(()) }) @@ -568,6 +653,14 @@ pub fn find_out_reference_exprs(expr: &Expr) -> Vec { }) } +/// Collect all deeply nested `Expr::OuterReferenceColumn`. They are returned in order of occurrence +/// (depth first), with duplicates omitted. +pub fn find_hidden_columns(expr: &Expr) -> Vec { + find_exprs_in_expr(expr, &|nested_expr| { + matches!(nested_expr, Expr::HiddenColumn { .. }) + }) +} + /// Search the provided `Expr`'s, and all of their nested `Expr`, for any that /// pass the provided test. The returned `Expr`'s are deduplicated and returned /// in order of appearance (depth first). @@ -1433,22 +1526,25 @@ mod tests { let grouping_set = grouping_set(vec![multi_cols]); // 1. col - let sets = enumerate_grouping_sets(vec![simple_col.clone()])?; + let sets = enumerate_grouping_sets(&vec![simple_col.clone()])?; let result = format!("{sets:?}"); assert_eq!("[simple_col]", &result); // 2. cube - let sets = enumerate_grouping_sets(vec![cube.clone()])?; + let sets = enumerate_grouping_sets(&vec![cube.clone()])?; let result = format!("{sets:?}"); - assert_eq!("[CUBE (col1, col2, col3)]", &result); + assert_eq!("[GROUPING SETS ((), (col1), (col2), (col1, col2), (col3), (col1, col3), (col2, col3), (col1, col2, col3))]", &result); // 3. rollup - let sets = enumerate_grouping_sets(vec![rollup.clone()])?; + let sets = enumerate_grouping_sets(&vec![rollup.clone()])?; let result = format!("{sets:?}"); - assert_eq!("[ROLLUP (col1, col2, col3)]", &result); + assert_eq!( + "[GROUPING SETS ((), (col1), (col1, col2), (col1, col2, col3))]", + &result + ); // 4. col + cube - let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?; + let sets = enumerate_grouping_sets(&vec![simple_col.clone(), cube.clone()])?; let result = format!("{sets:?}"); assert_eq!( "[GROUPING SETS (\ @@ -1464,7 +1560,7 @@ mod tests { ); // 5. col + rollup - let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?; + let sets = enumerate_grouping_sets(&vec![simple_col.clone(), rollup.clone()])?; let result = format!("{sets:?}"); assert_eq!( "[GROUPING SETS (\ @@ -1477,7 +1573,7 @@ mod tests { // 6. col + grouping_set let sets = - enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?; + enumerate_grouping_sets(&vec![simple_col.clone(), grouping_set.clone()])?; let result = format!("{sets:?}"); assert_eq!( "[GROUPING SETS (\ @@ -1486,11 +1582,9 @@ mod tests { ); // 7. col + grouping_set + rollup - let sets = enumerate_grouping_sets(vec![ - simple_col.clone(), - grouping_set, - rollup.clone(), - ])?; + let sets = enumerate_grouping_sets( + vec![simple_col.clone(), grouping_set, rollup.clone()].as_slice(), + )?; let result = format!("{sets:?}"); assert_eq!( "[GROUPING SETS (\ @@ -1502,7 +1596,7 @@ mod tests { ); // 8. col + cube + rollup - let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?; + let sets = enumerate_grouping_sets(&vec![simple_col, cube, rollup])?; let result = format!("{sets:?}"); assert_eq!( "[GROUPING SETS (\ @@ -1543,4 +1637,41 @@ mod tests { Ok(()) } + + #[test] + fn test_generate_grouping_ids() -> Result<()> { + // 001 + let multi_cols1 = vec![col("col1"), col("col2")]; + // 010 + let multi_cols2 = vec![col("col1"), col("col3")]; + // 100 + let multi_cols3 = vec![col("col2"), col("col3")]; + // 000 + let multi_cols4 = vec![col("col1"), col("col2"), col("col3")]; + // 011 + let multi_cols5 = vec![col("col1")]; + // 101 + let multi_cols6 = vec![col("col2")]; + // 110 + let multi_cols7 = vec![col("col3")]; + // 011 + let multi_cols8 = vec![col("col1"), col("col1"), col("col1")]; + + let grouping_set = GroupingSet::GroupingSets(vec![ + multi_cols1, + multi_cols2, + multi_cols3, + multi_cols4, + multi_cols5, + multi_cols6, + multi_cols7, + multi_cols8, + ]); + + let grouping_id = generate_grouping_ids(&grouping_set)?; + let grouping_id_result = format!("{grouping_id:?}"); + assert_eq!("[1, 2, 4, 0, 3, 5, 6, 3]", &grouping_id_result); + + Ok(()) + } } diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 116ed3feae862..9986f47e45208 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -16,11 +16,10 @@ // under the License. mod count_wildcard_rule; -mod replace_grouping_func; - +mod resolve_grouping_analytics; use crate::analyzer::count_wildcard_rule::CountWildcardRule; -use crate::analyzer::replace_grouping_func::ReplaceGroupingFunc; +use crate::analyzer::resolve_grouping_analytics::ResolveGroupingAnalytics; use crate::rewrite::TreeNodeRewritable; use datafusion_common::config::ConfigOptions; use datafusion_common::{DataFusionError, Result}; @@ -56,8 +55,8 @@ impl Analyzer { /// Create a new analyzer using the recommended list of rules pub fn new() -> Self { let rules: Vec> = vec![ + Arc::new(ResolveGroupingAnalytics::new()), Arc::new(CountWildcardRule::new()), - Arc::new(ReplaceGroupingFunc::new()), ]; Self::with_rules(rules) } diff --git a/datafusion/optimizer/src/analyzer/replace_grouping_func.rs b/datafusion/optimizer/src/analyzer/replace_grouping_func.rs deleted file mode 100644 index 823886652e6b2..0000000000000 --- a/datafusion/optimizer/src/analyzer/replace_grouping_func.rs +++ /dev/null @@ -1,214 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::analyzer::AnalyzerRule; -use crate::rewrite::TreeNodeRewritable; -use arrow::datatypes::DataType; -use datafusion_common::config::ConfigOptions; -use datafusion_expr::expr::AggregateFunction; -use datafusion_expr::expr_rewriter::rewrite_expr; -use datafusion_expr::utils::find_exprs_in_expr; -use datafusion_expr::{ - aggregate_function, bitwise_and, bitwise_shift_right, cast, col, lit, Filter, - GroupingSet, Sort, -}; -use datafusion_expr::{Aggregate, Expr, LogicalPlan}; - -use datafusion_common::{Column, DataFusionError, Result}; - -use hashbrown::HashSet; - -pub struct ReplaceGroupingFunc; - -impl ReplaceGroupingFunc { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -const INTERNAL_GROUPING_COLUMN: &str = "grouping_id"; - -impl AnalyzerRule for ReplaceGroupingFunc { - fn analyze( - &self, - plan: &LogicalPlan, - _config: &ConfigOptions, - ) -> datafusion_common::Result { - plan.clone().transform_up(&|plan| match plan { - LogicalPlan::Aggregate(Aggregate { - input, - aggr_expr, - group_expr, - .. - }) if contains_grouping_funcs_in_exprs(&aggr_expr) => { - let gid_column = Expr::VirtualColumn(DataType::UInt32, INTERNAL_GROUPING_COLUMN.to_string()); - let distinct_group_by = distinct_group_exprs(&group_expr); - let new_agg_expr = aggr_expr - .into_iter() - .map(|expr| { - replace_grouping_func( - expr, - &distinct_group_by, - gid_column.clone(), - ) - }) - .collect::>>()?; - Ok(Some(LogicalPlan::Aggregate(Aggregate::try_new( - input, - group_expr, - new_agg_expr, - )?))) - } - LogicalPlan::Filter(Filter { predicate, .. }) - if contains_grouping_funcs(&predicate) => - { - Ok(None) - } - LogicalPlan::Sort(Sort { expr, .. }) - if contains_grouping_funcs_in_exprs(&expr) => - { - Ok(None) - } - _ => Ok(None), - }) - } - fn name(&self) -> &str { - "replace_grouping_func" - } -} - -pub fn distinct_group_exprs(group_expr: &[Expr]) -> Vec { - let mut dedup_expr = Vec::new(); - let mut dedup_set = HashSet::new(); - group_expr.iter().for_each(|expr| match expr { - Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => exprs.iter().for_each(|e| { - if !dedup_set.contains(e) { - dedup_expr.push(e.clone()); - dedup_set.insert(e.clone()); - } - }), - GroupingSet::Cube(exprs) => exprs.iter().for_each(|e| { - if !dedup_set.contains(e) { - dedup_expr.push(e.clone()); - dedup_set.insert(e.clone()); - } - }), - GroupingSet::GroupingSets(groups) => groups.iter().flatten().for_each(|e| { - if !dedup_set.contains(e) { - dedup_expr.push(e.clone()); - dedup_set.insert(e.clone()); - } - }), - }, - _ => { - if !dedup_set.contains(expr) { - dedup_expr.push(expr.clone()); - dedup_set.insert(expr.clone()); - } - } - }); - dedup_expr -} - -fn contains_grouping_funcs_in_exprs(aggr_expr: &[Expr]) -> bool { - aggr_expr.iter().any(|expr| { - !find_exprs_in_expr(expr, &|nested_expr| { - matches!( - nested_expr, - Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Grouping, - .. - }) | Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::GroupingID, - .. - }) - ) - }) - .is_empty() - }) -} - -fn contains_grouping_funcs(expr: &Expr) -> bool { - !find_exprs_in_expr(expr, &|nested_expr| { - matches!( - nested_expr, - Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Grouping, - .. - }) | Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::GroupingID, - .. - }) - ) - }) - .is_empty() -} - -fn replace_grouping_func( - expr: Expr, - group_by_exprs: &[Expr], - gid_column: Expr, -) -> Result { - rewrite_expr(expr, |expr| { - let display_name = expr.display_name()?; - match expr { - Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Grouping, - args, - .. - }) => { - let grouping_col = &args[0]; - match group_by_exprs.iter().position(|e| e == grouping_col) { - Some(idx) => Ok(cast( - bitwise_and( - bitwise_shift_right( - gid_column.clone(), - lit((group_by_exprs.len() - 1 - idx) as u32), - ), - lit(1u32), - ), - DataType::Binary, - ).alias(display_name)), - None => Err(DataFusionError::Plan(format!( - "Column of GROUPING({:?}) can't be found in GROUP BY columns {:?}", - grouping_col, group_by_exprs - ))), - } - } - Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::GroupingID, - args, - .. - }) => { - if group_by_exprs.is_empty() - || (group_by_exprs.len() == args.len() - && group_by_exprs.iter().zip(args.iter()).all(|(g, a)| g == a)) - { - Ok(gid_column.clone().alias(display_name)) - } else { - Err(DataFusionError::Plan(format!( - "Columns of GROUPING_ID({:?}) does not match GROUP BY columns {:?}", - args, group_by_exprs - ))) - } - } - _ => Ok(expr), - } - }) -} diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_analytics.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_analytics.rs new file mode 100644 index 0000000000000..9197be355808e --- /dev/null +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_analytics.rs @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::analyzer::AnalyzerRule; +use crate::rewrite::TreeNodeRewritable; +use arrow::datatypes::DataType; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::expr_rewriter::rewrite_expr; +use datafusion_expr::utils::{ + add_hidden_grouping_set_expr, contains_grouping_set, distinct_group_exprs, + enumerate_grouping_sets, generate_grouping_ids, +}; +use datafusion_expr::{ + aggregate_function, bitwise_and, bitwise_shift_right, cast, lit, Projection, +}; +use datafusion_expr::{Aggregate, Expr, LogicalPlan}; +use std::sync::Arc; + +use datafusion_common::{DataFusionError, Result}; + +pub struct ResolveGroupingAnalytics; + +impl ResolveGroupingAnalytics { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +// Internal column used to represent the grouping_id, used by the grouping functions. +// It is "spark_grouping_id" in Spark +const INTERNAL_GROUPING_ID: &str = "grouping_id"; +// Internal column used to represent different grouping sets when there are duplicated grouping sets +const INTERNAL_GROUPING_SET_ID: &str = "grouping_set_id"; + +impl AnalyzerRule for ResolveGroupingAnalytics { + fn analyze( + &self, + plan: &LogicalPlan, + _config: &ConfigOptions, + ) -> datafusion_common::Result { + plan.clone().transform_down(&|plan| match plan { + LogicalPlan::Aggregate(Aggregate { + input, + aggr_expr, + group_expr, + .. + }) if contains_grouping_set(&group_expr) => { + let mut expanded_grouping = enumerate_grouping_sets(&group_expr)?; + let mut new_project_exec = vec![]; + if let [Expr::GroupingSet(ref mut grouping_set)] = expanded_grouping.as_mut_slice() { + if !grouping_set.contains_hidden_expr() { + let new_agg_expr = if contains_grouping_funcs_as_agg_expr(&aggr_expr) { + let gid_column = Expr::HiddenColumn( + DataType::UInt32, + INTERNAL_GROUPING_ID.to_string(), + ); + let hidden_name = gid_column.display_name()?; + let grouping_ids = generate_grouping_ids(grouping_set)?; + let hidden_grouping_expr = |group_set_idx: usize| Expr::HiddenExpr(Box::new(lit(grouping_ids[group_set_idx]) + .alias(hidden_name.clone())), Box::new(gid_column.clone())); + add_hidden_grouping_set_expr(grouping_set, hidden_grouping_expr)?; + + let distinct_group_by = distinct_group_exprs(&group_expr, false); + let mut new_agg_expr = vec![]; + aggr_expr.into_iter().try_for_each(|expr| { + let new_expr = replace_grouping_func( + expr.clone(), + &distinct_group_by, + gid_column.clone(), + )?; + // The grouping func is rewrited to a normal expr, not the AggregateFunction anymore, remove it from the aggr_expr + if new_expr.ne(&expr) { + new_project_exec.push(new_expr); + } else { + new_agg_expr.push(new_expr); + } + Ok::<(), DataFusionError>(()) + })?; + new_agg_expr + } else { + aggr_expr + }; + if grouping_set.contains_duplicate_grouping() { + let grouping_set_id_column = Expr::HiddenColumn( + DataType::UInt32, + INTERNAL_GROUPING_SET_ID.to_string(), + ); + let hidden_name = grouping_set_id_column.display_name()?; + let hidden_grouping_expr = |group_set_idx: usize| Expr::HiddenExpr(Box::new(lit((group_set_idx + 1) as u32) + .alias(hidden_name.clone())), Box::new(grouping_set_id_column.clone())); + add_hidden_grouping_set_expr(grouping_set, hidden_grouping_expr)?; + } + + let aggregate = Aggregate::try_new( + input, + vec![Expr::GroupingSet(grouping_set.clone())], + new_agg_expr, + )?; + let agg_schema = aggregate.schema.clone(); + let new_agg = LogicalPlan::Aggregate(aggregate); + if !new_project_exec.is_empty() { + let mut expr: Vec = agg_schema + .fields() + .iter() + .map(|field| field.qualified_column()) + .map(Expr::Column) + .collect(); + expr.append(&mut new_project_exec); + Ok(Some(LogicalPlan::Projection(Projection::try_new(expr, Arc::new(new_agg))?))) + } else { + Ok(Some(new_agg)) + } + } else { + Ok(None) + } + } else { + Err(DataFusionError::Plan( + "Invalid group by expressions, GroupingSet must be the only expression" + .to_string(), + )) + } + } + _ => Ok(None), + }) + } + fn name(&self) -> &str { + "resolve_grouping_analytics" + } +} + +fn contains_grouping_funcs_as_agg_expr(aggr_expr: &[Expr]) -> bool { + aggr_expr.iter().any(|expr| { + matches!( + expr, + Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::Grouping, + .. + }) | Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::GroupingId, + .. + }) + ) + }) +} + +fn replace_grouping_func( + expr: Expr, + group_by_exprs: &[Expr], + gid_column: Expr, +) -> Result { + rewrite_expr(expr, |expr| { + let display_name = expr.display_name()?; + match expr { + Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::Grouping, + args, + .. + }) => { + let grouping_col = &args[0]; + match group_by_exprs.iter().position(|e| e == grouping_col) { + Some(idx) => Ok(cast( + bitwise_and( + bitwise_shift_right( + gid_column.clone(), + lit((group_by_exprs.len() - 1 - idx) as u32), + ), + lit(1u32), + ), + DataType::UInt8, + ).alias(display_name)), + None => Err(DataFusionError::Plan(format!( + "Column of GROUPING({:?}) can't be found in GROUP BY columns {:?}", + grouping_col, group_by_exprs + ))), + } + } + Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::GroupingId, + args, + .. + }) => { + if group_by_exprs.is_empty() + || (group_by_exprs.len() == args.len() + && group_by_exprs.iter().zip(args.iter()).all(|(g, a)| g == a)) + { + Ok(gid_column.clone().alias(display_name)) + } else { + Err(DataFusionError::Plan(format!( + "Columns of GROUPING_ID({:?}) does not match GROUP BY columns {:?}", + args, group_by_exprs + ))) + } + } + _ => Ok(expr), + } + }) +} diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index fef5d15f78567..33bf676db1281 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -18,7 +18,6 @@ //! Eliminate common sub-expression. use std::collections::{BTreeSet, HashMap}; -use std::convert::identity; use std::sync::Arc; use arrow::datatypes::DataType; @@ -538,14 +537,8 @@ impl ExprRewriter for CommonSubexprRewriter<'_> { if self.curr_index >= self.id_array.len() { return Ok(expr); } - if matches!(expr, Expr::VirtualColumn(_, _)) { - return Ok(expr); - } let (series_number, id) = &self.id_array[self.curr_index]; - if id.eq("_virtual_grouping_id") { - return Ok(expr); - } self.curr_index += 1; // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. let expr_set_item = self.expr_set.get(id).ok_or_else(|| { diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 55c77e51e2d3d..b8d1fdd1397fc 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -649,7 +649,9 @@ impl OptimizerRule for PushDownFilter { let mut push_predicates = vec![]; for expr in predicates { let cols = expr.to_columns()?; - if cols.iter().all(|c| group_expr_columns.contains(c)) { + if !expr.contains_hidden_columns() + && cols.iter().all(|c| group_expr_columns.contains(c)) + { push_predicates.push(expr); } else { keep_predicates.push(expr); diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 767077aa0c027..f21900497a2a2 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -329,6 +329,8 @@ impl OptimizerRule for PushDownProjection { let new_proj = plan.with_new_inputs(&[filter.input.as_ref().clone()])?; child_plan.with_new_inputs(&[new_proj])? + } else if filter.predicate.contains_hidden_columns() { + return Ok(None); } else { let mut required_columns = HashSet::new(); exprlist_to_columns(&projection.expr, &mut required_columns)?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 66a5a365162bf..dd94d93691bd5 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -255,7 +255,8 @@ impl<'a> ConstEvaluator<'a> { | Expr::ScalarVariable(_, _) | Expr::Column(_) | Expr::OuterReferenceColumn(_, _) - | Expr::VirtualColumn(_, _) + | Expr::HiddenColumn(_, _) + | Expr::HiddenExpr(_, _) | Expr::Exists { .. } | Expr::InSubquery { .. } | Expr::ScalarSubquery(_) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index cee31b5b33522..e0cda718262d3 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -20,6 +20,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DFSchema, Result}; +use datafusion_expr::utils::contains_grouping_set; use datafusion_expr::{ col, expr::AggregateFunction, @@ -82,11 +83,6 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { } } -/// Check if the first expr is [Expr::GroupingSet]. -fn contains_grouping_set(expr: &[Expr]) -> bool { - matches!(expr.first(), Some(Expr::GroupingSet(_))) -} - impl OptimizerRule for SingleDistinctToGroupBy { fn try_optimize( &self, diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 774208b6b6cba..f0ab7baf33741 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -247,12 +247,12 @@ pub fn create_aggregate_expr( } (AggregateFunction::Grouping, _) => { return Err(DataFusionError::Plan( - "GROUPING() aggregations are not evaluable".to_string(), + "GROUPING() aggregations are not evaluable, should be converted by the Analyzer".to_string(), )); } - (AggregateFunction::GroupingID, _) => { + (AggregateFunction::GroupingId, _) => { return Err(DataFusionError::Plan( - "GROUPING_ID() aggregations are not evaluable".to_string(), + "GROUPING_ID() aggregations are not evaluable, should be converted by the Analyzer".to_string(), )); } }) diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index eb2be5ef217c1..43c06f22b565a 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -202,6 +202,77 @@ impl PartialEq for UnKnownColumn { } } +#[derive(Debug, Hash, PartialEq, Eq, Clone)] +pub struct HiddenColumn { + name: String, + data_type: DataType, +} + +impl HiddenColumn { + /// Create a new hidden column + pub fn new(name: &str, data_type: &DataType) -> Self { + Self { + name: name.to_owned(), + data_type: data_type.clone(), + } + } + + /// Get the column name + pub fn name(&self) -> &str { + &self.name + } +} + +impl std::fmt::Display for HiddenColumn { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +impl PhysicalExpr for HiddenColumn { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn std::any::Any { + self + } + + /// Get the data type of this expression, given the schema of the input + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.data_type.clone()) + } + + /// Decide whehter this expression is nullable, given the schema of the input + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + /// Evaluate the expression + fn evaluate(&self, _batch: &RecordBatch) -> Result { + Err(DataFusionError::Plan( + "HiddenColumn::evaluate() should not be called".to_owned(), + )) + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } +} + +impl PartialEq for HiddenColumn { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self == x) + .unwrap_or(false) + } +} + /// Create a column expression pub fn col(name: &str, schema: &Schema) -> Result> { Ok(Arc::new(Column::new_with_schema(name, schema)?)) diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 5efda1c6c1342..2ac11c23cc521 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -76,7 +76,7 @@ pub use case::{case, CaseExpr}; pub use cast::{ cast, cast_column, cast_with_options, CastExpr, DEFAULT_DATAFUSION_CAST_OPTIONS, }; -pub use column::{col, Column, UnKnownColumn}; +pub use column::{col, Column, HiddenColumn, UnKnownColumn}; pub use datetime::DateTimeIntervalExpr; pub use get_indexed_field::GetIndexedFieldExpr; pub use in_list::{in_list, InListExpr}; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 1fbd73b3ba01c..d63a6a897b4a8 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -26,7 +26,9 @@ use crate::{ PhysicalExpr, }; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + unqualified_field_not_found, DFSchema, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::expr::Cast; use datafusion_expr::{ binary_expr, Between, BinaryExpr, Expr, GetIndexedField, Like, Operator, TryCast, @@ -474,6 +476,22 @@ pub fn create_physical_expr( expressions::in_list(value_expr, list_exprs, negated, input_schema) } }, + Expr::HiddenColumn(_, _) => { + let hidden_col_name = e.display_name()?; + let col_idx = + input_dfschema.index_of_column_by_name(None, &hidden_col_name)?; + if let Some(idx) = col_idx { + Ok(Arc::new(Column::new(&hidden_col_name, idx))) + } else { + Err(unqualified_field_not_found( + &hidden_col_name, + input_dfschema, + )) + } + } + Expr::HiddenExpr(expr, _) => { + create_physical_expr(expr, input_dfschema, input_schema, execution_props) + } other => Err(DataFusionError::NotImplemented(format!( "Physical plan does not support logical expression {other:?}" ))), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 9d9b289d33377..ab0c2915c687f 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -502,7 +502,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian, protobuf::AggregateFunction::Grouping => Self::Grouping, protobuf::AggregateFunction::Median => Self::Median, - protobuf::AggregateFunction::GroupingId => Self::GroupingID, + protobuf::AggregateFunction::GroupingId => Self::GroupingId, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 980662aa07157..fd087ac1f94db 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -371,7 +371,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::ApproxMedian => Self::ApproxMedian, AggregateFunction::Grouping => Self::Grouping, AggregateFunction::Median => Self::Median, - AggregateFunction::GroupingID => Self::GroupingId, + AggregateFunction::GroupingId => Self::GroupingId, } } } @@ -631,7 +631,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, AggregateFunction::Median => protobuf::AggregateFunction::Median, - AggregateFunction::GroupingID => protobuf::AggregateFunction::GroupingId, + AggregateFunction::GroupingId => protobuf::AggregateFunction::GroupingId, }; let aggregate_expr = protobuf::AggregateExprNode { @@ -856,10 +856,10 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { Expr::Wildcard => Self { expr_type: Some(ExprType::Wildcard(true)), }, - Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::OuterReferenceColumn{..} | Expr::VirtualColumn{..} => { + Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::OuterReferenceColumn{..} | Expr::HiddenColumn{..} | Expr::HiddenExpr{..}=> { // we would need to add logical plan operators to datafusion.proto to support this // see discussion in https://github.com/apache/arrow-datafusion/issues/2565 - return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported | Exp:VirtualColumn not supported".to_string())); + return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported | Exp:HiddenColumn not supported | Exp:HiddenExpr not supported".to_string())); } Expr::GetIndexedField(GetIndexedField { key, expr }) => Self { diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 92986b0b39b26..9f0494122e9e0 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -449,7 +449,13 @@ impl AsExecutionPlan for PhysicalPlanNode { Ok(Arc::new(AggregateExec::try_new( agg_mode, - PhysicalGroupBy::new(group_expr, null_expr, groups), + PhysicalGroupBy::new( + group_expr.clone(), + vec![], + group_expr, + null_expr, + groups, + ), physical_aggr_expr, input, Arc::new((&input_schema).try_into()?), diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 25b9ce680a234..8c0cd0e6a53f0 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -362,7 +362,8 @@ where ))), Expr::Column { .. } | Expr::OuterReferenceColumn(_, _) - | Expr::VirtualColumn(_, _) + | Expr::HiddenColumn(_, _) + | Expr::HiddenExpr(_, _) | Expr::Literal(_) | Expr::ScalarVariable(_, _) | Expr::Exists { .. } From 4e065c9adc3e7db18bc66f56a28285eaa5ddcbae Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Mon, 27 Mar 2023 14:38:16 +0800 Subject: [PATCH 5/6] fix failed UT, make the sort result stable --- datafusion/core/tests/sql/group_by.rs | 46 +++++++++++++-------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs index 02d53b0c9bafe..dcb6002cac6d7 100644 --- a/datafusion/core/tests/sql/group_by.rs +++ b/datafusion/core/tests/sql/group_by.rs @@ -1262,13 +1262,13 @@ async fn group_by_with_grouping_func_and_order_by() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_csv(&ctx).await?; let sql = "SELECT c1, avg(c12), GROUPING_ID(C1, c2) FROM aggregate_test_100 \ - GROUP BY CUBE(c1,c2) ORDER BY GROUPING_ID(C1, c2) DESC"; + GROUP BY CUBE(c1,c2) ORDER BY GROUPING_ID(C1, c2) DESC, avg(c12) ASC"; let msg = format!("Creating logical plan for '{sql}'"); let dataframe = ctx.sql(sql).await.expect(&msg); let plan = dataframe.clone().into_optimized_plan()?; let expected = vec![ - "Sort: GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2) DESC NULLS FIRST [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N, GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2):UInt32;N]", + "Sort: GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2) DESC NULLS FIRST, AVG(aggregate_test_100.c12) ASC NULLS LAST [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N, GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2):UInt32;N]", " Projection: aggregate_test_100.c1, AVG(aggregate_test_100.c12), #grouping_id AS GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2) [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N, GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2):UInt32;N]", " Aggregate: groupBy=[[GROUPING SETS ((UInt32(3) AS #grouping_id), (aggregate_test_100.c1, UInt32(1) AS #grouping_id), (aggregate_test_100.c2, UInt32(2) AS #grouping_id), (aggregate_test_100.c1, aggregate_test_100.c2, UInt32(0) AS #grouping_id))]], aggr=[[AVG(aggregate_test_100.c12)]] [c1:Utf8, c2:UInt32, #grouping_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", " TableScan: aggregate_test_100 projection=[c1, c2, c12] [c1:Utf8, c2:UInt32, c12:Float64]", @@ -1287,41 +1287,41 @@ async fn group_by_with_grouping_func_and_order_by() -> Result<()> { "| c1 | AVG(aggregate_test_100.c12) | GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2) |", "+----+-----------------------------+----------------------------------------------------------+", "| | 0.5089725099127211 | 3 |", - "| | 0.6545641966127662 | 2 |", - "| | 0.5245329062820169 | 2 |", + "| | 0.40234192123489837 | 2 |", "| | 0.4312272637333415 | 2 |", "| | 0.5108939802619781 | 2 |", - "| | 0.40234192123489837 | 2 |", + "| | 0.5245329062820169 | 2 |", + "| | 0.6545641966127662 | 2 |", "| b | 0.41040709263815384 | 1 |", + "| e | 0.48600669271341534 | 1 |", "| a | 0.48754517466109415 | 1 |", "| d | 0.48855379387549824 | 1 |", "| c | 0.6600456536439784 | 1 |", - "| e | 0.48600669271341534 | 1 |", - "| e | 0.5165824734324667 | 0 |", - "| c | 0.6430620563927849 | 0 |", - "| c | 0.6827805579021969 | 0 |", - "| b | 0.5857678873564655 | 0 |", - "| c | 0.7277229477969185 | 0 |", + "| b | 0.16148594845154118 | 0 |", + "| d | 0.2488799233225611 | 0 |", "| e | 0.2720288398836001 | 0 |", - "| e | 0.780297346359783 | 0 |", "| e | 0.29536905073188496 | 0 |", - "| b | 0.42804338065410286 | 0 |", - "| b | 0.4888141504446429 | 0 |", "| b | 0.33400957036260354 | 0 |", + "| a | 0.3497223654469457 | 0 |", "| a | 0.3653038379118398 | 0 |", + "| c | 0.421733279717472 | 0 |", + "| b | 0.42804338065410286 | 0 |", + "| a | 0.4693685626367209 | 0 |", + "| b | 0.4888141504446429 | 0 |", + "| d | 0.49575895804943215 | 0 |", + "| d | 0.49931809179640024 | 0 |", + "| e | 0.5165824734324667 | 0 |", "| d | 0.5181987328311988 | 0 |", + "| b | 0.5857678873564655 | 0 |", + "| d | 0.586369575965718 | 0 |", "| a | 0.5945188963859894 | 0 |", - "| c | 0.7736013221256991 | 0 |", - "| b | 0.16148594845154118 | 0 |", "| a | 0.5996111195922015 | 0 |", - "| d | 0.586369575965718 | 0 |", - "| d | 0.2488799233225611 | 0 |", - "| a | 0.4693685626367209 | 0 |", - "| d | 0.49931809179640024 | 0 |", + "| c | 0.6430620563927849 | 0 |", "| e | 0.660795726704708 | 0 |", - "| a | 0.3497223654469457 | 0 |", - "| d | 0.49575895804943215 | 0 |", - "| c | 0.421733279717472 | 0 |", + "| c | 0.6827805579021969 | 0 |", + "| c | 0.7277229477969185 | 0 |", + "| c | 0.7736013221256991 | 0 |", + "| e | 0.780297346359783 | 0 |", "+----+-----------------------------+----------------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); From f6059e8d469a37ac18cb36f38de096546dbf5bfd Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Tue, 28 Mar 2023 14:20:27 +0800 Subject: [PATCH 6/6] minor change --- datafusion/expr/src/utils.rs | 2 +- .../optimizer/src/analyzer/resolve_grouping_analytics.rs | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index fcaff7f1b24af..ca7c30d326ca6 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -58,7 +58,7 @@ pub fn contains_grouping_set(group_expr: &[Expr]) -> bool { group_expr.iter().any(|e| matches!(e, Expr::GroupingSet(_))) } -/// Check whether the group_expr contains [Expr::GroupingSet] without any hidden exprs. +/// Check whether the group_expr contains [Expr::GroupingSet] without any hidden expr. pub fn contains_grouping_set_without_hidden_expr(group_expr: &[Expr]) -> bool { group_expr.iter().any(|e| matches!(e, Expr::GroupingSet(grouping_set) if !grouping_set.contains_hidden_expr())) } diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_analytics.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_analytics.rs index 269b39712aff2..a6c8c2476a55d 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_analytics.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_analytics.rs @@ -21,9 +21,8 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_expr::expr::AggregateFunction; use datafusion_expr::utils::{ - add_hidden_grouping_set_expr, - contains_grouping_set_without_hidden_expr, distinct_group_exprs, - enumerate_grouping_sets, generate_grouping_ids, + add_hidden_grouping_set_expr, contains_grouping_set_without_hidden_expr, + distinct_group_exprs, enumerate_grouping_sets, generate_grouping_ids, }; use datafusion_expr::{ aggregate_function, bitwise_and, bitwise_shift_right, cast, lit, Projection,