Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use crate::execution::table_impl::TableImpl;
use crate::logicalplan::*;
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::projection_push_down::ProjectionPushDown;
use crate::optimizer::resolve_columns::ResolveColumnsRule;
use crate::optimizer::type_coercion::TypeCoercionRule;
use crate::sql::parser::{DFASTNode, DFParser, FileType};
use crate::sql::planner::{SchemaProvider, SqlToRel};
Expand Down Expand Up @@ -231,12 +232,13 @@ impl ExecutionContext {
}

/// Optimize the logical plan by applying optimizer rules
pub fn optimize(&self, plan: &LogicalPlan) -> Result<Arc<LogicalPlan>> {
pub fn optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan> {
let rules: Vec<Box<dyn OptimizerRule>> = vec![
Box::new(ResolveColumnsRule::new()),
Box::new(ProjectionPushDown::new()),
Box::new(TypeCoercionRule::new()),
];
let mut plan = Arc::new(plan.clone());
let mut plan = plan.clone();
for mut rule in rules {
plan = rule.optimize(&plan)?;
}
Expand All @@ -246,10 +248,10 @@ impl ExecutionContext {
/// Create a physical plan from a logical plan
pub fn create_physical_plan(
&mut self,
logical_plan: &Arc<LogicalPlan>,
logical_plan: &LogicalPlan,
batch_size: usize,
) -> Result<Arc<dyn ExecutionPlan>> {
match logical_plan.as_ref() {
match logical_plan {
LogicalPlan::TableScan {
table_name,
projection,
Expand Down Expand Up @@ -435,9 +437,10 @@ impl ExecutionContext {
))),
}
}
_ => Err(ExecutionError::NotImplemented(
"Unsupported aggregate expression".to_string(),
)),
other => Err(ExecutionError::General(format!(
"Invalid aggregate expression '{:?}'",
other
))),
}
}

Expand Down Expand Up @@ -731,22 +734,19 @@ mod tests {
let mut ctx = create_ctx(&tmp_dir, 1)?;

let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::UInt32, false),
Field::new("c2", DataType::UInt64, false),
Field::new("state", DataType::Utf8, false),
Field::new("salary", DataType::UInt32, false),
]));

let plan = LogicalPlanBuilder::scan(
"default",
"test",
schema.as_ref(),
Some(vec![0, 1]),
)?
.aggregate(
vec![col(0)],
vec![aggregate_expr("SUM", col(1), DataType::Int32)],
)?
.project(vec![col(0), col(1).alias("total_salary")])?
.build()?;
let plan = LogicalPlanBuilder::scan("default", "test", schema.as_ref(), None)?
.aggregate(
vec![col("state")],
vec![aggregate_expr("SUM", col("salary"), DataType::UInt32)],
)?
.project(vec![col("state"), col_index(1).alias("total_salary")])?
.build()?;

let plan = ctx.optimize(&plan)?;

