diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index d2b30e77f871f..07ec321f8cdaa 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -19,7 +19,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; -use datafusion_common::{DFField, DFSchema, DataFusionError, Result}; +use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::{ col, expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}, @@ -94,7 +94,10 @@ fn optimize( schema, alias, }) => { - let arrays = to_arrays(expr, input, &mut expr_set)?; + let input_schema = Arc::clone(input.schema()); + let all_schemas: Vec = + plan.all_schemas().into_iter().cloned().collect(); + let arrays = to_arrays(expr, input_schema, all_schemas, &mut expr_set)?; let (mut new_expr, new_input) = rewrite_expr( &[expr], @@ -112,22 +115,18 @@ fn optimize( )?)) } LogicalPlan::Filter(Filter { predicate, input }) => { - let schema = plan.schema().as_ref().clone(); - let data_type = if let Ok(data_type) = predicate.get_type(&schema) { - data_type - } else { - // predicate type could not be resolved in schema, fall back to all schemas - let schemas = plan.all_schemas(); - let all_schema = - schemas.into_iter().fold(DFSchema::empty(), |mut lhs, rhs| { - lhs.merge(rhs); - lhs - }); - predicate.get_type(&all_schema)? - }; + let input_schema = Arc::clone(input.schema()); + let all_schemas: Vec = + plan.all_schemas().into_iter().cloned().collect(); let mut id_array = vec![]; - expr_to_identifier(predicate, &mut expr_set, &mut id_array, data_type)?; + expr_to_identifier( + predicate, + &mut expr_set, + &mut id_array, + input_schema, + all_schemas, + )?; let (mut new_expr, new_input) = rewrite_expr( &[&[predicate.clone()]], @@ -153,7 +152,11 @@ fn optimize( window_expr, schema, }) => { - let arrays = to_arrays(window_expr, input, &mut expr_set)?; + let input_schema = Arc::clone(input.schema()); + let all_schemas: Vec = + plan.all_schemas().into_iter().cloned().collect(); + let arrays = + to_arrays(window_expr, input_schema, all_schemas, &mut expr_set)?; let (mut new_expr, new_input) = rewrite_expr( &[window_expr], @@ -175,8 +178,17 @@ fn optimize( input, schema, }) => { - let group_arrays = to_arrays(group_expr, input, &mut expr_set)?; - let aggr_arrays = to_arrays(aggr_expr, input, &mut expr_set)?; + let input_schema = Arc::clone(input.schema()); + let all_schemas: Vec = + plan.all_schemas().into_iter().cloned().collect(); + let group_arrays = to_arrays( + group_expr, + Arc::clone(&input_schema), + all_schemas.clone(), + &mut expr_set, + )?; + let aggr_arrays = + to_arrays(aggr_expr, input_schema, all_schemas, &mut expr_set)?; let (mut new_expr, new_input) = rewrite_expr( &[group_expr, aggr_expr], @@ -197,7 +209,10 @@ fn optimize( )?)) } LogicalPlan::Sort(Sort { expr, input, fetch }) => { - let arrays = to_arrays(expr, input, &mut expr_set)?; + let input_schema = Arc::clone(input.schema()); + let all_schemas: Vec = + plan.all_schemas().into_iter().cloned().collect(); + let arrays = to_arrays(expr, input_schema, all_schemas, &mut expr_set)?; let (mut new_expr, new_input) = rewrite_expr( &[expr], @@ -255,14 +270,20 @@ fn pop_expr(new_expr: &mut Vec>) -> Result> { fn to_arrays( expr: &[Expr], - input: &LogicalPlan, + input_schema: DFSchemaRef, + all_schemas: Vec, expr_set: &mut ExprSet, ) -> Result>> { expr.iter() .map(|e| { - let data_type = e.get_type(input.schema())?; let mut id_array = vec![]; - expr_to_identifier(e, expr_set, &mut id_array, data_type)?; + expr_to_identifier( + e, + expr_set, + &mut id_array, + Arc::clone(&input_schema), + all_schemas.clone(), + )?; Ok(id_array) }) @@ -370,7 +391,15 @@ struct ExprIdentifierVisitor<'a> { expr_set: &'a mut ExprSet, /// series number (usize) and identifier. id_array: &'a mut Vec<(usize, Identifier)>, - data_type: DataType, + /// input schema for the node that we're optimizing, so we can determine the correct datatype + /// for each subexpression + input_schema: DFSchemaRef, + /// all schemas in the logical plan, as a fall back if we cannot resolve an expression type + /// from the input schema alone + // This fallback should never be necessary as the expression datatype should always be + // resolvable from the input schema of the node that's being optimized. + // todo: This can likely be removed if we are sure it's safe to do so. + all_schemas: Vec, // inner states visit_stack: Vec, @@ -448,7 +477,25 @@ impl ExpressionVisitor for ExprIdentifierVisitor<'_> { self.id_array[idx] = (self.series_number, desc.clone()); self.visit_stack.push(VisitRecord::ExprItem(desc.clone())); - let data_type = self.data_type.clone(); + + let data_type = if let Ok(data_type) = expr.get_type(&self.input_schema) { + data_type + } else { + // Expression type could not be resolved in schema, fall back to all schemas. + // + // This fallback should never be necessary as the expression datatype should always be + // resolvable from the input schema of the node that's being optimized. + // todo: This else-branch can likely be removed if we are sure it's safe to do so. + let merged_schema = + self.all_schemas + .iter() + .fold(DFSchema::empty(), |mut lhs, rhs| { + lhs.merge(rhs); + lhs + }); + expr.get_type(&merged_schema)? + }; + self.expr_set .entry(desc) .or_insert_with(|| (expr.clone(), 0, data_type)) @@ -462,12 +509,14 @@ fn expr_to_identifier( expr: &Expr, expr_set: &mut ExprSet, id_array: &mut Vec<(usize, Identifier)>, - data_type: DataType, + input_schema: DFSchemaRef, + all_schemas: Vec, ) -> Result<()> { expr.accept(ExprIdentifierVisitor { expr_set, id_array, - data_type, + input_schema, + all_schemas, visit_stack: vec![], node_count: 0, series_number: 0, @@ -577,7 +626,8 @@ fn replace_common_expr( mod test { use super::*; use crate::test::*; - use datafusion_expr::logical_plan::JoinType; + use arrow::datatypes::{Field, Schema}; + use datafusion_expr::logical_plan::{table_scan, JoinType}; use datafusion_expr::{ avg, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, sum, Operator, @@ -597,7 +647,7 @@ mod test { fn id_array_visitor() -> Result<()> { let expr = binary_expr( binary_expr( - sum(binary_expr(col("a"), Operator::Plus, lit("1"))), + sum(binary_expr(col("a"), Operator::Plus, lit(1))), Operator::Minus, avg(col("c")), ), @@ -605,14 +655,28 @@ mod test { lit(2), ); + let schema = Arc::new(DFSchema::new_with_metadata( + vec![ + DFField::new(None, "a", DataType::Int64, false), + DFField::new(None, "c", DataType::Int64, false), + ], + Default::default(), + )?); + let mut id_array = vec![]; - expr_to_identifier(&expr, &mut HashMap::new(), &mut id_array, DataType::Int64)?; + expr_to_identifier( + &expr, + &mut HashMap::new(), + &mut id_array, + Arc::clone(&schema), + vec![schema], + )?; let expected = vec![ - (9, "SUM(a + Utf8(\"1\")) - AVG(c) * Int32(2)Int32(2)SUM(a + Utf8(\"1\")) - AVG(c)AVG(c)cSUM(a + Utf8(\"1\"))a + Utf8(\"1\")Utf8(\"1\")a"), - (7, "SUM(a + Utf8(\"1\")) - AVG(c)AVG(c)cSUM(a + Utf8(\"1\"))a + Utf8(\"1\")Utf8(\"1\")a"), - (4, "SUM(a + Utf8(\"1\"))a + Utf8(\"1\")Utf8(\"1\")a"), - (3, "a + Utf8(\"1\")Utf8(\"1\")a"), + (9, "SUM(a + Int32(1)) - AVG(c) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), + (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), + (4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"), + (3, "a + Int32(1)Int32(1)a"), (1, ""), (2, ""), (6, "AVG(c)c"), @@ -796,4 +860,55 @@ mod test { assert!(field_set.insert(field.qualified_name())); } } + + #[test] + fn eliminated_subexpr_datatype() { + use datafusion_expr::cast; + + let schema = Schema::new(vec![ + Field::new("a", DataType::UInt64, false), + Field::new("b", DataType::UInt64, false), + Field::new("c", DataType::UInt64, false), + ]); + + let plan = table_scan(Some("table"), &schema, None) + .unwrap() + .filter( + cast(col("a"), DataType::Int64) + .lt(lit(1_i64)) + .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))), + ) + .unwrap() + .build() + .unwrap(); + let rule = CommonSubexprEliminate {}; + let optimized_plan = rule.optimize(&plan, &mut OptimizerConfig::new()).unwrap(); + + let schema = optimized_plan.schema(); + let fields_with_datatypes: Vec<_> = schema + .fields() + .iter() + .map(|field| (field.name(), field.data_type())) + .collect(); + let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}"); + let expected = r###"[ + ( + "CAST(table.a AS Int64)table.a", + Int64, + ), + ( + "a", + UInt64, + ), + ( + "b", + UInt64, + ), + ( + "c", + UInt64, + ), +]"###; + assert_eq!(expected, formatted_fields_with_datatype); + } }