diff --git a/src/binder/expression/mod.rs b/src/binder/expression/mod.rs index ce790df..d68102c 100644 --- a/src/binder/expression/mod.rs +++ b/src/binder/expression/mod.rs @@ -16,11 +16,11 @@ use crate::types::ScalarValue; pub enum BoundExpr { Constant(ScalarValue), ColumnRef(BoundColumnRef), - /// InputRef represents an index of the RecordBatch, which is resolved in optimizer. InputRef(BoundInputRef), BinaryOp(BoundBinaryOp), TypeCast(BoundTypeCast), AggFunc(BoundAggFunc), + Alias(BoundAlias), } impl BoundExpr { @@ -34,6 +34,7 @@ impl BoundExpr { BoundExpr::BinaryOp(binary_op) => binary_op.return_type.clone(), BoundExpr::TypeCast(tc) => Some(tc.cast_type.clone()), BoundExpr::AggFunc(agg) => Some(agg.return_type.clone()), + BoundExpr::Alias(alias) => alias.expr.return_type(), } } @@ -47,6 +48,7 @@ impl BoundExpr { } BoundExpr::TypeCast(tc) => tc.expr.contains_column_ref(), BoundExpr::AggFunc(agg) => agg.exprs.iter().any(|arg| arg.contains_column_ref()), + BoundExpr::Alias(alias) => alias.expr.contains_column_ref(), } } @@ -67,6 +69,7 @@ impl BoundExpr { .iter() .flat_map(|arg| arg.get_column_catalog()) .collect::>(), + BoundExpr::Alias(alias) => alias.expr.get_column_catalog(), } } } @@ -90,6 +93,12 @@ pub struct BoundTypeCast { pub cast_type: DataType, } +#[derive(Clone, PartialEq, Eq)] +pub struct BoundAlias { + pub expr: Box, + pub alias: String, +} + impl Binder { /// bind sqlparser Expr into BoundExpr pub fn bind_expr(&mut self, expr: &Expr) -> Result { @@ -150,6 +159,12 @@ impl Binder { got_column = Some(column_catalog); } } + // handle col alias + if got_column.is_none() { + if let Some(expr) = self.context.aliases.get(column_name) { + return Ok(expr.clone()); + } + } let column_catalog = got_column.ok_or_else(|| BindError::InvalidColumn(column_name.clone()))?; Ok(BoundExpr::ColumnRef(BoundColumnRef { column_catalog })) @@ -166,6 +181,7 @@ impl fmt::Debug for BoundExpr { BoundExpr::BinaryOp(binary_op) => write!(f, "{:?}", binary_op), BoundExpr::TypeCast(type_cast) => write!(f, "{:?}", type_cast), BoundExpr::AggFunc(agg_func) => write!(f, "{:?}", agg_func), + BoundExpr::Alias(alias) => write!(f, "{:?} as {}", alias.expr, alias.alias), } } } diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 85c5745..3cab609 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -21,6 +21,7 @@ struct BinderContext { /// table_name == table_id /// table_id -> table_catalog tables: HashMap, + aliases: HashMap, } impl Binder { diff --git a/src/binder/statement/mod.rs b/src/binder/statement/mod.rs index 59999c3..f655d4f 100644 --- a/src/binder/statement/mod.rs +++ b/src/binder/statement/mod.rs @@ -3,7 +3,7 @@ use sqlparser::ast::{Query, SelectItem}; use super::expression::BoundExpr; use super::table::BoundTableRef; -use super::{BindError, Binder, BoundColumnRef}; +use super::{BindError, Binder, BoundAlias, BoundColumnRef}; #[derive(Debug)] pub enum BoundStatement { @@ -50,7 +50,14 @@ impl Binder { let expr = self.bind_expr(expr)?; select_list.push(expr); } - SelectItem::ExprWithAlias { expr: _, alias: _ } => todo!(), + SelectItem::ExprWithAlias { expr, alias } => { + let expr = self.bind_expr(expr)?; + self.context.aliases.insert(alias.to_string(), expr.clone()); + select_list.push(BoundExpr::Alias(BoundAlias { + expr: Box::new(expr), + alias: alias.to_string().to_lowercase(), + })); + } SelectItem::QualifiedWildcard(object_name) => { let qualifier = format!("{}", object_name); select_list.extend_from_slice( diff --git a/src/executor/evaluator.rs b/src/executor/evaluator.rs index f32300f..2e34591 100644 --- a/src/executor/evaluator.rs +++ b/src/executor/evaluator.rs @@ -22,6 +22,7 @@ impl BoundExpr { BoundExpr::ColumnRef(_) => panic!("column ref should be resolved"), BoundExpr::TypeCast(tc) => Ok(cast(&tc.expr.eval_column(batch)?, &tc.cast_type)?), BoundExpr::AggFunc(_) => todo!(), + BoundExpr::Alias(alias) => alias.expr.eval_column(batch), } } @@ -52,6 +53,11 @@ impl BoundExpr { let new_name = format!("{}({})", agg.func, inner_name); Field::new(new_name.as_str(), agg.return_type.clone(), true) } + BoundExpr::Alias(alias) => { + let new_name = alias.alias.to_string(); + let data_type = alias.expr.return_type().unwrap(); + Field::new(new_name.as_str(), data_type, true) + } } } } diff --git a/src/optimizer/expr_rewriter.rs b/src/optimizer/expr_rewriter.rs index e11aecb..1892eb1 100644 --- a/src/optimizer/expr_rewriter.rs +++ b/src/optimizer/expr_rewriter.rs @@ -9,6 +9,7 @@ pub trait ExprRewriter { BoundExpr::BinaryOp(_) => self.rewrite_binary_op(expr), BoundExpr::TypeCast(_) => self.rewrite_type_cast(expr), BoundExpr::AggFunc(_) => self.rewrite_agg_func(expr), + BoundExpr::Alias(_) => self.rewrite_alias(expr), } } @@ -40,4 +41,13 @@ pub trait ExprRewriter { _ => unreachable!(), } } + + fn rewrite_alias(&self, expr: &mut BoundExpr) { + match expr { + BoundExpr::Alias(e) => { + self.rewrite_expr(&mut e.expr); + } + _ => unreachable!(), + } + } } diff --git a/src/optimizer/expr_visitor.rs b/src/optimizer/expr_visitor.rs index f0a8dbf..531d68e 100644 --- a/src/optimizer/expr_visitor.rs +++ b/src/optimizer/expr_visitor.rs @@ -1,5 +1,6 @@ use crate::binder::{ - BoundAggFunc, BoundBinaryOp, BoundColumnRef, BoundExpr, BoundInputRef, BoundTypeCast, + BoundAggFunc, BoundAlias, BoundBinaryOp, BoundColumnRef, BoundExpr, BoundInputRef, + BoundTypeCast, }; use crate::types::ScalarValue; @@ -15,6 +16,7 @@ pub trait ExprVisitor { BoundExpr::BinaryOp(expr) => self.visit_binary_op(expr), BoundExpr::TypeCast(expr) => self.visit_type_cast(expr), BoundExpr::AggFunc(expr) => self.visit_agg_func(expr), + BoundExpr::Alias(expr) => self.visit_alias(expr), } } @@ -38,4 +40,8 @@ pub trait ExprVisitor { self.visit_expr(arg); } } + + fn visit_alias(&mut self, expr: &BoundAlias) { + self.visit_expr(&expr.expr); + } } diff --git a/src/optimizer/input_ref_rewriter.rs b/src/optimizer/input_ref_rewriter.rs index dbb9284..0cd5940 100644 --- a/src/optimizer/input_ref_rewriter.rs +++ b/src/optimizer/input_ref_rewriter.rs @@ -38,6 +38,7 @@ impl InputRefRewriter { self.rewrite_expr(arg); } } + BoundExpr::Alias(e) => self.rewrite_expr(e.expr.as_mut()), _ => unreachable!( "unexpected expr type {:?} for InputRefRewriter, binding: {:?}", expr, self.bindings @@ -62,6 +63,10 @@ impl ExprRewriter for InputRefRewriter { fn rewrite_agg_func(&self, expr: &mut BoundExpr) { self.rewrite_internal(expr); } + + fn rewrite_alias(&self, expr: &mut BoundExpr) { + self.rewrite_internal(expr); + } } impl PlanRewriter for InputRefRewriter { diff --git a/src/optimizer/rules/column_pruning.rs b/src/optimizer/rules/column_pruning.rs index f38ae51..341a991 100644 --- a/src/optimizer/rules/column_pruning.rs +++ b/src/optimizer/rules/column_pruning.rs @@ -64,6 +64,15 @@ impl Rule for PushProjectIntoTableScan { .get_plan_ref() .as_logical_project() .unwrap(); + + // only push down the project when the exprs are all column refs + for expr in project_node.exprs().iter() { + match expr { + BoundExpr::ColumnRef(_) => {} + _ => return, + } + } + let table_scan_node = table_scan_opt_expr .root .get_plan_ref() @@ -234,11 +243,27 @@ mod tests { #[test] fn test_push_project_into_table_scan_rule() { - let tests = vec![RuleTest { - name: "push_project_into_table_scan_rule", - sql: "select a from t1", - expect: "LogicalTableScan: table: #t1, columns: [a]", - }]; + let tests = vec![ + RuleTest { + name: "push_project_into_table_scan_rule", + sql: "select a from t1", + expect: "LogicalTableScan: table: #t1, columns: [a]", + }, + RuleTest { + name: "should not push when project has alias", + sql: "select a as c1 from t1", + expect: r" +LogicalProject: exprs [t1.a:Nullable(Int32) as c1] + LogicalTableScan: table: #t1, columns: [a, b, c]", + }, + RuleTest { + name: "should not push when project expr is not column", + sql: "select a + 1 from t1", + expect: r" +LogicalProject: exprs [t1.a:Nullable(Int32) + 1] + LogicalTableScan: table: #t1, columns: [a, b, c]", + }, + ]; for t in tests { let logical_plan = build_plan(t.sql); diff --git a/tests/slt/alias.slt b/tests/slt/alias.slt new file mode 100644 index 0000000..4d0f46a --- /dev/null +++ b/tests/slt/alias.slt @@ -0,0 +1,16 @@ +query I +select a as c1 from t1 order by c1 desc limit 1; +---- +2 + +query I +select a as c1 from t1 where c1 = 1; +---- +1 + +query II +select sum(b) as c1, a as c2 from t1 group by c2 order by c1 desc; +---- +15 2 +5 1 +4 0