Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/binder/statement/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub enum BoundStatement {
Select(BoundSelect),
}

#[derive(Debug)]
#[derive(Debug, Clone, PartialEq)]
pub struct BoundSelect {
pub select_list: Vec<BoundExpr>,
pub from_table: Option<BoundTableRef>,
Expand Down
14 changes: 12 additions & 2 deletions src/binder/table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod join;
pub use join::*;
use sqlparser::ast::{TableFactor, TableWithJoins};

use super::{BindError, Binder};
use super::{BindError, Binder, BoundSelect};
use crate::catalog::{ColumnCatalog, ColumnId, TableCatalog, TableId};

pub static DEFAULT_DATABASE_NAME: &str = "postgres";
Expand All @@ -13,6 +13,7 @@ pub static DEFAULT_SCHEMA_NAME: &str = "postgres";
pub enum BoundTableRef {
Table(TableCatalog),
Join(Join),
Subquery(Box<BoundSelect>),
}

impl BoundTableRef {
Expand All @@ -22,6 +23,7 @@ impl BoundTableRef {
BoundTableRef::Join(join) => {
TableSchema::new_from_join(&join.left.schema(), &join.right.schema())
}
BoundTableRef::Subquery(subquery) => subquery.from_table.clone().unwrap().schema(),
}
}
}
Expand Down Expand Up @@ -118,7 +120,15 @@ impl Binder {

Ok(BoundTableRef::Table(table_catalog))
}
_ => panic!("unsupported table factor"),
TableFactor::Derived {
lateral: _,
subquery,
alias: _,
} => {
let table = self.bind_select(subquery)?;
Ok(BoundTableRef::Subquery(Box::new(table)))
}
_other => panic!("unsupported table factor: {:?}", _other),
}
}
}
11 changes: 6 additions & 5 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ use sqlparser::parser::ParserError;
use crate::binder::{BindError, Binder};
use crate::executor::{try_collect, ExecutorBuilder, ExecutorError};
use crate::optimizer::{
CollapseProject, EliminateLimits, HepBatch, HepBatchStrategy, HepOptimizer, InputRefRewriter,
LimitProjectTranspose, PhysicalRewriteRule, PlanRef, PlanRewriter, PushLimitIntoTableScan,
PushLimitThroughJoin, PushPredicateThroughJoin, PushPredicateThroughNonJoin,
PushProjectIntoTableScan, PushProjectThroughChild, RemoveNoopOperators, SimplifyCasts,
CollapseProject, CombineFilter, EliminateLimits, HepBatch, HepBatchStrategy, HepOptimizer,
InputRefRewriter, LimitProjectTranspose, PhysicalRewriteRule, PlanRef, PlanRewriter,
PushLimitIntoTableScan, PushLimitThroughJoin, PushPredicateThroughJoin,
PushPredicateThroughNonJoin, PushProjectIntoTableScan, PushProjectThroughChild,
RemoveNoopOperators, SimplifyCasts,
};
use crate::parser::parse;
use crate::planner::{LogicalPlanError, Planner};
Expand Down Expand Up @@ -86,7 +87,7 @@ impl Database {
HepBatch::new(
"Combine operators".to_string(),
HepBatchStrategy::fix_point_topdown(10),
vec![CollapseProject::create()],
vec![CollapseProject::create(), CombineFilter::create()],
),
HepBatch::new(
"One-time simplification".to_string(),
Expand Down
61 changes: 58 additions & 3 deletions src/optimizer/rules/combine_operators.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use super::util::is_subset_exprs;
use std::sync::Arc;

use super::util::{is_subset_exprs, reduce_conjunctive_predicate};
use super::RuleImpl;
use crate::optimizer::core::{OptExpr, Pattern, PatternChildrenPredicate, Rule, Substitute};
use crate::optimizer::PlanNodeType;
use crate::optimizer::core::{
OptExpr, OptExprNode, Pattern, PatternChildrenPredicate, Rule, Substitute,
};
use crate::optimizer::{Dummy, LogicalFilter, PlanNodeType};

lazy_static! {
static ref COLLAPSE_PROJECT_RULE: Pattern = {
Expand All @@ -13,6 +17,15 @@ lazy_static! {
}]),
}
};
static ref COMBINE_FILTERS: Pattern = {
Pattern {
predicate: |p| p.node_type() == PlanNodeType::LogicalFilter,
children: PatternChildrenPredicate::Predicate(vec![Pattern {
predicate: |p| p.node_type() == PlanNodeType::LogicalFilter,
children: PatternChildrenPredicate::None,
}]),
}
};
}

/// Combine two adjacent project operators into one.
Expand Down Expand Up @@ -53,3 +66,45 @@ impl Rule for CollapseProject {
}
}
}

