From 9c6bea495036b8b10a22a72a7096915a6eee855c Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 12 Jul 2020 16:51:46 +0200 Subject: [PATCH 1/8] Major simplification to columns. Columns are no longer identified by its index, but by its name. This is induced by the following assumption: every table that we scan has a unique column name. This greatly simplifies the code and the public API of physical plans logical plans. This also greatly simplifies the projection push down, and deprecates the ResolveColumns. --- rust/datafusion/src/execution/context.rs | 33 +++-- .../execution/physical_plan/expressions.rs | 130 +++++++++++------- .../execution/physical_plan/hash_aggregate.rs | 16 +-- .../physical_plan/math_expressions.rs | 10 +- .../src/execution/physical_plan/mod.rs | 2 +- .../src/execution/physical_plan/projection.rs | 9 +- .../src/execution/physical_plan/selection.rs | 12 +- .../src/execution/physical_plan/sort.rs | 6 +- rust/datafusion/src/execution/table_impl.rs | 17 ++- rust/datafusion/src/logicalplan.rs | 23 +--- .../src/optimizer/projection_push_down.rs | 126 ++++++----------- .../src/optimizer/resolve_columns.rs | 3 +- .../datafusion/src/optimizer/type_coercion.rs | 22 ++- rust/datafusion/src/optimizer/utils.rs | 67 ++++----- rust/datafusion/src/sql/planner.rs | 113 ++++++--------- 15 files changed, 254 insertions(+), 335 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 2500163e2e4..221de3e6356 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -459,8 +459,10 @@ impl ExecutionContext { let expr = self.create_physical_expr(expr, input_schema)?; Ok(Arc::new(Alias::new(expr, &name))) } - Expr::Column(i) => { - Ok(Arc::new(Column::new(*i, &input_schema.field(*i).name()))) + Expr::Column(name) => { + // check that name exists + input_schema.field_with_name(&name)?; + Ok(Arc::new(Column::new(name))) } Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), Expr::BinaryExpr { left, op, right } => Ok(Arc::new(BinaryExpr::new( @@ -706,7 +708,7 @@ mod tests { let table = ctx.table("test")?; let logical_plan = LogicalPlanBuilder::from(&table.to_logical_plan()) - .project(vec![Expr::UnresolvedColumn("c2".to_string())])? + .project(vec![col("c2")])? .build()?; let optimized_plan = ctx.optimize(&logical_plan)?; @@ -725,7 +727,7 @@ mod tests { _ => assert!(false, "expect optimized_plan to be projection"), } - let expected = "Projection: #0\ + let expected = "Projection: #c2\ \n TableScan: test projection=Some([1])"; assert_eq!(format!("{:?}", optimized_plan), expected); @@ -747,19 +749,19 @@ mod tests { let tmp_dir = TempDir::new("execute")?; let ctx = create_ctx(&tmp_dir, 1)?; - let schema = Arc::new(Schema::new(vec![Field::new( - "state", - DataType::Utf8, - false, - )])); + let schema = ctx.datasources.get("test").unwrap().schema(); + assert_eq!(schema.field_with_name("c1")?.is_nullable(), false); let plan = LogicalPlanBuilder::scan("default", "test", schema.as_ref(), None)? - .project(vec![col("state")])? + .project(vec![col("c1")])? .build()?; let plan = ctx.optimize(&plan)?; let physical_plan = ctx.create_physical_plan(&Arc::new(plan), 1024)?; - assert_eq!(physical_plan.schema().field(0).is_nullable(), false); + assert_eq!( + physical_plan.schema().field_with_name("c1")?.is_nullable(), + false + ); Ok(()) } @@ -783,7 +785,7 @@ mod tests { projection: None, projected_schema: Box::new(schema.clone()), }) - .project(vec![Expr::UnresolvedColumn("b".to_string())])? + .project(vec![col("b")])? .build()?; assert_fields_eq(&plan, vec!["b"]); @@ -804,7 +806,7 @@ mod tests { _ => assert!(false, "expect optimized_plan to be projection"), } - let expected = "Projection: #0\ + let expected = "Projection: #b\ \n InMemoryScan: projection=Some([1])"; assert_eq!(format!("{:?}", optimized_plan), expected); @@ -1004,7 +1006,10 @@ mod tests { vec![col("state")], vec![aggregate_expr("SUM", col("salary"), DataType::UInt32)], )? - .project(vec![col("state"), col_index(1).alias("total_salary")])? + .project(vec![ + col("state"), + col("SUM(SUM(salary))").alias("total_salary"), + ])? .build()?; let plan = ctx.optimize(&plan)?; diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index e3e1ea847fe..a680b6cb11c 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -83,15 +83,13 @@ impl PhysicalExpr for Alias { /// Represents the column at a given index in a RecordBatch pub struct Column { - index: usize, name: String, } impl Column { /// Create a new column expression - pub fn new(index: usize, name: &str) -> Self { + pub fn new(name: &str) -> Self { Self { - index, name: name.to_owned(), } } @@ -105,23 +103,26 @@ impl PhysicalExpr for Column { /// Get the data type of this expression, given the schema of the input fn data_type(&self, input_schema: &Schema) -> Result { - Ok(input_schema.field(self.index).data_type().clone()) + Ok(input_schema + .field_with_name(&self.name)? + .data_type() + .clone()) } /// Decide whehter this expression is nullable, given the schema of the input fn nullable(&self, input_schema: &Schema) -> Result { - Ok(input_schema.field(self.index).is_nullable()) + Ok(input_schema.field_with_name(&self.name)?.is_nullable()) } /// Evaluate the expression fn evaluate(&self, batch: &RecordBatch) -> Result { - Ok(batch.column(self.index).clone()) + Ok(batch.column(batch.schema().index_of(&self.name)?).clone()) } } /// Create a column expression -pub fn col(i: usize, schema: &Schema) -> Arc { - Arc::new(Column::new(i, &schema.field(i).name())) +pub fn col(name: &str) -> Arc { + Arc::new(Column::new(name)) } /// SUM aggregate expression @@ -138,7 +139,7 @@ impl Sum { impl AggregateExpr for Sum { fn name(&self) -> String { - "SUM".to_string() + format!("SUM({})", self.expr.name()) } fn data_type(&self, input_schema: &Schema) -> Result { @@ -166,8 +167,8 @@ impl AggregateExpr for Sum { Rc::new(RefCell::new(SumAccumulator { sum: None })) } - fn create_reducer(&self, column_index: usize) -> Arc { - Arc::new(Sum::new(Arc::new(Column::new(column_index, &self.name())))) + fn create_reducer(&self) -> Arc { + Arc::new(Sum::new(col(&self.name()))) } } @@ -334,7 +335,7 @@ impl Avg { impl AggregateExpr for Avg { fn name(&self) -> String { - "AVG".to_string() + format!("AVG({})", self.expr.name()) } fn data_type(&self, input_schema: &Schema) -> Result { @@ -367,8 +368,8 @@ impl AggregateExpr for Avg { })) } - fn create_reducer(&self, column_index: usize) -> Arc { - Arc::new(Avg::new(Arc::new(Column::new(column_index, &self.name())))) + fn create_reducer(&self) -> Arc { + Arc::new(Avg::new(Arc::new(Column::new(&self.name())))) } } @@ -452,7 +453,7 @@ impl Max { impl AggregateExpr for Max { fn name(&self) -> String { - "MAX".to_string() + format!("MAX({})", self.expr.name()) } fn data_type(&self, input_schema: &Schema) -> Result { @@ -480,8 +481,8 @@ impl AggregateExpr for Max { Rc::new(RefCell::new(MaxAccumulator { max: None })) } - fn create_reducer(&self, column_index: usize) -> Arc { - Arc::new(Max::new(Arc::new(Column::new(column_index, &self.name())))) + fn create_reducer(&self) -> Arc { + Arc::new(Max::new(Arc::new(Column::new(&self.name())))) } } @@ -651,7 +652,7 @@ impl Min { impl AggregateExpr for Min { fn name(&self) -> String { - "MIN".to_string() + format!("MIN({})", self.expr.name()) } fn data_type(&self, input_schema: &Schema) -> Result { @@ -679,8 +680,8 @@ impl AggregateExpr for Min { Rc::new(RefCell::new(MinAccumulator { min: None })) } - fn create_reducer(&self, column_index: usize) -> Arc { - Arc::new(Min::new(Arc::new(Column::new(column_index, &self.name())))) + fn create_reducer(&self) -> Arc { + Arc::new(Min::new(Arc::new(Column::new(&self.name())))) } } @@ -851,7 +852,7 @@ impl Count { impl AggregateExpr for Count { fn name(&self) -> String { - "COUNT".to_string() + format!("COUNT({})", self.expr.name()) } fn data_type(&self, _input_schema: &Schema) -> Result { @@ -866,8 +867,8 @@ impl AggregateExpr for Count { Rc::new(RefCell::new(CountAccumulator { count: 0 })) } - fn create_reducer(&self, column_index: usize) -> Arc { - Arc::new(Sum::new(Arc::new(Column::new(column_index, &self.name())))) + fn create_reducer(&self) -> Arc { + Arc::new(Sum::new(Arc::new(Column::new(&self.name())))) } } @@ -1336,7 +1337,7 @@ mod tests { )?; // expression: "a < b" - let lt = binary(col(0, &schema), Operator::Lt, col(1, &schema)); + let lt = binary(col("a"), Operator::Lt, col("b")); let result = lt.evaluate(&batch)?; assert_eq!(result.len(), 5); @@ -1367,9 +1368,9 @@ mod tests { // expression: "a < b OR a == b" let expr = binary( - binary(col(0, &schema), Operator::Lt, col(1, &schema)), + binary(col("a"), Operator::Lt, col("b")), Operator::Or, - binary(col(0, &schema), Operator::Eq, col(1, &schema)), + binary(col("a"), Operator::Eq, col("b")), ); let result = expr.evaluate(&batch)?; assert_eq!(result.len(), 5); @@ -1414,7 +1415,7 @@ mod tests { let a = Int32Array::from(vec![1, 2, 3, 4, 5]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let cast = CastExpr::try_new(col(0, &schema), &schema, DataType::UInt32)?; + let cast = CastExpr::try_new(col("a"), &schema, DataType::UInt32)?; let result = cast.evaluate(&batch)?; assert_eq!(result.len(), 5); @@ -1433,7 +1434,7 @@ mod tests { let a = Int32Array::from(vec![1, 2, 3, 4, 5]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let cast = CastExpr::try_new(col(0, &schema), &schema, DataType::Utf8)?; + let cast = CastExpr::try_new(col("a"), &schema, DataType::Utf8)?; let result = cast.evaluate(&batch)?; assert_eq!(result.len(), 5); @@ -1453,7 +1454,7 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; let cast = CastExpr::try_new( - col(0, &schema), + col("a"), &schema, DataType::Timestamp(TimeUnit::Nanosecond, None), )?; @@ -1472,7 +1473,7 @@ mod tests { #[test] fn invalid_cast() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); - match CastExpr::try_new(col(0, &schema), &schema, DataType::Int32) { + match CastExpr::try_new(col("a"), &schema, DataType::Int32) { Err(ExecutionError::General(ref str)) => { assert_eq!(str, "Invalid CAST from Utf8 to Int32"); Ok(()) @@ -1485,12 +1486,18 @@ mod tests { fn sum_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let sum = sum(col(0, &schema)); - assert_eq!("SUM".to_string(), sum.name()); + let sum = sum(col("a")); + assert_eq!("SUM(a)".to_string(), sum.name()); assert_eq!(DataType::Int64, sum.data_type(&schema)?); - let combiner = sum.create_reducer(0); - assert_eq!("SUM".to_string(), combiner.name()); + // after the aggr expression is applied, the schema changes to: + let schema = Schema::new(vec![ + schema.field(0).clone(), + Field::new(&sum.name(), sum.data_type(&schema)?, false), + ]); + + let combiner = sum.create_reducer(); + assert_eq!("SUM(SUM(a))".to_string(), combiner.name()); assert_eq!(DataType::Int64, combiner.data_type(&schema)?); Ok(()) @@ -1500,12 +1507,18 @@ mod tests { fn max_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let max = max(col(0, &schema)); - assert_eq!("MAX".to_string(), max.name()); + let max = max(col("a")); + assert_eq!("MAX(a)".to_string(), max.name()); assert_eq!(DataType::Int64, max.data_type(&schema)?); - let combiner = max.create_reducer(0); - assert_eq!("MAX".to_string(), combiner.name()); + // after the aggr expression is applied, the schema changes to: + let schema = Schema::new(vec![ + schema.field(0).clone(), + Field::new(&max.name(), max.data_type(&schema)?, false), + ]); + + let combiner = max.create_reducer(); + assert_eq!("MAX(MAX(a))".to_string(), combiner.name()); assert_eq!(DataType::Int64, combiner.data_type(&schema)?); Ok(()) @@ -1515,12 +1528,17 @@ mod tests { fn min_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let min = min(col(0, &schema)); - assert_eq!("MIN".to_string(), min.name()); + let min = min(col("a")); + assert_eq!("MIN(a)".to_string(), min.name()); assert_eq!(DataType::Int64, min.data_type(&schema)?); - let combiner = min.create_reducer(0); - assert_eq!("MIN".to_string(), combiner.name()); + // after the aggr expression is applied, the schema changes to: + let schema = Schema::new(vec![ + schema.field(0).clone(), + Field::new(&min.name(), min.data_type(&schema)?, false), + ]); + let combiner = min.create_reducer(); + assert_eq!("MIN(MIN(a))".to_string(), combiner.name()); assert_eq!(DataType::Int64, combiner.data_type(&schema)?); Ok(()) @@ -1529,12 +1547,18 @@ mod tests { fn avg_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let avg = avg(col(0, &schema)); - assert_eq!("AVG".to_string(), avg.name()); + let avg = avg(col("a")); + assert_eq!("AVG(a)".to_string(), avg.name()); assert_eq!(DataType::Float64, avg.data_type(&schema)?); - let combiner = avg.create_reducer(0); - assert_eq!("AVG".to_string(), combiner.name()); + // after the aggr expression is applied, the schema changes to: + let schema = Schema::new(vec![ + schema.field(0).clone(), + Field::new(&avg.name(), avg.data_type(&schema)?, false), + ]); + + let combiner = avg.create_reducer(); + assert_eq!("AVG(AVG(a))".to_string(), combiner.name()); assert_eq!(DataType::Float64, combiner.data_type(&schema)?); Ok(()) @@ -1865,7 +1889,7 @@ mod tests { } fn do_sum(batch: &RecordBatch) -> Result> { - let sum = sum(col(0, &batch.schema())); + let sum = sum(col("a")); let accum = sum.create_accumulator(); let input = sum.evaluate_input(batch)?; let mut accum = accum.borrow_mut(); @@ -1876,7 +1900,7 @@ mod tests { } fn do_max(batch: &RecordBatch) -> Result> { - let max = max(col(0, &batch.schema())); + let max = max(col("a")); let accum = max.create_accumulator(); let input = max.evaluate_input(batch)?; let mut accum = accum.borrow_mut(); @@ -1887,7 +1911,7 @@ mod tests { } fn do_min(batch: &RecordBatch) -> Result> { - let min = min(col(0, &batch.schema())); + let min = min(col("a")); let accum = min.create_accumulator(); let input = min.evaluate_input(batch)?; let mut accum = accum.borrow_mut(); @@ -1898,7 +1922,7 @@ mod tests { } fn do_count(batch: &RecordBatch) -> Result> { - let count = count(col(0, &batch.schema())); + let count = count(col("a")); let accum = count.create_accumulator(); let input = count.evaluate_input(batch)?; let mut accum = accum.borrow_mut(); @@ -1909,7 +1933,7 @@ mod tests { } fn do_avg(batch: &RecordBatch) -> Result> { - let avg = avg(col(0, &batch.schema())); + let avg = avg(col("a")); let accum = avg.create_accumulator(); let input = avg.evaluate_input(batch)?; let mut accum = accum.borrow_mut(); @@ -2009,7 +2033,7 @@ mod tests { op: Operator, expected: PrimitiveArray, ) -> Result<()> { - let arithmetic_op = binary(col(0, schema.as_ref()), op, col(1, schema.as_ref())); + let arithmetic_op = binary(col("a"), op, col("b")); let batch = RecordBatch::try_new(schema, data)?; let result = arithmetic_op.evaluate(&batch)?; @@ -2039,7 +2063,7 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; // expression: "!a" - let lt = not(col(0, &schema)); + let lt = not(col("a")); let result = lt.evaluate(&batch)?; assert_eq!(result.len(), 2); diff --git a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs index 441681ad0bf..e935480154c 100644 --- a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs @@ -39,7 +39,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::{RecordBatch, RecordBatchReader}; -use crate::execution::physical_plan::expressions::Column; +use crate::execution::physical_plan::expressions::col; use crate::logicalplan::ScalarValue; use fnv::FnvHashMap; @@ -85,17 +85,11 @@ impl HashAggregateExec { &self, ) -> (Vec>, Vec>) { let final_group: Vec> = (0..self.group_expr.len()) - .map(|i| { - Arc::new(Column::new(i, &self.group_expr[i].name())) - as Arc - }) + .map(|i| col(&self.group_expr[i].name()) as Arc) .collect(); let final_aggr: Vec> = (0..self.aggr_expr.len()) - .map(|i| { - let aggr = self.aggr_expr[i].create_reducer(i + self.group_expr.len()); - aggr as Arc - }) + .map(|i| self.aggr_expr[i].create_reducer()) .collect(); (final_group, final_aggr) @@ -772,9 +766,9 @@ mod tests { let csv = CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema), None, 1024)?; - let group_expr: Vec> = vec![col(1, schema.as_ref())]; + let group_expr: Vec> = vec![col("c2")]; - let aggr_expr: Vec> = vec![sum(col(3, schema.as_ref()))]; + let aggr_expr: Vec> = vec![sum(col("c4"))]; let partition_aggregate = HashAggregateExec::try_new( group_expr.clone(), diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs index c9954053725..aa578b697da 100644 --- a/rust/datafusion/src/execution/physical_plan/math_expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs @@ -81,18 +81,18 @@ pub fn register_math_functions(ctx: &mut ExecutionContext) { mod tests { use super::*; use crate::error::Result; - use crate::logicalplan::{sqrt, Expr, LogicalPlanBuilder}; + use crate::logicalplan::{col, sqrt, LogicalPlanBuilder}; use arrow::datatypes::Schema; #[test] fn cast_i8_input() -> Result<()> { let schema = Schema::new(vec![Field::new("c0", DataType::Int8, true)]); let plan = LogicalPlanBuilder::scan("", "", &schema, None)? - .project(vec![sqrt(Expr::UnresolvedColumn("c0".to_owned()))])? + .project(vec![sqrt(col("c0"))])? .build()?; let ctx = ExecutionContext::new(); let plan = ctx.optimize(&plan)?; - let expected = "Projection: sqrt(CAST(#0 AS Float64))\ + let expected = "Projection: sqrt(CAST(#c0 AS Float64))\ \n TableScan: projection=Some([0])"; assert_eq!(format!("{:?}", plan), expected); Ok(()) @@ -102,11 +102,11 @@ mod tests { fn no_cast_f64_input() -> Result<()> { let schema = Schema::new(vec![Field::new("c0", DataType::Float64, true)]); let plan = LogicalPlanBuilder::scan("", "", &schema, None)? - .project(vec![sqrt(Expr::UnresolvedColumn("c0".to_owned()))])? + .project(vec![sqrt(col("c0"))])? .build()?; let ctx = ExecutionContext::new(); let plan = ctx.optimize(&plan)?; - let expected = "Projection: sqrt(#0)\ + let expected = "Projection: sqrt(#c0)\ \n TableScan: projection=Some([0])"; assert_eq!(format!("{:?}", plan), expected); Ok(()) diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index a3b32eb80cc..2828e5dd96f 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -74,7 +74,7 @@ pub trait AggregateExpr: Send + Sync { /// Create an aggregate expression for combining the results of accumulators from partitions. /// For example, to combine the results of a parallel SUM we just need to do another SUM, but /// to combine the results of parallel COUNT we would also use SUM. - fn create_reducer(&self, column_index: usize) -> Arc; + fn create_reducer(&self) -> Arc; } /// Aggregate accumulator diff --git a/rust/datafusion/src/execution/physical_plan/projection.rs b/rust/datafusion/src/execution/physical_plan/projection.rs index 7f39deda499..c10b7824205 100644 --- a/rust/datafusion/src/execution/physical_plan/projection.rs +++ b/rust/datafusion/src/execution/physical_plan/projection.rs @@ -141,7 +141,7 @@ mod tests { use super::*; use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions}; - use crate::execution::physical_plan::expressions::Column; + use crate::execution::physical_plan::expressions::col; use crate::test; #[test] @@ -154,12 +154,7 @@ mod tests { let csv = CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema), None, 1024)?; - let projection = ProjectionExec::try_new( - vec![Arc::new(Column::new(0, &schema.as_ref().field(0).name()))], - Arc::new(csv), - )?; - - assert_eq!("c1", projection.schema.field(0).name().as_str()); + let projection = ProjectionExec::try_new(vec![col("c1")], Arc::new(csv))?; let mut partition_count = 0; let mut row_count = 0; diff --git a/rust/datafusion/src/execution/physical_plan/selection.rs b/rust/datafusion/src/execution/physical_plan/selection.rs index b8efe421d9f..4f021ea1c8f 100644 --- a/rust/datafusion/src/execution/physical_plan/selection.rs +++ b/rust/datafusion/src/execution/physical_plan/selection.rs @@ -166,17 +166,9 @@ mod tests { CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema), None, 1024)?; let predicate: Arc = binary( - binary( - col(1, schema.as_ref()), - Operator::Gt, - lit(ScalarValue::UInt32(1)), - ), + binary(col("c2"), Operator::Gt, lit(ScalarValue::UInt32(1))), Operator::And, - binary( - col(1, schema.as_ref()), - Operator::Lt, - lit(ScalarValue::UInt32(4)), - ), + binary(col("c2"), Operator::Lt, lit(ScalarValue::UInt32(4))), ); let selection: Arc = diff --git a/rust/datafusion/src/execution/physical_plan/sort.rs b/rust/datafusion/src/execution/physical_plan/sort.rs index 7017e8a0363..c8b8dec4ee3 100644 --- a/rust/datafusion/src/execution/physical_plan/sort.rs +++ b/rust/datafusion/src/execution/physical_plan/sort.rs @@ -172,17 +172,17 @@ mod tests { vec![ // c1 string column PhysicalSortExpr { - expr: col(0, schema.as_ref()), + expr: col("c1"), options: SortOptions::default(), }, // c2 uin32 column PhysicalSortExpr { - expr: col(1, schema.as_ref()), + expr: col("c2"), options: SortOptions::default(), }, // c7 uin8 column PhysicalSortExpr { - expr: col(6, schema.as_ref()), + expr: col("c7"), options: SortOptions::default(), }, ], diff --git a/rust/datafusion/src/execution/table_impl.rs b/rust/datafusion/src/execution/table_impl.rs index 8f798bdf4df..401a1e3809d 100644 --- a/rust/datafusion/src/execution/table_impl.rs +++ b/rust/datafusion/src/execution/table_impl.rs @@ -23,8 +23,8 @@ use crate::arrow::datatypes::DataType; use crate::arrow::record_batch::RecordBatch; use crate::error::{ExecutionError, Result}; use crate::execution::context::ExecutionContext; -use crate::logicalplan::LogicalPlanBuilder; -use crate::logicalplan::{Expr, LogicalPlan}; +use crate::logicalplan::{col, Expr, LogicalPlan}; +use crate::logicalplan::{LogicalPlanBuilder, ScalarValue}; use crate::table::*; use arrow::datatypes::Schema; @@ -48,8 +48,9 @@ impl Table for TableImpl { .map(|name| { self.plan .schema() + // take the index to ensure that the column exists in the schema .index_of(name.to_owned()) - .and_then(|i| Ok(Expr::Column(i))) + .and_then(|_| Ok(col(name))) .map_err(|e| e.into()) }) .collect::>>()?; @@ -90,7 +91,8 @@ impl Table for TableImpl { /// Return an expression representing a column within this table fn col(&self, name: &str) -> Result { - Ok(Expr::Column(self.plan.schema().index_of(name)?)) + self.plan.schema().index_of(name)?; // check that the column exists + Ok(col(name)) } /// Create an expression to represent the min() aggregate function @@ -141,7 +143,12 @@ impl TableImpl { /// Determine the data type for a given expression fn get_data_type(&self, expr: &Expr) -> Result { match expr { - Expr::Column(i) => Ok(self.plan.schema().field(*i).data_type().clone()), + Expr::Column(name) => Ok(self + .plan + .schema() + .field_with_name(name)? + .data_type() + .clone()), _ => Err(ExecutionError::General(format!( "Could not determine data type for expr {:?}", expr diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 032bfb906d8..6ded26872dc 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -184,10 +184,8 @@ impl ScalarValue { pub enum Expr { /// An aliased expression Alias(Box, String), - /// index into a value within the row or complex value - Column(usize), - /// Reference to column by name - UnresolvedColumn(String), + /// column of a table scan + Column(String), /// literal value Literal(ScalarValue), /// binary expression e.g. "age > 21" @@ -248,10 +246,7 @@ impl Expr { pub fn get_type(&self, schema: &Schema) -> Result { match self { Expr::Alias(expr, _) => expr.get_type(schema), - Expr::Column(n) => Ok(schema.field(*n).data_type().clone()), - Expr::UnresolvedColumn(name) => { - Ok(schema.field_with_name(&name)?.data_type().clone()) - } + Expr::Column(name) => Ok(schema.field_with_name(name)?.data_type().clone()), Expr::Literal(l) => Ok(l.get_datatype()), Expr::Cast { data_type, .. } => Ok(data_type.clone()), Expr::ScalarFunction { return_type, .. } => Ok(return_type.clone()), @@ -368,14 +363,9 @@ impl Expr { } } -/// Create a column expression based on a column index -pub fn col_index(index: usize) -> Expr { - Expr::Column(index) -} - /// Create a column expression based on a column name pub fn col(name: &str) -> Expr { - Expr::UnresolvedColumn(name.to_owned()) + Expr::Column(name.to_owned()) } /// Whether it can be represented as a literal expression @@ -475,8 +465,7 @@ impl fmt::Debug for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias), - Expr::Column(i) => write!(f, "#{}", i), - Expr::UnresolvedColumn(name) => write!(f, "#{}", name), + Expr::Column(name) => write!(f, "#{}", name), Expr::Literal(v) => write!(f, "{:?}", v), Expr::Cast { expr, data_type } => { write!(f, "CAST({:?} AS {:?})", expr, data_type) @@ -925,7 +914,7 @@ impl LogicalPlanBuilder { (0..expr.len()).for_each(|i| match &expr[i] { Expr::Wildcard => { (0..input_schema.fields().len()) - .for_each(|i| expr_vec.push(col_index(i).clone())); + .for_each(|i| expr_vec.push(col(input_schema.field(i).name()))); } _ => expr_vec.push(expr[i].clone()), }); diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs index d1bba6edbdb..9efe8c58794 100644 --- a/rust/datafusion/src/optimizer/projection_push_down.rs +++ b/rust/datafusion/src/optimizer/projection_push_down.rs @@ -32,8 +32,10 @@ pub struct ProjectionPushDown {} impl OptimizerRule for ProjectionPushDown { fn optimize(&mut self, plan: &LogicalPlan) -> Result { - let mut accum: HashSet = HashSet::new(); - let mut mapping: HashMap = HashMap::new(); + // set of all columns refered from a scan. + let mut accum: HashSet = HashSet::new(); + // mapping + let mut mapping: HashMap = HashMap::new(); self.optimize_plan(plan, &mut accum, &mut mapping, false) } } @@ -47,14 +49,14 @@ impl ProjectionPushDown { fn optimize_plan( &self, plan: &LogicalPlan, - accum: &mut HashSet, - mapping: &mut HashMap, + accum: &mut HashSet, + mapping: &mut HashMap, has_projection: bool, ) -> Result { match plan { LogicalPlan::Projection { expr, input, .. } => { // collect all columns referenced by projection expressions - utils::exprlist_to_column_indices(&expr, accum)?; + utils::exprlist_to_column_names(&expr, accum)?; LogicalPlanBuilder::from( &self.optimize_plan(&input, accum, mapping, true)?, @@ -64,7 +66,7 @@ impl ProjectionPushDown { } LogicalPlan::Selection { expr, input } => { // collect all columns referenced by filter expression - utils::expr_to_column_indices(expr, accum)?; + utils::expr_to_column_names(expr, accum)?; LogicalPlanBuilder::from(&self.optimize_plan( &input, @@ -82,8 +84,8 @@ impl ProjectionPushDown { .. } => { // collect all columns referenced by grouping and aggregate expressions - utils::exprlist_to_column_indices(&group_expr, accum)?; - utils::exprlist_to_column_indices(&aggr_expr, accum)?; + utils::exprlist_to_column_names(&group_expr, accum)?; + utils::exprlist_to_column_names(&aggr_expr, accum)?; LogicalPlanBuilder::from(&self.optimize_plan( &input, @@ -99,7 +101,7 @@ impl ProjectionPushDown { } LogicalPlan::Sort { expr, input, .. } => { // collect all columns referenced by sort expressions - utils::exprlist_to_column_indices(&expr, accum)?; + utils::exprlist_to_column_names(&expr, accum)?; LogicalPlanBuilder::from(&self.optimize_plan( &input, @@ -124,7 +126,6 @@ impl ProjectionPushDown { &table_schema, projection, accum, - mapping, has_projection, )?; @@ -143,13 +144,8 @@ impl ProjectionPushDown { projection, .. } => { - let (projection, projected_schema) = get_projected_schema( - &schema, - projection, - accum, - mapping, - has_projection, - )?; + let (projection, projected_schema) = + get_projected_schema(&schema, projection, accum, has_projection)?; Ok(LogicalPlan::InMemoryScan { data: data.clone(), schema: schema.clone(), @@ -165,13 +161,8 @@ impl ProjectionPushDown { projection, .. } => { - let (projection, projected_schema) = get_projected_schema( - &schema, - projection, - accum, - mapping, - has_projection, - )?; + let (projection, projected_schema) = + get_projected_schema(&schema, projection, accum, has_projection)?; Ok(LogicalPlan::CsvScan { path: path.to_owned(), @@ -188,13 +179,8 @@ impl ProjectionPushDown { projection, .. } => { - let (projection, projected_schema) = get_projected_schema( - &schema, - projection, - accum, - mapping, - has_projection, - )?; + let (projection, projected_schema) = + get_projected_schema(&schema, projection, accum, has_projection)?; Ok(LogicalPlan::ParquetScan { path: path.to_owned(), @@ -227,7 +213,7 @@ impl ProjectionPushDown { fn rewrite_expr_list( &self, expr: &[Expr], - mapping: &HashMap, + mapping: &HashMap, ) -> Result> { Ok(expr .iter() @@ -235,17 +221,17 @@ impl ProjectionPushDown { .collect::>>()?) } - fn rewrite_expr(&self, expr: &Expr, mapping: &HashMap) -> Result { + fn rewrite_expr( + &self, + expr: &Expr, + mapping: &HashMap, + ) -> Result { match expr { Expr::Alias(expr, name) => Ok(Expr::Alias( Box::new(self.rewrite_expr(expr, mapping)?), name.clone(), )), - Expr::Column(i) => Ok(Expr::Column(self.new_index(mapping, i)?)), - Expr::UnresolvedColumn(_) => Err(ExecutionError::ExecutionError( - "Columns need to be resolved before projection push down rule can run" - .to_owned(), - )), + Expr::Column(_) => Ok(expr.clone()), Expr::Literal(_) => Ok(expr.clone()), Expr::Not(e) => Ok(Expr::Not(Box::new(self.rewrite_expr(e, mapping)?))), Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(self.rewrite_expr(e, mapping)?))), @@ -293,22 +279,12 @@ impl ProjectionPushDown { )), } } - - fn new_index(&self, mapping: &HashMap, i: &usize) -> Result { - match mapping.get(i) { - Some(j) => Ok(*j), - _ => Err(ExecutionError::InternalError( - "Internal error computing new column index".to_string(), - )), - } - } } fn get_projected_schema( table_schema: &Schema, projection: &Option>, - accum: &HashSet, - mapping: &mut HashMap, + accum: &HashSet, has_projection: bool, ) -> Result<(Vec, Schema)> { if projection.is_some() { @@ -318,8 +294,11 @@ fn get_projected_schema( } // once we reach the table scan, we can use the accumulated set of column - // indexes as the projection in the table scan - let mut projection = accum.iter().map(|i| *i).collect::>(); + // names to construct the set of column indexes in the scan + let mut projection: Vec = accum + .iter() + .map(|name| table_schema.index_of(name).unwrap()) + .collect(); if projection.is_empty() { if has_projection { @@ -346,21 +325,6 @@ fn get_projected_schema( projected_fields.push(table_schema.fields()[*i].clone()); } - // now that the table scan is returning a different schema we need to - // create a mapping from the original column index to the - // new column index so that we can rewrite expressions as - // we walk back up the tree - - if mapping.len() != 0 { - return Err(ExecutionError::InternalError("illegal state".to_string())); - } - - for i in 0..table_schema.fields().len() { - if let Some(n) = projection.iter().position(|v| *v == i) { - mapping.insert(i, n); - } - } - Ok((projection, Schema::new(projected_fields))) } @@ -368,8 +332,8 @@ fn get_projected_schema( mod tests { use super::*; - use crate::logicalplan::lit; use crate::logicalplan::Expr::*; + use crate::logicalplan::{col, lit}; use crate::test::*; use arrow::datatypes::DataType; @@ -378,10 +342,10 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .aggregate(vec![], vec![max(Column(1))])? + .aggregate(vec![], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#0)]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\ \n TableScan: test projection=Some([1])"; assert_optimized_plan_eq(&plan, expected); @@ -394,10 +358,10 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .aggregate(vec![Column(2)], vec![max(Column(1))])? + .aggregate(vec![col("c")], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[#1]], aggr=[[MAX(#0)]]\ + let expected = "Aggregate: groupBy=[[#c]], aggr=[[MAX(#b)]]\ \n TableScan: test projection=Some([1, 2])"; assert_optimized_plan_eq(&plan, expected); @@ -410,12 +374,12 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(&table_scan) - .filter(Column(2))? - .aggregate(vec![], vec![max(Column(1))])? + .filter(col("c"))? + .aggregate(vec![], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#0)]]\ - \n Selection: #1\ + let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\ + \n Selection: #c\ \n TableScan: test projection=Some([1, 2])"; assert_optimized_plan_eq(&plan, expected); @@ -429,12 +393,12 @@ mod tests { let projection = LogicalPlanBuilder::from(&table_scan) .project(vec![Cast { - expr: Box::new(Column(2)), + expr: Box::new(col("c")), data_type: DataType::Float64, }])? .build()?; - let expected = "Projection: CAST(#0 AS Float64)\ + let expected = "Projection: CAST(#c AS Float64)\ \n TableScan: test projection=Some([2])"; assert_optimized_plan_eq(&projection, expected); @@ -449,12 +413,12 @@ mod tests { assert_fields_eq(&table_scan, vec!["a", "b", "c"]); let plan = LogicalPlanBuilder::from(&table_scan) - .project(vec![Column(0), Column(1)])? + .project(vec![col("a"), col("b")])? .build()?; assert_fields_eq(&plan, vec!["a", "b"]); - let expected = "Projection: #0, #1\ + let expected = "Projection: #a, #b\ \n TableScan: test projection=Some([0, 1])"; assert_optimized_plan_eq(&plan, expected); @@ -469,14 +433,14 @@ mod tests { assert_fields_eq(&table_scan, vec!["a", "b", "c"]); let plan = LogicalPlanBuilder::from(&table_scan) - .project(vec![Column(2), Column(0)])? + .project(vec![col("c"), col("a")])? .limit(5)? .build()?; assert_fields_eq(&plan, vec!["c", "a"]); let expected = "Limit: 5\ - \n Projection: #1, #0\ + \n Projection: #c, #a\ \n TableScan: test projection=Some([0, 2])"; assert_optimized_plan_eq(&plan, expected); diff --git a/rust/datafusion/src/optimizer/resolve_columns.rs b/rust/datafusion/src/optimizer/resolve_columns.rs index 61b2e817361..bdaa867b3b6 100644 --- a/rust/datafusion/src/optimizer/resolve_columns.rs +++ b/rust/datafusion/src/optimizer/resolve_columns.rs @@ -75,7 +75,6 @@ fn rewrite_expr_list(expr: &[Expr], schema: &Schema) -> Result> { fn rewrite_expr(expr: &Expr, schema: &Schema) -> Result { match expr { Expr::Alias(expr, alias) => Ok(rewrite_expr(&expr, schema)?.alias(&alias)), - Expr::UnresolvedColumn(name) => Ok(Expr::Column(schema.index_of(&name)?)), Expr::BinaryExpr { left, op, right } => Ok(Expr::BinaryExpr { left: Box::new(rewrite_expr(&left, schema)?), op: op.clone(), @@ -140,7 +139,7 @@ mod tests { assert_eq!(format!("{:?}", plan), expected); // optimized plan has resolved columns - let expected = "Aggregate: groupBy=[[#0]], aggr=[[MAX(#1)]]\n TableScan: test projection=None"; + let expected = "Aggregate: groupBy=[[#a]], aggr=[[MAX(#b)]]\n TableScan: test projection=None"; assert_optimized_plan_eq(&plan, expected); Ok(()) diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index 7639b5bc125..a03a92cdfe5 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -132,7 +132,6 @@ impl<'a> TypeCoercionRule<'a> { alias.to_owned(), )), Expr::Literal(_) => Ok(expr.clone()), - Expr::UnresolvedColumn(_) => Ok(expr.clone()), Expr::Not(_) => Ok(expr.clone()), Expr::Sort { .. } => Ok(expr.clone()), Expr::Wildcard { .. } => Err(ExecutionError::General( @@ -183,7 +182,6 @@ mod tests { use super::*; use crate::execution::context::ExecutionContext; use crate::execution::physical_plan::csv::CsvReadOptions; - use crate::logicalplan::Expr::*; use crate::logicalplan::{col, Operator}; use crate::test::arrow_testdata_path; use arrow::datatypes::{DataType, Field, Schema}; @@ -214,12 +212,12 @@ mod tests { binary_cast_test( DataType::Int32, DataType::Int64, - "CAST(#0 AS Int64) Plus #1", + "CAST(#c0 AS Int64) Plus #c1", ); binary_cast_test( DataType::Int64, DataType::Int32, - "#0 Plus CAST(#1 AS Int64)", + "#c0 Plus CAST(#c1 AS Int64)", ); } @@ -228,12 +226,12 @@ mod tests { binary_cast_test( DataType::Float32, DataType::Float64, - "CAST(#0 AS Float64) Plus #1", + "CAST(#c0 AS Float64) Plus #c1", ); binary_cast_test( DataType::Float64, DataType::Float32, - "#0 Plus CAST(#1 AS Float64)", + "#c0 Plus CAST(#c1 AS Float64)", ); } @@ -242,12 +240,12 @@ mod tests { binary_cast_test( DataType::Int32, DataType::Float32, - "CAST(#0 AS Float32) Plus #1", + "CAST(#c0 AS Float32) Plus #c1", ); binary_cast_test( DataType::Float32, DataType::Int32, - "#0 Plus CAST(#1 AS Float32)", + "#c0 Plus CAST(#c1 AS Float32)", ); } @@ -256,12 +254,12 @@ mod tests { binary_cast_test( DataType::UInt32, DataType::Int64, - "CAST(#0 AS Int64) Plus #1", + "CAST(#c0 AS Int64) Plus #c1", ); binary_cast_test( DataType::Int64, DataType::UInt32, - "#0 Plus CAST(#1 AS Int64)", + "#c0 Plus CAST(#c1 AS Int64)", ); } @@ -272,9 +270,9 @@ mod tests { ]); let expr = Expr::BinaryExpr { - left: Box::new(Column(0)), + left: Box::new(col("c0")), op: Operator::Plus, - right: Box::new(Column(1)), + right: Box::new(col("c1")), }; let ctx = ExecutionContext::new(); diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index cdbad9eae66..7f78f485c74 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -25,46 +25,42 @@ use crate::error::{ExecutionError, Result}; use crate::logicalplan::Expr; /// Recursively walk a list of expression trees, collecting the unique set of column -/// indexes referenced in the expression -pub fn exprlist_to_column_indices( +/// names referenced in the expression +pub fn exprlist_to_column_names( expr: &[Expr], - accum: &mut HashSet, + accum: &mut HashSet, ) -> Result<()> { for e in expr { - expr_to_column_indices(e, accum)?; + expr_to_column_names(e, accum)?; } Ok(()) } -/// Recursively walk an expression tree, collecting the unique set of column indexes +/// Recursively walk an expression tree, collecting the unique set of column names /// referenced in the expression -pub fn expr_to_column_indices(expr: &Expr, accum: &mut HashSet) -> Result<()> { +pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet) -> Result<()> { match expr { - Expr::Alias(expr, _) => expr_to_column_indices(expr, accum), - Expr::Column(i) => { - accum.insert(*i); + Expr::Alias(expr, _) => expr_to_column_names(expr, accum), + Expr::Column(name) => { + accum.insert(name.clone()); Ok(()) } - Expr::UnresolvedColumn(_) => Err(ExecutionError::ExecutionError( - "Columns need to be resolved before column indexes resolution rule can run" - .to_owned(), - )), Expr::Literal(_) => { // not needed Ok(()) } - Expr::Not(e) => expr_to_column_indices(e, accum), - Expr::IsNull(e) => expr_to_column_indices(e, accum), - Expr::IsNotNull(e) => expr_to_column_indices(e, accum), + Expr::Not(e) => expr_to_column_names(e, accum), + Expr::IsNull(e) => expr_to_column_names(e, accum), + Expr::IsNotNull(e) => expr_to_column_names(e, accum), Expr::BinaryExpr { left, right, .. } => { - expr_to_column_indices(left, accum)?; - expr_to_column_indices(right, accum)?; + expr_to_column_names(left, accum)?; + expr_to_column_names(right, accum)?; Ok(()) } - Expr::Cast { expr, .. } => expr_to_column_indices(expr, accum), - Expr::Sort { expr, .. } => expr_to_column_indices(expr, accum), - Expr::AggregateFunction { args, .. } => exprlist_to_column_indices(args, accum), - Expr::ScalarFunction { args, .. } => exprlist_to_column_indices(args, accum), + Expr::Cast { expr, .. } => expr_to_column_names(expr, accum), + Expr::Sort { expr, .. } => expr_to_column_names(expr, accum), + Expr::AggregateFunction { args, .. } => exprlist_to_column_names(args, accum), + Expr::ScalarFunction { args, .. } => exprlist_to_column_names(args, accum), Expr::Wildcard => Err(ExecutionError::General( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), @@ -77,18 +73,7 @@ pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result { Expr::Alias(expr, name) => { Ok(Field::new(name, expr.get_type(input_schema)?, true)) } - Expr::UnresolvedColumn(name) => Ok(input_schema.field_with_name(&name)?.clone()), - Expr::Column(i) => { - let input_schema_field_count = input_schema.fields().len(); - if *i < input_schema_field_count { - Ok(input_schema.fields()[*i].clone()) - } else { - Err(ExecutionError::General(format!( - "Column index {} out of bounds for input schema with {} field(s)", - *i, input_schema_field_count - ))) - } - } + Expr::Column(name) => Ok(input_schema.field_with_name(name)?.clone()), Expr::Literal(ref lit) => Ok(Field::new("lit", lit.get_datatype(), true)), Expr::ScalarFunction { ref name, @@ -248,29 +233,29 @@ fn _get_supertype(l: &DataType, r: &DataType) -> Option { #[cfg(test)] mod tests { use super::*; - use crate::logicalplan::Expr; + use crate::logicalplan::col; use arrow::datatypes::DataType; use std::collections::HashSet; #[test] fn test_collect_expr() -> Result<()> { - let mut accum: HashSet = HashSet::new(); - expr_to_column_indices( + let mut accum: HashSet = HashSet::new(); + expr_to_column_names( &Expr::Cast { - expr: Box::new(Expr::Column(3)), + expr: Box::new(col("a")), data_type: DataType::Float64, }, &mut accum, )?; - expr_to_column_indices( + expr_to_column_names( &Expr::Cast { - expr: Box::new(Expr::Column(3)), + expr: Box::new(col("a")), data_type: DataType::Float64, }, &mut accum, )?; assert_eq!(1, accum.len()); - assert!(accum.contains(&3)); + assert!(accum.contains("a")); Ok(()) } } diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 8c74c5a8ad8..81aeab713c4 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -173,38 +173,7 @@ impl SqlToRel { .aggregate(group_expr, aggr_expr)? .build()?; - // wrap in projection to preserve final order of fields - let mut projected_fields = Vec::with_capacity(group_by_count + aggr_count); - let mut group_expr_index = 0; - let mut aggr_expr_index = 0; - for i in 0..projection_expr.len() { - if is_aggregate_expr(&projection_expr[i]) { - projected_fields.push(group_by_count + aggr_expr_index); - aggr_expr_index += 1; - } else { - projected_fields.push(group_expr_index); - group_expr_index += 1; - } - } - - // determine if projection is needed or not - // NOTE this would be better done later in a query optimizer rule - let mut projection_needed = false; - for i in 0..projected_fields.len() { - if projected_fields[i] != i { - projection_needed = true; - break; - } - } - - if projection_needed { - self.project( - &plan, - projected_fields.iter().map(|i| Expr::Column(*i)).collect(), - ) - } else { - Ok(plan) - } + Ok(plan) } /// Wrap a plan in a limit @@ -273,16 +242,14 @@ impl SqlToRel { alias.to_owned(), )), - ASTNode::SQLIdentifier(ref id) => { - match schema.fields().iter().position(|c| c.name().eq(id)) { - Some(index) => Ok(Expr::Column(index)), - None => Err(ExecutionError::ExecutionError(format!( - "Invalid identifier '{}' for schema {}", - id, - schema.to_string() - ))), - } - } + ASTNode::SQLIdentifier(ref id) => match schema.field_with_name(id) { + Ok(field) => Ok(Expr::Column(field.name().clone())), + Err(_) => Err(ExecutionError::ExecutionError(format!( + "Invalid identifier '{}' for schema {}", + id, + schema.to_string() + ))), + }, ASTNode::SQLWildcard => Ok(Expr::Wildcard), @@ -483,8 +450,8 @@ mod tests { fn select_simple_selection() { let sql = "SELECT id, first_name, last_name \ FROM person WHERE state = 'CO'"; - let expected = "Projection: #0, #1, #2\ - \n Selection: #4 Eq Utf8(\"CO\")\ + let expected = "Projection: #id, #first_name, #last_name\ + \n Selection: #state Eq Utf8(\"CO\")\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -493,8 +460,8 @@ mod tests { fn select_neg_selection() { let sql = "SELECT id, first_name, last_name \ FROM person WHERE NOT state"; - let expected = "Projection: #0, #1, #2\ - \n Selection: NOT #4\ + let expected = "Projection: #id, #first_name, #last_name\ + \n Selection: NOT #state\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -503,8 +470,8 @@ mod tests { fn select_compound_selection() { let sql = "SELECT id, first_name, last_name \ FROM person WHERE state = 'CO' AND age >= 21 AND age <= 65"; - let expected = "Projection: #0, #1, #2\ - \n Selection: #4 Eq Utf8(\"CO\") And #3 GtEq Int64(21) And #3 LtEq Int64(65)\ + let expected = "Projection: #id, #first_name, #last_name\ + \n Selection: #state Eq Utf8(\"CO\") And #age GtEq Int64(21) And #age LtEq Int64(65)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -513,8 +480,8 @@ mod tests { fn test_timestamp_selection() { let sql = "SELECT state FROM person WHERE birth_date < CAST (158412331400600000 as timestamp)"; - let expected = "Projection: #4\ - \n Selection: #6 Lt CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\ + let expected = "Projection: #state\ + \n Selection: #birth_date Lt CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -530,13 +497,13 @@ mod tests { AND age >= 21 \ AND age < 65 \ AND age <= 65"; - let expected = "Projection: #3, #1, #2\ - \n Selection: #3 Eq Int64(21) \ - And #3 NotEq Int64(21) \ - And #3 Gt Int64(21) \ - And #3 GtEq Int64(21) \ - And #3 Lt Int64(65) \ - And #3 LtEq Int64(65)\ + let expected = "Projection: #age, #first_name, #last_name\ + \n Selection: #age Eq Int64(21) \ + And #age NotEq Int64(21) \ + And #age Gt Int64(21) \ + And #age GtEq Int64(21) \ + And #age Lt Int64(65) \ + And #age LtEq Int64(65)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -545,7 +512,7 @@ mod tests { fn select_simple_aggregate() { quick_test( "SELECT MIN(age) FROM person", - "Aggregate: groupBy=[[]], aggr=[[MIN(#3)]]\ + "Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\ \n TableScan: person projection=None", ); } @@ -554,7 +521,7 @@ mod tests { fn test_sum_aggregate() { quick_test( "SELECT SUM(age) from person", - "Aggregate: groupBy=[[]], aggr=[[SUM(#3)]]\ + "Aggregate: groupBy=[[]], aggr=[[SUM(#age)]]\ \n TableScan: person projection=None", ); } @@ -563,7 +530,7 @@ mod tests { fn select_simple_aggregate_with_groupby() { quick_test( "SELECT state, MIN(age), MAX(age) FROM person GROUP BY state", - "Aggregate: groupBy=[[#4]], aggr=[[MIN(#3), MAX(#3)]]\ + "Aggregate: groupBy=[[#state]], aggr=[[MIN(#age), MAX(#age)]]\ \n TableScan: person projection=None", ); } @@ -572,7 +539,7 @@ mod tests { fn test_wildcard() { quick_test( "SELECT * from person", - "Projection: #0, #1, #2, #3, #4, #5, #6\ + "Projection: #id, #first_name, #last_name, #age, #state, #salary, #birth_date\ \n TableScan: person projection=None", ); } @@ -588,7 +555,7 @@ mod tests { #[test] fn select_count_column() { let sql = "SELECT COUNT(id) FROM person"; - let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(#0)]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(#id)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -596,7 +563,7 @@ mod tests { #[test] fn select_scalar_func() { let sql = "SELECT sqrt(age) FROM person"; - let expected = "Projection: sqrt(CAST(#3 AS Float64))\ + let expected = "Projection: sqrt(CAST(#age AS Float64))\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -604,7 +571,7 @@ mod tests { #[test] fn select_aliased_scalar_func() { let sql = "SELECT sqrt(age) AS square_people FROM person"; - let expected = "Projection: sqrt(CAST(#3 AS Float64)) AS square_people\ + let expected = "Projection: sqrt(CAST(#age AS Float64)) AS square_people\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -612,8 +579,8 @@ mod tests { #[test] fn select_order_by() { let sql = "SELECT id FROM person ORDER BY id"; - let expected = "Sort: #0 ASC NULLS FIRST\ - \n Projection: #0\ + let expected = "Sort: #id ASC NULLS FIRST\ + \n Projection: #id\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -621,8 +588,8 @@ mod tests { #[test] fn select_order_by_desc() { let sql = "SELECT id FROM person ORDER BY id DESC"; - let expected = "Sort: #0 DESC NULLS FIRST\ - \n Projection: #0\ + let expected = "Sort: #id DESC NULLS FIRST\ + \n Projection: #id\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -631,15 +598,15 @@ mod tests { fn select_order_by_nulls_last() { quick_test( "SELECT id FROM person ORDER BY id DESC NULLS LAST", - "Sort: #0 DESC NULLS LAST\ - \n Projection: #0\ + "Sort: #id DESC NULLS LAST\ + \n Projection: #id\ \n TableScan: person projection=None", ); quick_test( "SELECT id FROM person ORDER BY id NULLS LAST", - "Sort: #0 ASC NULLS LAST\ - \n Projection: #0\ + "Sort: #id ASC NULLS LAST\ + \n Projection: #id\ \n TableScan: person projection=None", ); } @@ -647,7 +614,7 @@ mod tests { #[test] fn select_group_by() { let sql = "SELECT state FROM person GROUP BY state"; - let expected = "Aggregate: groupBy=[[#4]], aggr=[[]]\ + let expected = "Aggregate: groupBy=[[#state]], aggr=[[]]\ \n TableScan: person projection=None"; quick_test(sql, expected); From 5263bfbdb6e508f39cea1fe33ee26279c3762417 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 12 Jul 2020 16:55:43 +0200 Subject: [PATCH 2/8] Removed ResolveColumnsRule. This is no longer needed. --- rust/datafusion/src/execution/context.rs | 2 - rust/datafusion/src/optimizer/mod.rs | 1 - .../src/optimizer/resolve_columns.rs | 158 ------------------ 3 files changed, 161 deletions(-) delete mode 100644 rust/datafusion/src/optimizer/resolve_columns.rs diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 221de3e6356..9341982397f 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -54,7 +54,6 @@ use crate::execution::table_impl::TableImpl; use crate::logicalplan::*; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::projection_push_down::ProjectionPushDown; -use crate::optimizer::resolve_columns::ResolveColumnsRule; use crate::optimizer::type_coercion::TypeCoercionRule; use crate::sql::parser::{DFASTNode, DFParser, FileType}; use crate::sql::planner::{SchemaProvider, SqlToRel}; @@ -275,7 +274,6 @@ impl ExecutionContext { /// Optimize the logical plan by applying optimizer rules pub fn optimize(&self, plan: &LogicalPlan) -> Result { let rules: Vec> = vec![ - Box::new(ResolveColumnsRule::new()), Box::new(ProjectionPushDown::new()), Box::new(TypeCoercionRule::new(&self.scalar_functions)), ]; diff --git a/rust/datafusion/src/optimizer/mod.rs b/rust/datafusion/src/optimizer/mod.rs index 1ac97b1c30f..e60c7db824b 100644 --- a/rust/datafusion/src/optimizer/mod.rs +++ b/rust/datafusion/src/optimizer/mod.rs @@ -20,6 +20,5 @@ pub mod optimizer; pub mod projection_push_down; -pub mod resolve_columns; pub mod type_coercion; pub mod utils; diff --git a/rust/datafusion/src/optimizer/resolve_columns.rs b/rust/datafusion/src/optimizer/resolve_columns.rs deleted file mode 100644 index bdaa867b3b6..00000000000 --- a/rust/datafusion/src/optimizer/resolve_columns.rs +++ /dev/null @@ -1,158 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Optimizer rule to replace UnresolvedColumns with Columns - -use crate::error::Result; -use crate::logicalplan::LogicalPlan; -use crate::logicalplan::{Expr, LogicalPlanBuilder}; -use crate::optimizer::optimizer::OptimizerRule; -use arrow::datatypes::Schema; - -/// Replace UnresolvedColumns with Columns -pub struct ResolveColumnsRule {} - -impl ResolveColumnsRule { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl OptimizerRule for ResolveColumnsRule { - fn optimize(&mut self, plan: &LogicalPlan) -> Result { - match plan { - LogicalPlan::Projection { input, expr, .. } => { - Ok(LogicalPlanBuilder::from(&self.optimize(input.as_ref())?) - .project(rewrite_expr_list(expr, &input.schema())?)? - .build()?) - } - LogicalPlan::Selection { expr, input } => Ok(LogicalPlanBuilder::from(input) - .filter(rewrite_expr(expr, &input.schema())?)? - .build()?), - LogicalPlan::Aggregate { - input, - group_expr, - aggr_expr, - .. - } => Ok(LogicalPlanBuilder::from(input) - .aggregate( - rewrite_expr_list(group_expr, &input.schema())?, - rewrite_expr_list(aggr_expr, &input.schema())?, - )? - .build()?), - LogicalPlan::Sort { input, expr, .. } => { - Ok(LogicalPlanBuilder::from(&self.optimize(input)?) - .sort(rewrite_expr_list(expr, &input.schema())?)? - .build()?) - } - _ => Ok(plan.clone()), - } - } -} - -fn rewrite_expr_list(expr: &[Expr], schema: &Schema) -> Result> { - Ok(expr - .iter() - .map(|e| rewrite_expr(e, schema)) - .collect::>>()?) -} - -fn rewrite_expr(expr: &Expr, schema: &Schema) -> Result { - match expr { - Expr::Alias(expr, alias) => Ok(rewrite_expr(&expr, schema)?.alias(&alias)), - Expr::BinaryExpr { left, op, right } => Ok(Expr::BinaryExpr { - left: Box::new(rewrite_expr(&left, schema)?), - op: op.clone(), - right: Box::new(rewrite_expr(&right, schema)?), - }), - Expr::Not(expr) => Ok(Expr::Not(Box::new(rewrite_expr(&expr, schema)?))), - Expr::IsNotNull(expr) => { - Ok(Expr::IsNotNull(Box::new(rewrite_expr(&expr, schema)?))) - } - Expr::IsNull(expr) => Ok(Expr::IsNull(Box::new(rewrite_expr(&expr, schema)?))), - Expr::Cast { expr, data_type } => Ok(Expr::Cast { - expr: Box::new(rewrite_expr(&expr, schema)?), - data_type: data_type.clone(), - }), - Expr::Sort { - expr, - asc, - nulls_first, - } => Ok(Expr::Sort { - expr: Box::new(rewrite_expr(&expr, schema)?), - asc: asc.clone(), - nulls_first: nulls_first.clone(), - }), - Expr::ScalarFunction { - name, - args, - return_type, - } => Ok(Expr::ScalarFunction { - name: name.clone(), - args: rewrite_expr_list(args, schema)?, - return_type: return_type.clone(), - }), - Expr::AggregateFunction { - name, - args, - return_type, - } => Ok(Expr::AggregateFunction { - name: name.clone(), - args: rewrite_expr_list(args, schema)?, - return_type: return_type.clone(), - }), - _ => Ok(expr.clone()), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::logicalplan::col; - use crate::test::*; - - #[test] - fn aggregate_no_group_by() -> Result<()> { - let table_scan = test_table_scan()?; - - let plan = LogicalPlanBuilder::from(&table_scan) - .aggregate(vec![col("a")], vec![max(col("b"))])? - .build()?; - - // plan has unresolve columns - let expected = "Aggregate: groupBy=[[#a]], aggr=[[MAX(#b)]]\n TableScan: test projection=None"; - assert_eq!(format!("{:?}", plan), expected); - - // optimized plan has resolved columns - let expected = "Aggregate: groupBy=[[#a]], aggr=[[MAX(#b)]]\n TableScan: test projection=None"; - assert_optimized_plan_eq(&plan, expected); - - Ok(()) - } - - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let optimized_plan = optimize(plan).expect("failed to optimize plan"); - let formatted_plan = format!("{:?}", optimized_plan); - assert_eq!(formatted_plan, expected); - } - - fn optimize(plan: &LogicalPlan) -> Result { - let mut rule = ResolveColumnsRule::new(); - rule.optimize(plan) - } -} From c3da7a9354641973c5f73fc0c600fd77bb0d0ca4 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 12 Jul 2020 17:56:39 +0200 Subject: [PATCH 3/8] Simplified ProjectionPushDown. --- .../src/optimizer/projection_push_down.rs | 174 +++++------------- 1 file changed, 41 insertions(+), 133 deletions(-) diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs index 9efe8c58794..aacc2aef9ef 100644 --- a/rust/datafusion/src/optimizer/projection_push_down.rs +++ b/rust/datafusion/src/optimizer/projection_push_down.rs @@ -20,11 +20,10 @@ use crate::error::{ExecutionError, Result}; use crate::logicalplan::LogicalPlan; -use crate::logicalplan::{Expr, LogicalPlanBuilder}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use arrow::datatypes::{Field, Schema}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; /// Projection Push Down optimizer rule ensures that only referenced columns are /// loaded into memory @@ -34,9 +33,7 @@ impl OptimizerRule for ProjectionPushDown { fn optimize(&mut self, plan: &LogicalPlan) -> Result { // set of all columns refered from a scan. let mut accum: HashSet = HashSet::new(); - // mapping - let mut mapping: HashMap = HashMap::new(); - self.optimize_plan(plan, &mut accum, &mut mapping, false) + self.optimize_plan(plan, &mut accum, false) } } @@ -50,71 +47,69 @@ impl ProjectionPushDown { &self, plan: &LogicalPlan, accum: &mut HashSet, - mapping: &mut HashMap, has_projection: bool, ) -> Result { match plan { - LogicalPlan::Projection { expr, input, .. } => { + LogicalPlan::Projection { + expr, + input, + schema, + } => { // collect all columns referenced by projection expressions utils::exprlist_to_column_names(&expr, accum)?; - LogicalPlanBuilder::from( - &self.optimize_plan(&input, accum, mapping, true)?, - ) - .project(self.rewrite_expr_list(expr, mapping)?)? - .build() + Ok(LogicalPlan::Projection { + expr: expr.clone(), + input: Box::new(self.optimize_plan(&input, accum, true)?), + schema: schema.clone(), + }) } LogicalPlan::Selection { expr, input } => { // collect all columns referenced by filter expression utils::expr_to_column_names(expr, accum)?; - LogicalPlanBuilder::from(&self.optimize_plan( - &input, - accum, - mapping, - has_projection, - )?) - .filter(self.rewrite_expr(expr, mapping)?)? - .build() + Ok(LogicalPlan::Selection { + expr: expr.clone(), + input: Box::new(self.optimize_plan(&input, accum, has_projection)?), + }) } LogicalPlan::Aggregate { input, group_expr, aggr_expr, - .. + schema, } => { // collect all columns referenced by grouping and aggregate expressions utils::exprlist_to_column_names(&group_expr, accum)?; utils::exprlist_to_column_names(&aggr_expr, accum)?; - LogicalPlanBuilder::from(&self.optimize_plan( - &input, - accum, - mapping, - has_projection, - )?) - .aggregate( - self.rewrite_expr_list(group_expr, mapping)?, - self.rewrite_expr_list(aggr_expr, mapping)?, - )? - .build() + Ok(LogicalPlan::Aggregate { + input: Box::new(self.optimize_plan(&input, accum, has_projection)?), + group_expr: group_expr.clone(), + aggr_expr: aggr_expr.clone(), + schema: schema.clone(), + }) } - LogicalPlan::Sort { expr, input, .. } => { + LogicalPlan::Sort { + expr, + input, + schema, + } => { // collect all columns referenced by sort expressions utils::exprlist_to_column_names(&expr, accum)?; - LogicalPlanBuilder::from(&self.optimize_plan( - &input, - accum, - mapping, - has_projection, - )?) - .sort(self.rewrite_expr_list(expr, mapping)?)? - .build() + Ok(LogicalPlan::Sort { + expr: expr.clone(), + input: Box::new(self.optimize_plan(&input, accum, has_projection)?), + schema: schema.clone(), + }) } - LogicalPlan::EmptyRelation { schema } => Ok(LogicalPlan::EmptyRelation { + LogicalPlan::Limit { n, input, schema } => Ok(LogicalPlan::Limit { + n: n.clone(), + input: Box::new(self.optimize_plan(&input, accum, has_projection)?), schema: schema.clone(), }), + LogicalPlan::EmptyRelation { .. } => Ok(plan.clone()), LogicalPlan::TableScan { schema_name, table_name, @@ -134,8 +129,8 @@ impl ProjectionPushDown { schema_name: schema_name.to_string(), table_name: table_name.to_string(), table_schema: table_schema.clone(), - projected_schema: Box::new(projected_schema), projection: Some(projection), + projected_schema: Box::new(projected_schema), }) } LogicalPlan::InMemoryScan { @@ -189,94 +184,7 @@ impl ProjectionPushDown { projected_schema: Box::new(projected_schema), }) } - LogicalPlan::Limit { n, input, .. } => LogicalPlanBuilder::from( - &self.optimize_plan(&input, accum, mapping, has_projection)?, - ) - .limit(*n)? - .build(), - LogicalPlan::CreateExternalTable { - schema, - name, - location, - file_type, - has_header, - } => Ok(LogicalPlan::CreateExternalTable { - schema: schema.clone(), - name: name.to_string(), - location: location.to_string(), - file_type: file_type.clone(), - has_header: *has_header, - }), - } - } - - fn rewrite_expr_list( - &self, - expr: &[Expr], - mapping: &HashMap, - ) -> Result> { - Ok(expr - .iter() - .map(|e| self.rewrite_expr(e, mapping)) - .collect::>>()?) - } - - fn rewrite_expr( - &self, - expr: &Expr, - mapping: &HashMap, - ) -> Result { - match expr { - Expr::Alias(expr, name) => Ok(Expr::Alias( - Box::new(self.rewrite_expr(expr, mapping)?), - name.clone(), - )), - Expr::Column(_) => Ok(expr.clone()), - Expr::Literal(_) => Ok(expr.clone()), - Expr::Not(e) => Ok(Expr::Not(Box::new(self.rewrite_expr(e, mapping)?))), - Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(self.rewrite_expr(e, mapping)?))), - Expr::IsNotNull(e) => { - Ok(Expr::IsNotNull(Box::new(self.rewrite_expr(e, mapping)?))) - } - Expr::BinaryExpr { left, op, right } => Ok(Expr::BinaryExpr { - left: Box::new(self.rewrite_expr(left, mapping)?), - op: op.clone(), - right: Box::new(self.rewrite_expr(right, mapping)?), - }), - Expr::Cast { expr, data_type } => Ok(Expr::Cast { - expr: Box::new(self.rewrite_expr(expr, mapping)?), - data_type: data_type.clone(), - }), - Expr::Sort { - expr, - asc, - nulls_first, - } => Ok(Expr::Sort { - expr: Box::new(self.rewrite_expr(expr, mapping)?), - asc: *asc, - nulls_first: *nulls_first, - }), - Expr::AggregateFunction { - name, - args, - return_type, - } => Ok(Expr::AggregateFunction { - name: name.to_string(), - args: self.rewrite_expr_list(args, mapping)?, - return_type: return_type.clone(), - }), - Expr::ScalarFunction { - name, - args, - return_type, - } => Ok(Expr::ScalarFunction { - name: name.to_string(), - args: self.rewrite_expr_list(args, mapping)?, - return_type: return_type.clone(), - }), - Expr::Wildcard => Err(ExecutionError::General( - "Wildcard expressions are not valid in a logical query plan".to_owned(), - )), + LogicalPlan::CreateExternalTable { .. } => Ok(plan.clone()), } } } @@ -332,8 +240,8 @@ fn get_projected_schema( mod tests { use super::*; - use crate::logicalplan::Expr::*; use crate::logicalplan::{col, lit}; + use crate::logicalplan::{Expr, LogicalPlanBuilder}; use crate::test::*; use arrow::datatypes::DataType; @@ -392,7 +300,7 @@ mod tests { let table_scan = test_table_scan()?; let projection = LogicalPlanBuilder::from(&table_scan) - .project(vec![Cast { + .project(vec![Expr::Cast { expr: Box::new(col("c")), data_type: DataType::Float64, }])? From fd961346a3de68731ab83c3aa9b41f1020395e53 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Fri, 17 Jul 2020 10:49:35 +0200 Subject: [PATCH 4/8] Fixed import. --- rust/datafusion/src/execution/table_impl.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rust/datafusion/src/execution/table_impl.rs b/rust/datafusion/src/execution/table_impl.rs index 401a1e3809d..7494ba51bd9 100644 --- a/rust/datafusion/src/execution/table_impl.rs +++ b/rust/datafusion/src/execution/table_impl.rs @@ -23,8 +23,7 @@ use crate::arrow::datatypes::DataType; use crate::arrow::record_batch::RecordBatch; use crate::error::{ExecutionError, Result}; use crate::execution::context::ExecutionContext; -use crate::logicalplan::{col, Expr, LogicalPlan}; -use crate::logicalplan::{LogicalPlanBuilder, ScalarValue}; +use crate::logicalplan::{col, Expr, LogicalPlan, LogicalPlanBuilder}; use crate::table::*; use arrow::datatypes::Schema; From e8878694ff4d3a95146a6a1e4e2ad808b148ba52 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 18 Jul 2020 07:58:39 +0200 Subject: [PATCH 5/8] Moved column naming to the logical plan. --- rust/datafusion/src/execution/context.rs | 90 +++++++++---- .../execution/physical_plan/expressions.rs | 119 +++--------------- .../execution/physical_plan/hash_aggregate.rs | 54 +++++--- .../src/execution/physical_plan/mod.rs | 17 +-- .../src/execution/physical_plan/projection.rs | 18 ++- .../src/execution/physical_plan/udf.rs | 7 -- rust/datafusion/src/logicalplan.rs | 99 +++++++++++++-- .../src/optimizer/projection_push_down.rs | 7 +- rust/datafusion/src/optimizer/utils.rs | 50 +------- 9 files changed, 231 insertions(+), 230 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 9341982397f..2433fda43ca 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -36,8 +36,7 @@ use crate::execution::physical_plan::common; use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions}; use crate::execution::physical_plan::datasource::DatasourceExec; use crate::execution::physical_plan::expressions::{ - Alias, Avg, BinaryExpr, CastExpr, Column, Count, Literal, Max, Min, PhysicalSortExpr, - Sum, + Avg, BinaryExpr, CastExpr, Column, Count, Literal, Max, Min, PhysicalSortExpr, Sum, }; use crate::execution::physical_plan::hash_aggregate::HashAggregateExec; use crate::execution::physical_plan::limit::LimitExec; @@ -51,7 +50,9 @@ use crate::execution::physical_plan::sort::{SortExec, SortOptions}; use crate::execution::physical_plan::udf::{ScalarFunction, ScalarFunctionExpr}; use crate::execution::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr}; use crate::execution::table_impl::TableImpl; -use crate::logicalplan::*; +use crate::logicalplan::{ + Expr, FunctionMeta, FunctionType, LogicalPlan, LogicalPlanBuilder, +}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::projection_push_down::ProjectionPushDown; use crate::optimizer::type_coercion::TypeCoercionRule; @@ -66,6 +67,15 @@ pub struct ExecutionContext { scalar_functions: HashMap>, } +fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { + match value { + (Ok(e), Ok(e1)) => Ok((e, e1)), + (Err(e), Ok(_)) => Err(e), + (Ok(_), Err(e1)) => Err(e1), + (Err(e), Err(_)) => Err(e), + } +} + impl ExecutionContext { /// Create a new execution context for in-memory queries pub fn new() -> Self { @@ -354,7 +364,12 @@ impl ExecutionContext { let input_schema = input.as_ref().schema().clone(); let runtime_expr = expr .iter() - .map(|e| self.create_physical_expr(e, &input_schema)) + .map(|e| { + tuple_err(( + self.create_physical_expr(e, &input_schema), + e.name(&input_schema), + )) + }) .collect::>>()?; Ok(Arc::new(ProjectionExec::try_new(runtime_expr, input)?)) } @@ -368,17 +383,30 @@ impl ExecutionContext { let input = self.create_physical_plan(input, batch_size)?; let input_schema = input.as_ref().schema().clone(); - let group_expr = group_expr + let groups = group_expr .iter() - .map(|e| self.create_physical_expr(e, &input_schema)) + .map(|e| { + tuple_err(( + self.create_physical_expr(e, &input_schema), + e.name(&input_schema), + )) + }) .collect::>>()?; - let aggr_expr = aggr_expr + let aggregates = aggr_expr .iter() - .map(|e| self.create_aggregate_expr(e, &input_schema)) + .map(|e| { + tuple_err(( + self.create_aggregate_expr(e, &input_schema), + e.name(&input_schema), + )) + }) .collect::>>()?; - let initial_aggr = - HashAggregateExec::try_new(group_expr, aggr_expr, input)?; + let initial_aggr = HashAggregateExec::try_new( + groups.clone(), + aggregates.clone(), + input, + )?; let schema = initial_aggr.schema(); let partitions = initial_aggr.partitions()?; @@ -387,13 +415,27 @@ impl ExecutionContext { return Ok(Arc::new(initial_aggr)); } - let (final_group, final_aggr) = initial_aggr.make_final_expr(); - let merge = Arc::new(MergeExec::new(schema.clone(), partitions)); + // construct the expressions for the final aggregation + let (final_group, final_aggr) = initial_aggr.make_final_expr( + groups.iter().map(|x| x.1.clone()).collect(), + aggregates.iter().map(|x| x.1.clone()).collect(), + ); + + // construct a second aggregation, keeping the final column name equal to the first aggregation + // and the expressions corresponding to the respective aggregate Ok(Arc::new(HashAggregateExec::try_new( - final_group, - final_aggr, + final_group + .iter() + .enumerate() + .map(|(i, expr)| (expr.clone(), groups[i].1.clone())) + .collect(), + final_aggr + .iter() + .enumerate() + .map(|(i, expr)| (expr.clone(), aggregates[i].1.clone())) + .collect(), merge, )?)) } @@ -453,10 +495,7 @@ impl ExecutionContext { input_schema: &Schema, ) -> Result> { match e { - Expr::Alias(expr, name) => { - let expr = self.create_physical_expr(expr, input_schema)?; - Ok(Arc::new(Alias::new(expr, &name))) - } + Expr::Alias(expr, ..) => Ok(self.create_physical_expr(expr, input_schema)?), Expr::Column(name) => { // check that name exists input_schema.field_with_name(&name)?; @@ -484,7 +523,6 @@ impl ExecutionContext { physical_args.push(self.create_physical_expr(e, input_schema)?); } Ok(Arc::new(ScalarFunctionExpr::new( - name, Box::new(f.fun.clone()), physical_args, return_type, @@ -650,6 +688,7 @@ mod tests { use super::*; use crate::datasource::MemTable; use crate::execution::physical_plan::udf::ScalarUdf; + use crate::logicalplan::{aggregate_expr, col, scalar_function}; use crate::test; use arrow::array::{ArrayRef, Int32Array}; use arrow::compute::add; @@ -995,19 +1034,16 @@ mod tests { let ctx = create_ctx(&tmp_dir, 1)?; let schema = Arc::new(Schema::new(vec![ - Field::new("state", DataType::Utf8, false), - Field::new("salary", DataType::UInt32, false), + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), ])); let plan = LogicalPlanBuilder::scan("default", "test", schema.as_ref(), None)? .aggregate( - vec![col("state")], - vec![aggregate_expr("SUM", col("salary"), DataType::UInt32)], + vec![col("c1")], + vec![aggregate_expr("SUM", col("c2"), DataType::UInt32)], )? - .project(vec![ - col("state"), - col("SUM(SUM(salary))").alias("total_salary"), - ])? + .project(vec![col("c1"), col("SUM(c2)").alias("total_salary")])? .build()?; let plan = ctx.optimize(&plan)?; diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index a680b6cb11c..194f35669f6 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -47,40 +47,6 @@ use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::{DataType, Schema, TimeUnit}; use arrow::record_batch::RecordBatch; -/// Represents an aliased expression -pub struct Alias { - expr: Arc, - alias: String, -} - -impl Alias { - /// Create a new aliased expression - pub fn new(expr: Arc, alias: &str) -> Self { - Self { - expr: expr.clone(), - alias: alias.to_owned(), - } - } -} - -impl PhysicalExpr for Alias { - fn name(&self) -> String { - self.alias.clone() - } - - fn data_type(&self, input_schema: &Schema) -> Result { - self.expr.data_type(input_schema) - } - - fn nullable(&self, input_schema: &Schema) -> Result { - self.expr.nullable(input_schema) - } - - fn evaluate(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) - } -} - /// Represents the column at a given index in a RecordBatch pub struct Column { name: String, @@ -96,11 +62,6 @@ impl Column { } impl PhysicalExpr for Column { - /// Get the name to use in a schema to represent the result of this expression - fn name(&self) -> String { - self.name.clone() - } - /// Get the data type of this expression, given the schema of the input fn data_type(&self, input_schema: &Schema) -> Result { Ok(input_schema @@ -138,10 +99,6 @@ impl Sum { } impl AggregateExpr for Sum { - fn name(&self) -> String { - format!("SUM({})", self.expr.name()) - } - fn data_type(&self, input_schema: &Schema) -> Result { match self.expr.data_type(input_schema)? { DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { @@ -167,8 +124,8 @@ impl AggregateExpr for Sum { Rc::new(RefCell::new(SumAccumulator { sum: None })) } - fn create_reducer(&self) -> Arc { - Arc::new(Sum::new(col(&self.name()))) + fn create_reducer(&self, column_name: &str) -> Arc { + Arc::new(Sum::new(Arc::new(Column::new(column_name)))) } } @@ -334,10 +291,6 @@ impl Avg { } impl AggregateExpr for Avg { - fn name(&self) -> String { - format!("AVG({})", self.expr.name()) - } - fn data_type(&self, input_schema: &Schema) -> Result { match self.expr.data_type(input_schema)? { DataType::Int8 @@ -368,8 +321,8 @@ impl AggregateExpr for Avg { })) } - fn create_reducer(&self) -> Arc { - Arc::new(Avg::new(Arc::new(Column::new(&self.name())))) + fn create_reducer(&self, column_name: &str) -> Arc { + Arc::new(Avg::new(Arc::new(Column::new(column_name)))) } } @@ -452,10 +405,6 @@ impl Max { } impl AggregateExpr for Max { - fn name(&self) -> String { - format!("MAX({})", self.expr.name()) - } - fn data_type(&self, input_schema: &Schema) -> Result { match self.expr.data_type(input_schema)? { DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { @@ -481,8 +430,8 @@ impl AggregateExpr for Max { Rc::new(RefCell::new(MaxAccumulator { max: None })) } - fn create_reducer(&self) -> Arc { - Arc::new(Max::new(Arc::new(Column::new(&self.name())))) + fn create_reducer(&self, column_name: &str) -> Arc { + Arc::new(Max::new(Arc::new(Column::new(column_name)))) } } @@ -651,10 +600,6 @@ impl Min { } impl AggregateExpr for Min { - fn name(&self) -> String { - format!("MIN({})", self.expr.name()) - } - fn data_type(&self, input_schema: &Schema) -> Result { match self.expr.data_type(input_schema)? { DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { @@ -680,8 +625,8 @@ impl AggregateExpr for Min { Rc::new(RefCell::new(MinAccumulator { min: None })) } - fn create_reducer(&self) -> Arc { - Arc::new(Min::new(Arc::new(Column::new(&self.name())))) + fn create_reducer(&self, column_name: &str) -> Arc { + Arc::new(Min::new(Arc::new(Column::new(column_name)))) } } @@ -851,10 +796,6 @@ impl Count { } impl AggregateExpr for Count { - fn name(&self) -> String { - format!("COUNT({})", self.expr.name()) - } - fn data_type(&self, _input_schema: &Schema) -> Result { Ok(DataType::UInt64) } @@ -867,8 +808,8 @@ impl AggregateExpr for Count { Rc::new(RefCell::new(CountAccumulator { count: 0 })) } - fn create_reducer(&self) -> Arc { - Arc::new(Sum::new(Arc::new(Column::new(&self.name())))) + fn create_reducer(&self, column_name: &str) -> Arc { + Arc::new(Sum::new(Arc::new(Column::new(column_name)))) } } @@ -1025,10 +966,6 @@ impl BinaryExpr { } impl PhysicalExpr for BinaryExpr { - fn name(&self) -> String { - format!("{:?}", self.op) - } - fn data_type(&self, input_schema: &Schema) -> Result { self.left.data_type(input_schema) } @@ -1113,10 +1050,6 @@ impl NotExpr { } impl PhysicalExpr for NotExpr { - fn name(&self) -> String { - "NOT".to_string() - } - fn data_type(&self, _input_schema: &Schema) -> Result { return Ok(DataType::Boolean); } @@ -1194,10 +1127,6 @@ impl CastExpr { } impl PhysicalExpr for CastExpr { - fn name(&self) -> String { - "CAST".to_string() - } - fn data_type(&self, _input_schema: &Schema) -> Result { Ok(self.cast_type.clone()) } @@ -1237,10 +1166,6 @@ macro_rules! build_literal_array { } impl PhysicalExpr for Literal { - fn name(&self) -> String { - "lit".to_string() - } - fn data_type(&self, _input_schema: &Schema) -> Result { Ok(self.value.get_datatype()) } @@ -1487,17 +1412,15 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let sum = sum(col("a")); - assert_eq!("SUM(a)".to_string(), sum.name()); assert_eq!(DataType::Int64, sum.data_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ schema.field(0).clone(), - Field::new(&sum.name(), sum.data_type(&schema)?, false), + Field::new("SUM(a)", sum.data_type(&schema)?, false), ]); - let combiner = sum.create_reducer(); - assert_eq!("SUM(SUM(a))".to_string(), combiner.name()); + let combiner = sum.create_reducer("SUM(a)"); assert_eq!(DataType::Int64, combiner.data_type(&schema)?); Ok(()) @@ -1508,17 +1431,15 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let max = max(col("a")); - assert_eq!("MAX(a)".to_string(), max.name()); assert_eq!(DataType::Int64, max.data_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ schema.field(0).clone(), - Field::new(&max.name(), max.data_type(&schema)?, false), + Field::new("Max(a)", max.data_type(&schema)?, false), ]); - let combiner = max.create_reducer(); - assert_eq!("MAX(MAX(a))".to_string(), combiner.name()); + let combiner = max.create_reducer("Max(a)"); assert_eq!(DataType::Int64, combiner.data_type(&schema)?); Ok(()) @@ -1529,16 +1450,14 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let min = min(col("a")); - assert_eq!("MIN(a)".to_string(), min.name()); assert_eq!(DataType::Int64, min.data_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ schema.field(0).clone(), - Field::new(&min.name(), min.data_type(&schema)?, false), + Field::new("MIN(a)", min.data_type(&schema)?, false), ]); - let combiner = min.create_reducer(); - assert_eq!("MIN(MIN(a))".to_string(), combiner.name()); + let combiner = min.create_reducer("MIN(a)"); assert_eq!(DataType::Int64, combiner.data_type(&schema)?); Ok(()) @@ -1548,17 +1467,15 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let avg = avg(col("a")); - assert_eq!("AVG(a)".to_string(), avg.name()); assert_eq!(DataType::Float64, avg.data_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ schema.field(0).clone(), - Field::new(&avg.name(), avg.data_type(&schema)?, false), + Field::new("SUM(a)", avg.data_type(&schema)?, false), ]); - let combiner = avg.create_reducer(); - assert_eq!("AVG(AVG(a))".to_string(), combiner.name()); + let combiner = avg.create_reducer("SUM(a)"); assert_eq!(DataType::Float64, combiner.data_type(&schema)?); Ok(()) diff --git a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs index e935480154c..19836fd864d 100644 --- a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs @@ -54,26 +54,24 @@ pub struct HashAggregateExec { impl HashAggregateExec { /// Create a new hash aggregate execution plan pub fn try_new( - group_expr: Vec>, - aggr_expr: Vec>, + group_expr: Vec<(Arc, String)>, + aggr_expr: Vec<(Arc, String)>, input: Arc, ) -> Result { let input_schema = input.schema(); let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len()); - for expr in &group_expr { - let name = expr.name(); - fields.push(Field::new(&name, expr.data_type(&input_schema)?, true)) + for (expr, name) in &group_expr { + fields.push(Field::new(name, expr.data_type(&input_schema)?, true)) } - for expr in &aggr_expr { - let name = expr.name(); + for (expr, name) in &aggr_expr { fields.push(Field::new(&name, expr.data_type(&input_schema)?, true)) } let schema = Arc::new(Schema::new(fields)); Ok(HashAggregateExec { - group_expr, - aggr_expr, + group_expr: group_expr.iter().map(|x| x.0.clone()).collect(), + aggr_expr: aggr_expr.iter().map(|x| x.0.clone()).collect(), input, schema, }) @@ -83,13 +81,15 @@ impl HashAggregateExec { /// expressions pub fn make_final_expr( &self, + group_names: Vec, + agg_names: Vec, ) -> (Vec>, Vec>) { let final_group: Vec> = (0..self.group_expr.len()) - .map(|i| col(&self.group_expr[i].name()) as Arc) + .map(|i| col(&group_names[i]) as Arc) .collect(); let final_aggr: Vec> = (0..self.aggr_expr.len()) - .map(|i| self.aggr_expr[i].create_reducer()) + .map(|i| self.aggr_expr[i].create_reducer(&agg_names[i])) .collect(); (final_group, final_aggr) @@ -766,24 +766,42 @@ mod tests { let csv = CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema), None, 1024)?; - let group_expr: Vec> = vec![col("c2")]; + let groups: Vec<(Arc, String)> = + vec![(col("c2"), "c2".to_string())]; - let aggr_expr: Vec> = vec![sum(col("c4"))]; + let aggregates: Vec<(Arc, String)> = + vec![(sum(col("c4")), "SUM(c4)".to_string())]; let partition_aggregate = HashAggregateExec::try_new( - group_expr.clone(), - aggr_expr.clone(), + groups.clone(), + aggregates.clone(), Arc::new(csv), )?; let schema = partition_aggregate.schema(); let partitions = partition_aggregate.partitions()?; - let (final_group, final_aggr) = partition_aggregate.make_final_expr(); + + // construct the expressions for the final aggregation + let (final_group, final_aggr) = partition_aggregate.make_final_expr( + groups.iter().map(|x| x.1.clone()).collect(), + aggregates.iter().map(|x| x.1.clone()).collect(), + ); let merge = Arc::new(MergeExec::new(schema.clone(), partitions)); - let merged_aggregate = - HashAggregateExec::try_new(final_group, final_aggr, merge)?; + let merged_aggregate = HashAggregateExec::try_new( + final_group + .iter() + .enumerate() + .map(|(i, expr)| (expr.clone(), groups[i].1.clone())) + .collect(), + final_aggr + .iter() + .enumerate() + .map(|(i, expr)| (expr.clone(), aggregates[i].1.clone())) + .collect(), + merge, + )?; let result = test::execute(&merged_aggregate)?; assert_eq!(result.len(), 1); diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index 2828e5dd96f..2e191784678 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -24,7 +24,7 @@ use std::sync::{Arc, Mutex}; use crate::error::Result; use crate::logicalplan::ScalarValue; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::record_batch::{RecordBatch, RecordBatchReader}; /// Partition-aware execution plan for a relation @@ -42,29 +42,18 @@ pub trait Partition: Send + Sync { } /// Expression that can be evaluated against a RecordBatch +/// A Physical expression knows its type, nullability and how to evaluate itself. pub trait PhysicalExpr: Send + Sync { - /// Get the name to use in a schema to represent the result of this expression - fn name(&self) -> String; /// Get the data type of this expression, given the schema of the input fn data_type(&self, input_schema: &Schema) -> Result; /// Decide whehter this expression is nullable, given the schema of the input fn nullable(&self, input_schema: &Schema) -> Result; /// Evaluate an expression against a RecordBatch fn evaluate(&self, batch: &RecordBatch) -> Result; - /// Generate schema Field type for this expression - fn to_schema_field(&self, input_schema: &Schema) -> Result { - Ok(Field::new( - &self.name(), - self.data_type(input_schema)?, - self.nullable(input_schema)?, - )) - } } /// Aggregate expression that can be evaluated against a RecordBatch pub trait AggregateExpr: Send + Sync { - /// Get the name to use in a schema to represent the result of this expression - fn name(&self) -> String; /// Get the data type of this expression, given the schema of the input fn data_type(&self, input_schema: &Schema) -> Result; /// Evaluate the expression being aggregated @@ -74,7 +63,7 @@ pub trait AggregateExpr: Send + Sync { /// Create an aggregate expression for combining the results of accumulators from partitions. /// For example, to combine the results of a parallel SUM we just need to do another SUM, but /// to combine the results of parallel COUNT we would also use SUM. - fn create_reducer(&self) -> Arc; + fn create_reducer(&self, column_name: &str) -> Arc; } /// Aggregate accumulator diff --git a/rust/datafusion/src/execution/physical_plan/projection.rs b/rust/datafusion/src/execution/physical_plan/projection.rs index c10b7824205..2c2bcb02659 100644 --- a/rust/datafusion/src/execution/physical_plan/projection.rs +++ b/rust/datafusion/src/execution/physical_plan/projection.rs @@ -24,7 +24,7 @@ use std::sync::{Arc, Mutex}; use crate::error::{ExecutionError, Result}; use crate::execution::physical_plan::{ExecutionPlan, Partition, PhysicalExpr}; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::{RecordBatch, RecordBatchReader}; @@ -41,20 +41,26 @@ pub struct ProjectionExec { impl ProjectionExec { /// Create a projection on an input pub fn try_new( - expr: Vec>, + expr: Vec<(Arc, String)>, input: Arc, ) -> Result { let input_schema = input.schema(); let fields: Result> = expr .iter() - .map(|e| e.to_schema_field(&input_schema)) + .map(|(e, name)| { + Ok(Field::new( + name, + e.data_type(&input_schema)?, + e.nullable(&input_schema)?, + )) + }) .collect(); let schema = Arc::new(Schema::new(fields?)); Ok(Self { - expr: expr.clone(), + expr: expr.iter().map(|x| x.0.clone()).collect(), schema, input: input.clone(), }) @@ -154,7 +160,9 @@ mod tests { let csv = CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema), None, 1024)?; - let projection = ProjectionExec::try_new(vec![col("c1")], Arc::new(csv))?; + // pick column c1 and name it column c1 in the output schema + let projection = + ProjectionExec::try_new(vec![(col("c1"), "c1".to_string())], Arc::new(csv))?; let mut partition_count = 0; let mut row_count = 0; diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs index df970dda717..a4480bc0bee 100644 --- a/rust/datafusion/src/execution/physical_plan/udf.rs +++ b/rust/datafusion/src/execution/physical_plan/udf.rs @@ -61,7 +61,6 @@ impl ScalarFunction { /// Scalar UDF Physical Expression pub struct ScalarFunctionExpr { - name: String, fun: Box, args: Vec>, return_type: DataType, @@ -70,13 +69,11 @@ pub struct ScalarFunctionExpr { impl ScalarFunctionExpr { /// Create a new Scalar function pub fn new( - name: &str, fun: Box, args: Vec>, return_type: &DataType, ) -> Self { Self { - name: name.to_owned(), fun, args, return_type: return_type.clone(), @@ -85,10 +82,6 @@ impl ScalarFunctionExpr { } impl PhysicalExpr for ScalarFunctionExpr { - fn name(&self) -> String { - self.name.clone() - } - fn data_type(&self, _input_schema: &Schema) -> Result { Ok(self.return_type.clone()) } diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 6ded26872dc..b372597adc3 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -179,6 +179,85 @@ impl ScalarValue { } } +/// Returns a readable name of an expression based on the input schema. +/// This function recursively transverses the expression for names such as "CAST(a > 2)". +fn create_name(e: &Expr, input_schema: &Schema) -> Result { + match e { + Expr::Alias(_, name) => Ok(name.clone()), + Expr::Column(name) => Ok(name.clone()), + Expr::Literal(value) => Ok(format!("{:?}", value)), + Expr::BinaryExpr { left, op, right } => { + let left = create_name(left, input_schema)?; + let right = create_name(right, input_schema)?; + Ok(format!("{} {:?} {}", left, op, right)) + } + Expr::Cast { expr, data_type } => { + let expr = create_name(expr, input_schema)?; + Ok(format!("CAST({} as {:?})", expr, data_type)) + } + Expr::ScalarFunction { name, args, .. } => { + let mut names = Vec::with_capacity(args.len()); + for e in args { + names.push(create_name(e, input_schema)?); + } + Ok(format!("{}({})", name, names.join(","))) + } + Expr::AggregateFunction { name, args, .. } => { + let mut names = Vec::with_capacity(args.len()); + for e in args { + names.push(create_name(e, input_schema)?); + } + Ok(format!("{}({})", name, names.join(","))) + } + other => Err(ExecutionError::NotImplemented(format!( + "Physical plan does not support logical expression {:?}", + other + ))), + } +} + +/// Returns the datatype of the expression given the input schema +// note: the physical plan derived from an expression must match the datatype on this function. +fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result { + let data_type = match e { + Expr::Alias(expr, ..) => expr.get_type(input_schema), + Expr::Column(name) => Ok(input_schema.field_with_name(name)?.data_type().clone()), + Expr::Literal(ref lit) => Ok(lit.get_datatype()), + Expr::ScalarFunction { + ref return_type, .. + } => Ok(return_type.clone()), + Expr::AggregateFunction { + ref return_type, .. + } => Ok(return_type.clone()), + Expr::Cast { ref data_type, .. } => Ok(data_type.clone()), + Expr::BinaryExpr { + ref left, + ref right, + .. + } => { + let left_type = left.get_type(input_schema)?; + let right_type = right.get_type(input_schema)?; + Ok(utils::get_supertype(&left_type, &right_type).unwrap()) + } + _ => Err(ExecutionError::NotImplemented(format!( + "Cannot determine schema type for expression {:?}", + e + ))), + }; + + match data_type { + Ok(d) => Ok(Field::new(&e.name(input_schema)?, d, true)), + Err(e) => Err(e), + } +} + +/// Create field meta-data from an expression, for use in a result set schema +fn exprlist_to_fields(expr: &[Expr], input_schema: &Schema) -> Result> { + expr.iter() + .map(|e| expr_to_field(e, input_schema)) + .collect() +} + /// Relation expression #[derive(Clone, PartialEq)] pub enum Expr { @@ -278,6 +357,13 @@ impl Expr { } } + /// Return the name of this expression + /// + /// This represents how a column with this expression is named when no alias is chosen + pub fn name(&self, input_schema: &Schema) -> Result { + create_name(self, input_schema) + } + /// Perform a type cast on the expression value. /// /// Will `Err` if the type cast cannot be performed. @@ -923,10 +1009,8 @@ impl LogicalPlanBuilder { expr.clone() }; - let schema = Schema::new(utils::exprlist_to_fields( - &projected_expr, - input_schema.as_ref(), - )?); + let schema = + Schema::new(exprlist_to_fields(&projected_expr, input_schema.as_ref())?); Ok(Self::from(&LogicalPlan::Projection { expr: projected_expr, @@ -963,11 +1047,10 @@ impl LogicalPlanBuilder { /// Apply an aggregate pub fn aggregate(&self, group_expr: Vec, aggr_expr: Vec) -> Result { - let mut all_fields: Vec = group_expr.clone(); - aggr_expr.iter().for_each(|x| all_fields.push(x.clone())); + let mut all_expr: Vec = group_expr.clone(); + aggr_expr.iter().for_each(|x| all_expr.push(x.clone())); - let aggr_schema = - Schema::new(utils::exprlist_to_fields(&all_fields, self.plan.schema())?); + let aggr_schema = Schema::new(exprlist_to_fields(&all_expr, self.plan.schema())?); Ok(Self::from(&LogicalPlan::Aggregate { input: Box::new(self.plan.clone()), diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs index aacc2aef9ef..e99a996d7c9 100644 --- a/rust/datafusion/src/optimizer/projection_push_down.rs +++ b/rust/datafusion/src/optimizer/projection_push_down.rs @@ -23,6 +23,7 @@ use crate::logicalplan::LogicalPlan; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use arrow::datatypes::{Field, Schema}; +use arrow::error::Result as ArrowResult; use std::collections::HashSet; /// Projection Push Down optimizer rule ensures that only referenced columns are @@ -203,9 +204,13 @@ fn get_projected_schema( // once we reach the table scan, we can use the accumulated set of column // names to construct the set of column indexes in the scan + // + // we discard non-existing columns because some column names are not part of the schema, + // e.g. when the column derives from an aggregation let mut projection: Vec = accum .iter() - .map(|name| table_schema.index_of(name).unwrap()) + .map(|name| table_schema.index_of(name)) + .filter_map(ArrowResult::ok) .collect(); if projection.is_empty() { diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index 7f78f485c74..5c59803ec7b 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -19,7 +19,7 @@ use std::collections::HashSet; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::DataType; use crate::error::{ExecutionError, Result}; use crate::logicalplan::Expr; @@ -67,54 +67,6 @@ pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet) -> Result< } } -/// Create field meta-data from an expression, for use in a result set schema -pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result { - match e { - Expr::Alias(expr, name) => { - Ok(Field::new(name, expr.get_type(input_schema)?, true)) - } - Expr::Column(name) => Ok(input_schema.field_with_name(name)?.clone()), - Expr::Literal(ref lit) => Ok(Field::new("lit", lit.get_datatype(), true)), - Expr::ScalarFunction { - ref name, - ref return_type, - .. - } => Ok(Field::new(&name, return_type.clone(), true)), - Expr::AggregateFunction { - ref name, - ref return_type, - .. - } => Ok(Field::new(&name, return_type.clone(), true)), - Expr::Cast { ref data_type, .. } => { - Ok(Field::new("cast", data_type.clone(), true)) - } - Expr::BinaryExpr { - ref left, - ref right, - .. - } => { - let left_type = left.get_type(input_schema)?; - let right_type = right.get_type(input_schema)?; - Ok(Field::new( - "binary_expr", - get_supertype(&left_type, &right_type).unwrap(), - true, - )) - } - _ => Err(ExecutionError::NotImplemented(format!( - "Cannot determine schema type for expression {:?}", - e - ))), - } -} - -/// Create field meta-data from an expression, for use in a result set schema -pub fn exprlist_to_fields(expr: &[Expr], input_schema: &Schema) -> Result> { - expr.iter() - .map(|e| expr_to_field(e, input_schema)) - .collect() -} - /// Given two datatypes, determine the supertype that both types can safely be cast to pub fn get_supertype(l: &DataType, r: &DataType) -> Result { match _get_supertype(l, r) { From 308393df90b70fd5b060da13f66812f1df5ddb50 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 18 Jul 2020 09:10:54 +0200 Subject: [PATCH 6/8] Added testing of expression naming. --- rust/datafusion/src/execution/context.rs | 47 +++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 2433fda43ca..cec0f1531e0 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -705,10 +705,12 @@ mod tests { // there should be one batch per partition assert_eq!(results.len(), partition_count); - // each batch should contain 2 columns and 10 rows + // each batch should contain 2 columns and 10 rows with correct field names for batch in &results { assert_eq!(batch.num_columns(), 2); assert_eq!(batch.num_rows(), 10); + + assert_eq!(field_names(batch), vec!["c1", "c2"]); } Ok(()) @@ -883,6 +885,9 @@ mod tests { assert_eq!(results.len(), 1); let batch = &results[0]; + + assert_eq!(field_names(batch), vec!["SUM(c1)", "SUM(c2)"]); + let expected: Vec<&str> = vec!["60,220"]; let mut rows = test::format_batch(&batch); rows.sort(); @@ -897,6 +902,9 @@ mod tests { assert_eq!(results.len(), 1); let batch = &results[0]; + + assert_eq!(field_names(batch), vec!["AVG(c1)", "AVG(c2)"]); + let expected: Vec<&str> = vec!["1.5,5.5"]; let mut rows = test::format_batch(&batch); rows.sort(); @@ -911,6 +919,9 @@ mod tests { assert_eq!(results.len(), 1); let batch = &results[0]; + + assert_eq!(field_names(batch), vec!["MAX(c1)", "MAX(c2)"]); + let expected: Vec<&str> = vec!["3,10"]; let mut rows = test::format_batch(&batch); rows.sort(); @@ -925,6 +936,9 @@ mod tests { assert_eq!(results.len(), 1); let batch = &results[0]; + + assert_eq!(field_names(batch), vec!["MIN(c1)", "MIN(c2)"]); + let expected: Vec<&str> = vec!["0,1"]; let mut rows = test::format_batch(&batch); rows.sort(); @@ -939,6 +953,9 @@ mod tests { assert_eq!(results.len(), 1); let batch = &results[0]; + + assert_eq!(field_names(batch), vec!["c1", "SUM(c2)"]); + let expected: Vec<&str> = vec!["0,55", "1,55", "2,55", "3,55"]; let mut rows = test::format_batch(&batch); rows.sort(); @@ -953,6 +970,9 @@ mod tests { assert_eq!(results.len(), 1); let batch = &results[0]; + + assert_eq!(field_names(batch), vec!["c1", "AVG(c2)"]); + let expected: Vec<&str> = vec!["0,5.5", "1,5.5", "2,5.5", "3,5.5"]; let mut rows = test::format_batch(&batch); rows.sort(); @@ -967,6 +987,9 @@ mod tests { assert_eq!(results.len(), 1); let batch = &results[0]; + + assert_eq!(field_names(batch), vec!["c1", "MAX(c2)"]); + let expected: Vec<&str> = vec!["0,10", "1,10", "2,10", "3,10"]; let mut rows = test::format_batch(&batch); rows.sort(); @@ -981,6 +1004,9 @@ mod tests { assert_eq!(results.len(), 1); let batch = &results[0]; + + assert_eq!(field_names(batch), vec!["c1", "MIN(c2)"]); + let expected: Vec<&str> = vec!["0,1", "1,1", "2,1", "3,1"]; let mut rows = test::format_batch(&batch); rows.sort(); @@ -995,6 +1021,9 @@ mod tests { assert_eq!(results.len(), 1); let batch = &results[0]; + + assert_eq!(field_names(batch), vec!["COUNT(c1)", "COUNT(c2)"]); + let expected: Vec<&str> = vec!["10,10"]; let mut rows = test::format_batch(&batch); rows.sort(); @@ -1008,6 +1037,9 @@ mod tests { assert_eq!(results.len(), 1); let batch = &results[0]; + + assert_eq!(field_names(batch), vec!["COUNT(c1)", "COUNT(c2)"]); + let expected: Vec<&str> = vec!["40,40"]; let mut rows = test::format_batch(&batch); rows.sort(); @@ -1021,6 +1053,9 @@ mod tests { assert_eq!(results.len(), 1); let batch = &results[0]; + + assert_eq!(field_names(batch), vec!["c1", "COUNT(c2)"]); + let expected = vec!["0,10", "1,10", "2,10", "3,10"]; let mut rows = test::format_batch(&batch); rows.sort(); @@ -1170,6 +1205,7 @@ mod tests { let batch = &result[0]; assert_eq!(3, batch.num_columns()); assert_eq!(4, batch.num_rows()); + assert_eq!(field_names(batch), vec!["a", "b", "my_add(a,b)"]); let a = batch .column(0) @@ -1205,6 +1241,15 @@ mod tests { ctx.collect(physical_plan.as_ref()) } + fn field_names(result: &RecordBatch) -> Vec { + result + .schema() + .fields() + .iter() + .map(|x| x.name().clone()) + .collect::>() + } + /// Execute SQL and return results fn execute(sql: &str, partition_count: usize) -> Result> { let tmp_dir = TempDir::new("execute")?; From 4a5b62a89c5eb39935a3b8ea6453583206e9490e Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Fri, 24 Jul 2020 08:26:33 +0200 Subject: [PATCH 7/8] Fixed error in edge case of order in group by. --- rust/datafusion/src/sql/planner.rs | 34 +++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 81aeab713c4..21cf870ba34 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -173,7 +173,28 @@ impl SqlToRel { .aggregate(group_expr, aggr_expr)? .build()?; - Ok(plan) + // optionally wrap in projection to preserve final order of fields + let expected_columns: Vec = projection_expr + .iter() + .map(|e| e.name(input.schema())) + .collect::>>()?; + let columns: Vec = plan + .schema() + .fields() + .iter() + .map(|f| f.name().clone()) + .collect::>(); + if expected_columns != columns { + self.project( + &plan, + expected_columns + .iter() + .map(|c| Expr::Column(c.clone())) + .collect(), + ) + } else { + Ok(plan) + } } /// Wrap a plan in a limit @@ -620,6 +641,17 @@ mod tests { quick_test(sql, expected); } + #[test] + fn select_group_by_needs_projection() { + let sql = "SELECT COUNT(state), state FROM person GROUP BY state"; + let expected = "\ + Projection: #COUNT(state), #state\ + \n Aggregate: groupBy=[[#state]], aggr=[[COUNT(#state)]]\ + \n TableScan: person projection=None"; + + quick_test(sql, expected); + } + #[test] fn select_7480_1() { let sql = "SELECT c1, MIN(c12) FROM aggregate_test_100 GROUP BY c1, c13"; From a5558e9af48e6b3d7b9c23c5cdd9247aa1a9115c Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Fri, 24 Jul 2020 08:30:47 +0200 Subject: [PATCH 8/8] Fixed error in SQL statement of test. --- rust/datafusion/tests/sql.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 70635e9eaf0..30a55f7be55 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -279,7 +279,7 @@ fn csv_query_group_by_int_count() -> Result<()> { fn csv_query_group_by_string_min_max() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx)?; - let sql = "SELECT c2, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c1"; + let sql = "SELECT c1, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c1"; let mut actual = execute(&mut ctx, sql); actual.sort(); let expected =