diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index 00ba4524d02c0..190248efe847b 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -17,21 +17,20 @@ use arrow::datatypes::{DataType, Field, Schema}; use arrow::{ - array::{Int32Array, StringArray}, + array::{Int32Array, StringArray, UInt32Array}, record_batch::RecordBatch, }; use datafusion::from_slice::FromSlice; use std::sync::Arc; -use datafusion::assert_batches_eq; use datafusion::dataframe::DataFrame; use datafusion::error::Result; use datafusion::execution::context::SessionContext; use datafusion::prelude::CsvReadOptions; use datafusion::prelude::JoinType; +use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; use datafusion_expr::expr::{GroupingSet, Sort}; -use datafusion_expr::{avg, count, lit, sum}; -use datafusion_expr::{col, Expr}; +use datafusion_expr::{avg, col, count, lit, sum, Expr, ExprSchemable}; #[tokio::test] async fn join() -> Result<()> { @@ -352,6 +351,62 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { Ok(()) } +#[tokio::test] +async fn join_with_alias_filter() -> Result<()> { + let join_ctx = create_join_context()?; + let t1 = join_ctx.table("t1")?; + let t2 = join_ctx.table("t2")?; + let t1_schema = t1.schema().clone(); + let t2_schema = t2.schema().clone(); + + // filter: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) as t1.a + 1 = t2.a + 2 + let filter = Expr::eq( + col("t1.a") + lit(3i64).cast_to(&DataType::UInt32, &t1_schema)?, + col("t2.a") + lit(1i32).cast_to(&DataType::UInt32, &t2_schema)?, + ) + .alias("t1.b + 1 = t2.a + 2"); + + let df = t1 + .join(t2, JoinType::Inner, &[], &[], Some(filter))? + .select(vec![ + col("t1.a"), + col("t2.a"), + col("t1.b"), + col("t1.c"), + col("t2.b"), + col("t2.c"), + ])?; + let optimized_plan = df.clone().into_optimized_plan()?; + + let expected = vec![ + "Projection: t1.a, t2.a, t1.b, t1.c, t2.b, t2.c [a:UInt32, a:UInt32, b:Utf8, c:Int32, b:Utf8, c:Int32]", + " Inner Join: t1.a + UInt32(3) = t2.a + UInt32(1) [a:UInt32, b:Utf8, c:Int32, a:UInt32, b:Utf8, c:Int32]", + " TableScan: t1 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32]", + " TableScan: t2 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32]", + ]; + + let formatted = optimized_plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let results = df.collect().await?; + let expected: Vec<&str> = vec![ + "+----+----+---+----+---+---+", + "| a | a | b | c | b | c |", + "+----+----+---+----+---+---+", + "| 11 | 13 | c | 30 | c | 3 |", + "| 1 | 3 | a | 10 | a | 1 |", + "+----+----+---+----+---+---+", + ]; + + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + fn create_test_table() -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), @@ -388,3 +443,42 @@ async fn aggregates_table(ctx: &SessionContext) -> Result { ) .await } + +fn create_join_context() -> Result { + let t1 = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Int32, false), + ])); + let t2 = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Int32, false), + ])); + + // define data. + let batch1 = RecordBatch::try_new( + t1, + vec![ + Arc::new(UInt32Array::from_slice([1, 10, 11, 100])), + Arc::new(StringArray::from_slice(["a", "b", "c", "d"])), + Arc::new(Int32Array::from_slice([10, 20, 30, 40])), + ], + )?; + // define data. + let batch2 = RecordBatch::try_new( + t2, + vec![ + Arc::new(UInt32Array::from_slice([3, 10, 13, 100])), + Arc::new(StringArray::from_slice(["a", "b", "c", "d"])), + Arc::new(Int32Array::from_slice([1, 2, 3, 4])), + ], + )?; + + let ctx = SessionContext::new(); + + ctx.register_batch("t1", batch1)?; + ctx.register_batch("t2", batch2)?; + + Ok(ctx) +} diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 060cd82f63fa7..2f7a20d6e230d 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -17,6 +17,7 @@ //! Optimizer rule to extract equijoin expr from filter use crate::optimizer::ApplyOrder; +use crate::utils::split_conjunction; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::DFSchema; use datafusion_common::Result; @@ -24,6 +25,9 @@ use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator}; use std::sync::Arc; +// equijoin predicate +type EquijoinPredicate = (Expr, Expr); + /// Optimization rule that extract equijoin expr from the filter #[derive(Default)] pub struct ExtractEquijoinPredicate; @@ -56,27 +60,22 @@ impl OptimizerRule for ExtractEquijoinPredicate { let right_schema = right.schema(); filter.as_ref().map_or(Result::Ok(None), |expr| { - let mut accum: Vec<(Expr, Expr)> = vec![]; - let mut accum_filter: Vec = vec![]; - // TODO: avoding clone with split_conjunction - extract_join_keys( - expr.clone(), - &mut accum, - &mut accum_filter, - left_schema, - right_schema, - )?; - - let optimized_plan = (!accum.is_empty()).then(|| { + let (equijoin_predicates, non_equijoin_expr) = + split_eq_and_noneq_join_predicate( + expr, + left_schema, + right_schema, + )?; + + let optimized_plan = (!equijoin_predicates.is_empty()).then(|| { let mut new_on = on.clone(); - new_on.extend(accum); + new_on.extend(equijoin_predicates); - let new_filter = accum_filter.into_iter().reduce(Expr::and); LogicalPlan::Join(Join { left: left.clone(), right: right.clone(), on: new_on, - filter: new_filter, + filter: non_equijoin_expr, join_type: *join_type, join_constraint: *join_constraint, schema: schema.clone(), @@ -100,30 +99,22 @@ impl OptimizerRule for ExtractEquijoinPredicate { } } -/// Extracts equijoin ON condition be a single Eq or multiple conjunctive Eqs -/// Filters matching this pattern are added to `accum` -/// Filters that don't match this pattern are added to `accum_filter` -/// Examples: -/// ```text -/// foo = bar => accum=[(foo, bar)] accum_filter=[] -/// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[] -/// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1] -/// -/// For equijoin join key, assume we have tables -- a(c0, c1 c2) and b(c0, c1, c2): -/// (a.c0 = 10) => accum=[], accum_filter=[a.c0 = 10] -/// (a.c0 + 1 = b.c0 * 2) => accum=[(a.c0 + 1, b.c0 * 2)], accum_filter=[] -/// (a.c0 + b.c0 = 10) => accum=[], accum_filter=[a.c0 + b.c0 = 10] -/// ``` -fn extract_join_keys( - expr: Expr, - accum: &mut Vec<(Expr, Expr)>, - accum_filter: &mut Vec, +fn split_eq_and_noneq_join_predicate( + filter: &Expr, left_schema: &Arc, right_schema: &Arc, -) -> Result<()> { - match &expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - Operator::Eq => { +) -> Result<(Vec, Option)> { + let exprs = split_conjunction(filter); + + let mut accum_join_keys: Vec<(Expr, Expr)> = vec![]; + let mut accum_filters: Vec = vec![]; + for expr in exprs { + match expr { + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) => { let left = left.as_ref(); let right = right.as_ref(); @@ -139,48 +130,27 @@ fn extract_join_keys( let right_expr_type = right_expr.get_type(right_schema)?; if can_hash(&left_expr_type) && can_hash(&right_expr_type) { - accum.push((left_expr, right_expr)); + accum_join_keys.push((left_expr, right_expr)); } else { - accum_filter.push(expr); + accum_filters.push(expr.clone()); } } else { - accum_filter.push(expr); - } - } - Operator::And => { - if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = expr { - extract_join_keys( - *left, - accum, - accum_filter, - left_schema, - right_schema, - )?; - extract_join_keys( - *right, - accum, - accum_filter, - left_schema, - right_schema, - )?; + accum_filters.push(expr.clone()); } } - _other => { - accum_filter.push(expr); - } - }, - _other => { - accum_filter.push(expr); + _ => accum_filters.push(expr.clone()), } } - Ok(()) + let result_filter = accum_filters.into_iter().reduce(Expr::and); + Ok((accum_join_keys, result_filter)) } #[cfg(test)] mod tests { use super::*; use crate::test::*; + use arrow::datatypes::DataType; use datafusion_common::Column; use datafusion_expr::{ col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType, @@ -387,4 +357,33 @@ mod tests { assert_plan_eq(&plan, expected) } + + #[test] + fn join_with_alias_filter() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + let t1_schema = t1.schema().clone(); + let t2_schema = t2.schema().clone(); + + // filter: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) as t1.a + 1 = t2.a + 2 + let filter = Expr::eq( + col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, &t1_schema)?, + col("t2.a") + lit(2i32).cast_to(&DataType::UInt32, &t2_schema)?, + ) + .alias("t1.a + 1 = t2.a + 2"); + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (Vec::::new(), Vec::::new()), + Some(filter), + )? + .build()?; + let expected = "Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; + + assert_plan_eq(&plan, expected) + } } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 5d2dd45af3cb1..95015c4a58bc6 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -237,12 +237,12 @@ impl Optimizer { let rules: Vec> = vec![ Arc::new(InlineTableScan::new()), Arc::new(TypeCoercion::new()), - Arc::new(ExtractEquijoinPredicate::new()), Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(DecorrelateWhereExists::new()), Arc::new(DecorrelateWhereIn::new()), Arc::new(ScalarSubqueryToJoin::new()), + Arc::new(ExtractEquijoinPredicate::new()), // simplify expressions does not simplify expressions in subqueries, so we // run it again after running the optimizations that potentially converted // subqueries to joins