diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index e021cdc59c436..b9b4c5cf31d58 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -3518,3 +3518,122 @@ async fn not_exists_subquery_to_join_expr_filter() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn exists_distinct_subquery_to_join() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT DISTINCT t2_int FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Aggregate: groupBy=[[t2.t2_id]], aggr=[[]] [t2_id:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} + +#[tokio::test] +async fn exists_distinct_subquery_to_join_with_expr() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + // `t2_id + t2_int` is in the subquery project. + let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT DISTINCT t2_id + t2_int, t2_int FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Aggregate: groupBy=[[t2.t2_id]], aggr=[[]] [t2_id:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} + +#[tokio::test] +async fn exists_distinct_subquery_to_join_with_literal() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + // `1` is in the subquery project. + let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT DISTINCT 1, t2_int FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Aggregate: groupBy=[[t2.t2_id]], aggr=[[]] [t2_id:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index 023629b97ee02..ffc12c1582351 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -16,14 +16,16 @@ // under the License. use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, extract_join_filters, split_conjunction}; +use crate::utils::{ + collect_subquery_cols, conjunction, extract_join_filters, split_conjunction, +}; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DataFusionError, Result}; use datafusion_expr::{ - logical_plan::{Filter, JoinType, Subquery}, + logical_plan::{Distinct, Filter, JoinType, Subquery}, Expr, LogicalPlan, LogicalPlanBuilder, }; -use std::collections::BTreeSet; + use std::sync::Arc; /// Optimizer rule for rewriting subquery filters to joins @@ -142,69 +144,70 @@ fn optimize_exists( query_info: &SubqueryInfo, outer_input: &LogicalPlan, ) -> Result> { - let maybe_subqury_filter = match query_info.query.subquery.as_ref() { - LogicalPlan::Distinct(subqry_distinct) => match subqry_distinct.input.as_ref() { - LogicalPlan::Projection(subqry_proj) => &subqry_proj.input, - _ => { - return Ok(None); - } - }, - LogicalPlan::Projection(subqry_proj) => &subqry_proj.input, - _ => { - // Subquery currently only supports distinct or projection - return Ok(None); - } - } - .as_ref(); + let subquery = query_info.query.subquery.as_ref(); + if let Some((join_filter, optimized_subquery)) = optimize_subquery(subquery)? { + // join our sub query into the main plan + let join_type = match query_info.negated { + true => JoinType::LeftAnti, + false => JoinType::LeftSemi, + }; + + let new_plan = LogicalPlanBuilder::from(outer_input.clone()) + .join( + optimized_subquery, + join_type, + (Vec::::new(), Vec::::new()), + Some(join_filter), + )? + .build()?; - // extract join filters - let (join_filters, subquery_input) = extract_join_filters(maybe_subqury_filter)?; - // cannot optimize non-correlated subquery - if join_filters.is_empty() { - return Ok(None); + Ok(Some(new_plan)) + } else { + Ok(None) } - - let input_schema = subquery_input.schema(); - let subquery_cols: BTreeSet = - join_filters - .iter() - .try_fold(BTreeSet::new(), |mut cols, expr| { - let using_cols: Vec = expr - .to_columns()? +} +/// Optimize the subquery and extract the possible join filter. +/// This function can't optimize non-correlated subquery, and will return None. +fn optimize_subquery(subquery: &LogicalPlan) -> Result> { + match subquery { + LogicalPlan::Distinct(subqry_distinct) => { + let distinct_input = &subqry_distinct.input; + let optimized_plan = + optimize_subquery(distinct_input)?.map(|(filters, right)| { + ( + filters, + LogicalPlan::Distinct(Distinct { + input: Arc::new(right), + }), + ) + }); + Ok(optimized_plan) + } + LogicalPlan::Projection(projection) => { + // extract join filters + let (join_filters, subquery_input) = extract_join_filters(&projection.input)?; + // cannot optimize non-correlated subquery + if join_filters.is_empty() { + return Ok(None); + } + let input_schema = subquery_input.schema(); + let project_exprs: Vec = + collect_subquery_cols(&join_filters, input_schema.clone())? .into_iter() - .filter(|col| input_schema.has_column(col)) - .collect::<_>(); - - cols.extend(using_cols); - Result::<_, DataFusionError>::Ok(cols) + .map(Expr::Column) + .collect(); + let right = LogicalPlanBuilder::from(subquery_input) + .project(project_exprs)? + .build()?; + + // join_filters is not empty. + let join_filter = conjunction(join_filters).ok_or_else(|| { + DataFusionError::Internal("join filters should not be empty".to_string()) })?; - - let projection_exprs: Vec = - subquery_cols.into_iter().map(Expr::Column).collect(); - - let right = LogicalPlanBuilder::from(subquery_input) - .project(projection_exprs)? - .build()?; - - let join_filter = conjunction(join_filters); - - // join our sub query into the main plan - let join_type = match query_info.negated { - true => JoinType::LeftAnti, - false => JoinType::LeftSemi, - }; - - // TODO: add Distinct if the original plan is a Distinct. - let new_plan = LogicalPlanBuilder::from(outer_input.clone()) - .join( - right, - join_type, - (Vec::::new(), Vec::::new()), - join_filter, - )? - .build()?; - - Ok(Some(new_plan)) + Ok(Some((join_filter, right))) + } + _ => Ok(None), + } } struct SubqueryInfo { @@ -670,4 +673,76 @@ mod tests { assert_plan_eq(&plan, expected) } + + #[test] + fn exists_distinct_subquery() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter((lit(1u32) + col("sq.a")).gt(col("test.a") * lit(2u32)))? + .project(vec![col("sq.c")])? + .distinct()? + .build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Distinct: [a:UInt32]\ + \n Projection: sq.a [a:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_plan_eq(&plan, expected) + } + + #[test] + fn exists_distinct_expr_subquery() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter((lit(1u32) + col("sq.a")).gt(col("test.a") * lit(2u32)))? + .project(vec![col("sq.b") + col("sq.c")])? + .distinct()? + .build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Distinct: [a:UInt32]\ + \n Projection: sq.a [a:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_plan_eq(&plan, expected) + } + + #[test] + fn exists_distinct_subquery_with_literal() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter((lit(1u32) + col("sq.a")).gt(col("test.a") * lit(2u32)))? + .project(vec![lit(1u32), col("sq.c")])? + .distinct()? + .build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Distinct: [a:UInt32]\ + \n Projection: sq.a [a:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_plan_eq(&plan, expected) + } } diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index 0aa3bac9ebae9..5eb7e99e79804 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -17,9 +17,12 @@ use crate::alias::AliasGenerator; use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, extract_join_filters, only_or_err, split_conjunction}; +use crate::utils::{ + collect_subquery_cols, conjunction, extract_join_filters, only_or_err, + split_conjunction, +}; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{context, Column, DataFusionError, Result}; +use datafusion_common::{context, Column, Result}; use datafusion_expr::expr_rewriter::{replace_col, unnormalize_col}; use datafusion_expr::logical_plan::{JoinType, Projection, Subquery}; use datafusion_expr::{Expr, Filter, LogicalPlan, LogicalPlanBuilder}; @@ -159,19 +162,7 @@ fn optimize_where_in( // replace qualified name with subquery alias. let subquery_alias = alias.next("__correlated_sq"); let input_schema = subquery_input.schema(); - let mut subquery_cols: BTreeSet = - join_filters - .iter() - .try_fold(BTreeSet::new(), |mut cols, expr| { - let using_cols: Vec = expr - .to_columns()? - .into_iter() - .filter(|col| input_schema.has_column(col)) - .collect::<_>(); - - cols.extend(using_cols); - Result::<_, DataFusionError>::Ok(cols) - })?; + let mut subquery_cols = collect_subquery_cols(&join_filters, input_schema.clone())?; let join_filter = conjunction(join_filters).map_or(Ok(None), |filter| { replace_qualified_name(filter, &subquery_cols, &subquery_alias).map(Option::Some) })?; diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index c6c03ad793e3f..617c123d837cf 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -18,7 +18,7 @@ //! Collection of utility functions that are leveraged by the query optimizer rules use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{plan_err, Column, DFSchemaRef}; +use datafusion_common::{plan_err, Column, DFSchemaRef, DataFusionError}; use datafusion_common::{DFSchema, Result}; use datafusion_expr::expr::{BinaryExpr, Sort}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; @@ -30,7 +30,7 @@ use datafusion_expr::{ logical_plan::{Filter, LogicalPlan}, Expr, Operator, }; -use std::collections::HashSet; +use std::collections::{BTreeSet, HashSet}; use std::sync::Arc; /// Convenience rule for writing optimizers: recursively invoke @@ -505,6 +505,23 @@ pub(crate) fn extract_join_filters( } } +pub(crate) fn collect_subquery_cols( + exprs: &[Expr], + subquery_schema: DFSchemaRef, +) -> Result> { + exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { + let mut using_cols: Vec = vec![]; + for col in expr.to_columns()?.into_iter() { + if subquery_schema.has_column(&col) { + using_cols.push(col); + } + } + + cols.extend(using_cols); + Result::<_, DataFusionError>::Ok(cols) + }) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 7e05cffe87aaf..0b9134c8b84ff 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -151,8 +151,9 @@ fn where_exists_distinct() -> Result<()> { let plan = test_sql(sql)?; let expected = "LeftSemi Join: test.col_int32 = t2.col_int32\ \n TableScan: test projection=[col_int32]\ - \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32]"; + \n Aggregate: groupBy=[[t2.col_int32]], aggr=[[]]\ + \n SubqueryAlias: t2\ + \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) }