From c7ef6187bfe653ca2051527a66b216bd1cd7a87a Mon Sep 17 00:00:00 2001 From: xudong963 Date: Sat, 9 Jul 2022 17:42:35 +0800 Subject: [PATCH 1/3] feat: add optimize rule: rewrite_disjunctive_predicate --- datafusion/core/src/execution/context.rs | 2 + datafusion/optimizer/src/lib.rs | 1 + .../src/rewrite_disjunctive_predicate.rs | 468 ++++++++++++++++++ 3 files changed, 471 insertions(+) create mode 100644 datafusion/optimizer/src/rewrite_disjunctive_predicate.rs diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 41964e33ac96a..96705bb0cab56 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -106,6 +106,7 @@ use datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists; use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn; use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; +use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use datafusion_sql::{ parser::DFParser, planner::{ContextProvider, SqlToRel}, @@ -1367,6 +1368,7 @@ impl SessionState { Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), Arc::new(ProjectionPushDown::new()), + Arc::new(RewriteDisjunctivePredicate::new()), ]; if config.config_options.get_bool(OPT_FILTER_NULL_JOIN_KEYS) { rules.push(Arc::new(FilterNullJoinKeys::default())); diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 588903ad08e27..6da67b6fc1327 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -33,6 +33,7 @@ pub mod single_distinct_to_groupby; pub mod subquery_filter_to_join; pub mod utils; +pub mod rewrite_disjunctive_predicate; #[cfg(test)] pub mod test; diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs new file mode 100644 index 0000000000000..1f11dd85e994c --- /dev/null +++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs @@ -0,0 +1,468 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::Result; +use datafusion_expr::logical_plan::{ + Aggregate, Analyze, CreateMemoryTable, CreateView, CrossJoin, Distinct, Explain, + Filter, Join, Limit, Projection, Repartition, Sort, Subquery, SubqueryAlias, Union, + Window, +}; +use datafusion_expr::Expr::BinaryExpr; +use datafusion_expr::{Expr, LogicalPlan, Operator}; +use std::sync::Arc; + +#[derive(Clone, PartialEq, Debug)] +enum Predicate { + And { args: Vec }, + Or { args: Vec }, + Other { expr: Box }, +} + +fn predicate(expr: &Expr) -> Result { + match expr { + BinaryExpr { left, op, right } => match op { + Operator::And => { + let args = vec![predicate(left)?, predicate(right)?]; + Ok(Predicate::And { args }) + } + Operator::Or => { + let args = vec![predicate(left)?, predicate(right)?]; + Ok(Predicate::Or { args }) + } + _ => Ok(Predicate::Other { + expr: Box::new(BinaryExpr { + left: left.clone(), + op: *op, + right: right.clone(), + }), + }), + }, + _ => Ok(Predicate::Other { + expr: Box::new(expr.clone()), + }), + } +} + +fn normalize_predicate(predicate: &Predicate) -> Expr { + match predicate { + Predicate::And { args } => { + assert!(args.len() >= 2); + let left = normalize_predicate(&args[0]); + let right = normalize_predicate(&args[1]); + let mut and_expr = BinaryExpr { + left: Box::new(left), + op: Operator::And, + right: Box::new(right), + }; + for arg in args.iter().skip(2) { + and_expr = BinaryExpr { + left: Box::new(and_expr), + op: Operator::And, + right: Box::new(normalize_predicate(arg)), + }; + } + and_expr + } + Predicate::Or { args } => { + assert!(args.len() >= 2); + let left = normalize_predicate(&args[0]); + let right = normalize_predicate(&args[1]); + let mut or_expr = BinaryExpr { + left: Box::new(left), + op: Operator::Or, + right: Box::new(right), + }; + for arg in args.iter().skip(2) { + or_expr = BinaryExpr { + left: Box::new(or_expr), + op: Operator::Or, + right: Box::new(normalize_predicate(arg)), + }; + } + or_expr + } + Predicate::Other { expr } => *expr.clone(), + } +} + +fn rewrite_predicate(predicate: &Predicate) -> Predicate { + match predicate { + Predicate::And { args } => { + let mut rewritten_args = Vec::with_capacity(args.len()); + for arg in args.iter() { + rewritten_args.push(rewrite_predicate(arg)); + } + rewritten_args = flatten_and_predicates(&rewritten_args); + Predicate::And { + args: rewritten_args, + } + } + Predicate::Or { args } => { + let mut rewritten_args = vec![]; + for arg in args.iter() { + rewritten_args.push(rewrite_predicate(arg)); + } + rewritten_args = flatten_or_predicates(&rewritten_args); + delete_duplicate_predicates(&rewritten_args) + } + Predicate::Other { expr } => Predicate::Other { + expr: Box::new(*expr.clone()), + }, + } +} + +fn flatten_and_predicates(and_predicates: &[Predicate]) -> Vec { + let mut flattened_predicates = vec![]; + for predicate in and_predicates { + match predicate { + Predicate::And { args } => { + flattened_predicates + .extend_from_slice(flatten_and_predicates(args).as_slice()); + } + _ => { + flattened_predicates.push(predicate.clone()); + } + } + } + flattened_predicates +} + +fn flatten_or_predicates(or_predicates: &[Predicate]) -> Vec { + let mut flattened_predicates = vec![]; + for predicate in or_predicates { + match predicate { + Predicate::Or { args } => { + flattened_predicates + .extend_from_slice(flatten_or_predicates(args).as_slice()); + } + _ => { + flattened_predicates.push(predicate.clone()); + } + } + } + flattened_predicates +} + +fn delete_duplicate_predicates(or_predicates: &[Predicate]) -> Predicate { + let mut shortest_exprs: Vec = vec![]; + let mut shortest_exprs_len = 0; + // choose the shortest AND predicate + for or_predicate in or_predicates.iter() { + match or_predicate { + Predicate::And { args } => { + let args_num = args.len(); + if shortest_exprs.is_empty() || args_num < shortest_exprs_len { + shortest_exprs = (*args).clone(); + shortest_exprs_len = args_num; + } + } + _ => { + // if there is no AND predicate, it must be the shortest expression. + shortest_exprs = vec![or_predicate.clone()]; + break; + } + } + } + + // dedup shortest_exprs + shortest_exprs.dedup(); + + // Check each element in shortest_exprs to see if it's in all the OR arguments. + let mut exist_exprs: Vec = vec![]; + for expr in shortest_exprs.iter() { + let mut found = true; + for or_predicate in or_predicates.iter() { + match or_predicate { + Predicate::And { args } => { + if !args.contains(expr) { + found = false; + break; + } + } + _ => { + if or_predicate != expr { + found = false; + break; + } + } + } + } + if found { + exist_exprs.push((*expr).clone()); + } + } + if exist_exprs.is_empty() { + return Predicate::Or { + args: or_predicates.to_vec(), + }; + } + + // Rebuild the OR predicate. + // (A AND B) OR A will be optimized to A. + let mut new_or_predicates = vec![]; + for or_predicate in or_predicates.iter() { + match or_predicate { + Predicate::And { args } => { + let mut new_args = (*args).clone(); + new_args.retain(|expr| !exist_exprs.contains(expr)); + if !new_args.is_empty() { + if new_args.len() == 1 { + new_or_predicates.push(new_args[0].clone()); + } else { + new_or_predicates.push(Predicate::And { args: new_args }); + } + } else { + new_or_predicates.clear(); + break; + } + } + _ => { + if exist_exprs.contains(or_predicate) { + new_or_predicates.clear(); + break; + } + } + } + } + if !new_or_predicates.is_empty() { + if new_or_predicates.len() == 1 { + exist_exprs.push(new_or_predicates[0].clone()); + } else { + exist_exprs.push(Predicate::Or { + args: flatten_or_predicates(&new_or_predicates), + }); + } + } + + if exist_exprs.len() == 1 { + exist_exprs[0].clone() + } else { + Predicate::And { + args: flatten_and_predicates(&exist_exprs), + } + } +} + +#[derive(Default)] +pub struct RewriteDisjunctivePredicate; + +impl RewriteDisjunctivePredicate { + pub fn new() -> Self { + Self::default() + } + fn rewrite_disjunctive_predicate( + &self, + plan: &LogicalPlan, + _optimizer_config: &OptimizerConfig, + ) -> Result { + match plan { + LogicalPlan::Filter(filter) => { + let predicate = predicate(&filter.predicate)?; + let rewritten_predicate = rewrite_predicate(&predicate); + let rewritten_expr = normalize_predicate(&rewritten_predicate); + Ok(LogicalPlan::Filter(Filter { + predicate: rewritten_expr, + input: Arc::new(self.rewrite_disjunctive_predicate( + &filter.input, + _optimizer_config, + )?), + })) + } + LogicalPlan::Projection(project) => { + Ok(LogicalPlan::Projection(Projection { + expr: project.expr.clone(), + input: Arc::new(self.rewrite_disjunctive_predicate( + &project.input, + _optimizer_config, + )?), + schema: project.schema.clone(), + alias: project.alias.clone(), + })) + } + LogicalPlan::Window(window) => Ok(LogicalPlan::Window(Window { + input: Arc::new( + self.rewrite_disjunctive_predicate(&window.input, _optimizer_config)?, + ), + window_expr: window.window_expr.clone(), + schema: window.schema.clone(), + })), + LogicalPlan::Aggregate(aggregate) => Ok(LogicalPlan::Aggregate(Aggregate { + input: Arc::new(self.rewrite_disjunctive_predicate( + &aggregate.input, + _optimizer_config, + )?), + group_expr: aggregate.group_expr.clone(), + aggr_expr: aggregate.aggr_expr.clone(), + schema: aggregate.schema.clone(), + })), + LogicalPlan::Sort(sort) => Ok(LogicalPlan::Sort(Sort { + expr: sort.expr.clone(), + input: Arc::new( + self.rewrite_disjunctive_predicate(&sort.input, _optimizer_config)?, + ), + })), + LogicalPlan::Join(join) => Ok(LogicalPlan::Join(Join { + left: Arc::new( + self.rewrite_disjunctive_predicate(&join.left, _optimizer_config)?, + ), + right: Arc::new( + self.rewrite_disjunctive_predicate(&join.right, _optimizer_config)?, + ), + on: join.on.clone(), + filter: join.filter.clone(), + join_type: join.join_type, + join_constraint: join.join_constraint, + schema: join.schema.clone(), + null_equals_null: join.null_equals_null, + })), + LogicalPlan::CrossJoin(cross_join) => Ok(LogicalPlan::CrossJoin(CrossJoin { + left: Arc::new(self.rewrite_disjunctive_predicate( + &cross_join.left, + _optimizer_config, + )?), + right: Arc::new(self.rewrite_disjunctive_predicate( + &cross_join.right, + _optimizer_config, + )?), + schema: cross_join.schema.clone(), + })), + LogicalPlan::Repartition(repartition) => { + Ok(LogicalPlan::Repartition(Repartition { + input: Arc::new(self.rewrite_disjunctive_predicate( + &repartition.input, + _optimizer_config, + )?), + partitioning_scheme: repartition.partitioning_scheme.clone(), + })) + } + LogicalPlan::Union(union) => { + let inputs = union + .inputs + .iter() + .map(|input| { + self.rewrite_disjunctive_predicate(input, _optimizer_config) + }) + .collect::>>()?; + Ok(LogicalPlan::Union(Union { + inputs, + schema: union.schema.clone(), + alias: union.alias.clone(), + })) + } + LogicalPlan::TableScan(table_scan) => { + Ok(LogicalPlan::TableScan(table_scan.clone())) + } + LogicalPlan::EmptyRelation(empty_relation) => { + Ok(LogicalPlan::EmptyRelation(empty_relation.clone())) + } + LogicalPlan::Subquery(subquery) => Ok(LogicalPlan::Subquery(Subquery { + subquery: Arc::new(self.rewrite_disjunctive_predicate( + &subquery.subquery, + _optimizer_config, + )?), + })), + LogicalPlan::SubqueryAlias(subquery_alias) => { + Ok(LogicalPlan::SubqueryAlias(SubqueryAlias { + input: Arc::new(self.rewrite_disjunctive_predicate( + &subquery_alias.input, + _optimizer_config, + )?), + alias: subquery_alias.alias.clone(), + schema: subquery_alias.schema.clone(), + })) + } + LogicalPlan::Limit(limit) => Ok(LogicalPlan::Limit(Limit { + skip: limit.skip, + fetch: limit.fetch, + input: Arc::new( + self.rewrite_disjunctive_predicate(&limit.input, _optimizer_config)?, + ), + })), + LogicalPlan::CreateExternalTable(plan) => { + Ok(LogicalPlan::CreateExternalTable(plan.clone())) + } + LogicalPlan::CreateMemoryTable(plan) => { + Ok(LogicalPlan::CreateMemoryTable(CreateMemoryTable { + name: plan.name.clone(), + input: Arc::new( + self.rewrite_disjunctive_predicate( + &plan.input, + _optimizer_config, + )?, + ), + if_not_exists: plan.if_not_exists, + or_replace: plan.or_replace, + })) + } + LogicalPlan::CreateView(plan) => Ok(LogicalPlan::CreateView(CreateView { + name: plan.name.clone(), + input: Arc::new( + self.rewrite_disjunctive_predicate(&plan.input, _optimizer_config)?, + ), + or_replace: plan.or_replace, + definition: plan.definition.clone(), + })), + LogicalPlan::CreateCatalogSchema(plan) => { + Ok(LogicalPlan::CreateCatalogSchema(plan.clone())) + } + LogicalPlan::CreateCatalog(plan) => { + Ok(LogicalPlan::CreateCatalog(plan.clone())) + } + LogicalPlan::DropTable(plan) => Ok(LogicalPlan::DropTable(plan.clone())), + LogicalPlan::Values(plan) => Ok(LogicalPlan::Values(plan.clone())), + LogicalPlan::Explain(explain) => Ok(LogicalPlan::Explain(Explain { + verbose: explain.verbose, + plan: Arc::new( + self.rewrite_disjunctive_predicate(&explain.plan, _optimizer_config)?, + ), + stringified_plans: explain.stringified_plans.clone(), + schema: explain.schema.clone(), + })), + LogicalPlan::Analyze(analyze) => { + Ok(LogicalPlan::Analyze(Analyze { + verbose: analyze.verbose, + input: Arc::new(self.rewrite_disjunctive_predicate( + &analyze.input, + _optimizer_config, + )?), + schema: analyze.schema.clone(), + })) + } + LogicalPlan::Extension(plan) => Ok(LogicalPlan::Extension(plan.clone())), + LogicalPlan::Distinct(plan) => Ok(LogicalPlan::Distinct(Distinct { + input: Arc::new( + self.rewrite_disjunctive_predicate(&plan.input, _optimizer_config)?, + ), + })), + } + } +} + +impl OptimizerRule for RewriteDisjunctivePredicate { + fn optimize( + &self, + plan: &LogicalPlan, + optimizer_config: &OptimizerConfig, + ) -> Result { + self.rewrite_disjunctive_predicate(plan, optimizer_config) + } + + fn name(&self) -> &str { + "rewrite_disjunctive_predicate" + } +} From 15af00686758e4df684e344875eab2053e814a83 Mon Sep 17 00:00:00 2001 From: xudong963 Date: Mon, 11 Jul 2022 23:23:08 +0800 Subject: [PATCH 2/3] address comments and add tests --- datafusion/core/tests/sql/predicates.rs | 56 +++ .../src/rewrite_disjunctive_predicate.rs | 346 ++++++------------ 2 files changed, 172 insertions(+), 230 deletions(-) diff --git a/datafusion/core/tests/sql/predicates.rs b/datafusion/core/tests/sql/predicates.rs index ea79e2b142d55..e6cb77d9a7c7e 100644 --- a/datafusion/core/tests/sql/predicates.rs +++ b/datafusion/core/tests/sql/predicates.rs @@ -386,3 +386,59 @@ async fn csv_in_set_test() -> Result<()> { assert_batches_sorted_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn multiple_or_predicates() -> Result<()> { + let ctx = SessionContext::new(); + register_tpch_csv(&ctx, "lineitem").await?; + register_tpch_csv(&ctx, "part").await?; + let sql = "explain select + l_partkey + from + lineitem, + part + where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + )"; + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let plan = state.optimize(&plan)?; + // Note that we expect `#part.p_partkey = #lineitem.l_partkey` to have been + // factored out and appear only once in the following plan + let expected =vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: #lineitem.l_partkey [l_partkey:Int64]", + " Projection: #part.p_partkey = #lineitem.l_partkey AS BinaryExpr-=Column-lineitem.l_partkeyColumn-part.p_partkey, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [BinaryExpr-=Column-lineitem.l_partkeyColumn-part.p_partkey:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]", + " Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Int64(1) AND #lineitem.l_quantity <= Int64(11) AND #part.p_size BETWEEN Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Int64(10) AND #lineitem.l_quantity <= Int64(20) AND #part.p_size BETWEEN Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Int64(20) AND #lineitem.l_quantity <= Int64(30) AND #part.p_size BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]", + " CrossJoin: [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]", + " TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Float64]", + " TableScan: part projection=[p_partkey, p_brand, p_size] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) +} diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs index 1f11dd85e994c..2321a6dffff81 100644 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs @@ -17,11 +17,8 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; -use datafusion_expr::logical_plan::{ - Aggregate, Analyze, CreateMemoryTable, CreateView, CrossJoin, Distinct, Explain, - Filter, Join, Limit, Projection, Repartition, Sort, Subquery, SubqueryAlias, Union, - Window, -}; +use datafusion_expr::logical_plan::Filter; +use datafusion_expr::utils::from_plan; use datafusion_expr::Expr::BinaryExpr; use datafusion_expr::{Expr, LogicalPlan, Operator}; use std::sync::Arc; @@ -58,56 +55,35 @@ fn predicate(expr: &Expr) -> Result { } } -fn normalize_predicate(predicate: &Predicate) -> Expr { +fn normalize_predicate(predicate: Predicate) -> Expr { match predicate { Predicate::And { args } => { assert!(args.len() >= 2); - let left = normalize_predicate(&args[0]); - let right = normalize_predicate(&args[1]); - let mut and_expr = BinaryExpr { - left: Box::new(left), - op: Operator::And, - right: Box::new(right), - }; - for arg in args.iter().skip(2) { - and_expr = BinaryExpr { - left: Box::new(and_expr), - op: Operator::And, - right: Box::new(normalize_predicate(arg)), - }; - } - and_expr + args.into_iter() + .map(normalize_predicate) + .reduce(Expr::and) + .expect("had more than one arg") } Predicate::Or { args } => { assert!(args.len() >= 2); - let left = normalize_predicate(&args[0]); - let right = normalize_predicate(&args[1]); - let mut or_expr = BinaryExpr { - left: Box::new(left), - op: Operator::Or, - right: Box::new(right), - }; - for arg in args.iter().skip(2) { - or_expr = BinaryExpr { - left: Box::new(or_expr), - op: Operator::Or, - right: Box::new(normalize_predicate(arg)), - }; - } - or_expr + assert!(args.len() >= 2); + args.into_iter() + .map(normalize_predicate) + .reduce(Expr::or) + .expect("had more than one arg") } - Predicate::Other { expr } => *expr.clone(), + Predicate::Other { expr } => *expr, } } -fn rewrite_predicate(predicate: &Predicate) -> Predicate { +fn rewrite_predicate(predicate: Predicate) -> Predicate { match predicate { Predicate::And { args } => { let mut rewritten_args = Vec::with_capacity(args.len()); for arg in args.iter() { - rewritten_args.push(rewrite_predicate(arg)); + rewritten_args.push(rewrite_predicate(arg.clone())); } - rewritten_args = flatten_and_predicates(&rewritten_args); + rewritten_args = flatten_and_predicates(rewritten_args); Predicate::And { args: rewritten_args, } @@ -115,18 +91,20 @@ fn rewrite_predicate(predicate: &Predicate) -> Predicate { Predicate::Or { args } => { let mut rewritten_args = vec![]; for arg in args.iter() { - rewritten_args.push(rewrite_predicate(arg)); + rewritten_args.push(rewrite_predicate(arg.clone())); } - rewritten_args = flatten_or_predicates(&rewritten_args); + rewritten_args = flatten_or_predicates(rewritten_args); delete_duplicate_predicates(&rewritten_args) } Predicate::Other { expr } => Predicate::Other { - expr: Box::new(*expr.clone()), + expr: Box::new(*expr), }, } } -fn flatten_and_predicates(and_predicates: &[Predicate]) -> Vec { +fn flatten_and_predicates( + and_predicates: impl IntoIterator, +) -> Vec { let mut flattened_predicates = vec![]; for predicate in and_predicates { match predicate { @@ -135,14 +113,16 @@ fn flatten_and_predicates(and_predicates: &[Predicate]) -> Vec { .extend_from_slice(flatten_and_predicates(args).as_slice()); } _ => { - flattened_predicates.push(predicate.clone()); + flattened_predicates.push(predicate); } } } flattened_predicates } -fn flatten_or_predicates(or_predicates: &[Predicate]) -> Vec { +fn flatten_or_predicates( + or_predicates: impl IntoIterator, +) -> Vec { let mut flattened_predicates = vec![]; for predicate in or_predicates { match predicate { @@ -151,7 +131,7 @@ fn flatten_or_predicates(or_predicates: &[Predicate]) -> Vec { .extend_from_slice(flatten_or_predicates(args).as_slice()); } _ => { - flattened_predicates.push(predicate.clone()); + flattened_predicates.push(predicate); } } } @@ -185,23 +165,10 @@ fn delete_duplicate_predicates(or_predicates: &[Predicate]) -> Predicate { // Check each element in shortest_exprs to see if it's in all the OR arguments. let mut exist_exprs: Vec = vec![]; for expr in shortest_exprs.iter() { - let mut found = true; - for or_predicate in or_predicates.iter() { - match or_predicate { - Predicate::And { args } => { - if !args.contains(expr) { - found = false; - break; - } - } - _ => { - if or_predicate != expr { - found = false; - break; - } - } - } - } + let found = or_predicates.iter().all(|or_predicate| match or_predicate { + Predicate::And { args } => args.contains(expr), + _ => or_predicate == expr, + }); if found { exist_exprs.push((*expr).clone()); } @@ -244,7 +211,7 @@ fn delete_duplicate_predicates(or_predicates: &[Predicate]) -> Predicate { exist_exprs.push(new_or_predicates[0].clone()); } else { exist_exprs.push(Predicate::Or { - args: flatten_or_predicates(&new_or_predicates), + args: flatten_or_predicates(new_or_predicates), }); } } @@ -253,7 +220,7 @@ fn delete_duplicate_predicates(or_predicates: &[Predicate]) -> Predicate { exist_exprs[0].clone() } else { Predicate::And { - args: flatten_and_predicates(&exist_exprs), + args: flatten_and_predicates(exist_exprs), } } } @@ -273,8 +240,8 @@ impl RewriteDisjunctivePredicate { match plan { LogicalPlan::Filter(filter) => { let predicate = predicate(&filter.predicate)?; - let rewritten_predicate = rewrite_predicate(&predicate); - let rewritten_expr = normalize_predicate(&rewritten_predicate); + let rewritten_predicate = rewrite_predicate(predicate); + let rewritten_expr = normalize_predicate(rewritten_predicate); Ok(LogicalPlan::Filter(Filter { predicate: rewritten_expr, input: Arc::new(self.rewrite_disjunctive_predicate( @@ -283,172 +250,17 @@ impl RewriteDisjunctivePredicate { )?), })) } - LogicalPlan::Projection(project) => { - Ok(LogicalPlan::Projection(Projection { - expr: project.expr.clone(), - input: Arc::new(self.rewrite_disjunctive_predicate( - &project.input, - _optimizer_config, - )?), - schema: project.schema.clone(), - alias: project.alias.clone(), - })) - } - LogicalPlan::Window(window) => Ok(LogicalPlan::Window(Window { - input: Arc::new( - self.rewrite_disjunctive_predicate(&window.input, _optimizer_config)?, - ), - window_expr: window.window_expr.clone(), - schema: window.schema.clone(), - })), - LogicalPlan::Aggregate(aggregate) => Ok(LogicalPlan::Aggregate(Aggregate { - input: Arc::new(self.rewrite_disjunctive_predicate( - &aggregate.input, - _optimizer_config, - )?), - group_expr: aggregate.group_expr.clone(), - aggr_expr: aggregate.aggr_expr.clone(), - schema: aggregate.schema.clone(), - })), - LogicalPlan::Sort(sort) => Ok(LogicalPlan::Sort(Sort { - expr: sort.expr.clone(), - input: Arc::new( - self.rewrite_disjunctive_predicate(&sort.input, _optimizer_config)?, - ), - })), - LogicalPlan::Join(join) => Ok(LogicalPlan::Join(Join { - left: Arc::new( - self.rewrite_disjunctive_predicate(&join.left, _optimizer_config)?, - ), - right: Arc::new( - self.rewrite_disjunctive_predicate(&join.right, _optimizer_config)?, - ), - on: join.on.clone(), - filter: join.filter.clone(), - join_type: join.join_type, - join_constraint: join.join_constraint, - schema: join.schema.clone(), - null_equals_null: join.null_equals_null, - })), - LogicalPlan::CrossJoin(cross_join) => Ok(LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(self.rewrite_disjunctive_predicate( - &cross_join.left, - _optimizer_config, - )?), - right: Arc::new(self.rewrite_disjunctive_predicate( - &cross_join.right, - _optimizer_config, - )?), - schema: cross_join.schema.clone(), - })), - LogicalPlan::Repartition(repartition) => { - Ok(LogicalPlan::Repartition(Repartition { - input: Arc::new(self.rewrite_disjunctive_predicate( - &repartition.input, - _optimizer_config, - )?), - partitioning_scheme: repartition.partitioning_scheme.clone(), - })) - } - LogicalPlan::Union(union) => { - let inputs = union - .inputs + _ => { + let expr = plan.expressions(); + let inputs = plan.inputs(); + let new_inputs = inputs .iter() .map(|input| { self.rewrite_disjunctive_predicate(input, _optimizer_config) }) - .collect::>>()?; - Ok(LogicalPlan::Union(Union { - inputs, - schema: union.schema.clone(), - alias: union.alias.clone(), - })) - } - LogicalPlan::TableScan(table_scan) => { - Ok(LogicalPlan::TableScan(table_scan.clone())) + .collect::>>()?; + from_plan(plan, &expr, &new_inputs) } - LogicalPlan::EmptyRelation(empty_relation) => { - Ok(LogicalPlan::EmptyRelation(empty_relation.clone())) - } - LogicalPlan::Subquery(subquery) => Ok(LogicalPlan::Subquery(Subquery { - subquery: Arc::new(self.rewrite_disjunctive_predicate( - &subquery.subquery, - _optimizer_config, - )?), - })), - LogicalPlan::SubqueryAlias(subquery_alias) => { - Ok(LogicalPlan::SubqueryAlias(SubqueryAlias { - input: Arc::new(self.rewrite_disjunctive_predicate( - &subquery_alias.input, - _optimizer_config, - )?), - alias: subquery_alias.alias.clone(), - schema: subquery_alias.schema.clone(), - })) - } - LogicalPlan::Limit(limit) => Ok(LogicalPlan::Limit(Limit { - skip: limit.skip, - fetch: limit.fetch, - input: Arc::new( - self.rewrite_disjunctive_predicate(&limit.input, _optimizer_config)?, - ), - })), - LogicalPlan::CreateExternalTable(plan) => { - Ok(LogicalPlan::CreateExternalTable(plan.clone())) - } - LogicalPlan::CreateMemoryTable(plan) => { - Ok(LogicalPlan::CreateMemoryTable(CreateMemoryTable { - name: plan.name.clone(), - input: Arc::new( - self.rewrite_disjunctive_predicate( - &plan.input, - _optimizer_config, - )?, - ), - if_not_exists: plan.if_not_exists, - or_replace: plan.or_replace, - })) - } - LogicalPlan::CreateView(plan) => Ok(LogicalPlan::CreateView(CreateView { - name: plan.name.clone(), - input: Arc::new( - self.rewrite_disjunctive_predicate(&plan.input, _optimizer_config)?, - ), - or_replace: plan.or_replace, - definition: plan.definition.clone(), - })), - LogicalPlan::CreateCatalogSchema(plan) => { - Ok(LogicalPlan::CreateCatalogSchema(plan.clone())) - } - LogicalPlan::CreateCatalog(plan) => { - Ok(LogicalPlan::CreateCatalog(plan.clone())) - } - LogicalPlan::DropTable(plan) => Ok(LogicalPlan::DropTable(plan.clone())), - LogicalPlan::Values(plan) => Ok(LogicalPlan::Values(plan.clone())), - LogicalPlan::Explain(explain) => Ok(LogicalPlan::Explain(Explain { - verbose: explain.verbose, - plan: Arc::new( - self.rewrite_disjunctive_predicate(&explain.plan, _optimizer_config)?, - ), - stringified_plans: explain.stringified_plans.clone(), - schema: explain.schema.clone(), - })), - LogicalPlan::Analyze(analyze) => { - Ok(LogicalPlan::Analyze(Analyze { - verbose: analyze.verbose, - input: Arc::new(self.rewrite_disjunctive_predicate( - &analyze.input, - _optimizer_config, - )?), - schema: analyze.schema.clone(), - })) - } - LogicalPlan::Extension(plan) => Ok(LogicalPlan::Extension(plan.clone())), - LogicalPlan::Distinct(plan) => Ok(LogicalPlan::Distinct(Distinct { - input: Arc::new( - self.rewrite_disjunctive_predicate(&plan.input, _optimizer_config)?, - ), - })), } } } @@ -457,7 +269,7 @@ impl OptimizerRule for RewriteDisjunctivePredicate { fn optimize( &self, plan: &LogicalPlan, - optimizer_config: &OptimizerConfig, + optimizer_config: &mut OptimizerConfig, ) -> Result { self.rewrite_disjunctive_predicate(plan, optimizer_config) } @@ -466,3 +278,77 @@ impl OptimizerRule for RewriteDisjunctivePredicate { "rewrite_disjunctive_predicate" } } + +#[cfg(test)] + +mod tests { + use crate::rewrite_disjunctive_predicate::{ + normalize_predicate, predicate, rewrite_predicate, Predicate, + }; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{and, col, lit, or}; + + #[test] + fn test_rewrite_predicate() -> Result<()> { + let equi_expr = col("t1.a").eq(col("t2.b")); + let gt_expr = col("t1.c").gt(lit(ScalarValue::Int8(Some(1)))); + let lt_expr = col("t1.d").lt(lit(ScalarValue::Int8(Some(2)))); + let expr = or( + and(equi_expr.clone(), gt_expr.clone()), + and(equi_expr.clone(), lt_expr.clone()), + ); + let predicate = predicate(&expr)?; + assert_eq!( + predicate, + Predicate::Or { + args: vec![ + Predicate::And { + args: vec![ + Predicate::Other { + expr: Box::new(equi_expr.clone()) + }, + Predicate::Other { + expr: Box::new(gt_expr.clone()) + } + ] + }, + Predicate::And { + args: vec![ + Predicate::Other { + expr: Box::new(equi_expr.clone()) + }, + Predicate::Other { + expr: Box::new(lt_expr.clone()) + } + ] + } + ] + } + ); + let rewritten_predicate = rewrite_predicate(predicate); + assert_eq!( + rewritten_predicate, + Predicate::And { + args: vec![ + Predicate::Other { + expr: Box::new(equi_expr.clone()) + }, + Predicate::Or { + args: vec![ + Predicate::Other { + expr: Box::new(gt_expr.clone()) + }, + Predicate::Other { + expr: Box::new(lt_expr.clone()) + } + ] + } + ] + } + ); + let rewritten_expr = normalize_predicate(rewritten_predicate); + assert_eq!(rewritten_expr, and(equi_expr, or(gt_expr, lt_expr))); + Ok(()) + } +} From 688571db3ca53c67330bd3bca850f90d25611b70 Mon Sep 17 00:00:00 2001 From: "xudong.w" Date: Tue, 26 Jul 2022 21:42:18 +0800 Subject: [PATCH 3/3] Update datafusion/optimizer/src/rewrite_disjunctive_predicate.rs Co-authored-by: Andrew Lamb --- datafusion/optimizer/src/rewrite_disjunctive_predicate.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs index 2321a6dffff81..b68adef5ae3bb 100644 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs @@ -65,7 +65,6 @@ fn normalize_predicate(predicate: Predicate) -> Expr { .expect("had more than one arg") } Predicate::Or { args } => { - assert!(args.len() >= 2); assert!(args.len() >= 2); args.into_iter() .map(normalize_predicate)