diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 7e701b56dfc53..236eab1aa3c84 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -2837,3 +2837,96 @@ async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn reduce_where_in_to_expr_equijoin() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t1.t1_int \ + from t1 \ + where t1_id + 11 in (select t2_id from t2)"; + + // assert logical plan + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Projection: t2.t2_id [t2_id:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "| 33 | c | 3 |", + "| 44 | d | 4 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn reduce_where_in_to_non_equijoin() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t1.t1_int \ + from t1 \ + where 1 + 10 in (select t2_id from t2)"; + + // assert logical plan + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: Filter: UInt32(11) = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Projection: t2.t2_id [t2_id:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "| 22 | b | 2 |", + "| 33 | c | 3 |", + "| 44 | d | 4 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/optimizer/src/subquery_filter_to_join.rs b/datafusion/optimizer/src/subquery_filter_to_join.rs index 696f911a10208..2324be90cfa67 100644 --- a/datafusion/optimizer/src/subquery_filter_to_join.rs +++ b/datafusion/optimizer/src/subquery_filter_to_join.rs @@ -95,10 +95,9 @@ impl OptimizerRule for SubqueryFilterToJoin { subquery, negated, } => { - let right_input = self.try_optimize( - &subquery.subquery, - _config - )?.unwrap_or_else(||subquery.subquery.as_ref().clone()); + let right_input = self + .try_optimize(&subquery.subquery, _config)? + .unwrap_or_else(|| subquery.subquery.as_ref().clone()); let right_schema = right_input.schema(); if right_schema.fields().len() != 1 { return Err(DataFusionError::Plan( @@ -108,13 +107,19 @@ impl OptimizerRule for SubqueryFilterToJoin { }; let right_key = right_schema.field(0).qualified_column(); - let left_key = match *expr.clone() { - Expr::Column(col) => col, - _ => return Err(DataFusionError::NotImplemented( - "Filtering by expression not implemented for InSubquery" - .to_string(), - )), - }; + let left_key = *expr.clone(); + // TODO: save the predicate to join-filter and let the other rule decide it is + // a equi or non-equi predicate. + let (on, filter) = + // When left is a constant expression, like 1, + // the join predicate will be `1 = right_key`, it is better to add it to filter. + if left_key.to_columns()?.is_empty() { + let equi_expr = + Expr::eq(*expr.clone(), Expr::Column(right_key)); + (vec![], Some(equi_expr)) + } else { + (vec![(left_key, Expr::Column(right_key))], None) + }; let join_type = if *negated { JoinType::LeftAnti @@ -131,8 +136,8 @@ impl OptimizerRule for SubqueryFilterToJoin { Ok(LogicalPlan::Join(Join { left: Arc::new(input), right: Arc::new(right_input), - on: vec![(Expr::Column(left_key), Expr::Column(right_key))], - filter: None, + on, + filter, join_type, join_constraint: JoinConstraint::On, schema: Arc::new(schema), @@ -143,7 +148,7 @@ impl OptimizerRule for SubqueryFilterToJoin { "Unknown expression while rewriting subquery to joins" .to_string(), )), - } + }, ); // In case of expressions which could not be rewritten @@ -418,4 +423,43 @@ mod tests { assert_optimized_plan_equal(&plan, expected) } + + /// Test for single IN subquery filter with expr equijoin + #[test] + fn in_subquery_to_expr_equijoin() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(in_subquery( + col("c") + lit(10i32), + test_subquery_with_name("sq")?, + ))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: test.c + Int32(10) = sq.c [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq.c [c:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } + + /// Test for single IN subquery filter with non equijoin + #[test] + fn in_subquery_to_non_equijoin() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(in_subquery(lit(10i32), test_subquery_with_name("sq")?))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: Int32(10) = sq.c [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: sq.c [c:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_equal(&plan, expected) + } }