Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions benchmarks/expected-plans/q20.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
Sort: supplier.s_name ASC NULLS LAST
Projection: supplier.s_name, supplier.s_address
LeftSemi Join: supplier.s_suppkey = __sq_2.ps_suppkey
LeftSemi Join: supplier.s_suppkey = __sq_1.ps_suppkey
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these changes imply the decorrelate passes used to be applied bottom up and after this PR they are applied top-down

Is that intentional? Or maybe I am misreading the diff 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😂 original implementation mix TopDown and BottomUp.
It use TopDown overall, but use BottomUp when match the subPlan.
I think it's a small mistake in original code but don't affect correctness.

Inner Join: supplier.s_nationkey = nation.n_nationkey
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey]
Filter: nation.n_name = Utf8("CANADA")
TableScan: nation projection=[n_nationkey, n_name]
SubqueryAlias: __sq_2
SubqueryAlias: __sq_1
Projection: partsupp.ps_suppkey AS ps_suppkey
Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value
Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey
LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey
LeftSemi Join: partsupp.ps_partkey = __sq_2.p_partkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty]
SubqueryAlias: __sq_1
SubqueryAlias: __sq_2
Projection: part.p_partkey AS p_partkey
Filter: part.p_name LIKE Utf8("forest%")
TableScan: part projection=[p_partkey, p_name]
Expand Down
20 changes: 10 additions & 10 deletions datafusion/core/tests/sql/subqueries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,16 @@ where c_acctbal < (
let actual = format!("{}", plan.display_indent());
let expected = "Sort: customer.c_custkey ASC NULLS LAST\
\n Projection: customer.c_custkey\
\n Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __sq_2.__value\
\n Inner Join: customer.c_custkey = __sq_2.o_custkey\
\n Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __sq_1.__value\
\n Inner Join: customer.c_custkey = __sq_1.o_custkey\
\n TableScan: customer projection=[c_custkey, c_acctbal]\
\n SubqueryAlias: __sq_2\
\n SubqueryAlias: __sq_1\
\n Projection: orders.o_custkey, SUM(orders.o_totalprice) AS __value\
\n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]]\
\n Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __sq_1.__value\
\n Inner Join: orders.o_orderkey = __sq_1.l_orderkey\
\n Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __sq_2.__value\
\n Inner Join: orders.o_orderkey = __sq_2.l_orderkey\
\n TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice]\
\n SubqueryAlias: __sq_1\
\n SubqueryAlias: __sq_2\
\n Projection: lineitem.l_orderkey, SUM(lineitem.l_extendedprice) AS price AS __value\
\n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]]\
\n TableScan: lineitem projection=[l_orderkey, l_extendedprice]";
Expand Down Expand Up @@ -324,18 +324,18 @@ order by s_name;
let actual = format!("{}", plan.display_indent());
let expected = "Sort: supplier.s_name ASC NULLS LAST\
\n Projection: supplier.s_name, supplier.s_address\
\n LeftSemi Join: supplier.s_suppkey = __sq_2.ps_suppkey\
\n LeftSemi Join: supplier.s_suppkey = __sq_1.ps_suppkey\
\n Inner Join: supplier.s_nationkey = nation.n_nationkey\
\n TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey]\
\n Filter: nation.n_name = Utf8(\"CANADA\")\
\n TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8(\"CANADA\")]\
\n SubqueryAlias: __sq_2\
\n SubqueryAlias: __sq_1\
\n Projection: partsupp.ps_suppkey AS ps_suppkey\
\n Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value\
\n Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey\
\n LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey\
\n LeftSemi Join: partsupp.ps_partkey = __sq_2.p_partkey\
\n TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty]\
\n SubqueryAlias: __sq_1\
\n SubqueryAlias: __sq_2\
\n Projection: part.p_partkey AS p_partkey\
\n Filter: part.p_name LIKE Utf8(\"forest%\")\
\n TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8(\"forest%\")]\
Expand Down
8 changes: 6 additions & 2 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1329,12 +1329,16 @@ impl SubqueryAlias {
/// If the value of `<predicate>` is true, the input row is passed to
/// the output. If the value of `<predicate>` is false, the row is
/// discarded.
///
/// Filter should not be created directly but instead use `try_new()`
/// and that these fields are only pub to support pattern matching
#[derive(Clone)]
#[non_exhaustive]
pub struct Filter {
/// The predicate expression, which must have Boolean type.
predicate: Expr,
pub predicate: Expr,
/// The incoming logical plan
input: Arc<LogicalPlan>,
pub input: Arc<LogicalPlan>,
Comment on lines +1339 to +1341
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pub them because I need pattern-match in rules.

Discussion about it in #4464.

cc @alamb @andygrove @tustvold .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is ok -- it would be nice to add some comments explaining a Filter should not be created directly but instead use try_new() and that these fields are only pub to support pattern matching

}

impl Filter {
Expand Down
80 changes: 34 additions & 46 deletions datafusion/optimizer/src/decorrelate_where_exists.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
// specific language governing permissions and limitations
// under the License.

use crate::optimizer::ApplyOrder;
use crate::utils::{
conjunction, exprs_to_join_cols, find_join_exprs, split_conjunction,
verify_not_disjunction,
};
use crate::{utils, OptimizerConfig, OptimizerRule};
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::{context, Result};
use datafusion_expr::{
logical_plan::{Filter, JoinType, Subquery},
Expand Down Expand Up @@ -81,27 +82,15 @@ impl OptimizerRule for DecorrelateWhereExists {
) -> Result<Option<LogicalPlan>> {
match plan {
LogicalPlan::Filter(filter) => {
let predicate = filter.predicate();
let filter_input = filter.input().as_ref();

// Apply optimizer rule to current input
let optimized_input = self
.try_optimize(filter_input, config)?
.unwrap_or_else(|| filter_input.clone());

let (subqueries, other_exprs) =
self.extract_subquery_exprs(predicate, config)?;
let optimized_plan = LogicalPlan::Filter(Filter::try_new(
predicate.clone(),
Arc::new(optimized_input),
)?);
self.extract_subquery_exprs(filter.predicate(), config)?;
if subqueries.is_empty() {
// regular filter, no subquery exists clause here
return Ok(Some(optimized_plan));
return Ok(None);
}

// iterate through all exists clauses in predicate, turning each into a join
let mut cur_input = filter_input.clone();
let mut cur_input = filter.input().as_ref().clone();
for subquery in subqueries {
if let Some(x) = optimize_exists(&subquery, &cur_input, &other_exprs)?
{
Expand All @@ -112,16 +101,17 @@ impl OptimizerRule for DecorrelateWhereExists {
}
Ok(Some(cur_input))
}
_ => {
// Apply the optimization to all inputs of the plan
Ok(Some(utils::optimize_children(self, plan, config)?))
}
_ => Ok(None),
}
}

fn name(&self) -> &str {
"decorrelate_where_exists"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}
}

/// Takes a query like:
Expand Down Expand Up @@ -226,6 +216,15 @@ mod tests {
};
use std::ops::Add;

fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq_display_indent(
Arc::new(DecorrelateWhereExists::new()),
plan,
expected,
);
Ok(())
}

/// Test for multiple exists subqueries in the same filter expression
#[test]
fn multiple_subqueries() -> Result<()> {
Expand All @@ -248,8 +247,7 @@ mod tests {
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;

assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
Ok(())
assert_plan_eq(&plan, expected)
}

/// Test recursive correlated subqueries
Expand Down Expand Up @@ -284,8 +282,7 @@ mod tests {
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"#;

assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
Ok(())
assert_plan_eq(&plan, expected)
}

/// Test for correlated exists subquery filter with additional subquery filters
Expand Down Expand Up @@ -313,8 +310,7 @@ mod tests {
Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;

assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
Ok(())
assert_plan_eq(&plan, expected)
}

/// Test for correlated exists subquery with no columns in schema
Expand All @@ -332,8 +328,7 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;

assert_optimization_skipped(&DecorrelateWhereExists::new(), &plan);
Ok(())
assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan)
}

/// Test for exists subquery with both columns in schema
Expand All @@ -351,8 +346,7 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;

assert_optimization_skipped(&DecorrelateWhereExists::new(), &plan);
Ok(())
assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan)
}

/// Test for correlated exists subquery not equal
Expand All @@ -370,8 +364,7 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;

assert_optimization_skipped(&DecorrelateWhereExists::new(), &plan);
Ok(())
assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan)
}

