From bc3e96749a1c35eb4d10c84de9a2f28e5f144cef Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 25 Aug 2022 08:12:08 -0600 Subject: [PATCH 1/2] Use ExprRewriter in pre_cast_lit_in_comparison.rs --- .../src/pre_cast_lit_in_comparison.rs | 159 +++++++++++------- 1 file changed, 94 insertions(+), 65 deletions(-) diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs index 0c16f7921c328..ea6e1351daf62 100644 --- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs +++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs @@ -20,6 +20,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; +use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; use datafusion_expr::utils::from_plan; use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator}; @@ -74,79 +75,93 @@ fn optimize(plan: &LogicalPlan) -> Result { .collect::>>()?; let schema = plan.schema(); + + let mut expr_rewriter = PreCastLitExprRewriter { + schema: schema.clone(), + }; + let new_exprs = plan .expressions() .into_iter() - .map(|expr| visit_expr(expr, schema)) + .map(|expr| expr.rewrite(&mut expr_rewriter)) .collect::>>()?; from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice()) } -// Visit all type of expr, if the current has child expr, the child expr needed to visit first. -fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result { - // traverse the expr by dfs - match &expr { - Expr::BinaryExpr { left, op, right } => { - // dfs visit the left and right expr - let left = visit_expr(*left.clone(), schema)?; - let right = visit_expr(*right.clone(), schema)?; - let left_type = left.get_type(schema); - let right_type = right.get_type(schema); - // can't get the data type, just return the expr - if left_type.is_err() || right_type.is_err() { - return Ok(expr.clone()); - } - let left_type = left_type.unwrap(); - let right_type = right_type.unwrap(); - if !left_type.eq(&right_type) - && is_support_data_type(&left_type) - && is_support_data_type(&right_type) - && is_comparison_op(op) - { - match (&left, &right) { - (Expr::Literal(_), Expr::Literal(_)) => { - // do nothing - } - (Expr::Literal(left_lit_value), _) - if can_integer_literal_cast_to_type( - left_lit_value, - &right_type, - )? => - { - // cast the left literal to the right type - return Ok(binary_expr( - cast_to_other_scalar_expr(left_lit_value, &right_type)?, - *op, - right, - )); - } - (_, Expr::Literal(right_lit_value)) - if can_integer_literal_cast_to_type( - right_lit_value, - &left_type, - ) - .unwrap() => - { - // cast the right literal to the left type - return Ok(binary_expr( - left, - *op, - cast_to_other_scalar_expr(right_lit_value, &left_type)?, - )); - } - (_, _) => { - // do nothing - } - }; +struct PreCastLitExprRewriter { + schema: DFSchemaRef, +} + +impl ExprRewriter for PreCastLitExprRewriter { + fn pre_visit(&mut self, _expr: &Expr) -> Result { + Ok(RewriteRecursion::Continue) + } + + fn mutate(&mut self, expr: Expr) -> Result { + // traverse the expr by dfs + match &expr { + Expr::BinaryExpr { left, op, right } => { + // dfs visit the left and right expr + let left = self.mutate(*left.clone())?; + let right = self.mutate(*right.clone())?; + let left_type = left.get_type(&self.schema); + let right_type = right.get_type(&self.schema); + // can't get the data type, just return the expr + if left_type.is_err() || right_type.is_err() { + return Ok(expr.clone()); + } + let left_type = left_type?; + let right_type = right_type?; + if !left_type.eq(&right_type) + && is_support_data_type(&left_type) + && is_support_data_type(&right_type) + && is_comparison_op(op) + { + match (&left, &right) { + (Expr::Literal(_), Expr::Literal(_)) => { + // do nothing + } + (Expr::Literal(left_lit_value), _) + if can_integer_literal_cast_to_type( + left_lit_value, + &right_type, + )? => + { + // cast the left literal to the right type + return Ok(binary_expr( + cast_to_other_scalar_expr(left_lit_value, &right_type)?, + *op, + right, + )); + } + (_, Expr::Literal(right_lit_value)) + if can_integer_literal_cast_to_type( + right_lit_value, + &left_type, + ) + .unwrap() => + { + // cast the right literal to the left type + return Ok(binary_expr( + left, + *op, + cast_to_other_scalar_expr(right_lit_value, &left_type)?, + )); + } + (_, _) => { + // do nothing + } + }; + } + // return the new binary op + Ok(binary_expr(left, *op, right)) } - // return the new binary op - Ok(binary_expr(left, *op, right)) + // TODO: optimize in list + // Expr::InList { .. } => {} + // TODO: handle other expr type and dfs visit them + _ => Ok(expr), } - // TODO: optimize in list - // Expr::InList { .. } => {} - // TODO: handle other expr type and dfs visit them - _ => Ok(expr), } } @@ -245,9 +260,10 @@ fn can_integer_literal_cast_to_type( #[cfg(test)] mod tests { - use crate::pre_cast_lit_in_comparison::visit_expr; + use crate::pre_cast_lit_in_comparison::PreCastLitExprRewriter; use arrow::datatypes::DataType; use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue}; + use datafusion_expr::expr_rewriter::ExprRewritable; use datafusion_expr::{col, lit, Expr}; use std::collections::HashMap; use std::sync::Arc; @@ -292,8 +308,21 @@ mod tests { assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); } + #[test] + fn nested() { + let schema = expr_test_schema(); + // c1 < INT64(16) -> c1 < cast(INT32(16)) + // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) + let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(16)))).alias("x"); + let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16)))).alias("x"); + assert_eq!(optimize_test(expr_lt, &schema), expected); + } + fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { - visit_expr(expr, schema).unwrap() + let mut expr_rewriter = PreCastLitExprRewriter { + schema: schema.clone(), + }; + expr.rewrite(&mut expr_rewriter).unwrap() } fn expr_test_schema() -> DFSchemaRef { From 6e0f926a23c1179bffdca186b8cfaa2f497acb78 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 25 Aug 2022 08:32:51 -0600 Subject: [PATCH 2/2] remove manual recursion and add a nested test case --- .../src/pre_cast_lit_in_comparison.rs | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs index ea6e1351daf62..68c738ca8739a 100644 --- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs +++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs @@ -102,9 +102,8 @@ impl ExprRewriter for PreCastLitExprRewriter { // traverse the expr by dfs match &expr { Expr::BinaryExpr { left, op, right } => { - // dfs visit the left and right expr - let left = self.mutate(*left.clone())?; - let right = self.mutate(*right.clone())?; + let left = left.as_ref().clone(); + let right = right.as_ref().clone(); let left_type = left.get_type(&self.schema); let right_type = right.get_type(&self.schema); // can't get the data type, just return the expr @@ -309,7 +308,7 @@ mod tests { } #[test] - fn nested() { + fn aliased() { let schema = expr_test_schema(); // c1 < INT64(16) -> c1 < cast(INT32(16)) // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) @@ -318,6 +317,20 @@ mod tests { assert_eq!(optimize_test(expr_lt, &schema), expected); } + #[test] + fn nested() { + let schema = expr_test_schema(); + // c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32) + // the 16 and 32 are within the range of MAX(int32) and MIN(int32), we can cast them to int32 + let expr_lt = col("c1") + .lt(lit(ScalarValue::Int64(Some(16)))) + .or(col("c1").gt(lit(ScalarValue::Int64(Some(32))))); + let expected = col("c1") + .lt(lit(ScalarValue::Int32(Some(16)))) + .or(col("c1").gt(lit(ScalarValue::Int32(Some(32))))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + } + fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { let mut expr_rewriter = PreCastLitExprRewriter { schema: schema.clone(),