diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 48fbae96ef7..fd23ff15dec 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -232,37 +232,42 @@ pub enum Expr { /// The `DataType` the expression will yield return_type: DataType, }, + /// Wildcard + Wildcard, } impl Expr { /// Find the `DataType` for the expression - pub fn get_type(&self, schema: &Schema) -> DataType { + pub fn get_type(&self, schema: &Schema) -> Result { match self { Expr::Alias(expr, _) => expr.get_type(schema), - Expr::Column(n) => schema.field(*n).data_type().clone(), - Expr::Literal(l) => l.get_datatype(), - Expr::Cast { data_type, .. } => data_type.clone(), - Expr::ScalarFunction { return_type, .. } => return_type.clone(), - Expr::AggregateFunction { return_type, .. } => return_type.clone(), - Expr::Not(_) => DataType::Boolean, - Expr::IsNull(_) => DataType::Boolean, - Expr::IsNotNull(_) => DataType::Boolean, + Expr::Column(n) => Ok(schema.field(*n).data_type().clone()), + Expr::Literal(l) => Ok(l.get_datatype()), + Expr::Cast { data_type, .. } => Ok(data_type.clone()), + Expr::ScalarFunction { return_type, .. } => Ok(return_type.clone()), + Expr::AggregateFunction { return_type, .. } => Ok(return_type.clone()), + Expr::Not(_) => Ok(DataType::Boolean), + Expr::IsNull(_) => Ok(DataType::Boolean), + Expr::IsNotNull(_) => Ok(DataType::Boolean), Expr::BinaryExpr { ref left, ref right, ref op, } => match op { - Operator::Eq | Operator::NotEq => DataType::Boolean, - Operator::Lt | Operator::LtEq => DataType::Boolean, - Operator::Gt | Operator::GtEq => DataType::Boolean, - Operator::And | Operator::Or => DataType::Boolean, + Operator::Eq | Operator::NotEq => Ok(DataType::Boolean), + Operator::Lt | Operator::LtEq => Ok(DataType::Boolean), + Operator::Gt | Operator::GtEq => Ok(DataType::Boolean), + Operator::And | Operator::Or => Ok(DataType::Boolean), _ => { - let left_type = left.get_type(schema); - let right_type = right.get_type(schema); - utils::get_supertype(&left_type, &right_type).unwrap() + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + utils::get_supertype(&left_type, &right_type) } }, Expr::Sort { ref expr, .. } => expr.get_type(schema), + Expr::Wildcard => Err(ExecutionError::General( + "Wildcard expressions are not valid in a logical query plan".to_owned(), + )), } } @@ -270,7 +275,7 @@ impl Expr { /// /// Will `Err` if the type cast cannot be performed. pub fn cast_to(&self, cast_to_type: &DataType, schema: &Schema) -> Result { - let this_type = self.get_type(schema); + let this_type = self.get_type(schema)?; if this_type == *cast_to_type { Ok(self.clone()) } else if can_coerce_from(cast_to_type, &this_type) { @@ -414,6 +419,7 @@ impl fmt::Debug for Expr { write!(f, ")") } + Expr::Wildcard => write!(f, "*"), } } } @@ -698,12 +704,27 @@ impl LogicalPlanBuilder { /// Apply a projection pub fn project(&self, expr: &Vec) -> Result { let input_schema = self.plan.schema(); + let projected_expr = if expr.contains(&Expr::Wildcard) { + let mut expr_vec = vec![]; + (0..expr.len()).for_each(|i| match &expr[i] { + Expr::Wildcard => { + (0..input_schema.fields().len()) + .for_each(|i| expr_vec.push(col(i).clone())); + } + _ => expr_vec.push(expr[i].clone()), + }); + expr_vec + } else { + expr.clone() + }; - let schema = - Schema::new(utils::exprlist_to_fields(&expr, input_schema.as_ref())?); + let schema = Schema::new(utils::exprlist_to_fields( + &projected_expr, + input_schema.as_ref(), + )?); Ok(Self::from(&LogicalPlan::Projection { - expr: expr.clone(), + expr: projected_expr, input: Arc::new(self.plan.clone()), schema: Arc::new(schema), })) diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs index 5a541c9d4de..a05f8656d02 100644 --- a/rust/datafusion/src/optimizer/projection_push_down.rs +++ b/rust/datafusion/src/optimizer/projection_push_down.rs @@ -58,7 +58,7 @@ impl ProjectionPushDown { schema, } => { // collect all columns referenced by projection expressions - utils::exprlist_to_column_indices(&expr, accum); + utils::exprlist_to_column_indices(&expr, accum)?; // push projection down let input = self.optimize_plan(&input, accum, mapping)?; @@ -74,7 +74,7 @@ impl ProjectionPushDown { } LogicalPlan::Selection { expr, input } => { // collect all columns referenced by filter expression - utils::expr_to_column_indices(expr, accum); + utils::expr_to_column_indices(expr, accum)?; // push projection down let input = self.optimize_plan(&input, accum, mapping)?; @@ -94,8 +94,8 @@ impl ProjectionPushDown { schema, } => { // collect all columns referenced by grouping and aggregate expressions - utils::exprlist_to_column_indices(&group_expr, accum); - utils::exprlist_to_column_indices(&aggr_expr, accum); + utils::exprlist_to_column_indices(&group_expr, accum)?; + utils::exprlist_to_column_indices(&aggr_expr, accum)?; // push projection down let input = self.optimize_plan(&input, accum, mapping)?; @@ -117,7 +117,7 @@ impl ProjectionPushDown { schema, } => { // collect all columns referenced by sort expressions - utils::exprlist_to_column_indices(&expr, accum); + utils::exprlist_to_column_indices(&expr, accum)?; // push projection down let input = self.optimize_plan(&input, accum, mapping)?; @@ -271,6 +271,9 @@ impl ProjectionPushDown { args: self.rewrite_exprs(args, mapping)?, return_type: return_type.clone(), }), + Expr::Wildcard => Err(ExecutionError::General( + "Wildcard expressions are not valid in a logical query plan".to_owned(), + )), } } diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index bfd63b4d5e1..e93d01e640d 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -96,8 +96,8 @@ fn rewrite_expr(expr: &Expr, schema: &Schema) -> Result { Expr::BinaryExpr { left, op, right } => { let left = rewrite_expr(left, schema)?; let right = rewrite_expr(right, schema)?; - let left_type = left.get_type(schema); - let right_type = right.get_type(schema); + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; if left_type == right_type { Ok(Expr::BinaryExpr { left: Arc::new(left), diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index b4d9889db9d..1755ac61198 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -26,30 +26,44 @@ use crate::logicalplan::Expr; /// Recursively walk a list of expression trees, collecting the unique set of column /// indexes referenced in the expression -pub fn exprlist_to_column_indices(expr: &Vec, accum: &mut HashSet) { - expr.iter().for_each(|e| expr_to_column_indices(e, accum)); +pub fn exprlist_to_column_indices( + expr: &Vec, + accum: &mut HashSet, +) -> Result<()> { + for e in expr { + expr_to_column_indices(e, accum)?; + } + Ok(()) } /// Recursively walk an expression tree, collecting the unique set of column indexes /// referenced in the expression -pub fn expr_to_column_indices(expr: &Expr, accum: &mut HashSet) { +pub fn expr_to_column_indices(expr: &Expr, accum: &mut HashSet) -> Result<()> { match expr { Expr::Alias(expr, _) => expr_to_column_indices(expr, accum), Expr::Column(i) => { accum.insert(*i); + Ok(()) + } + Expr::Literal(_) => { + // not needed + Ok(()) } - Expr::Literal(_) => { /* not needed */ } Expr::Not(e) => expr_to_column_indices(e, accum), Expr::IsNull(e) => expr_to_column_indices(e, accum), Expr::IsNotNull(e) => expr_to_column_indices(e, accum), Expr::BinaryExpr { left, right, .. } => { - expr_to_column_indices(left, accum); - expr_to_column_indices(right, accum); + expr_to_column_indices(left, accum)?; + expr_to_column_indices(right, accum)?; + Ok(()) } Expr::Cast { expr, .. } => expr_to_column_indices(expr, accum), Expr::Sort { expr, .. } => expr_to_column_indices(expr, accum), Expr::AggregateFunction { args, .. } => exprlist_to_column_indices(args, accum), Expr::ScalarFunction { args, .. } => exprlist_to_column_indices(args, accum), + Expr::Wildcard => Err(ExecutionError::General( + "Wildcard expressions are not valid in a logical query plan".to_owned(), + )), } } @@ -57,7 +71,7 @@ pub fn expr_to_column_indices(expr: &Expr, accum: &mut HashSet) { pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result { match e { Expr::Alias(expr, name) => { - Ok(Field::new(name, expr.get_type(input_schema), true)) + Ok(Field::new(name, expr.get_type(input_schema)?, true)) } Expr::Column(i) => { let input_schema_field_count = input_schema.fields().len(); @@ -89,8 +103,8 @@ pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result { ref right, .. } => { - let left_type = left.get_type(input_schema); - let right_type = right.get_type(input_schema); + let left_type = left.get_type(input_schema)?; + let right_type = right.get_type(input_schema)?; Ok(Field::new( "binary_expr", get_supertype(&left_type, &right_type).unwrap(), @@ -235,7 +249,7 @@ mod tests { use std::sync::Arc; #[test] - fn test_collect_expr() { + fn test_collect_expr() -> Result<()> { let mut accum: HashSet = HashSet::new(); expr_to_column_indices( &Expr::Cast { @@ -243,15 +257,16 @@ mod tests { data_type: DataType::Float64, }, &mut accum, - ); + )?; expr_to_column_indices( &Expr::Cast { expr: Arc::new(Expr::Column(3)), data_type: DataType::Float64, }, &mut accum, - ); + )?; assert_eq!(1, accum.len()); assert!(accum.contains(&3)); + Ok(()) } } diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 6abebf91419..98f4f774e9d 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -282,9 +282,7 @@ impl SqlToRel { } } - ASTNode::SQLWildcard => { - Err(ExecutionError::NotImplemented("SQL wildcard operator is not supported in projection - please use explicit column names".to_string())) - } + ASTNode::SQLWildcard => Ok(Expr::Wildcard), ASTNode::SQLCast { ref expr, @@ -302,17 +300,17 @@ impl SqlToRel { Ok(Expr::IsNotNull(Arc::new(self.sql_to_rex(expr, schema)?))) } - ASTNode::SQLUnary{ + ASTNode::SQLUnary { ref operator, ref expr, - } => { - match *operator { - SQLOperator::Not => Ok(Expr::Not(Arc::new(self.sql_to_rex(expr, schema)?))), - _ => Err(ExecutionError::InternalError(format!( - "SQL binary operator cannot be interpreted as a unary operator" - ))), + } => match *operator { + SQLOperator::Not => { + Ok(Expr::Not(Arc::new(self.sql_to_rex(expr, schema)?))) } - } + _ => Err(ExecutionError::InternalError(format!( + "SQL binary operator cannot be interpreted as a unary operator" + ))), + }, ASTNode::SQLBinaryExpr { ref left, @@ -365,7 +363,7 @@ impl SqlToRel { // return type is same as the argument type for these aggregate // functions - let return_type = rex_args[0].get_type(schema).clone(); + let return_type = rex_args[0].get_type(schema)?.clone(); Ok(Expr::AggregateFunction { name: id.clone(), @@ -382,7 +380,7 @@ impl SqlToRel { } ASTNode::SQLWildcard => { Ok(Expr::Literal(ScalarValue::UInt8(1))) - }, + } _ => self.sql_to_rex(a, schema), }) .collect::>>()?; @@ -570,6 +568,15 @@ mod tests { ); } + #[test] + fn test_wildcard() { + quick_test( + "SELECT * from person", + "Projection: #0, #1, #2, #3, #4, #5, #6\ + \n TableScan: person projection=None", + ); + } + #[test] fn select_count_one() { let sql = "SELECT COUNT(1) FROM person";