/// Test for correlated exists subquery less than
Expand All @@ -391,7 +384,7 @@ mod tests {

let expected = r#"can't optimize < column comparison"#;

assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected);
Ok(())
}

Expand All @@ -416,7 +409,7 @@ mod tests {

let expected = r#"Optimizing disjunctions not supported!"#;

assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected);
Ok(())
}

Expand All @@ -434,8 +427,7 @@ mod tests {
.project(vec![col("customer.c_custkey")])?
.build()?;

assert_optimization_skipped(&DecorrelateWhereExists::new(), &plan);
Ok(())
assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan)
}

/// Test for correlated exists expressions
Expand All @@ -459,8 +451,7 @@ mod tests {
TableScan: customer [c_custkey:Int64, c_name:Utf8]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;

assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
Ok(())
assert_plan_eq(&plan, expected)
}

/// Test for correlated exists subquery filter with additional filters
Expand All @@ -483,8 +474,7 @@ mod tests {
TableScan: customer [c_custkey:Int64, c_name:Utf8]
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#;

assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
Ok(())
assert_plan_eq(&plan, expected)
}

/// Test for correlated exists subquery filter with disjustions
Expand All @@ -511,8 +501,7 @@ mod tests {
TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
TableScan: customer [c_custkey:Int64, c_name:Utf8]"#;

assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
Ok(())
assert_plan_eq(&plan, expected)
}

/// Test for correlated EXISTS subquery filter
Expand All @@ -535,8 +524,7 @@ mod tests {
TableScan: test [a:UInt32, b:UInt32, c:UInt32]
TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#;

assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected);
Ok(())
assert_plan_eq(&plan, expected)
}

/// Test for single exists subquery filter
Expand All @@ -550,7 +538,7 @@ mod tests {

let expected = "cannot optimize non-correlated subquery";

assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected);
Ok(())
}

Expand All @@ -565,7 +553,7 @@ mod tests {

let expected = "cannot optimize non-correlated subquery";

assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected);
assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected);
Ok(())
}
}
Loading