From 4b8bc067ecfc90329a78df3b183924b037107f55 Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Wed, 17 May 2023 14:51:50 +0800 Subject: [PATCH 1/3] More scalar subquery support --- benchmarks/expected-plans/q2.txt | 26 +- datafusion/core/tests/sql/subqueries.rs | 28 +- .../optimizer/src/scalar_subquery_to_join.rs | 478 +++++++++++------- 3 files changed, 315 insertions(+), 217 deletions(-) diff --git a/benchmarks/expected-plans/q2.txt b/benchmarks/expected-plans/q2.txt index 2bc3f732bf234..c503bd2e0b713 100644 --- a/benchmarks/expected-plans/q2.txt +++ b/benchmarks/expected-plans/q2.txt @@ -3,12 +3,12 @@ +---------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | logical_plan | Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name ASC NULLS LAST, supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST | | | Projection: supplier.s_acctbal, supplier.s_name, nation.n_name, part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone, supplier.s_comment | -| | Inner Join: partsupp.ps_supplycost = __scalar_sq_1.__value, part.p_partkey = __scalar_sq_1.ps_partkey | -| | Projection: part.p_partkey, part.p_mfgr, partsupp.ps_supplycost, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, nation.n_name | +| | Inner Join: part.p_partkey = __scalar_sq_1.ps_partkey, partsupp.ps_supplycost = __scalar_sq_1.__value | +| | Projection: part.p_partkey, part.p_mfgr, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, partsupp.ps_supplycost, nation.n_name | | | Inner Join: nation.n_regionkey = region.r_regionkey | -| | Projection: part.p_partkey, part.p_mfgr, partsupp.ps_supplycost, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, nation.n_name, nation.n_regionkey | +| | Projection: part.p_partkey, part.p_mfgr, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, partsupp.ps_supplycost, nation.n_name, nation.n_regionkey | | | Inner Join: supplier.s_nationkey = nation.n_nationkey | -| | Projection: part.p_partkey, part.p_mfgr, partsupp.ps_supplycost, supplier.s_name, supplier.s_address, supplier.s_nationkey, supplier.s_phone, supplier.s_acctbal, supplier.s_comment | +| | Projection: part.p_partkey, part.p_mfgr, supplier.s_name, supplier.s_address, supplier.s_nationkey, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, partsupp.ps_supplycost | | | Inner Join: partsupp.ps_suppkey = supplier.s_suppkey | | | Projection: part.p_partkey, part.p_mfgr, partsupp.ps_suppkey, partsupp.ps_supplycost | | | Inner Join: part.p_partkey = partsupp.ps_partkey | @@ -38,22 +38,22 @@ | | TableScan: region projection=[r_regionkey, r_name] | | physical_plan | SortPreservingMergeExec: [s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST] | | | SortExec: expr=[s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST] | -| | ProjectionExec: expr=[s_acctbal@6 as s_acctbal, s_name@3 as s_name, n_name@8 as n_name, p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_address@4 as s_address, s_phone@5 as s_phone, s_comment@7 as s_comment] | +| | ProjectionExec: expr=[s_acctbal@5 as s_acctbal, s_name@2 as s_name, n_name@8 as n_name, p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_address@3 as s_address, s_phone@4 as s_phone, s_comment@6 as s_comment] | | | CoalesceBatchesExec: target_batch_size=8192 | -| | HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "ps_supplycost", index: 2 }, Column { name: "__value", index: 1 }), (Column { name: "p_partkey", index: 0 }, Column { name: "ps_partkey", index: 0 })] | +| | HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "p_partkey", index: 0 }, Column { name: "ps_partkey", index: 0 }), (Column { name: "ps_supplycost", index: 7 }, Column { name: "__value", index: 1 })] | | | CoalesceBatchesExec: target_batch_size=8192 | -| | RepartitionExec: partitioning=Hash([Column { name: "ps_supplycost", index: 2 }, Column { name: "p_partkey", index: 0 }], 2), input_partitions=2 | -| | ProjectionExec: expr=[p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, ps_supplycost@2 as ps_supplycost, s_name@3 as s_name, s_address@4 as s_address, s_phone@5 as s_phone, s_acctbal@6 as s_acctbal, s_comment@7 as s_comment, n_name@8 as n_name] | +| | RepartitionExec: partitioning=Hash([Column { name: "p_partkey", index: 0 }, Column { name: "ps_supplycost", index: 7 }], 2), input_partitions=2 | +| | ProjectionExec: expr=[p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_name@2 as s_name, s_address@3 as s_address, s_phone@4 as s_phone, s_acctbal@5 as s_acctbal, s_comment@6 as s_comment, ps_supplycost@7 as ps_supplycost, n_name@8 as n_name] | | | CoalesceBatchesExec: target_batch_size=8192 | | | HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "n_regionkey", index: 9 }, Column { name: "r_regionkey", index: 0 })] | | | CoalesceBatchesExec: target_batch_size=8192 | | | RepartitionExec: partitioning=Hash([Column { name: "n_regionkey", index: 9 }], 2), input_partitions=2 | -| | ProjectionExec: expr=[p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, ps_supplycost@2 as ps_supplycost, s_name@3 as s_name, s_address@4 as s_address, s_phone@6 as s_phone, s_acctbal@7 as s_acctbal, s_comment@8 as s_comment, n_name@10 as n_name, n_regionkey@11 as n_regionkey] | +| | ProjectionExec: expr=[p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_name@2 as s_name, s_address@3 as s_address, s_phone@5 as s_phone, s_acctbal@6 as s_acctbal, s_comment@7 as s_comment, ps_supplycost@8 as ps_supplycost, n_name@10 as n_name, n_regionkey@11 as n_regionkey] | | | CoalesceBatchesExec: target_batch_size=8192 | -| | HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_nationkey", index: 5 }, Column { name: "n_nationkey", index: 0 })] | +| | HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "s_nationkey", index: 4 }, Column { name: "n_nationkey", index: 0 })] | | | CoalesceBatchesExec: target_batch_size=8192 | -| | RepartitionExec: partitioning=Hash([Column { name: "s_nationkey", index: 5 }], 2), input_partitions=2 | -| | ProjectionExec: expr=[p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, ps_supplycost@3 as ps_supplycost, s_name@5 as s_name, s_address@6 as s_address, s_nationkey@7 as s_nationkey, s_phone@8 as s_phone, s_acctbal@9 as s_acctbal, s_comment@10 as s_comment] | +| | RepartitionExec: partitioning=Hash([Column { name: "s_nationkey", index: 4 }], 2), input_partitions=2 | +| | ProjectionExec: expr=[p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_name@5 as s_name, s_address@6 as s_address, s_nationkey@7 as s_nationkey, s_phone@8 as s_phone, s_acctbal@9 as s_acctbal, s_comment@10 as s_comment, ps_supplycost@3 as ps_supplycost] | | | CoalesceBatchesExec: target_batch_size=8192 | | | HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "ps_suppkey", index: 2 }, Column { name: "s_suppkey", index: 0 })] | | | CoalesceBatchesExec: target_batch_size=8192 | @@ -85,7 +85,7 @@ | | FilterExec: r_name@1 = EUROPE | | | MemoryExec: partitions=0, partition_sizes=[] | | | CoalesceBatchesExec: target_batch_size=8192 | -| | RepartitionExec: partitioning=Hash([Column { name: "__value", index: 1 }, Column { name: "ps_partkey", index: 0 }], 2), input_partitions=2 | +| | RepartitionExec: partitioning=Hash([Column { name: "ps_partkey", index: 0 }, Column { name: "__value", index: 1 }], 2), input_partitions=2 | | | ProjectionExec: expr=[ps_partkey@0 as ps_partkey, MIN(partsupp.ps_supplycost)@1 as __value] | | | AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[MIN(partsupp.ps_supplycost)] | | | CoalesceBatchesExec: target_batch_size=8192 | diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index b93e0e0c8b9f8..30e02b2854895 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -382,13 +382,13 @@ async fn aggregated_correlated_scalar_subquery() -> Result<()> { let plan = dataframe.into_optimized_plan()?; let expected = vec![ - "Projection: t1.t1_id, () AS t2_sum [t1_id:UInt32;N, t2_sum:UInt64;N]", - " Subquery: [SUM(t2.t2_int):UInt64;N]", - " Projection: SUM(t2.t2_int) [SUM(t2.t2_int):UInt64;N]", - " Aggregate: groupBy=[[]], aggr=[[SUM(t2.t2_int)]] [SUM(t2.t2_int):UInt64;N]", - " Filter: t2.t2_id = outer_ref(t1.t1_id) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]", + "Projection: t1.t1_id, __scalar_sq_1.__value AS t2_sum [t1_id:UInt32;N, t2_sum:UInt64;N]", + " Left Join: t1.t1_id = __scalar_sq_1.t2_id [t1_id:UInt32;N, t2_id:UInt32;N, __value:UInt64;N]", + " TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]", + " SubqueryAlias: __scalar_sq_1 [t2_id:UInt32;N, __value:UInt64;N]", + " Projection: t2.t2_id, SUM(t2.t2_int) AS __value [t2_id:UInt32;N, __value:UInt64;N]", + " Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(t2.t2_int)]] [t2_id:UInt32;N, SUM(t2.t2_int):UInt64;N]", + " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -429,13 +429,13 @@ async fn aggregated_correlated_scalar_subquery_with_extra_group_by_constant() -> let plan = dataframe.into_optimized_plan()?; let expected = vec![ - "Projection: t1.t1_id, () AS t2_sum [t1_id:UInt32;N, t2_sum:UInt64;N]", - " Subquery: [SUM(t2.t2_int):UInt64;N]", - " Projection: SUM(t2.t2_int) [SUM(t2.t2_int):UInt64;N]", - " Aggregate: groupBy=[[t2.t2_id, Utf8(\"a\")]], aggr=[[SUM(t2.t2_int)]] [t2_id:UInt32;N, Utf8(\"a\"):Utf8, SUM(t2.t2_int):UInt64;N]", - " Filter: t2.t2_id = outer_ref(t1.t1_id) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]", + "Projection: t1.t1_id, __scalar_sq_1.__value AS t2_sum [t1_id:UInt32;N, t2_sum:UInt64;N]", + " Left Join: t1.t1_id = __scalar_sq_1.t2_id [t1_id:UInt32;N, t2_id:UInt32;N, __value:UInt64;N]", + " TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]", + " SubqueryAlias: __scalar_sq_1 [t2_id:UInt32;N, __value:UInt64;N]", + " Projection: t2.t2_id, SUM(t2.t2_int) AS __value [t2_id:UInt32;N, __value:UInt64;N]", + " Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(t2.t2_int)]] [t2_id:UInt32;N, SUM(t2.t2_int):UInt64;N]", + " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index dfaa17213b466..02bff98ae2e98 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -19,13 +19,13 @@ use crate::alias::AliasGenerator; use crate::optimizer::ApplyOrder; use crate::utils::{ collect_subquery_cols, conjunction, extract_join_filters, only_or_err, - replace_qualified_name, split_conjunction, + replace_qualified_name, }; use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; use datafusion_common::{context, Column, Result}; -use datafusion_expr::expr::BinaryExpr; use datafusion_expr::logical_plan::{JoinType, Subquery}; -use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Operator}; +use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; use log::debug; use std::sync::Arc; @@ -45,51 +45,19 @@ impl ScalarSubqueryToJoin { /// /// # Arguments /// * `predicate` - A conjunction to split and search - /// * `optimizer_config` - For generating unique subquery aliases /// - /// Returns a tuple (subqueries, non-subquery expressions) + /// Returns a tuple (subqueries, rewrite expression) fn extract_subquery_exprs( &self, predicate: &Expr, - config: &dyn OptimizerConfig, - ) -> Result<(Vec, Vec)> { - let filters = split_conjunction(predicate); // TODO: disjunctions - - let mut subqueries = vec![]; - let mut others = vec![]; - for it in filters.iter() { - match it { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l_query = Subquery::try_from_expr(left); - let r_query = Subquery::try_from_expr(right); - if l_query.is_err() && r_query.is_err() { - others.push((*it).clone()); - continue; - } - let mut recurse = - |q: Result<&Subquery>, expr: Expr, lhs: bool| -> Result<()> { - let subquery = match q { - Ok(subquery) => subquery, - _ => return Ok(()), - }; - let subquery_plan = self - .try_optimize(&subquery.subquery, config)? - .map(Arc::new) - .unwrap_or_else(|| subquery.subquery.clone()); - let new_subquery = subquery.with_plan(subquery_plan); - let res = SubqueryInfo::new(new_subquery, expr, *op, lhs); - subqueries.push(res); - Ok(()) - }; - recurse(l_query, (**right).clone(), false)?; - recurse(r_query, (**left).clone(), true)?; - // TODO: if subquery doesn't get optimized, optimized children are lost - } - _ => others.push((*it).clone()), - } - } - - Ok((subqueries, others)) + alias_gen: Arc, + ) -> Result<(Vec<(Subquery, String)>, Expr)> { + let mut extract = ExtractScalarSubQuery { + sub_query_info: vec![], + alias_gen, + }; + let new_expr = predicate.clone().rewrite(&mut extract)?; + Ok((extract.sub_query_info, new_expr)) } } @@ -97,23 +65,23 @@ impl OptimizerRule for ScalarSubqueryToJoin { fn try_optimize( &self, plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _config: &dyn OptimizerConfig, ) -> Result> { match plan { LogicalPlan::Filter(filter) => { - let (subqueries, other_exprs) = - self.extract_subquery_exprs(&filter.predicate, config)?; + let (subqueries, expr) = + self.extract_subquery_exprs(&filter.predicate, self.alias.clone())?; if subqueries.is_empty() { // regular filter, no subquery exists clause here return Ok(None); } - // iterate through all subqueries in predicate, turning each into a join + // iterate through all subqueries in predicate, turning each into a left join let mut cur_input = filter.input.as_ref().clone(); - for subquery in subqueries { + for (subquery, alias) in subqueries { if let Some(optimized_subquery) = - optimize_scalar(&subquery, &cur_input, &other_exprs, &self.alias)? + optimize_scalar(&subquery, &cur_input, &alias)? { cur_input = optimized_subquery; } else { @@ -121,8 +89,38 @@ impl OptimizerRule for ScalarSubqueryToJoin { return Ok(None); } } - Ok(Some(cur_input)) + let new_plan = LogicalPlanBuilder::from(cur_input); + Ok(Some(new_plan.filter(expr)?.build()?)) } + LogicalPlan::Projection(projection) => { + let mut all_subqueryies = vec![]; + let mut rewrite_exprs = vec![]; + for expr in projection.expr.iter() { + let (subqueries, expr) = + self.extract_subquery_exprs(expr, self.alias.clone())?; + all_subqueryies.extend(subqueries); + rewrite_exprs.push(expr); + } + if all_subqueryies.is_empty() { + // regular projection, no subquery exists clause here + return Ok(None); + } + // iterate through all subqueries in predicate, turning each into a left join + let mut cur_input = projection.input.as_ref().clone(); + for (subquery, alias) in all_subqueryies { + if let Some(optimized_subquery) = + optimize_scalar(&subquery, &cur_input, &alias)? + { + cur_input = optimized_subquery; + } else { + // if we can't handle all of the subqueries then bail for now + return Ok(None); + } + } + let new_plan = LogicalPlanBuilder::from(cur_input); + Ok(Some(new_plan.project(rewrite_exprs)?.build()?)) + } + _ => Ok(None), } } @@ -136,6 +134,34 @@ impl OptimizerRule for ScalarSubqueryToJoin { } } +struct ExtractScalarSubQuery { + sub_query_info: Vec<(Subquery, String)>, + alias_gen: Arc, +} + +impl TreeNodeRewriter for ExtractScalarSubQuery { + type N = Expr; + + fn pre_visit(&mut self, expr: &Expr) -> Result { + match expr { + Expr::ScalarSubquery(_) => Ok(RewriteRecursion::Mutate), + _ => Ok(RewriteRecursion::Continue), + } + } + + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::ScalarSubquery(subquery) => { + let subqry_alias = self.alias_gen.next("__scalar_sq"); + self.sub_query_info.push((subquery, subqry_alias.clone())); + let scalar_column = "__value"; + Ok(Expr::Column(Column::new(Some(subqry_alias), scalar_column))) + } + _ => Ok(expr), + } + } +} + /// Takes a query like: /// /// ```text @@ -147,7 +173,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { /// /// ```text /// select c.id from customers c -/// inner join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id +/// left join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id /// where c.balance > o.val /// ``` /// @@ -162,7 +188,8 @@ impl OptimizerRule for ScalarSubqueryToJoin { /// /// ```text /// select c.id from customers c -/// inner join (select avg(total) as val from orders) a on (c.balance > a.val) +/// cross join (select avg(total) as val from orders) a +/// where c.balance > a.val /// ``` /// /// # Arguments @@ -170,27 +197,21 @@ impl OptimizerRule for ScalarSubqueryToJoin { /// * `query_info` - The subquery portion of the `where` (select avg(total) from orders) /// * `filter_input` - The non-subquery portion (from customers) /// * `outer_others` - Any additional parts to the `where` expression (and c.x = y) -/// * `optimizer_config` - Used to generate unique subquery aliases +/// * `subquery_alias` - Subquery aliases fn optimize_scalar( - query_info: &SubqueryInfo, + subquery: &Subquery, filter_input: &LogicalPlan, - outer_others: &[Expr], - alias: &AliasGenerator, + subquery_alias: &str, ) -> Result> { - let subquery = query_info.query.subquery.as_ref(); - debug!( - "optimizing: -{}", - subquery.display_indent() - ); - let proj = match &subquery { + let subquery_plan = subquery.subquery.as_ref(); + let proj = match &subquery_plan { LogicalPlan::Projection(proj) => proj, _ => { // this rule does not support this type of scalar subquery // TODO support more types debug!( "cannot translate this type of scalar subquery to a join: {}", - subquery.display_indent() + subquery_plan.display_indent() ); return Ok(None); } @@ -198,7 +219,7 @@ fn optimize_scalar( let proj = only_or_err(proj.expr.as_slice()) .map_err(|e| context!("exactly one expression should be projected", e))?; let proj = Expr::Alias(Box::new(proj.clone()), "__value".to_string()); - let sub_inputs = subquery.inputs(); + let sub_inputs = subquery_plan.inputs(); let sub_input = only_or_err(sub_inputs.as_slice()) .map_err(|e| context!("Exactly one input is expected. Is this a join?", e))?; @@ -209,7 +230,7 @@ fn optimize_scalar( // TODO support more types debug!( "cannot translate this type of scalar subquery to a join: {}", - subquery.display_indent() + subquery_plan.display_indent() ); return Ok(None); } @@ -218,11 +239,10 @@ fn optimize_scalar( // extract join filters let (join_filters, subquery_input) = extract_join_filters(&aggr.input)?; // Only operate if one column is present and the other closed upon from outside scope - let subqry_alias = alias.next("__scalar_sq"); let input_schema = subquery_input.schema(); let subqry_cols = collect_subquery_cols(&join_filters, input_schema.clone())?; let join_filter = conjunction(join_filters).map_or(Ok(None), |filter| { - replace_qualified_name(filter, &subqry_cols, &subqry_alias).map(Option::Some) + replace_qualified_name(filter, &subqry_cols, subquery_alias).map(Option::Some) })?; let group_by: Vec<_> = subqry_cols @@ -240,89 +260,39 @@ fn optimize_scalar( let subqry_plan = subqry_plan .aggregate(group_by, aggr.aggr_expr.clone())? .project(proj)? - .alias(subqry_alias.clone())? + .alias(subquery_alias.to_string())? .build()?; - let qry_expr = Expr::Column(Column::new(Some(subqry_alias), "__value".to_string())); - - // if correlated subquery's operation is column equality, put the clause into join on clause. - let mut restore_where_clause = true; - - let mut outer_keys = vec![]; - let mut subquery_keys = vec![]; - if let (Operator::Eq, Expr::Column(column)) = (query_info.op, &query_info.expr) { - // only do this optimization for correlated subquery - if !query_info.query.outer_ref_columns.is_empty() { - outer_keys.push(column.clone()); - subquery_keys.push(qry_expr.try_into_col().unwrap()); - restore_where_clause = false; - } - } - let join_keys = (outer_keys, subquery_keys); // join our sub query into the main plan let new_plan = LogicalPlanBuilder::from(filter_input.clone()); - let mut new_plan = if join_filter.is_none() && join_keys.0.is_empty() { + let new_plan = if join_filter.is_none() { // if not correlated, group down to 1 row and cross join on that (preserving row count) new_plan.cross_join(subqry_plan)? } else { - // inner join if correlated, grouping by the join keys so we don't change row count - new_plan.join(subqry_plan, JoinType::Inner, join_keys, join_filter)? + // left join if correlated, grouping by the join keys so we don't change row count + new_plan.join( + subqry_plan, + JoinType::Left, + (Vec::::new(), Vec::::new()), + join_filter, + )? }; - // restore where in condition - if restore_where_clause { - let filter_expr = if query_info.expr_on_left { - Expr::BinaryExpr(BinaryExpr::new( - Box::new(query_info.expr.clone()), - query_info.op, - Box::new(qry_expr), - )) - } else { - Expr::BinaryExpr(BinaryExpr::new( - Box::new(qry_expr), - query_info.op, - Box::new(query_info.expr.clone()), - )) - }; - new_plan = new_plan.filter(filter_expr)?; - } - - // if the main query had additional expressions, restore them - if let Some(expr) = conjunction(outer_others.to_vec()) { - new_plan = new_plan.filter(expr)? - } - let new_plan = new_plan.build()?; - Ok(Some(new_plan)) -} - -struct SubqueryInfo { - query: Subquery, - expr: Expr, - op: Operator, - expr_on_left: bool, -} - -impl SubqueryInfo { - pub fn new(query: Subquery, expr: Expr, op: Operator, expr_on_left: bool) -> Self { - Self { - query, - expr, - op, - expr_on_left, - } - } + Ok(Some(new_plan.build()?)) } #[cfg(test)] mod tests { use super::*; + use crate::eliminate_cross_join::EliminateCrossJoin; + use crate::eliminate_outer_join::EliminateOuterJoin; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; use crate::test::*; use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::{ col, lit, logical_plan::LogicalPlanBuilder, max, min, out_ref_col, - scalar_subquery, sum, + scalar_subquery, sum, Between, }; use std::ops::Add; @@ -356,15 +326,14 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: Int32(1) < __scalar_sq_2.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N, o_custkey:Int64, __value:Int64;N]\ - \n Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N, o_custkey:Int64, __value:Int64;N]\ - \n Filter: Int32(1) < __scalar_sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]\ - \n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ - \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n Filter: Int32(1) < __scalar_sq_1.__value AND Int32(1) < __scalar_sq_2.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N, o_custkey:Int64, __value:Int64;N]\ + \n Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N, o_custkey:Int64, __value:Int64;N]\ + \n Left Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ + \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n SubqueryAlias: __scalar_sq_2 [o_custkey:Int64, __value:Int64;N]\ \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ @@ -372,6 +341,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![ Arc::new(ScalarSubqueryToJoin::new()), + Arc::new(EliminateOuterJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), ], &plan, @@ -421,8 +391,8 @@ mod tests { \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Float64;N]\ \n Projection: orders.o_custkey, SUM(orders.o_totalprice) AS __value [o_custkey:Int64, __value:Float64;N]\ \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]] [o_custkey:Int64, SUM(orders.o_totalprice):Float64;N]\ - \n Filter: orders.o_totalprice < __scalar_sq_2.__value [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64, __value:Float64;N]\ - \n Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64, __value:Float64;N]\ + \n Filter: orders.o_totalprice < __scalar_sq_2.__value [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64;N, __value:Float64;N]\ + \n Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64;N, __value:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n SubqueryAlias: __scalar_sq_2 [l_orderkey:Int64, __value:Float64;N]\ \n Projection: lineitem.l_orderkey, SUM(lineitem.l_extendedprice) AS __value [l_orderkey:Int64, __value:Float64;N]\ @@ -432,6 +402,7 @@ mod tests { vec![ Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), + Arc::new(EliminateOuterJoin::new()), ], &plan, expected, @@ -460,18 +431,20 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Inner Join: customer.c_custkey = __scalar_sq_1.__value, customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ - \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ - \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n Filter: customer.c_custkey = __scalar_sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ + \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ + \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![ Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), + Arc::new(EliminateOuterJoin::new()), ], &plan, expected, @@ -510,6 +483,7 @@ mod tests { vec![ Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), + Arc::new(EliminateCrossJoin::new()), ], &plan, expected, @@ -534,19 +508,19 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N]\ - \n CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [__value:Int64;N]\ - \n Projection: MAX(orders.o_custkey) AS __value [__value:Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ - \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n Inner Join: customer.c_custkey = __scalar_sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __scalar_sq_1 [__value:Int64;N]\ + \n Projection: MAX(orders.o_custkey) AS __value [__value:Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ + \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![ Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), + Arc::new(EliminateCrossJoin::new()), ], &plan, expected, @@ -662,6 +636,7 @@ mod tests { vec![ Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), + Arc::new(EliminateOuterJoin::new()), ], &plan, expected, @@ -694,17 +669,19 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Inner Join: customer.c_custkey = __scalar_sq_1.__value, customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ - \n Projection: orders.o_custkey, MAX(orders.o_custkey) + Int32(1) AS __value [o_custkey:Int64, __value:Int64;N]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n Filter: customer.c_custkey = __scalar_sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ + \n Projection: orders.o_custkey, MAX(orders.o_custkey) + Int32(1) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![ Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), + Arc::new(EliminateOuterJoin::new()), ], &plan, expected, @@ -762,19 +739,19 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]\ - \n Filter: customer.c_custkey >= __scalar_sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]\ - \n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ - \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n Filter: customer.c_custkey >= __scalar_sq_1.__value AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ + \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![ Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), + Arc::new(EliminateOuterJoin::new()), ], &plan, expected, @@ -805,8 +782,8 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]\ - \n Inner Join: customer.c_custkey = __scalar_sq_1.__value, customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Int64;N]\ + \n Filter: customer.c_custkey = __scalar_sq_1.__value AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ @@ -816,6 +793,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![ Arc::new(ScalarSubqueryToJoin::new()), + Arc::new(EliminateOuterJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), ], &plan, @@ -829,7 +807,10 @@ mod tests { fn scalar_subquery_disjunction() -> Result<()> { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) - .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? + .filter( + out_ref_col(DataType::Int64, "customer.c_custkey") + .eq(col("orders.o_custkey")), + )? .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? .project(vec![max(col("orders.o_custkey"))])? .build()?, @@ -844,19 +825,20 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - // unoptimized plan because we don't support disjunctions yet - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - Filter: customer.c_custkey = () OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] - Subquery: [MAX(orders.o_custkey):Int64;N] - Projection: MAX(orders.o_custkey) [MAX(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N] - Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = __scalar_sq_1.__value OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n Left Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ + \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + assert_multi_rules_optimized_plan_eq_display_indent( vec![ Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), + Arc::new(EliminateCrossJoin::new()), ], &plan, expected, @@ -881,8 +863,8 @@ mod tests { .build()?; let expected = "Projection: test.c [c:UInt32]\ - \n Filter: test.c < __scalar_sq_1.__value [a:UInt32, b:UInt32, c:UInt32, a:UInt32, __value:UInt32;N]\ - \n Inner Join: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, __value:UInt32;N]\ + \n Filter: test.c < __scalar_sq_1.__value [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, __value:UInt32;N]\ + \n Inner Join: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, __value:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __scalar_sq_1 [a:UInt32, __value:UInt32;N]\ \n Projection: sq.a, MIN(sq.c) AS __value [a:UInt32, __value:UInt32;N]\ @@ -893,6 +875,7 @@ mod tests { vec![ Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), + Arc::new(EliminateOuterJoin::new()), ], &plan, expected, @@ -928,6 +911,7 @@ mod tests { vec![ Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), + Arc::new(EliminateCrossJoin::new()), ], &plan, expected, @@ -950,10 +934,123 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N]\ - \n CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [__value:Int64;N]\ + \n Inner Join: customer.c_custkey = __scalar_sq_1.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __scalar_sq_1 [__value:Int64;N]\ + \n Projection: MAX(orders.o_custkey) AS __value [__value:Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_multi_rules_optimized_plan_eq_display_indent( + vec![ + Arc::new(ScalarSubqueryToJoin::new()), + Arc::new(ExtractEquijoinPredicate::new()), + Arc::new(EliminateCrossJoin::new()), + ], + &plan, + expected, + ); + Ok(()) + } + + #[test] + fn correlated_scalar_subquery_in_between_clause() -> Result<()> { + let sq1 = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter( + out_ref_col(DataType::Int64, "customer.c_custkey") + .eq(col("orders.o_custkey")), + )? + .aggregate(Vec::::new(), vec![min(col("orders.o_custkey"))])? + .project(vec![min(col("orders.o_custkey"))])? + .build()?, + ); + let sq2 = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .filter( + out_ref_col(DataType::Int64, "customer.c_custkey") + .eq(col("orders.o_custkey")), + )? + .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? + .project(vec![max(col("orders.o_custkey"))])? + .build()?, + ); + + let between_expr = Expr::Between(Between { + expr: Box::new(col("customer.c_custkey")), + negated: false, + low: Box::new(scalar_subquery(sq1)), + high: Box::new(scalar_subquery(sq2)), + }); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(between_expr)? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.__value AND __scalar_sq_2.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N, o_custkey:Int64;N, __value:Int64;N]\ + \n Left Join: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N, o_custkey:Int64;N, __value:Int64;N]\ + \n Left Join: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64;N, __value:Int64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __scalar_sq_1 [o_custkey:Int64, __value:Int64;N]\ + \n Projection: orders.o_custkey, MIN(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MIN(orders.o_custkey)]] [o_custkey:Int64, MIN(orders.o_custkey):Int64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n SubqueryAlias: __scalar_sq_2 [o_custkey:Int64, __value:Int64;N]\ + \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_multi_rules_optimized_plan_eq_display_indent( + vec![ + Arc::new(ScalarSubqueryToJoin::new()), + Arc::new(ExtractEquijoinPredicate::new()), + Arc::new(EliminateOuterJoin::new()), + ], + &plan, + expected, + ); + Ok(()) + } + + #[test] + fn uncorrelated_scalar_subquery_in_between_clause() -> Result<()> { + let sq1 = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .aggregate(Vec::::new(), vec![min(col("orders.o_custkey"))])? + .project(vec![min(col("orders.o_custkey"))])? + .build()?, + ); + let sq2 = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? + .project(vec![max(col("orders.o_custkey"))])? + .build()?, + ); + + let between_expr = Expr::Between(Between { + expr: Box::new(col("customer.c_custkey")), + negated: false, + low: Box::new(scalar_subquery(sq1)), + high: Box::new(scalar_subquery(sq2)), + }); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(between_expr)? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.__value AND __scalar_sq_2.__value [c_custkey:Int64, c_name:Utf8, __value:Int64;N, __value:Int64;N]\ + \n CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N, __value:Int64;N]\ + \n CrossJoin: [c_custkey:Int64, c_name:Utf8, __value:Int64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __scalar_sq_1 [__value:Int64;N]\ + \n Projection: MIN(orders.o_custkey) AS __value [__value:Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[MIN(orders.o_custkey)]] [MIN(orders.o_custkey):Int64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n SubqueryAlias: __scalar_sq_2 [__value:Int64;N]\ \n Projection: MAX(orders.o_custkey) AS __value [__value:Int64;N]\ \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; @@ -962,6 +1059,7 @@ mod tests { vec![ Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), + Arc::new(EliminateOuterJoin::new()), ], &plan, expected, From 0499366e9a697cd1bbdb2a88c5688133dd4d1a98 Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Wed, 17 May 2023 18:05:48 +0800 Subject: [PATCH 2/3] fix simple scalar subquery --- datafusion/core/tests/sql/subqueries.rs | 32 ++++++++++++++++++ .../optimizer/src/scalar_subquery_to_join.rs | 33 ++++++++++++------- 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 30e02b2854895..7d1b84c9387c2 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -690,3 +690,35 @@ async fn support_union_subquery() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn simple_uncorrelated_scalar_subquery() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", true)?; + + let sql = "select (select count(*) from t1) as b"; + + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Projection: __scalar_sq_1.__value AS b [b:Int64;N]", + " SubqueryAlias: __scalar_sq_1 [__value:Int64;N]", + " Projection: COUNT(UInt8(1)) AS __value [__value:Int64;N]", + " Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]", + " TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // assert data + let results = execute_to_batches(&ctx, sql).await; + let expected = vec!["+---+", "| b |", "+---+", "| 4 |", "+---+"]; + assert_batches_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 02bff98ae2e98..26f86c607a22b 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -25,7 +25,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; use datafusion_common::{context, Column, Result}; use datafusion_expr::logical_plan::{JoinType, Subquery}; -use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use log::debug; use std::sync::Arc; @@ -264,21 +264,32 @@ fn optimize_scalar( .build()?; // join our sub query into the main plan - let new_plan = LogicalPlanBuilder::from(filter_input.clone()); let new_plan = if join_filter.is_none() { - // if not correlated, group down to 1 row and cross join on that (preserving row count) - new_plan.cross_join(subqry_plan)? + match filter_input { + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: _, + }) => subqry_plan, + _ => { + // if not correlated, group down to 1 row and cross join on that (preserving row count) + LogicalPlanBuilder::from(filter_input.clone()) + .cross_join(subqry_plan)? + .build()? + } + } } else { // left join if correlated, grouping by the join keys so we don't change row count - new_plan.join( - subqry_plan, - JoinType::Left, - (Vec::::new(), Vec::::new()), - join_filter, - )? + LogicalPlanBuilder::from(filter_input.clone()) + .join( + subqry_plan, + JoinType::Left, + (Vec::::new(), Vec::::new()), + join_filter, + )? + .build()? }; - Ok(Some(new_plan.build()?)) + Ok(Some(new_plan)) } #[cfg(test)] From 44cf728334d421885e40c54c01b4a83212fa866a Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Wed, 17 May 2023 18:18:50 +0800 Subject: [PATCH 3/3] add more UT --- datafusion/core/tests/sql/subqueries.rs | 43 +++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 7d1b84c9387c2..640628e0b5006 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -722,3 +722,46 @@ async fn simple_uncorrelated_scalar_subquery() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn simple_uncorrelated_scalar_subquery2() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", true)?; + + let sql = "select (select count(*) from t1) as b, (select count(1) from t2) as c"; + + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Projection: __scalar_sq_1.__value AS b, __scalar_sq_2.__value AS c [b:Int64;N, c:Int64;N]", + " CrossJoin: [__value:Int64;N, __value:Int64;N]", + " SubqueryAlias: __scalar_sq_1 [__value:Int64;N]", + " Projection: COUNT(UInt8(1)) AS __value [__value:Int64;N]", + " Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]", + " TableScan: t1 projection=[t1_id] [t1_id:UInt32;N]", + " SubqueryAlias: __scalar_sq_2 [__value:Int64;N]", + " Projection: COUNT(Int64(1)) AS __value [__value:Int64;N]", + " Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] [COUNT(Int64(1)):Int64;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + // assert data + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+---+---+", + "| b | c |", + "+---+---+", + "| 4 | 4 |", + "+---+---+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +}