From 4e4f4ed307ec61f5956eab0388fa9a497edb1903 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 28 Dec 2022 02:04:08 -0500 Subject: [PATCH 1/7] Refactor extract_join_keys with split_conjunction --- .../src/extract_equijoin_predicate.rs | 129 +++++++++--------- 1 file changed, 64 insertions(+), 65 deletions(-) diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 060cd82f63fa7..4ebf0566dacb1 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -17,6 +17,7 @@ //! Optimizer rule to extract equijoin expr from filter use crate::optimizer::ApplyOrder; +use crate::utils::split_conjunction; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::DFSchema; use datafusion_common::Result; @@ -24,6 +25,9 @@ use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator}; use std::sync::Arc; +// equijoin predicate +type EquijoinPredicate = (Expr, Expr); + /// Optimization rule that extract equijoin expr from the filter #[derive(Default)] pub struct ExtractEquijoinPredicate; @@ -56,27 +60,22 @@ impl OptimizerRule for ExtractEquijoinPredicate { let right_schema = right.schema(); filter.as_ref().map_or(Result::Ok(None), |expr| { - let mut accum: Vec<(Expr, Expr)> = vec![]; - let mut accum_filter: Vec = vec![]; - // TODO: avoding clone with split_conjunction - extract_join_keys( - expr.clone(), - &mut accum, - &mut accum_filter, - left_schema, - right_schema, - )?; - - let optimized_plan = (!accum.is_empty()).then(|| { + let (equijoin_predicates, non_equijoin_expr) = + split_equi_and_none_equijoin_predicate( + expr, + left_schema, + right_schema, + )?; + + let optimized_plan = (!equijoin_predicates.is_empty()).then(|| { let mut new_on = on.clone(); - new_on.extend(accum); + new_on.extend(equijoin_predicates); - let new_filter = accum_filter.into_iter().reduce(Expr::and); LogicalPlan::Join(Join { left: left.clone(), right: right.clone(), on: new_on, - filter: new_filter, + filter: non_equijoin_expr, join_type: *join_type, join_constraint: *join_constraint, schema: schema.clone(), @@ -100,30 +99,22 @@ impl OptimizerRule for ExtractEquijoinPredicate { } } -/// Extracts equijoin ON condition be a single Eq or multiple conjunctive Eqs -/// Filters matching this pattern are added to `accum` -/// Filters that don't match this pattern are added to `accum_filter` -/// Examples: -/// ```text -/// foo = bar => accum=[(foo, bar)] accum_filter=[] -/// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[] -/// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1] -/// -/// For equijoin join key, assume we have tables -- a(c0, c1 c2) and b(c0, c1, c2): -/// (a.c0 = 10) => accum=[], accum_filter=[a.c0 = 10] -/// (a.c0 + 1 = b.c0 * 2) => accum=[(a.c0 + 1, b.c0 * 2)], accum_filter=[] -/// (a.c0 + b.c0 = 10) => accum=[], accum_filter=[a.c0 + b.c0 = 10] -/// ``` -fn extract_join_keys( - expr: Expr, - accum: &mut Vec<(Expr, Expr)>, - accum_filter: &mut Vec, +fn split_equi_and_none_equijoin_predicate( + expr: &Expr, left_schema: &Arc, right_schema: &Arc, -) -> Result<()> { - match &expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - Operator::Eq => { +) -> Result<(Vec, Option)> { + let filters = split_conjunction(expr); + + let mut accum_join_keys: Vec<(Expr, Expr)> = vec![]; + let mut accum_filters: Vec = vec![]; + for filter in filters { + match filter { + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) => { let left = left.as_ref(); let right = right.as_ref(); @@ -139,48 +130,27 @@ fn extract_join_keys( let right_expr_type = right_expr.get_type(right_schema)?; if can_hash(&left_expr_type) && can_hash(&right_expr_type) { - accum.push((left_expr, right_expr)); + accum_join_keys.push((left_expr, right_expr)); } else { - accum_filter.push(expr); + accum_filters.push(filter.clone()); } } else { - accum_filter.push(expr); - } - } - Operator::And => { - if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = expr { - extract_join_keys( - *left, - accum, - accum_filter, - left_schema, - right_schema, - )?; - extract_join_keys( - *right, - accum, - accum_filter, - left_schema, - right_schema, - )?; + accum_filters.push(filter.clone()); } } - _other => { - accum_filter.push(expr); - } - }, - _other => { - accum_filter.push(expr); + _ => accum_filters.push(filter.clone()), } } - Ok(()) + let result_filter = accum_filters.into_iter().reduce(Expr::and); + Ok((accum_join_keys, result_filter)) } #[cfg(test)] mod tests { use super::*; use crate::test::*; + use arrow::datatypes::DataType; use datafusion_common::Column; use datafusion_expr::{ col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType, @@ -387,4 +357,33 @@ mod tests { assert_plan_eq(&plan, expected) } + + #[test] + fn join_with_alias_filter() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let t1_schema = t1.schema().clone(); + let t2_schema = t2.schema().clone(); + + // filter: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) as t1.a + 1 = t2.a + 2 + let filter = Expr::eq( + col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, &t1_schema)?, + col("t2.a") + lit(2i32).cast_to(&DataType::UInt32, &t2_schema)?, + ) + .alias("t1.a + 1 = t2.a + 2"); + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (Vec::::new(), Vec::::new()), + Some(filter), + )? + .build()?; + let expected = "Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; + + assert_plan_eq(&plan, expected) + } } From 28aa76a41ae439a448414ee5ecb60e0008dde96c Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 28 Dec 2022 04:08:42 -0500 Subject: [PATCH 2/7] reorder the ExtractEquijoinPredicate rule --- datafusion/optimizer/src/optimizer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 36968f2f18d65..4dd40d6a124c1 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -238,13 +238,13 @@ impl Optimizer { let rules: Vec> = vec![ Arc::new(InlineTableScan::new()), Arc::new(TypeCoercion::new()), - Arc::new(ExtractEquijoinPredicate::new()), Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(DecorrelateWhereExists::new()), Arc::new(DecorrelateWhereIn::new()), Arc::new(ScalarSubqueryToJoin::new()), Arc::new(SubqueryFilterToJoin::new()), + Arc::new(ExtractEquijoinPredicate::new()), // simplify expressions does not simplify expressions in subqueries, so we // run it again after running the optimizations that potentially converted // subqueries to joins From 59d0a374c3215359619c971eb31da80d68bd5bce Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 28 Dec 2022 05:19:38 -0500 Subject: [PATCH 3/7] fix cargo test --- datafusion/core/tests/sql/joins.rs | 48 +++++++++---------- .../src/extract_equijoin_predicate.rs | 18 +++---- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index b6c78b0cf3814..5c4171484af59 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -2452,29 +2452,29 @@ async fn both_side_expr_key_inner_join() -> Result<()> { "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name]", " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id]", " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1.t1_id + Int64(12)\", index: 2 }, Column { name: \"t2.t2_id + Int64(1)\", index: 1 })]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1.t1_id + UInt32(12)\", index: 2 }, Column { name: \"t2.t2_id + UInt32(1)\", index: 1 })]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1.t1_id + Int64(12)\", index: 2 }], 2)", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + CAST(12 AS UInt32) as t1.t1_id + Int64(12)]", + " RepartitionExec: partitioning=Hash([Column { name: \"t1.t1_id + UInt32(12)\", index: 2 }], 2)", + " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as t1.t1_id + UInt32(12)]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " MemoryExec: partitions=1, partition_sizes=[1]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2.t2_id + Int64(1)\", index: 1 }], 2)", - " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + CAST(1 AS UInt32) as t2.t2_id + Int64(1)]", + " RepartitionExec: partitioning=Hash([Column { name: \"t2.t2_id + UInt32(1)\", index: 1 }], 2)", + " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as t2.t2_id + UInt32(1)]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " MemoryExec: partitions=1, partition_sizes=[1]", - ] + ] } else { vec![ "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name]", " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id]", " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(Column { name: \"t1.t1_id + Int64(12)\", index: 2 }, Column { name: \"t2.t2_id + Int64(1)\", index: 1 })]", + " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(Column { name: \"t1.t1_id + UInt32(12)\", index: 2 }, Column { name: \"t2.t2_id + UInt32(1)\", index: 1 })]", " CoalescePartitionsExec", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + CAST(12 AS UInt32) as t1.t1_id + Int64(12)]", + " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 12 as t1.t1_id + UInt32(12)]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " MemoryExec: partitions=1, partition_sizes=[1]", - " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + CAST(1 AS UInt32) as t2.t2_id + Int64(1)]", + " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 1 as t2.t2_id + UInt32(1)]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " MemoryExec: partitions=1, partition_sizes=[1]", ] @@ -2524,10 +2524,10 @@ async fn left_side_expr_key_inner_join() -> Result<()> { "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name]", " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id]", " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1.t1_id + Int64(11)\", index: 2 }, Column { name: \"t2_id\", index: 0 })]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1.t1_id + UInt32(11)\", index: 2 }, Column { name: \"t2_id\", index: 0 })]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t1.t1_id + Int64(11)\", index: 2 }], 2)", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + CAST(11 AS UInt32) as t1.t1_id + Int64(11)]", + " RepartitionExec: partitioning=Hash([Column { name: \"t1.t1_id + UInt32(11)\", index: 2 }], 2)", + " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as t1.t1_id + UInt32(11)]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " MemoryExec: partitions=1, partition_sizes=[1]", " CoalesceBatchesExec: target_batch_size=4096", @@ -2541,9 +2541,9 @@ async fn left_side_expr_key_inner_join() -> Result<()> { " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@3 as t2_id]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(Column { name: \"t1.t1_id + Int64(11)\", index: 2 }, Column { name: \"t2_id\", index: 0 })]", + " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(Column { name: \"t1.t1_id + UInt32(11)\", index: 2 }, Column { name: \"t2_id\", index: 0 })]", " CoalescePartitionsExec", - " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + CAST(11 AS UInt32) as t1.t1_id + Int64(11)]", + " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_id@0 + 11 as t1.t1_id + UInt32(11)]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " MemoryExec: partitions=1, partition_sizes=[1]", " MemoryExec: partitions=1, partition_sizes=[1]", @@ -2594,14 +2594,14 @@ async fn right_side_expr_key_inner_join() -> Result<()> { "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name]", " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@2 as t2_id]", " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - Int64(11)\", index: 1 })]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - UInt32(11)\", index: 1 })]", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2)", " RepartitionExec: partitioning=RoundRobinBatch(2)", " MemoryExec: partitions=1, partition_sizes=[1]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2.t2_id - Int64(11)\", index: 1 }], 2)", - " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - CAST(11 AS UInt32) as t2.t2_id - Int64(11)]", + " RepartitionExec: partitioning=Hash([Column { name: \"t2.t2_id - UInt32(11)\", index: 1 }], 2)", + " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as t2.t2_id - UInt32(11)]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " MemoryExec: partitions=1, partition_sizes=[1]", ] @@ -2610,9 +2610,9 @@ async fn right_side_expr_key_inner_join() -> Result<()> { "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, t1_name@1 as t1_name]", " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t2_id@2 as t2_id]", " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - Int64(11)\", index: 1 })]", + " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - UInt32(11)\", index: 1 })]", " MemoryExec: partitions=1, partition_sizes=[1]", - " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - CAST(11 AS UInt32) as t2.t2_id - Int64(11)]", + " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 11 as t2.t2_id - UInt32(11)]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " MemoryExec: partitions=1, partition_sizes=[1]", ] @@ -2662,14 +2662,14 @@ async fn select_wildcard_with_expr_key_inner_join() -> Result<()> { "ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int]", " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int]", " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - Int64(11)\", index: 3 })]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - UInt32(11)\", index: 3 })]", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([Column { name: \"t1_id\", index: 0 }], 2)", " RepartitionExec: partitioning=RoundRobinBatch(2)", " MemoryExec: partitions=1, partition_sizes=[1]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"t2.t2_id - Int64(11)\", index: 3 }], 2)", - " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, t2_id@0 - CAST(11 AS UInt32) as t2.t2_id - Int64(11)]", + " RepartitionExec: partitioning=Hash([Column { name: \"t2.t2_id - UInt32(11)\", index: 3 }], 2)", + " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, t2_id@0 - 11 as t2.t2_id - UInt32(11)]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " MemoryExec: partitions=1, partition_sizes=[1]", ] @@ -2678,9 +2678,9 @@ async fn select_wildcard_with_expr_key_inner_join() -> Result<()> { "ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int]", " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int]", " CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - Int64(11)\", index: 3 })]", + " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - UInt32(11)\", index: 3 })]", " MemoryExec: partitions=1, partition_sizes=[1]", - " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, t2_id@0 - CAST(11 AS UInt32) as t2.t2_id - Int64(11)]", + " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as t2_name, t2_int@2 as t2_int, t2_id@0 - 11 as t2.t2_id - UInt32(11)]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " MemoryExec: partitions=1, partition_sizes=[1]", ] diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 4ebf0566dacb1..2f7a20d6e230d 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -61,7 +61,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { filter.as_ref().map_or(Result::Ok(None), |expr| { let (equijoin_predicates, non_equijoin_expr) = - split_equi_and_none_equijoin_predicate( + split_eq_and_noneq_join_predicate( expr, left_schema, right_schema, @@ -99,17 +99,17 @@ impl OptimizerRule for ExtractEquijoinPredicate { } } -fn split_equi_and_none_equijoin_predicate( - expr: &Expr, +fn split_eq_and_noneq_join_predicate( + filter: &Expr, left_schema: &Arc, right_schema: &Arc, ) -> Result<(Vec, Option)> { - let filters = split_conjunction(expr); + let exprs = split_conjunction(filter); let mut accum_join_keys: Vec<(Expr, Expr)> = vec![]; let mut accum_filters: Vec = vec![]; - for filter in filters { - match filter { + for expr in exprs { + match expr { Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, @@ -132,13 +132,13 @@ fn split_equi_and_none_equijoin_predicate( if can_hash(&left_expr_type) && can_hash(&right_expr_type) { accum_join_keys.push((left_expr, right_expr)); } else { - accum_filters.push(filter.clone()); + accum_filters.push(expr.clone()); } } else { - accum_filters.push(filter.clone()); + accum_filters.push(expr.clone()); } } - _ => accum_filters.push(filter.clone()), + _ => accum_filters.push(expr.clone()), } } From b6c560baaa04b779ed21c3fa40cc42baa560c02d Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 28 Dec 2022 05:40:42 -0500 Subject: [PATCH 4/7] trigger ci aggain From f28fe2026bca01eddebd8f9792563ee81d89b473 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Wed, 28 Dec 2022 20:41:41 -0500 Subject: [PATCH 5/7] trigger ci aggain From 0552a69384ad0c491d0b30f72277b6e6a4b995ef Mon Sep 17 00:00:00 2001 From: ygf11 Date: Sun, 1 Jan 2023 02:55:13 -0500 Subject: [PATCH 6/7] add integration test --- datafusion/core/tests/dataframe.rs | 100 ++++++++++++++++++++++++++++- 1 file changed, 97 insertions(+), 3 deletions(-) diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index 00ba4524d02c0..499330e2b30d2 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -17,7 +17,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use arrow::{ - array::{Int32Array, StringArray}, + array::{Int32Array, StringArray, UInt32Array}, record_batch::RecordBatch, }; use datafusion::from_slice::FromSlice; @@ -30,8 +30,7 @@ use datafusion::execution::context::SessionContext; use datafusion::prelude::CsvReadOptions; use datafusion::prelude::JoinType; use datafusion_expr::expr::{GroupingSet, Sort}; -use datafusion_expr::{avg, count, lit, sum}; -use datafusion_expr::{col, Expr}; +use datafusion_expr::{avg, col, count, lit, sum, Expr, ExprSchemable}; #[tokio::test] async fn join() -> Result<()> { @@ -352,6 +351,62 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { Ok(()) } +#[tokio::test] +async fn join_with_alias_filter() -> Result<()> { + let join_ctx = create_join_context()?; + let t1 = join_ctx.table("t1")?; + let t2 = join_ctx.table("t2")?; + let t1_schema = t1.schema().clone(); + let t2_schema = t2.schema().clone(); + + // filter: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) as t1.a + 1 = t2.a + 2 + let filter = Expr::eq( + col("t1.a") + lit(3i64).cast_to(&DataType::UInt32, &t1_schema)?, + col("t2.a") + lit(1i32).cast_to(&DataType::UInt32, &t2_schema)?, + ) + .alias("t1.b + 1 = t2.a + 2"); + + let df = t1 + .join(t2, JoinType::Inner, &[], &[], Some(filter))? + .select(vec![ + col("t1.a"), + col("t2.a"), + col("t1.b"), + col("t1.c"), + col("t2.b"), + col("t2.c"), + ])?; + let optimized_plan = df.clone().into_optimized_plan()?; + + let expected = vec![ + "Projection: t1.a, t2.a, t1.b, t1.c, t2.b, t2.c [a:UInt32, a:UInt32, b:Utf8, c:Int32, b:Utf8, c:Int32]", + " Inner Join: t1.a + UInt32(3) = t2.a + UInt32(1) [a:UInt32, b:Utf8, c:Int32, a:UInt32, b:Utf8, c:Int32]", + " TableScan: t1 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32]", + " TableScan: t2 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32]", + ]; + + let formatted = optimized_plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let results = df.collect().await?; + let expected: Vec<&str> = vec![ + "+----+----+---+----+---+---+", + "| a | a | b | c | b | c |", + "+----+----+---+----+---+---+", + "| 11 | 13 | c | 30 | c | 3 |", + "| 1 | 3 | a | 10 | a | 1 |", + "+----+----+---+----+---+---+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + fn create_test_table() -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), @@ -388,3 +443,42 @@ async fn aggregates_table(ctx: &SessionContext) -> Result { ) .await } + +fn create_join_context() -> Result { + let t1 = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Int32, false), + ])); + let t2 = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Int32, false), + ])); + + // define data. + let batch1 = RecordBatch::try_new( + t1, + vec![ + Arc::new(UInt32Array::from_slice([1, 10, 11, 100])), + Arc::new(StringArray::from_slice(["a", "b", "c", "d"])), + Arc::new(Int32Array::from_slice([10, 20, 30, 40])), + ], + )?; + // define data. + let batch2 = RecordBatch::try_new( + t2, + vec![ + Arc::new(UInt32Array::from_slice([3, 10, 13, 100])), + Arc::new(StringArray::from_slice(["a", "b", "c", "d"])), + Arc::new(Int32Array::from_slice([1, 2, 3, 4])), + ], + )?; + + let ctx = SessionContext::new(); + + ctx.register_batch("t1", batch1)?; + ctx.register_batch("t2", batch2)?; + + Ok(ctx) +} From 67548fc40625c6be6b2d2917f048f7392df58fe7 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Sun, 1 Jan 2023 03:22:45 -0500 Subject: [PATCH 7/7] fix test --- datafusion/core/tests/dataframe.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index 499330e2b30d2..190248efe847b 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -23,12 +23,12 @@ use arrow::{ use datafusion::from_slice::FromSlice; use std::sync::Arc; -use datafusion::assert_batches_eq; use datafusion::dataframe::DataFrame; use datafusion::error::Result; use datafusion::execution::context::SessionContext; use datafusion::prelude::CsvReadOptions; use datafusion::prelude::JoinType; +use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::{avg, col, count, lit, sum, Expr, ExprSchemable}; @@ -402,7 +402,7 @@ async fn join_with_alias_filter() -> Result<()> { "+----+----+---+----+---+---+", ]; - assert_batches_eq!(expected, &results); + assert_batches_sorted_eq!(expected, &results); Ok(()) }