From 183b9671ff50f20bd5fe3cc6716eaa7f9841fc73 Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 1 Dec 2024 23:05:43 +0800 Subject: [PATCH 1/4] draft for `EliminateUnnecessaryGroupByKeys`. --- .../eliminate_unnecessary_group_by_keys.rs | 184 ++++++++++++++++++ datafusion/optimizer/src/lib.rs | 6 +- 2 files changed, 187 insertions(+), 3 deletions(-) create mode 100644 datafusion/optimizer/src/eliminate_unnecessary_group_by_keys.rs diff --git a/datafusion/optimizer/src/eliminate_unnecessary_group_by_keys.rs b/datafusion/optimizer/src/eliminate_unnecessary_group_by_keys.rs new file mode 100644 index 0000000000000..59a64e2ff5c93 --- /dev/null +++ b/datafusion/optimizer/src/eliminate_unnecessary_group_by_keys.rs @@ -0,0 +1,184 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`EliminateDuplicatedExpr`] Removes redundant expressions + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::Transformed; +use datafusion_common::{Column, HashSet, Result}; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::{Aggregate, Expr, LogicalPlanBuilder, Sort, SortExpr}; +use indexmap::IndexSet; +use std::hash::{Hash, Hasher}; + +/// Optimization rule that eliminate unnecessary group by keys +#[derive(Default, Debug)] +pub struct EliminateUnnecessaryGroupByKeys { + column_group_keys: HashSet, +} + +impl EliminateUnnecessaryGroupByKeys { + pub fn new() -> Self { + Self { + column_group_keys: HashSet::new(), + } + } +} + +impl OptimizerRule for EliminateUnnecessaryGroupByKeys { + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + match plan { + LogicalPlan::Aggregate(agg) => { + let len = agg.group_expr.len(); + + // Collect column group keys + let mut column_group_keys = HashSet::new(); + for group_key in agg.group_expr.iter() { + if let Expr::Column(col) = group_key { + column_group_keys.insert(col.clone()); + } + } + + // If no group keys, just return + if column_group_keys.is_empty() { + return Ok(Transformed::no(LogicalPlan::Aggregate(agg))); + } + + // Try to eliminate the unnecessary group keys + let mut keep_group_by_keys = Vec::new(); + for group_key in agg.group_expr.iter() { + if matches!(&group_key, Expr::BinaryExpr(_)) + || matches!(&group_key, Expr::ScalarFunction(_)) + { + // If all of the cols in `column_group_keys`, we should eliminate this key. + // For example, `a + 1` in `group by a, a + 1` should be eliminated. + let cols_in_key = group_key.column_refs(); + + if cols_in_key.is_empty() + || cols_in_key + .iter() + .any(|col| !column_group_keys.contains(*col)) + { + keep_group_by_keys.push(group_key.clone()); + } + } else { + keep_group_by_keys.push(group_key.clone()); + } + } + + if len != keep_group_by_keys.len() { + let projection_expr = agg.group_expr.into_iter().chain(agg.aggr_expr.clone()); + let new_plan = LogicalPlanBuilder::from(agg.input) + .aggregate(keep_group_by_keys, agg.aggr_expr)? + .project(projection_expr)? + .build()?; + + Ok(Transformed::yes(new_plan)) + } else { + Ok(Transformed::no(LogicalPlan::Aggregate(agg))) + } + } + _ => Ok(Transformed::no(plan)), + } + } + + fn name(&self) -> &str { + "eliminate_unnecessary_group_by_keys" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{test::*, Optimizer, OptimizerContext}; + use datafusion_expr::{ + binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Operator, + }; + + use datafusion_functions_aggregate::expr_fn::count; + use std::sync::Arc; + + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { + crate::test::assert_optimized_plan_eq( + Arc::new(EliminateUnnecessaryGroupByKeys::new()), + plan, + expected, + ) + } + + #[test] + fn eliminate_binary_group_by_keys() { + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a"), binary_expr(col("a"), Operator::Plus, lit(1))], + vec![count(col("c"))], + ) + .unwrap() + .build() + .unwrap(); + + let opt_context = OptimizerContext::new().with_max_passes(1); + let optimizer = + Optimizer::with_rules(vec![Arc::new(EliminateUnnecessaryGroupByKeys::new())]); + let optimized_plan = optimizer + .optimize( + plan, + &opt_context, + |_plan: &LogicalPlan, _rule: &dyn OptimizerRule| {}, + ) + .unwrap(); + println!("{optimized_plan}"); + // let expected = "Limit: skip=5, fetch=10\ + // \n Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST\ + // \n TableScan: test"; + // assert_optimized_plan_eq(plan, expected) + } + + // #[test] + // fn eliminate_sort_exprs_with_options() -> Result<()> { + // let table_scan = test_table_scan().unwrap(); + // let sort_exprs = vec![ + // col("a").sort(true, true), + // col("b").sort(true, false), + // col("a").sort(false, false), + // col("b").sort(false, true), + // ]; + // let plan = LogicalPlanBuilder::from(table_scan) + // .sort(sort_exprs)? + // .limit(5, Some(10))? + // .build()?; + // let expected = "Limit: skip=5, fetch=10\ + // \n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\ + // \n TableScan: test"; + // assert_optimized_plan_eq(plan, expected) + // } + // } +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 263770b81fcdc..6ccb0fbdfdab5 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -55,11 +55,11 @@ pub mod replace_distinct_aggregate; pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; -pub mod unwrap_cast_in_comparison; -pub mod utils; - #[cfg(test)] pub mod test; +pub mod unwrap_cast_in_comparison; +pub mod utils; +pub mod eliminate_unnecessary_group_by_keys; pub use analyzer::{Analyzer, AnalyzerRule}; pub use optimizer::{Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule}; From 5523f989b473380ccfaf2679359651017b4f6e19 Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 1 Dec 2024 23:41:12 +0800 Subject: [PATCH 2/4] ensure `EliminateUnnecessaryGroupByKeys` will work. --- .../src/eliminate_unnecessary_group_by_keys.rs | 13 ++++++------- datafusion/optimizer/src/optimizer.rs | 2 ++ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_unnecessary_group_by_keys.rs b/datafusion/optimizer/src/eliminate_unnecessary_group_by_keys.rs index 59a64e2ff5c93..05508a5f5d708 100644 --- a/datafusion/optimizer/src/eliminate_unnecessary_group_by_keys.rs +++ b/datafusion/optimizer/src/eliminate_unnecessary_group_by_keys.rs @@ -28,15 +28,11 @@ use std::hash::{Hash, Hasher}; /// Optimization rule that eliminate unnecessary group by keys #[derive(Default, Debug)] -pub struct EliminateUnnecessaryGroupByKeys { - column_group_keys: HashSet, -} +pub struct EliminateUnnecessaryGroupByKeys {} impl EliminateUnnecessaryGroupByKeys { pub fn new() -> Self { - Self { - column_group_keys: HashSet::new(), - } + Self {} } } @@ -76,6 +72,8 @@ impl OptimizerRule for EliminateUnnecessaryGroupByKeys { for group_key in agg.group_expr.iter() { if matches!(&group_key, Expr::BinaryExpr(_)) || matches!(&group_key, Expr::ScalarFunction(_)) + || matches!(&group_key, Expr::Cast(_)) + || matches!(&group_key, Expr::TryCast(_)) { // If all of the cols in `column_group_keys`, we should eliminate this key. // For example, `a + 1` in `group by a, a + 1` should be eliminated. @@ -94,7 +92,8 @@ impl OptimizerRule for EliminateUnnecessaryGroupByKeys { } if len != keep_group_by_keys.len() { - let projection_expr = agg.group_expr.into_iter().chain(agg.aggr_expr.clone()); + let projection_expr = + agg.group_expr.into_iter().chain(agg.aggr_expr.clone()); let new_plan = LogicalPlanBuilder::from(agg.input) .aggregate(keep_group_by_keys, agg.aggr_expr)? .project(projection_expr)? diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 975150cd61220..333e8a6e96b4c 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -42,6 +42,7 @@ use crate::eliminate_limit::EliminateLimit; use crate::eliminate_nested_union::EliminateNestedUnion; use crate::eliminate_one_union::EliminateOneUnion; use crate::eliminate_outer_join::EliminateOuterJoin; +use crate::eliminate_unnecessary_group_by_keys::EliminateUnnecessaryGroupByKeys; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; use crate::filter_null_join_keys::FilterNullJoinKeys; use crate::optimize_projections::OptimizeProjections; @@ -250,6 +251,7 @@ impl Optimizer { Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), Arc::new(EliminateDuplicatedExpr::new()), + Arc::new(EliminateUnnecessaryGroupByKeys::new()), Arc::new(EliminateFilter::new()), Arc::new(EliminateCrossJoin::new()), Arc::new(CommonSubexprEliminate::new()), From 4fa399eb18092c129525b39baa3c6978c5cc3af6 Mon Sep 17 00:00:00 2001 From: kamille Date: Sun, 1 Dec 2024 23:51:15 +0800 Subject: [PATCH 3/4] quick dbg. --- datafusion/core/src/execution/session_state.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index e99cf82223815..3b099fb650ded 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -720,6 +720,7 @@ impl SessionState { logical_plan: &LogicalPlan, ) -> datafusion_common::Result> { let logical_plan = self.optimize(logical_plan)?; + println!("{logical_plan}"); self.query_planner .create_physical_plan(&logical_plan, self) .await From 57f67dc3d2b10eaf1e09c8acf21bb9b444c5d2e2 Mon Sep 17 00:00:00 2001 From: kamille Date: Mon, 2 Dec 2024 00:01:18 +0800 Subject: [PATCH 4/4] remove dbg log. --- datafusion/core/src/execution/session_state.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 3b099fb650ded..e99cf82223815 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -720,7 +720,6 @@ impl SessionState { logical_plan: &LogicalPlan, ) -> datafusion_common::Result> { let logical_plan = self.optimize(logical_plan)?; - println!("{logical_plan}"); self.query_planner .create_physical_plan(&logical_plan, self) .await