diff --git a/src/db.rs b/src/db.rs index aa9f17a..4fb5ae5 100644 --- a/src/db.rs +++ b/src/db.rs @@ -15,7 +15,7 @@ use crate::optimizer::{ RemoveNoopOperators, SimplifyCasts, }; use crate::parser::parse; -use crate::planner::{LogicalPlanError, Planner}; +use crate::planner::{LogicalPlanError, Planner, PlannerContext}; use crate::storage::{CsvStorage, Storage, StorageError, StorageImpl}; use crate::util::pretty_plan_tree_string; @@ -54,7 +54,7 @@ impl Database { Ok(data) } - fn default_optimizer(&self, root: PlanRef) -> HepOptimizer { + fn default_optimizer(&self, root: PlanRef, planner_context: PlannerContext) -> HepOptimizer { // the order of rules is important and affects the rule matching logic let batches = vec![ HepBatch::new( @@ -101,7 +101,7 @@ impl Database { ), ]; - HepOptimizer::new(batches, root) + HepOptimizer::new(batches, root, planner_context) } pub async fn run(&self, sql: &str) -> Result, DatabaseError> { @@ -123,7 +123,7 @@ impl Database { println!("bound_stmt:\n{:#?}\n", bound_stmt); // 3. convert bound stmts to logical plan - let planner = Planner {}; + let mut planner = Planner::default(); let logical_plan = planner.plan(bound_stmt)?; println!( "original_plan:\n{}\n", @@ -131,7 +131,7 @@ impl Database { ); // 4. optimize logical plan to physical plan - let mut optimizer = self.default_optimizer(logical_plan); + let mut optimizer = self.default_optimizer(logical_plan, planner.context); let physical_plan = optimizer.find_best(); println!( "optimized_plan:\n{}\n", @@ -165,7 +165,7 @@ impl Database { let bound_stmt = binder.bind(&stats[0])?; let mut explain_str = String::new(); - let planner = Planner {}; + let mut planner = Planner::default(); let logical_plan = planner.plan(bound_stmt)?; _ = write!( explain_str, @@ -173,7 +173,7 @@ impl Database { pretty_plan_tree_string(&*logical_plan) ); - let mut optimizer = self.default_optimizer(logical_plan); + let mut optimizer = self.default_optimizer(logical_plan, planner.context); let physical_plan = optimizer.find_best(); _ = write!( explain_str, diff --git a/src/executor/mod.rs b/src/executor/mod.rs index e01c762..a0971bc 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -26,8 +26,8 @@ use self::project::ProjectExecutor; use self::table_scan::TableScanExecutor; use crate::optimizer::{ PhysicalCrossJoin, PhysicalFilter, PhysicalHashAgg, PhysicalHashJoin, PhysicalLimit, - PhysicalOrder, PhysicalProject, PhysicalSimpleAgg, PhysicalTableScan, PlanNode, PlanRef, - PlanTreeNode, PlanVisitor, + PhysicalOrder, PhysicalProject, PhysicalSimpleAgg, PhysicalTableScan, PlanRef, PlanTreeNode, + PlanVisitor, }; use crate::storage::{StorageError, StorageImpl}; @@ -107,7 +107,7 @@ impl PlanVisitor for ExecutorBuilder { right_child: self.visit(plan.right()).unwrap(), join_type: plan.join_type(), join_condition: plan.join_condition(), - join_output_schema: plan.output_columns(), + join_output_schema: plan.join_output_columns(), } .execute(), ) @@ -118,7 +118,7 @@ impl PlanVisitor for ExecutorBuilder { CrossJoinExecutor { left_child: self.visit(plan.left()).unwrap(), right_child: self.visit(plan.right()).unwrap(), - join_output_schema: plan.output_columns(), + join_output_schema: plan.join_output_columns(), } .execute(), ) @@ -251,7 +251,7 @@ mod executor_test { println!("bound_stmt = {:#?}", bound_stmt); // convert bound stmts to logical plan - let planner = Planner {}; + let mut planner = Planner::default(); let logical_plan = planner.plan(bound_stmt)?; println!("logical_plan = {:#?}", logical_plan); let mut input_ref_rewriter = InputRefRewriter::default(); diff --git a/src/optimizer/core/rule.rs b/src/optimizer/core/rule.rs index a0654fa..e5feb89 100644 --- a/src/optimizer/core/rule.rs +++ b/src/optimizer/core/rule.rs @@ -1,6 +1,7 @@ use enum_dispatch::enum_dispatch; use super::{OptExpr, Pattern}; +use crate::planner::PlannerContext; /// A rule is to transform logically equivalent expression. There are two kinds of rules: /// @@ -13,7 +14,7 @@ pub trait Rule { /// Apply the rule and write the transformation result to `Substitute`. /// The pattern tree determines the opt_expr tree internal nodes type. - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute); + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, planner_context: &PlannerContext); } /// Define the transformed plans diff --git a/src/optimizer/heuristic/optimizer.rs b/src/optimizer/heuristic/optimizer.rs index dbd4e57..19c6011 100644 --- a/src/optimizer/heuristic/optimizer.rs +++ b/src/optimizer/heuristic/optimizer.rs @@ -4,17 +4,23 @@ use super::matcher::HepMatcher; use crate::optimizer::core::{PatternMatcher, Rule, Substitute}; use crate::optimizer::rules::RuleImpl; use crate::optimizer::PlanRef; +use crate::planner::PlannerContext; use crate::util::pretty_plan_tree_string; pub struct HepOptimizer { batches: Vec, graph: HepGraph, + planner_context: PlannerContext, } impl HepOptimizer { - pub fn new(batches: Vec, root: PlanRef) -> Self { + pub fn new(batches: Vec, root: PlanRef, planner_context: PlannerContext) -> Self { let graph = HepGraph::new(root); - Self { batches, graph } + Self { + batches, + graph, + planner_context, + } } pub fn find_best(&mut self) -> PlanRef { @@ -97,7 +103,7 @@ impl HepOptimizer { if let Some(opt_expr) = matcher.match_opt_expr() { let mut substitute = Substitute::default(); let opt_expr_root = opt_expr.root.clone(); - rule.apply(opt_expr, &mut substitute); + rule.apply(opt_expr, &mut substitute, &self.planner_context); if !substitute.opt_exprs.is_empty() { assert!(substitute.opt_exprs.len() == 1); @@ -172,7 +178,7 @@ mod tests { HepBatchStrategy::once_topdown(), vec![PhysicalRewriteRule::create()], ); - let mut planner = HepOptimizer::new(vec![batch], root); + let mut planner = HepOptimizer::new(vec![batch], root, Default::default()); let new_plan = planner.find_best(); assert_eq!( new_plan.as_physical_project().unwrap().logical().exprs()[0], diff --git a/src/optimizer/plan_node/dummy.rs b/src/optimizer/plan_node/dummy.rs index 48faa61..cc46925 100644 --- a/src/optimizer/plan_node/dummy.rs +++ b/src/optimizer/plan_node/dummy.rs @@ -22,9 +22,13 @@ impl PlanNode for Dummy { vec![] } - fn output_columns(&self) -> Vec { + fn output_columns(&self, _base_table_id: String) -> Vec { vec![] } + + fn get_based_table_id(&self) -> crate::catalog::TableId { + "Dummy".to_string() + } } impl PlanTreeNode for Dummy { diff --git a/src/optimizer/plan_node/logical_agg.rs b/src/optimizer/plan_node/logical_agg.rs index cfd4d11..e9ce1b5 100644 --- a/src/optimizer/plan_node/logical_agg.rs +++ b/src/optimizer/plan_node/logical_agg.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{PlanNode, PlanRef, PlanTreeNode}; use crate::binder::BoundExpr; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct LogicalAgg { @@ -36,16 +36,24 @@ impl LogicalAgg { impl PlanNode for LogicalAgg { fn referenced_columns(&self) -> Vec { - self.output_columns() + self.group_by + .iter() + .chain(self.agg_funcs.iter()) + .flat_map(|e| e.get_referenced_column_catalog()) + .collect::>() } - fn output_columns(&self) -> Vec { + fn output_columns(&self, base_table_id: String) -> Vec { self.group_by .iter() .chain(self.agg_funcs.iter()) - .flat_map(|e| e.get_referenced_column_catalog()) + .map(|e| e.output_column_catalog_for_alias_table(base_table_id.clone())) .collect::>() } + + fn get_based_table_id(&self) -> TableId { + self.children()[0].get_based_table_id() + } } impl PlanTreeNode for LogicalAgg { diff --git a/src/optimizer/plan_node/logical_filter.rs b/src/optimizer/plan_node/logical_filter.rs index 8b9604b..1d2cbf5 100644 --- a/src/optimizer/plan_node/logical_filter.rs +++ b/src/optimizer/plan_node/logical_filter.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{PlanNode, PlanRef, PlanTreeNode}; use crate::binder::BoundExpr; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct LogicalFilter { @@ -32,8 +32,12 @@ impl PlanNode for LogicalFilter { self.expr.get_referenced_column_catalog() } - fn output_columns(&self) -> Vec { - self.children()[0].output_columns() + fn output_columns(&self, base_table_id: String) -> Vec { + self.children()[0].output_columns(base_table_id) + } + + fn get_based_table_id(&self) -> TableId { + self.children()[0].get_based_table_id() } } diff --git a/src/optimizer/plan_node/logical_join.rs b/src/optimizer/plan_node/logical_join.rs index 340df5c..b8c0b85 100644 --- a/src/optimizer/plan_node/logical_join.rs +++ b/src/optimizer/plan_node/logical_join.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{PlanNode, PlanRef, PlanTreeNode}; use crate::binder::{JoinCondition, JoinType}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct LogicalJoin { @@ -30,7 +30,8 @@ impl LogicalJoin { join_condition, join_output_columns: vec![], }; - join.join_output_columns = join.join_output_columns_internal(); + let base_table_id = join.get_based_table_id(); + join.join_output_columns = join.join_output_columns_internal(base_table_id); join } @@ -79,7 +80,7 @@ impl LogicalJoin { /// /// So in the left child schema, b's fields is nullable, therefore we should use left join /// schema directly, rather than set b's fields as non-nullable. - fn join_output_columns_internal(&self) -> Vec { + fn join_output_columns_internal(&self, base_table_id: String) -> Vec { let (left_join_keys_force_nullable, right_join_keys_force_nullable) = match self.join_type { JoinType::Inner => (false, false), JoinType::Left => (false, true), @@ -89,7 +90,7 @@ impl LogicalJoin { }; let left_fields = self .left - .output_columns() + .output_columns(base_table_id.clone()) .iter() .map(|c| { c.clone_with_nullable( @@ -101,7 +102,7 @@ impl LogicalJoin { .collect::>(); let right_fields = self .right - .output_columns() + .output_columns(base_table_id) .iter() .map(|c| { c.clone_with_nullable( @@ -139,8 +140,12 @@ impl PlanNode for LogicalJoin { } } - fn output_columns(&self) -> Vec { - self.join_output_columns() + fn output_columns(&self, base_table_id: String) -> Vec { + self.join_output_columns_internal(base_table_id) + } + + fn get_based_table_id(&self) -> TableId { + self.children()[0].get_based_table_id() } } @@ -204,8 +209,9 @@ mod tests { let cond = build_join_condition_eq("t1", "b1", "t2", "b1"); let plan = LogicalJoin::new(t1.clone(), t2.clone(), JoinType::Inner, cond.clone()); + let based_table_id = plan.get_based_table_id(); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], false), build_columns_catalog("t2", vec!["a2", "b1", "c2"], false), @@ -215,7 +221,7 @@ mod tests { let plan = LogicalJoin::new(t1.clone(), t2.clone(), JoinType::Left, cond.clone()); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], false), build_columns_catalog("t2", vec!["a2", "b1", "c2"], true), @@ -225,7 +231,7 @@ mod tests { let plan = LogicalJoin::new(t1.clone(), t2.clone(), JoinType::Right, cond.clone()); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), build_columns_catalog("t2", vec!["a2", "b1", "c2"], false), @@ -235,7 +241,7 @@ mod tests { let plan = LogicalJoin::new(t1, t2, JoinType::Full, cond); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id), vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), build_columns_catalog("t2", vec!["a2", "b1", "c2"], true), @@ -282,8 +288,9 @@ mod tests { JoinType::Inner, cond2.clone(), ); + let based_table_id = plan.get_based_table_id(); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], false), @@ -307,7 +314,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], false), @@ -331,7 +338,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), @@ -355,7 +362,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), @@ -380,7 +387,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], false), @@ -404,7 +411,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], false), @@ -428,7 +435,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), @@ -452,7 +459,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), @@ -477,7 +484,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), @@ -501,7 +508,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), @@ -525,7 +532,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), @@ -544,7 +551,7 @@ mod tests { cond2, ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), diff --git a/src/optimizer/plan_node/logical_limit.rs b/src/optimizer/plan_node/logical_limit.rs index 7121f82..46bfd3d 100644 --- a/src/optimizer/plan_node/logical_limit.rs +++ b/src/optimizer/plan_node/logical_limit.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{PlanNode, PlanRef, PlanTreeNode}; use crate::binder::BoundExpr; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct LogicalLimit { @@ -39,8 +39,12 @@ impl PlanNode for LogicalLimit { vec![] } - fn output_columns(&self) -> Vec { - self.children()[0].output_columns() + fn output_columns(&self, base_table_id: String) -> Vec { + self.children()[0].output_columns(base_table_id) + } + + fn get_based_table_id(&self) -> TableId { + self.children()[0].get_based_table_id() } } diff --git a/src/optimizer/plan_node/logical_order.rs b/src/optimizer/plan_node/logical_order.rs index f11f8cc..b9555d7 100644 --- a/src/optimizer/plan_node/logical_order.rs +++ b/src/optimizer/plan_node/logical_order.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{PlanNode, PlanRef, PlanTreeNode}; use crate::binder::BoundOrderBy; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct LogicalOrder { @@ -33,8 +33,12 @@ impl PlanNode for LogicalOrder { .collect::>() } - fn output_columns(&self) -> Vec { - self.children()[0].output_columns() + fn output_columns(&self, base_table_id: String) -> Vec { + self.children()[0].output_columns(base_table_id) + } + + fn get_based_table_id(&self) -> TableId { + self.children()[0].get_based_table_id() } } diff --git a/src/optimizer/plan_node/logical_project.rs b/src/optimizer/plan_node/logical_project.rs index d53d393..3a56edc 100644 --- a/src/optimizer/plan_node/logical_project.rs +++ b/src/optimizer/plan_node/logical_project.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{PlanNode, PlanRef, PlanTreeNode}; use crate::binder::BoundExpr; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct LogicalProject { @@ -29,15 +29,22 @@ impl LogicalProject { impl PlanNode for LogicalProject { fn referenced_columns(&self) -> Vec { - self.output_columns() + self.exprs + .iter() + .flat_map(|e| e.get_referenced_column_catalog()) + .collect::>() } - fn output_columns(&self) -> Vec { + fn output_columns(&self, base_table_id: String) -> Vec { self.exprs .iter() - .flat_map(|e| e.get_referenced_column_catalog()) + .map(|e| e.output_column_catalog_for_alias_table(base_table_id.clone())) .collect::>() } + + fn get_based_table_id(&self) -> TableId { + self.children()[0].get_based_table_id() + } } impl PlanTreeNode for LogicalProject { diff --git a/src/optimizer/plan_node/logical_table_scan.rs b/src/optimizer/plan_node/logical_table_scan.rs index 88efe95..6335d36 100644 --- a/src/optimizer/plan_node/logical_table_scan.rs +++ b/src/optimizer/plan_node/logical_table_scan.rs @@ -59,11 +59,22 @@ impl LogicalTableScan { impl PlanNode for LogicalTableScan { fn referenced_columns(&self) -> Vec { - self.output_columns() + self.columns() } - fn output_columns(&self) -> Vec { - self.columns() + fn output_columns(&self, _: String) -> Vec { + if let Some(alias) = self.table_alias() { + self.columns() + .iter() + .map(|c| c.clone_with_table_id(alias.clone())) + .collect() + } else { + self.columns() + } + } + + fn get_based_table_id(&self) -> TableId { + self.table_id.clone() } } diff --git a/src/optimizer/plan_node/mod.rs b/src/optimizer/plan_node/mod.rs index 877365f..c6bda56 100644 --- a/src/optimizer/plan_node/mod.rs +++ b/src/optimizer/plan_node/mod.rs @@ -41,7 +41,7 @@ pub use physical_simple_agg::*; pub use physical_table_scan::*; pub use plan_node_traits::*; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; /// The common trait over all plan nodes. Used by optimizer framework which will treat all node as /// `dyn PlanNode`. Meanwhile, we split the trait into lots of sub-traits so that we can easily use @@ -49,11 +49,15 @@ use crate::catalog::ColumnCatalog; pub trait PlanNode: WithPlanNodeType + PlanTreeNode + Downcast + Debug + Display + Send + Sync { - /// All columns that appears in BoundExprs from this plan node. + /// Return column catalog that appears in BoundExprs which used in current PlanNode. fn referenced_columns(&self) -> Vec; - /// All columns that appears in output RecordBatch from this plan node. - fn output_columns(&self) -> Vec; + /// Return output column catalog which converted from `BoundExpr`. + fn output_columns(&self, base_table_id: String) -> Vec; + + // Get this PlanNode based TableId which could be TableScan Id or Join left child based table + // id. + fn get_based_table_id(&self) -> TableId; } impl_downcast!(PlanNode); diff --git a/src/optimizer/plan_node/physical_cross_join.rs b/src/optimizer/plan_node/physical_cross_join.rs index e02479f..137abff 100644 --- a/src/optimizer/plan_node/physical_cross_join.rs +++ b/src/optimizer/plan_node/physical_cross_join.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{LogicalJoin, PlanNode, PlanRef, PlanTreeNode}; use crate::binder::JoinType; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalCrossJoin { @@ -30,6 +30,10 @@ impl PhysicalCrossJoin { pub fn logical(&self) -> &LogicalJoin { &self.logical } + + pub fn join_output_columns(&self) -> Vec { + self.logical.join_output_columns() + } } impl PlanNode for PhysicalCrossJoin { @@ -37,8 +41,12 @@ impl PlanNode for PhysicalCrossJoin { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() } } diff --git a/src/optimizer/plan_node/physical_filter.rs b/src/optimizer/plan_node/physical_filter.rs index 38a57fd..bb27b81 100644 --- a/src/optimizer/plan_node/physical_filter.rs +++ b/src/optimizer/plan_node/physical_filter.rs @@ -2,7 +2,7 @@ use core::fmt; use std::sync::Arc; use super::{LogicalFilter, PlanNode, PlanRef, PlanTreeNode}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalFilter { @@ -24,8 +24,12 @@ impl PlanNode for PhysicalFilter { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() } } diff --git a/src/optimizer/plan_node/physical_hash_agg.rs b/src/optimizer/plan_node/physical_hash_agg.rs index 188ef5f..977be58 100644 --- a/src/optimizer/plan_node/physical_hash_agg.rs +++ b/src/optimizer/plan_node/physical_hash_agg.rs @@ -2,7 +2,7 @@ use std::fmt; use std::sync::Arc; use super::{LogicalAgg, PlanNode, PlanRef, PlanTreeNode}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalHashAgg { @@ -24,8 +24,12 @@ impl PlanNode for PhysicalHashAgg { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() } } diff --git a/src/optimizer/plan_node/physical_hash_join.rs b/src/optimizer/plan_node/physical_hash_join.rs index 54141a6..fe60023 100644 --- a/src/optimizer/plan_node/physical_hash_join.rs +++ b/src/optimizer/plan_node/physical_hash_join.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{LogicalJoin, PlanNode, PlanRef, PlanTreeNode}; use crate::binder::{JoinCondition, JoinType}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalHashJoin { @@ -34,6 +34,10 @@ impl PhysicalHashJoin { pub fn logical(&self) -> &LogicalJoin { &self.logical } + + pub fn join_output_columns(&self) -> Vec { + self.logical.join_output_columns() + } } impl PlanNode for PhysicalHashJoin { @@ -41,8 +45,12 @@ impl PlanNode for PhysicalHashJoin { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() } } diff --git a/src/optimizer/plan_node/physical_limit.rs b/src/optimizer/plan_node/physical_limit.rs index 83e99b0..5f285d6 100644 --- a/src/optimizer/plan_node/physical_limit.rs +++ b/src/optimizer/plan_node/physical_limit.rs @@ -2,7 +2,7 @@ use core::fmt; use std::sync::Arc; use super::{LogicalLimit, PlanNode, PlanRef, PlanTreeNode}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalLimit { @@ -24,8 +24,12 @@ impl PlanNode for PhysicalLimit { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() } } diff --git a/src/optimizer/plan_node/physical_order.rs b/src/optimizer/plan_node/physical_order.rs index fb13256..e3298c4 100644 --- a/src/optimizer/plan_node/physical_order.rs +++ b/src/optimizer/plan_node/physical_order.rs @@ -2,7 +2,7 @@ use core::fmt; use std::sync::Arc; use super::{LogicalOrder, PlanNode, PlanRef, PlanTreeNode}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalOrder { @@ -24,8 +24,12 @@ impl PlanNode for PhysicalOrder { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() } } diff --git a/src/optimizer/plan_node/physical_project.rs b/src/optimizer/plan_node/physical_project.rs index ea5c92e..aa3e73b 100644 --- a/src/optimizer/plan_node/physical_project.rs +++ b/src/optimizer/plan_node/physical_project.rs @@ -2,7 +2,7 @@ use std::fmt; use std::sync::Arc; use super::{LogicalProject, PlanNode, PlanRef, PlanTreeNode}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalProject { @@ -24,8 +24,12 @@ impl PlanNode for PhysicalProject { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() } } diff --git a/src/optimizer/plan_node/physical_simple_agg.rs b/src/optimizer/plan_node/physical_simple_agg.rs index bf78a21..6b521d4 100644 --- a/src/optimizer/plan_node/physical_simple_agg.rs +++ b/src/optimizer/plan_node/physical_simple_agg.rs @@ -2,7 +2,7 @@ use std::fmt; use std::sync::Arc; use super::{LogicalAgg, PlanNode, PlanRef, PlanTreeNode}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalSimpleAgg { @@ -24,8 +24,12 @@ impl PlanNode for PhysicalSimpleAgg { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() } } diff --git a/src/optimizer/plan_node/physical_table_scan.rs b/src/optimizer/plan_node/physical_table_scan.rs index 5521a62..1633380 100644 --- a/src/optimizer/plan_node/physical_table_scan.rs +++ b/src/optimizer/plan_node/physical_table_scan.rs @@ -2,7 +2,7 @@ use std::fmt; use std::sync::Arc; use super::{LogicalTableScan, PlanNode, PlanRef, PlanTreeNode}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalTableScan { @@ -24,8 +24,12 @@ impl PlanNode for PhysicalTableScan { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() } } diff --git a/src/optimizer/rules/column_pruning.rs b/src/optimizer/rules/column_pruning.rs index d8388c0..c761548 100644 --- a/src/optimizer/rules/column_pruning.rs +++ b/src/optimizer/rules/column_pruning.rs @@ -7,8 +7,9 @@ use crate::binder::{BoundColumnRef, BoundExpr}; use crate::optimizer::core::{ OptExpr, OptExprNode, Pattern, PatternChildrenPredicate, Rule, Substitute, }; -use crate::optimizer::rules::util::is_subset_cols; +use crate::optimizer::rules::util::is_superset_cols; use crate::optimizer::{Dummy, LogicalProject, LogicalTableScan, PlanNodeType}; +use crate::planner::PlannerContext; lazy_static! { static ref PUSH_PROJECT_INTO_TABLE_SCAN_RULE: Pattern = { @@ -57,7 +58,7 @@ impl Rule for PushProjectIntoTableScan { &PUSH_PROJECT_INTO_TABLE_SCAN_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let project_opt_expr_root = opt_expr.root; let table_scan_opt_expr = opt_expr.children[0].clone(); let project_node = project_opt_expr_root @@ -122,7 +123,7 @@ impl Rule for PushProjectThroughChild { &PUSH_PROJECT_THROUGH_CHILD_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, planner_context: &PlannerContext) { let project_opt_expr_root = opt_expr.root; let child_opt_expr = opt_expr.children[0].clone(); @@ -130,50 +131,70 @@ impl Rule for PushProjectThroughChild { let project_cols = project_plan_ref.referenced_columns(); let child_plan_ref = child_opt_expr.root.get_plan_ref(); let child_cols = child_plan_ref.referenced_columns(); - let required_cols = [project_cols, child_cols].concat(); - - let child_children_columns = child_plan_ref + let mut required_cols = [project_cols, child_cols].concat(); + let mut child_children_cols = child_plan_ref .children() .iter() - .flat_map(|c| c.output_columns()) + .flat_map(|c| { + c.output_columns( + planner_context + .find_subquery_alias(c) + .unwrap_or_else(|| c.get_based_table_id()), + ) + }) .collect::>(); - // if child_children_columns more than required_cols, pushdown extra projection. - if !is_subset_cols(&child_children_columns, &required_cols) { + // distinct cols + required_cols = required_cols.into_iter().unique().collect(); + child_children_cols = child_children_cols.into_iter().unique().collect(); + + // println!("required_cols: {:?}", required_cols); + // println!("child_children_cols: {:?}", child_children_cols); + + // if child_children_cols more than required_cols, pushdown extra projection. + if is_superset_cols(&child_children_cols, &required_cols) { let new_child_opt_expr_children = child_plan_ref .children() .iter() .zip_eq(child_opt_expr.children.iter()) .map(|(child_child_plan, child_child_opt_expr)| { + // Note: resolve base_table_id to calc real ColumnCatalog for subquery + // such as: select a, t2.v1 as max_b from t1 cross join (select max(b) as v1 + // from t1) t2; + // `t2.v1` should be resolved in child_child_plan output_columns. + let base_table_id = planner_context + .find_subquery_alias(child_child_plan) + .unwrap_or_else(|| child_child_plan.get_based_table_id()); + let mut child_child_output_cols = + child_child_plan.output_columns(base_table_id); + // for child's child, filter corresponding required columns + let mut required_cols_in_child_child = child_child_output_cols + .clone() + .into_iter() + .filter(|c| required_cols.contains(c)) + .collect::>(); + + // distinct cols + child_child_output_cols = + child_child_output_cols.into_iter().unique().collect(); + required_cols_in_child_child = + required_cols_in_child_child.into_iter().unique().collect(); + // println!("child_child_output_cols: {:?}", child_child_output_cols); + // println!( + // "required_cols_in_child_child: {:?}", + // required_cols_in_child_child + // ); + // if child's child cols more than required_cols, pushdown extra projection. - if !is_subset_cols(&child_child_plan.output_columns(), &required_cols) { - // for child's child, filter corresponding required columns - let exprs = child_child_plan - .output_columns() - .iter() - .filter(|c| required_cols.contains(c)) - .map(|c| { - BoundExpr::ColumnRef(BoundColumnRef { - column_catalog: c.clone(), - }) - }) - .collect::>(); - - // FIXME: resolve alias corresponding real ColumnCatalog - // such as: select a, t2.v1 as max_b from t1 cross join (select max(b) as v1 - // from t1) t2; - // t2.v1 should be resolved to t1.b which means this exprs only use t1.b - // column. - // assert!(exprs.is_empty(), "pruned project exprs should not be empty"); - - if exprs.is_empty() { - child_child_opt_expr.clone() - } else { - let new_project = LogicalProject::new(exprs, Dummy::new_ref()); - OptExpr { - root: OptExprNode::PlanRef(Arc::new(new_project)), - children: vec![child_child_opt_expr.clone()], - } + if is_superset_cols(&child_child_output_cols, &required_cols_in_child_child) { + let exprs = required_cols_in_child_child + .into_iter() + .map(|c| BoundExpr::ColumnRef(BoundColumnRef { column_catalog: c })) + .collect(); + let new_project = LogicalProject::new(exprs, Dummy::new_ref()); + OptExpr { + root: OptExprNode::PlanRef(Arc::new(new_project)), + children: vec![child_child_opt_expr.clone()], } } else { child_child_opt_expr.clone() @@ -212,7 +233,7 @@ impl Rule for RemoveNoopOperators { &REMOVE_NOOP_OPERATORS_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { // eliminate no-op project for those children type: project{input: project/aggregate} let project_opt_expr_root = opt_expr.root; let project_plan_ref = project_opt_expr_root.get_plan_ref(); @@ -284,7 +305,7 @@ LogicalProject: exprs [t1.a:Nullable(Int32) + 1] HepBatchStrategy::fix_point_topdown(100), vec![PushProjectIntoTableScan::create()], ); - let mut optimizer = HepOptimizer::new(vec![batch], logical_plan); + let mut optimizer = HepOptimizer::new(vec![batch], logical_plan, Default::default()); let optimized_plan = optimizer.find_best(); @@ -349,7 +370,8 @@ LogicalProject: exprs [employee.id:Nullable(Int32), employee.first_name:Nullable HepBatchStrategy::fix_point_topdown(100), vec![RemoveNoopOperators::create()], ); - let mut optimizer = HepOptimizer::new(vec![batch, final_batch], logical_plan); + let mut optimizer = + HepOptimizer::new(vec![batch, final_batch], logical_plan, Default::default()); let optimized_plan = optimizer.find_best(); diff --git a/src/optimizer/rules/combine_operators.rs b/src/optimizer/rules/combine_operators.rs index a8eb194..6bf0a58 100644 --- a/src/optimizer/rules/combine_operators.rs +++ b/src/optimizer/rules/combine_operators.rs @@ -6,6 +6,7 @@ use crate::optimizer::core::{ OptExpr, OptExprNode, Pattern, PatternChildrenPredicate, Rule, Substitute, }; use crate::optimizer::{Dummy, LogicalFilter, PlanNodeType}; +use crate::planner::PlannerContext; lazy_static! { static ref COLLAPSE_PROJECT_RULE: Pattern = { @@ -43,7 +44,7 @@ impl Rule for CollapseProject { &COLLAPSE_PROJECT_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { // TODO: handle column alias let project_opt_expr = opt_expr; let next_project_opt_expr = project_opt_expr.children[0].clone(); @@ -82,7 +83,7 @@ impl Rule for CombineFilter { &COMBINE_FILTERS } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { // TODO: handle column alias let filter_opt_expr = opt_expr; let next_filter_opt_expr = filter_opt_expr.children[0].clone(); diff --git a/src/optimizer/rules/mod.rs b/src/optimizer/rules/mod.rs index f99d365..52f5c3c 100644 --- a/src/optimizer/rules/mod.rs +++ b/src/optimizer/rules/mod.rs @@ -17,6 +17,7 @@ pub use simplification::*; use strum_macros::AsRefStr; use crate::optimizer::core::{OptExpr, Pattern, Rule, Substitute}; +use crate::planner::PlannerContext; #[enum_dispatch(Rule)] #[derive(Clone, AsRefStr)] @@ -105,7 +106,7 @@ mod rule_test_util { let mut binder = Binder::new(Arc::new(catalog)); let bound_stmt = binder.bind(&stats[0]).unwrap(); - let planner = Planner {}; + let mut planner = Planner::default(); planner.plan(bound_stmt).unwrap() } } diff --git a/src/optimizer/rules/physical_rewrite.rs b/src/optimizer/rules/physical_rewrite.rs index 6561369..6133cd3 100644 --- a/src/optimizer/rules/physical_rewrite.rs +++ b/src/optimizer/rules/physical_rewrite.rs @@ -1,6 +1,7 @@ use super::RuleImpl; use crate::optimizer::core::*; use crate::optimizer::{PhysicalRewriter, PlanRewriter}; +use crate::planner::PlannerContext; lazy_static! { static ref PATTERN: Pattern = { @@ -25,7 +26,7 @@ impl Rule for PhysicalRewriteRule { &PATTERN } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let mut rewriter = PhysicalRewriter::default(); let plan = opt_expr.to_plan_ref(); let new_plan = rewriter.rewrite(plan); diff --git a/src/optimizer/rules/pushdown_limit.rs b/src/optimizer/rules/pushdown_limit.rs index 077f2e6..fc2ee2a 100644 --- a/src/optimizer/rules/pushdown_limit.rs +++ b/src/optimizer/rules/pushdown_limit.rs @@ -6,6 +6,7 @@ use crate::optimizer::core::{ OptExpr, OptExprNode, Pattern, PatternChildrenPredicate, Rule, Substitute, }; use crate::optimizer::{Dummy, LogicalLimit, LogicalTableScan, PlanNodeType}; +use crate::planner::PlannerContext; lazy_static! { static ref LIMIT_PROJECT_TRANSPOSE_RULE: Pattern = { @@ -61,7 +62,7 @@ impl Rule for LimitProjectTranspose { &LIMIT_PROJECT_TRANSPOSE_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let limit_opt_expr_root = opt_expr.root; let project_opt_expr = opt_expr.children[0].clone(); @@ -90,7 +91,7 @@ impl Rule for EliminateLimits { &ELIMINATE_LIMITS_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let limit_opt_expr_root = opt_expr.root; let next_limit_opt_expr = opt_expr.children[0].clone(); let next_limit_opt_expr_root = next_limit_opt_expr.root; @@ -156,7 +157,7 @@ impl Rule for PushLimitThroughJoin { &PUSH_LIMIT_THROUGH_JOIN_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let limit_opt_expr_root = opt_expr.root; let limit_node = limit_opt_expr_root .get_plan_ref() @@ -237,7 +238,7 @@ impl Rule for PushLimitIntoTableScan { &PUSH_LIMIT_INTO_TABLE_SCAN_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let limit_opt_expr_root = opt_expr.root; let limit_node = limit_opt_expr_root .get_plan_ref() @@ -306,7 +307,7 @@ LogicalProject: exprs [t1.a:Nullable(Int32)] HepBatchStrategy::fix_point_topdown(100), vec![LimitProjectTranspose::create()], ); - let mut optimizer = HepOptimizer::new(vec![batch], logical_plan); + let mut optimizer = HepOptimizer::new(vec![batch], logical_plan, Default::default()); let optimized_plan = optimizer.find_best(); @@ -364,7 +365,7 @@ LogicalProject: exprs [t1.a:Nullable(Int32)] EliminateLimits::create(), ], ); - let mut optimizer = HepOptimizer::new(vec![batch], logical_plan); + let mut optimizer = HepOptimizer::new(vec![batch], logical_plan, Default::default()); let optimized_plan = optimizer.find_best(); @@ -394,7 +395,7 @@ LogicalProject: exprs [t1.a:Nullable(Int32)] PushLimitIntoTableScan::create(), ], ); - let mut optimizer = HepOptimizer::new(vec![batch], logical_plan); + let mut optimizer = HepOptimizer::new(vec![batch], logical_plan, Default::default()); let optimized_plan = optimizer.find_best(); diff --git a/src/optimizer/rules/pushdown_predicates.rs b/src/optimizer/rules/pushdown_predicates.rs index de81882..a294ea2 100644 --- a/src/optimizer/rules/pushdown_predicates.rs +++ b/src/optimizer/rules/pushdown_predicates.rs @@ -12,6 +12,7 @@ use crate::catalog::ColumnCatalog; use crate::optimizer::core::*; use crate::optimizer::expr_rewriter::ExprRewriter; use crate::optimizer::{Dummy, LogicalFilter, LogicalJoin, PlanNodeType}; +use crate::planner::PlannerContext; lazy_static! { static ref PUSH_PREDICATE_THROUGH_JOIN: Pattern = { @@ -105,15 +106,17 @@ impl Rule for PushPredicateThroughJoin { &PUSH_PREDICATE_THROUGH_JOIN } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let join_opt_expr = opt_expr.children[0].clone(); let join_node = join_opt_expr.root.get_plan_ref().as_logical_join().unwrap(); if !self.can_push_through(join_node.join_type()) { return; } - let left_output_cols = join_node.left().output_columns(); - let right_output_cols = join_node.right().output_columns(); + let left = join_node.left(); + let left_output_cols = left.output_columns(left.get_based_table_id()); + let right = join_node.right(); + let right_output_cols = right.output_columns(right.get_based_table_id()); let filter_opt_expr = opt_expr; let join_left_opt_expr = join_opt_expr.children[0].clone(); @@ -203,7 +206,7 @@ impl Rule for PushPredicateThroughNonJoin { &PUSH_PREDICATE_THROUGH_NON_JOIN } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let filter_opt_expr = opt_expr; let child_opt_expr = filter_opt_expr.children[0].clone(); let child_node = child_opt_expr.root.get_plan_ref(); @@ -354,7 +357,7 @@ LogicalProject: exprs [t1.a:Nullable(Int32), t1.b:Nullable(Int32), t1.c:Nullable HepBatchStrategy::fix_point_topdown(100), vec![PushPredicateThroughJoin::create()], ); - let mut optimizer = HepOptimizer::new(vec![batch], logical_plan); + let mut optimizer = HepOptimizer::new(vec![batch], logical_plan, Default::default()); let optimized_plan = optimizer.find_best(); @@ -404,7 +407,7 @@ LogicalProject: exprs [t1.a:Nullable(Int32), t1.b:Nullable(Int32), t1.c:Nullable ), ]; - let mut optimizer = HepOptimizer::new(batches, logical_plan); + let mut optimizer = HepOptimizer::new(batches, logical_plan, Default::default()); let optimized_plan = optimizer.find_best(); diff --git a/src/optimizer/rules/simplification.rs b/src/optimizer/rules/simplification.rs index c1ee58c..450e71d 100644 --- a/src/optimizer/rules/simplification.rs +++ b/src/optimizer/rules/simplification.rs @@ -8,6 +8,7 @@ use crate::optimizer::{ LogicalAgg, LogicalFilter, LogicalJoin, LogicalLimit, LogicalOrder, LogicalProject, PlanRef, PlanRewriter, }; +use crate::planner::PlannerContext; lazy_static! { static ref SIMPLIFY_CASTS_RULE: Pattern = { @@ -32,7 +33,7 @@ impl Rule for SimplifyCasts { &SIMPLIFY_CASTS_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let mut rewriter = SimplifyCastsRewriter::default(); let plan = opt_expr.to_plan_ref(); let new_plan = rewriter.rewrite(plan); diff --git a/src/optimizer/rules/util.rs b/src/optimizer/rules/util.rs index ab4d116..bce56ea 100644 --- a/src/optimizer/rules/util.rs +++ b/src/optimizer/rules/util.rs @@ -6,10 +6,17 @@ use crate::catalog::ColumnCatalog; /// Return true when left is subset of right, only compare table_id and column_id, so it's safe to /// used for join output cols with nullable columns. +/// If left equals right, return true. pub fn is_subset_cols(left: &[ColumnCatalog], right: &[ColumnCatalog]) -> bool { left.iter().all(|l| right.contains(l)) } +/// Return true when left is superset of right. +/// If left equals right, return false. +pub fn is_superset_cols(left: &[ColumnCatalog], right: &[ColumnCatalog]) -> bool { + right.iter().all(|r| left.contains(r)) && left.len() > right.len() +} + /// Return true when left is subset of right pub fn is_subset_exprs(left: &[BoundExpr], right: &[BoundExpr]) -> bool { left.iter().all(|l| right.contains(l)) diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 3ab0b8c..2348b1e 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -1,13 +1,35 @@ mod select; mod util; +use std::collections::HashMap; + use crate::binder::BoundStatement; use crate::optimizer::PlanRef; -pub struct Planner {} +#[derive(Default)] +pub struct Planner { + pub context: PlannerContext, +} + +#[derive(Default, Debug)] +pub struct PlannerContext { + // subquery alias to subquery plan + pub subquery_context: HashMap, +} + +impl PlannerContext { + pub fn find_subquery_alias(&self, plan_ref: &PlanRef) -> Option { + for (alias, p) in &self.subquery_context { + if p == plan_ref { + return Some(alias.clone()); + } + } + None + } +} impl Planner { - pub fn plan(&self, stmt: BoundStatement) -> Result { + pub fn plan(&mut self, stmt: BoundStatement) -> Result { match stmt { BoundStatement::Select(stmt) => self.plan_select(stmt), } @@ -107,12 +129,15 @@ mod planner_test { #[test] fn test_plan_select_works() { let stmt = build_test_select_stmt(); - let p = Planner {}; + let mut p = Planner::default(); let node = p.plan(stmt); assert!(node.is_ok()); let plan_ref = node.unwrap(); assert_eq!(plan_ref.node_type(), PlanNodeType::LogicalLimit); - assert_eq!(plan_ref.output_columns().len(), 1); + assert_eq!( + plan_ref.output_columns(plan_ref.get_based_table_id()).len(), + 1 + ); dbg!(plan_ref); } @@ -123,7 +148,7 @@ mod planner_test { // inner join t2 on t1.c1=t2.c1 // left join t3 on t2.c1=t3.c1 let stmt = build_test_select_stmt_with_multiple_joins(); - let p = Planner {}; + let mut p = Planner::default(); let node = p.plan(stmt); assert!(node.is_ok()); let plan_ref = node.unwrap(); @@ -174,7 +199,7 @@ mod planner_test { #[test] fn test_plan_select_distinct_works() { let stmt = build_test_select_distinct_stmt(); - let p = Planner {}; + let mut p = Planner::default(); let node = p.plan(stmt); assert!(node.is_ok()); let plan_ref = node.unwrap(); diff --git a/src/planner/select.rs b/src/planner/select.rs index 5cd2870..655dc97 100644 --- a/src/planner/select.rs +++ b/src/planner/select.rs @@ -6,7 +6,7 @@ use crate::binder::{BoundSelect, BoundTableRef}; use crate::optimizer::*; impl Planner { - pub fn plan_select(&self, stmt: BoundSelect) -> Result { + pub fn plan_select(&mut self, stmt: BoundSelect) -> Result { let mut plan: PlanRef; if let Some(table_ref) = stmt.from_table { @@ -48,7 +48,7 @@ impl Planner { Ok(plan) } - fn plan_table_ref(&self, table_ref: &BoundTableRef) -> Result { + fn plan_table_ref(&mut self, table_ref: &BoundTableRef) -> Result { match table_ref { BoundTableRef::Table(table) => Ok(Arc::new(LogicalTableScan::new( table.catalog.id.clone(), @@ -73,7 +73,11 @@ impl Planner { } BoundTableRef::Subquery(subquery) => { let subquery = subquery.clone(); - self.plan_select(*subquery.query) + let plan_ref = self.plan_select(*subquery.query)?; + self.context + .subquery_context + .insert(subquery.alias, plan_ref.clone()); + Ok(plan_ref) } } } diff --git a/tests/planner/column-pruning.planner.sql b/tests/planner/column-pruning.planner.sql index 267201e..6c80419 100644 --- a/tests/planner/column-pruning.planner.sql +++ b/tests/planner/column-pruning.planner.sql @@ -94,3 +94,25 @@ PhysicalProject: exprs [employee.id:Int64, employee.first_name:Utf8, department. PhysicalTableScan: table: #state, columns: [state_code, state_name] */ +-- PushProjectThroughChild: column pruning across subquery + +select a, t2.v1 as max_b from t1 cross join (select max(b) as v1 from t1) t2 + +/* +original plan: +LogicalProject: exprs [t1.a:Int64, (t2.v1:Int64) as t1.max_b] + LogicalJoin: type Cross, cond None + LogicalTableScan: table: #t1, columns: [a, b, c] + LogicalProject: exprs [((Max(t1.b:Int64):Int64) as t1.v1) as t2.v1] + LogicalAgg: agg_funcs [Max(t1.b:Int64):Int64] group_by [] + LogicalTableScan: table: #t1, columns: [a, b, c] + +optimized plan: +PhysicalProject: exprs [t1.a:Int64, (t2.v1:Int64) as t1.max_b] + PhysicalCrossJoin: type Cross + PhysicalTableScan: table: #t1, columns: [a] + PhysicalProject: exprs [((Max(t1.b:Int64):Int64) as t1.v1) as t2.v1] + PhysicalSimpleAgg: agg_funcs [Max(t1.b:Int64):Int64] group_by [] + PhysicalTableScan: table: #t1, columns: [b] +*/ + diff --git a/tests/planner/column-pruning.yml b/tests/planner/column-pruning.yml index 0fd8751..a15986f 100644 --- a/tests/planner/column-pruning.yml +++ b/tests/planner/column-pruning.yml @@ -25,3 +25,7 @@ desc: | PushProjectThroughChild: column pruning across multiple join +- sql: | + select a, t2.v1 as max_b from t1 cross join (select max(b) as v1 from t1) t2 + desc: | + PushProjectThroughChild: column pruning across subquery