diff --git a/benchmarks/expected-plans/q21.txt b/benchmarks/expected-plans/q21.txt index 3ef6269dee48a..a91632df4e479 100644 --- a/benchmarks/expected-plans/q21.txt +++ b/benchmarks/expected-plans/q21.txt @@ -14,8 +14,10 @@ Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST TableScan: orders projection=[o_orderkey, o_orderstatus] Filter: nation.n_name = Utf8("SAUDI ARABIA") TableScan: nation projection=[n_nationkey, n_name] - SubqueryAlias: l2 - TableScan: lineitem projection=[l_orderkey, l_suppkey] - SubqueryAlias: l3 - Filter: lineitem.l_receiptdate > lineitem.l_commitdate - TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate] \ No newline at end of file + Projection: l2.l_orderkey, l2.l_suppkey + SubqueryAlias: l2 + TableScan: lineitem projection=[l_orderkey, l_suppkey] + Projection: l3.l_orderkey, l3.l_suppkey + SubqueryAlias: l3 + Filter: lineitem.l_receiptdate > lineitem.l_commitdate + TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate] \ No newline at end of file diff --git a/benchmarks/expected-plans/q22.txt b/benchmarks/expected-plans/q22.txt index 0fd7a590ac194..11b438085a0bb 100644 --- a/benchmarks/expected-plans/q22.txt +++ b/benchmarks/expected-plans/q22.txt @@ -8,7 +8,8 @@ Sort: custsale.cntrycode ASC NULLS LAST LeftAnti Join: customer.c_custkey = orders.o_custkey Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) TableScan: customer projection=[c_custkey, c_phone, c_acctbal] - TableScan: orders projection=[o_custkey] + Projection: orders.o_custkey + TableScan: orders projection=[o_custkey] SubqueryAlias: __scalar_sq_1 Projection: AVG(customer.c_acctbal) AS __value Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]] diff --git a/benchmarks/expected-plans/q4.txt b/benchmarks/expected-plans/q4.txt index 3610ae175adc2..e677f3a988a61 100644 --- a/benchmarks/expected-plans/q4.txt +++ b/benchmarks/expected-plans/q4.txt @@ -4,5 +4,6 @@ Sort: orders.o_orderpriority ASC NULLS LAST LeftSemi Join: orders.o_orderkey = lineitem.l_orderkey Filter: orders.o_orderdate >= Date32("8582") AND orders.o_orderdate < Date32("8674") TableScan: orders projection=[o_orderkey, o_orderdate, o_orderpriority] - Filter: lineitem.l_commitdate < lineitem.l_receiptdate - TableScan: lineitem projection=[l_orderkey, l_commitdate, l_receiptdate] \ No newline at end of file + Projection: lineitem.l_orderkey + Filter: lineitem.l_commitdate < lineitem.l_receiptdate + TableScan: lineitem projection=[l_orderkey, l_commitdate, l_receiptdate] \ No newline at end of file diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 3c0aa8b3f3358..6d1b1e91b66ef 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -2187,9 +2187,8 @@ async fn left_anti_join() -> Result<()> { } #[tokio::test] -#[ignore = "Test ignored, will be enabled after fixing the anti join plan bug"] -// https://github.com/apache/arrow-datafusion/issues/4366 async fn error_left_anti_join() -> Result<()> { + // https://github.com/apache/arrow-datafusion/issues/4366 let test_repartition_joins = vec![true, false]; for repartition_joins in test_repartition_joins { let ctx = create_left_semi_anti_join_context_with_null_ids( @@ -2255,19 +2254,20 @@ async fn right_semi_join() -> Result<()> { let dataframe = ctx.sql(sql).await.expect(&msg); let physical_plan = dataframe.create_physical_plan().await?; let expected = if repartition_joins { - vec![ "SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]", - " SortExec: expr=[t1_id@0 ASC NULLS LAST]", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int]", - " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 1 }, op: NotEq, right: Column { name: \"t1_name\", index: 0 } }", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2", - " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", - " MemoryExec: partitions=1, partition_sizes=[1]", + vec!["SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST]", + " SortExec: expr=[t1_id@0 ASC NULLS LAST]", + " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int]", + " CoalesceBatchesExec: target_batch_size=4096", + " HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 1 }, op: NotEq, right: Column { name: \"t1_name\", index: 0 } }", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"t2_id\", index: 0 }], 2), input_partitions=2", + " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", + " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name]", + " MemoryExec: partitions=1, partition_sizes=[1]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2), input_partitions=2", + " RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[1]", ] } else { vec![ @@ -2275,7 +2275,8 @@ async fn right_semi_join() -> Result<()> { " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int]", " CoalesceBatchesExec: target_batch_size=4096", " HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(Column { name: \"t2_id\", index: 0 }, Column { name: \"t1_id\", index: 0 })], filter=BinaryExpr { left: Column { name: \"t2_name\", index: 1 }, op: NotEq, right: Column { name: \"t1_name\", index: 0 } }", - " MemoryExec: partitions=1, partition_sizes=[1]", + " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name]", + " MemoryExec: partitions=1, partition_sizes=[1]", " MemoryExec: partitions=1, partition_sizes=[1]", ] }; @@ -3393,3 +3394,173 @@ async fn left_as_inner_table_nested_loop_join() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn exists_subquery_to_join_expr_filter() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + // exists subquery to LeftSemi join + let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Projection: t2.t2_id [t2_id:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 22 | b | 2 |", + "| 33 | c | 3 |", + "| 44 | d | 4 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} + +#[tokio::test] +async fn exists_subquery_to_join_inner_filter() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + // exists subquery to LeftSemi join + let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2 AND t2.t2_int < 3)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + // `t2.t2_int < 3` will be kept in the subquery filter. + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Projection: t2.t2_id [t2_id:UInt32;N]", + " Filter: t2.t2_int < UInt32(3) [t2_id:UInt32;N, t2_int:UInt32;N]", + " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 44 | d | 4 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} + +#[tokio::test] +async fn exists_subquery_to_join_outer_filter() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + // exists subquery to LeftSemi join + let sql = "SELECT * FROM t1 WHERE EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2 AND t1.t1_int < 3)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + // `t1.t1_int < 3` will be moved to the filter of t1. + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Filter: t1.t1_int < UInt32(3) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Projection: t2.t2_id [t2_id:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 22 | b | 2 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} + +#[tokio::test] +async fn not_exists_subquery_to_join_expr_filter() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + // not exists subquery to LeftAnti join + let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT t2_id FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Projection: t2.t2_id [t2_id:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index bdc4afe901f4f..3f6b160fa9a07 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -16,16 +16,14 @@ // under the License. use crate::optimizer::ApplyOrder; -use crate::utils::{ - conjunction, exprs_to_join_cols, find_join_exprs, split_conjunction, - verify_not_disjunction, -}; +use crate::utils::{conjunction, extract_join_filters, split_conjunction}; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{context, Result}; +use datafusion_common::{Column, DataFusionError, Result}; use datafusion_expr::{ logical_plan::{Filter, JoinType, Subquery}, Expr, LogicalPlan, LogicalPlanBuilder, }; +use std::collections::BTreeSet; use std::sync::Arc; /// Optimizer rule for rewriting subquery filters to joins @@ -144,55 +142,68 @@ fn optimize_exists( query_info: &SubqueryInfo, outer_input: &LogicalPlan, ) -> Result> { - let subqry_filter = match query_info.query.subquery.as_ref() { + let maybe_subqury_filter = match query_info.query.subquery.as_ref() { LogicalPlan::Distinct(subqry_distinct) => match subqry_distinct.input.as_ref() { - LogicalPlan::Projection(subqry_proj) => { - Filter::try_from_plan(&subqry_proj.input) - } + LogicalPlan::Projection(subqry_proj) => &subqry_proj.input, _ => { - // Subquery currently only supports distinct or projection return Ok(None); } }, - LogicalPlan::Projection(subqry_proj) => Filter::try_from_plan(&subqry_proj.input), + LogicalPlan::Projection(subqry_proj) => &subqry_proj.input, _ => { // Subquery currently only supports distinct or projection return Ok(None); } } - .map_err(|e| context!("cannot optimize non-correlated subquery", e))?; - - // split into filters - let subqry_filter_exprs = split_conjunction(&subqry_filter.predicate); - verify_not_disjunction(&subqry_filter_exprs)?; - - // Grab column names to join on - let (col_exprs, other_subqry_exprs) = - find_join_exprs(subqry_filter_exprs, subqry_filter.input.schema())?; - let (outer_cols, subqry_cols, join_filters) = - exprs_to_join_cols(&col_exprs, subqry_filter.input.schema(), false)?; - if subqry_cols.is_empty() || outer_cols.is_empty() { - // cannot optimize non-correlated subquery + .as_ref(); + + // extract join filters + let (join_filters, subquery_input) = extract_join_filters(maybe_subqury_filter)?; + // cannot optimize non-correlated subquery + if join_filters.is_empty() { return Ok(None); } - // build subquery side of join - the thing the subquery was querying - let mut subqry_plan = LogicalPlanBuilder::from(subqry_filter.input.as_ref().clone()); - if let Some(expr) = conjunction(other_subqry_exprs) { - subqry_plan = subqry_plan.filter(expr)? // if the subquery had additional expressions, restore them - } - let subqry_plan = subqry_plan.build()?; + let input_schema = subquery_input.schema(); + let subquery_cols: BTreeSet = + join_filters + .iter() + .try_fold(BTreeSet::new(), |mut cols, expr| { + let using_cols: Vec = expr + .to_columns()? + .into_iter() + .filter(|col| input_schema.field_from_column(col).is_ok()) + .collect::<_>(); + + cols.extend(using_cols); + Result::<_, DataFusionError>::Ok(cols) + })?; + + let projection_exprs: Vec = + subquery_cols.into_iter().map(Expr::Column).collect(); + + let right = LogicalPlanBuilder::from(subquery_input) + .project(projection_exprs)? + .build()?; - let join_keys = (subqry_cols, outer_cols); + let join_filter = conjunction(join_filters); // join our sub query into the main plan let join_type = match query_info.negated { true => JoinType::LeftAnti, false => JoinType::LeftSemi, }; + + // TODO: add Distinct if the original plan is a Distinct. let new_plan = LogicalPlanBuilder::from(outer_input.clone()) - .join(subqry_plan, join_type, join_keys, join_filters)? + .join( + right, + join_type, + (Vec::::new(), Vec::::new()), + join_filter, + )? .build()?; + Ok(Some(new_plan)) } @@ -241,13 +252,14 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8] - LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; - + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -276,13 +288,14 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - LeftSemi Join: orders.o_orderkey = lineitem.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"#; - + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: orders.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n LeftSemi Join: Filter: lineitem.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ + \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_plan_eq(&plan, expected) } @@ -305,21 +318,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } - /// Test for correlated exists subquery with no columns in schema #[test] fn exists_subquery_no_cols() -> Result<()> { let sq = Arc::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) - .filter(col("customer.c_custkey").eq(col("customer.c_custkey")))? + .filter(col("customer.c_custkey").eq(lit(1u32)))? .project(vec![col("orders.o_custkey")])? .build()?, ); @@ -329,7 +342,14 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan) + // Other rule will pushdown `customer.c_custkey = 1`, + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: []\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_plan_eq(&plan, expected) } /// Test for exists subquery with both columns in schema @@ -365,7 +385,13 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan) + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey != orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_plan_eq(&plan, expected) } /// Test for correlated exists subquery less than @@ -383,10 +409,13 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = r#"can't optimize < column comparison"#; + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey < orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected); - Ok(()) + assert_plan_eq(&plan, expected) } /// Test for correlated exists subquery filter with subquery disjunction @@ -408,10 +437,13 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = r#"Optimizing disjunctions not supported!"#; + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey OR orders.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected); - Ok(()) + assert_plan_eq(&plan, expected) } /// Test for correlated exists without projection @@ -446,11 +478,11 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - // Doesn't matter we projected an expression, just that we returned a result - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -469,11 +501,12 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] - LeftSemi Join: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n Projection: orders.o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_plan_eq(&plan, expected) } @@ -520,10 +553,11 @@ mod tests { .project(vec![col("test.c")])? .build()?; - let expected = r#"Projection: test.c [c:UInt32] - LeftSemi Join: test.a = sq.a [a:UInt32, b:UInt32, c:UInt32] - TableScan: test [a:UInt32, b:UInt32, c:UInt32] - TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#; + let expected = "Projection: test.c [c:UInt32]\ + \n LeftSemi Join: Filter: test.a = sq.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq.a [a:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } @@ -537,10 +571,7 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "cannot optimize non-correlated subquery"; - - assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected); - Ok(()) + assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan) } /// Test for single NOT exists subquery filter @@ -552,10 +583,7 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "cannot optimize non-correlated subquery"; - - assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected); - Ok(()) + assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan) } #[test] @@ -583,18 +611,37 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: test.a = sq2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: test.a = sq1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; + \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.a = sq2.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.a = sq1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq1.a [a:UInt32]\ + \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq2.a [a:UInt32]\ + \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelateWhereExists::new()), - &plan, - expected, - ); - Ok(()) + assert_plan_eq(&plan, expected) + } + + #[test] + fn exists_subquery_expr_filter() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter((lit(1u32) + col("sq.a")).gt(col("test.a") * lit(2u32)))? + .project(vec![lit(1u32)])? + .build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq.a [a:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_plan_eq(&plan, expected) } } diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index 7a9a75ff45bbd..c8ff65f125234 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -17,12 +17,11 @@ use crate::alias::AliasGenerator; use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, only_or_err, split_conjunction}; +use crate::utils::{conjunction, extract_join_filters, only_or_err, split_conjunction}; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{context, Column, DataFusionError, Result}; use datafusion_expr::expr_rewriter::{replace_col, unnormalize_col}; use datafusion_expr::logical_plan::{JoinType, Projection, Subquery}; -use datafusion_expr::utils::check_all_columns_from_schema; use datafusion_expr::{Expr, Filter, LogicalPlan, LogicalPlanBuilder}; use log::debug; use std::collections::{BTreeSet, HashMap}; @@ -220,34 +219,6 @@ fn optimize_where_in( Ok(new_plan) } -fn extract_join_filters(maybe_filter: &LogicalPlan) -> Result<(Vec, LogicalPlan)> { - if let LogicalPlan::Filter(plan_filter) = maybe_filter { - let input_schema = plan_filter.input.schema(); - let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); - - let mut join_filters: Vec = vec![]; - let mut subquery_filters: Vec = vec![]; - for expr in subquery_filter_exprs { - let cols = expr.to_columns()?; - if check_all_columns_from_schema(&cols, input_schema.clone())? { - subquery_filters.push(expr.clone()); - } else { - join_filters.push(expr.clone()) - } - } - - // if the subquery still has filter expressions, restore them. - let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone()); - if let Some(expr) = conjunction(subquery_filters) { - plan = plan.filter(expr)? - } - - Ok((join_filters, plan.build()?)) - } else { - Ok((vec![], maybe_filter.clone())) - } -} - fn remove_duplicated_filter(filters: Vec, in_predicate: Expr) -> Vec { filters .into_iter() diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 4d9d10d51a7d1..747fa62089eda 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -23,10 +23,11 @@ use datafusion_common::{DFSchema, Result}; use datafusion_expr::expr::{BinaryExpr, Sort}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; use datafusion_expr::expr_visitor::inspect_expr_pre; +use datafusion_expr::logical_plan::LogicalPlanBuilder; +use datafusion_expr::utils::{check_all_columns_from_schema, from_plan}; use datafusion_expr::{ and, logical_plan::{Filter, LogicalPlan}, - utils::from_plan, Expr, Operator, }; use std::collections::HashSet; @@ -468,6 +469,42 @@ pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { }) } +/// Extract join predicates from the correclated subquery. +/// The join predicate means that the expression references columns +/// from both the subquery and outer table or only from the outer table. +/// +/// Returns join predicates and subquery(extracted). +/// ``` +pub(crate) fn extract_join_filters( + maybe_filter: &LogicalPlan, +) -> Result<(Vec, LogicalPlan)> { + if let LogicalPlan::Filter(plan_filter) = maybe_filter { + let input_schema = plan_filter.input.schema(); + let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); + + let mut join_filters: Vec = vec![]; + let mut subquery_filters: Vec = vec![]; + for expr in subquery_filter_exprs { + let cols = expr.to_columns()?; + if check_all_columns_from_schema(&cols, input_schema.clone())? { + subquery_filters.push(expr.clone()); + } else { + join_filters.push(expr.clone()) + } + } + + // if the subquery still has filter expressions, restore them. + let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone()); + if let Some(expr) = conjunction(subquery_filters) { + plan = plan.filter(expr)? + } + + Ok((join_filters, plan.build()?)) + } else { + Ok((vec![], maybe_filter.clone())) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index f901e33d41d67..eac849e347811 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -121,8 +121,9 @@ fn semi_join_with_join_filter() -> Result<()> { let expected = "Projection: test.col_utf8\ \n LeftSemi Join: test.col_int32 = t2.col_int32 Filter: test.col_uint32 != t2.col_uint32\ \n TableScan: test projection=[col_int32, col_uint32, col_utf8]\ - \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32, col_uint32, col_utf8]"; + \n Projection: t2.col_int32, t2.col_uint32\ + \n SubqueryAlias: t2\ + \n TableScan: test projection=[col_int32, col_uint32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -137,8 +138,9 @@ fn anti_join_with_join_filter() -> Result<()> { let expected = "Projection: test.col_utf8\ \n LeftAnti Join: test.col_int32 = t2.col_int32 Filter: test.col_uint32 != t2.col_uint32\ \n TableScan: test projection=[col_int32, col_uint32, col_utf8]\ - \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32, col_uint32, col_utf8]"; + \n Projection: t2.col_int32, t2.col_uint32\ + \n SubqueryAlias: t2\ + \n TableScan: test projection=[col_int32, col_uint32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -152,8 +154,9 @@ fn where_exists_distinct() -> Result<()> { let expected = "Projection: test.col_int32\ \n LeftSemi Join: test.col_int32 = t2.col_int32\ \n TableScan: test projection=[col_int32]\ - \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32]"; + \n Projection: t2.col_int32\ + \n SubqueryAlias: t2\ + \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) }