From 41991adcd56037fb92d0029e959209c2e6caae2e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 Mar 2020 11:24:05 -0600 Subject: [PATCH] Table API now uses LogicalPlanBuilder --- rust/datafusion/examples/memory_table_api.rs | 25 ++-- rust/datafusion/src/execution/context.rs | 7 +- rust/datafusion/src/execution/table_impl.rs | 122 +++++-------------- rust/datafusion/src/table.rs | 7 +- 4 files changed, 50 insertions(+), 111 deletions(-) diff --git a/rust/datafusion/examples/memory_table_api.rs b/rust/datafusion/examples/memory_table_api.rs index 9fa218c11c0..cf42264fd45 100644 --- a/rust/datafusion/examples/memory_table_api.rs +++ b/rust/datafusion/examples/memory_table_api.rs @@ -26,11 +26,12 @@ use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion::datasource::MemTable; +use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; use datafusion::logicalplan::{Expr, ScalarValue}; /// This example demonstrates basic uses of the Table API on an in-memory table -fn main() { +fn main() -> Result<()> { // define a schema. let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), @@ -44,31 +45,23 @@ fn main() { Arc::new(StringArray::from(vec!["a", "b", "c", "d"])), Arc::new(Int32Array::from(vec![1, 10, 10, 100])), ], - ) - .unwrap(); + )?; // declare a new context. In spark API, this corresponds to a new spark SQLsession let mut ctx = ExecutionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - let provider = MemTable::new(schema, vec![batch]).unwrap(); + let provider = MemTable::new(schema, vec![batch])?; ctx.register_table("t", Box::new(provider)); - let t = ctx.table("t").unwrap(); + let t = ctx.table("t")?; // construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL - let filter = t - .col("b") - .unwrap() - .eq(&Expr::Literal(ScalarValue::Int32(10))); + let filter = t.col("b")?.eq(&Expr::Literal(ScalarValue::Int32(10))); - let t = t - .select_columns(vec!["a", "b"]) - .unwrap() - .filter(filter) - .unwrap(); + let t = t.select_columns(vec!["a", "b"])?.filter(filter)?; // execute - let results = t.collect(&mut ctx, 10).unwrap(); + let results = t.collect(&mut ctx, 10)?; // print results results.iter().for_each(|batch| { @@ -94,4 +87,6 @@ fn main() { println!("{}, {}", c1.value(i), c2.value(i),); } }); + + Ok(()) } diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 47608e36537..063f5bb3ec9 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -216,13 +216,16 @@ impl ExecutionContext { pub fn table(&mut self, table_name: &str) -> Result> { match self.datasources.get(table_name) { Some(provider) => { - Ok(Arc::new(TableImpl::new(Arc::new(LogicalPlan::TableScan { + let table_scan = LogicalPlan::TableScan { schema_name: "".to_string(), table_name: table_name.to_string(), table_schema: provider.schema().clone(), projected_schema: provider.schema().clone(), projection: None, - })))) + }; + Ok(Arc::new(TableImpl::new( + &LogicalPlanBuilder::from(&table_scan).build()?, + ))) } _ => Err(ExecutionError::General(format!( "No table named '{}'", diff --git a/rust/datafusion/src/execution/table_impl.rs b/rust/datafusion/src/execution/table_impl.rs index e56cd76383e..10d65a43b56 100644 --- a/rust/datafusion/src/execution/table_impl.rs +++ b/rust/datafusion/src/execution/table_impl.rs @@ -19,72 +19,54 @@ use std::sync::Arc; -use crate::arrow::datatypes::{DataType, Field, Schema}; +use crate::arrow::datatypes::DataType; use crate::arrow::record_batch::RecordBatch; use crate::error::{ExecutionError, Result}; use crate::execution::context::ExecutionContext; -use crate::logicalplan::Expr::Literal; -use crate::logicalplan::ScalarValue; use crate::logicalplan::{Expr, LogicalPlan}; +use crate::logicalplan::{LogicalPlanBuilder, ScalarValue}; use crate::table::*; /// Implementation of Table API pub struct TableImpl { - plan: Arc, + plan: LogicalPlan, } impl TableImpl { /// Create a new Table based on an existing logical plan - pub fn new(plan: Arc) -> Self { - Self { plan } + pub fn new(plan: &LogicalPlan) -> Self { + Self { plan: plan.clone() } } } impl Table for TableImpl { /// Apply a projection based on a list of column names fn select_columns(&self, columns: Vec<&str>) -> Result> { - let mut expr: Vec = Vec::with_capacity(columns.len()); - for column_name in columns { - let i = self.column_index(column_name)?; - expr.push(Expr::Column(i)); - } - self.select(expr) + let exprs = columns + .iter() + .map(|name| { + self.plan + .schema() + .index_of(name.to_owned()) + .and_then(|i| Ok(Expr::Column(i))) + .map_err(|e| e.into()) + }) + .collect::>>()?; + self.select(exprs) } /// Create a projection based on arbitrary expressions fn select(&self, expr_list: Vec) -> Result> { - let schema = self.plan.schema(); - let mut field: Vec = Vec::with_capacity(expr_list.len()); - - for expr in &expr_list { - match expr { - Expr::Column(i) => { - field.push(schema.field(*i).clone()); - } - other => { - return Err(ExecutionError::NotImplemented(format!( - "Expr {:?} is not currently supported in this context", - other - ))) - } - } - } - - Ok(Arc::new(TableImpl::new(Arc::new( - LogicalPlan::Projection { - expr: expr_list.clone(), - input: self.plan.clone(), - schema: Arc::new(Schema::new(field)), - }, - )))) + let plan = LogicalPlanBuilder::from(&self.plan) + .project(expr_list)? + .build()?; + Ok(Arc::new(TableImpl::new(&plan))) } /// Create a selection based on a filter expression fn filter(&self, expr: Expr) -> Result> { - Ok(Arc::new(TableImpl::new(Arc::new(LogicalPlan::Selection { - expr, - input: self.plan.clone(), - })))) + let plan = LogicalPlanBuilder::from(&self.plan).filter(expr)?.build()?; + Ok(Arc::new(TableImpl::new(&plan))) } /// Perform an aggregate query @@ -93,38 +75,23 @@ impl Table for TableImpl { group_expr: Vec, aggr_expr: Vec, ) -> Result> { - Ok(Arc::new(TableImpl::new(Arc::new(LogicalPlan::Aggregate { - input: self.plan.clone(), - group_expr, - aggr_expr, - schema: Arc::new(Schema::new(vec![])), - })))) + let plan = LogicalPlanBuilder::from(&self.plan) + .aggregate(group_expr, aggr_expr)? + .build()?; + Ok(Arc::new(TableImpl::new(&plan))) } /// Limit the number of rows - fn limit(&self, n: usize) -> Result> { - Ok(Arc::new(TableImpl::new(Arc::new(LogicalPlan::Limit { - expr: Literal(ScalarValue::UInt32(n as u32)), - input: self.plan.clone(), - schema: self.plan.schema().clone(), - })))) + fn limit(&self, n: u32) -> Result> { + let plan = LogicalPlanBuilder::from(&self.plan) + .limit(Expr::Literal(ScalarValue::UInt32(n)))? + .build()?; + Ok(Arc::new(TableImpl::new(&plan))) } /// Return an expression representing a column within this table fn col(&self, name: &str) -> Result { - Ok(Expr::Column(self.column_index(name)?)) - } - - /// Return the index of a column within this table's schema - fn column_index(&self, name: &str) -> Result { - let schema = self.plan.schema(); - match schema.column_with_name(name) { - Some((i, _)) => Ok(i), - _ => Err(ExecutionError::InvalidColumn(format!( - "No column named '{}'", - name - ))), - } + Ok(Expr::Column(self.plan.schema().index_of(name)?)) } /// Create an expression to represent the min() aggregate function @@ -153,7 +120,7 @@ impl Table for TableImpl { } /// Convert to logical plan - fn to_logical_plan(&self) -> Arc { + fn to_logical_plan(&self) -> LogicalPlan { self.plan.clone() } @@ -195,14 +162,6 @@ mod tests { use crate::execution::context::ExecutionContext; use crate::test; - #[test] - fn column_index() { - let t = test_table(); - assert_eq!(0, t.column_index("c1").unwrap()); - assert_eq!(1, t.column_index("c2").unwrap()); - assert_eq!(12, t.column_index("c13").unwrap()); - } - #[test] fn select_columns() -> Result<()> { // build plan using Table API @@ -235,21 +194,6 @@ mod tests { Ok(()) } - #[test] - fn select_invalid_column() -> Result<()> { - let t = test_table(); - - match t.col("invalid_column_name") { - Ok(_) => panic!(), - Err(e) => assert_eq!( - "InvalidColumn(\"No column named \\\'invalid_column_name\\\'\")", - format!("{:?}", e) - ), - } - - Ok(()) - } - #[test] fn aggregate() -> Result<()> { // build plan using Table API diff --git a/rust/datafusion/src/table.rs b/rust/datafusion/src/table.rs index 37c86f6de56..c60c1dfa03f 100644 --- a/rust/datafusion/src/table.rs +++ b/rust/datafusion/src/table.rs @@ -43,10 +43,10 @@ pub trait Table { ) -> Result>; /// limit the number of rows - fn limit(&self, n: usize) -> Result>; + fn limit(&self, n: u32) -> Result>; /// Return the logical plan - fn to_logical_plan(&self) -> Arc; + fn to_logical_plan(&self) -> LogicalPlan; /// Return an expression representing a column within this table fn col(&self, name: &str) -> Result; @@ -66,9 +66,6 @@ pub trait Table { /// Create an expression to represent the count() aggregate function fn count(&self, expr: &Expr) -> Result; - /// Return the index of a column within this table's schema - fn column_index(&self, name: &str) -> Result; - /// Collects the result as a vector of RecordBatch. fn collect( &self,