Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions rust/datafusion/examples/memory_table_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;

use datafusion::datasource::MemTable;
use datafusion::error::Result;
use datafusion::execution::context::ExecutionContext;
use datafusion::logicalplan::{Expr, ScalarValue};

/// This example demonstrates basic uses of the Table API on an in-memory table
fn main() {
fn main() -> Result<()> {
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Expand All @@ -44,31 +45,23 @@ fn main() {
Arc::new(StringArray::from(vec!["a", "b", "c", "d"])),
Arc::new(Int32Array::from(vec![1, 10, 10, 100])),
],
)
.unwrap();
)?;

// declare a new context. In spark API, this corresponds to a new spark SQLsession
let mut ctx = ExecutionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::new(schema, vec![batch]).unwrap();
let provider = MemTable::new(schema, vec![batch])?;
ctx.register_table("t", Box::new(provider));
let t = ctx.table("t").unwrap();
let t = ctx.table("t")?;

// construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL
let filter = t
.col("b")
.unwrap()
.eq(&Expr::Literal(ScalarValue::Int32(10)));
let filter = t.col("b")?.eq(&Expr::Literal(ScalarValue::Int32(10)));

let t = t
.select_columns(vec!["a", "b"])
.unwrap()
.filter(filter)
.unwrap();
let t = t.select_columns(vec!["a", "b"])?.filter(filter)?;

// execute
let results = t.collect(&mut ctx, 10).unwrap();
let results = t.collect(&mut ctx, 10)?;

// print results
results.iter().for_each(|batch| {
Expand All @@ -94,4 +87,6 @@ fn main() {
println!("{}, {}", c1.value(i), c2.value(i),);
}
});

