diff --git a/Makefile b/Makefile index 67a31a8..245af80 100644 --- a/Makefile +++ b/Makefile @@ -31,3 +31,6 @@ clean: run: cargo run --release + +debug: + RUST_BACKTRACE=1 cargo run diff --git a/src/binder/expression/agg_func.rs b/src/binder/expression/agg_func.rs index ed70783..97a3516 100644 --- a/src/binder/expression/agg_func.rs +++ b/src/binder/expression/agg_func.rs @@ -6,7 +6,7 @@ use sqlparser::ast::{Function, FunctionArg, FunctionArgExpr}; use super::BoundExpr; use crate::binder::{BindError, Binder}; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum AggFunc { Count, Sum, @@ -25,7 +25,7 @@ impl fmt::Display for AggFunc { } } -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, Hash)] pub struct BoundAggFunc { pub func: AggFunc, pub exprs: Vec, diff --git a/src/binder/expression/binary_op.rs b/src/binder/expression/binary_op.rs index c306a3f..9d5d2fa 100644 --- a/src/binder/expression/binary_op.rs +++ b/src/binder/expression/binary_op.rs @@ -6,7 +6,7 @@ use sqlparser::ast::{BinaryOperator, Expr}; use super::BoundExpr; use crate::binder::{BindError, Binder, BoundTypeCast}; -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, Hash)] pub struct BoundBinaryOp { pub op: BinaryOperator, pub left: Box, diff --git a/src/binder/expression/mod.rs b/src/binder/expression/mod.rs index 9dc9741..f67046f 100644 --- a/src/binder/expression/mod.rs +++ b/src/binder/expression/mod.rs @@ -9,10 +9,10 @@ use itertools::Itertools; use sqlparser::ast::{Expr, Ident}; use super::{BindError, Binder}; -use crate::catalog::ColumnCatalog; +use crate::catalog::{ColumnCatalog, ColumnId, TableId}; use crate::types::ScalarValue; -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, Hash)] pub enum BoundExpr { Constant(ScalarValue), ColumnRef(BoundColumnRef), @@ -24,6 +24,18 @@ pub enum BoundExpr { } impl BoundExpr { + pub fn nullable(&self) -> bool { + match self { + BoundExpr::Constant(_) => false, + BoundExpr::ColumnRef(e) => e.column_catalog.nullable, + BoundExpr::InputRef(_) => unreachable!(), + BoundExpr::BinaryOp(e) => e.left.nullable() && e.right.nullable(), + BoundExpr::TypeCast(e) => e.expr.nullable(), + BoundExpr::AggFunc(e) => e.exprs[0].nullable(), + BoundExpr::Alias(e) => e.expr.nullable(), + } + } + pub fn return_type(&self) -> Option { match self { BoundExpr::Constant(value) => Some(value.data_type()), @@ -52,51 +64,96 @@ impl BoundExpr { } } - pub fn get_column_catalog(&self) -> Vec { + pub fn get_referenced_column_catalog(&self) -> Vec { match self { BoundExpr::Constant(_) => vec![], BoundExpr::InputRef(_) => vec![], BoundExpr::ColumnRef(column_ref) => vec![column_ref.column_catalog.clone()], BoundExpr::BinaryOp(binary_op) => binary_op .left - .get_column_catalog() + .get_referenced_column_catalog() .into_iter() - .chain(binary_op.right.get_column_catalog().into_iter()) + .chain(binary_op.right.get_referenced_column_catalog().into_iter()) .collect::>(), - BoundExpr::TypeCast(tc) => tc.expr.get_column_catalog(), + BoundExpr::TypeCast(tc) => tc.expr.get_referenced_column_catalog(), BoundExpr::AggFunc(agg) => agg .exprs .iter() - .flat_map(|arg| arg.get_column_catalog()) + .flat_map(|arg| arg.get_referenced_column_catalog()) .collect::>(), - BoundExpr::Alias(alias) => alias.expr.get_column_catalog(), + BoundExpr::Alias(alias) => alias.expr.get_referenced_column_catalog(), } } + + /// Generate a new column catalog in table alias or subquery for outside referenced. + /// Such as `t.v` in subquery: select t.v from (select a as v from t1) t. + pub fn output_column_catalog_for_alias_table(&self, alias_table_id: String) -> ColumnCatalog { + let (column_id, data_type) = match self { + BoundExpr::Constant(e) => (e.to_string(), e.data_type()), + BoundExpr::ColumnRef(e) => ( + e.column_catalog.column_id.clone(), + e.column_catalog.desc.data_type.clone(), + ), + BoundExpr::InputRef(_) => unreachable!(), + BoundExpr::BinaryOp(e) => { + let l = e + .left + .output_column_catalog_for_alias_table(alias_table_id.clone()); + let r = e + .right + .output_column_catalog_for_alias_table(alias_table_id.clone()); + let column_id = format!("{}{}{}", l.column_id, e.op, r.column_id); + let data_type = e.return_type.clone().unwrap(); + (column_id, data_type) + } + BoundExpr::TypeCast(e) => { + let c = e + .expr + .output_column_catalog_for_alias_table(alias_table_id.clone()); + let column_id = format!("{}({})", e.cast_type, c.column_id); + let data_type = e.cast_type.clone(); + (column_id, data_type) + } + BoundExpr::AggFunc(agg) => { + let c = agg.exprs[0].output_column_catalog_for_alias_table(alias_table_id.clone()); + let column_id = format!("{}({})", agg.func, c.column_id); + let data_type = agg.return_type.clone(); + (column_id, data_type) + } + BoundExpr::Alias(e) => { + let column_id = e.column_id.to_string(); + let data_type = e.expr.return_type().unwrap(); + (column_id, data_type) + } + }; + ColumnCatalog::new(alias_table_id, column_id, self.nullable(), data_type) + } } -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, Hash)] pub struct BoundColumnRef { pub column_catalog: ColumnCatalog, } -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, Hash)] pub struct BoundInputRef { /// column index in data chunk pub index: usize, pub return_type: DataType, } -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, Hash)] pub struct BoundTypeCast { /// original expression pub expr: Box, pub cast_type: DataType, } -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, Hash)] pub struct BoundAlias { pub expr: Box, - pub alias: String, + pub column_id: ColumnId, + pub table_id: TableId, } impl Binder { @@ -185,7 +242,13 @@ impl fmt::Debug for BoundExpr { BoundExpr::BinaryOp(binary_op) => write!(f, "{:?}", binary_op), BoundExpr::TypeCast(type_cast) => write!(f, "{:?}", type_cast), BoundExpr::AggFunc(agg_func) => write!(f, "{:?}", agg_func), - BoundExpr::Alias(alias) => write!(f, "{:?} as {}", alias.expr, alias.alias), + BoundExpr::Alias(alias) => { + write!( + f, + "({:?}) as {}.{}", + alias.expr, alias.table_id, alias.column_id + ) + } } } } diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 3cab609..2c6d36b 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -22,6 +22,7 @@ struct BinderContext { /// table_id -> table_catalog tables: HashMap, aliases: HashMap, + subquery_base_index: usize, } impl Binder { @@ -57,6 +58,8 @@ pub enum BindError { AmbiguousColumn(String), #[error("binary operator types mismatch: {0} != {1}")] BinaryOpTypeMismatch(String, String), + #[error("subquery in FROM must have an alias")] + SubqueryMustHaveAlias, } #[cfg(test)] @@ -108,8 +111,8 @@ mod binder_test { BoundStatement::Select(select) => { assert_eq!(select.select_list.len(), 2); assert!(select.from_table.is_some()); - if let BoundTableRef::Table(table_catalog) = select.from_table.unwrap() { - assert_eq!(table_catalog.id, "t1"); + if let BoundTableRef::Table(table) = select.from_table.unwrap() { + assert_eq!(table.catalog.id, "t1"); } } } @@ -352,7 +355,10 @@ pub mod test_util { } pub fn build_table_ref(table_id: &str, columns: Vec<&str>) -> BoundTableRef { - BoundTableRef::Table(build_table_catalog(table_id, columns)) + BoundTableRef::Table(BoundSimpleTable::new( + build_table_catalog(table_id, columns), + None, + )) } pub fn build_table_ref_box(table_id: &str, columns: Vec<&str>) -> Box { diff --git a/src/binder/statement/mod.rs b/src/binder/statement/mod.rs index eca1a9a..66fcc1c 100644 --- a/src/binder/statement/mod.rs +++ b/src/binder/statement/mod.rs @@ -3,7 +3,7 @@ use sqlparser::ast::{Join, JoinOperator, Query, SelectItem, TableWithJoins}; use super::expression::BoundExpr; use super::table::BoundTableRef; -use super::{BindError, Binder, BoundAlias, BoundColumnRef}; +use super::{BindError, Binder, BoundAlias, BoundColumnRef, EMPTY_DATABASE_ID}; #[derive(Debug)] pub enum BoundStatement { @@ -57,6 +57,10 @@ impl Binder { } else { Some(self.bind_table_with_joins(&select.from[0])?) }; + let bound_table_id = from_table + .clone() + .map(|t| t.bound_table_id()) + .unwrap_or_else(|| EMPTY_DATABASE_ID.to_string()); // bind select list let mut select_list = vec![]; @@ -71,7 +75,8 @@ impl Binder { self.context.aliases.insert(alias.to_string(), expr.clone()); select_list.push(BoundExpr::Alias(BoundAlias { expr: Box::new(expr), - alias: alias.to_string().to_lowercase(), + column_id: alias.to_string().to_lowercase(), + table_id: bound_table_id.clone(), })); } SelectItem::QualifiedWildcard(object_name) => { diff --git a/src/binder/table/mod.rs b/src/binder/table/mod.rs index 8128f6f..f2b86a3 100644 --- a/src/binder/table/mod.rs +++ b/src/binder/table/mod.rs @@ -1,30 +1,72 @@ mod join; +mod subquery; + +use std::fmt::{self}; pub use join::*; use sqlparser::ast::{TableFactor, TableWithJoins}; +pub use subquery::*; -use super::{BindError, Binder, BoundSelect}; -use crate::binder::BoundExpr::ColumnRef; +use super::{BindError, Binder}; use crate::catalog::{ColumnCatalog, ColumnId, TableCatalog, TableId}; pub static DEFAULT_DATABASE_NAME: &str = "postgres"; pub static DEFAULT_SCHEMA_NAME: &str = "postgres"; +pub static EMPTY_DATABASE_ID: &str = "empty-database-id"; #[derive(Debug, Clone, PartialEq)] pub enum BoundTableRef { - Table(TableCatalog), + Table(BoundSimpleTable), Join(Join), - Subquery(Box), + Subquery(BoundSubquery), +} + +#[derive(Clone, PartialEq, Eq)] +pub struct BoundSimpleTable { + pub catalog: TableCatalog, + pub alias: Option, +} + +impl BoundSimpleTable { + pub fn new(catalog: TableCatalog, alias: Option) -> Self { + Self { catalog, alias } + } + + pub fn table_id(&self) -> TableId { + self.alias + .clone() + .unwrap_or_else(|| self.catalog.id.clone()) + } + + pub fn schema(&self) -> TableSchema { + let table_id = self.table_id(); + let columns = self + .catalog + .get_all_columns() + .into_iter() + .map(|c| (table_id.clone(), c.column_id)) + .collect(); + TableSchema { columns } + } } impl BoundTableRef { pub fn schema(&self) -> TableSchema { match self { - BoundTableRef::Table(catalog) => TableSchema::new(catalog.clone()), + BoundTableRef::Table(table) => table.schema(), BoundTableRef::Join(join) => { TableSchema::new_from_join(&join.left.schema(), &join.right.schema()) } - BoundTableRef::Subquery(subquery) => subquery.from_table.clone().unwrap().schema(), + BoundTableRef::Subquery(subquery) => subquery.schema(), + } + } + + /// Bound table id, if table alias exists, use alias as id + pub fn bound_table_id(&self) -> TableId { + match self { + BoundTableRef::Table(table) => table.table_id(), + BoundTableRef::Join(join) => join.left.bound_table_id(), + BoundTableRef::Subquery(subquery) => subquery.alias.clone(), } } } @@ -36,10 +78,9 @@ pub struct TableSchema { } impl TableSchema { - pub fn new(table_catalog: TableCatalog) -> Self { + pub fn new_from_columns(columns: Vec) -> Self { Self { - columns: table_catalog - .get_all_columns() + columns: columns .into_iter() .map(|c| (c.table_id, c.column_id)) .collect(), @@ -111,49 +152,67 @@ impl Binder { }; let table_name = table.to_string(); - let table_catalog = self + let mut table_catalog = self .catalog .get_table_by_name(table) .ok_or_else(|| BindError::InvalidTable(table_name.clone()))?; + let mut table_alias = None; if let Some(alias) = alias { - let table_alias = alias.to_string().to_lowercase(); + // add alias table in table catalog for later column binding + // such as: select sum(t.a) as c1 from t1 as t + let table_alias_str = alias.to_string().to_lowercase(); + // we only change column's table_id to table_alias, keep original real table_id + // for storage layer lookup corresponding file + table_catalog = + table_catalog.clone_with_new_column_table_id(table_alias_str.clone()); self.context .tables - .insert(table_alias, table_catalog.clone()); + .insert(table_alias_str.clone(), table_catalog.clone()); + table_alias = Some(table_alias_str); } else { self.context .tables .insert(table_name, table_catalog.clone()); } - Ok(BoundTableRef::Table(table_catalog)) + Ok(BoundTableRef::Table(BoundSimpleTable::new( + table_catalog, + table_alias, + ))) } TableFactor::Derived { lateral: _, subquery, alias, } => { - // handle subquery table - let table = self.bind_select(subquery)?; - if let Some(alias) = alias { - // add subquery into context - let columns = table - .select_list - .iter() - .map(|expr| match expr { - ColumnRef(col) => col.column_catalog.clone(), - _ => { - unreachable!("subquery select list should only contains column ref") - } - }) - .collect::>(); - let table_alias = alias.to_string().to_lowercase(); - let table_catalog = - TableCatalog::new_from_columns(table_alias.clone(), columns); - self.context.tables.insert(table_alias, table_catalog); - } - Ok(BoundTableRef::Subquery(Box::new(table))) + // handle subquery as source + let query = self.bind_select(subquery)?; + let alias = alias + .clone() + .map(|a| a.to_string().to_lowercase()) + .ok_or(BindError::SubqueryMustHaveAlias)?; + let mut subquery = BoundSubquery::new(Box::new(query), alias.clone()); + + // add subquery output columns into context + let subquery_catalog = subquery.gen_table_catalog_for_outside_reference(); + self.context.tables.insert(alias, subquery_catalog); + + // add BoundAlias for all subquery columns + subquery.bind_alias_to_all_columns(); + + Ok(BoundTableRef::Subquery(subquery)) } _other => panic!("unsupported table factor: {:?}", _other), } } } + +impl fmt::Debug for BoundSimpleTable { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let alias = if let Some(alias) = &self.alias { + format!(" as {}", alias) + } else { + "".to_string() + }; + write!(f, r#"{:?}{}"#, self.catalog, alias) + } +} diff --git a/src/binder/table/subquery.rs b/src/binder/table/subquery.rs new file mode 100644 index 0000000..935b529 --- /dev/null +++ b/src/binder/table/subquery.rs @@ -0,0 +1,60 @@ +use crate::binder::{Binder, BoundAlias, BoundExpr, BoundSelect, TableSchema}; +use crate::catalog::{ColumnCatalog, TableCatalog, TableId}; + +#[derive(Clone, Debug, PartialEq)] +pub struct BoundSubquery { + pub query: Box, + /// subquery always has a alias, if not, we will generate a alias number + pub alias: TableId, +} + +impl BoundSubquery { + pub fn new(query: Box, alias: TableId) -> Self { + Self { query, alias } + } + + fn get_output_columns(&self) -> Vec { + self.query + .select_list + .iter() + .map(|expr| expr.output_column_catalog_for_alias_table(self.alias.clone())) + .collect::>() + } + + pub fn gen_table_catalog_for_outside_reference(&self) -> TableCatalog { + let subquery_output_columns = self.get_output_columns(); + TableCatalog::new_from_columns(self.alias.clone(), subquery_output_columns) + } + + pub fn schema(&self) -> TableSchema { + TableSchema::new_from_columns(self.get_output_columns()) + } + + pub fn bind_alias_to_all_columns(&mut self) { + let table_catalog = self.gen_table_catalog_for_outside_reference(); + let column_catalog = table_catalog.get_all_columns(); + let new_subquery_select_list_with_alias = self + .query + .select_list + .iter() + .enumerate() + .map(|(idx, expr)| { + let column_catalog = column_catalog[idx].clone(); + BoundExpr::Alias(BoundAlias { + expr: Box::new(expr.clone()), + column_id: column_catalog.column_id, + table_id: column_catalog.table_id, + }) + }) + .collect::>(); + self.query.select_list = new_subquery_select_list_with_alias; + } +} + +impl Binder { + pub fn gen_subquery_table_id(&mut self) -> String { + let id = format!("subquery_{}", self.context.subquery_base_index); + self.context.subquery_base_index += 1; + id + } +} diff --git a/src/catalog/mod.rs b/src/catalog/mod.rs index 1258a45..08f72dd 100644 --- a/src/catalog/mod.rs +++ b/src/catalog/mod.rs @@ -1,5 +1,6 @@ use std::collections::{BTreeMap, HashMap}; use std::fmt; +use std::hash::Hash; use std::sync::Arc; use arrow::datatypes::{DataType, Field}; @@ -67,6 +68,20 @@ impl TableCatalog { columns: columns_tree, } } + + /// Only change column catalog table id to alias, keep original id + pub fn clone_with_new_column_table_id(&self, table_id: String) -> Self { + let mut columns_tree = BTreeMap::new(); + for c in self.get_all_columns() { + columns_tree.insert(c.column_id.clone(), c.clone_with_table_id(table_id.clone())); + } + TableCatalog { + id: self.id.clone(), + name: self.name.clone(), + column_ids: self.column_ids.clone(), + columns: columns_tree, + } + } } /// use column name as id for simplicity @@ -81,6 +96,32 @@ pub struct ColumnCatalog { } impl ColumnCatalog { + pub fn new( + table_id: TableId, + column_id: ColumnId, + nullable: bool, + data_type: DataType, + ) -> Self { + Self { + table_id, + column_id: column_id.clone(), + nullable, + desc: ColumnDesc { + name: column_id, + data_type, + }, + } + } + + pub fn clone_with_table_id(&self, table_id: TableId) -> Self { + Self { + table_id, + column_id: self.column_id.clone(), + nullable: self.nullable, + desc: self.desc.clone(), + } + } + pub fn clone_with_nullable(&self, nullable: bool) -> ColumnCatalog { let mut c = self.clone(); c.nullable = nullable; @@ -103,7 +144,14 @@ impl PartialEq for ColumnCatalog { } } -#[derive(Debug, Clone, PartialEq, Eq)] +impl Hash for ColumnCatalog { + fn hash(&self, state: &mut H) { + self.table_id.hash(state); + self.column_id.hash(state); + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ColumnDesc { pub name: String, pub data_type: DataType, @@ -124,11 +172,9 @@ impl fmt::Debug for TableCatalog { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - r#"{} {{ - columns: {:?} -}}"#, + r#"{} {{ columns: {:?} }}"#, self.id, - self.get_all_columns() + self.get_all_columns(), ) } } diff --git a/src/db.rs b/src/db.rs index 4cb1e21..aa9f17a 100644 --- a/src/db.rs +++ b/src/db.rs @@ -120,7 +120,7 @@ impl Database { let catalog = storage.get_catalog(); let mut binder = Binder::new(Arc::new(catalog)); let bound_stmt = binder.bind(&stats[0])?; - println!("bound_stmt:\n{:?}\n", bound_stmt); + println!("bound_stmt:\n{:#?}\n", bound_stmt); // 3. convert bound stmts to logical plan let planner = Planner {}; diff --git a/src/executor/evaluator.rs b/src/executor/evaluator.rs index 2e34591..30f5dc0 100644 --- a/src/executor/evaluator.rs +++ b/src/executor/evaluator.rs @@ -54,7 +54,7 @@ impl BoundExpr { Field::new(new_name.as_str(), agg.return_type.clone(), true) } BoundExpr::Alias(alias) => { - let new_name = alias.alias.to_string(); + let new_name = alias.column_id.to_string(); let data_type = alias.expr.return_type().unwrap(); Field::new(new_name.as_str(), data_type, true) } diff --git a/src/optimizer/heuristic/graph.rs b/src/optimizer/heuristic/graph.rs index 74437cb..21743d9 100644 --- a/src/optimizer/heuristic/graph.rs +++ b/src/optimizer/heuristic/graph.rs @@ -201,6 +201,7 @@ mod tests { fn build_logical_table_scan(table_id: &str) -> LogicalTableScan { LogicalTableScan::new( table_id.to_string(), + None, vec![ build_column_catalog(table_id, "c1"), build_column_catalog(table_id, "c2"), diff --git a/src/optimizer/heuristic/matcher.rs b/src/optimizer/heuristic/matcher.rs index 16a20c6..64470b9 100644 --- a/src/optimizer/heuristic/matcher.rs +++ b/src/optimizer/heuristic/matcher.rs @@ -94,6 +94,7 @@ mod tests { fn build_logical_table_scan(table_id: &str) -> LogicalTableScan { LogicalTableScan::new( table_id.to_string(), + None, vec![ build_column_catalog(table_id, "c1"), build_column_catalog(table_id, "c2"), diff --git a/src/optimizer/heuristic/optimizer.rs b/src/optimizer/heuristic/optimizer.rs index 1ac5e1e..dbd4e57 100644 --- a/src/optimizer/heuristic/optimizer.rs +++ b/src/optimizer/heuristic/optimizer.rs @@ -136,6 +136,7 @@ mod tests { fn build_logical_table_scan(table_id: &str) -> LogicalTableScan { LogicalTableScan::new( table_id.to_string(), + None, vec![ build_column_catalog(table_id, "c1"), build_column_catalog(table_id, "c2"), diff --git a/src/optimizer/input_ref_rewriter.rs b/src/optimizer/input_ref_rewriter.rs index 714a400..0509a4c 100644 --- a/src/optimizer/input_ref_rewriter.rs +++ b/src/optimizer/input_ref_rewriter.rs @@ -27,6 +27,25 @@ impl InputRefRewriter { return; } + // Find alias expr in bindings. + if let Some(idx) = self.bindings.iter().position(|e| { + if let BoundExpr::Alias(alias) = e { + let column_catalog = + e.output_column_catalog_for_alias_table(alias.table_id.clone()); + let alias_expr = &BoundExpr::ColumnRef(BoundColumnRef { column_catalog }); + if expr == alias_expr { + return true; + } + } + false + }) { + *expr = BoundExpr::InputRef(BoundInputRef { + index: idx, + return_type: expr.return_type().unwrap(), + }); + return; + } + // If not found in bindings, expand nested expr and then continuity rewrite_expr. match expr { BoundExpr::BinaryOp(e) => { @@ -314,6 +333,7 @@ mod input_ref_rewriter_test { fn build_logical_table_scan(table_id: &str) -> LogicalTableScan { LogicalTableScan::new( table_id.to_string(), + None, vec![ build_column_catalog(table_id, "c1"), build_column_catalog(table_id, "c2"), diff --git a/src/optimizer/physical_rewriter.rs b/src/optimizer/physical_rewriter.rs index 5259baa..dcca9bb 100644 --- a/src/optimizer/physical_rewriter.rs +++ b/src/optimizer/physical_rewriter.rs @@ -111,7 +111,7 @@ mod physical_rewriter_test { ] .to_vec(); let mut plan: PlanRef; - plan = Arc::new(LogicalTableScan::new(table_id, columns, None, None)); + plan = Arc::new(LogicalTableScan::new(table_id, None, columns, None, None)); let filter_expr = BoundExpr::BinaryOp(BoundBinaryOp { op: BinaryOperator::Eq, left: Box::new(BoundExpr::ColumnRef(BoundColumnRef { diff --git a/src/optimizer/plan_node/logical_agg.rs b/src/optimizer/plan_node/logical_agg.rs index 3d36e13..cfd4d11 100644 --- a/src/optimizer/plan_node/logical_agg.rs +++ b/src/optimizer/plan_node/logical_agg.rs @@ -43,7 +43,7 @@ impl PlanNode for LogicalAgg { self.group_by .iter() .chain(self.agg_funcs.iter()) - .flat_map(|e| e.get_column_catalog()) + .flat_map(|e| e.get_referenced_column_catalog()) .collect::>() } } diff --git a/src/optimizer/plan_node/logical_filter.rs b/src/optimizer/plan_node/logical_filter.rs index 0c41612..8b9604b 100644 --- a/src/optimizer/plan_node/logical_filter.rs +++ b/src/optimizer/plan_node/logical_filter.rs @@ -29,7 +29,7 @@ impl LogicalFilter { impl PlanNode for LogicalFilter { fn referenced_columns(&self) -> Vec { - self.expr.get_column_catalog() + self.expr.get_referenced_column_catalog() } fn output_columns(&self) -> Vec { diff --git a/src/optimizer/plan_node/logical_join.rs b/src/optimizer/plan_node/logical_join.rs index 45ff931..340df5c 100644 --- a/src/optimizer/plan_node/logical_join.rs +++ b/src/optimizer/plan_node/logical_join.rs @@ -122,9 +122,17 @@ impl PlanNode for LogicalJoin { JoinCondition::On { on, filter } => { let on_cols = on .iter() - .flat_map(|e| [e.0.get_column_catalog(), e.1.get_column_catalog()].concat()) + .flat_map(|e| { + [ + e.0.get_referenced_column_catalog(), + e.1.get_referenced_column_catalog(), + ] + .concat() + }) .collect::>(); - let filter_cols = filter.map(|f| f.get_column_catalog()).unwrap_or_default(); + let filter_cols = filter + .map(|f| f.get_referenced_column_catalog()) + .unwrap_or_default(); [on_cols, filter_cols].concat() } JoinCondition::None => vec![], @@ -181,12 +189,14 @@ mod tests { fn test_join_output_schema_when_two_tables() { let t1 = Arc::new(LogicalTableScan::new( "t1".to_string(), + None, build_columns_catalog("t1", vec!["a1", "b1", "c1"], false), None, None, )); let t2 = Arc::new(LogicalTableScan::new( "t2".to_string(), + None, build_columns_catalog("t2", vec!["a2", "b1", "c2"], false), None, None, @@ -238,18 +248,21 @@ mod tests { fn test_join_output_schema_when_three_tables() { let t1 = Arc::new(LogicalTableScan::new( "t1".to_string(), + None, build_columns_catalog("t1", vec!["a1", "b1", "c1"], false), None, None, )); let t2 = Arc::new(LogicalTableScan::new( "t2".to_string(), + None, build_columns_catalog("t2", vec!["a2", "b1", "c2"], false), None, None, )); let t3 = Arc::new(LogicalTableScan::new( "t3".to_string(), + None, build_columns_catalog("t3", vec!["a3", "b3", "c1"], false), None, None, diff --git a/src/optimizer/plan_node/logical_order.rs b/src/optimizer/plan_node/logical_order.rs index 358fab4..f11f8cc 100644 --- a/src/optimizer/plan_node/logical_order.rs +++ b/src/optimizer/plan_node/logical_order.rs @@ -29,7 +29,7 @@ impl PlanNode for LogicalOrder { fn referenced_columns(&self) -> Vec { self.order_by .iter() - .flat_map(|e| e.expr.get_column_catalog()) + .flat_map(|e| e.expr.get_referenced_column_catalog()) .collect::>() } diff --git a/src/optimizer/plan_node/logical_project.rs b/src/optimizer/plan_node/logical_project.rs index a5ec869..d53d393 100644 --- a/src/optimizer/plan_node/logical_project.rs +++ b/src/optimizer/plan_node/logical_project.rs @@ -35,7 +35,7 @@ impl PlanNode for LogicalProject { fn output_columns(&self) -> Vec { self.exprs .iter() - .flat_map(|e| e.get_column_catalog()) + .flat_map(|e| e.get_referenced_column_catalog()) .collect::>() } } diff --git a/src/optimizer/plan_node/logical_table_scan.rs b/src/optimizer/plan_node/logical_table_scan.rs index 2ecc252..88efe95 100644 --- a/src/optimizer/plan_node/logical_table_scan.rs +++ b/src/optimizer/plan_node/logical_table_scan.rs @@ -7,6 +7,7 @@ use crate::catalog::{ColumnCatalog, TableId}; #[derive(Debug, Clone)] pub struct LogicalTableScan { table_id: TableId, + table_alias: Option, columns: Vec, /// optional bounds of the reader, of the form (offset, limit). bounds: Option<(usize, usize)>, @@ -17,12 +18,14 @@ pub struct LogicalTableScan { impl LogicalTableScan { pub fn new( table_id: TableId, + table_alias: Option, columns: Vec, bounds: Option<(usize, usize)>, projections: Option>, ) -> Self { Self { table_id, + table_alias, columns, bounds, projections, @@ -33,6 +36,10 @@ impl LogicalTableScan { self.table_id.clone() } + pub fn table_alias(&self) -> Option { + self.table_alias.clone() + } + pub fn column_ids(&self) -> Vec { self.columns.iter().map(|c| c.column_id.clone()).collect() } @@ -77,10 +84,15 @@ impl fmt::Display for LogicalTableScan { .bounds() .map(|b| format!(", bounds: (offset:{},limit:{})", b.0, b.1)) .unwrap_or_else(|| "".into()); + let alias = self + .table_alias() + .map(|alias| format!(" as {}", alias)) + .unwrap_or_else(|| "".into()); writeln!( f, - "LogicalTableScan: table: #{}, columns: [{}]{}", + "LogicalTableScan: table: #{}{}, columns: [{}]{}", self.table_id(), + alias, self.column_ids().join(", "), bounds_str, ) diff --git a/src/optimizer/plan_node/physical_table_scan.rs b/src/optimizer/plan_node/physical_table_scan.rs index 81b9ca1..5521a62 100644 --- a/src/optimizer/plan_node/physical_table_scan.rs +++ b/src/optimizer/plan_node/physical_table_scan.rs @@ -47,10 +47,16 @@ impl fmt::Display for PhysicalTableScan { .bounds() .map(|b| format!(", bounds: (offset:{},limit:{})", b.0, b.1)) .unwrap_or_else(|| "".into()); + let alias = self + .logical() + .table_alias() + .map(|alias| format!(" as {}", alias)) + .unwrap_or_else(|| "".into()); writeln!( f, - "PhysicalTableScan: table: #{}, columns: [{}]{}", + "PhysicalTableScan: table: #{}{}, columns: [{}]{}", self.logical().table_id(), + alias, self.logical().column_ids().join(", "), bounds_str, ) diff --git a/src/optimizer/plan_node/plan_node_traits.rs b/src/optimizer/plan_node/plan_node_traits.rs index 94d54c8..b5e3ec2 100644 --- a/src/optimizer/plan_node/plan_node_traits.rs +++ b/src/optimizer/plan_node/plan_node_traits.rs @@ -100,6 +100,7 @@ mod tests { fn build_logical_table_scan(table_id: &str) -> LogicalTableScan { LogicalTableScan::new( table_id.to_string(), + None, vec![ build_column_catalog(table_id, "c1"), build_column_catalog(table_id, "c2"), diff --git a/src/optimizer/rules/column_pruning.rs b/src/optimizer/rules/column_pruning.rs index 341a991..f2375b7 100644 --- a/src/optimizer/rules/column_pruning.rs +++ b/src/optimizer/rules/column_pruning.rs @@ -82,7 +82,7 @@ impl Rule for PushProjectIntoTableScan { let columns = project_node .exprs() .iter() - .flat_map(|e| e.get_column_catalog()) + .flat_map(|e| e.get_referenced_column_catalog()) .collect::>(); let original_columns = table_scan_node.columns(); let projections = columns @@ -92,6 +92,7 @@ impl Rule for PushProjectIntoTableScan { let new_table_scan_node = LogicalTableScan::new( table_scan_node.table_id(), + table_scan_node.table_alias(), columns, table_scan_node.bounds(), Some(projections), @@ -253,7 +254,7 @@ mod tests { name: "should not push when project has alias", sql: "select a as c1 from t1", expect: r" -LogicalProject: exprs [t1.a:Nullable(Int32) as c1] +LogicalProject: exprs [(t1.a:Nullable(Int32)) as t1.c1] LogicalTableScan: table: #t1, columns: [a, b, c]", }, RuleTest { diff --git a/src/optimizer/rules/pushdown_limit.rs b/src/optimizer/rules/pushdown_limit.rs index 6642143..077f2e6 100644 --- a/src/optimizer/rules/pushdown_limit.rs +++ b/src/optimizer/rules/pushdown_limit.rs @@ -262,6 +262,7 @@ impl Rule for PushLimitIntoTableScan { let new_table_scan_node = LogicalTableScan::new( table_scan_node.table_id(), + table_scan_node.table_alias(), table_scan_node.columns(), Some(bounds), table_scan_node.projections(), diff --git a/src/optimizer/rules/pushdown_predicates.rs b/src/optimizer/rules/pushdown_predicates.rs index 5ceeb93..de81882 100644 --- a/src/optimizer/rules/pushdown_predicates.rs +++ b/src/optimizer/rules/pushdown_predicates.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::sync::Arc; use std::vec; @@ -6,8 +7,10 @@ use sqlparser::ast::BinaryOperator; use super::util::is_subset_cols; use super::RuleImpl; -use crate::binder::{BoundBinaryOp, BoundExpr, JoinType}; +use crate::binder::{BoundBinaryOp, BoundColumnRef, BoundExpr, JoinType}; +use crate::catalog::ColumnCatalog; use crate::optimizer::core::*; +use crate::optimizer::expr_rewriter::ExprRewriter; use crate::optimizer::{Dummy, LogicalFilter, LogicalJoin, PlanNodeType}; lazy_static! { @@ -126,10 +129,10 @@ impl Rule for PushPredicateThroughJoin { let filter_exprs = self.split_conjunctive_predicates(&filter_expr); let (left_filters, rest): (Vec<_>, Vec<_>) = filter_exprs .into_iter() - .partition(|f| is_subset_cols(&f.get_column_catalog(), &left_output_cols)); + .partition(|f| is_subset_cols(&f.get_referenced_column_catalog(), &left_output_cols)); let (right_filters, common_filters): (Vec<_>, Vec<_>) = rest .into_iter() - .partition(|f| is_subset_cols(&f.get_column_catalog(), &right_output_cols)); + .partition(|f| is_subset_cols(&f.get_referenced_column_catalog(), &right_output_cols)); match join_node.join_type() { JoinType::Inner => { @@ -207,14 +210,55 @@ impl Rule for PushPredicateThroughNonJoin { match child_node.node_type() { PlanNodeType::LogicalProject => { - // TODO: handle column alias let project_opt_expr = child_opt_expr; + let project_node = project_opt_expr + .root + .get_plan_ref() + .as_logical_project() + .unwrap(); + // handle column alias. + // such as: select t.a from (select * from t1 where a > 1) t where t.b > 7; + let mut alias_map = HashMap::new(); + let project_exprs = project_node.exprs(); + for expr in project_exprs.iter() { + if let BoundExpr::Alias(e) = expr { + let column_catalog = ColumnCatalog::new( + e.table_id.clone(), + e.column_id.clone(), + expr.nullable(), + expr.return_type().unwrap(), + ); + let a = BoundExpr::ColumnRef(BoundColumnRef { column_catalog }); + alias_map.insert(a, e.expr.clone()); + } + } + + let mut filter_expr = filter_opt_expr + .root + .get_plan_ref() + .as_logical_filter() + .unwrap() + .expr(); + + // rewrite alias column to real expr + struct AliasRewriter(HashMap>); + impl ExprRewriter for AliasRewriter { + fn rewrite_column_ref(&self, e: &mut BoundExpr) { + if self.0.contains_key(e) { + *e = *self.0.get(e).unwrap().clone(); + } + } + } + AliasRewriter(alias_map).rewrite_expr(&mut filter_expr); + + let new_filter_opt_expr = OptExprNode::PlanRef(Arc::new(LogicalFilter::new( + filter_expr, + Dummy::new_ref(), + ))); + let res = OptExpr::new( project_opt_expr.root, - vec![OptExpr::new( - filter_opt_expr.root, - project_opt_expr.children, - )], + vec![OptExpr::new(new_filter_opt_expr, project_opt_expr.children)], ); result.opt_exprs.push(res); } diff --git a/src/planner/select.rs b/src/planner/select.rs index 61f5993..5cd2870 100644 --- a/src/planner/select.rs +++ b/src/planner/select.rs @@ -50,9 +50,10 @@ impl Planner { fn plan_table_ref(&self, table_ref: &BoundTableRef) -> Result { match table_ref { - BoundTableRef::Table(table_catalog) => Ok(Arc::new(LogicalTableScan::new( - table_catalog.id.clone(), - table_catalog.get_all_columns(), + BoundTableRef::Table(table) => Ok(Arc::new(LogicalTableScan::new( + table.catalog.id.clone(), + table.alias.clone(), + table.catalog.get_all_columns(), None, None, ))), @@ -72,7 +73,7 @@ impl Planner { } BoundTableRef::Subquery(subquery) => { let subquery = subquery.clone(); - self.plan_select(*subquery) + self.plan_select(*subquery.query) } } } diff --git a/tests/planner/combine-operators.planner.sql b/tests/planner/combine-operators.planner.sql index edf3c0d..bed0da5 100644 --- a/tests/planner/combine-operators.planner.sql +++ b/tests/planner/combine-operators.planner.sql @@ -1,20 +1,22 @@ -- CollapseProject & CombineFilter: combine adjacent projects and filters into one -select * from (select * from (select * from t1 where c < 2) where a > 1) where b > 7; +select t_2.* from (select t_1.* from (select * from t1 where c < 2) t_1 where t_1.a > 1) t_2 where t_2.b > 7; /* original plan: -LogicalProject: exprs [t1.a:Int64, t1.b:Int64, t1.c:Int64] - LogicalFilter: expr t1.b:Int64 > Cast(7 as Int64) - LogicalProject: exprs [t1.a:Int64, t1.b:Int64, t1.c:Int64] - LogicalFilter: expr t1.a:Int64 > Cast(1 as Int64) - LogicalProject: exprs [t1.a:Int64, t1.b:Int64, t1.c:Int64] +LogicalProject: exprs [t_2.a:Int64, t_2.b:Int64, t_2.c:Int64] + LogicalFilter: expr t_2.b:Int64 > Cast(7 as Int64) + LogicalProject: exprs [(t_1.a:Int64) as t_2.a, (t_1.b:Int64) as t_2.b, (t_1.c:Int64) as t_2.c] + LogicalFilter: expr t_1.a:Int64 > Cast(1 as Int64) + LogicalProject: exprs [(t1.a:Int64) as t_1.a, (t1.b:Int64) as t_1.b, (t1.c:Int64) as t_1.c] LogicalFilter: expr t1.c:Int64 < Cast(2 as Int64) LogicalTableScan: table: #t1, columns: [a, b, c] optimized plan: -PhysicalProject: exprs [t1.a:Int64, t1.b:Int64, t1.c:Int64] - PhysicalFilter: expr t1.b:Int64 > 7 AND t1.a:Int64 > 1 AND t1.c:Int64 < 2 - PhysicalTableScan: table: #t1, columns: [a, b, c] +PhysicalProject: exprs [t_2.a:Int64, t_2.b:Int64, t_2.c:Int64] + PhysicalProject: exprs [(t_1.a:Int64) as t_2.a, (t_1.b:Int64) as t_2.b, (t_1.c:Int64) as t_2.c] + PhysicalProject: exprs [(t1.a:Int64) as t_1.a, (t1.b:Int64) as t_1.b, (t1.c:Int64) as t_1.c] + PhysicalFilter: expr t1.b:Int64 > 7 AND t1.a:Int64 > 1 AND t1.c:Int64 < 2 + PhysicalTableScan: table: #t1, columns: [a, b, c] */ diff --git a/tests/planner/combine-operators.yml b/tests/planner/combine-operators.yml index 7424bfb..64cdd82 100644 --- a/tests/planner/combine-operators.yml +++ b/tests/planner/combine-operators.yml @@ -1,4 +1,4 @@ - sql: | - select * from (select * from (select * from t1 where c < 2) where a > 1) where b > 7; + select t_2.* from (select t_1.* from (select * from t1 where c < 2) t_1 where t_1.a > 1) t_2 where t_2.b > 7; desc: | CollapseProject & CombineFilter: combine adjacent projects and filters into one diff --git a/tests/planner/predicate-pushdown.planner.sql b/tests/planner/predicate-pushdown.planner.sql index 99334fb..246af31 100644 --- a/tests/planner/predicate-pushdown.planner.sql +++ b/tests/planner/predicate-pushdown.planner.sql @@ -128,3 +128,22 @@ PhysicalProject: exprs [t1.a:Int64, t1.b:Int64, t1.c:Int64] PhysicalTableScan: table: #t2, columns: [a, b] */ +-- PushPredicateThroughNonJoin: pushdown filter with column alias + +select t.a from (select * from t1 where a > 1) t where t.b > 7 + +/* +original plan: +LogicalProject: exprs [t.a:Int64] + LogicalFilter: expr t.b:Int64 > Cast(7 as Int64) + LogicalProject: exprs [(t1.a:Int64) as t.a, (t1.b:Int64) as t.b, (t1.c:Int64) as t.c] + LogicalFilter: expr t1.a:Int64 > Cast(1 as Int64) + LogicalTableScan: table: #t1, columns: [a, b, c] + +optimized plan: +PhysicalProject: exprs [t.a:Int64] + PhysicalProject: exprs [(t1.a:Int64) as t.a, (t1.b:Int64) as t.b, (t1.c:Int64) as t.c] + PhysicalFilter: expr t1.b:Int64 > 7 AND t1.a:Int64 > 1 + PhysicalTableScan: table: #t1, columns: [a, b, c] +*/ + diff --git a/tests/planner/predicate-pushdown.yml b/tests/planner/predicate-pushdown.yml index a18e8a9..f19e8fb 100644 --- a/tests/planner/predicate-pushdown.yml +++ b/tests/planner/predicate-pushdown.yml @@ -28,6 +28,11 @@ desc: | PushPredicateThroughJoin: don't pushdown filters for right outer join +- sql: | + select t.a from (select * from t1 where a > 1) t where t.b > 7 + desc: | + PushPredicateThroughNonJoin: pushdown filter with column alias + diff --git a/tests/slt/alias.slt b/tests/slt/alias.slt index e087982..a8901fb 100644 --- a/tests/slt/alias.slt +++ b/tests/slt/alias.slt @@ -51,3 +51,10 @@ query III select t.* from (select * from t1 where a > 1) t where t.b > 7; ---- 2 8 1 + + +query I +select t.v1 + 1 from (select a + 1 as v1 from t1 where a > 1) t; +---- +4 +4 diff --git a/tests/slt/subquery.slt b/tests/slt/subquery.slt index fe7af17..fce8229 100644 --- a/tests/slt/subquery.slt +++ b/tests/slt/subquery.slt @@ -1,14 +1,21 @@ -query III +# subquery in FROM must have an alias. same behavior as Postgres +statement error select * from (select * from t1 where a > 1) where b > 7; + +# TODO: handle multi-layer binder context to resolve current context all columns +# select * from (select * from t1 where a > 1) t where t.b > 7; + +query III +select t.* from (select * from t1 where a > 1) t where t.b > 7; ---- 2 8 1 query II -select b from (select a, b from t1 where a > 1) where b > 7; +select t.b from (select a, b from t1 where a > 1) t where t.b > 7; ---- 8 query III -select * from (select * from (select * from t1 where c < 2) where a > 1) where b > 7; +select t_2.* from (select t_1.* from (select * from t1 where c < 2) t_1 where t_1.a > 1) t_2 where t_2.b > 7; ---- 2 8 1