diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 6ab1dd1fc9ac5..e8158e632469e 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -17,7 +17,7 @@ //! Eliminate common sub-expression. -use crate::{OptimizerConfig, OptimizerRule}; +use crate::{utils, OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::{ @@ -25,7 +25,6 @@ use datafusion_expr::{ expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}, expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}, logical_plan::{Aggregate, Filter, LogicalPlan, Projection, Sort, Window}, - utils::from_plan, Expr, ExprSchemable, }; use std::collections::{BTreeSet, HashMap}; @@ -54,13 +53,200 @@ type Identifier = String; /// be eliminated. pub struct CommonSubexprEliminate {} +impl CommonSubexprEliminate { + fn rewrite_expr( + &self, + exprs_list: &[&[Expr]], + arrays_list: &[&[Vec<(usize, String)>]], + input: &LogicalPlan, + expr_set: &mut ExprSet, + optimizer_config: &mut OptimizerConfig, + ) -> Result<(Vec>, LogicalPlan)> { + let mut affected_id = BTreeSet::::new(); + + let rewrite_exprs = exprs_list + .iter() + .zip(arrays_list.iter()) + .map(|(exprs, arrays)| { + exprs + .iter() + .cloned() + .zip(arrays.iter()) + .map(|(expr, id_array)| { + replace_common_expr(expr, id_array, expr_set, &mut affected_id) + }) + .collect::>>() + }) + .collect::>>()?; + + let mut new_input = self.optimize(input, optimizer_config)?; + if !affected_id.is_empty() { + new_input = build_project_plan(new_input, affected_id, expr_set)?; + } + + Ok((rewrite_exprs, new_input)) + } +} + impl OptimizerRule for CommonSubexprEliminate { fn optimize( &self, plan: &LogicalPlan, optimizer_config: &mut OptimizerConfig, ) -> Result { - optimize(plan, optimizer_config) + let mut expr_set = ExprSet::new(); + + match plan { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + alias, + }) => { + let input_schema = Arc::clone(input.schema()); + let arrays = to_arrays(expr, input_schema, &mut expr_set)?; + + let (mut new_expr, new_input) = self.rewrite_expr( + &[expr], + &[&arrays], + input, + &mut expr_set, + optimizer_config, + )?; + + Ok(LogicalPlan::Projection(Projection::try_new_with_schema( + pop_expr(&mut new_expr)?, + Arc::new(new_input), + schema.clone(), + alias.clone(), + )?)) + } + LogicalPlan::Filter(filter) => { + let input = filter.input(); + let predicate = filter.predicate(); + let input_schema = Arc::clone(input.schema()); + let mut id_array = vec![]; + expr_to_identifier( + predicate, + &mut expr_set, + &mut id_array, + input_schema, + )?; + + let (mut new_expr, new_input) = self.rewrite_expr( + &[&[predicate.clone()]], + &[&[id_array]], + filter.input(), + &mut expr_set, + optimizer_config, + )?; + + if let Some(predicate) = pop_expr(&mut new_expr)?.pop() { + Ok(LogicalPlan::Filter(Filter::try_new( + predicate, + Arc::new(new_input), + )?)) + } else { + Err(DataFusionError::Internal( + "Failed to pop predicate expr".to_string(), + )) + } + } + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) => { + let input_schema = Arc::clone(input.schema()); + let arrays = to_arrays(window_expr, input_schema, &mut expr_set)?; + + let (mut new_expr, new_input) = self.rewrite_expr( + &[window_expr], + &[&arrays], + input, + &mut expr_set, + optimizer_config, + )?; + + Ok(LogicalPlan::Window(Window { + input: Arc::new(new_input), + window_expr: pop_expr(&mut new_expr)?, + schema: schema.clone(), + })) + } + LogicalPlan::Aggregate(Aggregate { + group_expr, + aggr_expr, + input, + schema, + }) => { + let input_schema = Arc::clone(input.schema()); + let group_arrays = + to_arrays(group_expr, Arc::clone(&input_schema), &mut expr_set)?; + let aggr_arrays = to_arrays(aggr_expr, input_schema, &mut expr_set)?; + + let (mut new_expr, new_input) = self.rewrite_expr( + &[group_expr, aggr_expr], + &[&group_arrays, &aggr_arrays], + input, + &mut expr_set, + optimizer_config, + )?; + // note the reversed pop order. + let new_aggr_expr = pop_expr(&mut new_expr)?; + let new_group_expr = pop_expr(&mut new_expr)?; + + Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema( + Arc::new(new_input), + new_group_expr, + new_aggr_expr, + schema.clone(), + )?)) + } + LogicalPlan::Sort(Sort { expr, input, fetch }) => { + let input_schema = Arc::clone(input.schema()); + let arrays = to_arrays(expr, input_schema, &mut expr_set)?; + + let (mut new_expr, new_input) = self.rewrite_expr( + &[expr], + &[&arrays], + input, + &mut expr_set, + optimizer_config, + )?; + + Ok(LogicalPlan::Sort(Sort { + expr: pop_expr(&mut new_expr)?, + input: Arc::new(new_input), + fetch: *fetch, + })) + } + LogicalPlan::Join { .. } + | LogicalPlan::CrossJoin(_) + | LogicalPlan::Repartition(_) + | LogicalPlan::Union(_) + | LogicalPlan::TableScan { .. } + | LogicalPlan::Values(_) + | LogicalPlan::EmptyRelation(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::CreateExternalTable(_) + | LogicalPlan::Explain { .. } + | LogicalPlan::Analyze { .. } + | LogicalPlan::CreateMemoryTable(_) + | LogicalPlan::CreateView(_) + | LogicalPlan::CreateCatalogSchema(_) + | LogicalPlan::CreateCatalog(_) + | LogicalPlan::DropTable(_) + | LogicalPlan::DropView(_) + | LogicalPlan::SetVariable(_) + | LogicalPlan::Distinct(_) + | LogicalPlan::Extension { .. } => { + // apply the optimization to all inputs of the plan + utils::optimize_children(self, plan, optimizer_config) + } + } } fn name(&self) -> &str { @@ -81,167 +267,6 @@ impl CommonSubexprEliminate { } } -fn optimize( - plan: &LogicalPlan, - optimizer_config: &OptimizerConfig, -) -> Result { - let mut expr_set = ExprSet::new(); - - match plan { - LogicalPlan::Projection(Projection { - expr, - input, - schema, - alias, - }) => { - let input_schema = Arc::clone(input.schema()); - let arrays = to_arrays(expr, input_schema, &mut expr_set)?; - - let (mut new_expr, new_input) = rewrite_expr( - &[expr], - &[&arrays], - input, - &mut expr_set, - optimizer_config, - )?; - - Ok(LogicalPlan::Projection(Projection::try_new_with_schema( - pop_expr(&mut new_expr)?, - Arc::new(new_input), - schema.clone(), - alias.clone(), - )?)) - } - LogicalPlan::Filter(filter) => { - let input = filter.input(); - let predicate = filter.predicate(); - let input_schema = Arc::clone(input.schema()); - let mut id_array = vec![]; - expr_to_identifier(predicate, &mut expr_set, &mut id_array, input_schema)?; - - let (mut new_expr, new_input) = rewrite_expr( - &[&[predicate.clone()]], - &[&[id_array]], - filter.input(), - &mut expr_set, - optimizer_config, - )?; - - if let Some(predicate) = pop_expr(&mut new_expr)?.pop() { - Ok(LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(new_input), - )?)) - } else { - Err(DataFusionError::Internal( - "Failed to pop predicate expr".to_string(), - )) - } - } - LogicalPlan::Window(Window { - input, - window_expr, - schema, - }) => { - let input_schema = Arc::clone(input.schema()); - let arrays = to_arrays(window_expr, input_schema, &mut expr_set)?; - - let (mut new_expr, new_input) = rewrite_expr( - &[window_expr], - &[&arrays], - input, - &mut expr_set, - optimizer_config, - )?; - - Ok(LogicalPlan::Window(Window { - input: Arc::new(new_input), - window_expr: pop_expr(&mut new_expr)?, - schema: schema.clone(), - })) - } - LogicalPlan::Aggregate(Aggregate { - group_expr, - aggr_expr, - input, - schema, - }) => { - let input_schema = Arc::clone(input.schema()); - let group_arrays = - to_arrays(group_expr, Arc::clone(&input_schema), &mut expr_set)?; - let aggr_arrays = to_arrays(aggr_expr, input_schema, &mut expr_set)?; - - let (mut new_expr, new_input) = rewrite_expr( - &[group_expr, aggr_expr], - &[&group_arrays, &aggr_arrays], - input, - &mut expr_set, - optimizer_config, - )?; - // note the reversed pop order. - let new_aggr_expr = pop_expr(&mut new_expr)?; - let new_group_expr = pop_expr(&mut new_expr)?; - - Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema( - Arc::new(new_input), - new_group_expr, - new_aggr_expr, - schema.clone(), - )?)) - } - LogicalPlan::Sort(Sort { expr, input, fetch }) => { - let input_schema = Arc::clone(input.schema()); - let arrays = to_arrays(expr, input_schema, &mut expr_set)?; - - let (mut new_expr, new_input) = rewrite_expr( - &[expr], - &[&arrays], - input, - &mut expr_set, - optimizer_config, - )?; - - Ok(LogicalPlan::Sort(Sort { - expr: pop_expr(&mut new_expr)?, - input: Arc::new(new_input), - fetch: *fetch, - })) - } - LogicalPlan::Join { .. } - | LogicalPlan::CrossJoin(_) - | LogicalPlan::Repartition(_) - | LogicalPlan::Union(_) - | LogicalPlan::TableScan { .. } - | LogicalPlan::Values(_) - | LogicalPlan::EmptyRelation(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) - | LogicalPlan::CreateExternalTable(_) - | LogicalPlan::Explain { .. } - | LogicalPlan::Analyze { .. } - | LogicalPlan::CreateMemoryTable(_) - | LogicalPlan::CreateView(_) - | LogicalPlan::CreateCatalogSchema(_) - | LogicalPlan::CreateCatalog(_) - | LogicalPlan::DropTable(_) - | LogicalPlan::DropView(_) - | LogicalPlan::SetVariable(_) - | LogicalPlan::Distinct(_) - | LogicalPlan::Extension { .. } => { - // apply the optimization to all inputs of the plan - let expr = plan.expressions(); - let inputs = plan.inputs(); - let new_inputs = inputs - .iter() - .map(|input_plan| optimize(input_plan, optimizer_config)) - .collect::>>()?; - - from_plan(plan, &expr, &new_inputs) - } - } -} - fn pop_expr(new_expr: &mut Vec>) -> Result> { new_expr .pop() @@ -285,7 +310,7 @@ fn build_project_plan( _ => { return Err(DataFusionError::Internal( "expr_set invalid state".to_string(), - )) + )); } } } @@ -307,39 +332,6 @@ fn build_project_plan( )?)) } -#[inline] -fn rewrite_expr( - exprs_list: &[&[Expr]], - arrays_list: &[&[Vec<(usize, String)>]], - input: &LogicalPlan, - expr_set: &mut ExprSet, - optimizer_config: &OptimizerConfig, -) -> Result<(Vec>, LogicalPlan)> { - let mut affected_id = BTreeSet::::new(); - - let rewrote_exprs = exprs_list - .iter() - .zip(arrays_list.iter()) - .map(|(exprs, arrays)| { - exprs - .iter() - .cloned() - .zip(arrays.iter()) - .map(|(expr, id_array)| { - replace_common_expr(expr, id_array, expr_set, &mut affected_id) - }) - .collect::>>() - }) - .collect::>>()?; - - let mut new_input = optimize(input, optimizer_config)?; - if !affected_id.is_empty() { - new_input = build_project_plan(new_input, affected_id, expr_set)?; - } - - Ok((rewrote_exprs, new_input)) -} - /// Go through an expression tree and generate identifier. /// /// An identifier contains information of the expression itself and its sub-expression.