From 2ce552b3776bb4e39917e5e7f63c87e9d1dcce8b Mon Sep 17 00:00:00 2001 From: ygf11 Date: Mon, 20 Feb 2023 06:47:04 -0500 Subject: [PATCH 1/7] Refactor DecorrelateWhereExists and add back Distinct if needs --- datafusion/core/tests/sql/joins.rs | 80 ++++++ .../optimizer/src/decorrelate_where_exists.rs | 248 ++++++++++++++---- .../optimizer/tests/integration-test.rs | 2 +- 3 files changed, 272 insertions(+), 58 deletions(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 6d1b1e91b66ef..1eb3f7319e23a 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -3564,3 +3564,83 @@ async fn not_exists_subquery_to_join_expr_filter() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn exists_distinct_subquery_to_join() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT DISTINCT t2_int FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Distinct: [t2_id:UInt32;N, t2_int:UInt32;N]", + " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} + +#[tokio::test] +async fn exists_distinct_subquery_to_join_dedup_cols() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT DISTINCT t2_id, t2_int FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Distinct: [t2_id:UInt32;N, t2_int:UInt32;N]", + " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index 3f6b160fa9a07..cf8d24be12b36 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -20,7 +20,7 @@ use crate::utils::{conjunction, extract_join_filters, split_conjunction}; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DataFusionError, Result}; use datafusion_expr::{ - logical_plan::{Filter, JoinType, Subquery}, + logical_plan::{Distinct, Filter, JoinType, Subquery}, Expr, LogicalPlan, LogicalPlanBuilder, }; use std::collections::BTreeSet; @@ -142,69 +142,102 @@ fn optimize_exists( query_info: &SubqueryInfo, outer_input: &LogicalPlan, ) -> Result> { - let maybe_subqury_filter = match query_info.query.subquery.as_ref() { - LogicalPlan::Distinct(subqry_distinct) => match subqry_distinct.input.as_ref() { - LogicalPlan::Projection(subqry_proj) => &subqry_proj.input, - _ => { - return Ok(None); - } - }, - LogicalPlan::Projection(subqry_proj) => &subqry_proj.input, - _ => { - // Subquery currently only supports distinct or projection - return Ok(None); - } - } - .as_ref(); + let subquery = query_info.query.subquery.as_ref(); + if let Some((join_filter, optimized_subquery)) = optimize_subquery(subquery, false)? { + // join our sub query into the main plan + let join_type = match query_info.negated { + true => JoinType::LeftAnti, + false => JoinType::LeftSemi, + }; + + let new_plan = LogicalPlanBuilder::from(outer_input.clone()) + .join( + optimized_subquery, + join_type, + (Vec::::new(), Vec::::new()), + Some(join_filter), + )? + .build()?; - // extract join filters - let (join_filters, subquery_input) = extract_join_filters(maybe_subqury_filter)?; - // cannot optimize non-correlated subquery - if join_filters.is_empty() { - return Ok(None); + Ok(Some(new_plan)) + } else { + Ok(None) } +} +/// Optimize the subquery and extract the possible join filter. +/// This function can't optimize non-correlated subquery, and will return None. +/// +/// `keep_original_project` means if we should keep the exprs of the original project in the optimized plan. +/// Except for `DISTINCT`, other plan nodes will ignore these exprs. +fn optimize_subquery( + subquery: &LogicalPlan, + keep_original_project: bool, +) -> Result> { + match subquery { + LogicalPlan::Distinct(subqry_distinct) => { + let distinct_input = &subqry_distinct.input; + let optimized_plan = + optimize_subquery(distinct_input, true)?.map(|(filters, right)| { + ( + filters, + LogicalPlan::Distinct(Distinct { + input: Arc::new(right), + }), + ) + }); + Ok(optimized_plan) + } + LogicalPlan::Projection(projection) => { + // extract join filters + let (join_filters, subquery_input) = extract_join_filters(&projection.input)?; + // cannot optimize non-correlated subquery + if join_filters.is_empty() { + return Ok(None); + } - let input_schema = subquery_input.schema(); - let subquery_cols: BTreeSet = - join_filters - .iter() - .try_fold(BTreeSet::new(), |mut cols, expr| { - let using_cols: Vec = expr - .to_columns()? + let input_schema = subquery_input.schema(); + let subquery_cols: BTreeSet = + join_filters + .iter() + .try_fold(BTreeSet::new(), |mut cols, expr| { + let using_cols: Vec = expr + .to_columns()? + .into_iter() + .filter(|col| input_schema.field_from_column(col).is_ok()) + .collect::<_>(); + + cols.extend(using_cols); + Result::<_, DataFusionError>::Ok(cols) + })?; + + let mut project_exprs: Vec = + subquery_cols.into_iter().map(Expr::Column).collect(); + let original_project_exprs = &projection.expr; + if keep_original_project && !original_project_exprs.is_empty() { + let exprs = original_project_exprs + .iter() + .filter(|expr| !matches!(expr, Expr::Literal(_))); + // merge + project_exprs = project_exprs .into_iter() - .filter(|col| input_schema.field_from_column(col).is_ok()) - .collect::<_>(); + .chain(exprs.cloned()) + .collect::>(); - cols.extend(using_cols); - Result::<_, DataFusionError>::Ok(cols) - })?; - - let projection_exprs: Vec = - subquery_cols.into_iter().map(Expr::Column).collect(); - - let right = LogicalPlanBuilder::from(subquery_input) - .project(projection_exprs)? - .build()?; - - let join_filter = conjunction(join_filters); + project_exprs.dedup(); + } - // join our sub query into the main plan - let join_type = match query_info.negated { - true => JoinType::LeftAnti, - false => JoinType::LeftSemi, - }; + let right = LogicalPlanBuilder::from(subquery_input) + .project(project_exprs)? + .build()?; - // TODO: add Distinct if the original plan is a Distinct. - let new_plan = LogicalPlanBuilder::from(outer_input.clone()) - .join( - right, - join_type, - (Vec::::new(), Vec::::new()), - join_filter, - )? - .build()?; - - Ok(Some(new_plan)) + // join_filters is not empty. + let join_filter = conjunction(join_filters).ok_or_else(|| { + DataFusionError::Internal("join filters should not be empty".to_string()) + })?; + Ok(Some((join_filter, right))) + } + _ => Ok(None), + } } struct SubqueryInfo { @@ -644,4 +677,105 @@ mod tests { assert_plan_eq(&plan, expected) } + + #[test] + fn exists_distinct_subquery() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter((lit(1u32) + col("sq.a")).gt(col("test.a") * lit(2u32)))? + .project(vec![col("sq.c")])? + .distinct()? + .build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Distinct: [a:UInt32, c:UInt32]\ + \n Projection: sq.a, sq.c [a:UInt32, c:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_plan_eq(&plan, expected) + } + + #[test] + fn exists_distinct_subquery_with_literal() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter((lit(1u32) + col("sq.a")).gt(col("test.a") * lit(2u32)))? + .project(vec![lit(1u32), col("sq.c")])? + .distinct()? + .build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Distinct: [a:UInt32, c:UInt32]\ + \n Projection: sq.a, sq.c [a:UInt32, c:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_plan_eq(&plan, expected) + } + + #[test] + fn exists_distinct_subquery_dedup_cols() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + // "sq.a" both in projection and filter + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter((lit(1u32) + col("sq.a")).gt(col("test.a") * lit(2u32)))? + .project(vec![lit("sq.a"), col("sq.c")])? + .distinct()? + .build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Distinct: [a:UInt32, c:UInt32]\ + \n Projection: sq.a, sq.c [a:UInt32, c:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_plan_eq(&plan, expected) + } + + #[test] + fn exists_subquery_constant_filter() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter(lit(1u32).eq(lit(1u32)))? + .project(vec![lit("sq.a"), col("sq.c")])? + .distinct()? + .build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + // constant filter is also non-correlated subquery + let expected = "Projection: test.b [b:UInt32]\ + \n Filter: EXISTS () [a:UInt32, b:UInt32, c:UInt32]\ + \n Subquery: [Utf8(\"sq.a\"):Utf8, c:UInt32]\ + \n Distinct: [Utf8(\"sq.a\"):Utf8, c:UInt32]\ + \n Projection: Utf8(\"sq.a\"), sq.c [Utf8(\"sq.a\"):Utf8, c:UInt32]\ + \n Filter: UInt32(1) = UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + + assert_plan_eq(&plan, expected) + } } diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index eac849e347811..c942c2d7109f9 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -154,7 +154,7 @@ fn where_exists_distinct() -> Result<()> { let expected = "Projection: test.col_int32\ \n LeftSemi Join: test.col_int32 = t2.col_int32\ \n TableScan: test projection=[col_int32]\ - \n Projection: t2.col_int32\ + \n Distinct:\ \n SubqueryAlias: t2\ \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{plan:?}")); From 45f321299e5d371a639c1360ccace192a4280f6e Mon Sep 17 00:00:00 2001 From: ygf11 Date: Mon, 20 Feb 2023 07:45:06 -0500 Subject: [PATCH 2/7] extract collect_subquery_cols --- .../optimizer/src/decorrelate_where_exists.rs | 25 ++++++------------- .../optimizer/src/decorrelate_where_in.rs | 21 +++++----------- datafusion/optimizer/src/utils.rs | 21 ++++++++++++++-- 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index cf8d24be12b36..c6c1cc2c2a9db 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -16,14 +16,16 @@ // under the License. use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, extract_join_filters, split_conjunction}; +use crate::utils::{ + collect_subquery_cols, conjunction, extract_join_filters, split_conjunction, +}; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DataFusionError, Result}; use datafusion_expr::{ logical_plan::{Distinct, Filter, JoinType, Subquery}, Expr, LogicalPlan, LogicalPlanBuilder, }; -use std::collections::BTreeSet; + use std::sync::Arc; /// Optimizer rule for rewriting subquery filters to joins @@ -196,22 +198,11 @@ fn optimize_subquery( } let input_schema = subquery_input.schema(); - let subquery_cols: BTreeSet = - join_filters - .iter() - .try_fold(BTreeSet::new(), |mut cols, expr| { - let using_cols: Vec = expr - .to_columns()? - .into_iter() - .filter(|col| input_schema.field_from_column(col).is_ok()) - .collect::<_>(); - - cols.extend(using_cols); - Result::<_, DataFusionError>::Ok(cols) - })?; - let mut project_exprs: Vec = - subquery_cols.into_iter().map(Expr::Column).collect(); + collect_subquery_cols(&join_filters, input_schema.clone())? + .into_iter() + .map(Expr::Column) + .collect(); let original_project_exprs = &projection.expr; if keep_original_project && !original_project_exprs.is_empty() { let exprs = original_project_exprs diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index c8ff65f125234..29f5d10de3070 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -17,9 +17,12 @@ use crate::alias::AliasGenerator; use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, extract_join_filters, only_or_err, split_conjunction}; +use crate::utils::{ + collect_subquery_cols, conjunction, extract_join_filters, only_or_err, + split_conjunction, +}; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{context, Column, DataFusionError, Result}; +use datafusion_common::{context, Column, Result}; use datafusion_expr::expr_rewriter::{replace_col, unnormalize_col}; use datafusion_expr::logical_plan::{JoinType, Projection, Subquery}; use datafusion_expr::{Expr, Filter, LogicalPlan, LogicalPlanBuilder}; @@ -159,19 +162,7 @@ fn optimize_where_in( // replace qualified name with subquery alias. let subquery_alias = alias.next("__correlated_sq"); let input_schema = subquery_input.schema(); - let mut subquery_cols: BTreeSet = - join_filters - .iter() - .try_fold(BTreeSet::new(), |mut cols, expr| { - let using_cols: Vec = expr - .to_columns()? - .into_iter() - .filter(|col| input_schema.field_from_column(col).is_ok()) - .collect::<_>(); - - cols.extend(using_cols); - Result::<_, DataFusionError>::Ok(cols) - })?; + let mut subquery_cols = collect_subquery_cols(&join_filters, input_schema.clone())?; let join_filter = conjunction(join_filters).map_or(Ok(None), |filter| { replace_qualified_name(filter, &subquery_cols, &subquery_alias).map(Option::Some) })?; diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 747fa62089eda..6017fefc436f2 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -18,7 +18,7 @@ //! Collection of utility functions that are leveraged by the query optimizer rules use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{plan_err, Column, DFSchemaRef}; +use datafusion_common::{plan_err, Column, DFSchemaRef, DataFusionError}; use datafusion_common::{DFSchema, Result}; use datafusion_expr::expr::{BinaryExpr, Sort}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; @@ -30,7 +30,7 @@ use datafusion_expr::{ logical_plan::{Filter, LogicalPlan}, Expr, Operator, }; -use std::collections::HashSet; +use std::collections::{BTreeSet, HashSet}; use std::sync::Arc; /// Convenience rule for writing optimizers: recursively invoke @@ -505,6 +505,23 @@ pub(crate) fn extract_join_filters( } } +pub(crate) fn collect_subquery_cols( + exprs: &[Expr], + subquery_schema: DFSchemaRef, +) -> Result> { + exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { + let mut using_cols: Vec = vec![]; + for col in expr.to_columns()?.into_iter() { + if subquery_schema.is_column_from_schema(&col)? { + using_cols.push(col); + } + } + + cols.extend(using_cols); + Result::<_, DataFusionError>::Ok(cols) + }) +} + #[cfg(test)] mod tests { use super::*; From af25a6e4c38ec000d91ea3853b25f622c6abe96e Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 22 Feb 2023 02:45:01 -0500 Subject: [PATCH 3/7] Add back aggregate --- datafusion/expr/src/logical_plan/plan.rs | 18 +++++++- .../optimizer/src/decorrelate_where_exists.rs | 45 +++++++++++++------ 2 files changed, 49 insertions(+), 14 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d86a44e5dccd5..1fc01112be139 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -24,7 +24,7 @@ use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::plan; use crate::utils::{ - self, exprlist_to_fields, from_plan, grouping_set_expr_count, + self, expand_wildcard, exprlist_to_fields, from_plan, grouping_set_expr_count, grouping_set_to_exprlist, }; use crate::{ @@ -1773,6 +1773,22 @@ impl Aggregate { _ => plan_err!("Could not coerce into Aggregate!"), } } + + pub fn is_distinct(&self) -> datafusion_common::Result { + let group_expr_size = self.group_expr.len(); + if !self.aggr_expr.is_empty() || group_expr_size != self.schema.fields().len() { + return Ok(false); + } + + let expected_group_exprs = expand_wildcard(&self.schema, self.input.as_ref())?; + let expected_expr_set = expected_group_exprs.iter().collect::>(); + let group_expr_set = self.group_expr.iter().collect::>(); + Ok(group_expr_set + .intersection(&expected_expr_set) + .collect::>() + .len() + == group_expr_size) + } } /// Sorts its input according to a list of sort expressions. diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index c6c1cc2c2a9db..bc313d04cce08 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -21,8 +21,9 @@ use crate::utils::{ }; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DataFusionError, Result}; +use datafusion_expr::utils::expand_wildcard; use datafusion_expr::{ - logical_plan::{Distinct, Filter, JoinType, Subquery}, + logical_plan::{Aggregate, Distinct, Filter, JoinType, Subquery}, Expr, LogicalPlan, LogicalPlanBuilder, }; @@ -188,6 +189,27 @@ fn optimize_subquery( ) }); Ok(optimized_plan) + } // Aggregate maybe a distinct + LogicalPlan::Aggregate(aggregate) => { + if !aggregate.is_distinct()? { + return Ok(None); + } + + if let Some((join_filter, plan)) = + optimize_subquery(aggregate.input.as_ref(), true)? + { + let input_schema = plan.schema().clone(); + let group_expr = expand_wildcard(&input_schema, &plan)?; + let aggregate = LogicalPlan::Aggregate(Aggregate::try_new_with_schema( + Arc::new(plan), + group_expr, + vec![], + input_schema, + )?); + Ok(Some((join_filter, aggregate))) + } else { + Ok(None) + } } LogicalPlan::Projection(projection) => { // extract join filters @@ -743,29 +765,26 @@ mod tests { } #[test] - fn exists_subquery_constant_filter() -> Result<()> { + fn exists_subquery_infer_distinct() -> Result<()> { let table_scan = test_table_scan()?; let subquery_scan = test_table_scan_with_name("sq")?; let subquery = LogicalPlanBuilder::from(subquery_scan) - .filter(lit(1u32).eq(lit(1u32)))? - .project(vec![lit("sq.a"), col("sq.c")])? - .distinct()? + .filter(col("sq.a").gt(col("test.b")))? + .project(vec![col("sq.a"), col("sq.c")])? + .aggregate(vec![col("a"), col("c")], Vec::::new())? .build()?; let plan = LogicalPlanBuilder::from(table_scan) .filter(exists(Arc::new(subquery)))? .project(vec![col("test.b")])? .build()?; - // constant filter is also non-correlated subquery let expected = "Projection: test.b [b:UInt32]\ - \n Filter: EXISTS () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [Utf8(\"sq.a\"):Utf8, c:UInt32]\ - \n Distinct: [Utf8(\"sq.a\"):Utf8, c:UInt32]\ - \n Projection: Utf8(\"sq.a\"), sq.c [Utf8(\"sq.a\"):Utf8, c:UInt32]\ - \n Filter: UInt32(1) = UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + \n LeftSemi Join: Filter: sq.a > test.b [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Aggregate: groupBy=[[sq.a, sq.c]], aggr=[[]] [a:UInt32, c:UInt32]\ + \n Projection: sq.a, sq.c [a:UInt32, c:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } From dbaf64cb8581d53c68cc5b94047e455458872d36 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 22 Feb 2023 02:47:25 -0500 Subject: [PATCH 4/7] fix comment --- datafusion/optimizer/src/decorrelate_where_exists.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index bc313d04cce08..cc1a541408fe6 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -189,7 +189,7 @@ fn optimize_subquery( ) }); Ok(optimized_plan) - } // Aggregate maybe a distinct + } // Aggregate may be a distinct LogicalPlan::Aggregate(aggregate) => { if !aggregate.is_distinct()? { return Ok(None); From 47b1cd5b4c03fc755c1fe702a1dc58b80df72679 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Fri, 3 Mar 2023 01:28:34 -0500 Subject: [PATCH 5/7] fix tests --- datafusion/core/tests/sql/joins.rs | 63 +++++++++--- datafusion/expr/src/logical_plan/plan.rs | 32 +++--- .../optimizer/src/decorrelate_where_exists.rs | 97 +++---------------- .../optimizer/tests/integration-test.rs | 5 +- 4 files changed, 85 insertions(+), 112 deletions(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 6749baa8ae2bd..b9b4c5cf31d58 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -3532,11 +3532,10 @@ async fn exists_distinct_subquery_to_join() -> Result<()> { let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Distinct: [t2_id:UInt32;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", + " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Aggregate: groupBy=[[t2.t2_id]], aggr=[[]] [t2_id:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -3560,23 +3559,63 @@ async fn exists_distinct_subquery_to_join() -> Result<()> { } #[tokio::test] -async fn exists_distinct_subquery_to_join_dedup_cols() -> Result<()> { +async fn exists_distinct_subquery_to_join_with_expr() -> Result<()> { let test_repartition_joins = vec![true, false]; for repartition_joins in test_repartition_joins { let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT DISTINCT t2_id, t2_int FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; + // `t2_id + t2_int` is in the subquery project. + let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT DISTINCT t2_id + t2_int, t2_int FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; let msg = format!("Creating logical plan for '{sql}'"); let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); let plan = dataframe.into_optimized_plan()?; let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Distinct: [t2_id:UInt32;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", + " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Aggregate: groupBy=[[t2.t2_id]], aggr=[[]] [t2_id:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} + +#[tokio::test] +async fn exists_distinct_subquery_to_join_with_literal() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + // `1` is in the subquery project. + let sql = "SELECT * FROM t1 WHERE NOT EXISTS(SELECT DISTINCT 1, t2_int FROM t2 WHERE t1.t1_id + 1 > t2.t2_id * 2)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan()?; + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " LeftAnti Join: Filter: CAST(t1.t1_id AS Int64) + Int64(1) > CAST(t2.t2_id AS Int64) * Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Aggregate: groupBy=[[t2.t2_id]], aggr=[[]] [t2_id:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index c94accfa8d48f..cd7dbf6c50fb8 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -24,7 +24,7 @@ use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::plan; use crate::utils::{ - self, expand_wildcard, exprlist_to_fields, from_plan, grouping_set_expr_count, + self, exprlist_to_fields, from_plan, grouping_set_expr_count, grouping_set_to_exprlist, }; use crate::{ @@ -1777,21 +1777,21 @@ impl Aggregate { } } - pub fn is_distinct(&self) -> datafusion_common::Result { - let group_expr_size = self.group_expr.len(); - if !self.aggr_expr.is_empty() || group_expr_size != self.schema.fields().len() { - return Ok(false); - } - - let expected_group_exprs = expand_wildcard(&self.schema, self.input.as_ref())?; - let expected_expr_set = expected_group_exprs.iter().collect::>(); - let group_expr_set = self.group_expr.iter().collect::>(); - Ok(group_expr_set - .intersection(&expected_expr_set) - .collect::>() - .len() - == group_expr_size) - } + // pub fn is_distinct(&self) -> datafusion_common::Result { + // let group_expr_size = self.group_expr.len(); + // if !self.aggr_expr.is_empty() || group_expr_size != self.schema.fields().len() { + // return Ok(false); + // } + + // let expected_group_exprs = expand_wildcard(&self.schema, self.input.as_ref())?; + // let expected_expr_set = expected_group_exprs.iter().collect::>(); + // let group_expr_set = self.group_expr.iter().collect::>(); + // Ok(group_expr_set + // .intersection(&expected_expr_set) + // .collect::>() + // .len() + // == group_expr_size) + // } } /// Sorts its input according to a list of sort expressions. diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index 082b5715c8802..8d81c2d38197a 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -21,9 +21,8 @@ use crate::utils::{ }; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DataFusionError, Result}; -use datafusion_expr::utils::expand_wildcard; use datafusion_expr::{ - logical_plan::{Aggregate, Distinct, Filter, JoinType, Subquery}, + logical_plan::{Distinct, Filter, JoinType, Subquery}, Expr, LogicalPlan, LogicalPlanBuilder, }; @@ -146,7 +145,7 @@ fn optimize_exists( outer_input: &LogicalPlan, ) -> Result> { let subquery = query_info.query.subquery.as_ref(); - if let Some((join_filter, optimized_subquery)) = optimize_subquery(subquery, false)? { + if let Some((join_filter, optimized_subquery)) = optimize_subquery(subquery)? { // join our sub query into the main plan let join_type = match query_info.negated { true => JoinType::LeftAnti, @@ -169,18 +168,12 @@ fn optimize_exists( } /// Optimize the subquery and extract the possible join filter. /// This function can't optimize non-correlated subquery, and will return None. -/// -/// `keep_original_project` means if we should keep the exprs of the original project in the optimized plan. -/// Except for `DISTINCT`, other plan nodes will ignore these exprs. -fn optimize_subquery( - subquery: &LogicalPlan, - keep_original_project: bool, -) -> Result> { +fn optimize_subquery(subquery: &LogicalPlan) -> Result> { match subquery { LogicalPlan::Distinct(subqry_distinct) => { let distinct_input = &subqry_distinct.input; let optimized_plan = - optimize_subquery(distinct_input, true)?.map(|(filters, right)| { + optimize_subquery(distinct_input)?.map(|(filters, right)| { ( filters, LogicalPlan::Distinct(Distinct { @@ -189,27 +182,6 @@ fn optimize_subquery( ) }); Ok(optimized_plan) - } // Aggregate may be a distinct - LogicalPlan::Aggregate(aggregate) => { - if !aggregate.is_distinct()? { - return Ok(None); - } - - if let Some((join_filter, plan)) = - optimize_subquery(aggregate.input.as_ref(), true)? - { - let input_schema = plan.schema().clone(); - let group_expr = expand_wildcard(&input_schema, &plan)?; - let aggregate = LogicalPlan::Aggregate(Aggregate::try_new_with_schema( - Arc::new(plan), - group_expr, - vec![], - input_schema, - )?); - Ok(Some((join_filter, aggregate))) - } else { - Ok(None) - } } LogicalPlan::Projection(projection) => { // extract join filters @@ -220,24 +192,11 @@ fn optimize_subquery( } let input_schema = subquery_input.schema(); - let mut project_exprs: Vec = + let project_exprs: Vec = collect_subquery_cols(&join_filters, input_schema.clone())? .into_iter() .map(Expr::Column) .collect(); - let original_project_exprs = &projection.expr; - if keep_original_project && !original_project_exprs.is_empty() { - let exprs = original_project_exprs - .iter() - .filter(|expr| !matches!(expr, Expr::Literal(_))); - // merge - project_exprs = project_exprs - .into_iter() - .chain(exprs.cloned()) - .collect::>(); - - project_exprs.dedup(); - } let right = LogicalPlanBuilder::from(subquery_input) .project(project_exprs)? @@ -734,20 +693,20 @@ mod tests { let expected = "Projection: test.b [b:UInt32]\ \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n Distinct: [a:UInt32, c:UInt32]\ - \n Projection: sq.a, sq.c [a:UInt32, c:UInt32]\ + \n Distinct: [a:UInt32]\ + \n Projection: sq.a [a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } #[test] - fn exists_distinct_subquery_with_literal() -> Result<()> { + fn exists_distinct_expr_subquery() -> Result<()> { let table_scan = test_table_scan()?; let subquery_scan = test_table_scan_with_name("sq")?; let subquery = LogicalPlanBuilder::from(subquery_scan) .filter((lit(1u32) + col("sq.a")).gt(col("test.a") * lit(2u32)))? - .project(vec![lit(1u32), col("sq.c")])? + .project(vec![col("sq.b") + col("sq.c")])? .distinct()? .build()?; let plan = LogicalPlanBuilder::from(table_scan) @@ -758,21 +717,20 @@ mod tests { let expected = "Projection: test.b [b:UInt32]\ \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n Distinct: [a:UInt32, c:UInt32]\ - \n Projection: sq.a, sq.c [a:UInt32, c:UInt32]\ + \n Distinct: [a:UInt32]\ + \n Projection: sq.a [a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } #[test] - fn exists_distinct_subquery_dedup_cols() -> Result<()> { + fn exists_distinct_subquery_with_literal() -> Result<()> { let table_scan = test_table_scan()?; let subquery_scan = test_table_scan_with_name("sq")?; - // "sq.a" both in projection and filter let subquery = LogicalPlanBuilder::from(subquery_scan) .filter((lit(1u32) + col("sq.a")).gt(col("test.a") * lit(2u32)))? - .project(vec![lit("sq.a"), col("sq.c")])? + .project(vec![lit(1u32), col("sq.c")])? .distinct()? .build()?; let plan = LogicalPlanBuilder::from(table_scan) @@ -783,33 +741,8 @@ mod tests { let expected = "Projection: test.b [b:UInt32]\ \n LeftSemi Join: Filter: UInt32(1) + sq.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n Distinct: [a:UInt32, c:UInt32]\ - \n Projection: sq.a, sq.c [a:UInt32, c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_plan_eq(&plan, expected) - } - - #[test] - fn exists_subquery_infer_distinct() -> Result<()> { - let table_scan = test_table_scan()?; - let subquery_scan = test_table_scan_with_name("sq")?; - - let subquery = LogicalPlanBuilder::from(subquery_scan) - .filter(col("sq.a").gt(col("test.b")))? - .project(vec![col("sq.a"), col("sq.c")])? - .aggregate(vec![col("a"), col("c")], Vec::::new())? - .build()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(exists(Arc::new(subquery)))? - .project(vec![col("test.b")])? - .build()?; - - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: sq.a > test.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n Aggregate: groupBy=[[sq.a, sq.c]], aggr=[[]] [a:UInt32, c:UInt32]\ - \n Projection: sq.a, sq.c [a:UInt32, c:UInt32]\ + \n Distinct: [a:UInt32]\ + \n Projection: sq.a [a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 7e05cffe87aaf..0b9134c8b84ff 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -151,8 +151,9 @@ fn where_exists_distinct() -> Result<()> { let plan = test_sql(sql)?; let expected = "LeftSemi Join: test.col_int32 = t2.col_int32\ \n TableScan: test projection=[col_int32]\ - \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32]"; + \n Aggregate: groupBy=[[t2.col_int32]], aggr=[[]]\ + \n SubqueryAlias: t2\ + \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } From 0b4d252ac60e74be72875a6903035a5c3f1bc819 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Fri, 3 Mar 2023 02:01:07 -0500 Subject: [PATCH 6/7] remove unused method --- datafusion/expr/src/logical_plan/plan.rs | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index cd7dbf6c50fb8..c3ef861eb3b40 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1776,22 +1776,6 @@ impl Aggregate { _ => plan_err!("Could not coerce into Aggregate!"), } } - - // pub fn is_distinct(&self) -> datafusion_common::Result { - // let group_expr_size = self.group_expr.len(); - // if !self.aggr_expr.is_empty() || group_expr_size != self.schema.fields().len() { - // return Ok(false); - // } - - // let expected_group_exprs = expand_wildcard(&self.schema, self.input.as_ref())?; - // let expected_expr_set = expected_group_exprs.iter().collect::>(); - // let group_expr_set = self.group_expr.iter().collect::>(); - // Ok(group_expr_set - // .intersection(&expected_expr_set) - // .collect::>() - // .len() - // == group_expr_size) - // } } /// Sorts its input according to a list of sort expressions. From 573c48bbea503a33f79ceb5aa3a7c61c7135f35c Mon Sep 17 00:00:00 2001 From: ygf11 Date: Fri, 3 Mar 2023 02:57:50 -0500 Subject: [PATCH 7/7] fix cargo fmt --- datafusion/optimizer/src/decorrelate_where_exists.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index 8d81c2d38197a..ffc12c1582351 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -190,14 +190,12 @@ fn optimize_subquery(subquery: &LogicalPlan) -> Result = collect_subquery_cols(&join_filters, input_schema.clone())? .into_iter() .map(Expr::Column) .collect(); - let right = LogicalPlanBuilder::from(subquery_input) .project(project_exprs)? .build()?;