From ee58c29117165a6594e23eb8a001220956af706b Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 30 Nov 2022 00:56:57 -0500 Subject: [PATCH 1/9] Support non-column join key in eliminating cross join to inner join --- datafusion/core/tests/sql/subqueries.rs | 1 + datafusion/expr/src/logical_plan/builder.rs | 29 +- .../optimizer/src/eliminate_cross_join.rs | 331 ++++++++++++++---- datafusion/sql/src/planner.rs | 30 +- 4 files changed, 302 insertions(+), 89 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 719c0c3d7a258..3039a3b49db52 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -465,6 +465,7 @@ order by value desc; \n TableScan: supplier projection=[s_suppkey, s_nationkey]\ \n Filter: nation.n_name = Utf8(\"GERMANY\")\ \n TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8(\"GERMANY\")]"; + assert_eq!(actual, expected); // assert data diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 71097f5a61c6a..b6ecbd532d8ff 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -42,8 +42,9 @@ use datafusion_common::{ ToDFSchema, }; use std::any::Any; +use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; -use std::{collections::HashMap, sync::Arc}; +use std::sync::Arc; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -1012,6 +1013,32 @@ pub fn table_scan( LogicalPlanBuilder::scan(name.unwrap_or(UNNAMED_TABLE), table_source, projection) } +/// Wrap projection for a plan, if the join keys contains normal expression. +pub fn wrap_projection_for_join_if_necessary( + join_keys: &[Expr], + input: LogicalPlan, +) -> Result<(LogicalPlan, bool)> { + let expr_join_keys = join_keys + .iter() + .flat_map(|expr| expr.try_into_col().is_err().then_some(expr)) + .cloned() + .collect::>(); + + let need_project = !expr_join_keys.is_empty(); + let plan = if need_project { + let mut projection = vec![Expr::Wildcard]; + projection.extend(expr_join_keys.into_iter()); + + LogicalPlanBuilder::from(input) + .project(projection)? + .build()? + } else { + input + }; + + Ok((plan, need_project)) +} + /// Basic TableSource implementation intended for use in tests and documentation. It is expected /// that users will provide their own TableSource implementations or use DataFusion's /// DefaultTableSource. diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 83fc9e164beb0..6c02428655199 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -18,16 +18,15 @@ //! Optimizer rule to eliminate cross join to inner join if join predicates are available in filters. use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, DataFusionError, Result}; +use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; +use datafusion_expr::utils::{can_hash, check_all_column_from_schema}; use datafusion_expr::{ and, expr::BinaryExpr, logical_plan::{CrossJoin, Filter, Join, JoinType, LogicalPlan}, - or, - utils::can_hash, - Projection, + or, Projection, }; -use datafusion_expr::{Expr, Operator}; - +use datafusion_expr::{Expr, ExprSchemable, Operator}; use std::collections::{HashMap, HashSet}; //use std::collections::HashMap; @@ -64,7 +63,7 @@ impl OptimizerRule for EliminateCrossJoin { LogicalPlan::Filter(filter) => { let input = (**filter.input()).clone(); - let mut possible_join_keys: Vec<(Column, Column)> = vec![]; + let mut possible_join_keys: Vec<(Expr, Expr)> = vec![]; let mut all_inputs: Vec = vec![]; match &input { LogicalPlan::Join(join) if (join.join_type == JoinType::Inner) => { @@ -88,9 +87,9 @@ impl OptimizerRule for EliminateCrossJoin { let predicate = filter.predicate(); // join keys are handled locally - let mut all_join_keys: HashSet<(Column, Column)> = HashSet::new(); + let mut all_join_keys: HashSet<(Expr, Expr)> = HashSet::new(); - extract_possible_join_keys(predicate, &mut possible_join_keys); + extract_possible_join_keys(predicate, &mut possible_join_keys)?; let mut left = all_inputs.remove(0); while !all_inputs.is_empty() { @@ -103,6 +102,7 @@ impl OptimizerRule for EliminateCrossJoin { } left = utils::optimize_children(self, &left, _optimizer_config)?; + if plan.schema() != left.schema() { left = LogicalPlan::Projection(Projection::new_from_schema( Arc::new(left.clone()), @@ -139,13 +139,15 @@ impl OptimizerRule for EliminateCrossJoin { fn flatten_join_inputs( plan: &LogicalPlan, - possible_join_keys: &mut Vec<(Column, Column)>, + possible_join_keys: &mut Vec<(Expr, Expr)>, all_inputs: &mut Vec, ) -> Result<()> { let children = match plan { LogicalPlan::Join(join) => { for join_keys in join.on.iter() { - possible_join_keys.push(join_keys.clone()); + let join_keys = join_keys.clone(); + possible_join_keys + .push((Expr::Column(join_keys.0), Expr::Column(join_keys.1))); } let left = &*(join.left); let right = &*(join.right); @@ -182,23 +184,49 @@ fn flatten_join_inputs( } fn find_inner_join( - left: &LogicalPlan, + left_input: &LogicalPlan, rights: &mut Vec, - possible_join_keys: &mut Vec<(Column, Column)>, - all_join_keys: &mut HashSet<(Column, Column)>, + possible_join_keys: &mut Vec<(Expr, Expr)>, + all_join_keys: &mut HashSet<(Expr, Expr)>, ) -> Result { - for (i, right) in rights.iter().enumerate() { + for (i, right_input) in rights.iter().enumerate() { let mut join_keys = vec![]; for (l, r) in &mut *possible_join_keys { - if left.schema().field_from_column(l).is_ok() - && right.schema().field_from_column(r).is_ok() - && can_hash(left.schema().field_from_column(l).unwrap().data_type()) - { + let left_using_columns = l.to_columns()?; + let right_using_columns = r.to_columns()?; + + // Conditions like a = 10, will be treated as filter. + if left_using_columns.is_empty() || right_using_columns.is_empty() { + continue; + } + + let l_is_left = check_all_column_from_schema( + &left_using_columns, + left_input.schema().clone(), + )?; + let r_is_right = check_all_column_from_schema( + &right_using_columns, + right_input.schema().clone(), + )?; + + let r_is_left_and_l_is_right = || { + let result = check_all_column_from_schema( + &right_using_columns, + left_input.schema().clone(), + )? && check_all_column_from_schema( + &left_using_columns, + right_input.schema().clone(), + )?; + + Result::Ok(result) + }; + + // Data type of l and r is same. + if l_is_left && r_is_right && can_hash(&l.get_type(left_input.schema())?) { join_keys.push((l.clone(), r.clone())); - } else if left.schema().field_from_column(r).is_ok() - && right.schema().field_from_column(l).is_ok() - && can_hash(left.schema().field_from_column(r).unwrap().data_type()) + } else if r_is_left_and_l_is_right()? + && can_hash(&l.get_type(right_input.schema())?) { join_keys.push((r.clone(), l.clone())); } @@ -206,14 +234,37 @@ fn find_inner_join( if !join_keys.is_empty() { all_join_keys.extend(join_keys.clone()); - let right = rights.remove(i); - let join_schema = Arc::new(build_join_schema(left, &right)?); + let right_input = rights.remove(i); + let join_schema = Arc::new(build_join_schema(left_input, &right_input)?); + + let (left_on, right_on): (Vec, Vec) = + join_keys.into_iter().unzip(); + let (new_left_input, _) = + wrap_projection_for_join_if_necessary(&left_on, left_input.clone())?; + let (new_right_input, _) = + wrap_projection_for_join_if_necessary(&right_on, right_input)?; + + let join_on = left_on + .iter() + .zip(right_on.iter()) + .map(|(left, right)| { + let left_key = left.try_into_col().or_else(|_| { + Result::Ok(Column::from_name(left.display_name()?)) + })?; + let right_key = right.try_into_col().or_else(|_| { + Result::Ok(Column::from_name(right.display_name()?)) + })?; + + Ok((left_key, right_key)) + }) + .collect::>>()?; + return Ok(LogicalPlan::Join(Join { - left: Arc::new(left.clone()), - right: Arc::new(right), + left: Arc::new(new_left_input), + right: Arc::new(new_right_input), join_type: JoinType::Inner, join_constraint: JoinConstraint::On, - on: join_keys, + on: join_on, filter: None, schema: join_schema, null_equals_null: false, @@ -221,10 +272,10 @@ fn find_inner_join( } } let right = rights.remove(0); - let join_schema = Arc::new(build_join_schema(left, &right)?); + let join_schema = Arc::new(build_join_schema(left_input, &right)?); Ok(LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(left.clone()), + left: Arc::new(left_input.clone()), right: Arc::new(right), schema: join_schema, })) @@ -242,9 +293,9 @@ fn build_join_schema(left: &LogicalPlan, right: &LogicalPlan) -> Result, - vec1: &[(Column, Column)], - vec2: &[(Column, Column)], + accum: &mut Vec<(Expr, Expr)>, + vec1: &[(Expr, Expr)], + vec2: &[(Expr, Expr)], ) { for x1 in vec1.iter() { for x2 in vec2.iter() { @@ -256,38 +307,35 @@ fn intersect( } /// Extract join keys from a WHERE clause -fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Column, Column)>) { +fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Expr, Expr)>) -> Result<()> { if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr { match op { Operator::Eq => { - if let (Expr::Column(l), Expr::Column(r)) = - (left.as_ref(), right.as_ref()) + // Ensure that we don't add the same Join keys multiple times + if !(accum.contains(&(*left.clone(), *right.clone())) + || accum.contains(&(*right.clone(), *left.clone()))) { - // Ensure that we don't add the same Join keys multiple times - if !(accum.contains(&(l.clone(), r.clone())) - || accum.contains(&(r.clone(), l.clone()))) - { - accum.push((l.clone(), r.clone())); - } + accum.push((*left.clone(), *right.clone())); } } Operator::And => { - extract_possible_join_keys(left, accum); - extract_possible_join_keys(right, accum) + extract_possible_join_keys(left, accum)?; + extract_possible_join_keys(right, accum)? } // Fix for issue#78 join predicates from inside of OR expr also pulled up properly. Operator::Or => { let mut left_join_keys = vec![]; let mut right_join_keys = vec![]; - extract_possible_join_keys(left, &mut left_join_keys); - extract_possible_join_keys(right, &mut right_join_keys); + extract_possible_join_keys(left, &mut left_join_keys)?; + extract_possible_join_keys(right, &mut right_join_keys)?; intersect(accum, &left_join_keys, &right_join_keys) } _ => (), - } + }; } + Ok(()) } /// Remove join expressions from a filter expression @@ -295,25 +343,22 @@ fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Column, Column)>) { /// Returns None otherwise fn remove_join_expressions( expr: &Expr, - join_columns: &HashSet<(Column, Column)>, + join_keys: &HashSet<(Expr, Expr)>, ) -> Result> { match expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - Operator::Eq => match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => { - if join_columns.contains(&(l.clone(), r.clone())) - || join_columns.contains(&(r.clone(), l.clone())) - { - Ok(None) - } else { - Ok(Some(expr.clone())) - } + Operator::Eq => { + if join_keys.contains(&(*left.clone(), *right.clone())) + || join_keys.contains(&(*right.clone(), *left.clone())) + { + Ok(None) + } else { + Ok(Some(expr.clone())) } - _ => Ok(Some(expr.clone())), - }, + } Operator::And => { - let l = remove_join_expressions(left, join_columns)?; - let r = remove_join_expressions(right, join_columns)?; + let l = remove_join_expressions(left, join_keys)?; + let r = remove_join_expressions(right, join_keys)?; match (l, r) { (Some(ll), Some(rr)) => Ok(Some(and(ll, rr))), (Some(ll), _) => Ok(Some(ll)), @@ -323,8 +368,8 @@ fn remove_join_expressions( } // Fix for issue#78 join predicates from inside of OR expr also pulled up properly. Operator::Or => { - let l = remove_join_expressions(left, join_columns)?; - let r = remove_join_expressions(right, join_columns)?; + let l = remove_join_expressions(left, join_keys)?; + let r = remove_join_expressions(right, join_keys)?; match (l, r) { (Some(ll), Some(rr)) => Ok(Some(or(ll, rr))), (Some(ll), _) => Ok(Some(ll)), @@ -1055,4 +1100,168 @@ mod tests { Ok(()) } + + #[test] + fn eliminate_cross_join_with_expr_and() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // could eliminate to inner join since filter has Join predicates + let plan = LogicalPlanBuilder::from(t1) + .cross_join(&t2)? + .filter(binary_expr( + (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)), + And, + col("t2.c").lt(lit(20u32)), + ))? + .build()?; + + let expected = vec![ + "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32, a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]", + " Projection: t1.a, t1.b, t1.c, t1.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + ]; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + #[test] + fn eliminate_cross_with_expr_or() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // could not eliminate to inner join since filter OR expression and there is no common + // Join predicates in left and right of OR expr. + let plan = LogicalPlanBuilder::from(t1) + .cross_join(&t2)? + .filter(binary_expr( + (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)), + Or, + col("t2.b").eq(col("t1.a")), + ))? + .build()?; + + let expected = vec![ + "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + ]; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + #[test] + fn eliminate_cross_with_common_expr_and() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // could eliminate to inner join + let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)); + let plan = LogicalPlanBuilder::from(t1) + .cross_join(&t2)? + .filter(binary_expr( + binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(20u32))), + And, + binary_expr(common_join_key, And, col("t2.c").eq(lit(10u32))), + ))? + .build()?; + + let expected = vec![ + "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32, a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]", + " Projection: t1.a, t1.b, t1.c, t1.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + ]; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + #[test] + fn eliminate_cross_with_common_expr_or() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // could eliminate to inner join since Or predicates have common Join predicates + let common_join_key = (col("t1.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)); + let plan = LogicalPlanBuilder::from(t1) + .cross_join(&t2)? + .filter(binary_expr( + binary_expr(common_join_key.clone(), And, col("t2.c").lt(lit(15u32))), + Or, + binary_expr(common_join_key, And, col("t2.c").eq(lit(688u32))), + ))? + .build()?; + + let expected = vec![ + "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32, a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]", + " Projection: t1.a, t1.b, t1.c, t1.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + ]; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + #[test] + fn reorder_join_with_expr_key_multi_tables() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + let t3 = test_table_scan_with_name("t3")?; + + // could eliminate to inner join + let plan = LogicalPlanBuilder::from(t1) + .cross_join(&t2)? + .cross_join(&t3)? + .filter(binary_expr( + binary_expr( + (col("t3.a") + lit(100u32)).eq(col("t1.a") * lit(2u32)), + And, + col("t3.c").lt(lit(15u32)), + ), + And, + binary_expr( + (col("t3.a") + lit(100u32)).eq(col("t2.a") * lit(2u32)), + And, + col("t3.b").lt(lit(15u32)), + ), + ))? + .build()?; + + let expected = vec![ + "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32, a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]", + " Projection: t1.a, t1.b, t1.c, t3.a, t3.b, t3.c, t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]", + " Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t1.a * UInt32(2):UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]", + " Projection: t1.a, t1.b, t1.c, t1.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t1.a * UInt32(2):UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " Projection: t3.a, t3.b, t3.c, t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + ]; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 36b45390e47af..c614fbebb8a3d 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -45,7 +45,9 @@ use datafusion_common::{ use datafusion_expr::expr::{Between, BinaryExpr, Case, Cast, GroupingSet, Like}; use datafusion_expr::expr_rewriter::normalize_col; use datafusion_expr::expr_rewriter::normalize_col_with_schemas; -use datafusion_expr::logical_plan::builder::{project_with_alias, with_alias}; +use datafusion_expr::logical_plan::builder::{ + project_with_alias, with_alias, wrap_projection_for_join_if_necessary, +}; use datafusion_expr::logical_plan::Join as HashJoin; use datafusion_expr::logical_plan::JoinConstraint as HashJoinConstraint; use datafusion_expr::logical_plan::{ @@ -3030,32 +3032,6 @@ fn extract_join_keys( Ok(()) } -/// Wrap projection for a plan, if the join keys contains normal expression. -fn wrap_projection_for_join_if_necessary( - join_keys: &[Expr], - input: LogicalPlan, -) -> Result<(LogicalPlan, bool)> { - let expr_join_keys = join_keys - .iter() - .flat_map(|expr| expr.try_into_col().is_err().then_some(expr)) - .cloned() - .collect::>(); - - let need_project = !expr_join_keys.is_empty(); - let plan = if need_project { - let mut projection = vec![Expr::Wildcard]; - projection.extend(expr_join_keys.into_iter()); - - LogicalPlanBuilder::from(input) - .project(projection)? - .build()? - } else { - input - }; - - Ok((plan, need_project)) -} - /// Ensure any column reference of the expression is unambiguous. /// Assume we have two schema: /// schema1: a, b ,c From a6fd19dc92bebe4ea9a55d2cc26833f14f37801b Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 30 Nov 2022 08:35:38 -0500 Subject: [PATCH 2/9] Add comment --- datafusion/core/tests/sql/subqueries.rs | 1 - datafusion/optimizer/src/eliminate_cross_join.rs | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 3039a3b49db52..719c0c3d7a258 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -465,7 +465,6 @@ order by value desc; \n TableScan: supplier projection=[s_suppkey, s_nationkey]\ \n Filter: nation.n_name = Utf8(\"GERMANY\")\ \n TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8(\"GERMANY\")]"; - assert_eq!(actual, expected); // assert data diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 6c02428655199..b71698f4311f9 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -222,7 +222,7 @@ fn find_inner_join( Result::Ok(result) }; - // Data type of l and r is same. + // Save join keys if l_is_left && r_is_right && can_hash(&l.get_type(left_input.schema())?) { join_keys.push((l.clone(), r.clone())); } else if r_is_left_and_l_is_right()? @@ -237,6 +237,7 @@ fn find_inner_join( let right_input = rights.remove(i); let join_schema = Arc::new(build_join_schema(left_input, &right_input)?); + // Wrap projection let (left_on, right_on): (Vec, Vec) = join_keys.into_iter().unzip(); let (new_left_input, _) = @@ -244,6 +245,7 @@ fn find_inner_join( let (new_right_input, _) = wrap_projection_for_join_if_necessary(&right_on, right_input)?; + // Build new join on let join_on = left_on .iter() .zip(right_on.iter()) From 799e3c49f0b37552c076ab1e485878d3b0c55ae6 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 30 Nov 2022 09:25:24 -0500 Subject: [PATCH 3/9] Make clippy happy --- datafusion/optimizer/src/eliminate_cross_join.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index b71698f4311f9..e6ffd3e6361aa 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -245,7 +245,7 @@ fn find_inner_join( let (new_right_input, _) = wrap_projection_for_join_if_necessary(&right_on, right_input)?; - // Build new join on + // Build new join on let join_on = left_on .iter() .zip(right_on.iter()) From 60c86d1391b0f36d2ac6aed3ba54fd6c0e1986fc Mon Sep 17 00:00:00 2001 From: ygf11 Date: Thu, 1 Dec 2022 21:34:40 -0500 Subject: [PATCH 4/9] Add tests --- datafusion/core/tests/sql/joins.rs | 94 +++++++++++++++++++ datafusion/core/tests/sql/mod.rs | 2 +- .../optimizer/src/eliminate_cross_join.rs | 31 ++++++ 3 files changed, 126 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 87fb594c79b3b..6af1dd9fe63cc 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -2304,3 +2304,97 @@ async fn error_cross_join() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn reduce_cross_join_with_expr_join_key_all() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + // reduce to inner join + let sql = + "select * from t1 cross join t2 where t1.t1_id + 11 = t2.t2_id"; + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx + .create_logical_plan(&("explain ".to_owned() + sql)) + .expect(&msg); + let state = ctx.state(); + let plan = state.optimize(&plan)?; + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " CrossJoin: [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let expected = vec![ + "+-------+---------+--------+-------+---------+--------+", + "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", + "+-------+---------+--------+-------+---------+--------+", + "| 11 | a | 1 | 22 | y | 1 |", + "| 33 | c | 3 | 44 | x | 3 |", + "| 44 | d | 4 | 55 | w | 3 |", + "+-------+---------+--------+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} + +#[tokio::test] +async fn reduce_cross_join_with_expr_join_key_some() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + // reduce to inner join + let sql = + "select t1.t1_id, t2.t2_id, t1.t1_name from t1 cross join t2 where t1.t1_id + 11 = t2.t2_id"; + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx + .create_logical_plan(&("explain ".to_owned() + sql)) + .expect(&msg); + let state = ctx.state(); + let plan = state.optimize(&plan)?; + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Filter: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " CrossJoin: [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let expected = vec![ + "+-------+---------+--------+-------+---------+--------+", + "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", + "+-------+---------+--------+-------+---------+--------+", + "| 11 | a | 1 | 22 | y | 1 |", + "| 33 | c | 3 | 44 | x | 3 |", + "| 44 | d | 4 | 55 | w | 3 |", + "+-------+---------+--------+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} \ No newline at end of file diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 1e1307672394a..84dce97f56f93 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -195,7 +195,7 @@ fn create_join_context( ])); let t1_data = RecordBatch::try_new( t1_schema, - vec![ + vec![ Arc::new(UInt32Array::from_slice([11, 22, 33, 44])), Arc::new(StringArray::from(vec![ Some("a"), diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index e6ffd3e6361aa..3ac3bfcb71381 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -59,6 +59,7 @@ impl OptimizerRule for EliminateCrossJoin { plan: &LogicalPlan, _optimizer_config: &mut OptimizerConfig, ) -> Result { + println!("EliminateCrossJoin"); match plan { LogicalPlan::Filter(filter) => { let input = (**filter.input()).clone(); @@ -1266,4 +1267,34 @@ mod tests { Ok(()) } + + #[test] + fn reorder_join() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // could eliminate to inner join + let plan = LogicalPlanBuilder::from(t1) + .cross_join(&t2)? + .filter((col("t1.a") + lit(11u32)).eq(col("t2.a")))? + .build()?; + + let expected = vec![ + "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32, a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]", + " Projection: t1.a, t1.b, t1.c, t3.a, t3.b, t3.c, t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]", + " Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t1.a * UInt32(2):UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]", + " Projection: t1.a, t1.b, t1.c, t1.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t1.a * UInt32(2):UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " Projection: t3.a, t3.b, t3.c, t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + ]; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } } From d148fce99d1e5786c8b2324bb88eb95fc2782e5e Mon Sep 17 00:00:00 2001 From: ygf11 Date: Fri, 2 Dec 2022 02:43:27 -0500 Subject: [PATCH 5/9] Add alias for cast expr join keys --- datafusion/core/tests/sql/joins.rs | 186 +++++++++--------- datafusion/expr/src/logical_plan/builder.rs | 53 ++++- .../optimizer/src/eliminate_cross_join.rs | 52 +---- datafusion/sql/src/planner.rs | 20 +- 4 files changed, 146 insertions(+), 165 deletions(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 6af1dd9fe63cc..701547d1228cd 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -2305,96 +2305,96 @@ async fn error_cross_join() -> Result<()> { Ok(()) } -#[tokio::test] -async fn reduce_cross_join_with_expr_join_key_all() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - - // reduce to inner join - let sql = - "select * from t1 cross join t2 where t1.t1_id + 11 = t2.t2_id"; - let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx - .create_logical_plan(&("explain ".to_owned() + sql)) - .expect(&msg); - let state = ctx.state(); - let plan = state.optimize(&plan)?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " CrossJoin: [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - let expected = vec![ - "+-------+---------+--------+-------+---------+--------+", - "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", - "+-------+---------+--------+-------+---------+--------+", - "| 11 | a | 1 | 22 | y | 1 |", - "| 33 | c | 3 | 44 | x | 3 |", - "| 44 | d | 4 | 55 | w | 3 |", - "+-------+---------+--------+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } - - Ok(()) -} - -#[tokio::test] -async fn reduce_cross_join_with_expr_join_key_some() -> Result<()> { - let test_repartition_joins = vec![true, false]; - for repartition_joins in test_repartition_joins { - let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - - // reduce to inner join - let sql = - "select t1.t1_id, t2.t2_id, t1.t1_name from t1 cross join t2 where t1.t1_id + 11 = t2.t2_id"; - let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx - .create_logical_plan(&("explain ".to_owned() + sql)) - .expect(&msg); - let state = ctx.state(); - let plan = state.optimize(&plan)?; - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " CrossJoin: [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - ]; - let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - let expected = vec![ - "+-------+---------+--------+-------+---------+--------+", - "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", - "+-------+---------+--------+-------+---------+--------+", - "| 11 | a | 1 | 22 | y | 1 |", - "| 33 | c | 3 | 44 | x | 3 |", - "| 44 | d | 4 | 55 | w | 3 |", - "+-------+---------+--------+-------+---------+--------+", - ]; - - let results = execute_to_batches(&ctx, sql).await; - assert_batches_sorted_eq!(expected, &results); - } - - Ok(()) -} \ No newline at end of file +// #[tokio::test] +// async fn reduce_cross_join_with_expr_join_key_all() -> Result<()> { +// let test_repartition_joins = vec![true, false]; +// for repartition_joins in test_repartition_joins { +// let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + +// // reduce to inner join +// let sql = +// "select * from t1 cross join t2 where t1.t1_id + 11 = t2.t2_id"; +// let msg = format!("Creating logical plan for '{}'", sql); +// let plan = ctx +// .create_logical_plan(&("explain ".to_owned() + sql)) +// .expect(&msg); +// let state = ctx.state(); +// let plan = state.optimize(&plan)?; +// let expected = vec![ +// "Explain [plan_type:Utf8, plan:Utf8]", +// " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", +// " Filter: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", +// " CrossJoin: [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", +// " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", +// " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", +// ]; +// let formatted = plan.display_indent_schema().to_string(); +// let actual: Vec<&str> = formatted.trim().lines().collect(); +// assert_eq!( +// expected, actual, +// "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", +// expected, actual +// ); +// let expected = vec![ +// "+-------+---------+--------+-------+---------+--------+", +// "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", +// "+-------+---------+--------+-------+---------+--------+", +// "| 11 | a | 1 | 22 | y | 1 |", +// "| 33 | c | 3 | 44 | x | 3 |", +// "| 44 | d | 4 | 55 | w | 3 |", +// "+-------+---------+--------+-------+---------+--------+", +// ]; + +// let results = execute_to_batches(&ctx, sql).await; +// assert_batches_sorted_eq!(expected, &results); +// } + +// Ok(()) +// } + +// #[tokio::test] +// async fn reduce_cross_join_with_expr_join_key_some() -> Result<()> { +// let test_repartition_joins = vec![true, false]; +// for repartition_joins in test_repartition_joins { +// let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + +// // reduce to inner join +// let sql = +// "select t1.t1_id, t2.t2_id, t1.t1_name from t1 cross join t2 where t1.t1_id + 11 = t2.t2_id"; +// let msg = format!("Creating logical plan for '{}'", sql); +// let plan = ctx +// .create_logical_plan(&("explain ".to_owned() + sql)) +// .expect(&msg); +// let state = ctx.state(); +// let plan = state.optimize(&plan)?; +// let expected = vec![ +// "Explain [plan_type:Utf8, plan:Utf8]", +// " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", +// " Filter: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", +// " CrossJoin: [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", +// " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", +// " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", +// ]; +// let formatted = plan.display_indent_schema().to_string(); +// let actual: Vec<&str> = formatted.trim().lines().collect(); +// assert_eq!( +// expected, actual, +// "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", +// expected, actual +// ); +// let expected = vec![ +// "+-------+---------+--------+-------+---------+--------+", +// "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", +// "+-------+---------+--------+-------+---------+--------+", +// "| 11 | a | 1 | 22 | y | 1 |", +// "| 33 | c | 3 | 44 | x | 3 |", +// "| 44 | d | 4 | 55 | w | 3 |", +// "+-------+---------+--------+-------+---------+--------+", +// ]; + +// let results = execute_to_batches(&ctx, sql).await; +// assert_batches_sorted_eq!(expected, &results); +// } + +// Ok(()) +// } \ No newline at end of file diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index b6ecbd532d8ff..2facaa0bf2abd 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1017,17 +1017,44 @@ pub fn table_scan( pub fn wrap_projection_for_join_if_necessary( join_keys: &[Expr], input: LogicalPlan, -) -> Result<(LogicalPlan, bool)> { - let expr_join_keys = join_keys +) -> Result<(LogicalPlan, Vec, bool)> { + let input_schema = input.schema(); + let alias_join_keys: Vec = join_keys .iter() - .flat_map(|expr| expr.try_into_col().is_err().then_some(expr)) - .cloned() - .collect::>(); + .map(|key| { + // The display_name() of cast expression will ignore the cast info, and show the inner expression name. + // If we do not add alais, it will throw same field name error in the schema when adding projection. + // For example: + // input scan : [a, b, c], + // join keys: [cast(a as int)] + // + // then a and cast(a as int) will use the same field name - `a` in projection schema. + if matches!(key, Expr::Cast(_)) + || matches!( + key, + Expr::TryCast { + expr: _, + data_type: _ + } + ) + { + let alias = format!("{:?}", key); + key.clone().alias(alias) + } else { + key.clone() + } + }) + .collect::>(); - let need_project = !expr_join_keys.is_empty(); + let need_project = join_keys.iter().any(|key| !matches!(key, Expr::Column(_))); let plan = if need_project { - let mut projection = vec![Expr::Wildcard]; - projection.extend(expr_join_keys.into_iter()); + let mut projection = expand_wildcard(input_schema, &input)?; + let join_key_items = alias_join_keys + .iter() + .flat_map(|expr| expr.try_into_col().is_err().then_some(expr)) + .cloned() + .collect::>(); + projection.extend(join_key_items); LogicalPlanBuilder::from(input) .project(projection)? @@ -1036,7 +1063,15 @@ pub fn wrap_projection_for_join_if_necessary( input }; - Ok((plan, need_project)) + let join_on = alias_join_keys + .into_iter() + .map(|key| { + key.try_into_col() + .or_else(|_| Ok(Column::from_name(key.display_name()?))) + }) + .collect::>>()?; + + Ok((plan, join_on, need_project)) } /// Basic TableSource implementation intended for use in tests and documentation. It is expected diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 3ac3bfcb71381..11be5dee9def7 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -241,26 +241,16 @@ fn find_inner_join( // Wrap projection let (left_on, right_on): (Vec, Vec) = join_keys.into_iter().unzip(); - let (new_left_input, _) = + let (new_left_input, new_left_on, _) = wrap_projection_for_join_if_necessary(&left_on, left_input.clone())?; - let (new_right_input, _) = + let (new_right_input, new_right_on, _) = wrap_projection_for_join_if_necessary(&right_on, right_input)?; // Build new join on - let join_on = left_on - .iter() - .zip(right_on.iter()) - .map(|(left, right)| { - let left_key = left.try_into_col().or_else(|_| { - Result::Ok(Column::from_name(left.display_name()?)) - })?; - let right_key = right.try_into_col().or_else(|_| { - Result::Ok(Column::from_name(right.display_name()?)) - })?; - - Ok((left_key, right_key)) - }) - .collect::>>()?; + let join_on = new_left_on + .into_iter() + .zip(new_right_on.into_iter()) + .collect::>(); return Ok(LogicalPlan::Join(Join { left: Arc::new(new_left_input), @@ -1267,34 +1257,4 @@ mod tests { Ok(()) } - - #[test] - fn reorder_join() -> Result<()> { - let t1 = test_table_scan_with_name("t1")?; - let t2 = test_table_scan_with_name("t2")?; - - // could eliminate to inner join - let plan = LogicalPlanBuilder::from(t1) - .cross_join(&t2)? - .filter((col("t1.a") + lit(11u32)).eq(col("t2.a")))? - .build()?; - - let expected = vec![ - "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32, a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]", - " Projection: t1.a, t1.b, t1.c, t3.a, t3.b, t3.c, t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]", - " Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t1.a * UInt32(2):UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]", - " Projection: t1.a, t1.b, t1.c, t1.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t1.a * UInt32(2):UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " Projection: t3.a, t3.b, t3.c, t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(&plan, expected); - - Ok(()) - } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index c614fbebb8a3d..f367304e578c5 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -778,26 +778,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .build() } else { // Wrap projection for left input if left join keys contain normal expression. - let (left_child, left_projected) = + let (left_child, left_join_keys, left_projected) = wrap_projection_for_join_if_necessary(&left_keys, left)?; - let left_join_keys = left_keys - .iter() - .map(|key| { - key.try_into_col() - .or_else(|_| Ok(Column::from_name(key.display_name()?))) - }) - .collect::>>()?; // Wrap projection for right input if right join keys contains normal expression. - let (right_child, right_projected) = + let (right_child, right_join_keys, right_projected) = wrap_projection_for_join_if_necessary(&right_keys, right)?; - let right_join_keys = right_keys - .iter() - .map(|key| { - key.try_into_col() - .or_else(|_| Ok(Column::from_name(key.display_name()?))) - }) - .collect::>>()?; let join_plan_builder = LogicalPlanBuilder::from(left_child).join( &right_child, @@ -805,7 +791,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { (left_join_keys, right_join_keys), join_filter, )?; - + // Remove temporary projected columns if necessary. if left_projected || right_projected { let final_join_result = join_schema From 34f883c626fa780fba454d153d41ddee0f04c4a0 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Fri, 2 Dec 2022 03:09:04 -0500 Subject: [PATCH 6/9] Add tests --- datafusion/core/tests/sql/joins.rs | 188 +++++++++--------- datafusion/core/tests/sql/mod.rs | 2 +- .../optimizer/src/eliminate_cross_join.rs | 3 +- datafusion/sql/src/planner.rs | 19 +- 4 files changed, 115 insertions(+), 97 deletions(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 701547d1228cd..fa49bfc3f72a1 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -2305,96 +2305,98 @@ async fn error_cross_join() -> Result<()> { Ok(()) } -// #[tokio::test] -// async fn reduce_cross_join_with_expr_join_key_all() -> Result<()> { -// let test_repartition_joins = vec![true, false]; -// for repartition_joins in test_repartition_joins { -// let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - -// // reduce to inner join -// let sql = -// "select * from t1 cross join t2 where t1.t1_id + 11 = t2.t2_id"; -// let msg = format!("Creating logical plan for '{}'", sql); -// let plan = ctx -// .create_logical_plan(&("explain ".to_owned() + sql)) -// .expect(&msg); -// let state = ctx.state(); -// let plan = state.optimize(&plan)?; -// let expected = vec![ -// "Explain [plan_type:Utf8, plan:Utf8]", -// " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", -// " Filter: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", -// " CrossJoin: [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", -// " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", -// " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", -// ]; -// let formatted = plan.display_indent_schema().to_string(); -// let actual: Vec<&str> = formatted.trim().lines().collect(); -// assert_eq!( -// expected, actual, -// "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", -// expected, actual -// ); -// let expected = vec![ -// "+-------+---------+--------+-------+---------+--------+", -// "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", -// "+-------+---------+--------+-------+---------+--------+", -// "| 11 | a | 1 | 22 | y | 1 |", -// "| 33 | c | 3 | 44 | x | 3 |", -// "| 44 | d | 4 | 55 | w | 3 |", -// "+-------+---------+--------+-------+---------+--------+", -// ]; - -// let results = execute_to_batches(&ctx, sql).await; -// assert_batches_sorted_eq!(expected, &results); -// } - -// Ok(()) -// } - -// #[tokio::test] -// async fn reduce_cross_join_with_expr_join_key_some() -> Result<()> { -// let test_repartition_joins = vec![true, false]; -// for repartition_joins in test_repartition_joins { -// let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; - -// // reduce to inner join -// let sql = -// "select t1.t1_id, t2.t2_id, t1.t1_name from t1 cross join t2 where t1.t1_id + 11 = t2.t2_id"; -// let msg = format!("Creating logical plan for '{}'", sql); -// let plan = ctx -// .create_logical_plan(&("explain ".to_owned() + sql)) -// .expect(&msg); -// let state = ctx.state(); -// let plan = state.optimize(&plan)?; -// let expected = vec![ -// "Explain [plan_type:Utf8, plan:Utf8]", -// " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", -// " Filter: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", -// " CrossJoin: [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", -// " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", -// " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", -// ]; -// let formatted = plan.display_indent_schema().to_string(); -// let actual: Vec<&str> = formatted.trim().lines().collect(); -// assert_eq!( -// expected, actual, -// "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", -// expected, actual -// ); -// let expected = vec![ -// "+-------+---------+--------+-------+---------+--------+", -// "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", -// "+-------+---------+--------+-------+---------+--------+", -// "| 11 | a | 1 | 22 | y | 1 |", -// "| 33 | c | 3 | 44 | x | 3 |", -// "| 44 | d | 4 | 55 | w | 3 |", -// "+-------+---------+--------+-------+---------+--------+", -// ]; - -// let results = execute_to_batches(&ctx, sql).await; -// assert_batches_sorted_eq!(expected, &results); -// } - -// Ok(()) -// } \ No newline at end of file +#[tokio::test] +async fn reduce_cross_join_with_expr_join_key_all() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + // reduce to inner join + let sql = "select * from t1 cross join t2 where t1.t1_id + 12 = t2.t2_id + 1"; + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx + .create_logical_plan(&("explain ".to_owned() + sql)) + .expect(&msg); + let state = ctx.state(); + let plan = state.optimize(&plan)?; + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Inner Join: t1.t1_id + Int64(12) = t2.t2_id + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t1.t1_id + Int64(12):Int64;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2.t2_id + Int64(1):Int64;N]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int, CAST(t1.t1_id AS Int64) + Int64(12) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t1.t1_id + Int64(12):Int64;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Projection: t2.t2_id, t2.t2_name, t2.t2_int, CAST(t2.t2_id AS Int64) + Int64(1) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2.t2_id + Int64(1):Int64;N]", + " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let expected = vec![ + "+-------+---------+--------+-------+---------+--------+", + "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |", + "+-------+---------+--------+-------+---------+--------+", + "| 11 | a | 1 | 22 | y | 1 |", + "| 33 | c | 3 | 44 | x | 3 |", + "| 44 | d | 4 | 55 | w | 3 |", + "+-------+---------+--------+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} + +#[tokio::test] +async fn reduce_cross_join_with_cast_expr_join_key() -> Result<()> { + let test_repartition_joins = vec![true, false]; + for repartition_joins in test_repartition_joins { + let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?; + + // reduce to inner join, t2.t2_id will insert cast. + let sql = + "select t1.t1_id, t2.t2_id, t1.t1_name from t1 cross join t2 where t1.t1_id + 11 = t2.t2_id"; + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx + .create_logical_plan(&("explain ".to_owned() + sql)) + .expect(&msg); + let state = ctx.state(); + let plan = state.optimize(&plan)?; + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t2.t2_id, t1.t1_name [t1_id:UInt32;N, t2_id:UInt32;N, t1_name:Utf8;N]", + " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", + " Inner Join: t1.t1_id + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1.t1_id + Int64(11):Int64;N, t2_id:UInt32;N, CAST(t2.t2_id AS Int64):Int64;N]", + " Projection: t1.t1_id, t1.t1_name, CAST(t1.t1_id AS Int64) + Int64(11) [t1_id:UInt32;N, t1_name:Utf8;N, t1.t1_id + Int64(11):Int64;N]", + " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", + " Projection: t2.t2_id, CAST(t2.t2_id AS Int64) AS CAST(t2.t2_id AS Int64) [t2_id:UInt32;N, CAST(t2.t2_id AS Int64):Int64;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let expected = vec![ + "+-------+-------+---------+", + "| t1_id | t2_id | t1_name |", + "+-------+-------+---------+", + "| 11 | 22 | a |", + "| 33 | 44 | c |", + "| 44 | 55 | d |", + "+-------+-------+---------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + } + + Ok(()) +} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 84dce97f56f93..1e1307672394a 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -195,7 +195,7 @@ fn create_join_context( ])); let t1_data = RecordBatch::try_new( t1_schema, - vec![ + vec![ Arc::new(UInt32Array::from_slice([11, 22, 33, 44])), Arc::new(StringArray::from(vec![ Some("a"), diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 11be5dee9def7..264bff354697b 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -17,7 +17,7 @@ //! Optimizer rule to eliminate cross join to inner join if join predicates are available in filters. use crate::{utils, OptimizerConfig, OptimizerRule}; -use datafusion_common::{Column, DFSchema, DataFusionError, Result}; +use datafusion_common::{DFSchema, DataFusionError, Result}; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::utils::{can_hash, check_all_column_from_schema}; use datafusion_expr::{ @@ -59,7 +59,6 @@ impl OptimizerRule for EliminateCrossJoin { plan: &LogicalPlan, _optimizer_config: &mut OptimizerConfig, ) -> Result { - println!("EliminateCrossJoin"); match plan { LogicalPlan::Filter(filter) => { let input = (**filter.input()).clone(); diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index f367304e578c5..2e73039914059 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -791,7 +791,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { (left_join_keys, right_join_keys), join_filter, )?; - + // Remove temporary projected columns if necessary. if left_projected || right_projected { let final_join_result = join_schema @@ -5954,6 +5954,23 @@ mod tests { quick_test(sql, expected); } + #[test] + fn test_inner_join_with_cast_key() { + let sql = "SELECT person.id, person.age + FROM person + INNER JOIN orders + ON cast(person.id as Int) = cast(orders.customer_id as Int)"; + + let expected = "Projection: person.id, person.age\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + \n Inner Join: CAST(person.id AS Int32) = CAST(orders.customer_id AS Int32)\ + \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, CAST(person.id AS Int32) AS CAST(person.id AS Int32)\ + \n TableScan: person\ + \n Projection: orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered, CAST(orders.customer_id AS Int32) AS CAST(orders.customer_id AS Int32)\ + \n TableScan: orders"; + quick_test(sql, expected); + } + fn assert_field_not_found(err: DataFusionError, name: &str) { match err { DataFusionError::SchemaError { .. } => { From 85221fd86cd3c8cb1d14f4c859b72b3fccb7c8d8 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Fri, 2 Dec 2022 03:41:33 -0500 Subject: [PATCH 7/9] Add relative issue comment --- datafusion/expr/src/logical_plan/builder.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 2facaa0bf2abd..5c42e9f72864e 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1029,6 +1029,7 @@ pub fn wrap_projection_for_join_if_necessary( // join keys: [cast(a as int)] // // then a and cast(a as int) will use the same field name - `a` in projection schema. + // https://github.com/apache/arrow-datafusion/issues/4478 if matches!(key, Expr::Cast(_)) || matches!( key, From d4357334c61301b678e68b413a790e0c7768b6fd Mon Sep 17 00:00:00 2001 From: ygf11 Date: Fri, 2 Dec 2022 05:00:42 -0500 Subject: [PATCH 8/9] Improve test --- datafusion/core/tests/sql/joins.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index fa49bfc3f72a1..92b77540e1233 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -2360,7 +2360,7 @@ async fn reduce_cross_join_with_cast_expr_join_key() -> Result<()> { // reduce to inner join, t2.t2_id will insert cast. let sql = - "select t1.t1_id, t2.t2_id, t1.t1_name from t1 cross join t2 where t1.t1_id + 11 = t2.t2_id"; + "select t1.t1_id, t2.t2_id, t1.t1_name from t1 cross join t2 where t1.t1_id + 11 = cast(t2.t2_id as BIGINT)"; let msg = format!("Creating logical plan for '{}'", sql); let plan = ctx .create_logical_plan(&("explain ".to_owned() + sql)) From 44d3099ea1e56f156e9119dd24da7c075eb7d296 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 7 Dec 2022 08:44:18 -0500 Subject: [PATCH 9/9] Improve use declarations --- datafusion/expr/src/lib.rs | 4 +++- datafusion/optimizer/src/eliminate_cross_join.rs | 13 ++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index e061b134552e3..3c18b04818c36 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -67,7 +67,9 @@ pub use function::{ }; pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; pub use logical_plan::{ - builder::{build_join_schema, union, UNNAMED_TABLE}, + builder::{ + build_join_schema, union, wrap_projection_for_join_if_necessary, UNNAMED_TABLE, + }, Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateMemoryTable, CreateView, CrossJoin, Distinct, DropTable, DropView, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, JoinType, Limit, diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index c9847fe80f3c0..8ca457771646a 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -18,16 +18,15 @@ //! Optimizer rule to eliminate cross join to inner join if join predicates are available in filters. use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; -use datafusion_expr::logical_plan::JoinConstraint; +use datafusion_expr::expr::{BinaryExpr, Expr}; +use datafusion_expr::logical_plan::{ + CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, +}; use datafusion_expr::utils::{can_hash, check_all_column_from_schema}; use datafusion_expr::{ - and, build_join_schema, - expr::BinaryExpr, - logical_plan::{CrossJoin, Filter, Join, JoinType, LogicalPlan}, - or, Projection, + and, build_join_schema, or, wrap_projection_for_join_if_necessary, ExprSchemable, + Operator, }; -use datafusion_expr::{Expr, ExprSchemable, Operator}; use std::collections::HashSet; use std::sync::Arc;