Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2837,3 +2837,96 @@ async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn reduce_where_in_to_expr_equijoin() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", false)?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


let sql = "select t1.t1_id, t1.t1_name, t1.t1_int \
from t1 \
where t1_id + 11 in (select t2_id from t2)";

// assert logical plan
let msg = format!("Creating logical plan for '{}'", sql);
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan().unwrap();

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Projection: t2.t2_id [t2_id:UInt32;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];

let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);

let expected = vec![
"+-------+---------+--------+",
"| t1_id | t1_name | t1_int |",
"+-------+---------+--------+",
"| 11 | a | 1 |",
"| 33 | c | 3 |",
"| 44 | d | 4 |",
"+-------+---------+--------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

Ok(())
}

#[tokio::test]
async fn reduce_where_in_to_non_equijoin() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", false)?;

let sql = "select t1.t1_id, t1.t1_name, t1.t1_int \
from t1 \
where 1 + 10 in (select t2_id from t2)";

// assert logical plan
let msg = format!("Creating logical plan for '{}'", sql);
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan().unwrap();

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" LeftSemi Join: Filter: UInt32(11) = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Projection: t2.t2_id [t2_id:UInt32;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];

let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);

let expected = vec![
"+-------+---------+--------+",
"| t1_id | t1_name | t1_int |",
"+-------+---------+--------+",
"| 11 | a | 1 |",
"| 22 | b | 2 |",
"| 33 | c | 3 |",
"| 44 | d | 4 |",
"+-------+---------+--------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

Ok(())
}
72 changes: 58 additions & 14 deletions datafusion/optimizer/src/subquery_filter_to_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,9 @@ impl OptimizerRule for SubqueryFilterToJoin {
subquery,
negated,
} => {
let right_input = self.try_optimize(
&subquery.subquery,
_config
)?.unwrap_or_else(||subquery.subquery.as_ref().clone());
let right_input = self
.try_optimize(&subquery.subquery, _config)?
.unwrap_or_else(|| subquery.subquery.as_ref().clone());
let right_schema = right_input.schema();
if right_schema.fields().len() != 1 {
return Err(DataFusionError::Plan(
Expand All @@ -108,13 +107,19 @@ impl OptimizerRule for SubqueryFilterToJoin {
};

let right_key = right_schema.field(0).qualified_column();
let left_key = match *expr.clone() {
Expr::Column(col) => col,
_ => return Err(DataFusionError::NotImplemented(
"Filtering by expression not implemented for InSubquery"
.to_string(),
)),
};
let left_key = *expr.clone();
// TODO: save the predicate to join-filter and let the other rule decide it is
// a equi or non-equi predicate.
let (on, filter) =
// When left is a constant expression, like 1,
// the join predicate will be `1 = right_key`, it is better to add it to filter.
if left_key.to_columns()?.is_empty() {
let equi_expr =
Expr::eq(*expr.clone(), Expr::Column(right_key));
(vec![], Some(equi_expr))
} else {
Comment on lines +113 to +120
Copy link
Member

@jackwener jackwener Dec 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, I think it's a good improvement but maybe isn't suitable to put in this rule.🤔
It could be better to handle separate/handle on/joinfilter in a new rule(we also can add more improvement in it).

Just a thought, I think we also can add a TODO, and handle them as a future ticket.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be better to handle separate/handle on/joinfilter in a new rule(we also can add more improvement in it).

Agree.

There is another pr #4711 which will separate equi and non-equi predicate from filter.
We can add the left-key = right-key to join-filter in this rule, and let the other rule do the separating works.

(vec![(left_key, Expr::Column(right_key))], None)
};

let join_type = if *negated {
JoinType::LeftAnti
Expand All @@ -131,8 +136,8 @@ impl OptimizerRule for SubqueryFilterToJoin {
Ok(LogicalPlan::Join(Join {
left: Arc::new(input),
right: Arc::new(right_input),
on: vec![(Expr::Column(left_key), Expr::Column(right_key))],
filter: None,
on,
filter,
join_type,
join_constraint: JoinConstraint::On,
schema: Arc::new(schema),
Expand All @@ -143,7 +148,7 @@ impl OptimizerRule for SubqueryFilterToJoin {
"Unknown expression while rewriting subquery to joins"
.to_string(),
)),
}
},
);

// In case of expressions which could not be rewritten
Expand Down Expand Up @@ -418,4 +423,43 @@ mod tests {

assert_optimized_plan_equal(&plan, expected)
}

/// Test for single IN subquery filter with expr equijoin
#[test]
fn in_subquery_to_expr_equijoin() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(in_subquery(
col("c") + lit(10i32),
test_subquery_with_name("sq")?,
))?
.project(vec![col("test.b")])?
.build()?;

let expected = "Projection: test.b [b:UInt32]\
\n LeftSemi Join: test.c + Int32(10) = sq.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n Projection: sq.c [c:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

/// Test for single IN subquery filter with non equijoin
#[test]
fn in_subquery_to_non_equijoin() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.filter(in_subquery(lit(10i32), test_subquery_with_name("sq")?))?
.project(vec![col("test.b")])?
.build()?;

let expected = "Projection: test.b [b:UInt32]\
\n LeftSemi Join: Filter: Int32(10) = sq.c [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\
\n Projection: sq.c [c:UInt32]\
\n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}
}