let physical_plan = ctx.create_physical_plan(&Arc::new(plan), 1024)?;
assert_eq!("c1", physical_plan.schema().field(0).name().as_str());
Expand Down
44 changes: 28 additions & 16 deletions rust/datafusion/src/logicalplan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ pub enum Expr {
Alias(Arc<Expr>, String),
/// index into a value within the row or complex value
Column(usize),
/// Reference to column by name
UnresolvedColumn(String),
/// literal value
Literal(ScalarValue),
/// binary expression e.g. "age > 21"
Expand Down Expand Up @@ -242,6 +244,9 @@ impl Expr {
match self {
Expr::Alias(expr, _) => expr.get_type(schema),
Expr::Column(n) => Ok(schema.field(*n).data_type().clone()),
Expr::UnresolvedColumn(name) => {
Ok(schema.field_with_name(&name)?.data_type().clone())
}
Expr::Literal(l) => Ok(l.get_datatype()),
Expr::Cast { data_type, .. } => Ok(data_type.clone()),
Expr::ScalarFunction { return_type, .. } => Ok(return_type.clone()),
Expand Down Expand Up @@ -356,11 +361,16 @@ impl Expr {
}
}

/// Create a column expression
pub fn col(index: usize) -> Expr {
/// Create a column expression based on a column index
pub fn col_index(index: usize) -> Expr {
Expr::Column(index)
}

/// Create a column expression based on a column name
pub fn col(name: &str) -> Expr {
Expr::UnresolvedColumn(name.to_owned())
}

/// Create a literal string expression
pub fn lit_str(str: &str) -> Expr {
Expr::Literal(ScalarValue::Utf8(str.to_owned()))
Expand All @@ -380,6 +390,7 @@ impl fmt::Debug for Expr {
match self {
Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias),
Expr::Column(i) => write!(f, "#{}", i),
Expr::UnresolvedColumn(name) => write!(f, "#{}", name),
Expr::Literal(v) => write!(f, "{:?}", v),
Expr::Cast { expr, data_type } => {
write!(f, "CAST({:?} AS {:?})", expr, data_type)
Expand Down Expand Up @@ -709,7 +720,7 @@ impl LogicalPlanBuilder {
(0..expr.len()).for_each(|i| match &expr[i] {
Expr::Wildcard => {
(0..input_schema.fields().len())
.for_each(|i| expr_vec.push(col(i).clone()));
.for_each(|i| expr_vec.push(col_index(i).clone()));
}
_ => expr_vec.push(expr[i].clone()),
});
Expand Down Expand Up @@ -791,8 +802,8 @@ mod tests {
&employee_schema(),
Some(vec![0, 3]),
)?
.filter(col(1).eq(&lit_str("CO")))?
.project(vec![col(0)])?
.filter(col("id").eq(&lit_str("CO")))?
.project(vec![col("id")])?
.build()?;

// prove that a plan can be passed to a thread
Expand All @@ -812,13 +823,13 @@ mod tests {
&employee_schema(),
Some(vec![0, 3]),
)?
.filter(col(1).eq(&lit_str("CO")))?
.project(vec![col(0)])?
.filter(col("state").eq(&lit_str("CO")))?
.project(vec![col("id")])?
.build()?;

let expected = "Projection: #0\n \
Selection: #1 Eq Utf8(\"CO\")\n \
TableScan: employee.csv projection=Some([0, 3])";
let expected = "Projection: #id\
\n Selection: #state Eq Utf8(\"CO\")\
\n TableScan: employee.csv projection=Some([0, 3])";

assert_eq!(expected, format!("{:?}", plan));

Expand All @@ -834,15 +845,16 @@ mod tests {
Some(vec![3, 4]),
)?
.aggregate(
vec![col(0)],
vec![aggregate_expr("SUM", col(1), DataType::Int32)],
vec![col("state")],
vec![aggregate_expr("SUM", col("salary"), DataType::Int32)
.alias("total_salary")],
)?
.project(vec![col(0), col(1).alias("total_salary")])?
.project(vec![col("state"), col("total_salary")])?
.build()?;

let expected = "Projection: #0, #1 AS total_salary\
\n Aggregate: groupBy=[[#0]], aggr=[[SUM(#1)]]\
\n TableScan: employee.csv projection=Some([3, 4])";
let expected = "Projection: #state, #total_salary\
\n Aggregate: groupBy=[[#state]], aggr=[[SUM(#salary) AS total_salary]]\
\n TableScan: employee.csv projection=Some([3, 4])";

assert_eq!(expected, format!("{:?}", plan));

Expand Down
1 change: 1 addition & 0 deletions rust/datafusion/src/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@

pub mod optimizer;
pub mod projection_push_down;
pub mod resolve_columns;
pub mod type_coercion;
pub mod utils;
3 changes: 1 addition & 2 deletions rust/datafusion/src/optimizer/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@

use crate::error::Result;
use crate::logicalplan::LogicalPlan;
use std::sync::Arc;

/// An optimizer rules performs a transformation on a logical plan to produce an optimized
/// logical plan.
pub trait OptimizerRule {
/// Perform optimizations on the plan
fn optimize(&mut self, plan: &LogicalPlan) -> Result<Arc<LogicalPlan>>;
fn optimize(&mut self, plan: &LogicalPlan) -> Result<LogicalPlan>;
}
Loading