Ok(())
}
7 changes: 5 additions & 2 deletions rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,16 @@ impl ExecutionContext {
pub fn table(&mut self, table_name: &str) -> Result<Arc<dyn Table>> {
match self.datasources.get(table_name) {
Some(provider) => {
Ok(Arc::new(TableImpl::new(Arc::new(LogicalPlan::TableScan {
let table_scan = LogicalPlan::TableScan {
schema_name: "".to_string(),
table_name: table_name.to_string(),
table_schema: provider.schema().clone(),
projected_schema: provider.schema().clone(),
projection: None,
}))))
};
Ok(Arc::new(TableImpl::new(
&LogicalPlanBuilder::from(&table_scan).build()?,
)))
}
_ => Err(ExecutionError::General(format!(
"No table named '{}'",
Expand Down
122 changes: 33 additions & 89 deletions rust/datafusion/src/execution/table_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,72 +19,54 @@

use std::sync::Arc;

use crate::arrow::datatypes::{DataType, Field, Schema};
use crate::arrow::datatypes::DataType;
use crate::arrow::record_batch::RecordBatch;
use crate::error::{ExecutionError, Result};
use crate::execution::context::ExecutionContext;
use crate::logicalplan::Expr::Literal;
use crate::logicalplan::ScalarValue;
use crate::logicalplan::{Expr, LogicalPlan};
use crate::logicalplan::{LogicalPlanBuilder, ScalarValue};
use crate::table::*;

/// Implementation of Table API
pub struct TableImpl {
plan: Arc<LogicalPlan>,
plan: LogicalPlan,
}

impl TableImpl {
/// Create a new Table based on an existing logical plan
pub fn new(plan: Arc<LogicalPlan>) -> Self {
Self { plan }
pub fn new(plan: &LogicalPlan) -> Self {
Self { plan: plan.clone() }
}
}

impl Table for TableImpl {
/// Apply a projection based on a list of column names
fn select_columns(&self, columns: Vec<&str>) -> Result<Arc<dyn Table>> {
let mut expr: Vec<Expr> = Vec::with_capacity(columns.len());
for column_name in columns {
let i = self.column_index(column_name)?;
expr.push(Expr::Column(i));
}
self.select(expr)
let exprs = columns
.iter()
.map(|name| {
self.plan
.schema()
.index_of(name.to_owned())
.and_then(|i| Ok(Expr::Column(i)))
.map_err(|e| e.into())
})
.collect::<Result<Vec<_>>>()?;
self.select(exprs)
}

/// Create a projection based on arbitrary expressions
fn select(&self, expr_list: Vec<Expr>) -> Result<Arc<dyn Table>> {
let schema = self.plan.schema();
let mut field: Vec<Field> = Vec::with_capacity(expr_list.len());

for expr in &expr_list {
match expr {
Expr::Column(i) => {
field.push(schema.field(*i).clone());
}
other => {
return Err(ExecutionError::NotImplemented(format!(
"Expr {:?} is not currently supported in this context",
other
)))
}
}
}

Ok(Arc::new(TableImpl::new(Arc::new(
LogicalPlan::Projection {
expr: expr_list.clone(),
input: self.plan.clone(),
schema: Arc::new(Schema::new(field)),
},
))))
let plan = LogicalPlanBuilder::from(&self.plan)
.project(expr_list)?
.build()?;
Ok(Arc::new(TableImpl::new(&plan)))
}

/// Create a selection based on a filter expression
fn filter(&self, expr: Expr) -> Result<Arc<dyn Table>> {
Ok(Arc::new(TableImpl::new(Arc::new(LogicalPlan::Selection {
expr,
input: self.plan.clone(),
}))))
let plan = LogicalPlanBuilder::from(&self.plan).filter(expr)?.build()?;
Ok(Arc::new(TableImpl::new(&plan)))
}

/// Perform an aggregate query
Expand All @@ -93,38 +75,23 @@ impl Table for TableImpl {
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<Arc<dyn Table>> {
Ok(Arc::new(TableImpl::new(Arc::new(LogicalPlan::Aggregate {
input: self.plan.clone(),
group_expr,
aggr_expr,
schema: Arc::new(Schema::new(vec![])),
}))))
let plan = LogicalPlanBuilder::from(&self.plan)
.aggregate(group_expr, aggr_expr)?
.build()?;
Ok(Arc::new(TableImpl::new(&plan)))
}

/// Limit the number of rows
fn limit(&self, n: usize) -> Result<Arc<dyn Table>> {
Ok(Arc::new(TableImpl::new(Arc::new(LogicalPlan::Limit {
expr: Literal(ScalarValue::UInt32(n as u32)),
input: self.plan.clone(),
schema: self.plan.schema().clone(),
}))))
fn limit(&self, n: u32) -> Result<Arc<dyn Table>> {
let plan = LogicalPlanBuilder::from(&self.plan)
.limit(Expr::Literal(ScalarValue::UInt32(n)))?
.build()?;
Ok(Arc::new(TableImpl::new(&plan)))
}

/// Return an expression representing a column within this table
fn col(&self, name: &str) -> Result<Expr> {
Ok(Expr::Column(self.column_index(name)?))
}

/// Return the index of a column within this table's schema
fn column_index(&self, name: &str) -> Result<usize> {
let schema = self.plan.schema();
match schema.column_with_name(name) {
Some((i, _)) => Ok(i),
_ => Err(ExecutionError::InvalidColumn(format!(
"No column named '{}'",
name
))),
}
Ok(Expr::Column(self.plan.schema().index_of(name)?))
}

/// Create an expression to represent the min() aggregate function
Expand Down Expand Up @@ -153,7 +120,7 @@ impl Table for TableImpl {
}

/// Convert to logical plan
fn to_logical_plan(&self) -> Arc<LogicalPlan> {
fn to_logical_plan(&self) -> LogicalPlan {
self.plan.clone()
}

Expand Down Expand Up @@ -195,14 +162,6 @@ mod tests {
use crate::execution::context::ExecutionContext;
use crate::test;

#[test]
fn column_index() {
let t = test_table();
assert_eq!(0, t.column_index("c1").unwrap());
assert_eq!(1, t.column_index("c2").unwrap());
assert_eq!(12, t.column_index("c13").unwrap());
}

#[test]
fn select_columns() -> Result<()> {
// build plan using Table API
Expand Down Expand Up @@ -235,21 +194,6 @@ mod tests {
Ok(())
}

#[test]
fn select_invalid_column() -> Result<()> {
let t = test_table();

match t.col("invalid_column_name") {
Ok(_) => panic!(),
Err(e) => assert_eq!(
"InvalidColumn(\"No column named \\\'invalid_column_name\\\'\")",
format!("{:?}", e)
),
}

Ok(())
}

#[test]
fn aggregate() -> Result<()> {
// build plan using Table API
Expand Down
7 changes: 2 additions & 5 deletions rust/datafusion/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ pub trait Table {
) -> Result<Arc<dyn Table>>;

/// limit the number of rows
fn limit(&self, n: usize) -> Result<Arc<dyn Table>>;
fn limit(&self, n: u32) -> Result<Arc<dyn Table>>;

/// Return the logical plan
fn to_logical_plan(&self) -> Arc<LogicalPlan>;
fn to_logical_plan(&self) -> LogicalPlan;

/// Return an expression representing a column within this table
fn col(&self, name: &str) -> Result<Expr>;
Expand All @@ -66,9 +66,6 @@ pub trait Table {
/// Create an expression to represent the count() aggregate function
fn count(&self, expr: &Expr) -> Result<Expr>;

/// Return the index of a column within this table's schema
fn column_index(&self, name: &str) -> Result<usize>;

/// Collects the result as a vector of RecordBatch.
fn collect(
&self,
Expand Down