From 9eb0c6a705becd186899d5e519f09a7dd176552e Mon Sep 17 00:00:00 2001 From: Fedomn Date: Fri, 21 Oct 2022 00:46:17 +0800 Subject: [PATCH 1/4] refactor(optimizer): add PlannerContext for Optimizer to resolve subquery output incorrect columns issue Signed-off-by: Fedomn --- src/db.rs | 14 +++++----- src/executor/mod.rs | 2 +- src/optimizer/core/rule.rs | 3 +- src/optimizer/heuristic/optimizer.rs | 14 +++++++--- src/optimizer/rules/combine_operators.rs | 5 ++-- src/optimizer/rules/mod.rs | 3 +- src/optimizer/rules/physical_rewrite.rs | 3 +- src/optimizer/rules/pushdown_limit.rs | 15 +++++----- src/optimizer/rules/pushdown_predicates.rs | 9 +++--- src/optimizer/rules/simplification.rs | 3 +- src/planner/mod.rs | 32 ++++++++++++++++++---- src/planner/select.rs | 10 +++++-- 12 files changed, 76 insertions(+), 37 deletions(-) diff --git a/src/db.rs b/src/db.rs index aa9f17a..4fb5ae5 100644 --- a/src/db.rs +++ b/src/db.rs @@ -15,7 +15,7 @@ use crate::optimizer::{ RemoveNoopOperators, SimplifyCasts, }; use crate::parser::parse; -use crate::planner::{LogicalPlanError, Planner}; +use crate::planner::{LogicalPlanError, Planner, PlannerContext}; use crate::storage::{CsvStorage, Storage, StorageError, StorageImpl}; use crate::util::pretty_plan_tree_string; @@ -54,7 +54,7 @@ impl Database { Ok(data) } - fn default_optimizer(&self, root: PlanRef) -> HepOptimizer { + fn default_optimizer(&self, root: PlanRef, planner_context: PlannerContext) -> HepOptimizer { // the order of rules is important and affects the rule matching logic let batches = vec![ HepBatch::new( @@ -101,7 +101,7 @@ impl Database { ), ]; - HepOptimizer::new(batches, root) + HepOptimizer::new(batches, root, planner_context) } pub async fn run(&self, sql: &str) -> Result, DatabaseError> { @@ -123,7 +123,7 @@ impl Database { println!("bound_stmt:\n{:#?}\n", bound_stmt); // 3. convert bound stmts to logical plan - let planner = Planner {}; + let mut planner = Planner::default(); let logical_plan = planner.plan(bound_stmt)?; println!( "original_plan:\n{}\n", @@ -131,7 +131,7 @@ impl Database { ); // 4. optimize logical plan to physical plan - let mut optimizer = self.default_optimizer(logical_plan); + let mut optimizer = self.default_optimizer(logical_plan, planner.context); let physical_plan = optimizer.find_best(); println!( "optimized_plan:\n{}\n", @@ -165,7 +165,7 @@ impl Database { let bound_stmt = binder.bind(&stats[0])?; let mut explain_str = String::new(); - let planner = Planner {}; + let mut planner = Planner::default(); let logical_plan = planner.plan(bound_stmt)?; _ = write!( explain_str, @@ -173,7 +173,7 @@ impl Database { pretty_plan_tree_string(&*logical_plan) ); - let mut optimizer = self.default_optimizer(logical_plan); + let mut optimizer = self.default_optimizer(logical_plan, planner.context); let physical_plan = optimizer.find_best(); _ = write!( explain_str, diff --git a/src/executor/mod.rs b/src/executor/mod.rs index e01c762..43db615 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -251,7 +251,7 @@ mod executor_test { println!("bound_stmt = {:#?}", bound_stmt); // convert bound stmts to logical plan - let planner = Planner {}; + let mut planner = Planner::default(); let logical_plan = planner.plan(bound_stmt)?; println!("logical_plan = {:#?}", logical_plan); let mut input_ref_rewriter = InputRefRewriter::default(); diff --git a/src/optimizer/core/rule.rs b/src/optimizer/core/rule.rs index a0654fa..e5feb89 100644 --- a/src/optimizer/core/rule.rs +++ b/src/optimizer/core/rule.rs @@ -1,6 +1,7 @@ use enum_dispatch::enum_dispatch; use super::{OptExpr, Pattern}; +use crate::planner::PlannerContext; /// A rule is to transform logically equivalent expression. There are two kinds of rules: /// @@ -13,7 +14,7 @@ pub trait Rule { /// Apply the rule and write the transformation result to `Substitute`. /// The pattern tree determines the opt_expr tree internal nodes type. - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute); + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, planner_context: &PlannerContext); } /// Define the transformed plans diff --git a/src/optimizer/heuristic/optimizer.rs b/src/optimizer/heuristic/optimizer.rs index dbd4e57..19c6011 100644 --- a/src/optimizer/heuristic/optimizer.rs +++ b/src/optimizer/heuristic/optimizer.rs @@ -4,17 +4,23 @@ use super::matcher::HepMatcher; use crate::optimizer::core::{PatternMatcher, Rule, Substitute}; use crate::optimizer::rules::RuleImpl; use crate::optimizer::PlanRef; +use crate::planner::PlannerContext; use crate::util::pretty_plan_tree_string; pub struct HepOptimizer { batches: Vec, graph: HepGraph, + planner_context: PlannerContext, } impl HepOptimizer { - pub fn new(batches: Vec, root: PlanRef) -> Self { + pub fn new(batches: Vec, root: PlanRef, planner_context: PlannerContext) -> Self { let graph = HepGraph::new(root); - Self { batches, graph } + Self { + batches, + graph, + planner_context, + } } pub fn find_best(&mut self) -> PlanRef { @@ -97,7 +103,7 @@ impl HepOptimizer { if let Some(opt_expr) = matcher.match_opt_expr() { let mut substitute = Substitute::default(); let opt_expr_root = opt_expr.root.clone(); - rule.apply(opt_expr, &mut substitute); + rule.apply(opt_expr, &mut substitute, &self.planner_context); if !substitute.opt_exprs.is_empty() { assert!(substitute.opt_exprs.len() == 1); @@ -172,7 +178,7 @@ mod tests { HepBatchStrategy::once_topdown(), vec![PhysicalRewriteRule::create()], ); - let mut planner = HepOptimizer::new(vec![batch], root); + let mut planner = HepOptimizer::new(vec![batch], root, Default::default()); let new_plan = planner.find_best(); assert_eq!( new_plan.as_physical_project().unwrap().logical().exprs()[0], diff --git a/src/optimizer/rules/combine_operators.rs b/src/optimizer/rules/combine_operators.rs index a8eb194..6bf0a58 100644 --- a/src/optimizer/rules/combine_operators.rs +++ b/src/optimizer/rules/combine_operators.rs @@ -6,6 +6,7 @@ use crate::optimizer::core::{ OptExpr, OptExprNode, Pattern, PatternChildrenPredicate, Rule, Substitute, }; use crate::optimizer::{Dummy, LogicalFilter, PlanNodeType}; +use crate::planner::PlannerContext; lazy_static! { static ref COLLAPSE_PROJECT_RULE: Pattern = { @@ -43,7 +44,7 @@ impl Rule for CollapseProject { &COLLAPSE_PROJECT_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { // TODO: handle column alias let project_opt_expr = opt_expr; let next_project_opt_expr = project_opt_expr.children[0].clone(); @@ -82,7 +83,7 @@ impl Rule for CombineFilter { &COMBINE_FILTERS } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { // TODO: handle column alias let filter_opt_expr = opt_expr; let next_filter_opt_expr = filter_opt_expr.children[0].clone(); diff --git a/src/optimizer/rules/mod.rs b/src/optimizer/rules/mod.rs index f99d365..52f5c3c 100644 --- a/src/optimizer/rules/mod.rs +++ b/src/optimizer/rules/mod.rs @@ -17,6 +17,7 @@ pub use simplification::*; use strum_macros::AsRefStr; use crate::optimizer::core::{OptExpr, Pattern, Rule, Substitute}; +use crate::planner::PlannerContext; #[enum_dispatch(Rule)] #[derive(Clone, AsRefStr)] @@ -105,7 +106,7 @@ mod rule_test_util { let mut binder = Binder::new(Arc::new(catalog)); let bound_stmt = binder.bind(&stats[0]).unwrap(); - let planner = Planner {}; + let mut planner = Planner::default(); planner.plan(bound_stmt).unwrap() } } diff --git a/src/optimizer/rules/physical_rewrite.rs b/src/optimizer/rules/physical_rewrite.rs index 6561369..6133cd3 100644 --- a/src/optimizer/rules/physical_rewrite.rs +++ b/src/optimizer/rules/physical_rewrite.rs @@ -1,6 +1,7 @@ use super::RuleImpl; use crate::optimizer::core::*; use crate::optimizer::{PhysicalRewriter, PlanRewriter}; +use crate::planner::PlannerContext; lazy_static! { static ref PATTERN: Pattern = { @@ -25,7 +26,7 @@ impl Rule for PhysicalRewriteRule { &PATTERN } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let mut rewriter = PhysicalRewriter::default(); let plan = opt_expr.to_plan_ref(); let new_plan = rewriter.rewrite(plan); diff --git a/src/optimizer/rules/pushdown_limit.rs b/src/optimizer/rules/pushdown_limit.rs index 077f2e6..fc2ee2a 100644 --- a/src/optimizer/rules/pushdown_limit.rs +++ b/src/optimizer/rules/pushdown_limit.rs @@ -6,6 +6,7 @@ use crate::optimizer::core::{ OptExpr, OptExprNode, Pattern, PatternChildrenPredicate, Rule, Substitute, }; use crate::optimizer::{Dummy, LogicalLimit, LogicalTableScan, PlanNodeType}; +use crate::planner::PlannerContext; lazy_static! { static ref LIMIT_PROJECT_TRANSPOSE_RULE: Pattern = { @@ -61,7 +62,7 @@ impl Rule for LimitProjectTranspose { &LIMIT_PROJECT_TRANSPOSE_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let limit_opt_expr_root = opt_expr.root; let project_opt_expr = opt_expr.children[0].clone(); @@ -90,7 +91,7 @@ impl Rule for EliminateLimits { &ELIMINATE_LIMITS_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let limit_opt_expr_root = opt_expr.root; let next_limit_opt_expr = opt_expr.children[0].clone(); let next_limit_opt_expr_root = next_limit_opt_expr.root; @@ -156,7 +157,7 @@ impl Rule for PushLimitThroughJoin { &PUSH_LIMIT_THROUGH_JOIN_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let limit_opt_expr_root = opt_expr.root; let limit_node = limit_opt_expr_root .get_plan_ref() @@ -237,7 +238,7 @@ impl Rule for PushLimitIntoTableScan { &PUSH_LIMIT_INTO_TABLE_SCAN_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let limit_opt_expr_root = opt_expr.root; let limit_node = limit_opt_expr_root .get_plan_ref() @@ -306,7 +307,7 @@ LogicalProject: exprs [t1.a:Nullable(Int32)] HepBatchStrategy::fix_point_topdown(100), vec![LimitProjectTranspose::create()], ); - let mut optimizer = HepOptimizer::new(vec![batch], logical_plan); + let mut optimizer = HepOptimizer::new(vec![batch], logical_plan, Default::default()); let optimized_plan = optimizer.find_best(); @@ -364,7 +365,7 @@ LogicalProject: exprs [t1.a:Nullable(Int32)] EliminateLimits::create(), ], ); - let mut optimizer = HepOptimizer::new(vec![batch], logical_plan); + let mut optimizer = HepOptimizer::new(vec![batch], logical_plan, Default::default()); let optimized_plan = optimizer.find_best(); @@ -394,7 +395,7 @@ LogicalProject: exprs [t1.a:Nullable(Int32)] PushLimitIntoTableScan::create(), ], ); - let mut optimizer = HepOptimizer::new(vec![batch], logical_plan); + let mut optimizer = HepOptimizer::new(vec![batch], logical_plan, Default::default()); let optimized_plan = optimizer.find_best(); diff --git a/src/optimizer/rules/pushdown_predicates.rs b/src/optimizer/rules/pushdown_predicates.rs index de81882..6bdcddf 100644 --- a/src/optimizer/rules/pushdown_predicates.rs +++ b/src/optimizer/rules/pushdown_predicates.rs @@ -12,6 +12,7 @@ use crate::catalog::ColumnCatalog; use crate::optimizer::core::*; use crate::optimizer::expr_rewriter::ExprRewriter; use crate::optimizer::{Dummy, LogicalFilter, LogicalJoin, PlanNodeType}; +use crate::planner::PlannerContext; lazy_static! { static ref PUSH_PREDICATE_THROUGH_JOIN: Pattern = { @@ -105,7 +106,7 @@ impl Rule for PushPredicateThroughJoin { &PUSH_PREDICATE_THROUGH_JOIN } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let join_opt_expr = opt_expr.children[0].clone(); let join_node = join_opt_expr.root.get_plan_ref().as_logical_join().unwrap(); if !self.can_push_through(join_node.join_type()) { @@ -203,7 +204,7 @@ impl Rule for PushPredicateThroughNonJoin { &PUSH_PREDICATE_THROUGH_NON_JOIN } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let filter_opt_expr = opt_expr; let child_opt_expr = filter_opt_expr.children[0].clone(); let child_node = child_opt_expr.root.get_plan_ref(); @@ -354,7 +355,7 @@ LogicalProject: exprs [t1.a:Nullable(Int32), t1.b:Nullable(Int32), t1.c:Nullable HepBatchStrategy::fix_point_topdown(100), vec![PushPredicateThroughJoin::create()], ); - let mut optimizer = HepOptimizer::new(vec![batch], logical_plan); + let mut optimizer = HepOptimizer::new(vec![batch], logical_plan, Default::default()); let optimized_plan = optimizer.find_best(); @@ -404,7 +405,7 @@ LogicalProject: exprs [t1.a:Nullable(Int32), t1.b:Nullable(Int32), t1.c:Nullable ), ]; - let mut optimizer = HepOptimizer::new(batches, logical_plan); + let mut optimizer = HepOptimizer::new(batches, logical_plan, Default::default()); let optimized_plan = optimizer.find_best(); diff --git a/src/optimizer/rules/simplification.rs b/src/optimizer/rules/simplification.rs index c1ee58c..450e71d 100644 --- a/src/optimizer/rules/simplification.rs +++ b/src/optimizer/rules/simplification.rs @@ -8,6 +8,7 @@ use crate::optimizer::{ LogicalAgg, LogicalFilter, LogicalJoin, LogicalLimit, LogicalOrder, LogicalProject, PlanRef, PlanRewriter, }; +use crate::planner::PlannerContext; lazy_static! { static ref SIMPLIFY_CASTS_RULE: Pattern = { @@ -32,7 +33,7 @@ impl Rule for SimplifyCasts { &SIMPLIFY_CASTS_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let mut rewriter = SimplifyCastsRewriter::default(); let plan = opt_expr.to_plan_ref(); let new_plan = rewriter.rewrite(plan); diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 3ab0b8c..6abc525 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -1,13 +1,35 @@ mod select; mod util; +use std::collections::HashMap; + use crate::binder::BoundStatement; use crate::optimizer::PlanRef; -pub struct Planner {} +#[derive(Default)] +pub struct Planner { + pub context: PlannerContext, +} + +#[derive(Default, Debug)] +pub struct PlannerContext { + // subquery alias to subquery plan + pub subquery_context: HashMap, +} + +impl PlannerContext { + pub fn find_subquery_alias(&self, plan_ref: &PlanRef) -> Option { + for (alias, p) in &self.subquery_context { + if p == plan_ref { + return Some(alias.clone()); + } + } + None + } +} impl Planner { - pub fn plan(&self, stmt: BoundStatement) -> Result { + pub fn plan(&mut self, stmt: BoundStatement) -> Result { match stmt { BoundStatement::Select(stmt) => self.plan_select(stmt), } @@ -107,7 +129,7 @@ mod planner_test { #[test] fn test_plan_select_works() { let stmt = build_test_select_stmt(); - let p = Planner {}; + let mut p = Planner::default(); let node = p.plan(stmt); assert!(node.is_ok()); let plan_ref = node.unwrap(); @@ -123,7 +145,7 @@ mod planner_test { // inner join t2 on t1.c1=t2.c1 // left join t3 on t2.c1=t3.c1 let stmt = build_test_select_stmt_with_multiple_joins(); - let p = Planner {}; + let mut p = Planner::default(); let node = p.plan(stmt); assert!(node.is_ok()); let plan_ref = node.unwrap(); @@ -174,7 +196,7 @@ mod planner_test { #[test] fn test_plan_select_distinct_works() { let stmt = build_test_select_distinct_stmt(); - let p = Planner {}; + let mut p = Planner::default(); let node = p.plan(stmt); assert!(node.is_ok()); let plan_ref = node.unwrap(); diff --git a/src/planner/select.rs b/src/planner/select.rs index 5cd2870..655dc97 100644 --- a/src/planner/select.rs +++ b/src/planner/select.rs @@ -6,7 +6,7 @@ use crate::binder::{BoundSelect, BoundTableRef}; use crate::optimizer::*; impl Planner { - pub fn plan_select(&self, stmt: BoundSelect) -> Result { + pub fn plan_select(&mut self, stmt: BoundSelect) -> Result { let mut plan: PlanRef; if let Some(table_ref) = stmt.from_table { @@ -48,7 +48,7 @@ impl Planner { Ok(plan) } - fn plan_table_ref(&self, table_ref: &BoundTableRef) -> Result { + fn plan_table_ref(&mut self, table_ref: &BoundTableRef) -> Result { match table_ref { BoundTableRef::Table(table) => Ok(Arc::new(LogicalTableScan::new( table.catalog.id.clone(), @@ -73,7 +73,11 @@ impl Planner { } BoundTableRef::Subquery(subquery) => { let subquery = subquery.clone(); - self.plan_select(*subquery.query) + let plan_ref = self.plan_select(*subquery.query)?; + self.context + .subquery_context + .insert(subquery.alias, plan_ref.clone()); + Ok(plan_ref) } } } From 2c5aa3b61771fcfb85f0d017244885a8933a6f35 Mon Sep 17 00:00:00 2001 From: Fedomn Date: Fri, 21 Oct 2022 13:00:27 +0800 Subject: [PATCH 2/4] refactor(optimizer): fix PushProjectThroughChild rule across subquery Signed-off-by: Fedomn --- src/optimizer/plan_node/dummy.rs | 8 ++ src/optimizer/plan_node/logical_agg.rs | 14 ++- src/optimizer/plan_node/logical_filter.rs | 10 +- src/optimizer/plan_node/logical_join.rs | 46 +++++++- src/optimizer/plan_node/logical_limit.rs | 10 +- src/optimizer/plan_node/logical_order.rs | 10 +- src/optimizer/plan_node/logical_project.rs | 13 ++- src/optimizer/plan_node/logical_table_scan.rs | 15 +++ src/optimizer/plan_node/mod.rs | 9 +- .../plan_node/physical_cross_join.rs | 10 +- src/optimizer/plan_node/physical_filter.rs | 10 +- src/optimizer/plan_node/physical_hash_agg.rs | 10 +- src/optimizer/plan_node/physical_hash_join.rs | 10 +- src/optimizer/plan_node/physical_limit.rs | 10 +- src/optimizer/plan_node/physical_order.rs | 10 +- src/optimizer/plan_node/physical_project.rs | 10 +- .../plan_node/physical_simple_agg.rs | 10 +- .../plan_node/physical_table_scan.rs | 10 +- src/optimizer/rules/column_pruning.rs | 103 +++++++++++------- src/optimizer/rules/util.rs | 7 ++ tests/planner/column-pruning.planner.sql | 22 ++++ tests/planner/column-pruning.yml | 4 + 22 files changed, 305 insertions(+), 56 deletions(-) diff --git a/src/optimizer/plan_node/dummy.rs b/src/optimizer/plan_node/dummy.rs index 48faa61..9e86fb2 100644 --- a/src/optimizer/plan_node/dummy.rs +++ b/src/optimizer/plan_node/dummy.rs @@ -25,6 +25,14 @@ impl PlanNode for Dummy { fn output_columns(&self) -> Vec { vec![] } + + fn output_new_columns(&self, _base_table_id: String) -> Vec { + vec![] + } + + fn get_based_table_id(&self) -> crate::catalog::TableId { + "Dummy".to_string() + } } impl PlanTreeNode for Dummy { diff --git a/src/optimizer/plan_node/logical_agg.rs b/src/optimizer/plan_node/logical_agg.rs index cfd4d11..596293b 100644 --- a/src/optimizer/plan_node/logical_agg.rs +++ b/src/optimizer/plan_node/logical_agg.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{PlanNode, PlanRef, PlanTreeNode}; use crate::binder::BoundExpr; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct LogicalAgg { @@ -46,6 +46,18 @@ impl PlanNode for LogicalAgg { .flat_map(|e| e.get_referenced_column_catalog()) .collect::>() } + + fn output_new_columns(&self, base_table_id: String) -> Vec { + self.group_by + .iter() + .chain(self.agg_funcs.iter()) + .map(|e| e.output_column_catalog_for_alias_table(base_table_id.clone())) + .collect::>() + } + + fn get_based_table_id(&self) -> TableId { + self.children()[0].get_based_table_id() + } } impl PlanTreeNode for LogicalAgg { diff --git a/src/optimizer/plan_node/logical_filter.rs b/src/optimizer/plan_node/logical_filter.rs index 8b9604b..57462c0 100644 --- a/src/optimizer/plan_node/logical_filter.rs +++ b/src/optimizer/plan_node/logical_filter.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{PlanNode, PlanRef, PlanTreeNode}; use crate::binder::BoundExpr; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct LogicalFilter { @@ -35,6 +35,14 @@ impl PlanNode for LogicalFilter { fn output_columns(&self) -> Vec { self.children()[0].output_columns() } + + fn output_new_columns(&self, base_table_id: String) -> Vec { + self.children()[0].output_new_columns(base_table_id.clone()) + } + + fn get_based_table_id(&self) -> TableId { + self.children()[0].get_based_table_id() + } } impl PlanTreeNode for LogicalFilter { diff --git a/src/optimizer/plan_node/logical_join.rs b/src/optimizer/plan_node/logical_join.rs index 340df5c..b357c42 100644 --- a/src/optimizer/plan_node/logical_join.rs +++ b/src/optimizer/plan_node/logical_join.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{PlanNode, PlanRef, PlanTreeNode}; use crate::binder::{JoinCondition, JoinType}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct LogicalJoin { @@ -114,6 +114,42 @@ impl LogicalJoin { vec![left_fields, right_fields].concat() } + + fn join_output_new_columns_internal(&self, base_table_id: String) -> Vec { + let (left_join_keys_force_nullable, right_join_keys_force_nullable) = match self.join_type { + JoinType::Inner => (false, false), + JoinType::Left => (false, true), + JoinType::Right => (true, false), + JoinType::Full => (true, true), + JoinType::Cross => (true, true), + }; + let left_fields = self + .left + .output_new_columns(base_table_id.clone()) + .iter() + .map(|c| { + c.clone_with_nullable( + // if force nullable is false, use the original value + // to handle some original fields that are nullable + left_join_keys_force_nullable || c.nullable, + ) + }) + .collect::>(); + let right_fields = self + .right + .output_new_columns(base_table_id.clone()) + .iter() + .map(|c| { + c.clone_with_nullable( + // if force nullable is false, use the original value + // to handle some original fields that are nullable + right_join_keys_force_nullable || c.nullable, + ) + }) + .collect::>(); + + vec![left_fields, right_fields].concat() + } } impl PlanNode for LogicalJoin { @@ -142,6 +178,14 @@ impl PlanNode for LogicalJoin { fn output_columns(&self) -> Vec { self.join_output_columns() } + + fn output_new_columns(&self, base_table_id: String) -> Vec { + self.join_output_new_columns_internal(base_table_id.clone()) + } + + fn get_based_table_id(&self) -> TableId { + self.children()[0].get_based_table_id() + } } impl PlanTreeNode for LogicalJoin { diff --git a/src/optimizer/plan_node/logical_limit.rs b/src/optimizer/plan_node/logical_limit.rs index 7121f82..8a33879 100644 --- a/src/optimizer/plan_node/logical_limit.rs +++ b/src/optimizer/plan_node/logical_limit.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{PlanNode, PlanRef, PlanTreeNode}; use crate::binder::BoundExpr; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct LogicalLimit { @@ -42,6 +42,14 @@ impl PlanNode for LogicalLimit { fn output_columns(&self) -> Vec { self.children()[0].output_columns() } + + fn output_new_columns(&self, base_table_id: String) -> Vec { + self.children()[0].output_new_columns(base_table_id) + } + + fn get_based_table_id(&self) -> TableId { + self.children()[0].get_based_table_id() + } } impl PlanTreeNode for LogicalLimit { diff --git a/src/optimizer/plan_node/logical_order.rs b/src/optimizer/plan_node/logical_order.rs index f11f8cc..a7ed8c4 100644 --- a/src/optimizer/plan_node/logical_order.rs +++ b/src/optimizer/plan_node/logical_order.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{PlanNode, PlanRef, PlanTreeNode}; use crate::binder::BoundOrderBy; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct LogicalOrder { @@ -36,6 +36,14 @@ impl PlanNode for LogicalOrder { fn output_columns(&self) -> Vec { self.children()[0].output_columns() } + + fn output_new_columns(&self, base_table_id: String) -> Vec { + self.children()[0].output_new_columns(base_table_id) + } + + fn get_based_table_id(&self) -> TableId { + self.children()[0].get_based_table_id() + } } impl PlanTreeNode for LogicalOrder { diff --git a/src/optimizer/plan_node/logical_project.rs b/src/optimizer/plan_node/logical_project.rs index d53d393..d5caaf3 100644 --- a/src/optimizer/plan_node/logical_project.rs +++ b/src/optimizer/plan_node/logical_project.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{PlanNode, PlanRef, PlanTreeNode}; use crate::binder::BoundExpr; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct LogicalProject { @@ -38,6 +38,17 @@ impl PlanNode for LogicalProject { .flat_map(|e| e.get_referenced_column_catalog()) .collect::>() } + + fn output_new_columns(&self, base_table_id: String) -> Vec { + self.exprs + .iter() + .map(|e| e.output_column_catalog_for_alias_table(base_table_id.clone())) + .collect::>() + } + + fn get_based_table_id(&self) -> TableId { + self.children()[0].get_based_table_id() + } } impl PlanTreeNode for LogicalProject { diff --git a/src/optimizer/plan_node/logical_table_scan.rs b/src/optimizer/plan_node/logical_table_scan.rs index 88efe95..c17fc37 100644 --- a/src/optimizer/plan_node/logical_table_scan.rs +++ b/src/optimizer/plan_node/logical_table_scan.rs @@ -65,6 +65,21 @@ impl PlanNode for LogicalTableScan { fn output_columns(&self) -> Vec { self.columns() } + + fn output_new_columns(&self, _: String) -> Vec { + if let Some(alias) = self.table_alias() { + self.columns() + .iter() + .map(|c| c.clone_with_table_id(alias.clone())) + .collect() + } else { + self.columns() + } + } + + fn get_based_table_id(&self) -> TableId { + self.table_id.clone() + } } impl PlanTreeNode for LogicalTableScan { diff --git a/src/optimizer/plan_node/mod.rs b/src/optimizer/plan_node/mod.rs index 877365f..07cfc73 100644 --- a/src/optimizer/plan_node/mod.rs +++ b/src/optimizer/plan_node/mod.rs @@ -41,7 +41,7 @@ pub use physical_simple_agg::*; pub use physical_table_scan::*; pub use plan_node_traits::*; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; /// The common trait over all plan nodes. Used by optimizer framework which will treat all node as /// `dyn PlanNode`. Meanwhile, we split the trait into lots of sub-traits so that we can easily use @@ -50,10 +50,17 @@ pub trait PlanNode: WithPlanNodeType + PlanTreeNode + Downcast + Debug + Display + Send + Sync { /// All columns that appears in BoundExprs from this plan node. + /// FIXME: should changed to returned BoundColumnRef which is more make sense. fn referenced_columns(&self) -> Vec; /// All columns that appears in output RecordBatch from this plan node. + /// FIXME: should changed to returned BoundColumnRef which is more make sense. fn output_columns(&self) -> Vec; + + /// Return output column catalog which converted from `BoundExpr`. + fn output_new_columns(&self, base_table_id: String) -> Vec; + + fn get_based_table_id(&self) -> TableId; } impl_downcast!(PlanNode); diff --git a/src/optimizer/plan_node/physical_cross_join.rs b/src/optimizer/plan_node/physical_cross_join.rs index e02479f..d342afe 100644 --- a/src/optimizer/plan_node/physical_cross_join.rs +++ b/src/optimizer/plan_node/physical_cross_join.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{LogicalJoin, PlanNode, PlanRef, PlanTreeNode}; use crate::binder::JoinType; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalCrossJoin { @@ -40,6 +40,14 @@ impl PlanNode for PhysicalCrossJoin { fn output_columns(&self) -> Vec { self.logical.output_columns() } + + fn output_new_columns(&self, base_table_id: String) -> Vec { + self.logical().output_new_columns(base_table_id.clone()) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() + } } impl PlanTreeNode for PhysicalCrossJoin { diff --git a/src/optimizer/plan_node/physical_filter.rs b/src/optimizer/plan_node/physical_filter.rs index 38a57fd..56dd085 100644 --- a/src/optimizer/plan_node/physical_filter.rs +++ b/src/optimizer/plan_node/physical_filter.rs @@ -2,7 +2,7 @@ use core::fmt; use std::sync::Arc; use super::{LogicalFilter, PlanNode, PlanRef, PlanTreeNode}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalFilter { @@ -27,6 +27,14 @@ impl PlanNode for PhysicalFilter { fn output_columns(&self) -> Vec { self.logical.output_columns() } + + fn output_new_columns(&self, base_table_id: String) -> Vec { + self.logical().output_new_columns(base_table_id.clone()) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() + } } impl PlanTreeNode for PhysicalFilter { diff --git a/src/optimizer/plan_node/physical_hash_agg.rs b/src/optimizer/plan_node/physical_hash_agg.rs index 188ef5f..fb4f17b 100644 --- a/src/optimizer/plan_node/physical_hash_agg.rs +++ b/src/optimizer/plan_node/physical_hash_agg.rs @@ -2,7 +2,7 @@ use std::fmt; use std::sync::Arc; use super::{LogicalAgg, PlanNode, PlanRef, PlanTreeNode}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalHashAgg { @@ -27,6 +27,14 @@ impl PlanNode for PhysicalHashAgg { fn output_columns(&self) -> Vec { self.logical.output_columns() } + + fn output_new_columns(&self, base_table_id: String) -> Vec { + self.logical().output_new_columns(base_table_id.clone()) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() + } } impl PlanTreeNode for PhysicalHashAgg { diff --git a/src/optimizer/plan_node/physical_hash_join.rs b/src/optimizer/plan_node/physical_hash_join.rs index 54141a6..45938e5 100644 --- a/src/optimizer/plan_node/physical_hash_join.rs +++ b/src/optimizer/plan_node/physical_hash_join.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::{LogicalJoin, PlanNode, PlanRef, PlanTreeNode}; use crate::binder::{JoinCondition, JoinType}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalHashJoin { @@ -44,6 +44,14 @@ impl PlanNode for PhysicalHashJoin { fn output_columns(&self) -> Vec { self.logical.output_columns() } + + fn output_new_columns(&self, base_table_id: String) -> Vec { + self.logical().output_new_columns(base_table_id.clone()) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() + } } impl PlanTreeNode for PhysicalHashJoin { diff --git a/src/optimizer/plan_node/physical_limit.rs b/src/optimizer/plan_node/physical_limit.rs index 83e99b0..f24ef21 100644 --- a/src/optimizer/plan_node/physical_limit.rs +++ b/src/optimizer/plan_node/physical_limit.rs @@ -2,7 +2,7 @@ use core::fmt; use std::sync::Arc; use super::{LogicalLimit, PlanNode, PlanRef, PlanTreeNode}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalLimit { @@ -27,6 +27,14 @@ impl PlanNode for PhysicalLimit { fn output_columns(&self) -> Vec { self.logical.output_columns() } + + fn output_new_columns(&self, base_table_id: String) -> Vec { + self.logical().output_new_columns(base_table_id.clone()) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() + } } impl PlanTreeNode for PhysicalLimit { diff --git a/src/optimizer/plan_node/physical_order.rs b/src/optimizer/plan_node/physical_order.rs index fb13256..b7d406b 100644 --- a/src/optimizer/plan_node/physical_order.rs +++ b/src/optimizer/plan_node/physical_order.rs @@ -2,7 +2,7 @@ use core::fmt; use std::sync::Arc; use super::{LogicalOrder, PlanNode, PlanRef, PlanTreeNode}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalOrder { @@ -27,6 +27,14 @@ impl PlanNode for PhysicalOrder { fn output_columns(&self) -> Vec { self.logical.output_columns() } + + fn output_new_columns(&self, base_table_id: String) -> Vec { + self.logical().output_new_columns(base_table_id.clone()) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() + } } impl PlanTreeNode for PhysicalOrder { diff --git a/src/optimizer/plan_node/physical_project.rs b/src/optimizer/plan_node/physical_project.rs index ea5c92e..6899f24 100644 --- a/src/optimizer/plan_node/physical_project.rs +++ b/src/optimizer/plan_node/physical_project.rs @@ -2,7 +2,7 @@ use std::fmt; use std::sync::Arc; use super::{LogicalProject, PlanNode, PlanRef, PlanTreeNode}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalProject { @@ -27,6 +27,14 @@ impl PlanNode for PhysicalProject { fn output_columns(&self) -> Vec { self.logical.output_columns() } + + fn output_new_columns(&self, base_table_id: String) -> Vec { + self.logical().output_new_columns(base_table_id.clone()) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() + } } impl PlanTreeNode for PhysicalProject { diff --git a/src/optimizer/plan_node/physical_simple_agg.rs b/src/optimizer/plan_node/physical_simple_agg.rs index bf78a21..839c43b 100644 --- a/src/optimizer/plan_node/physical_simple_agg.rs +++ b/src/optimizer/plan_node/physical_simple_agg.rs @@ -2,7 +2,7 @@ use std::fmt; use std::sync::Arc; use super::{LogicalAgg, PlanNode, PlanRef, PlanTreeNode}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalSimpleAgg { @@ -27,6 +27,14 @@ impl PlanNode for PhysicalSimpleAgg { fn output_columns(&self) -> Vec { self.logical.output_columns() } + + fn output_new_columns(&self, base_table_id: String) -> Vec { + self.logical().output_new_columns(base_table_id.clone()) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() + } } impl PlanTreeNode for PhysicalSimpleAgg { diff --git a/src/optimizer/plan_node/physical_table_scan.rs b/src/optimizer/plan_node/physical_table_scan.rs index 5521a62..d5d705f 100644 --- a/src/optimizer/plan_node/physical_table_scan.rs +++ b/src/optimizer/plan_node/physical_table_scan.rs @@ -2,7 +2,7 @@ use std::fmt; use std::sync::Arc; use super::{LogicalTableScan, PlanNode, PlanRef, PlanTreeNode}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct PhysicalTableScan { @@ -27,6 +27,14 @@ impl PlanNode for PhysicalTableScan { fn output_columns(&self) -> Vec { self.logical.output_columns() } + + fn output_new_columns(&self, base_table_id: String) -> Vec { + self.logical().output_new_columns(base_table_id.clone()) + } + + fn get_based_table_id(&self) -> TableId { + self.logical().get_based_table_id() + } } impl PlanTreeNode for PhysicalTableScan { diff --git a/src/optimizer/rules/column_pruning.rs b/src/optimizer/rules/column_pruning.rs index d8388c0..bd1174a 100644 --- a/src/optimizer/rules/column_pruning.rs +++ b/src/optimizer/rules/column_pruning.rs @@ -7,8 +7,9 @@ use crate::binder::{BoundColumnRef, BoundExpr}; use crate::optimizer::core::{ OptExpr, OptExprNode, Pattern, PatternChildrenPredicate, Rule, Substitute, }; -use crate::optimizer::rules::util::is_subset_cols; +use crate::optimizer::rules::util::is_superset_cols; use crate::optimizer::{Dummy, LogicalProject, LogicalTableScan, PlanNodeType}; +use crate::planner::PlannerContext; lazy_static! { static ref PUSH_PROJECT_INTO_TABLE_SCAN_RULE: Pattern = { @@ -57,7 +58,7 @@ impl Rule for PushProjectIntoTableScan { &PUSH_PROJECT_INTO_TABLE_SCAN_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { let project_opt_expr_root = opt_expr.root; let table_scan_opt_expr = opt_expr.children[0].clone(); let project_node = project_opt_expr_root @@ -122,7 +123,7 @@ impl Rule for PushProjectThroughChild { &PUSH_PROJECT_THROUGH_CHILD_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, planner_context: &PlannerContext) { let project_opt_expr_root = opt_expr.root; let child_opt_expr = opt_expr.children[0].clone(); @@ -130,50 +131,71 @@ impl Rule for PushProjectThroughChild { let project_cols = project_plan_ref.referenced_columns(); let child_plan_ref = child_opt_expr.root.get_plan_ref(); let child_cols = child_plan_ref.referenced_columns(); - let required_cols = [project_cols, child_cols].concat(); - - let child_children_columns = child_plan_ref + let mut required_cols = [project_cols, child_cols].concat(); + let mut child_children_cols = child_plan_ref .children() .iter() - .flat_map(|c| c.output_columns()) + .flat_map(|c| { + c.output_new_columns( + planner_context + .find_subquery_alias(c) + .unwrap_or(c.get_based_table_id()), + ) + }) .collect::>(); - // if child_children_columns more than required_cols, pushdown extra projection. - if !is_subset_cols(&child_children_columns, &required_cols) { + // distinct cols + required_cols = required_cols.into_iter().unique().collect(); + child_children_cols = child_children_cols.into_iter().unique().collect(); + + // println!("required_cols: {:?}", required_cols); + // println!("child_children_cols: {:?}", child_children_cols); + + // if child_children_cols more than required_cols, pushdown extra projection. + if is_superset_cols(&child_children_cols, &required_cols) { let new_child_opt_expr_children = child_plan_ref .children() .iter() .zip_eq(child_opt_expr.children.iter()) .map(|(child_child_plan, child_child_opt_expr)| { + // Note: resolve base_table_id to calc real ColumnCatalog for subquery + // such as: select a, t2.v1 as max_b from t1 cross join (select max(b) as v1 + // from t1) t2; + // `t2.v1` should be resolved in child_child_plan output_columns. + let base_table_id = planner_context + .find_subquery_alias(child_child_plan) + .unwrap_or(child_child_plan.get_based_table_id()) + .clone(); + let mut child_child_output_cols = + child_child_plan.output_new_columns(base_table_id); + // for child's child, filter corresponding required columns + let mut required_cols_in_child_child = child_child_output_cols + .clone() + .into_iter() + .filter(|c| required_cols.contains(c)) + .collect::>(); + + // distinct cols + child_child_output_cols = + child_child_output_cols.into_iter().unique().collect(); + required_cols_in_child_child = + required_cols_in_child_child.into_iter().unique().collect(); + // println!("child_child_output_cols: {:?}", child_child_output_cols); + // println!( + // "required_cols_in_child_child: {:?}", + // required_cols_in_child_child + // ); + // if child's child cols more than required_cols, pushdown extra projection. - if !is_subset_cols(&child_child_plan.output_columns(), &required_cols) { - // for child's child, filter corresponding required columns - let exprs = child_child_plan - .output_columns() - .iter() - .filter(|c| required_cols.contains(c)) - .map(|c| { - BoundExpr::ColumnRef(BoundColumnRef { - column_catalog: c.clone(), - }) - }) - .collect::>(); - - // FIXME: resolve alias corresponding real ColumnCatalog - // such as: select a, t2.v1 as max_b from t1 cross join (select max(b) as v1 - // from t1) t2; - // t2.v1 should be resolved to t1.b which means this exprs only use t1.b - // column. - // assert!(exprs.is_empty(), "pruned project exprs should not be empty"); - - if exprs.is_empty() { - child_child_opt_expr.clone() - } else { - let new_project = LogicalProject::new(exprs, Dummy::new_ref()); - OptExpr { - root: OptExprNode::PlanRef(Arc::new(new_project)), - children: vec![child_child_opt_expr.clone()], - } + if is_superset_cols(&child_child_output_cols, &required_cols_in_child_child) { + let exprs = required_cols_in_child_child + .into_iter() + .map(|c| BoundExpr::ColumnRef(BoundColumnRef { column_catalog: c })) + .collect(); + let new_project = LogicalProject::new(exprs, Dummy::new_ref()); + OptExpr { + root: OptExprNode::PlanRef(Arc::new(new_project)), + children: vec![child_child_opt_expr.clone()], } } else { child_child_opt_expr.clone() @@ -212,7 +234,7 @@ impl Rule for RemoveNoopOperators { &REMOVE_NOOP_OPERATORS_RULE } - fn apply(&self, opt_expr: OptExpr, result: &mut Substitute) { + fn apply(&self, opt_expr: OptExpr, result: &mut Substitute, _planner_context: &PlannerContext) { // eliminate no-op project for those children type: project{input: project/aggregate} let project_opt_expr_root = opt_expr.root; let project_plan_ref = project_opt_expr_root.get_plan_ref(); @@ -284,7 +306,7 @@ LogicalProject: exprs [t1.a:Nullable(Int32) + 1] HepBatchStrategy::fix_point_topdown(100), vec![PushProjectIntoTableScan::create()], ); - let mut optimizer = HepOptimizer::new(vec![batch], logical_plan); + let mut optimizer = HepOptimizer::new(vec![batch], logical_plan, Default::default()); let optimized_plan = optimizer.find_best(); @@ -349,7 +371,8 @@ LogicalProject: exprs [employee.id:Nullable(Int32), employee.first_name:Nullable HepBatchStrategy::fix_point_topdown(100), vec![RemoveNoopOperators::create()], ); - let mut optimizer = HepOptimizer::new(vec![batch, final_batch], logical_plan); + let mut optimizer = + HepOptimizer::new(vec![batch, final_batch], logical_plan, Default::default()); let optimized_plan = optimizer.find_best(); diff --git a/src/optimizer/rules/util.rs b/src/optimizer/rules/util.rs index ab4d116..bce56ea 100644 --- a/src/optimizer/rules/util.rs +++ b/src/optimizer/rules/util.rs @@ -6,10 +6,17 @@ use crate::catalog::ColumnCatalog; /// Return true when left is subset of right, only compare table_id and column_id, so it's safe to /// used for join output cols with nullable columns. +/// If left equals right, return true. pub fn is_subset_cols(left: &[ColumnCatalog], right: &[ColumnCatalog]) -> bool { left.iter().all(|l| right.contains(l)) } +/// Return true when left is superset of right. +/// If left equals right, return false. +pub fn is_superset_cols(left: &[ColumnCatalog], right: &[ColumnCatalog]) -> bool { + right.iter().all(|r| left.contains(r)) && left.len() > right.len() +} + /// Return true when left is subset of right pub fn is_subset_exprs(left: &[BoundExpr], right: &[BoundExpr]) -> bool { left.iter().all(|l| right.contains(l)) diff --git a/tests/planner/column-pruning.planner.sql b/tests/planner/column-pruning.planner.sql index 267201e..6c80419 100644 --- a/tests/planner/column-pruning.planner.sql +++ b/tests/planner/column-pruning.planner.sql @@ -94,3 +94,25 @@ PhysicalProject: exprs [employee.id:Int64, employee.first_name:Utf8, department. PhysicalTableScan: table: #state, columns: [state_code, state_name] */ +-- PushProjectThroughChild: column pruning across subquery + +select a, t2.v1 as max_b from t1 cross join (select max(b) as v1 from t1) t2 + +/* +original plan: +LogicalProject: exprs [t1.a:Int64, (t2.v1:Int64) as t1.max_b] + LogicalJoin: type Cross, cond None + LogicalTableScan: table: #t1, columns: [a, b, c] + LogicalProject: exprs [((Max(t1.b:Int64):Int64) as t1.v1) as t2.v1] + LogicalAgg: agg_funcs [Max(t1.b:Int64):Int64] group_by [] + LogicalTableScan: table: #t1, columns: [a, b, c] + +optimized plan: +PhysicalProject: exprs [t1.a:Int64, (t2.v1:Int64) as t1.max_b] + PhysicalCrossJoin: type Cross + PhysicalTableScan: table: #t1, columns: [a] + PhysicalProject: exprs [((Max(t1.b:Int64):Int64) as t1.v1) as t2.v1] + PhysicalSimpleAgg: agg_funcs [Max(t1.b:Int64):Int64] group_by [] + PhysicalTableScan: table: #t1, columns: [b] +*/ + diff --git a/tests/planner/column-pruning.yml b/tests/planner/column-pruning.yml index 0fd8751..a15986f 100644 --- a/tests/planner/column-pruning.yml +++ b/tests/planner/column-pruning.yml @@ -25,3 +25,7 @@ desc: | PushProjectThroughChild: column pruning across multiple join +- sql: | + select a, t2.v1 as max_b from t1 cross join (select max(b) as v1 from t1) t2 + desc: | + PushProjectThroughChild: column pruning across subquery From 05545747016bdafec991bf64ed707514b795f2d5 Mon Sep 17 00:00:00 2001 From: Fedomn Date: Fri, 21 Oct 2022 22:01:00 +0800 Subject: [PATCH 3/4] refactor(optimizer): remove original output_columns method in PlanNode Signed-off-by: Fedomn --- src/executor/mod.rs | 8 +- src/optimizer/plan_node/dummy.rs | 4 - src/optimizer/plan_node/logical_agg.rs | 4 - src/optimizer/plan_node/logical_filter.rs | 6 +- src/optimizer/plan_node/logical_join.rs | 83 +++++-------------- src/optimizer/plan_node/logical_limit.rs | 4 - src/optimizer/plan_node/logical_order.rs | 4 - src/optimizer/plan_node/logical_project.rs | 4 - src/optimizer/plan_node/logical_table_scan.rs | 4 - src/optimizer/plan_node/mod.rs | 9 +- .../plan_node/physical_cross_join.rs | 10 +-- src/optimizer/plan_node/physical_filter.rs | 6 +- src/optimizer/plan_node/physical_hash_agg.rs | 6 +- src/optimizer/plan_node/physical_hash_join.rs | 10 +-- src/optimizer/plan_node/physical_limit.rs | 6 +- src/optimizer/plan_node/physical_order.rs | 6 +- src/optimizer/plan_node/physical_project.rs | 6 +- .../plan_node/physical_simple_agg.rs | 6 +- .../plan_node/physical_table_scan.rs | 6 +- src/optimizer/rules/column_pruning.rs | 5 +- src/optimizer/rules/pushdown_predicates.rs | 6 +- src/planner/mod.rs | 7 +- 22 files changed, 60 insertions(+), 150 deletions(-) diff --git a/src/executor/mod.rs b/src/executor/mod.rs index 43db615..a0971bc 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -26,8 +26,8 @@ use self::project::ProjectExecutor; use self::table_scan::TableScanExecutor; use crate::optimizer::{ PhysicalCrossJoin, PhysicalFilter, PhysicalHashAgg, PhysicalHashJoin, PhysicalLimit, - PhysicalOrder, PhysicalProject, PhysicalSimpleAgg, PhysicalTableScan, PlanNode, PlanRef, - PlanTreeNode, PlanVisitor, + PhysicalOrder, PhysicalProject, PhysicalSimpleAgg, PhysicalTableScan, PlanRef, PlanTreeNode, + PlanVisitor, }; use crate::storage::{StorageError, StorageImpl}; @@ -107,7 +107,7 @@ impl PlanVisitor for ExecutorBuilder { right_child: self.visit(plan.right()).unwrap(), join_type: plan.join_type(), join_condition: plan.join_condition(), - join_output_schema: plan.output_columns(), + join_output_schema: plan.join_output_columns(), } .execute(), ) @@ -118,7 +118,7 @@ impl PlanVisitor for ExecutorBuilder { CrossJoinExecutor { left_child: self.visit(plan.left()).unwrap(), right_child: self.visit(plan.right()).unwrap(), - join_output_schema: plan.output_columns(), + join_output_schema: plan.join_output_columns(), } .execute(), ) diff --git a/src/optimizer/plan_node/dummy.rs b/src/optimizer/plan_node/dummy.rs index 9e86fb2..e6f81da 100644 --- a/src/optimizer/plan_node/dummy.rs +++ b/src/optimizer/plan_node/dummy.rs @@ -22,10 +22,6 @@ impl PlanNode for Dummy { vec![] } - fn output_columns(&self) -> Vec { - vec![] - } - fn output_new_columns(&self, _base_table_id: String) -> Vec { vec![] } diff --git a/src/optimizer/plan_node/logical_agg.rs b/src/optimizer/plan_node/logical_agg.rs index 596293b..01c91a4 100644 --- a/src/optimizer/plan_node/logical_agg.rs +++ b/src/optimizer/plan_node/logical_agg.rs @@ -36,10 +36,6 @@ impl LogicalAgg { impl PlanNode for LogicalAgg { fn referenced_columns(&self) -> Vec { - self.output_columns() - } - - fn output_columns(&self) -> Vec { self.group_by .iter() .chain(self.agg_funcs.iter()) diff --git a/src/optimizer/plan_node/logical_filter.rs b/src/optimizer/plan_node/logical_filter.rs index 57462c0..66c2b68 100644 --- a/src/optimizer/plan_node/logical_filter.rs +++ b/src/optimizer/plan_node/logical_filter.rs @@ -32,12 +32,8 @@ impl PlanNode for LogicalFilter { self.expr.get_referenced_column_catalog() } - fn output_columns(&self) -> Vec { - self.children()[0].output_columns() - } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.children()[0].output_new_columns(base_table_id.clone()) + self.children()[0].output_new_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/logical_join.rs b/src/optimizer/plan_node/logical_join.rs index b357c42..079c3f5 100644 --- a/src/optimizer/plan_node/logical_join.rs +++ b/src/optimizer/plan_node/logical_join.rs @@ -30,7 +30,8 @@ impl LogicalJoin { join_condition, join_output_columns: vec![], }; - join.join_output_columns = join.join_output_columns_internal(); + let base_table_id = join.get_based_table_id(); + join.join_output_columns = join.join_output_columns_internal(base_table_id); join } @@ -79,43 +80,7 @@ impl LogicalJoin { /// /// So in the left child schema, b's fields is nullable, therefore we should use left join /// schema directly, rather than set b's fields as non-nullable. - fn join_output_columns_internal(&self) -> Vec { - let (left_join_keys_force_nullable, right_join_keys_force_nullable) = match self.join_type { - JoinType::Inner => (false, false), - JoinType::Left => (false, true), - JoinType::Right => (true, false), - JoinType::Full => (true, true), - JoinType::Cross => (true, true), - }; - let left_fields = self - .left - .output_columns() - .iter() - .map(|c| { - c.clone_with_nullable( - // if force nullable is false, use the original value - // to handle some original fields that are nullable - left_join_keys_force_nullable || c.nullable, - ) - }) - .collect::>(); - let right_fields = self - .right - .output_columns() - .iter() - .map(|c| { - c.clone_with_nullable( - // if force nullable is false, use the original value - // to handle some original fields that are nullable - right_join_keys_force_nullable || c.nullable, - ) - }) - .collect::>(); - - vec![left_fields, right_fields].concat() - } - - fn join_output_new_columns_internal(&self, base_table_id: String) -> Vec { + fn join_output_columns_internal(&self, base_table_id: String) -> Vec { let (left_join_keys_force_nullable, right_join_keys_force_nullable) = match self.join_type { JoinType::Inner => (false, false), JoinType::Left => (false, true), @@ -137,7 +102,7 @@ impl LogicalJoin { .collect::>(); let right_fields = self .right - .output_new_columns(base_table_id.clone()) + .output_new_columns(base_table_id) .iter() .map(|c| { c.clone_with_nullable( @@ -175,12 +140,8 @@ impl PlanNode for LogicalJoin { } } - fn output_columns(&self) -> Vec { - self.join_output_columns() - } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.join_output_new_columns_internal(base_table_id.clone()) + self.join_output_columns_internal(base_table_id) } fn get_based_table_id(&self) -> TableId { @@ -248,8 +209,9 @@ mod tests { let cond = build_join_condition_eq("t1", "b1", "t2", "b1"); let plan = LogicalJoin::new(t1.clone(), t2.clone(), JoinType::Inner, cond.clone()); + let based_table_id = plan.get_based_table_id(); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], false), build_columns_catalog("t2", vec!["a2", "b1", "c2"], false), @@ -259,7 +221,7 @@ mod tests { let plan = LogicalJoin::new(t1.clone(), t2.clone(), JoinType::Left, cond.clone()); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], false), build_columns_catalog("t2", vec!["a2", "b1", "c2"], true), @@ -269,7 +231,7 @@ mod tests { let plan = LogicalJoin::new(t1.clone(), t2.clone(), JoinType::Right, cond.clone()); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), build_columns_catalog("t2", vec!["a2", "b1", "c2"], false), @@ -279,7 +241,7 @@ mod tests { let plan = LogicalJoin::new(t1, t2, JoinType::Full, cond); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id), vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), build_columns_catalog("t2", vec!["a2", "b1", "c2"], true), @@ -326,8 +288,9 @@ mod tests { JoinType::Inner, cond2.clone(), ); + let based_table_id = plan.get_based_table_id(); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], false), @@ -351,7 +314,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], false), @@ -375,7 +338,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), @@ -399,7 +362,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), @@ -424,7 +387,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], false), @@ -448,7 +411,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], false), @@ -472,7 +435,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), @@ -496,7 +459,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), @@ -521,7 +484,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), @@ -545,7 +508,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), @@ -569,7 +532,7 @@ mod tests { cond2.clone(), ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id.clone()), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), @@ -588,7 +551,7 @@ mod tests { cond2, ); assert_eq!( - plan.join_output_columns_internal(), + plan.join_output_columns_internal(based_table_id), vec![ vec![ build_columns_catalog("t1", vec!["a1", "b1", "c1"], true), diff --git a/src/optimizer/plan_node/logical_limit.rs b/src/optimizer/plan_node/logical_limit.rs index 8a33879..a8cf969 100644 --- a/src/optimizer/plan_node/logical_limit.rs +++ b/src/optimizer/plan_node/logical_limit.rs @@ -39,10 +39,6 @@ impl PlanNode for LogicalLimit { vec![] } - fn output_columns(&self) -> Vec { - self.children()[0].output_columns() - } - fn output_new_columns(&self, base_table_id: String) -> Vec { self.children()[0].output_new_columns(base_table_id) } diff --git a/src/optimizer/plan_node/logical_order.rs b/src/optimizer/plan_node/logical_order.rs index a7ed8c4..1a725a1 100644 --- a/src/optimizer/plan_node/logical_order.rs +++ b/src/optimizer/plan_node/logical_order.rs @@ -33,10 +33,6 @@ impl PlanNode for LogicalOrder { .collect::>() } - fn output_columns(&self) -> Vec { - self.children()[0].output_columns() - } - fn output_new_columns(&self, base_table_id: String) -> Vec { self.children()[0].output_new_columns(base_table_id) } diff --git a/src/optimizer/plan_node/logical_project.rs b/src/optimizer/plan_node/logical_project.rs index d5caaf3..38a8ab6 100644 --- a/src/optimizer/plan_node/logical_project.rs +++ b/src/optimizer/plan_node/logical_project.rs @@ -29,10 +29,6 @@ impl LogicalProject { impl PlanNode for LogicalProject { fn referenced_columns(&self) -> Vec { - self.output_columns() - } - - fn output_columns(&self) -> Vec { self.exprs .iter() .flat_map(|e| e.get_referenced_column_catalog()) diff --git a/src/optimizer/plan_node/logical_table_scan.rs b/src/optimizer/plan_node/logical_table_scan.rs index c17fc37..33af2f7 100644 --- a/src/optimizer/plan_node/logical_table_scan.rs +++ b/src/optimizer/plan_node/logical_table_scan.rs @@ -59,10 +59,6 @@ impl LogicalTableScan { impl PlanNode for LogicalTableScan { fn referenced_columns(&self) -> Vec { - self.output_columns() - } - - fn output_columns(&self) -> Vec { self.columns() } diff --git a/src/optimizer/plan_node/mod.rs b/src/optimizer/plan_node/mod.rs index 07cfc73..09c3b5e 100644 --- a/src/optimizer/plan_node/mod.rs +++ b/src/optimizer/plan_node/mod.rs @@ -49,17 +49,14 @@ use crate::catalog::{ColumnCatalog, TableId}; pub trait PlanNode: WithPlanNodeType + PlanTreeNode + Downcast + Debug + Display + Send + Sync { - /// All columns that appears in BoundExprs from this plan node. - /// FIXME: should changed to returned BoundColumnRef which is more make sense. + /// Return column catalog that appears in BoundExprs which used in current PlanNode. fn referenced_columns(&self) -> Vec; - /// All columns that appears in output RecordBatch from this plan node. - /// FIXME: should changed to returned BoundColumnRef which is more make sense. - fn output_columns(&self) -> Vec; - /// Return output column catalog which converted from `BoundExpr`. fn output_new_columns(&self, base_table_id: String) -> Vec; + // Get this PlanNode based TableId which could be TableScan Id or Join left child based table + // id. fn get_based_table_id(&self) -> TableId; } impl_downcast!(PlanNode); diff --git a/src/optimizer/plan_node/physical_cross_join.rs b/src/optimizer/plan_node/physical_cross_join.rs index d342afe..6a10201 100644 --- a/src/optimizer/plan_node/physical_cross_join.rs +++ b/src/optimizer/plan_node/physical_cross_join.rs @@ -30,6 +30,10 @@ impl PhysicalCrossJoin { pub fn logical(&self) -> &LogicalJoin { &self.logical } + + pub fn join_output_columns(&self) -> Vec { + self.logical.join_output_columns() + } } impl PlanNode for PhysicalCrossJoin { @@ -37,12 +41,8 @@ impl PlanNode for PhysicalCrossJoin { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() - } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id.clone()) + self.logical().output_new_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_filter.rs b/src/optimizer/plan_node/physical_filter.rs index 56dd085..8d680ed 100644 --- a/src/optimizer/plan_node/physical_filter.rs +++ b/src/optimizer/plan_node/physical_filter.rs @@ -24,12 +24,8 @@ impl PlanNode for PhysicalFilter { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() - } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id.clone()) + self.logical().output_new_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_hash_agg.rs b/src/optimizer/plan_node/physical_hash_agg.rs index fb4f17b..4ab07c7 100644 --- a/src/optimizer/plan_node/physical_hash_agg.rs +++ b/src/optimizer/plan_node/physical_hash_agg.rs @@ -24,12 +24,8 @@ impl PlanNode for PhysicalHashAgg { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() - } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id.clone()) + self.logical().output_new_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_hash_join.rs b/src/optimizer/plan_node/physical_hash_join.rs index 45938e5..9277144 100644 --- a/src/optimizer/plan_node/physical_hash_join.rs +++ b/src/optimizer/plan_node/physical_hash_join.rs @@ -34,6 +34,10 @@ impl PhysicalHashJoin { pub fn logical(&self) -> &LogicalJoin { &self.logical } + + pub fn join_output_columns(&self) -> Vec { + self.logical.join_output_columns() + } } impl PlanNode for PhysicalHashJoin { @@ -41,12 +45,8 @@ impl PlanNode for PhysicalHashJoin { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() - } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id.clone()) + self.logical().output_new_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_limit.rs b/src/optimizer/plan_node/physical_limit.rs index f24ef21..1785d10 100644 --- a/src/optimizer/plan_node/physical_limit.rs +++ b/src/optimizer/plan_node/physical_limit.rs @@ -24,12 +24,8 @@ impl PlanNode for PhysicalLimit { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() - } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id.clone()) + self.logical().output_new_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_order.rs b/src/optimizer/plan_node/physical_order.rs index b7d406b..37d2fa6 100644 --- a/src/optimizer/plan_node/physical_order.rs +++ b/src/optimizer/plan_node/physical_order.rs @@ -24,12 +24,8 @@ impl PlanNode for PhysicalOrder { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() - } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id.clone()) + self.logical().output_new_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_project.rs b/src/optimizer/plan_node/physical_project.rs index 6899f24..9564d31 100644 --- a/src/optimizer/plan_node/physical_project.rs +++ b/src/optimizer/plan_node/physical_project.rs @@ -24,12 +24,8 @@ impl PlanNode for PhysicalProject { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() - } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id.clone()) + self.logical().output_new_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_simple_agg.rs b/src/optimizer/plan_node/physical_simple_agg.rs index 839c43b..35e166f 100644 --- a/src/optimizer/plan_node/physical_simple_agg.rs +++ b/src/optimizer/plan_node/physical_simple_agg.rs @@ -24,12 +24,8 @@ impl PlanNode for PhysicalSimpleAgg { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() - } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id.clone()) + self.logical().output_new_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_table_scan.rs b/src/optimizer/plan_node/physical_table_scan.rs index d5d705f..b04584e 100644 --- a/src/optimizer/plan_node/physical_table_scan.rs +++ b/src/optimizer/plan_node/physical_table_scan.rs @@ -24,12 +24,8 @@ impl PlanNode for PhysicalTableScan { self.logical.referenced_columns() } - fn output_columns(&self) -> Vec { - self.logical.output_columns() - } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id.clone()) + self.logical().output_new_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/rules/column_pruning.rs b/src/optimizer/rules/column_pruning.rs index bd1174a..1a9b694 100644 --- a/src/optimizer/rules/column_pruning.rs +++ b/src/optimizer/rules/column_pruning.rs @@ -139,7 +139,7 @@ impl Rule for PushProjectThroughChild { c.output_new_columns( planner_context .find_subquery_alias(c) - .unwrap_or(c.get_based_table_id()), + .unwrap_or_else(|| c.get_based_table_id()), ) }) .collect::>(); @@ -164,8 +164,7 @@ impl Rule for PushProjectThroughChild { // `t2.v1` should be resolved in child_child_plan output_columns. let base_table_id = planner_context .find_subquery_alias(child_child_plan) - .unwrap_or(child_child_plan.get_based_table_id()) - .clone(); + .unwrap_or_else(|| child_child_plan.get_based_table_id()); let mut child_child_output_cols = child_child_plan.output_new_columns(base_table_id); // for child's child, filter corresponding required columns diff --git a/src/optimizer/rules/pushdown_predicates.rs b/src/optimizer/rules/pushdown_predicates.rs index 6bdcddf..02b4f12 100644 --- a/src/optimizer/rules/pushdown_predicates.rs +++ b/src/optimizer/rules/pushdown_predicates.rs @@ -113,8 +113,10 @@ impl Rule for PushPredicateThroughJoin { return; } - let left_output_cols = join_node.left().output_columns(); - let right_output_cols = join_node.right().output_columns(); + let left = join_node.left(); + let left_output_cols = left.output_new_columns(left.get_based_table_id()); + let right = join_node.right(); + let right_output_cols = right.output_new_columns(right.get_based_table_id()); let filter_opt_expr = opt_expr; let join_left_opt_expr = join_opt_expr.children[0].clone(); diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 6abc525..aa1aa3a 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -134,7 +134,12 @@ mod planner_test { assert!(node.is_ok()); let plan_ref = node.unwrap(); assert_eq!(plan_ref.node_type(), PlanNodeType::LogicalLimit); - assert_eq!(plan_ref.output_columns().len(), 1); + assert_eq!( + plan_ref + .output_new_columns(plan_ref.get_based_table_id()) + .len(), + 1 + ); dbg!(plan_ref); } From d0cb83fbf7b05df03f27da0fe4db4fef30fe9e53 Mon Sep 17 00:00:00 2001 From: Fedomn Date: Fri, 21 Oct 2022 22:04:13 +0800 Subject: [PATCH 4/4] refactor(optimizer): replace original output_columns method in PlanNode Signed-off-by: Fedomn --- src/optimizer/plan_node/dummy.rs | 2 +- src/optimizer/plan_node/logical_agg.rs | 2 +- src/optimizer/plan_node/logical_filter.rs | 4 ++-- src/optimizer/plan_node/logical_join.rs | 6 +++--- src/optimizer/plan_node/logical_limit.rs | 4 ++-- src/optimizer/plan_node/logical_order.rs | 4 ++-- src/optimizer/plan_node/logical_project.rs | 2 +- src/optimizer/plan_node/logical_table_scan.rs | 2 +- src/optimizer/plan_node/mod.rs | 2 +- src/optimizer/plan_node/physical_cross_join.rs | 4 ++-- src/optimizer/plan_node/physical_filter.rs | 4 ++-- src/optimizer/plan_node/physical_hash_agg.rs | 4 ++-- src/optimizer/plan_node/physical_hash_join.rs | 4 ++-- src/optimizer/plan_node/physical_limit.rs | 4 ++-- src/optimizer/plan_node/physical_order.rs | 4 ++-- src/optimizer/plan_node/physical_project.rs | 4 ++-- src/optimizer/plan_node/physical_simple_agg.rs | 4 ++-- src/optimizer/plan_node/physical_table_scan.rs | 4 ++-- src/optimizer/rules/column_pruning.rs | 4 ++-- src/optimizer/rules/pushdown_predicates.rs | 4 ++-- src/planner/mod.rs | 4 +--- 21 files changed, 37 insertions(+), 39 deletions(-) diff --git a/src/optimizer/plan_node/dummy.rs b/src/optimizer/plan_node/dummy.rs index e6f81da..cc46925 100644 --- a/src/optimizer/plan_node/dummy.rs +++ b/src/optimizer/plan_node/dummy.rs @@ -22,7 +22,7 @@ impl PlanNode for Dummy { vec![] } - fn output_new_columns(&self, _base_table_id: String) -> Vec { + fn output_columns(&self, _base_table_id: String) -> Vec { vec![] } diff --git a/src/optimizer/plan_node/logical_agg.rs b/src/optimizer/plan_node/logical_agg.rs index 01c91a4..e9ce1b5 100644 --- a/src/optimizer/plan_node/logical_agg.rs +++ b/src/optimizer/plan_node/logical_agg.rs @@ -43,7 +43,7 @@ impl PlanNode for LogicalAgg { .collect::>() } - fn output_new_columns(&self, base_table_id: String) -> Vec { + fn output_columns(&self, base_table_id: String) -> Vec { self.group_by .iter() .chain(self.agg_funcs.iter()) diff --git a/src/optimizer/plan_node/logical_filter.rs b/src/optimizer/plan_node/logical_filter.rs index 66c2b68..1d2cbf5 100644 --- a/src/optimizer/plan_node/logical_filter.rs +++ b/src/optimizer/plan_node/logical_filter.rs @@ -32,8 +32,8 @@ impl PlanNode for LogicalFilter { self.expr.get_referenced_column_catalog() } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.children()[0].output_new_columns(base_table_id) + fn output_columns(&self, base_table_id: String) -> Vec { + self.children()[0].output_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/logical_join.rs b/src/optimizer/plan_node/logical_join.rs index 079c3f5..b8c0b85 100644 --- a/src/optimizer/plan_node/logical_join.rs +++ b/src/optimizer/plan_node/logical_join.rs @@ -90,7 +90,7 @@ impl LogicalJoin { }; let left_fields = self .left - .output_new_columns(base_table_id.clone()) + .output_columns(base_table_id.clone()) .iter() .map(|c| { c.clone_with_nullable( @@ -102,7 +102,7 @@ impl LogicalJoin { .collect::>(); let right_fields = self .right - .output_new_columns(base_table_id) + .output_columns(base_table_id) .iter() .map(|c| { c.clone_with_nullable( @@ -140,7 +140,7 @@ impl PlanNode for LogicalJoin { } } - fn output_new_columns(&self, base_table_id: String) -> Vec { + fn output_columns(&self, base_table_id: String) -> Vec { self.join_output_columns_internal(base_table_id) } diff --git a/src/optimizer/plan_node/logical_limit.rs b/src/optimizer/plan_node/logical_limit.rs index a8cf969..46bfd3d 100644 --- a/src/optimizer/plan_node/logical_limit.rs +++ b/src/optimizer/plan_node/logical_limit.rs @@ -39,8 +39,8 @@ impl PlanNode for LogicalLimit { vec![] } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.children()[0].output_new_columns(base_table_id) + fn output_columns(&self, base_table_id: String) -> Vec { + self.children()[0].output_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/logical_order.rs b/src/optimizer/plan_node/logical_order.rs index 1a725a1..b9555d7 100644 --- a/src/optimizer/plan_node/logical_order.rs +++ b/src/optimizer/plan_node/logical_order.rs @@ -33,8 +33,8 @@ impl PlanNode for LogicalOrder { .collect::>() } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.children()[0].output_new_columns(base_table_id) + fn output_columns(&self, base_table_id: String) -> Vec { + self.children()[0].output_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/logical_project.rs b/src/optimizer/plan_node/logical_project.rs index 38a8ab6..3a56edc 100644 --- a/src/optimizer/plan_node/logical_project.rs +++ b/src/optimizer/plan_node/logical_project.rs @@ -35,7 +35,7 @@ impl PlanNode for LogicalProject { .collect::>() } - fn output_new_columns(&self, base_table_id: String) -> Vec { + fn output_columns(&self, base_table_id: String) -> Vec { self.exprs .iter() .map(|e| e.output_column_catalog_for_alias_table(base_table_id.clone())) diff --git a/src/optimizer/plan_node/logical_table_scan.rs b/src/optimizer/plan_node/logical_table_scan.rs index 33af2f7..6335d36 100644 --- a/src/optimizer/plan_node/logical_table_scan.rs +++ b/src/optimizer/plan_node/logical_table_scan.rs @@ -62,7 +62,7 @@ impl PlanNode for LogicalTableScan { self.columns() } - fn output_new_columns(&self, _: String) -> Vec { + fn output_columns(&self, _: String) -> Vec { if let Some(alias) = self.table_alias() { self.columns() .iter() diff --git a/src/optimizer/plan_node/mod.rs b/src/optimizer/plan_node/mod.rs index 09c3b5e..c6bda56 100644 --- a/src/optimizer/plan_node/mod.rs +++ b/src/optimizer/plan_node/mod.rs @@ -53,7 +53,7 @@ pub trait PlanNode: fn referenced_columns(&self) -> Vec; /// Return output column catalog which converted from `BoundExpr`. - fn output_new_columns(&self, base_table_id: String) -> Vec; + fn output_columns(&self, base_table_id: String) -> Vec; // Get this PlanNode based TableId which could be TableScan Id or Join left child based table // id. diff --git a/src/optimizer/plan_node/physical_cross_join.rs b/src/optimizer/plan_node/physical_cross_join.rs index 6a10201..137abff 100644 --- a/src/optimizer/plan_node/physical_cross_join.rs +++ b/src/optimizer/plan_node/physical_cross_join.rs @@ -41,8 +41,8 @@ impl PlanNode for PhysicalCrossJoin { self.logical.referenced_columns() } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id) + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_filter.rs b/src/optimizer/plan_node/physical_filter.rs index 8d680ed..bb27b81 100644 --- a/src/optimizer/plan_node/physical_filter.rs +++ b/src/optimizer/plan_node/physical_filter.rs @@ -24,8 +24,8 @@ impl PlanNode for PhysicalFilter { self.logical.referenced_columns() } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id) + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_hash_agg.rs b/src/optimizer/plan_node/physical_hash_agg.rs index 4ab07c7..977be58 100644 --- a/src/optimizer/plan_node/physical_hash_agg.rs +++ b/src/optimizer/plan_node/physical_hash_agg.rs @@ -24,8 +24,8 @@ impl PlanNode for PhysicalHashAgg { self.logical.referenced_columns() } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id) + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_hash_join.rs b/src/optimizer/plan_node/physical_hash_join.rs index 9277144..fe60023 100644 --- a/src/optimizer/plan_node/physical_hash_join.rs +++ b/src/optimizer/plan_node/physical_hash_join.rs @@ -45,8 +45,8 @@ impl PlanNode for PhysicalHashJoin { self.logical.referenced_columns() } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id) + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_limit.rs b/src/optimizer/plan_node/physical_limit.rs index 1785d10..5f285d6 100644 --- a/src/optimizer/plan_node/physical_limit.rs +++ b/src/optimizer/plan_node/physical_limit.rs @@ -24,8 +24,8 @@ impl PlanNode for PhysicalLimit { self.logical.referenced_columns() } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id) + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_order.rs b/src/optimizer/plan_node/physical_order.rs index 37d2fa6..e3298c4 100644 --- a/src/optimizer/plan_node/physical_order.rs +++ b/src/optimizer/plan_node/physical_order.rs @@ -24,8 +24,8 @@ impl PlanNode for PhysicalOrder { self.logical.referenced_columns() } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id) + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_project.rs b/src/optimizer/plan_node/physical_project.rs index 9564d31..aa3e73b 100644 --- a/src/optimizer/plan_node/physical_project.rs +++ b/src/optimizer/plan_node/physical_project.rs @@ -24,8 +24,8 @@ impl PlanNode for PhysicalProject { self.logical.referenced_columns() } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id) + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_simple_agg.rs b/src/optimizer/plan_node/physical_simple_agg.rs index 35e166f..6b521d4 100644 --- a/src/optimizer/plan_node/physical_simple_agg.rs +++ b/src/optimizer/plan_node/physical_simple_agg.rs @@ -24,8 +24,8 @@ impl PlanNode for PhysicalSimpleAgg { self.logical.referenced_columns() } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id) + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/plan_node/physical_table_scan.rs b/src/optimizer/plan_node/physical_table_scan.rs index b04584e..1633380 100644 --- a/src/optimizer/plan_node/physical_table_scan.rs +++ b/src/optimizer/plan_node/physical_table_scan.rs @@ -24,8 +24,8 @@ impl PlanNode for PhysicalTableScan { self.logical.referenced_columns() } - fn output_new_columns(&self, base_table_id: String) -> Vec { - self.logical().output_new_columns(base_table_id) + fn output_columns(&self, base_table_id: String) -> Vec { + self.logical().output_columns(base_table_id) } fn get_based_table_id(&self) -> TableId { diff --git a/src/optimizer/rules/column_pruning.rs b/src/optimizer/rules/column_pruning.rs index 1a9b694..c761548 100644 --- a/src/optimizer/rules/column_pruning.rs +++ b/src/optimizer/rules/column_pruning.rs @@ -136,7 +136,7 @@ impl Rule for PushProjectThroughChild { .children() .iter() .flat_map(|c| { - c.output_new_columns( + c.output_columns( planner_context .find_subquery_alias(c) .unwrap_or_else(|| c.get_based_table_id()), @@ -166,7 +166,7 @@ impl Rule for PushProjectThroughChild { .find_subquery_alias(child_child_plan) .unwrap_or_else(|| child_child_plan.get_based_table_id()); let mut child_child_output_cols = - child_child_plan.output_new_columns(base_table_id); + child_child_plan.output_columns(base_table_id); // for child's child, filter corresponding required columns let mut required_cols_in_child_child = child_child_output_cols .clone() diff --git a/src/optimizer/rules/pushdown_predicates.rs b/src/optimizer/rules/pushdown_predicates.rs index 02b4f12..a294ea2 100644 --- a/src/optimizer/rules/pushdown_predicates.rs +++ b/src/optimizer/rules/pushdown_predicates.rs @@ -114,9 +114,9 @@ impl Rule for PushPredicateThroughJoin { } let left = join_node.left(); - let left_output_cols = left.output_new_columns(left.get_based_table_id()); + let left_output_cols = left.output_columns(left.get_based_table_id()); let right = join_node.right(); - let right_output_cols = right.output_new_columns(right.get_based_table_id()); + let right_output_cols = right.output_columns(right.get_based_table_id()); let filter_opt_expr = opt_expr; let join_left_opt_expr = join_opt_expr.children[0].clone(); diff --git a/src/planner/mod.rs b/src/planner/mod.rs index aa1aa3a..2348b1e 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -135,9 +135,7 @@ mod planner_test { let plan_ref = node.unwrap(); assert_eq!(plan_ref.node_type(), PlanNodeType::LogicalLimit); assert_eq!( - plan_ref - .output_new_columns(plan_ref.get_based_table_id()) - .len(), + plan_ref.output_columns(plan_ref.get_based_table_id()).len(), 1 ); dbg!(plan_ref);