diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 7d09d94834b1..fe51aedc8c95 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -767,6 +767,8 @@ async fn test_physical_plan_display_indent_multi_children() { #[tokio::test] #[cfg_attr(tarpaulin, ignore)] async fn csv_explain() { + // TODO: https://github.com/apache/arrow-datafusion/issues/3622 refactor the `PreCastLitInComparisonExpressions` + // This test uses the execute function that create full plan cycle: logical, optimized logical, and physical, // then execute the physical plan and return the final explain results let ctx = SessionContext::new(); @@ -777,6 +779,23 @@ async fn csv_explain() { // Note can't use `assert_batches_eq` as the plan needs to be // normalized for filenames and number of cores + let expected = vec![ + vec![ + "logical_plan", + "Projection: #aggregate_test_100.c1\ + \n Filter: CAST(#aggregate_test_100.c2 AS Int32) > Int32(10)\ + \n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[CAST(#aggregate_test_100.c2 AS Int32) > Int32(10)]" + ], + vec!["physical_plan", + "ProjectionExec: expr=[c1@0 as c1]\ + \n CoalesceBatchesExec: target_batch_size=4096\ + \n FilterExec: CAST(c2@1 AS Int32) > 10\ + \n RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\ + \n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c2]\ + \n" + ]]; + assert_eq!(expected, actual); + let expected = vec![ vec![ "logical_plan", @@ -792,7 +811,6 @@ async fn csv_explain() { \n CsvExec: files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, limit=None, projection=[c1, c2]\ \n" ]]; - assert_eq!(expected, actual); let sql = "explain SELECT c1 FROM aggregate_test_100 where c2 > 10"; let actual = execute(&ctx, sql).await; diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 879658c408ad..bfb5634364d2 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -35,9 +35,9 @@ pub mod subquery_filter_to_join; pub mod type_coercion; pub mod utils; +pub mod pre_cast_lit_in_comparison; pub mod rewrite_disjunctive_predicate; #[cfg(test)] pub mod test; -pub mod unwrap_cast_in_comparison; pub use optimizer::{OptimizerConfig, OptimizerRule}; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 5ef5cfdd5975..b59fdae59b07 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -25,6 +25,7 @@ use crate::eliminate_limit::EliminateLimit; use crate::filter_null_join_keys::FilterNullJoinKeys; use crate::filter_push_down::FilterPushDown; use crate::limit_push_down::LimitPushDown; +use crate::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions; use crate::projection_push_down::ProjectionPushDown; use crate::reduce_cross_join::ReduceCrossJoin; use crate::reduce_outer_join::ReduceOuterJoin; @@ -34,7 +35,6 @@ use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::subquery_filter_to_join::SubqueryFilterToJoin; use crate::type_coercion::TypeCoercion; -use crate::unwrap_cast_in_comparison::UnwrapCastInComparison; use chrono::{DateTime, Utc}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; @@ -137,9 +137,9 @@ impl Optimizer { /// Create a new optimizer using the recommended list of rules pub fn new(config: &OptimizerConfig) -> Self { let mut rules: Vec> = vec![ + Arc::new(PreCastLitInComparisonExpressions::new()), Arc::new(TypeCoercion::new()), Arc::new(SimplifyExpressions::new()), - Arc::new(UnwrapCastInComparison::new()), Arc::new(DecorrelateWhereExists::new()), Arc::new(DecorrelateWhereIn::new()), Arc::new(ScalarSubqueryToJoin::new()), diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs similarity index 64% rename from datafusion/optimizer/src/unwrap_cast_in_comparison.rs rename to datafusion/optimizer/src/pre_cast_lit_in_comparison.rs index 0f7238d33cd0..382f5bfb2206 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -//! Unwrap-cast binary comparison rule can be used to the binary/inlist comparison expr now, and other type -//! of expr can be added if needed. -//! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. +//! Pre-cast literal binary comparison rule can be only used to the binary comparison expr. +//! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr. use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, @@ -29,14 +28,14 @@ use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, }; -/// The rule can be used to the numeric binary comparison with literal expr, like below pattern: -/// `cast(left_expr as data_type) comparison_op literal_expr` or `literal_expr comparison_op cast(right_expr as data_type)`. -/// The data type of two sides must be equal, and must be signed numeric type now, and will support more data type later. +/// The rule can be only used to the numeric binary comparison with literal expr, like below pattern: +/// `left_expr comparison_op literal_expr` or `literal_expr comparison_op right_expr`. +/// The data type of two sides must be signed numeric type now, and will support more data type later. /// /// If the binary comparison expr match above rules, the optimizer will check if the value of `literal` /// is in within range(min,max) which is the range(min,max) of the data type for `left_expr` or `right_expr`. /// -/// If this is true, the literal expr will be casted to the data type of expr on the other side, and the result of +/// If this true, the literal expr will be casted to the data type of expr on the other side, and the result of /// binary comparison will be `left_expr comparison_op cast(literal_expr, left_data_type)` or /// `cast(literal_expr, right_data_type) comparison_op right_expr`. For better optimization, /// the expr of `cast(literal_expr, target_type)` will be precomputed and converted to the new expr `new_literal_expr` @@ -46,19 +45,19 @@ use datafusion_expr::{ /// This is inspired by the optimizer rule `UnwrapCastInBinaryComparison` of Spark. /// # Example /// -/// `Filter: cast(c1 as INT64) > INT64(10)` will be optimized to `Filter: c1 > CAST(INT64(10) AS INT32), +/// `Filter: c1 > INT64(10)` will be optimized to `Filter: c1 > CAST(INT64(10) AS INT32), /// and continue to be converted to `Filter: c1 > INT32(10)`, if the DataType of c1 is INT32. /// #[derive(Default)] -pub struct UnwrapCastInComparison {} +pub struct PreCastLitInComparisonExpressions {} -impl UnwrapCastInComparison { +impl PreCastLitInComparisonExpressions { pub fn new() -> Self { Self::default() } } -impl OptimizerRule for UnwrapCastInComparison { +impl OptimizerRule for PreCastLitInComparisonExpressions { fn optimize( &self, plan: &LogicalPlan, @@ -68,7 +67,7 @@ impl OptimizerRule for UnwrapCastInComparison { } fn name(&self) -> &str { - "unwrap_cast_in_comparison" + "pre_cast_lit_in_comparison" } } @@ -81,7 +80,7 @@ fn optimize(plan: &LogicalPlan) -> Result { let schema = plan.schema(); - let mut expr_rewriter = UnwrapCastExprRewriter { + let mut expr_rewriter = PreCastLitExprRewriter { schema: schema.clone(), }; @@ -94,20 +93,17 @@ fn optimize(plan: &LogicalPlan) -> Result { from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice()) } -struct UnwrapCastExprRewriter { +struct PreCastLitExprRewriter { schema: DFSchemaRef, } -impl ExprRewriter for UnwrapCastExprRewriter { +impl ExprRewriter for PreCastLitExprRewriter { fn pre_visit(&mut self, _expr: &Expr) -> Result { Ok(RewriteRecursion::Continue) } fn mutate(&mut self, expr: Expr) -> Result { match &expr { - // For case: - // try_cast/cast(expr as data_type) op literal - // literal op try_cast/cast(expr as data_type) Expr::BinaryExpr { left, op, right } => { let left = left.as_ref().clone(); let right = right.as_ref().clone(); @@ -117,48 +113,29 @@ impl ExprRewriter for UnwrapCastExprRewriter { if left_type.is_err() || right_type.is_err() { return Ok(expr.clone()); } - // Because the plan has been done the type coercion, the left and right must be equal let left_type = left_type?; let right_type = right_type?; - if is_support_data_type(&left_type) + if !left_type.eq(&right_type) + && is_support_data_type(&left_type) && is_support_data_type(&right_type) && is_comparison_op(op) { match (&left, &right) { - ( - Expr::Literal(left_lit_value), - Expr::TryCast { expr, .. } | Expr::Cast { expr, .. }, - ) => { - // if the left_lit_value can be casted to the type of expr - // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal - let expr_type = expr.get_type(&self.schema)?; + (Expr::Literal(_), Expr::Literal(_)) => { + // do nothing + } + (Expr::Literal(left_lit_value), _) => { let casted_scalar_value = - try_cast_literal_to_type(left_lit_value, &expr_type)?; + try_cast_literal_to_type(left_lit_value, &right_type)?; if let Some(value) = casted_scalar_value { - // unwrap the cast/try_cast for the right expr - return Ok(binary_expr( - lit(value), - *op, - expr.as_ref().clone(), - )); + return Ok(binary_expr(lit(value), *op, right)); } } - ( - Expr::TryCast { expr, .. } | Expr::Cast { expr, .. }, - Expr::Literal(right_lit_value), - ) => { - // if the right_lit_value can be casted to the type of expr - // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal - let expr_type = expr.get_type(&self.schema)?; + (_, Expr::Literal(right_lit_value)) => { let casted_scalar_value = - try_cast_literal_to_type(right_lit_value, &expr_type)?; + try_cast_literal_to_type(right_lit_value, &left_type)?; if let Some(value) = casted_scalar_value { - // unwrap the cast/try_cast for the left expr - return Ok(binary_expr( - expr.as_ref().clone(), - *op, - lit(value), - )); + return Ok(binary_expr(left, *op, lit(value))); } } (_, _) => { @@ -169,75 +146,55 @@ impl ExprRewriter for UnwrapCastExprRewriter { // return the new binary op Ok(binary_expr(left, *op, right)) } - // For case: - // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) Expr::InList { expr: left_expr, list, negated, } => { - if let Some( - Expr::TryCast { - expr: internal_left_expr, - .. - } - | Expr::Cast { - expr: internal_left_expr, - .. - }, - ) = Some(left_expr.as_ref()) - { - let internal_left = internal_left_expr.as_ref().clone(); - let internal_left_type = internal_left.get_type(&self.schema); - if internal_left_type.is_err() { - // error data type - return Ok(expr); - } - let internal_left_type = internal_left_type?; - if !is_support_data_type(&internal_left_type) { - // not supported data type - return Ok(expr); - } - let right_exprs = list - .iter() - .map(|right| { - let right_type = right.get_type(&self.schema)?; - if !is_support_data_type(&right_type) { - return Err(DataFusionError::Internal(format!( - "The type of list expr {} not support", - &right_type - ))); - } - match right { - Expr::Literal(right_lit_value) => { - // if the right_lit_value can be casted to the type of internal_left_expr - // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal - let casted_scalar_value = - try_cast_literal_to_type(right_lit_value, &internal_left_type)?; - if let Some(value) = casted_scalar_value { - Ok(lit(value)) - } else { - Err(DataFusionError::Internal(format!( - "Can't cast the list expr {:?} to type {:?}", - right_lit_value, &internal_left_type - ))) - } + let left = left_expr.as_ref().clone(); + let left_type = left.get_type(&self.schema); + if left_type.is_err() { + // error data type + return Ok(expr); + } + let left_type = left_type?; + if !is_support_data_type(&left_type) { + // not supported data type + return Ok(expr); + } + let right_exprs = list + .iter() + .map(|right| { + let right_type = right.get_type(&self.schema)?; + if !is_support_data_type(&right_type) { + return Err(DataFusionError::Internal(format!( + "The type of list expr {} not support", + &right_type + ))); + } + match right { + Expr::Literal(right_lit_value) => { + let casted_scalar_value = + try_cast_literal_to_type(right_lit_value, &left_type)?; + if let Some(value) = casted_scalar_value { + Ok(lit(value)) + } else { + Err(DataFusionError::Internal(format!( + "Can't cast the list expr {:?} to type {:?}", + right_lit_value, &left_type + ))) } - other_expr => Err(DataFusionError::Internal(format!( - "Only support literal expr to optimize, but the expr is {:?}", - &other_expr - ))), } - }) - .collect::>>(); - match right_exprs { - Ok(right_exprs) => { - Ok(in_list(internal_left, right_exprs, *negated)) + other_expr => Err(DataFusionError::Internal(format!( + "Only support literal expr to optimize, but the expr is {:?}", + &other_expr + ))), } - Err(_) => Ok(expr), - } - } else { - Ok(expr) + }) + .collect::>>(); + match right_exprs { + Ok(right_exprs) => Ok(in_list(left, right_exprs, *negated)), + Err(_) => Ok(expr), } } // TODO: handle other expr type and dfs visit them @@ -369,19 +326,23 @@ fn try_cast_literal_to_type( #[cfg(test)] mod tests { - use crate::unwrap_cast_in_comparison::UnwrapCastExprRewriter; + use crate::pre_cast_lit_in_comparison::PreCastLitExprRewriter; use arrow::datatypes::DataType; use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue}; use datafusion_expr::expr_rewriter::ExprRewritable; - use datafusion_expr::{cast, col, lit, try_cast, Expr}; + use datafusion_expr::{cast, col, lit, Expr}; use std::collections::HashMap; use std::sync::Arc; #[test] - fn test_not_unwrap_cast_comparison() { + fn test_not_cast_lit_comparison() { let schema = expr_test_schema(); - // cast(INT32(c1), INT64) > INT64(c2) - let c1_gt_c2 = cast(col("c1"), DataType::Int64).gt(col("c2")); + // INT8(NULL) < INT32(12) + let lit_lt_lit = + lit(ScalarValue::Int8(None)).lt(lit(ScalarValue::Int32(Some(12)))); + assert_eq!(optimize_test(lit_lt_lit.clone(), &schema), lit_lt_lit); + // INT32(c1) > INT64(c2) + let c1_gt_c2 = col("c1").gt(col("c2")); assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2); // INT32(c1) < INT32(16), the type is same @@ -394,35 +355,27 @@ mod tests { } #[test] - fn test_unwrap_cast_comparison() { + fn test_pre_cast_lit_comparison() { let schema = expr_test_schema(); - // cast(c1, INT64) < INT64(16) -> INT32(c1) < cast(INT32(16)) + // c1 < INT64(16) -> c1 < cast(INT32(16)) // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)); - let expected = col("c1").lt(lit(16i32)); - assert_eq!(optimize_test(expr_lt, &schema), expected); - let expr_lt = try_cast(col("c1"), DataType::Int64).lt(lit(16i64)); + let expr_lt = col("c1").lt(lit(16i64)); let expected = col("c1").lt(lit(16i32)); assert_eq!(optimize_test(expr_lt, &schema), expected); - // cast(c2, INT32) = INT32(16) => INT64(c2) = INT64(16) - let c2_eq_lit = cast(col("c2"), DataType::Int32).eq(lit(16i32)); + // // INT64(c2) = INT32(16) => INT64(c2) = INT64(16) + let c2_eq_lit = col("c2").eq(lit(16i32)); let expected = col("c2").eq(lit(16i64)); assert_eq!(optimize_test(c2_eq_lit, &schema), expected); - // cast(c1, INT64) < INT64(NULL) => INT32(c1) < INT32(NULL) - let c1_lt_lit_null = cast(col("c1"), DataType::Int64).lt(null_i64()); + // INT32(c1) < INT64(NULL) => INT32(c1) < INT32(NULL) + let c1_lt_lit_null = col("c1").lt(null_i64()); let expected = col("c1").lt(null_i32()); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); - - // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) - let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32)); - let expected = null_i8().lt(lit(12i8)); - assert_eq!(optimize_test(lit_lt_lit, &schema), expected); } #[test] - fn test_not_unwrap_cast_with_decimal_comparison() { + fn test_not_cast_with_decimal_lit_comparison() { let schema = expr_test_schema(); // integer to decimal: value is out of the bounds of the decimal // cast(c3, INT64) = INT64(100000000000000000) @@ -452,89 +405,77 @@ mod tests { } #[test] - fn test_unwrap_cast_with_decimal_lit_comparison() { + fn test_pre_cast_with_decimal_lit_comparison() { let schema = expr_test_schema(); // integer to decimal // c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2)); - let expr_lt = try_cast(col("c3"), DataType::Int64).lt(lit(16i64)); + let expr_lt = col("c3").lt(lit(16i64)); let expected = col("c3").lt(lit_decimal(1600, 18, 2)); assert_eq!(optimize_test(expr_lt, &schema), expected); // c3 < INT64(NULL) - let c1_lt_lit_null = cast(col("c3"), DataType::Int64).lt(null_i64()); + let c1_lt_lit_null = col("c3").lt(null_i64()); let expected = col("c3").lt(null_decimal(18, 2)); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); // decimal to decimal // c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2) - let expr_lt = - cast(col("c3"), DataType::Decimal128(10, 0)).lt(lit_decimal(123, 10, 0)); + let expr_lt = col("c3").lt(lit_decimal(123, 10, 0)); let expected = col("c3").lt(lit_decimal(12300, 18, 2)); assert_eq!(optimize_test(expr_lt, &schema), expected); // c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2) - let expr_lt = - cast(col("c3"), DataType::Decimal128(10, 3)).lt(lit_decimal(1230, 10, 3)); + let expr_lt = col("c3").lt(lit_decimal(1230, 10, 3)); let expected = col("c3").lt(lit_decimal(123, 18, 2)); assert_eq!(optimize_test(expr_lt, &schema), expected); // decimal to integer // c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123) - let expr_lt = - cast(col("c1"), DataType::Decimal128(10, 2)).lt(lit_decimal(12300, 10, 2)); + let expr_lt = col("c1").lt(lit_decimal(12300, 10, 2)); let expected = col("c1").lt(lit(123i32)); assert_eq!(optimize_test(expr_lt, &schema), expected); } #[test] - fn test_not_unwrap_list_cast_lit_comparison() { + fn test_not_list_cast_lit_comparison() { let schema = expr_test_schema(); - // internal left type is not supported + // left type is not supported // FLOAT32(C5) in ... - let expr_lt = - cast(col("c5"), DataType::Int64).in_list(vec![lit(12i64), lit(12i64)], false); + let expr_lt = col("c5").in_list(vec![lit(12i64), lit(12i32)], false); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); - // cast(INT32(C1), Float32) in (FLOAT32(1.23), Float32(12), Float32(12)) - let expr_lt = cast(col("c1"), DataType::Float32) - .in_list(vec![lit(12.0f32), lit(12.0f32), lit(1.23f32)], false); + // INT32(C1) in (FLOAT32(1.23), INT32(12), INT64(12)) + let expr_lt = + col("c1").in_list(vec![lit(12.0_f32), lit(12_i32), lit(12_i64)], false); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // INT32(C1) in (INT64(99999999999), INT64(12)) - let expr_lt = cast(col("c1"), DataType::Int64) - .in_list(vec![lit(12i32), lit(99999999999i64)], false); + let expr_lt = col("c1").in_list(vec![lit(12i32), lit(99999999999i64)], false); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3)) - let expr_lt = cast(col("c3"), DataType::Decimal128(12, 3)).in_list( - vec![ - lit_decimal(12, 12, 3), - lit_decimal(12, 12, 3), - lit_decimal(128, 12, 3), - ], + let expr_lt = col("c3").in_list( + vec![lit(12_i64), lit(12_i32), lit_decimal(128, 12, 3)], false, ); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); } #[test] - fn test_unwrap_list_cast_comparison() { + fn test_pre_list_cast_lit_comparison() { let schema = expr_test_schema(); // INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) - let expr_lt = - cast(col("c1"), DataType::Int64).in_list(vec![lit(12i64), lit(24i64)], false); + let expr_lt = col("c1").in_list(vec![lit(12i64), lit(24i64)], false); let expected = col("c1").in_list(vec![lit(12i32), lit(24i32)], false); assert_eq!(optimize_test(expr_lt, &schema), expected); // INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN (INT32(12),INT32(24)) - let expr_lt = - cast(col("c2"), DataType::Int32).in_list(vec![null_i32(), lit(14i32)], false); + let expr_lt = col("c2").in_list(vec![null_i32(), lit(14i32)], false); let expected = col("c2").in_list(vec![null_i64(), lit(14i64)], false); assert_eq!(optimize_test(expr_lt, &schema), expected); // decimal test case - // c3 is decimal(18,2) - let expr_lt = cast(col("c3"), DataType::Decimal128(19, 3)).in_list( + let expr_lt = col("c3").in_list( vec![ lit_decimal(12000, 19, 3), lit_decimal(24000, 19, 3), @@ -555,8 +496,7 @@ mod tests { assert_eq!(optimize_test(expr_lt, &schema), expected); // cast(INT32(12), INT64) IN (.....) - let expr_lt = cast(lit(12i32), DataType::Int64) - .in_list(vec![lit(13i64), lit(12i64)], false); + let expr_lt = lit(12i32).in_list(vec![lit(13i64), lit(12i64)], false); let expected = lit(12i32).in_list(vec![lit(13i32), lit(12i32)], false); assert_eq!(optimize_test(expr_lt, &schema), expected); } @@ -566,7 +506,7 @@ mod tests { let schema = expr_test_schema(); // c1 < INT64(16) -> c1 < cast(INT32(16)) // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).alias("x"); + let expr_lt = col("c1").lt(lit(16i64)).alias("x"); let expected = col("c1").lt(lit(16i32)).alias("x"); assert_eq!(optimize_test(expr_lt, &schema), expected); } @@ -576,17 +516,13 @@ mod tests { let schema = expr_test_schema(); // c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32) // the 16 and 32 are within the range of MAX(int32) and MIN(int32), we can cast them to int32 - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).or(cast( - col("c1"), - DataType::Int64, - ) - .gt(lit(32i64))); + let expr_lt = col("c1").lt(lit(16i64)).or(col("c1").gt(lit(32i64))); let expected = col("c1").lt(lit(16i32)).or(col("c1").gt(lit(32i32))); assert_eq!(optimize_test(expr_lt, &schema), expected); } fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { - let mut expr_rewriter = UnwrapCastExprRewriter { + let mut expr_rewriter = PreCastLitExprRewriter { schema: schema.clone(), }; expr.rewrite(&mut expr_rewriter).unwrap() @@ -608,10 +544,6 @@ mod tests { ) } - fn null_i8() -> Expr { - lit(ScalarValue::Int8(None)) - } - fn null_i32() -> Expr { lit(ScalarValue::Int32(None)) } diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 86f55e698505..71efe86bd3d6 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -29,6 +29,29 @@ use std::any::Any; use std::collections::HashMap; use std::sync::Arc; +#[test] +fn case_when() -> Result<()> { + // regression test for https://github.com/apache/arrow-datafusion/issues/3690 + let sql = "SELECT CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END FROM test"; + let plan = test_sql(sql)?; + let expected = "Projection: CASE WHEN CAST(#test.col_int32 AS Int64) > Int64(0) THEN Int64(1) ELSE Int64(0) END\ + \n TableScan: test projection=[col_int32]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + +#[test] +fn unsigned_target_type() -> Result<()> { + // regression test for https://github.com/apache/arrow-datafusion/issues/3690 + let sql = "SELECT * FROM test WHERE col_uint32 > 0"; + let plan = test_sql(sql)?; + let expected = "Projection: #test.col_int32, #test.col_uint32, #test.col_utf8, #test.col_date32, #test.col_date64\ + \n Filter: CAST(#test.col_uint32 AS Int64) > Int64(0)\ + \n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + #[test] fn distribute_by() -> Result<()> { // regression test for https://github.com/apache/arrow-datafusion/issues/3234 @@ -114,6 +137,7 @@ impl ContextProvider for MySchemaProvider { let schema = Schema::new_with_metadata( vec![ Field::new("col_int32", DataType::Int32, true), + Field::new("col_uint32", DataType::UInt32, true), Field::new("col_utf8", DataType::Utf8, true), Field::new("col_date32", DataType::Date32, true), Field::new("col_date64", DataType::Date64, true),