diff --git a/src/binder/statement/mod.rs b/src/binder/statement/mod.rs index c9a3789..adf5b5a 100644 --- a/src/binder/statement/mod.rs +++ b/src/binder/statement/mod.rs @@ -95,11 +95,14 @@ impl Binder { let select_distinct = select.distinct; // bind where clause - let where_clause = select + let mut where_clause = select .selection .as_ref() .map(|expr| self.bind_expr(expr)) .transpose()?; + if let Some(expr) = &mut where_clause { + self.rewrite_scalar_subquery(expr, &mut from_table) + } // bind group by clause let group_by = select diff --git a/tests/planner/column-pruning.planner.sql b/tests/planner/column-pruning.planner.sql index b1a5b7b..3ad28d1 100644 --- a/tests/planner/column-pruning.planner.sql +++ b/tests/planner/column-pruning.planner.sql @@ -200,3 +200,28 @@ PhysicalProject: exprs [t1.a:Int64, (subquery_0.subquery_0_scalar_v0:Nullable(In PhysicalTableScan: table: #t1, columns: [b] */ +-- PushProjectThroughChild: column pruning across scalar subquery in where expr + +select t1.a, t1.b from t1 where a >= (select max(a) from t1); + +/* +original plan: +LogicalProject: exprs [t1.a:Int64, t1.b:Int64] + LogicalFilter: expr t1.a:Int64 >= subquery_0.subquery_0_scalar_v0:Nullable(Int64) + LogicalJoin: type Cross, cond None + LogicalTableScan: table: #t1, columns: [a, b, c] + LogicalProject: exprs [(Max(t1.a:Int64):Int64) as subquery_0.subquery_0_scalar_v0] + LogicalAgg: agg_funcs [Max(t1.a:Int64):Int64] group_by [] + LogicalTableScan: table: #t1, columns: [a, b, c] + +optimized plan: +PhysicalProject: exprs [t1.a:Int64, t1.b:Int64] + PhysicalFilter: expr t1.a:Int64 >= subquery_0.subquery_0_scalar_v0:Nullable(Int64) + PhysicalProject: exprs [t1.a:Nullable(Int64), t1.b:Nullable(Int64), subquery_0.subquery_0_scalar_v0:Nullable(Int64)] + PhysicalCrossJoin: type Cross + PhysicalTableScan: table: #t1, columns: [a, b] + PhysicalProject: exprs [(Max(t1.a:Int64):Int64) as subquery_0.subquery_0_scalar_v0] + PhysicalSimpleAgg: agg_funcs [Max(t1.a:Int64):Int64] group_by [] + PhysicalTableScan: table: #t1, columns: [a] +*/ + diff --git a/tests/planner/column-pruning.yml b/tests/planner/column-pruning.yml index 2c684bb..2103d74 100644 --- a/tests/planner/column-pruning.yml +++ b/tests/planner/column-pruning.yml @@ -44,3 +44,8 @@ select a, (select max(b) from t1) + (select min(b) from t1) as mix_b from t1; desc: | PushProjectThroughChild: column pruning across multiple scalar subquery + +- sql: | + select t1.a, t1.b from t1 where a >= (select max(a) from t1); + desc: | + PushProjectThroughChild: column pruning across scalar subquery in where expr diff --git a/tests/slt/subquery.slt b/tests/slt/subquery.slt index eff5588..24a4c03 100644 --- a/tests/slt/subquery.slt +++ b/tests/slt/subquery.slt @@ -71,3 +71,15 @@ select a, (select max(b) from t1) + (select min(b) from t1) as mix_b from t1; 1 12 2 12 2 12 + +query I +select t1.a, t1.b from t1 where a >= (select max(a) from t1); +---- +2 7 +2 8 + +query I +select t1.a, t1.b from t1 where a >= (select max(a) from t1) and b = (select max(b) from t1); +---- +2 7 +2 8