From 05b0b473104f216d824c444e76174b25949f1e0c Mon Sep 17 00:00:00 2001 From: Alexander Spies Date: Wed, 5 Oct 2022 12:57:44 +0200 Subject: [PATCH 1/4] CommonSubexprEliminate: Fix additional col schema --- .../optimizer/src/common_subexpr_eliminate.rs | 92 ++++++++++++++----- 1 file changed, 69 insertions(+), 23 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index d2b30e77f871f..89c034ffca799 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -19,7 +19,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; -use datafusion_common::{DFField, DFSchema, DataFusionError, Result}; +use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::{ col, expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}, @@ -112,22 +112,10 @@ fn optimize( )?)) } LogicalPlan::Filter(Filter { predicate, input }) => { - let schema = plan.schema().as_ref().clone(); - let data_type = if let Ok(data_type) = predicate.get_type(&schema) { - data_type - } else { - // predicate type could not be resolved in schema, fall back to all schemas - let schemas = plan.all_schemas(); - let all_schema = - schemas.into_iter().fold(DFSchema::empty(), |mut lhs, rhs| { - lhs.merge(rhs); - lhs - }); - predicate.get_type(&all_schema)? - }; + let schema = Arc::clone(input.schema()); let mut id_array = vec![]; - expr_to_identifier(predicate, &mut expr_set, &mut id_array, data_type)?; + expr_to_identifier(predicate, &mut expr_set, &mut id_array, schema)?; let (mut new_expr, new_input) = rewrite_expr( &[&[predicate.clone()]], @@ -260,9 +248,9 @@ fn to_arrays( ) -> Result>> { expr.iter() .map(|e| { - let data_type = e.get_type(input.schema())?; + let schema = Arc::clone(input.schema()); let mut id_array = vec![]; - expr_to_identifier(e, expr_set, &mut id_array, data_type)?; + expr_to_identifier(e, expr_set, &mut id_array, schema)?; Ok(id_array) }) @@ -370,7 +358,8 @@ struct ExprIdentifierVisitor<'a> { expr_set: &'a mut ExprSet, /// series number (usize) and identifier. id_array: &'a mut Vec<(usize, Identifier)>, - data_type: DataType, + schema: DFSchemaRef, + //todo: also look in all schemas // inner states visit_stack: Vec, @@ -448,7 +437,7 @@ impl ExpressionVisitor for ExprIdentifierVisitor<'_> { self.id_array[idx] = (self.series_number, desc.clone()); self.visit_stack.push(VisitRecord::ExprItem(desc.clone())); - let data_type = self.data_type.clone(); + let data_type = expr.get_type(&self.schema)?; self.expr_set .entry(desc) .or_insert_with(|| (expr.clone(), 0, data_type)) @@ -462,12 +451,12 @@ fn expr_to_identifier( expr: &Expr, expr_set: &mut ExprSet, id_array: &mut Vec<(usize, Identifier)>, - data_type: DataType, + schema: DFSchemaRef, ) -> Result<()> { expr.accept(ExprIdentifierVisitor { expr_set, id_array, - data_type, + schema, visit_stack: vec![], node_count: 0, series_number: 0, @@ -577,7 +566,8 @@ fn replace_common_expr( mod test { use super::*; use crate::test::*; - use datafusion_expr::logical_plan::JoinType; + use arrow::datatypes::{Field, Schema}; + use datafusion_expr::logical_plan::{table_scan, JoinType}; use datafusion_expr::{ avg, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, sum, Operator, @@ -605,8 +595,13 @@ mod test { lit(2), ); + let schema = Arc::new(DFSchema::new_with_metadata( + vec![DFField::new(None, "a", DataType::Int64, false)], + Default::default(), + )?); + let mut id_array = vec![]; - expr_to_identifier(&expr, &mut HashMap::new(), &mut id_array, DataType::Int64)?; + expr_to_identifier(&expr, &mut HashMap::new(), &mut id_array, schema)?; let expected = vec![ (9, "SUM(a + Utf8(\"1\")) - AVG(c) * Int32(2)Int32(2)SUM(a + Utf8(\"1\")) - AVG(c)AVG(c)cSUM(a + Utf8(\"1\"))a + Utf8(\"1\")Utf8(\"1\")a"), @@ -796,4 +791,55 @@ mod test { assert!(field_set.insert(field.qualified_name())); } } + + #[test] + fn eliminated_subexpr_datatype() { + use datafusion_expr::cast; + + let schema = Schema::new(vec![ + Field::new("a", DataType::UInt64, false), + Field::new("b", DataType::UInt64, false), + Field::new("c", DataType::UInt64, false), + ]); + + let plan = table_scan(Some("table"), &schema, None) + .unwrap() + .filter( + cast(col("a"), DataType::Int64) + .lt(lit(1_i64)) + .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))), + ) + .unwrap() + .build() + .unwrap(); + let rule = CommonSubexprEliminate {}; + let optimized_plan = rule.optimize(&plan, &mut OptimizerConfig::new()).unwrap(); + + let schema = optimized_plan.schema(); + let fields_with_datatypes: Vec<_> = schema + .fields() + .iter() + .map(|field| (field.name(), field.data_type())) + .collect(); + let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}"); + let expected = r###"[ + ( + "CAST(table.a AS Int64)table.a", + Int64, + ), + ( + "a", + UInt64, + ), + ( + "b", + UInt64, + ), + ( + "c", + UInt64, + ), +]"###; + assert_eq!(expected, formatted_fields_with_datatype); + } } From f930e25ad40264cebe39ccd69d97594786b65bca Mon Sep 17 00:00:00 2001 From: Alexander Spies Date: Wed, 5 Oct 2022 17:25:17 +0200 Subject: [PATCH 2/4] Use correct types in test id_array_visitor --- .../optimizer/src/common_subexpr_eliminate.rs | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 89c034ffca799..7706ec2146bd6 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -358,7 +358,8 @@ struct ExprIdentifierVisitor<'a> { expr_set: &'a mut ExprSet, /// series number (usize) and identifier. id_array: &'a mut Vec<(usize, Identifier)>, - schema: DFSchemaRef, + /// input schema so we can determine the correct datatype for each subexpression + input_schema: DFSchemaRef, //todo: also look in all schemas // inner states @@ -437,7 +438,7 @@ impl ExpressionVisitor for ExprIdentifierVisitor<'_> { self.id_array[idx] = (self.series_number, desc.clone()); self.visit_stack.push(VisitRecord::ExprItem(desc.clone())); - let data_type = expr.get_type(&self.schema)?; + let data_type = expr.get_type(&self.input_schema)?; self.expr_set .entry(desc) .or_insert_with(|| (expr.clone(), 0, data_type)) @@ -456,7 +457,7 @@ fn expr_to_identifier( expr.accept(ExprIdentifierVisitor { expr_set, id_array, - schema, + input_schema: schema, visit_stack: vec![], node_count: 0, series_number: 0, @@ -587,7 +588,7 @@ mod test { fn id_array_visitor() -> Result<()> { let expr = binary_expr( binary_expr( - sum(binary_expr(col("a"), Operator::Plus, lit("1"))), + sum(binary_expr(col("a"), Operator::Plus, lit(1))), Operator::Minus, avg(col("c")), ), @@ -596,7 +597,10 @@ mod test { ); let schema = Arc::new(DFSchema::new_with_metadata( - vec![DFField::new(None, "a", DataType::Int64, false)], + vec![ + DFField::new(None, "a", DataType::Int64, false), + DFField::new(None, "c", DataType::Int64, false), + ], Default::default(), )?); @@ -604,10 +608,10 @@ mod test { expr_to_identifier(&expr, &mut HashMap::new(), &mut id_array, schema)?; let expected = vec![ - (9, "SUM(a + Utf8(\"1\")) - AVG(c) * Int32(2)Int32(2)SUM(a + Utf8(\"1\")) - AVG(c)AVG(c)cSUM(a + Utf8(\"1\"))a + Utf8(\"1\")Utf8(\"1\")a"), - (7, "SUM(a + Utf8(\"1\")) - AVG(c)AVG(c)cSUM(a + Utf8(\"1\"))a + Utf8(\"1\")Utf8(\"1\")a"), - (4, "SUM(a + Utf8(\"1\"))a + Utf8(\"1\")Utf8(\"1\")a"), - (3, "a + Utf8(\"1\")Utf8(\"1\")a"), + (9, "SUM(a + Int32(1)) - AVG(c) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), + (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), + (4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"), + (3, "a + Int32(1)Int32(1)a"), (1, ""), (2, ""), (6, "AVG(c)c"), From f1909cf140206d8349fa86cd97443f1dfa9197bf Mon Sep 17 00:00:00 2001 From: Alexander Spies Date: Wed, 5 Oct 2022 18:19:45 +0200 Subject: [PATCH 3/4] Re-enable fall back schema for datatype resolution Fall back to the merged schema from the whole logical plan if the input schema was not sufficient to resolve the datatype of a sub-expression. This re-enables the fallback logic added in 3860cd3 (#1925). --- .../optimizer/src/common_subexpr_eliminate.rs | 90 +++++++++++++++---- 1 file changed, 74 insertions(+), 16 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 7706ec2146bd6..d336bc21220c9 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -94,7 +94,10 @@ fn optimize( schema, alias, }) => { - let arrays = to_arrays(expr, input, &mut expr_set)?; + let input_schema = Arc::clone(input.schema()); + let all_schemas: Vec = + plan.all_schemas().into_iter().cloned().collect(); + let arrays = to_arrays(expr, input_schema, all_schemas, &mut expr_set)?; let (mut new_expr, new_input) = rewrite_expr( &[expr], @@ -112,10 +115,18 @@ fn optimize( )?)) } LogicalPlan::Filter(Filter { predicate, input }) => { - let schema = Arc::clone(input.schema()); + let input_schema = Arc::clone(input.schema()); + let all_schemas: Vec = + plan.all_schemas().into_iter().cloned().collect(); let mut id_array = vec![]; - expr_to_identifier(predicate, &mut expr_set, &mut id_array, schema)?; + expr_to_identifier( + predicate, + &mut expr_set, + &mut id_array, + input_schema, + all_schemas, + )?; let (mut new_expr, new_input) = rewrite_expr( &[&[predicate.clone()]], @@ -141,7 +152,11 @@ fn optimize( window_expr, schema, }) => { - let arrays = to_arrays(window_expr, input, &mut expr_set)?; + let input_schema = Arc::clone(input.schema()); + let all_schemas: Vec = + plan.all_schemas().into_iter().cloned().collect(); + let arrays = + to_arrays(window_expr, input_schema, all_schemas, &mut expr_set)?; let (mut new_expr, new_input) = rewrite_expr( &[window_expr], @@ -163,8 +178,17 @@ fn optimize( input, schema, }) => { - let group_arrays = to_arrays(group_expr, input, &mut expr_set)?; - let aggr_arrays = to_arrays(aggr_expr, input, &mut expr_set)?; + let input_schema = Arc::clone(input.schema()); + let all_schemas: Vec = + plan.all_schemas().into_iter().cloned().collect(); + let group_arrays = to_arrays( + group_expr, + Arc::clone(&input_schema), + all_schemas.clone(), + &mut expr_set, + )?; + let aggr_arrays = + to_arrays(aggr_expr, input_schema, all_schemas, &mut expr_set)?; let (mut new_expr, new_input) = rewrite_expr( &[group_expr, aggr_expr], @@ -185,7 +209,10 @@ fn optimize( )?)) } LogicalPlan::Sort(Sort { expr, input, fetch }) => { - let arrays = to_arrays(expr, input, &mut expr_set)?; + let input_schema = Arc::clone(input.schema()); + let all_schemas: Vec = + plan.all_schemas().into_iter().cloned().collect(); + let arrays = to_arrays(expr, input_schema, all_schemas, &mut expr_set)?; let (mut new_expr, new_input) = rewrite_expr( &[expr], @@ -243,14 +270,20 @@ fn pop_expr(new_expr: &mut Vec>) -> Result> { fn to_arrays( expr: &[Expr], - input: &LogicalPlan, + input_schema: DFSchemaRef, + all_schemas: Vec, expr_set: &mut ExprSet, ) -> Result>> { expr.iter() .map(|e| { - let schema = Arc::clone(input.schema()); let mut id_array = vec![]; - expr_to_identifier(e, expr_set, &mut id_array, schema)?; + expr_to_identifier( + e, + expr_set, + &mut id_array, + Arc::clone(&input_schema), + all_schemas.clone(), + )?; Ok(id_array) }) @@ -358,9 +391,12 @@ struct ExprIdentifierVisitor<'a> { expr_set: &'a mut ExprSet, /// series number (usize) and identifier. id_array: &'a mut Vec<(usize, Identifier)>, - /// input schema so we can determine the correct datatype for each subexpression + /// input schema for the node that we're optimizing, so we can determine the correct datatype + /// for each subexpression input_schema: DFSchemaRef, - //todo: also look in all schemas + /// all schemas in the logical plan, as a fall back if we cannot resolve an expression type + /// from the input schema alone + all_schemas: Vec, // inner states visit_stack: Vec, @@ -438,7 +474,21 @@ impl ExpressionVisitor for ExprIdentifierVisitor<'_> { self.id_array[idx] = (self.series_number, desc.clone()); self.visit_stack.push(VisitRecord::ExprItem(desc.clone())); - let data_type = expr.get_type(&self.input_schema)?; + + let data_type = if let Ok(data_type) = expr.get_type(&self.input_schema) { + data_type + } else { + // expression type could not be resolved in schema, fall back to all schemas + let merged_schema = + self.all_schemas + .iter() + .fold(DFSchema::empty(), |mut lhs, rhs| { + lhs.merge(rhs); + lhs + }); + expr.get_type(&merged_schema)? + }; + self.expr_set .entry(desc) .or_insert_with(|| (expr.clone(), 0, data_type)) @@ -452,12 +502,14 @@ fn expr_to_identifier( expr: &Expr, expr_set: &mut ExprSet, id_array: &mut Vec<(usize, Identifier)>, - schema: DFSchemaRef, + input_schema: DFSchemaRef, + all_schemas: Vec, ) -> Result<()> { expr.accept(ExprIdentifierVisitor { expr_set, id_array, - input_schema: schema, + input_schema, + all_schemas, visit_stack: vec![], node_count: 0, series_number: 0, @@ -605,7 +657,13 @@ mod test { )?); let mut id_array = vec![]; - expr_to_identifier(&expr, &mut HashMap::new(), &mut id_array, schema)?; + expr_to_identifier( + &expr, + &mut HashMap::new(), + &mut id_array, + Arc::clone(&schema), + vec![schema], + )?; let expected = vec![ (9, "SUM(a + Int32(1)) - AVG(c) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), From ea26e7a5e238dfa00627c0329b25279da2936c63 Mon Sep 17 00:00:00 2001 From: Alexander Spies Date: Thu, 6 Oct 2022 11:36:25 +0200 Subject: [PATCH 4/4] Add comment on fall-back logic using all schemas Point out that it can likely be removed. --- datafusion/optimizer/src/common_subexpr_eliminate.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index d336bc21220c9..07ec321f8cdaa 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -396,6 +396,9 @@ struct ExprIdentifierVisitor<'a> { input_schema: DFSchemaRef, /// all schemas in the logical plan, as a fall back if we cannot resolve an expression type /// from the input schema alone + // This fallback should never be necessary as the expression datatype should always be + // resolvable from the input schema of the node that's being optimized. + // todo: This can likely be removed if we are sure it's safe to do so. all_schemas: Vec, // inner states @@ -478,7 +481,11 @@ impl ExpressionVisitor for ExprIdentifierVisitor<'_> { let data_type = if let Ok(data_type) = expr.get_type(&self.input_schema) { data_type } else { - // expression type could not be resolved in schema, fall back to all schemas + // Expression type could not be resolved in schema, fall back to all schemas. + // + // This fallback should never be necessary as the expression datatype should always be + // resolvable from the input schema of the node that's being optimized. + // todo: This else-branch can likely be removed if we are sure it's safe to do so. let merged_schema = self.all_schemas .iter()