diff --git a/src/binder/statement/mod.rs b/src/binder/statement/mod.rs index 2e44836..59999c3 100644 --- a/src/binder/statement/mod.rs +++ b/src/binder/statement/mod.rs @@ -10,7 +10,7 @@ pub enum BoundStatement { Select(BoundSelect), } -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq)] pub struct BoundSelect { pub select_list: Vec, pub from_table: Option, diff --git a/src/binder/table/mod.rs b/src/binder/table/mod.rs index 888e274..266b7a4 100644 --- a/src/binder/table/mod.rs +++ b/src/binder/table/mod.rs @@ -3,7 +3,7 @@ mod join; pub use join::*; use sqlparser::ast::{TableFactor, TableWithJoins}; -use super::{BindError, Binder}; +use super::{BindError, Binder, BoundSelect}; use crate::catalog::{ColumnCatalog, ColumnId, TableCatalog, TableId}; pub static DEFAULT_DATABASE_NAME: &str = "postgres"; @@ -13,6 +13,7 @@ pub static DEFAULT_SCHEMA_NAME: &str = "postgres"; pub enum BoundTableRef { Table(TableCatalog), Join(Join), + Subquery(Box), } impl BoundTableRef { @@ -22,6 +23,7 @@ impl BoundTableRef { BoundTableRef::Join(join) => { TableSchema::new_from_join(&join.left.schema(), &join.right.schema()) } + BoundTableRef::Subquery(subquery) => subquery.from_table.clone().unwrap().schema(), } } } @@ -118,7 +120,15 @@ impl Binder { Ok(BoundTableRef::Table(table_catalog)) } - _ => panic!("unsupported table factor"), + TableFactor::Derived { + lateral: _, + subquery, + alias: _, + } => { + let table = self.bind_select(subquery)?; + Ok(BoundTableRef::Subquery(Box::new(table))) + } + _other => panic!("unsupported table factor: {:?}", _other), } } } diff --git a/src/db.rs b/src/db.rs index ed8c38b..4cb1e21 100644 --- a/src/db.rs +++ b/src/db.rs @@ -8,10 +8,11 @@ use sqlparser::parser::ParserError; use crate::binder::{BindError, Binder}; use crate::executor::{try_collect, ExecutorBuilder, ExecutorError}; use crate::optimizer::{ - CollapseProject, EliminateLimits, HepBatch, HepBatchStrategy, HepOptimizer, InputRefRewriter, - LimitProjectTranspose, PhysicalRewriteRule, PlanRef, PlanRewriter, PushLimitIntoTableScan, - PushLimitThroughJoin, PushPredicateThroughJoin, PushPredicateThroughNonJoin, - PushProjectIntoTableScan, PushProjectThroughChild, RemoveNoopOperators, SimplifyCasts, + CollapseProject, CombineFilter, EliminateLimits, HepBatch, HepBatchStrategy, HepOptimizer, + InputRefRewriter, LimitProjectTranspose, PhysicalRewriteRule, PlanRef, PlanRewriter, + PushLimitIntoTableScan, PushLimitThroughJoin, PushPredicateThroughJoin, + PushPredicateThroughNonJoin, PushProjectIntoTableScan, PushProjectThroughChild, + RemoveNoopOperators, SimplifyCasts, }; use crate::parser::parse; use crate::planner::{LogicalPlanError, Planner}; @@ -86,7 +87,7 @@ impl Database { HepBatch::new( "Combine operators".to_string(), HepBatchStrategy::fix_point_topdown(10), - vec![CollapseProject::create()], + vec![CollapseProject::create(), CombineFilter::create()], ), HepBatch::new( "One-time simplification".to_string(), diff --git a/src/optimizer/rules/combine_operators.rs b/src/optimizer/rules/combine_operators.rs index 9330f85..a8eb194 100644 --- a/src/optimizer/rules/combine_operators.rs +++ b/src/optimizer/rules/combine_operators.rs @@ -1,7 +1,11 @@ -use super::util::is_subset_exprs; +use std::sync::Arc; + +use super::util::{is_subset_exprs, reduce_conjunctive_predicate}; use super::RuleImpl; -use crate::optimizer::core::{OptExpr, Pattern, PatternChildrenPredicate, Rule, Substitute}; -use crate::optimizer::PlanNodeType; +use crate::optimizer::core::{ + OptExpr, OptExprNode, Pattern, PatternChildrenPredicate, Rule, Substitute, +}; +use crate::optimizer::{Dummy, LogicalFilter, PlanNodeType}; lazy_static! { static ref COLLAPSE_PROJECT_RULE: Pattern = { @@ -13,6 +17,15 @@ lazy_static! { }]), } }; + static ref COMBINE_FILTERS: Pattern = { + Pattern { + predicate: |p| p.node_type() == PlanNodeType::LogicalFilter, + children: PatternChildrenPredicate::Predicate(vec![Pattern { + predicate: |p| p.node_type() == PlanNodeType::LogicalFilter, + children: PatternChildrenPredicate::None, + }]), + } + }; } /// Combine two adjacent project operators into one. @@ -53,3 +66,45 @@ impl Rule for CollapseProject { } } } + +/// Combine two adjacent filter operators into one. +#[derive(Clone)] +pub struct CombineFilter; + +impl CombineFilter { + pub fn create() -> RuleImpl { + Self {}.into() + } +} + +impl Rule for CombineFilter { + fn pattern(&self) -> &Pattern { + &COMBINE_FILTERS + } + + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + // TODO: handle column alias + let filter_opt_expr = opt_expr; + let next_filter_opt_expr = filter_opt_expr.children[0].clone(); + + let filter_expr = filter_opt_expr + .root + .get_plan_ref() + .as_logical_filter() + .unwrap() + .expr(); + let next_filter_exprs = next_filter_opt_expr + .root + .get_plan_ref() + .as_logical_filter() + .unwrap() + .expr(); + if let Some(expr) = reduce_conjunctive_predicate([filter_expr, next_filter_exprs].to_vec()) + { + let new_filter_root = + OptExprNode::PlanRef(Arc::new(LogicalFilter::new(expr, Dummy::new_ref()))); + let res = OptExpr::new(new_filter_root, next_filter_opt_expr.children); + result.opt_exprs.push(res); + } + } +} diff --git a/src/optimizer/rules/mod.rs b/src/optimizer/rules/mod.rs index 8b2dd2c..f99d365 100644 --- a/src/optimizer/rules/mod.rs +++ b/src/optimizer/rules/mod.rs @@ -35,6 +35,7 @@ pub enum RuleImpl { RemoveNoopOperators, // Combine operators CollapseProject, + CombineFilter, // Simplification SimplifyCasts, // Rewrite physical plan diff --git a/src/optimizer/rules/util.rs b/src/optimizer/rules/util.rs index 65ca723..ab4d116 100644 --- a/src/optimizer/rules/util.rs +++ b/src/optimizer/rules/util.rs @@ -1,4 +1,7 @@ -use crate::binder::BoundExpr; +use arrow::datatypes::DataType; +use sqlparser::ast::BinaryOperator; + +use crate::binder::{BoundBinaryOp, BoundExpr}; use crate::catalog::ColumnCatalog; /// Return true when left is subset of right, only compare table_id and column_id, so it's safe to @@ -11,3 +14,15 @@ pub fn is_subset_cols(left: &[ColumnCatalog], right: &[ColumnCatalog]) -> bool { pub fn is_subset_exprs(left: &[BoundExpr], right: &[BoundExpr]) -> bool { left.iter().all(|l| right.contains(l)) } + +/// Reduce multi predicates into a conjunctive predicate by AND +pub fn reduce_conjunctive_predicate(exprs: Vec) -> Option { + exprs.into_iter().reduce(|a, b| { + BoundExpr::BinaryOp(BoundBinaryOp { + op: BinaryOperator::And, + left: Box::new(a), + right: Box::new(b), + return_type: Some(DataType::Boolean), + }) + }) +} diff --git a/src/planner/select.rs b/src/planner/select.rs index 0d84657..61f5993 100644 --- a/src/planner/select.rs +++ b/src/planner/select.rs @@ -70,6 +70,10 @@ impl Planner { ); Ok(Arc::new(join)) } + BoundTableRef::Subquery(subquery) => { + let subquery = subquery.clone(); + self.plan_select(*subquery) + } } } } diff --git a/tests/planner/combine-operators.planner.sql b/tests/planner/combine-operators.planner.sql new file mode 100644 index 0000000..edf3c0d --- /dev/null +++ b/tests/planner/combine-operators.planner.sql @@ -0,0 +1,20 @@ +-- CollapseProject & CombineFilter: combine adjacent projects and filters into one + +select * from (select * from (select * from t1 where c < 2) where a > 1) where b > 7; + +/* +original plan: +LogicalProject: exprs [t1.a:Int64, t1.b:Int64, t1.c:Int64] + LogicalFilter: expr t1.b:Int64 > Cast(7 as Int64) + LogicalProject: exprs [t1.a:Int64, t1.b:Int64, t1.c:Int64] + LogicalFilter: expr t1.a:Int64 > Cast(1 as Int64) + LogicalProject: exprs [t1.a:Int64, t1.b:Int64, t1.c:Int64] + LogicalFilter: expr t1.c:Int64 < Cast(2 as Int64) + LogicalTableScan: table: #t1, columns: [a, b, c] + +optimized plan: +PhysicalProject: exprs [t1.a:Int64, t1.b:Int64, t1.c:Int64] + PhysicalFilter: expr t1.b:Int64 > 7 AND t1.a:Int64 > 1 AND t1.c:Int64 < 2 + PhysicalTableScan: table: #t1, columns: [a, b, c] +*/ + diff --git a/tests/planner/combine-operators.yml b/tests/planner/combine-operators.yml new file mode 100644 index 0000000..7424bfb --- /dev/null +++ b/tests/planner/combine-operators.yml @@ -0,0 +1,4 @@ +- sql: | + select * from (select * from (select * from t1 where c < 2) where a > 1) where b > 7; + desc: | + CollapseProject & CombineFilter: combine adjacent projects and filters into one diff --git a/tests/slt/subquery.slt b/tests/slt/subquery.slt new file mode 100644 index 0000000..fe7af17 --- /dev/null +++ b/tests/slt/subquery.slt @@ -0,0 +1,14 @@ +query III +select * from (select * from t1 where a > 1) where b > 7; +---- +2 8 1 + +query II +select b from (select a, b from t1 where a > 1) where b > 7; +---- +8 + +query III +select * from (select * from (select * from t1 where c < 2) where a > 1) where b > 7; +---- +2 8 1