diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index ff0ccf835249d..2a805a5fc0e8b 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -110,10 +110,10 @@ use datafusion_expr::{TableSource, TableType}; use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists; use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn; use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; -use datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions; use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin; use datafusion_optimizer::type_coercion::TypeCoercion; +use datafusion_optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison; use datafusion_sql::{ parser::DFParser, planner::{ContextProvider, SqlToRel}, @@ -1466,9 +1466,9 @@ impl SessionState { } let mut rules: Vec> = vec![ - Arc::new(PreCastLitInComparisonExpressions::new()), Arc::new(TypeCoercion::new()), Arc::new(SimplifyExpressions::new()), + Arc::new(UnwrapCastInComparison::new()), Arc::new(DecorrelateWhereExists::new()), Arc::new(DecorrelateWhereIn::new()), Arc::new(ScalarSubqueryToJoin::new()), diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index fe51aedc8c954..7d09d94834b13 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -767,8 +767,6 @@ async fn test_physical_plan_display_indent_multi_children() { #[tokio::test] #[cfg_attr(tarpaulin, ignore)] async fn csv_explain() { - // TODO: https://github.com/apache/arrow-datafusion/issues/3622 refactor the `PreCastLitInComparisonExpressions` - // This test uses the execute function that create full plan cycle: logical, optimized logical, and physical, // then execute the physical plan and return the final explain results let ctx = SessionContext::new(); @@ -779,23 +777,6 @@ async fn csv_explain() { // Note can't use `assert_batches_eq` as the plan needs to be // normalized for filenames and number of cores - let expected = vec![ - vec![ - "logical_plan", - "Projection: #aggregate_test_100.c1\ - \n Filter: CAST(#aggregate_test_100.c2 AS Int32) > Int32(10)\ - \n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[CAST(#aggregate_test_100.c2 AS Int32) > Int32(10)]" - ], - vec!["physical_plan", - "ProjectionExec: expr=[c1@0 as c1]\ - \n CoalesceBatchesExec: target_batch_size=4096\ - \n FilterExec: CAST(c2@1 AS Int32) > 10\ - \n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\ - \n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c2]\ - \n" - ]]; - assert_eq!(expected, actual); - let expected = vec![ vec![ "logical_plan", @@ -811,6 +792,7 @@ async fn csv_explain() { \n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c2]\ \n" ]]; + assert_eq!(expected, actual); let sql = "explain SELECT c1 FROM aggregate_test_100 where c2 > 10"; let actual = execute(&ctx, sql).await; diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index bfb5634364d2d..879658c408ad6 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -35,9 +35,9 @@ pub mod subquery_filter_to_join; pub mod type_coercion; pub mod utils; -pub mod pre_cast_lit_in_comparison; pub mod rewrite_disjunctive_predicate; #[cfg(test)] pub mod test; +pub mod unwrap_cast_in_comparison; pub use optimizer::{OptimizerConfig, OptimizerRule}; diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs similarity index 60% rename from datafusion/optimizer/src/pre_cast_lit_in_comparison.rs rename to datafusion/optimizer/src/unwrap_cast_in_comparison.rs index a6d915cf0161e..0d5665f29e427 100644 --- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! Pre-cast literal binary comparison rule can be only used to the binary comparison expr. -//! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr. +//! Unwrap-cast binary comparison rule can be used to the binary/inlist comparison expr now, and other type +//! of expr can be added if needed. +//! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, @@ -28,14 +29,14 @@ use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, }; -/// The rule can be only used to the numeric binary comparison with literal expr, like below pattern: -/// `left_expr comparison_op literal_expr` or `literal_expr comparison_op right_expr`. -/// The data type of two sides must be signed numeric type now, and will support more data type later. +/// The rule can be used to the numeric binary comparison with literal expr, like below pattern: +/// `cast(left_expr as data_type) comparison_op literal_expr` or `literal_expr comparison_op cast(right_expr as data_type)`. +/// The data type of two sides must be equal, and must be signed numeric type now, and will support more data type later. /// /// If the binary comparison expr match above rules, the optimizer will check if the value of `literal` /// is in within range(min,max) which is the range(min,max) of the data type for `left_expr` or `right_expr`. /// -/// If this true, the literal expr will be casted to the data type of expr on the other side, and the result of +/// If this is true, the literal expr will be casted to the data type of expr on the other side, and the result of /// binary comparison will be `left_expr comparison_op cast(literal_expr, left_data_type)` or /// `cast(literal_expr, right_data_type) comparison_op right_expr`. For better optimization, /// the expr of `cast(literal_expr, target_type)` will be precomputed and converted to the new expr `new_literal_expr` @@ -45,19 +46,19 @@ use datafusion_expr::{ /// This is inspired by the optimizer rule `UnwrapCastInBinaryComparison` of Spark. /// # Example /// -/// `Filter: c1 > INT64(10)` will be optimized to `Filter: c1 > CAST(INT64(10) AS INT32), +/// `Filter: cast(c1 as INT64) > INT64(10)` will be optimized to `Filter: c1 > CAST(INT64(10) AS INT32), /// and continue to be converted to `Filter: c1 > INT32(10)`, if the DataType of c1 is INT32. /// #[derive(Default)] -pub struct PreCastLitInComparisonExpressions {} +pub struct UnwrapCastInComparison {} -impl PreCastLitInComparisonExpressions { +impl UnwrapCastInComparison { pub fn new() -> Self { Self::default() } } -impl OptimizerRule for PreCastLitInComparisonExpressions { +impl OptimizerRule for UnwrapCastInComparison { fn optimize( &self, plan: &LogicalPlan, @@ -67,7 +68,7 @@ impl OptimizerRule for PreCastLitInComparisonExpressions { } fn name(&self) -> &str { - "pre_cast_lit_in_comparison" + "unwrap_cast_in_comparison" } } @@ -80,7 +81,7 @@ fn optimize(plan: &LogicalPlan) -> Result { let schema = plan.schema(); - let mut expr_rewriter = PreCastLitExprRewriter { + let mut expr_rewriter = UnwrapCastExprRewriter { schema: schema.clone(), }; @@ -93,17 +94,20 @@ fn optimize(plan: &LogicalPlan) -> Result { from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice()) } -struct PreCastLitExprRewriter { +struct UnwrapCastExprRewriter { schema: DFSchemaRef, } -impl ExprRewriter for PreCastLitExprRewriter { +impl ExprRewriter for UnwrapCastExprRewriter { fn pre_visit(&mut self, _expr: &Expr) -> Result { Ok(RewriteRecursion::Continue) } fn mutate(&mut self, expr: Expr) -> Result { match &expr { + // For case: + // try_cast/cast(expr as data_type) op literal + // literal op try_cast/cast(expr as data_type) Expr::BinaryExpr { left, op, right } => { let left = left.as_ref().clone(); let right = right.as_ref().clone(); @@ -113,29 +117,48 @@ impl ExprRewriter for PreCastLitExprRewriter { if left_type.is_err() || right_type.is_err() { return Ok(expr.clone()); } + // Because the plan has been done the type coercion, the left and right must be equal let left_type = left_type?; let right_type = right_type?; - if !left_type.eq(&right_type) - && is_support_data_type(&left_type) + if is_support_data_type(&left_type) && is_support_data_type(&right_type) && is_comparison_op(op) { match (&left, &right) { - (Expr::Literal(_), Expr::Literal(_)) => { - // do nothing - } - (Expr::Literal(left_lit_value), _) => { + ( + Expr::Literal(left_lit_value), + Expr::TryCast { expr, .. } | Expr::Cast { expr, .. }, + ) => { + // if the left_lit_value can be casted to the type of expr + // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal + let expr_type = expr.get_type(&self.schema)?; let casted_scalar_value = - try_cast_literal_to_type(left_lit_value, &right_type)?; + try_cast_literal_to_type(left_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { - return Ok(binary_expr(lit(value), *op, right)); + // unwrap the cast/try_cast for the right expr + return Ok(binary_expr( + lit(value), + *op, + expr.as_ref().clone(), + )); } } - (_, Expr::Literal(right_lit_value)) => { + ( + Expr::TryCast { expr, .. } | Expr::Cast { expr, .. }, + Expr::Literal(right_lit_value), + ) => { + // if the right_lit_value can be casted to the type of expr + // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal + let expr_type = expr.get_type(&self.schema)?; let casted_scalar_value = - try_cast_literal_to_type(right_lit_value, &left_type)?; + try_cast_literal_to_type(right_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { - return Ok(binary_expr(left, *op, lit(value))); + // unwrap the cast/try_cast for the left expr + return Ok(binary_expr( + expr.as_ref().clone(), + *op, + lit(value), + )); } } (_, _) => { @@ -146,55 +169,75 @@ impl ExprRewriter for PreCastLitExprRewriter { // return the new binary op Ok(binary_expr(left, *op, right)) } + // For case: + // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) Expr::InList { expr: left_expr, list, negated, } => { - let left = left_expr.as_ref().clone(); - let left_type = left.get_type(&self.schema); - if left_type.is_err() { - // error data type - return Ok(expr); - } - let left_type = left_type?; - if !is_support_data_type(&left_type) { - // not supported data type - return Ok(expr); - } - let right_exprs = list - .iter() - .map(|right| { - let right_type = right.get_type(&self.schema)?; - if !is_support_data_type(&right_type) { - return Err(DataFusionError::Internal(format!( - "The type of list expr {} not support", - &right_type - ))); - } - match right { - Expr::Literal(right_lit_value) => { - let casted_scalar_value = - try_cast_literal_to_type(right_lit_value, &left_type)?; - if let Some(value) = casted_scalar_value { - Ok(lit(value)) - } else { - Err(DataFusionError::Internal(format!( - "Can't cast the list expr {:?} to type {:?}", - right_lit_value, &left_type - ))) + if let Some( + Expr::TryCast { + expr: internal_left_expr, + .. + } + | Expr::Cast { + expr: internal_left_expr, + .. + }, + ) = Some(left_expr.as_ref()) + { + let internal_left = internal_left_expr.as_ref().clone(); + let internal_left_type = internal_left.get_type(&self.schema); + if internal_left_type.is_err() { + // error data type + return Ok(expr); + } + let internal_left_type = internal_left_type?; + if !is_support_data_type(&internal_left_type) { + // not supported data type + return Ok(expr); + } + let right_exprs = list + .iter() + .map(|right| { + let right_type = right.get_type(&self.schema)?; + if !is_support_data_type(&right_type) { + return Err(DataFusionError::Internal(format!( + "The type of list expr {} not support", + &right_type + ))); + } + match right { + Expr::Literal(right_lit_value) => { + // if the right_lit_value can be casted to the type of internal_left_expr + // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal + let casted_scalar_value = + try_cast_literal_to_type(right_lit_value, &internal_left_type)?; + if let Some(value) = casted_scalar_value { + Ok(lit(value)) + } else { + Err(DataFusionError::Internal(format!( + "Can't cast the list expr {:?} to type {:?}", + right_lit_value, &internal_left_type + ))) + } } + other_expr => Err(DataFusionError::Internal(format!( + "Only support literal expr to optimize, but the expr is {:?}", + &other_expr + ))), } - other_expr => Err(DataFusionError::Internal(format!( - "Only support literal expr to optimize, but the expr is {:?}", - &other_expr - ))), + }) + .collect::>>(); + match right_exprs { + Ok(right_exprs) => { + Ok(in_list(internal_left, right_exprs, *negated)) } - }) - .collect::>>(); - match right_exprs { - Ok(right_exprs) => Ok(in_list(left, right_exprs, *negated)), - Err(_) => Ok(expr), + Err(_) => Ok(expr), + } + } else { + Ok(expr) } } // TODO: handle other expr type and dfs visit them @@ -326,23 +369,19 @@ fn try_cast_literal_to_type( #[cfg(test)] mod tests { - use crate::pre_cast_lit_in_comparison::PreCastLitExprRewriter; + use crate::unwrap_cast_in_comparison::UnwrapCastExprRewriter; use arrow::datatypes::DataType; use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue}; use datafusion_expr::expr_rewriter::ExprRewritable; - use datafusion_expr::{col, lit, Expr}; + use datafusion_expr::{cast, col, lit, try_cast, Expr}; use std::collections::HashMap; use std::sync::Arc; #[test] - fn test_not_cast_lit_comparison() { + fn test_not_unwrap_cast_comparison() { let schema = expr_test_schema(); - // INT8(NULL) < INT32(12) - let lit_lt_lit = - lit(ScalarValue::Int8(None)).lt(lit(ScalarValue::Int32(Some(12)))); - assert_eq!(optimize_test(lit_lt_lit.clone(), &schema), lit_lt_lit); - // INT32(c1) > INT64(c2) - let c1_gt_c2 = col("c1").gt(col("c2")); + // cast(INT32(c1), INT64) > INT64(c2) + let c1_gt_c2 = cast(col("c1"), DataType::Int64).gt(col("c2")); assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2); // INT32(c1) < INT32(16), the type is same @@ -350,110 +389,132 @@ mod tests { assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type - let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(99999999999)))); + let expr_lt = cast(col("c1"), DataType::Int64) + .lt(lit(ScalarValue::Int64(Some(99999999999)))); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); } #[test] - fn test_pre_cast_lit_comparison() { + fn test_unwrap_cast_comparison() { let schema = expr_test_schema(); - // c1 < INT64(16) -> c1 < cast(INT32(16)) + // cast(c1, INT64) < INT64(16) -> INT32(c1) < cast(INT32(16)) // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) - let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(16)))); + let expr_lt = + cast(col("c1"), DataType::Int64).lt(lit(ScalarValue::Int64(Some(16)))); + let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16)))); + assert_eq!(optimize_test(expr_lt, &schema), expected); + let expr_lt = + try_cast(col("c1"), DataType::Int64).lt(lit(ScalarValue::Int64(Some(16)))); let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16)))); assert_eq!(optimize_test(expr_lt, &schema), expected); - // INT64(c2) = INT32(16) => INT64(c2) = INT64(16) - let c2_eq_lit = col("c2").eq(lit(ScalarValue::Int32(Some(16)))); + // cast(c2, INT32) = INT32(16) => INT64(c2) = INT64(16) + let c2_eq_lit = + cast(col("c2"), DataType::Int32).eq(lit(ScalarValue::Int32(Some(16)))); let expected = col("c2").eq(lit(ScalarValue::Int64(Some(16)))); assert_eq!(optimize_test(c2_eq_lit, &schema), expected); - // INT32(c1) < INT64(NULL) => INT32(c1) < INT32(NULL) - let c1_lt_lit_null = col("c1").lt(lit(ScalarValue::Int64(None))); + // cast(c1, INT64) < INT64(NULL) => INT32(c1) < INT32(NULL) + let c1_lt_lit_null = + cast(col("c1"), DataType::Int64).lt(lit(ScalarValue::Int64(None))); let expected = col("c1").lt(lit(ScalarValue::Int32(None))); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); + + // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) + let lit_lt_lit = cast(lit(ScalarValue::Int8(None)), DataType::Int32) + .lt(lit(ScalarValue::Int32(Some(12)))); + let expected = lit(ScalarValue::Int8(None)).lt(lit(ScalarValue::Int8(Some(12)))); + assert_eq!(optimize_test(lit_lt_lit, &schema), expected); } #[test] - fn test_not_cast_with_decimal_lit_comparison() { + fn test_not_unwrap_cast_with_decimal_comparison() { let schema = expr_test_schema(); // integer to decimal: value is out of the bounds of the decimal - // c3 = INT64(100000000000000000) - let expr_eq = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000)))); - let expected = col("c3").eq(lit(ScalarValue::Int64(Some(100000000000000000)))); - assert_eq!(optimize_test(expr_eq, &schema), expected); - // c4 = INT64(1000) will overflow the i128 - let expr_eq = col("c4").eq(lit(ScalarValue::Int64(Some(1000)))); - let expected = col("c4").eq(lit(ScalarValue::Int64(Some(1000)))); - assert_eq!(optimize_test(expr_eq, &schema), expected); + // cast(c3, INT64) = INT64(100000000000000000) + let expr_eq = cast(col("c3"), DataType::Int64) + .eq(lit(ScalarValue::Int64(Some(100000000000000000)))); + assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); + + // cast(c4, INT64) = INT64(1000) will overflow the i128 + let expr_eq = + cast(col("c4"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(1000)))); + assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); // decimal to decimal: value will lose the scale when convert to the target data type // c3 = DECIMAL(12340,20,4) - let expr_eq = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4))); - let expected = col("c3").eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4))); - assert_eq!(optimize_test(expr_eq, &schema), expected); + let expr_eq = cast(col("c3"), DataType::Decimal128(20, 4)) + .eq(lit(ScalarValue::Decimal128(Some(12340), 20, 4))); + assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); // decimal to integer // c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to the target data type - let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1))); - let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(123), 10, 1))); - assert_eq!(optimize_test(expr_eq, &schema), expected); + let expr_eq = cast(col("c1"), DataType::Decimal128(10, 1)) + .eq(lit(ScalarValue::Decimal128(Some(123), 10, 1))); + assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); + // c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert to the target data type - let expr_eq = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2))); - let expected = col("c1").eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2))); - assert_eq!(optimize_test(expr_eq, &schema), expected); + let expr_eq = cast(col("c1"), DataType::Decimal128(10, 2)) + .eq(lit(ScalarValue::Decimal128(Some(1230), 10, 2))); + assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); } #[test] - fn test_pre_cast_with_decimal_lit_comparison() { + fn test_unwrap_cast_with_decimal_lit_comparison() { let schema = expr_test_schema(); // integer to decimal // c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2)); - let expr_lt = col("c3").lt(lit(ScalarValue::Int64(Some(16)))); + let expr_lt = + try_cast(col("c3"), DataType::Int64).lt(lit(ScalarValue::Int64(Some(16)))); let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(1600), 18, 2))); assert_eq!(optimize_test(expr_lt, &schema), expected); // c3 < INT64(NULL) - let c1_lt_lit_null = col("c3").lt(lit(ScalarValue::Int64(None))); + let c1_lt_lit_null = + cast(col("c3"), DataType::Int64).lt(lit(ScalarValue::Int64(None))); let expected = col("c3").lt(lit(ScalarValue::Decimal128(None, 18, 2))); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); // decimal to decimal // c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2) - let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 10, 0))); + let expr_lt = cast(col("c3"), DataType::Decimal128(10, 0)) + .lt(lit(ScalarValue::Decimal128(Some(123), 10, 0))); let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(12300), 18, 2))); assert_eq!(optimize_test(expr_lt, &schema), expected); + // c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2) - let expr_lt = col("c3").lt(lit(ScalarValue::Decimal128(Some(1230), 10, 3))); + let expr_lt = cast(col("c3"), DataType::Decimal128(10, 3)) + .lt(lit(ScalarValue::Decimal128(Some(1230), 10, 3))); let expected = col("c3").lt(lit(ScalarValue::Decimal128(Some(123), 18, 2))); assert_eq!(optimize_test(expr_lt, &schema), expected); // decimal to integer // c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123) - let expr_lt = col("c1").lt(lit(ScalarValue::Decimal128(Some(12300), 10, 2))); + let expr_lt = cast(col("c1"), DataType::Decimal128(10, 2)) + .lt(lit(ScalarValue::Decimal128(Some(12300), 10, 2))); let expected = col("c1").lt(lit(ScalarValue::Int32(Some(123)))); assert_eq!(optimize_test(expr_lt, &schema), expected); } #[test] - fn test_not_list_cast_lit_comparison() { + fn test_not_unwrap_list_cast_lit_comparison() { let schema = expr_test_schema(); - // left type is not supported + // internal left type is not supported // FLOAT32(C5) in ... - let expr_lt = col("c5").in_list( + let expr_lt = cast(col("c5"), DataType::Int64).in_list( vec![ lit(ScalarValue::Int64(Some(12))), - lit(ScalarValue::Int32(Some(12))), + lit(ScalarValue::Int64(Some(12))), ], false, ); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); - // INT32(C1) in (FLOAT32(1.23), INT32(12), INT64(12)) - let expr_lt = col("c1").in_list( + // cast(INT32(C1), Float32) in (FLOAT32(1.23), Float32(12), Float32(12)) + let expr_lt = cast(col("c1"), DataType::Float32).in_list( vec![ - lit(ScalarValue::Int32(Some(12))), - lit(ScalarValue::Int64(Some(12))), + lit(ScalarValue::Float32(Some(12.0))), + lit(ScalarValue::Float32(Some(12.0))), lit(ScalarValue::Float32(Some(1.23))), ], false, @@ -461,7 +522,7 @@ mod tests { assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // INT32(C1) in (INT64(99999999999), INT64(12)) - let expr_lt = col("c1").in_list( + let expr_lt = cast(col("c1"), DataType::Int64).in_list( vec![ lit(ScalarValue::Int32(Some(12))), lit(ScalarValue::Int64(Some(99999999999))), @@ -471,10 +532,10 @@ mod tests { assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3)) - let expr_lt = col("c3").in_list( + let expr_lt = cast(col("c3"), DataType::Decimal128(12, 3)).in_list( vec![ - lit(ScalarValue::Int32(Some(12))), - lit(ScalarValue::Int64(Some(12))), + lit(ScalarValue::Decimal128(Some(12), 12, 3)), + lit(ScalarValue::Decimal128(Some(12), 12, 3)), lit(ScalarValue::Decimal128(Some(128), 12, 3)), ], false, @@ -483,12 +544,12 @@ mod tests { } #[test] - fn test_pre_list_cast_lit_comparison() { + fn test_unwrap_list_cast_comparison() { let schema = expr_test_schema(); // INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) - let expr_lt = col("c1").in_list( + let expr_lt = cast(col("c1"), DataType::Int64).in_list( vec![ - lit(ScalarValue::Int32(Some(12))), + lit(ScalarValue::Int64(Some(12))), lit(ScalarValue::Int64(Some(24))), ], false, @@ -502,9 +563,9 @@ mod tests { ); assert_eq!(optimize_test(expr_lt, &schema), expected); // INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) - let expr_lt = col("c2").in_list( + let expr_lt = cast(col("c2"), DataType::Int32).in_list( vec![ - lit(ScalarValue::Int64(None)), + lit(ScalarValue::Int32(None)), lit(ScalarValue::Int32(Some(14))), ], false, @@ -520,12 +581,13 @@ mod tests { assert_eq!(optimize_test(expr_lt, &schema), expected); // decimal test case - let expr_lt = col("c3").in_list( + // c3 is decimal(18,2) + let expr_lt = cast(col("c3"), DataType::Decimal128(19, 3)).in_list( vec![ - lit(ScalarValue::Int32(Some(12))), - lit(ScalarValue::Int64(Some(24))), - lit(ScalarValue::Decimal128(Some(128), 10, 2)), - lit(ScalarValue::Decimal128(Some(1280), 10, 3)), + lit(ScalarValue::Decimal128(Some(12000), 19, 3)), + lit(ScalarValue::Decimal128(Some(24000), 19, 3)), + lit(ScalarValue::Decimal128(Some(1280), 19, 3)), + lit(ScalarValue::Decimal128(Some(1240), 19, 3)), ], false, ); @@ -534,23 +596,23 @@ mod tests { lit(ScalarValue::Decimal128(Some(1200), 18, 2)), lit(ScalarValue::Decimal128(Some(2400), 18, 2)), lit(ScalarValue::Decimal128(Some(128), 18, 2)), - lit(ScalarValue::Decimal128(Some(128), 18, 2)), + lit(ScalarValue::Decimal128(Some(124), 18, 2)), ], false, ); assert_eq!(optimize_test(expr_lt, &schema), expected); - // INT32(12) IN (.....) - let expr_lt = lit(ScalarValue::Int32(Some(12))).in_list( + // cast(INT32(12), INT64) IN (.....) + let expr_lt = cast(lit(ScalarValue::Int32(Some(12))), DataType::Int64).in_list( vec![ - lit(ScalarValue::Int32(Some(12))), + lit(ScalarValue::Int64(Some(13))), lit(ScalarValue::Int64(Some(12))), ], false, ); let expected = lit(ScalarValue::Int32(Some(12))).in_list( vec![ - lit(ScalarValue::Int32(Some(12))), + lit(ScalarValue::Int32(Some(13))), lit(ScalarValue::Int32(Some(12))), ], false, @@ -563,7 +625,9 @@ mod tests { let schema = expr_test_schema(); // c1 < INT64(16) -> c1 < cast(INT32(16)) // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) - let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(16)))).alias("x"); + let expr_lt = cast(col("c1"), DataType::Int64) + .lt(lit(ScalarValue::Int64(Some(16)))) + .alias("x"); let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16)))).alias("x"); assert_eq!(optimize_test(expr_lt, &schema), expected); } @@ -573,9 +637,9 @@ mod tests { let schema = expr_test_schema(); // c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32) // the 16 and 32 are within the range of MAX(int32) and MIN(int32), we can cast them to int32 - let expr_lt = col("c1") + let expr_lt = cast(col("c1"), DataType::Int64) .lt(lit(ScalarValue::Int64(Some(16)))) - .or(col("c1").gt(lit(ScalarValue::Int64(Some(32))))); + .or(cast(col("c1"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(32))))); let expected = col("c1") .lt(lit(ScalarValue::Int32(Some(16)))) .or(col("c1").gt(lit(ScalarValue::Int32(Some(32))))); @@ -583,7 +647,7 @@ mod tests { } fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { - let mut expr_rewriter = PreCastLitExprRewriter { + let mut expr_rewriter = UnwrapCastExprRewriter { schema: schema.clone(), }; expr.rewrite(&mut expr_rewriter).unwrap() diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 5f27603167d54..ab5caca99f535 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -27,7 +27,6 @@ use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; use datafusion_optimizer::filter_push_down::FilterPushDown; use datafusion_optimizer::limit_push_down::LimitPushDown; use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions; use datafusion_optimizer::projection_push_down::ProjectionPushDown; use datafusion_optimizer::reduce_cross_join::ReduceCrossJoin; use datafusion_optimizer::reduce_outer_join::ReduceOuterJoin; @@ -37,6 +36,7 @@ use datafusion_optimizer::simplify_expressions::SimplifyExpressions; use datafusion_optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; use datafusion_optimizer::subquery_filter_to_join::SubqueryFilterToJoin; use datafusion_optimizer::type_coercion::TypeCoercion; +use datafusion_optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison; use datafusion_optimizer::{OptimizerConfig, OptimizerRule}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::Statement; @@ -109,9 +109,9 @@ fn test_sql(sql: &str) -> Result { // TODO should make align with rules in the context // https://github.com/apache/arrow-datafusion/issues/3524 let rules: Vec> = vec![ - Arc::new(PreCastLitInComparisonExpressions::new()), Arc::new(TypeCoercion::new()), Arc::new(SimplifyExpressions::new()), + Arc::new(UnwrapCastInComparison::new()), Arc::new(DecorrelateWhereExists::new()), Arc::new(DecorrelateWhereIn::new()), Arc::new(ScalarSubqueryToJoin::new()),