From 3219bf53c1c2c389060dd26e13f4675bed47982e Mon Sep 17 00:00:00 2001 From: ygf11 Date: Mon, 13 Feb 2023 07:00:56 -0500 Subject: [PATCH 1/6] Support non-tuple expression for exists-subquery to join --- datafusion/core/tests/sql/joins.rs | 30 +-- .../optimizer/src/decorrelate_where_exists.rs | 234 +++++++++++------- .../optimizer/src/decorrelate_where_in.rs | 31 +-- datafusion/optimizer/src/utils.rs | 33 ++- 4 files changed, 198 insertions(+), 130 deletions(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index bcb90d4517846..e154d2b247c3e 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -2255,19 +2255,20 @@ async fn right_semi_join() -> Result<()> { let dataframe = ctx.sql(sql).await.expect(&msg); let physical_plan = dataframe.create_physical_plan().await?; let expected = if repartition_joins { - vec![ "SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]", - " SortExec: [t1_id@0 ASC NULLS LAST]", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 1 }, op: NotEq, right: Column { name: \"t1_name\", index: 0 } }", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", + vec!["SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]", + " SortExec: [t1_id@0 ASC NULLS LAST]", + " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int]", + " CoalesceBatchesExec: target_batch_size=4096", + " HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 1 }, op: NotEq, right: Column { name: \"t1_name\", index: 0 } }", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2", + " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", + " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name]", + " MemoryExec: partitions=1, partition_sizes=[1]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2", + " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[1]", ] } else { vec![ @@ -2275,7 +2276,8 @@ async fn right_semi_join() -> Result<()> { " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int]", " CoalesceBatchesExec: target_batch_size=4096", " HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 1 }, op: NotEq, right: Column { name: \"t1_name\", index: 0 } }", - " MemoryExec: partitions=1, partition_sizes=[1]", + " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name]", + " MemoryExec: partitions=1, partition_sizes=[1]", " MemoryExec: partitions=1, partition_sizes=[1]", ] }; diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index bdc4afe901f4f..3c5e256de5796 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -17,15 +17,16 @@ use crate::optimizer::ApplyOrder; use crate::utils::{ - conjunction, exprs_to_join_cols, find_join_exprs, split_conjunction, - verify_not_disjunction, + conjunction, extract_join_filters, + split_conjunction, }; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{context, Result}; +use datafusion_common::{Column, DataFusionError, Result}; use datafusion_expr::{ logical_plan::{Filter, JoinType, Subquery}, Expr, LogicalPlan, LogicalPlanBuilder, }; +use std::collections::BTreeSet; use std::sync::Arc; /// Optimizer rule for rewriting subquery filters to joins @@ -144,55 +145,76 @@ fn optimize_exists( query_info: &SubqueryInfo, outer_input: &LogicalPlan, ) -> Result> { - let subqry_filter = match query_info.query.subquery.as_ref() { + let maybe_subqury_filter = match query_info.query.subquery.as_ref() { LogicalPlan::Distinct(subqry_distinct) => match subqry_distinct.input.as_ref() { LogicalPlan::Projection(subqry_proj) => { - Filter::try_from_plan(&subqry_proj.input) + &subqry_proj.input } _ => { - // Subquery currently only supports distinct or projection return Ok(None); } }, - LogicalPlan::Projection(subqry_proj) => Filter::try_from_plan(&subqry_proj.input), + LogicalPlan::Projection(subqry_proj) => { + &subqry_proj.input + } _ => { // Subquery currently only supports distinct or projection return Ok(None); } } - .map_err(|e| context!("cannot optimize non-correlated subquery", e))?; - - // split into filters - let subqry_filter_exprs = split_conjunction(&subqry_filter.predicate); - verify_not_disjunction(&subqry_filter_exprs)?; - - // Grab column names to join on - let (col_exprs, other_subqry_exprs) = - find_join_exprs(subqry_filter_exprs, subqry_filter.input.schema())?; - let (outer_cols, subqry_cols, join_filters) = - exprs_to_join_cols(&col_exprs, subqry_filter.input.schema(), false)?; - if subqry_cols.is_empty() || outer_cols.is_empty() { - // cannot optimize non-correlated subquery + .as_ref(); + + // extract join filters + let (join_filters, subquery_input) = extract_join_filters(maybe_subqury_filter)?; + // cannot optimize non-correlated subquery + if join_filters.is_empty() { return Ok(None); } - // build subquery side of join - the thing the subquery was querying - let mut subqry_plan = LogicalPlanBuilder::from(subqry_filter.input.as_ref().clone()); - if let Some(expr) = conjunction(other_subqry_exprs) { - subqry_plan = subqry_plan.filter(expr)? // if the subquery had additional expressions, restore them + let input_schema = subquery_input.schema(); + let subquery_cols: BTreeSet = + join_filters + .iter() + .try_fold(BTreeSet::new(), |mut cols, expr| { + let using_cols: Vec = expr + .to_columns()? + .into_iter() + .filter(|col| input_schema.field_from_column(col).is_ok()) + .collect::<_>(); + + cols.extend(using_cols); + Result::<_, DataFusionError>::Ok(cols) + })?; + + // cannot optimize non-correlated subquery + if subquery_cols.is_empty() { + return Ok(None); } - let subqry_plan = subqry_plan.build()?; - let join_keys = (subqry_cols, outer_cols); + let projection_exprs: Vec = + subquery_cols.into_iter().map(Expr::Column).collect(); + + let right = LogicalPlanBuilder::from(subquery_input) + .project(projection_exprs)? + .build()?; + + let join_filter = conjunction(join_filters); // join our sub query into the main plan let join_type = match query_info.negated { true => JoinType::LeftAnti, false => JoinType::LeftSemi, }; + let new_plan = LogicalPlanBuilder::from(outer_input.clone()) - .join(subqry_plan, join_type, join_keys, join_filters)? + .join( + right, + join_type, + (Vec::::new(), Vec::::new()), + join_filter, + )? .build()?; + Ok(Some(new_plan)) } @@ -241,12 +263,15 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8] - LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + let expected = + "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -276,12 +301,15 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - LeftSemi Join: orders.o_orderkey = lineitem.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"#; + let expected = + "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n LeftSemi Join: Filter: lineitem.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ + \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_plan_eq(&plan, expected) } @@ -305,11 +333,13 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + let expected = + "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -363,9 +393,15 @@ mod tests { let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) .filter(exists(sq))? .project(vec![col("customer.c_custkey")])? - .build()?; + .build()?; - assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan) + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey != orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_plan_eq(&plan, expected) } /// Test for correlated exists subquery less than @@ -383,10 +419,14 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = r#"can't optimize < column comparison"#; + let expected = + "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey < orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected); - Ok(()) + assert_plan_eq(&plan, expected) } /// Test for correlated exists subquery filter with subquery disjunction @@ -408,10 +448,15 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = r#"Optimizing disjunctions not supported!"#; + let expected = + "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey OR orders.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_plan_eq(&plan, expected) - assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected); - Ok(()) } /// Test for correlated exists without projection @@ -446,11 +491,12 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - // Doesn't matter we projected an expression, just that we returned a result - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + let expected = + "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -469,11 +515,13 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] - LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + let expected = + "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -520,10 +568,12 @@ mod tests { .project(vec![col("test.c")])? .build()?; - let expected = r#"Projection: test.c [c:UInt32] - LeftSemi Join: test.a = sq.a [a:UInt32, b:UInt32, c:UInt32] - TableScan: test [a:UInt32, b:UInt32, c:UInt32] - TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#; + let expected = + "Projection: test.c [c:UInt32]\ + \n LeftSemi Join: Filter: test.a = sq.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq.a [a:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } @@ -537,10 +587,7 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "cannot optimize non-correlated subquery"; - - assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected); - Ok(()) + assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan) } /// Test for single NOT exists subquery filter @@ -552,10 +599,7 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "cannot optimize non-correlated subquery"; - - assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected); - Ok(()) + assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan) } #[test] @@ -582,19 +626,39 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ + let expected = + "Projection: test.b [b:UInt32]\ \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: test.a = sq2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: test.a = sq1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.a = sq2.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.a = sq1.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; + \n Projection: sq1.a [a:UInt32]\ + \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq2.a [a:UInt32]\ + \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelateWhereExists::new()), - &plan, - expected, - ); - Ok(()) + assert_plan_eq(&plan, expected) + } + + #[test] + fn exists_subquery_expr_filter() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter((lit(1u32) + col("sq.a")).gt(col("test.a") * lit(2u32)))? + .project(vec![lit(1u32)])? + .build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq.a [a:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_plan_eq(&plan, expected) } } diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index 35164dcadcd60..c8ff65f125234 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -17,12 +17,11 @@ use crate::alias::AliasGenerator; use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, only_or_err, split_conjunction}; +use crate::utils::{conjunction, extract_join_filters, only_or_err, split_conjunction}; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{context, Column, DataFusionError, Result}; use datafusion_expr::expr_rewriter::{replace_col, unnormalize_col}; use datafusion_expr::logical_plan::{JoinType, Projection, Subquery}; -use datafusion_expr::utils::check_all_column_from_schema; use datafusion_expr::{Expr, Filter, LogicalPlan, LogicalPlanBuilder}; use log::debug; use std::collections::{BTreeSet, HashMap}; @@ -220,34 +219,6 @@ fn optimize_where_in( Ok(new_plan) } -fn extract_join_filters(maybe_filter: &LogicalPlan) -> Result<(Vec, LogicalPlan)> { - if let LogicalPlan::Filter(plan_filter) = maybe_filter { - let input_schema = plan_filter.input.schema(); - let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); - - let mut join_filters: Vec = vec![]; - let mut subquery_filters: Vec = vec![]; - for expr in subquery_filter_exprs { - let cols = expr.to_columns()?; - if check_all_column_from_schema(&cols, input_schema.clone()) { - subquery_filters.push(expr.clone()); - } else { - join_filters.push(expr.clone()) - } - } - - // if the subquery still has filter expressions, restore them. - let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone()); - if let Some(expr) = conjunction(subquery_filters) { - plan = plan.filter(expr)? - } - - Ok((join_filters, plan.build()?)) - } else { - Ok((vec![], maybe_filter.clone())) - } -} - fn remove_duplicated_filter(filters: Vec, in_predicate: Expr) -> Vec { filters .into_iter() diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 4d9d10d51a7d1..920a898df3ba6 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -23,10 +23,11 @@ use datafusion_common::{DFSchema, Result}; use datafusion_expr::expr::{BinaryExpr, Sort}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; use datafusion_expr::expr_visitor::inspect_expr_pre; +use datafusion_expr::logical_plan::LogicalPlanBuilder; +use datafusion_expr::utils::{check_all_column_from_schema, from_plan}; use datafusion_expr::{ and, logical_plan::{Filter, LogicalPlan}, - utils::from_plan, Expr, Operator, }; use std::collections::HashSet; @@ -468,6 +469,36 @@ pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { }) } +pub(crate) fn extract_join_filters( + maybe_filter: &LogicalPlan, +) -> Result<(Vec, LogicalPlan)> { + if let LogicalPlan::Filter(plan_filter) = maybe_filter { + let input_schema = plan_filter.input.schema(); + let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); + + let mut join_filters: Vec = vec![]; + let mut subquery_filters: Vec = vec![]; + for expr in subquery_filter_exprs { + let cols = expr.to_columns()?; + if check_all_column_from_schema(&cols, input_schema.clone()) { + subquery_filters.push(expr.clone()); + } else { + join_filters.push(expr.clone()) + } + } + + // if the subquery still has filter expressions, restore them. + let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone()); + if let Some(expr) = conjunction(subquery_filters) { + plan = plan.filter(expr)? + } + + Ok((join_filters, plan.build()?)) + } else { + Ok((vec![], maybe_filter.clone())) + } +} + #[cfg(test)] mod tests { use super::*; From e4d2c3d2a9694ec6866acd6216e4f328f455b32c Mon Sep 17 00:00:00 2001 From: ygf11 Date: Tue, 14 Feb 2023 02:42:33 -0500 Subject: [PATCH 2/6] fix tests --- benchmarks/expected-plans/q21.txt | 12 +++++++----- benchmarks/expected-plans/q22.txt | 3 ++- benchmarks/expected-plans/q4.txt | 5 +++-- datafusion/optimizer/tests/integration-test.rs | 15 +++++++++------ 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/benchmarks/expected-plans/q21.txt b/benchmarks/expected-plans/q21.txt index 3ef6269dee48a..a91632df4e479 100644 --- a/benchmarks/expected-plans/q21.txt +++ b/benchmarks/expected-plans/q21.txt @@ -14,8 +14,10 @@ Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST TableScan: orders projection=[o_orderkey, o_orderstatus] Filter: nation.n_name = Utf8("SAUDI ARABIA") TableScan: nation projection=[n_nationkey, n_name] - SubqueryAlias: l2 - TableScan: lineitem projection=[l_orderkey, l_suppkey] - SubqueryAlias: l3 - Filter: lineitem.l_receiptdate > lineitem.l_commitdate - TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate] \ No newline at end of file + Projection: l2.l_orderkey, l2.l_suppkey + SubqueryAlias: l2 + TableScan: lineitem projection=[l_orderkey, l_suppkey] + Projection: l3.l_orderkey, l3.l_suppkey + SubqueryAlias: l3 + Filter: lineitem.l_receiptdate > lineitem.l_commitdate + TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate] \ No newline at end of file diff --git a/benchmarks/expected-plans/q22.txt b/benchmarks/expected-plans/q22.txt index 0fd7a590ac194..11b438085a0bb 100644 --- a/benchmarks/expected-plans/q22.txt +++ b/benchmarks/expected-plans/q22.txt @@ -8,7 +8,8 @@ Sort: custsale.cntrycode ASC NULLS LAST LeftAnti Join: customer.c_custkey = orders.o_custkey Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) TableScan: customer projection=[c_custkey, c_phone, c_acctbal] - TableScan: orders projection=[o_custkey] + Projection: orders.o_custkey + TableScan: orders projection=[o_custkey] SubqueryAlias: __scalar_sq_1 Projection: AVG(customer.c_acctbal) AS __value Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]] diff --git a/benchmarks/expected-plans/q4.txt b/benchmarks/expected-plans/q4.txt index 3610ae175adc2..e677f3a988a61 100644 --- a/benchmarks/expected-plans/q4.txt +++ b/benchmarks/expected-plans/q4.txt @@ -4,5 +4,6 @@ Sort: orders.o_orderpriority ASC NULLS LAST LeftSemi Join: orders.o_orderkey = lineitem.l_orderkey Filter: orders.o_orderdate >= Date32("8582") AND orders.o_orderdate < Date32("8674") TableScan: orders projection=[o_orderkey, o_orderdate, o_orderpriority] - Filter: lineitem.l_commitdate < lineitem.l_receiptdate - TableScan: lineitem projection=[l_orderkey, l_commitdate, l_receiptdate] \ No newline at end of file + Projection: lineitem.l_orderkey + Filter: lineitem.l_commitdate < lineitem.l_receiptdate + TableScan: lineitem projection=[l_orderkey, l_commitdate, l_receiptdate] \ No newline at end of file diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index f901e33d41d67..6d5b23238d30a 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -121,8 +121,9 @@ fn semi_join_with_join_filter() -> Result<()> { let expected = "Projection: test.col_utf8\ \n LeftSemi Join: test.col_int32 = t2.col_int32 Filter: test.col_uint32 != t2.col_uint32\ \n TableScan: test projection=[col_int32, col_uint32, col_utf8]\ - \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32, col_uint32, col_utf8]"; + \n Projection: t2.col_int32, t2.col_uint32\ + \n SubqueryAlias: t2\ + \n TableScan: test projection=[col_int32, col_uint32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -137,8 +138,9 @@ fn anti_join_with_join_filter() -> Result<()> { let expected = "Projection: test.col_utf8\ \n LeftAnti Join: test.col_int32 = t2.col_int32 Filter: test.col_uint32 != t2.col_uint32\ \n TableScan: test projection=[col_int32, col_uint32, col_utf8]\ - \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32, col_uint32, col_utf8]"; + \n Projection: t2.col_int32, t2.col_uint32\ + \n SubqueryAlias: t2\ + \n TableScan: test projection=[col_int32, col_uint32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -152,8 +154,9 @@ fn where_exists_distinct() -> Result<()> { let expected = "Projection: test.col_int32\ \n LeftSemi Join: test.col_int32 = t2.col_int32\ \n TableScan: test projection=[col_int32]\ - \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32]"; + \n Projection: t2.col_int32\ + \n SubqueryAlias: t2\ + \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } From 5e7f8e257f85070e04fde392d7c15879a6346fb3 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Tue, 14 Feb 2023 06:53:09 -0500 Subject: [PATCH 3/6] add tests --- datafusion/core/tests/sql/joins.rs | 171 +++++++++++++++++- .../optimizer/src/decorrelate_where_exists.rs | 158 ++++++++-------- 2 files changed, 244 insertions(+), 85 deletions(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index e154d2b247c3e..1422f5f9a7706 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -2187,7 +2187,6 @@ async fn left_anti_join() -> Result<()> { } #[tokio::test] -#[ignore = "Test ignored, will be enabled after fixing the anti join plan bug"] // https://github.com/apache/arrow-datafusion/issues/4366 async fn error_left_anti_join() -> Result<()> { let test_repartition_joins = vec![true, false]; @@ -3395,3 +3394,173 @@ async fn left_as_inner_table_nested_loop_join() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn exists_subquery_to_join_expr_filter() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + // exists subquery to LeftSemi join + let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Projection: t2.t2_id [t2_id:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 22 | b | 2 |", + "| 33 | c | 3 |", + "| 44 | d | 4 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} + +#[tokio::test] +async fn exists_subquery_to_join_inner_filter() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + // exists subquery to LeftSemi join + let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2 AND t2.t2_int < 3)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + // `t2.t2_int < 3` will be kept in the subquery filter. + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Projection: t2.t2_id [t2_id:UInt32;N]", + " Filter: t2.t2_int < UInt32(3) [t2_id:UInt32;N, t2_int:UInt32;N]", + " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 44 | d | 4 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} + +#[tokio::test] +async fn exists_subquery_to_join_outer_filter() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + // exists subquery to LeftSemi join + let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2 AND t1.t1_int < 3)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + // `t1.t1_int < 3` will be moved to the filter of t1. + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Filter: t1.t1_int < UInt32(3) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Projection: t2.t2_id [t2_id:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 22 | b | 2 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} + +#[tokio::test] +async fn not_exists_subquery_to_join_expr_filter() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + // not exists subquery to LeftAnti join + let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Projection: t2.t2_id [t2_id:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index 3c5e256de5796..fe75eef732930 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -186,11 +186,6 @@ fn optimize_exists( Result::<_, DataFusionError>::Ok(cols) })?; - // cannot optimize non-correlated subquery - if subquery_cols.is_empty() { - return Ok(None); - } - let projection_exprs: Vec = subquery_cols.into_iter().map(Expr::Column).collect(); @@ -263,16 +258,14 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = - "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -301,16 +294,14 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = - "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: lineitem.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ - \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n LeftSemi Join: Filter: lineitem.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ + \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_plan_eq(&plan, expected) } @@ -333,23 +324,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = - "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } - /// Test for correlated exists subquery with no columns in schema #[test] fn exists_subquery_no_cols() -> Result<()> { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) - .filter(col("customer.c_custkey").eq(col("customer.c_custkey")))? + .filter(col("customer.c_custkey").eq(lit(1u32)))? .project(vec![col("orders.o_custkey")])? .build()?, ); @@ -359,7 +348,14 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan) + // `customer.c_custkey = 1` will pushdown by other rule. + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: []\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_plan_eq(&plan, expected) } /// Test for exists subquery with both columns in schema @@ -396,10 +392,10 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey != orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n LeftSemi Join: Filter: customer.c_custkey != orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -419,12 +415,11 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = - "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey < orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey < orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -448,12 +443,11 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = - "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey OR orders.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey OR orders.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) @@ -491,12 +485,11 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = - "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -515,13 +508,12 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = - "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -568,12 +560,11 @@ mod tests { .project(vec![col("test.c")])? .build()?; - let expected = - "Projection: test.c [c:UInt32]\ - \n LeftSemi Join: Filter: test.a = sq.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n Projection: sq.a [a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Projection: test.c [c:UInt32]\ + \n LeftSemi Join: Filter: test.a = sq.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq.a [a:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } @@ -626,16 +617,15 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = - "Projection: test.b [b:UInt32]\ - \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.a = sq2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.a = sq1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n Projection: sq1.a [a:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n Projection: sq2.a [a:UInt32]\ - \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; + let expected = "Projection: test.b [b:UInt32]\ + \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.a = sq2.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.a = sq1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq1.a [a:UInt32]\ + \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq2.a [a:UInt32]\ + \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } @@ -654,10 +644,10 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n Projection: sq.a [a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq.a [a:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } From 9bc38cc9b618282f0f15b252d38fe2d5370c3262 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Tue, 14 Feb 2023 07:28:40 -0500 Subject: [PATCH 4/6] add comments --- .../optimizer/src/decorrelate_where_exists.rs | 73 +++++++++---------- datafusion/optimizer/src/utils.rs | 6 ++ .../optimizer/tests/integration-test.rs | 6 +- 3 files changed, 42 insertions(+), 43 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index fe75eef732930..7e0c163f2c1d0 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -16,10 +16,7 @@ // under the License. use crate::optimizer::ApplyOrder; -use crate::utils::{ - conjunction, extract_join_filters, - split_conjunction, -}; +use crate::utils::{conjunction, extract_join_filters, split_conjunction}; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DataFusionError, Result}; use datafusion_expr::{ @@ -147,16 +144,12 @@ fn optimize_exists( ) -> Result> { let maybe_subqury_filter = match query_info.query.subquery.as_ref() { LogicalPlan::Distinct(subqry_distinct) => match subqry_distinct.input.as_ref() { - LogicalPlan::Projection(subqry_proj) => { - &subqry_proj.input - } + LogicalPlan::Projection(subqry_proj) => &subqry_proj.input, _ => { return Ok(None); } }, - LogicalPlan::Projection(subqry_proj) => { - &subqry_proj.input - } + LogicalPlan::Projection(subqry_proj) => &subqry_proj.input, _ => { // Subquery currently only supports distinct or projection return Ok(None); @@ -201,6 +194,7 @@ fn optimize_exists( false => JoinType::LeftSemi, }; + // TODO: add Distinct if the original plan is a Distinct. let new_plan = LogicalPlanBuilder::from(outer_input.clone()) .join( right, @@ -259,13 +253,13 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -295,13 +289,13 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: lineitem.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ - \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; + \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n LeftSemi Join: Filter: lineitem.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ + \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_plan_eq(&plan, expected) } @@ -324,7 +318,7 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ @@ -348,12 +342,12 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - // `customer.c_custkey = 1` will pushdown by other rule. + // Other rule will pushdown `customer.c_custkey = 1`, let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ \n LeftSemi Join: Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n Projection: []\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -389,7 +383,7 @@ mod tests { let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) .filter(exists(sq))? .project(vec![col("customer.c_custkey")])? - .build()?; + .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ \n LeftSemi Join: Filter: customer.c_custkey != orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ @@ -450,7 +444,6 @@ mod tests { \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) - } /// Test for correlated exists without projection @@ -485,11 +478,11 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -508,12 +501,12 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 920a898df3ba6..2670401216fd8 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -469,6 +469,12 @@ pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { }) } +/// Extract join predicates from the correclated subquery. +/// The join predicate means that the expression references columns +/// from both the subquery and outer table or only from the outer table. +/// +/// Returns join predicates and subquery(extracted). +/// ``` pub(crate) fn extract_join_filters( maybe_filter: &LogicalPlan, ) -> Result<(Vec, LogicalPlan)> { diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 6d5b23238d30a..eac849e347811 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -123,7 +123,7 @@ fn semi_join_with_join_filter() -> Result<()> { \n TableScan: test projection=[col_int32, col_uint32, col_utf8]\ \n Projection: t2.col_int32, t2.col_uint32\ \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32, col_uint32]"; + \n TableScan: test projection=[col_int32, col_uint32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -140,7 +140,7 @@ fn anti_join_with_join_filter() -> Result<()> { \n TableScan: test projection=[col_int32, col_uint32, col_utf8]\ \n Projection: t2.col_int32, t2.col_uint32\ \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32, col_uint32]"; + \n TableScan: test projection=[col_int32, col_uint32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -156,7 +156,7 @@ fn where_exists_distinct() -> Result<()> { \n TableScan: test projection=[col_int32]\ \n Projection: t2.col_int32\ \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32]"; + \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } From a2f5b51120ca3b8fe90fdb436f98f1fc678d9ba6 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Tue, 14 Feb 2023 07:48:41 -0500 Subject: [PATCH 5/6] fix tests --- .../optimizer/src/decorrelate_where_exists.rs | 120 +++++++++--------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index 7e0c163f2c1d0..3f6b160fa9a07 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -253,13 +253,13 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -289,13 +289,13 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: lineitem.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ - \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; + \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n LeftSemi Join: Filter: lineitem.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ + \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_plan_eq(&plan, expected) } @@ -319,11 +319,11 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -344,10 +344,10 @@ mod tests { // Other rule will pushdown `customer.c_custkey = 1`, let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: []\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n LeftSemi Join: Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: []\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -386,10 +386,10 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey != orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n LeftSemi Join: Filter: customer.c_custkey != orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -410,10 +410,10 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey < orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n LeftSemi Join: Filter: customer.c_custkey < orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -438,10 +438,10 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey OR orders.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey OR orders.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -479,10 +479,10 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -502,11 +502,11 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -554,10 +554,10 @@ mod tests { .build()?; let expected = "Projection: test.c [c:UInt32]\ - \n LeftSemi Join: Filter: test.a = sq.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n Projection: sq.a [a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + \n LeftSemi Join: Filter: test.a = sq.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq.a [a:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } @@ -611,14 +611,14 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.a = sq2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.a = sq1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n Projection: sq1.a [a:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n Projection: sq2.a [a:UInt32]\ - \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; + \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.a = sq2.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.a = sq1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq1.a [a:UInt32]\ + \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq2.a [a:UInt32]\ + \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } @@ -637,10 +637,10 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n Projection: sq.a [a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq.a [a:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } From f2435aa028937a9d331753f28491e0d0d44fed05 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Sat, 18 Feb 2023 06:01:46 -0500 Subject: [PATCH 6/6] fix test comment --- datafusion/core/tests/sql/joins.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 8ab65c2e8d57e..6d1b1e91b66ef 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -2187,8 +2187,8 @@ async fn left_anti_join() -> Result<()> { } #[tokio::test] -// https://github.com/apache/arrow-datafusion/issues/4366 async fn error_left_anti_join() -> Result<()> { + // https://github.com/apache/arrow-datafusion/issues/4366 let test_repartition_joins = vec![true, false]; for repartition_joins in test_repartition_joins { let ctx = create_left_semi_anti_join_context_with_null_ids(