diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 39d3a36139d..47608e36537 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -47,6 +47,7 @@ use crate::execution::table_impl::TableImpl; use crate::logicalplan::*; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::projection_push_down::ProjectionPushDown; +use crate::optimizer::resolve_columns::ResolveColumnsRule; use crate::optimizer::type_coercion::TypeCoercionRule; use crate::sql::parser::{DFASTNode, DFParser, FileType}; use crate::sql::planner::{SchemaProvider, SqlToRel}; @@ -231,12 +232,13 @@ impl ExecutionContext { } /// Optimize the logical plan by applying optimizer rules - pub fn optimize(&self, plan: &LogicalPlan) -> Result> { + pub fn optimize(&self, plan: &LogicalPlan) -> Result { let rules: Vec> = vec![ + Box::new(ResolveColumnsRule::new()), Box::new(ProjectionPushDown::new()), Box::new(TypeCoercionRule::new()), ]; - let mut plan = Arc::new(plan.clone()); + let mut plan = plan.clone(); for mut rule in rules { plan = rule.optimize(&plan)?; } @@ -246,10 +248,10 @@ impl ExecutionContext { /// Create a physical plan from a logical plan pub fn create_physical_plan( &mut self, - logical_plan: &Arc, + logical_plan: &LogicalPlan, batch_size: usize, ) -> Result> { - match logical_plan.as_ref() { + match logical_plan { LogicalPlan::TableScan { table_name, projection, @@ -435,9 +437,10 @@ impl ExecutionContext { ))), } } - _ => Err(ExecutionError::NotImplemented( - "Unsupported aggregate expression".to_string(), - )), + other => Err(ExecutionError::General(format!( + "Invalid aggregate expression '{:?}'", + other + ))), } } @@ -731,22 +734,19 @@ mod tests { let mut ctx = create_ctx(&tmp_dir, 1)?; let schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::UInt32, false), - Field::new("c2", DataType::UInt64, false), + Field::new("state", DataType::Utf8, false), + Field::new("salary", DataType::UInt32, false), ])); - let plan = LogicalPlanBuilder::scan( - "default", - "test", - schema.as_ref(), - Some(vec![0, 1]), - )? - .aggregate( - vec![col(0)], - vec![aggregate_expr("SUM", col(1), DataType::Int32)], - )? - .project(vec![col(0), col(1).alias("total_salary")])? - .build()?; + let plan = LogicalPlanBuilder::scan("default", "test", schema.as_ref(), None)? + .aggregate( + vec![col("state")], + vec![aggregate_expr("SUM", col("salary"), DataType::UInt32)], + )? + .project(vec![col("state"), col_index(1).alias("total_salary")])? + .build()?; + + let plan = ctx.optimize(&plan)?; let physical_plan = ctx.create_physical_plan(&Arc::new(plan), 1024)?; assert_eq!("c1", physical_plan.schema().field(0).name().as_str()); diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index be1ca2838ab..0b9464c1588 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -183,6 +183,8 @@ pub enum Expr { Alias(Arc, String), /// index into a value within the row or complex value Column(usize), + /// Reference to column by name + UnresolvedColumn(String), /// literal value Literal(ScalarValue), /// binary expression e.g. "age > 21" @@ -242,6 +244,9 @@ impl Expr { match self { Expr::Alias(expr, _) => expr.get_type(schema), Expr::Column(n) => Ok(schema.field(*n).data_type().clone()), + Expr::UnresolvedColumn(name) => { + Ok(schema.field_with_name(&name)?.data_type().clone()) + } Expr::Literal(l) => Ok(l.get_datatype()), Expr::Cast { data_type, .. } => Ok(data_type.clone()), Expr::ScalarFunction { return_type, .. } => Ok(return_type.clone()), @@ -356,11 +361,16 @@ impl Expr { } } -/// Create a column expression -pub fn col(index: usize) -> Expr { +/// Create a column expression based on a column index +pub fn col_index(index: usize) -> Expr { Expr::Column(index) } +/// Create a column expression based on a column name +pub fn col(name: &str) -> Expr { + Expr::UnresolvedColumn(name.to_owned()) +} + /// Create a literal string expression pub fn lit_str(str: &str) -> Expr { Expr::Literal(ScalarValue::Utf8(str.to_owned())) @@ -380,6 +390,7 @@ impl fmt::Debug for Expr { match self { Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias), Expr::Column(i) => write!(f, "#{}", i), + Expr::UnresolvedColumn(name) => write!(f, "#{}", name), Expr::Literal(v) => write!(f, "{:?}", v), Expr::Cast { expr, data_type } => { write!(f, "CAST({:?} AS {:?})", expr, data_type) @@ -709,7 +720,7 @@ impl LogicalPlanBuilder { (0..expr.len()).for_each(|i| match &expr[i] { Expr::Wildcard => { (0..input_schema.fields().len()) - .for_each(|i| expr_vec.push(col(i).clone())); + .for_each(|i| expr_vec.push(col_index(i).clone())); } _ => expr_vec.push(expr[i].clone()), }); @@ -791,8 +802,8 @@ mod tests { &employee_schema(), Some(vec![0, 3]), )? - .filter(col(1).eq(&lit_str("CO")))? - .project(vec![col(0)])? + .filter(col("id").eq(&lit_str("CO")))? + .project(vec![col("id")])? .build()?; // prove that a plan can be passed to a thread @@ -812,13 +823,13 @@ mod tests { &employee_schema(), Some(vec![0, 3]), )? - .filter(col(1).eq(&lit_str("CO")))? - .project(vec![col(0)])? + .filter(col("state").eq(&lit_str("CO")))? + .project(vec![col("id")])? .build()?; - let expected = "Projection: #0\n \ - Selection: #1 Eq Utf8(\"CO\")\n \ - TableScan: employee.csv projection=Some([0, 3])"; + let expected = "Projection: #id\ + \n Selection: #state Eq Utf8(\"CO\")\ + \n TableScan: employee.csv projection=Some([0, 3])"; assert_eq!(expected, format!("{:?}", plan)); @@ -834,15 +845,16 @@ mod tests { Some(vec![3, 4]), )? .aggregate( - vec![col(0)], - vec![aggregate_expr("SUM", col(1), DataType::Int32)], + vec![col("state")], + vec![aggregate_expr("SUM", col("salary"), DataType::Int32) + .alias("total_salary")], )? - .project(vec![col(0), col(1).alias("total_salary")])? + .project(vec![col("state"), col("total_salary")])? .build()?; - let expected = "Projection: #0, #1 AS total_salary\ - \n Aggregate: groupBy=[[#0]], aggr=[[SUM(#1)]]\ - \n TableScan: employee.csv projection=Some([3, 4])"; + let expected = "Projection: #state, #total_salary\ + \n Aggregate: groupBy=[[#state]], aggr=[[SUM(#salary) AS total_salary]]\ + \n TableScan: employee.csv projection=Some([3, 4])"; assert_eq!(expected, format!("{:?}", plan)); diff --git a/rust/datafusion/src/optimizer/mod.rs b/rust/datafusion/src/optimizer/mod.rs index e60c7db824b..1ac97b1c30f 100644 --- a/rust/datafusion/src/optimizer/mod.rs +++ b/rust/datafusion/src/optimizer/mod.rs @@ -20,5 +20,6 @@ pub mod optimizer; pub mod projection_push_down; +pub mod resolve_columns; pub mod type_coercion; pub mod utils; diff --git a/rust/datafusion/src/optimizer/optimizer.rs b/rust/datafusion/src/optimizer/optimizer.rs index 9626bfdcf51..e041ce3f5e5 100644 --- a/rust/datafusion/src/optimizer/optimizer.rs +++ b/rust/datafusion/src/optimizer/optimizer.rs @@ -19,11 +19,10 @@ use crate::error::Result; use crate::logicalplan::LogicalPlan; -use std::sync::Arc; /// An optimizer rules performs a transformation on a logical plan to produce an optimized /// logical plan. pub trait OptimizerRule { /// Perform optimizations on the plan - fn optimize(&mut self, plan: &LogicalPlan) -> Result>; + fn optimize(&mut self, plan: &LogicalPlan) -> Result; } diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs index a05f8656d02..6017b907e40 100644 --- a/rust/datafusion/src/optimizer/projection_push_down.rs +++ b/rust/datafusion/src/optimizer/projection_push_down.rs @@ -19,8 +19,8 @@ //! loaded into memory use crate::error::{ExecutionError, Result}; -use crate::logicalplan::Expr; use crate::logicalplan::LogicalPlan; +use crate::logicalplan::{Expr, LogicalPlanBuilder}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use arrow::datatypes::{Field, Schema}; @@ -32,7 +32,7 @@ use std::sync::Arc; pub struct ProjectionPushDown {} impl OptimizerRule for ProjectionPushDown { - fn optimize(&mut self, plan: &LogicalPlan) -> Result> { + fn optimize(&mut self, plan: &LogicalPlan) -> Result { let mut accum: HashSet = HashSet::new(); let mut mapping: HashMap = HashMap::new(); self.optimize_plan(plan, &mut accum, &mut mapping) @@ -50,92 +50,52 @@ impl ProjectionPushDown { plan: &LogicalPlan, accum: &mut HashSet, mapping: &mut HashMap, - ) -> Result> { + ) -> Result { match plan { - LogicalPlan::Projection { - expr, - input, - schema, - } => { + LogicalPlan::Projection { expr, input, .. } => { // collect all columns referenced by projection expressions utils::exprlist_to_column_indices(&expr, accum)?; - // push projection down - let input = self.optimize_plan(&input, accum, mapping)?; - - // rewrite projection expressions to use new column indexes - let new_expr = self.rewrite_exprs(expr, mapping)?; - - Ok(Arc::new(LogicalPlan::Projection { - expr: new_expr, - input, - schema: schema.clone(), - })) + LogicalPlanBuilder::from(&self.optimize_plan(&input, accum, mapping)?) + .project(self.rewrite_expr_list(expr, mapping)?)? + .build() } LogicalPlan::Selection { expr, input } => { // collect all columns referenced by filter expression utils::expr_to_column_indices(expr, accum)?; - // push projection down - let input = self.optimize_plan(&input, accum, mapping)?; - - // rewrite filter expression to use new column indexes - let new_expr = self.rewrite_expr(expr, mapping)?; - - Ok(Arc::new(LogicalPlan::Selection { - expr: new_expr, - input, - })) + LogicalPlanBuilder::from(&self.optimize_plan(&input, accum, mapping)?) + .filter(self.rewrite_expr(expr, mapping)?)? + .build() } LogicalPlan::Aggregate { input, group_expr, aggr_expr, - schema, + .. } => { // collect all columns referenced by grouping and aggregate expressions utils::exprlist_to_column_indices(&group_expr, accum)?; utils::exprlist_to_column_indices(&aggr_expr, accum)?; - // push projection down - let input = self.optimize_plan(&input, accum, mapping)?; - - // rewrite expressions to use new column indexes - let new_group_expr = self.rewrite_exprs(group_expr, mapping)?; - let new_aggr_expr = self.rewrite_exprs(aggr_expr, mapping)?; - - Ok(Arc::new(LogicalPlan::Aggregate { - input, - group_expr: new_group_expr, - aggr_expr: new_aggr_expr, - schema: schema.clone(), - })) + LogicalPlanBuilder::from(&self.optimize_plan(&input, accum, mapping)?) + .aggregate( + self.rewrite_expr_list(group_expr, mapping)?, + self.rewrite_expr_list(aggr_expr, mapping)?, + )? + .build() } - LogicalPlan::Sort { - expr, - input, - schema, - } => { + LogicalPlan::Sort { expr, input, .. } => { // collect all columns referenced by sort expressions utils::exprlist_to_column_indices(&expr, accum)?; - // push projection down - let input = self.optimize_plan(&input, accum, mapping)?; - - // rewrite sort expressions to use new column indexes - let new_expr = self.rewrite_exprs(expr, mapping)?; - - Ok(Arc::new(LogicalPlan::Sort { - expr: new_expr, - input, - schema: schema.clone(), - })) - } - LogicalPlan::EmptyRelation { schema } => { - Ok(Arc::new(LogicalPlan::EmptyRelation { - schema: schema.clone(), - })) + LogicalPlanBuilder::from(&self.optimize_plan(&input, accum, mapping)?) + .sort(self.rewrite_expr_list(expr, mapping)?)? + .build() } + LogicalPlan::EmptyRelation { schema } => Ok(LogicalPlan::EmptyRelation { + schema: schema.clone(), + }), LogicalPlan::TableScan { schema_name, table_name, @@ -183,40 +143,40 @@ impl ProjectionPushDown { } // return the table scan with projection - Ok(Arc::new(LogicalPlan::TableScan { + Ok(LogicalPlan::TableScan { schema_name: schema_name.to_string(), table_name: table_name.to_string(), table_schema: table_schema.clone(), projected_schema: Arc::new(projected_schema), projection: Some(projection), - })) + }) } LogicalPlan::Limit { expr, input, schema, - } => Ok(Arc::new(LogicalPlan::Limit { + } => Ok(LogicalPlan::Limit { expr: expr.clone(), input: input.clone(), schema: schema.clone(), - })), + }), LogicalPlan::CreateExternalTable { schema, name, location, file_type, header_row, - } => Ok(Arc::new(LogicalPlan::CreateExternalTable { + } => Ok(LogicalPlan::CreateExternalTable { schema: schema.clone(), name: name.to_string(), location: location.to_string(), file_type: file_type.clone(), header_row: *header_row, - })), + }), } } - fn rewrite_exprs( + fn rewrite_expr_list( &self, expr: &Vec, mapping: &HashMap, @@ -234,6 +194,9 @@ impl ProjectionPushDown { name.clone(), )), Expr::Column(i) => Ok(Expr::Column(self.new_index(mapping, i)?)), + Expr::UnresolvedColumn(_) => Err(ExecutionError::ExecutionError( + "Columns need to be resolved before this rule can run".to_owned(), + )), Expr::Literal(_) => Ok(expr.clone()), Expr::Not(e) => Ok(Expr::Not(Arc::new(self.rewrite_expr(e, mapping)?))), Expr::IsNull(e) => Ok(Expr::IsNull(Arc::new(self.rewrite_expr(e, mapping)?))), @@ -259,7 +222,7 @@ impl ProjectionPushDown { return_type, } => Ok(Expr::AggregateFunction { name: name.to_string(), - args: self.rewrite_exprs(args, mapping)?, + args: self.rewrite_expr_list(args, mapping)?, return_type: return_type.clone(), }), Expr::ScalarFunction { @@ -268,7 +231,7 @@ impl ProjectionPushDown { return_type, } => Ok(Expr::ScalarFunction { name: name.to_string(), - args: self.rewrite_exprs(args, mapping)?, + args: self.rewrite_expr_list(args, mapping)?, return_type: return_type.clone(), }), Expr::Wildcard => Err(ExecutionError::General( @@ -292,147 +255,107 @@ mod tests { use super::*; use crate::logicalplan::Expr::*; - use crate::logicalplan::LogicalPlan::*; - use arrow::datatypes::{DataType, Field, Schema}; - use std::borrow::Borrow; + use crate::test::*; + use arrow::datatypes::DataType; use std::sync::Arc; #[test] - fn aggregate_no_group_by() { - let table_scan = test_table_scan(); - - let aggregate = Aggregate { - group_expr: vec![], - aggr_expr: vec![Column(1)], - schema: Arc::new(Schema::new(vec![Field::new( - "MAX(b)", - DataType::UInt32, - false, - )])), - input: Arc::new(table_scan), - }; - - assert_optimized_plan_eq(&aggregate, "Aggregate: groupBy=[[]], aggr=[[#0]]\n TableScan: test projection=Some([1])"); + fn aggregate_no_group_by() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(&table_scan) + .aggregate(vec![], vec![max(Column(1))])? + .build()?; + + let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#0)]]\ + \n TableScan: test projection=Some([1])"; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) } #[test] - fn aggregate_group_by() { - let table_scan = test_table_scan(); - - let aggregate = Aggregate { - group_expr: vec![Column(2)], - aggr_expr: vec![Column(1)], - schema: Arc::new(Schema::new(vec![ - Field::new("c", DataType::UInt32, false), - Field::new("MAX(b)", DataType::UInt32, false), - ])), - input: Arc::new(table_scan), - }; - - assert_optimized_plan_eq(&aggregate, "Aggregate: groupBy=[[#1]], aggr=[[#0]]\n TableScan: test projection=Some([1, 2])"); + fn aggregate_group_by() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(&table_scan) + .aggregate(vec![Column(2)], vec![max(Column(1))])? + .build()?; + + let expected = "Aggregate: groupBy=[[#1]], aggr=[[MAX(#0)]]\ + \n TableScan: test projection=Some([1, 2])"; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) } #[test] - fn aggregate_no_group_by_with_selection() { - let table_scan = test_table_scan(); - - let selection = Selection { - expr: Column(2), - input: Arc::new(table_scan), - }; - - let aggregate = Aggregate { - group_expr: vec![], - aggr_expr: vec![Column(1)], - schema: Arc::new(Schema::new(vec![Field::new( - "MAX(b)", - DataType::UInt32, - false, - )])), - input: Arc::new(selection), - }; - - assert_optimized_plan_eq(&aggregate, "Aggregate: groupBy=[[]], aggr=[[#0]]\n Selection: #1\n TableScan: test projection=Some([1, 2])"); + fn aggregate_no_group_by_with_selection() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(&table_scan) + .filter(Column(2))? + .aggregate(vec![], vec![max(Column(1))])? + .build()?; + + let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#0)]]\ + \n Selection: #1\ + \n TableScan: test projection=Some([1, 2])"; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) } #[test] - fn cast() { - let table_scan = test_table_scan(); + fn cast() -> Result<()> { + let table_scan = test_table_scan()?; - let projection = Projection { - expr: vec![Cast { + let projection = LogicalPlanBuilder::from(&table_scan) + .project(vec![Cast { expr: Arc::new(Column(2)), data_type: DataType::Float64, - }], - input: Arc::new(table_scan), - schema: Arc::new(Schema::new(vec![Field::new( - "CAST(c AS float)", - DataType::Float64, - false, - )])), - }; - - assert_optimized_plan_eq( - &projection, - "Projection: CAST(#0 AS Float64)\n TableScan: test projection=Some([2])", - ); + }])? + .build()?; + + let expected = "Projection: CAST(#0 AS Float64)\ + \n TableScan: test projection=Some([2])"; + + assert_optimized_plan_eq(&projection, expected); + + Ok(()) } #[test] - fn table_scan_projected_schema() { - let table_scan = test_table_scan(); + fn table_scan_projected_schema() -> Result<()> { + let table_scan = test_table_scan()?; assert_eq!(3, table_scan.schema().fields().len()); + assert_fields_eq(&table_scan, vec!["a", "b", "c"]); - let projection = Projection { - expr: vec![Column(0), Column(1)], - input: Arc::new(table_scan), - schema: Arc::new(Schema::new(vec![ - Field::new("a", DataType::UInt32, false), - Field::new("b", DataType::UInt32, false), - ])), - }; - - let optimized_plan = optimize(&projection); - - // check that table scan schema now contains 2 columns - match optimized_plan.as_ref().borrow() { - LogicalPlan::Projection { input, .. } => match input.as_ref().borrow() { - LogicalPlan::TableScan { - ref projected_schema, - .. - } => { - assert_eq!(2, projected_schema.fields().len()); - } - _ => assert!(false), - }, - _ => assert!(false), - } + let plan = LogicalPlanBuilder::from(&table_scan) + .project(vec![Column(0), Column(1)])? + .build()?; + + assert_fields_eq(&plan, vec!["a", "b"]); + + let expected = "Projection: #0, #1\ + \n TableScan: test projection=Some([0, 1])"; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) } fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let optimized_plan = optimize(plan); + let optimized_plan = optimize(plan).expect("failed to optimize plan"); let formatted_plan = format!("{:?}", optimized_plan); assert_eq!(formatted_plan, expected); } - fn optimize(plan: &LogicalPlan) -> Arc { + fn optimize(plan: &LogicalPlan) -> Result { let mut rule = ProjectionPushDown::new(); - rule.optimize(plan).unwrap() - } - - /// all tests share a common table - fn test_table_scan() -> LogicalPlan { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::UInt32, false), - Field::new("b", DataType::UInt32, false), - Field::new("c", DataType::UInt32, false), - ])); - TableScan { - schema_name: "default".to_string(), - table_name: "test".to_string(), - table_schema: schema.clone(), - projected_schema: schema, - projection: None, - } + rule.optimize(plan) } } diff --git a/rust/datafusion/src/optimizer/resolve_columns.rs b/rust/datafusion/src/optimizer/resolve_columns.rs new file mode 100644 index 00000000000..7469e4dbcec --- /dev/null +++ b/rust/datafusion/src/optimizer/resolve_columns.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule to replace UnresolvedColumns with Columns + +use crate::error::Result; +use crate::logicalplan::LogicalPlan; +use crate::logicalplan::{Expr, LogicalPlanBuilder}; +use crate::optimizer::optimizer::OptimizerRule; +use arrow::datatypes::Schema; +use std::sync::Arc; + +/// Replace UnresolvedColumns with Columns +pub struct ResolveColumnsRule {} + +impl ResolveColumnsRule { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for ResolveColumnsRule { + fn optimize(&mut self, plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Projection { input, expr, .. } => { + Ok(LogicalPlanBuilder::from(&self.optimize(input.as_ref())?) + .project(rewrite_expr_list(expr, &input.schema())?)? + .build()?) + } + LogicalPlan::Selection { expr, input } => Ok(LogicalPlanBuilder::from(input) + .filter(rewrite_expr(expr, &input.schema())?)? + .build()?), + LogicalPlan::Aggregate { + input, + group_expr, + aggr_expr, + .. + } => Ok(LogicalPlanBuilder::from(input) + .aggregate( + rewrite_expr_list(group_expr, &input.schema())?, + rewrite_expr_list(aggr_expr, &input.schema())?, + )? + .build()?), + LogicalPlan::Sort { input, expr, .. } => Ok(LogicalPlanBuilder::from(input) + .sort(rewrite_expr_list(expr, &input.schema())?)? + .build()?), + _ => Ok(plan.clone()), + } + } +} +fn rewrite_expr_list(expr: &Vec, schema: &Schema) -> Result> { + Ok(expr + .iter() + .map(|e| rewrite_expr(e, schema)) + .collect::>>()?) +} + +fn rewrite_expr(expr: &Expr, schema: &Schema) -> Result { + match expr { + Expr::Alias(expr, alias) => Ok(rewrite_expr(&expr, schema)?.alias(&alias)), + Expr::UnresolvedColumn(name) => Ok(Expr::Column(schema.index_of(&name)?)), + Expr::BinaryExpr { left, op, right } => Ok(Expr::BinaryExpr { + left: Arc::new(rewrite_expr(&left, schema)?), + op: op.clone(), + right: Arc::new(rewrite_expr(&right, schema)?), + }), + Expr::Not(expr) => Ok(Expr::Not(Arc::new(rewrite_expr(&expr, schema)?))), + Expr::IsNotNull(expr) => { + Ok(Expr::IsNotNull(Arc::new(rewrite_expr(&expr, schema)?))) + } + Expr::IsNull(expr) => Ok(Expr::IsNull(Arc::new(rewrite_expr(&expr, schema)?))), + Expr::Cast { expr, data_type } => Ok(Expr::Cast { + expr: Arc::new(rewrite_expr(&expr, schema)?), + data_type: data_type.clone(), + }), + Expr::Sort { expr, asc } => Ok(Expr::Sort { + expr: Arc::new(rewrite_expr(&expr, schema)?), + asc: asc.clone(), + }), + Expr::ScalarFunction { + name, + args, + return_type, + } => Ok(Expr::ScalarFunction { + name: name.clone(), + args: rewrite_expr_list(args, schema)?, + return_type: return_type.clone(), + }), + Expr::AggregateFunction { + name, + args, + return_type, + } => Ok(Expr::AggregateFunction { + name: name.clone(), + args: rewrite_expr_list(args, schema)?, + return_type: return_type.clone(), + }), + _ => Ok(expr.clone()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logicalplan::col; + use crate::test::*; + + #[test] + fn aggregate_no_group_by() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(&table_scan) + .aggregate(vec![col("a")], vec![max(col("b"))])? + .build()?; + + // plan has unresolve columns + let expected = "Aggregate: groupBy=[[#a]], aggr=[[MAX(#b)]]\n TableScan: test projection=None"; + assert_eq!(format!("{:?}", plan), expected); + + // optimized plan has resolved columns + let expected = "Aggregate: groupBy=[[#0]], aggr=[[MAX(#1)]]\n TableScan: test projection=None"; + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { + let optimized_plan = optimize(plan).expect("failed to optimize plan"); + let formatted_plan = format!("{:?}", optimized_plan); + assert_eq!(formatted_plan, expected); + } + + fn optimize(plan: &LogicalPlan) -> Result { + let mut rule = ResolveColumnsRule::new(); + rule.optimize(plan) + } +} diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index e93d01e640d..b8d90a90910 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -25,8 +25,8 @@ use std::sync::Arc; use arrow::datatypes::Schema; use crate::error::{ExecutionError, Result}; -use crate::logicalplan::Expr; use crate::logicalplan::LogicalPlan; +use crate::logicalplan::{Expr, LogicalPlanBuilder}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; @@ -34,47 +34,33 @@ use crate::optimizer::utils; pub struct TypeCoercionRule {} impl OptimizerRule for TypeCoercionRule { - fn optimize(&mut self, plan: &LogicalPlan) -> Result> { + fn optimize(&mut self, plan: &LogicalPlan) -> Result { match plan { - LogicalPlan::Projection { - expr, - input, - schema, - } => Ok(Arc::new(LogicalPlan::Projection { - expr: expr - .iter() - .map(|e| rewrite_expr(e, &schema)) - .collect::>>()?, - input: self.optimize(input)?, - schema: schema.clone(), - })), - LogicalPlan::Selection { expr, input } => { - Ok(Arc::new(LogicalPlan::Selection { - expr: rewrite_expr(expr, input.schema())?, - input: self.optimize(input)?, - })) + LogicalPlan::Projection { expr, input, .. } => { + LogicalPlanBuilder::from(&self.optimize(input)?) + .project(rewrite_expr_list(expr, input.schema())?)? + .build() + } + LogicalPlan::Selection { expr, input, .. } => { + LogicalPlanBuilder::from(&self.optimize(input)?) + .filter(rewrite_expr(expr, input.schema())?)? + .build() } LogicalPlan::Aggregate { input, group_expr, aggr_expr, - schema, - } => Ok(Arc::new(LogicalPlan::Aggregate { - group_expr: group_expr - .iter() - .map(|e| rewrite_expr(e, &schema)) - .collect::>>()?, - aggr_expr: aggr_expr - .iter() - .map(|e| rewrite_expr(e, &schema)) - .collect::>>()?, - input: self.optimize(input)?, - schema: schema.clone(), - })), - LogicalPlan::TableScan { .. } => Ok(Arc::new(plan.clone())), - LogicalPlan::EmptyRelation { .. } => Ok(Arc::new(plan.clone())), - LogicalPlan::Limit { .. } => Ok(Arc::new(plan.clone())), - LogicalPlan::CreateExternalTable { .. } => Ok(Arc::new(plan.clone())), + .. + } => LogicalPlanBuilder::from(&self.optimize(input)?) + .aggregate( + rewrite_expr_list(group_expr, input.schema())?, + rewrite_expr_list(aggr_expr, input.schema())?, + )? + .build(), + LogicalPlan::TableScan { .. } => Ok(plan.clone()), + LogicalPlan::EmptyRelation { .. } => Ok(plan.clone()), + LogicalPlan::Limit { .. } => Ok(plan.clone()), + LogicalPlan::CreateExternalTable { .. } => Ok(plan.clone()), other => Err(ExecutionError::NotImplemented(format!( "Type coercion optimizer rule does not support relation: {:?}", other @@ -90,6 +76,13 @@ impl TypeCoercionRule { } } +fn rewrite_expr_list(expr: &Vec, schema: &Schema) -> Result> { + Ok(expr + .iter() + .map(|e| rewrite_expr(e, schema)) + .collect::>>()?) +} + /// Rewrite an expression to include explicit CAST operations when required fn rewrite_expr(expr: &Expr, schema: &Schema) -> Result { match expr { @@ -141,11 +134,17 @@ fn rewrite_expr(expr: &Expr, schema: &Schema) -> Result { }), Expr::Cast { .. } => Ok(expr.clone()), Expr::Column(_) => Ok(expr.clone()), + Expr::Alias(expr, alias) => Ok(Expr::Alias( + Arc::new(rewrite_expr(expr, schema)?), + alias.to_owned(), + )), Expr::Literal(_) => Ok(expr.clone()), - other => Err(ExecutionError::NotImplemented(format!( - "Type coercion optimizer rule does not support expression: {:?}", - other - ))), + Expr::UnresolvedColumn(_) => Ok(expr.clone()), + Expr::Not(_) => Ok(expr.clone()), + Expr::Sort { .. } => Ok(expr.clone()), + Expr::Wildcard { .. } => Err(ExecutionError::General( + "Wildcard expressions are not valid in a logical query plan".to_owned(), + )), } } diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index 1755ac61198..0aa79469197 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -45,6 +45,9 @@ pub fn expr_to_column_indices(expr: &Expr, accum: &mut HashSet) -> Result accum.insert(*i); Ok(()) } + Expr::UnresolvedColumn(_) => Err(ExecutionError::ExecutionError( + "Columns need to be resolved before this rule can run".to_owned(), + )), Expr::Literal(_) => { // not needed Ok(()) @@ -73,6 +76,7 @@ pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result { Expr::Alias(expr, name) => { Ok(Field::new(name, expr.get_type(input_schema)?, true)) } + Expr::UnresolvedColumn(name) => Ok(input_schema.field_with_name(&name)?.clone()), Expr::Column(i) => { let input_schema_field_count = input_schema.fields().len(); if *i < input_schema_field_count { diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index bc4dc1c80c7..ecda07dda85 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -20,6 +20,7 @@ use crate::error::Result; use crate::execution::context::ExecutionContext; use crate::execution::physical_plan::ExecutionPlan; +use crate::logicalplan::{Expr, LogicalPlan, LogicalPlanBuilder}; use arrow::array; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; @@ -202,3 +203,31 @@ pub fn format_batch(batch: &RecordBatch) -> Vec { } rows } + +/// all tests share a common table +pub fn test_table_scan() -> Result { + let schema = Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::UInt32, false), + Field::new("c", DataType::UInt32, false), + ]); + LogicalPlanBuilder::scan("default", "test", &schema, None)?.build() +} + +pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { + let actual: Vec = plan + .schema() + .fields() + .iter() + .map(|f| f.name().clone()) + .collect(); + assert_eq!(actual, expected); +} + +pub fn max(expr: Expr) -> Expr { + Expr::AggregateFunction { + name: "MAX".to_owned(), + args: vec![expr], + return_type: DataType::Float64, + } +} diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index f626659d9aa..c45af09a840 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -64,7 +64,7 @@ fn nyc() -> Result<()> { let optimized_plan = ctx.optimize(&logical_plan)?; - match optimized_plan.as_ref() { + match &optimized_plan { LogicalPlan::Aggregate { input, .. } => match input.as_ref() { LogicalPlan::TableScan { ref projected_schema,