Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 41 additions & 20 deletions rust/datafusion/src/logicalplan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,45 +232,50 @@ pub enum Expr {
/// The `DataType` the expression will yield
return_type: DataType,
},
/// Wildcard
Wildcard,
}

impl Expr {
/// Find the `DataType` for the expression
pub fn get_type(&self, schema: &Schema) -> DataType {
pub fn get_type(&self, schema: &Schema) -> Result<DataType> {
match self {
Expr::Alias(expr, _) => expr.get_type(schema),
Expr::Column(n) => schema.field(*n).data_type().clone(),
Expr::Literal(l) => l.get_datatype(),
Expr::Cast { data_type, .. } => data_type.clone(),
Expr::ScalarFunction { return_type, .. } => return_type.clone(),
Expr::AggregateFunction { return_type, .. } => return_type.clone(),
Expr::Not(_) => DataType::Boolean,
Expr::IsNull(_) => DataType::Boolean,
Expr::IsNotNull(_) => DataType::Boolean,
Expr::Column(n) => Ok(schema.field(*n).data_type().clone()),
Expr::Literal(l) => Ok(l.get_datatype()),
Expr::Cast { data_type, .. } => Ok(data_type.clone()),
Expr::ScalarFunction { return_type, .. } => Ok(return_type.clone()),
Expr::AggregateFunction { return_type, .. } => Ok(return_type.clone()),
Expr::Not(_) => Ok(DataType::Boolean),
Expr::IsNull(_) => Ok(DataType::Boolean),
Expr::IsNotNull(_) => Ok(DataType::Boolean),
Expr::BinaryExpr {
ref left,
ref right,
ref op,
} => match op {
Operator::Eq | Operator::NotEq => DataType::Boolean,
Operator::Lt | Operator::LtEq => DataType::Boolean,
Operator::Gt | Operator::GtEq => DataType::Boolean,
Operator::And | Operator::Or => DataType::Boolean,
Operator::Eq | Operator::NotEq => Ok(DataType::Boolean),
Operator::Lt | Operator::LtEq => Ok(DataType::Boolean),
Operator::Gt | Operator::GtEq => Ok(DataType::Boolean),
Operator::And | Operator::Or => Ok(DataType::Boolean),
_ => {
let left_type = left.get_type(schema);
let right_type = right.get_type(schema);
utils::get_supertype(&left_type, &right_type).unwrap()
let left_type = left.get_type(schema)?;
let right_type = right.get_type(schema)?;
utils::get_supertype(&left_type, &right_type)
}
},
Expr::Sort { ref expr, .. } => expr.get_type(schema),
Expr::Wildcard => Err(ExecutionError::General(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
}
}

/// Perform a type cast on the expression value.
///
/// Will `Err` if the type cast cannot be performed.
pub fn cast_to(&self, cast_to_type: &DataType, schema: &Schema) -> Result<Expr> {
let this_type = self.get_type(schema);
let this_type = self.get_type(schema)?;
if this_type == *cast_to_type {
Ok(self.clone())
} else if can_coerce_from(cast_to_type, &this_type) {
Expand Down Expand Up @@ -414,6 +419,7 @@ impl fmt::Debug for Expr {

write!(f, ")")
}
Expr::Wildcard => write!(f, "*"),
}
}
}
Expand Down Expand Up @@ -698,12 +704,27 @@ impl LogicalPlanBuilder {
/// Apply a projection
pub fn project(&self, expr: &Vec<Expr>) -> Result<Self> {
let input_schema = self.plan.schema();
let projected_expr = if expr.contains(&Expr::Wildcard) {
let mut expr_vec = vec![];
(0..expr.len()).for_each(|i| match &expr[i] {
Expr::Wildcard => {
(0..input_schema.fields().len())
.for_each(|i| expr_vec.push(col(i).clone()));
}
_ => expr_vec.push(expr[i].clone()),
});
expr_vec
} else {
expr.clone()
};

let schema =
Schema::new(utils::exprlist_to_fields(&expr, input_schema.as_ref())?);
let schema = Schema::new(utils::exprlist_to_fields(
&projected_expr,
input_schema.as_ref(),
)?);

Ok(Self::from(&LogicalPlan::Projection {
expr: expr.clone(),
expr: projected_expr,
input: Arc::new(self.plan.clone()),
schema: Arc::new(schema),
}))
Expand Down
13 changes: 8 additions & 5 deletions rust/datafusion/src/optimizer/projection_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl ProjectionPushDown {
schema,
} => {
// collect all columns referenced by projection expressions
utils::exprlist_to_column_indices(&expr, accum);
utils::exprlist_to_column_indices(&expr, accum)?;

// push projection down
let input = self.optimize_plan(&input, accum, mapping)?;
Expand All @@ -74,7 +74,7 @@ impl ProjectionPushDown {
}
LogicalPlan::Selection { expr, input } => {
// collect all columns referenced by filter expression
utils::expr_to_column_indices(expr, accum);
utils::expr_to_column_indices(expr, accum)?;

// push projection down
let input = self.optimize_plan(&input, accum, mapping)?;
Expand All @@ -94,8 +94,8 @@ impl ProjectionPushDown {
schema,
} => {
// collect all columns referenced by grouping and aggregate expressions
utils::exprlist_to_column_indices(&group_expr, accum);
utils::exprlist_to_column_indices(&aggr_expr, accum);
utils::exprlist_to_column_indices(&group_expr, accum)?;
utils::exprlist_to_column_indices(&aggr_expr, accum)?;

// push projection down
let input = self.optimize_plan(&input, accum, mapping)?;
Expand All @@ -117,7 +117,7 @@ impl ProjectionPushDown {
schema,
} => {
// collect all columns referenced by sort expressions
utils::exprlist_to_column_indices(&expr, accum);
utils::exprlist_to_column_indices(&expr, accum)?;

// push projection down
let input = self.optimize_plan(&input, accum, mapping)?;
Expand Down Expand Up @@ -271,6 +271,9 @@ impl ProjectionPushDown {
args: self.rewrite_exprs(args, mapping)?,
return_type: return_type.clone(),
}),
Expr::Wildcard => Err(ExecutionError::General(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
}
}

Expand Down
4 changes: 2 additions & 2 deletions rust/datafusion/src/optimizer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ fn rewrite_expr(expr: &Expr, schema: &Schema) -> Result<Expr> {
Expr::BinaryExpr { left, op, right } => {
let left = rewrite_expr(left, schema)?;
let right = rewrite_expr(right, schema)?;
let left_type = left.get_type(schema);
let right_type = right.get_type(schema);
let left_type = left.get_type(schema)?;
let right_type = right.get_type(schema)?;
if left_type == right_type {
Ok(Expr::BinaryExpr {
left: Arc::new(left),
Expand Down
39 changes: 27 additions & 12 deletions rust/datafusion/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,38 +26,52 @@ use crate::logicalplan::Expr;

/// Recursively walk a list of expression trees, collecting the unique set of column
/// indexes referenced in the expression
pub fn exprlist_to_column_indices(expr: &Vec<Expr>, accum: &mut HashSet<usize>) {
expr.iter().for_each(|e| expr_to_column_indices(e, accum));
pub fn exprlist_to_column_indices(
expr: &Vec<Expr>,
accum: &mut HashSet<usize>,
) -> Result<()> {
for e in expr {
expr_to_column_indices(e, accum)?;
}
Ok(())
}

/// Recursively walk an expression tree, collecting the unique set of column indexes
/// referenced in the expression
pub fn expr_to_column_indices(expr: &Expr, accum: &mut HashSet<usize>) {
pub fn expr_to_column_indices(expr: &Expr, accum: &mut HashSet<usize>) -> Result<()> {
match expr {
Expr::Alias(expr, _) => expr_to_column_indices(expr, accum),
Expr::Column(i) => {
accum.insert(*i);
Ok(())
}
Expr::Literal(_) => {
// not needed
Ok(())
}
Expr::Literal(_) => { /* not needed */ }
Expr::Not(e) => expr_to_column_indices(e, accum),
Expr::IsNull(e) => expr_to_column_indices(e, accum),
Expr::IsNotNull(e) => expr_to_column_indices(e, accum),
Expr::BinaryExpr { left, right, .. } => {
expr_to_column_indices(left, accum);
expr_to_column_indices(right, accum);
expr_to_column_indices(left, accum)?;
expr_to_column_indices(right, accum)?;
Ok(())
}
Expr::Cast { expr, .. } => expr_to_column_indices(expr, accum),
Expr::Sort { expr, .. } => expr_to_column_indices(expr, accum),
Expr::AggregateFunction { args, .. } => exprlist_to_column_indices(args, accum),
Expr::ScalarFunction { args, .. } => exprlist_to_column_indices(args, accum),
Expr::Wildcard => Err(ExecutionError::General(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
)),
}
}

/// Create field meta-data from an expression, for use in a result set schema
pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result<Field> {
match e {
Expr::Alias(expr, name) => {
Ok(Field::new(name, expr.get_type(input_schema), true))
Ok(Field::new(name, expr.get_type(input_schema)?, true))
}
Expr::Column(i) => {
let input_schema_field_count = input_schema.fields().len();
Expand Down Expand Up @@ -89,8 +103,8 @@ pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result<Field> {
ref right,
..
} => {
let left_type = left.get_type(input_schema);
let right_type = right.get_type(input_schema);
let left_type = left.get_type(input_schema)?;
let right_type = right.get_type(input_schema)?;
Ok(Field::new(
"binary_expr",
get_supertype(&left_type, &right_type).unwrap(),
Expand Down Expand Up @@ -235,23 +249,24 @@ mod tests {
use std::sync::Arc;

#[test]
fn test_collect_expr() {
fn test_collect_expr() -> Result<()> {
let mut accum: HashSet<usize> = HashSet::new();
expr_to_column_indices(
&Expr::Cast {
expr: Arc::new(Expr::Column(3)),
data_type: DataType::Float64,
},
&mut accum,
);
)?;
expr_to_column_indices(
&Expr::Cast {
expr: Arc::new(Expr::Column(3)),
data_type: DataType::Float64,
},
&mut accum,
);
)?;
assert_eq!(1, accum.len());
assert!(accum.contains(&3));
Ok(())
}
}
33 changes: 20 additions & 13 deletions rust/datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,7 @@ impl<S: SchemaProvider> SqlToRel<S> {
}
}

ASTNode::SQLWildcard => {
Err(ExecutionError::NotImplemented("SQL wildcard operator is not supported in projection - please use explicit column names".to_string()))
}
ASTNode::SQLWildcard => Ok(Expr::Wildcard),

ASTNode::SQLCast {
ref expr,
Expand All @@ -302,17 +300,17 @@ impl<S: SchemaProvider> SqlToRel<S> {
Ok(Expr::IsNotNull(Arc::new(self.sql_to_rex(expr, schema)?)))
}

ASTNode::SQLUnary{
ASTNode::SQLUnary {
ref operator,
ref expr,
} => {
match *operator {
SQLOperator::Not => Ok(Expr::Not(Arc::new(self.sql_to_rex(expr, schema)?))),
_ => Err(ExecutionError::InternalError(format!(
"SQL binary operator cannot be interpreted as a unary operator"
))),
} => match *operator {
SQLOperator::Not => {
Ok(Expr::Not(Arc::new(self.sql_to_rex(expr, schema)?)))
}
}
_ => Err(ExecutionError::InternalError(format!(
"SQL binary operator cannot be interpreted as a unary operator"
))),
},

ASTNode::SQLBinaryExpr {
ref left,
Expand Down Expand Up @@ -365,7 +363,7 @@ impl<S: SchemaProvider> SqlToRel<S> {

// return type is same as the argument type for these aggregate
// functions
let return_type = rex_args[0].get_type(schema).clone();
let return_type = rex_args[0].get_type(schema)?.clone();

Ok(Expr::AggregateFunction {
name: id.clone(),
Expand All @@ -382,7 +380,7 @@ impl<S: SchemaProvider> SqlToRel<S> {
}
ASTNode::SQLWildcard => {
Ok(Expr::Literal(ScalarValue::UInt8(1)))
},
}
_ => self.sql_to_rex(a, schema),
})
.collect::<Result<Vec<Expr>>>()?;
Expand Down Expand Up @@ -570,6 +568,15 @@ mod tests {
);
}

#[test]
fn test_wildcard() {
quick_test(
"SELECT * from person",
"Projection: #0, #1, #2, #3, #4, #5, #6\
\n TableScan: person projection=None",
);
}

#[test]
fn select_count_one() {
let sql = "SELECT COUNT(1) FROM person";
Expand Down