/// Combine two adjacent filter operators into one.
#[derive(Clone)]
pub struct CombineFilter;

impl CombineFilter {
pub fn create() -> RuleImpl {
Self {}.into()
}
}

impl Rule for CombineFilter {
fn pattern(&self) -> &Pattern {
&COMBINE_FILTERS
}

fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) {
// TODO: handle column alias
let filter_opt_expr = opt_expr;
let next_filter_opt_expr = filter_opt_expr.children[0].clone();

let filter_expr = filter_opt_expr
.root
.get_plan_ref()
.as_logical_filter()
.unwrap()
.expr();
let next_filter_exprs = next_filter_opt_expr
.root
.get_plan_ref()
.as_logical_filter()
.unwrap()
.expr();
if let Some(expr) = reduce_conjunctive_predicate([filter_expr, next_filter_exprs].to_vec())
{
let new_filter_root =
OptExprNode::PlanRef(Arc::new(LogicalFilter::new(expr, Dummy::new_ref())));
let res = OptExpr::new(new_filter_root, next_filter_opt_expr.children);
result.opt_exprs.push(res);
}
}
}
1 change: 1 addition & 0 deletions src/optimizer/rules/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub enum RuleImpl {
RemoveNoopOperators,
// Combine operators
CollapseProject,
CombineFilter,
// Simplification
SimplifyCasts,
// Rewrite physical plan
Expand Down
17 changes: 16 additions & 1 deletion src/optimizer/rules/util.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::binder::BoundExpr;
use arrow::datatypes::DataType;
use sqlparser::ast::BinaryOperator;

use crate::binder::{BoundBinaryOp, BoundExpr};
use crate::catalog::ColumnCatalog;

/// Return true when left is subset of right, only compare table_id and column_id, so it's safe to
Expand All @@ -11,3 +14,15 @@ pub fn is_subset_cols(left: &[ColumnCatalog], right: &[ColumnCatalog]) -> bool {
pub fn is_subset_exprs(left: &[BoundExpr], right: &[BoundExpr]) -> bool {
left.iter().all(|l| right.contains(l))
}

/// Reduce multi predicates into a conjunctive predicate by AND
pub fn reduce_conjunctive_predicate(exprs: Vec<BoundExpr>) -> Option<BoundExpr> {
exprs.into_iter().reduce(|a, b| {
BoundExpr::BinaryOp(BoundBinaryOp {
op: BinaryOperator::And,
left: Box::new(a),
right: Box::new(b),
return_type: Some(DataType::Boolean),
})
})
}
4 changes: 4 additions & 0 deletions src/planner/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ impl Planner {
);
Ok(Arc::new(join))
}
BoundTableRef::Subquery(subquery) => {
let subquery = subquery.clone();
self.plan_select(*subquery)
}
}
}
}
20 changes: 20 additions & 0 deletions tests/planner/combine-operators.planner.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
-- CollapseProject & CombineFilter: combine adjacent projects and filters into one

select * from (select * from (select * from t1 where c < 2) where a > 1) where b > 7;

/*
original plan:
LogicalProject: exprs [t1.a:Int64, t1.b:Int64, t1.c:Int64]
LogicalFilter: expr t1.b:Int64 > Cast(7 as Int64)
LogicalProject: exprs [t1.a:Int64, t1.b:Int64, t1.c:Int64]
LogicalFilter: expr t1.a:Int64 > Cast(1 as Int64)
LogicalProject: exprs [t1.a:Int64, t1.b:Int64, t1.c:Int64]
LogicalFilter: expr t1.c:Int64 < Cast(2 as Int64)
LogicalTableScan: table: #t1, columns: [a, b, c]

optimized plan:
PhysicalProject: exprs [t1.a:Int64, t1.b:Int64, t1.c:Int64]
PhysicalFilter: expr t1.b:Int64 > 7 AND t1.a:Int64 > 1 AND t1.c:Int64 < 2
PhysicalTableScan: table: #t1, columns: [a, b, c]
*/

4 changes: 4 additions & 0 deletions tests/planner/combine-operators.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- sql: |
select * from (select * from (select * from t1 where c < 2) where a > 1) where b > 7;
desc: |
CollapseProject & CombineFilter: combine adjacent projects and filters into one
14 changes: 14 additions & 0 deletions tests/slt/subquery.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
query III
select * from (select * from t1 where a > 1) where b > 7;
----
2 8 1

query II
select b from (select a, b from t1 where a > 1) where b > 7;
----
8

query III
select * from (select * from (select * from t1 where c < 2) where a > 1) where b > 7;
----
2 8 1