diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index 97cc23264bda1..d2ac5ce2f3837 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -30,6 +30,7 @@ use crate::optimizer::utils; use crate::physical_plan::functions::BuiltinScalarFunction; use crate::scalar::ScalarValue; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; +use arrow::compute::{kernels, DEFAULT_CAST_OPTIONS}; /// Optimizer that simplifies comparison expressions involving boolean literals. /// @@ -247,6 +248,25 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> { } } } + Expr::Cast { + expr: inner, + data_type, + } => match inner.as_ref() { + Expr::Literal(val) => { + let scalar_array = val.to_array(); + let cast_array = kernels::cast::cast_with_options( + &scalar_array, + &data_type, + &DEFAULT_CAST_OPTIONS, + )?; + let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; + Expr::Literal(cast_scalar) + } + _ => Expr::Cast { + expr: inner, + data_type, + }, + }, expr => { // no rewrite possible expr @@ -724,6 +744,44 @@ mod tests { assert_eq!(expected, actual); } + #[test] + fn cast_expr() { + let table_scan = test_table_scan().unwrap(); + let proj = vec![Expr::Cast { + expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some("0".to_string())))), + data_type: DataType::Int32, + }]; + let plan = LogicalPlanBuilder::from(&table_scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + + let expected = "Projection: Int32(0)\ + \n TableScan: test projection=None"; + let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now()); + assert_eq!(expected, actual); + } + + #[test] + fn cast_expr_wrong_arg() { + let table_scan = test_table_scan().unwrap(); + let proj = vec![Expr::Cast { + expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some("".to_string())))), + data_type: DataType::Int32, + }]; + let plan = LogicalPlanBuilder::from(&table_scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + + let expected = "Projection: Int32(NULL)\ + \n TableScan: test projection=None"; + let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now()); + assert_eq!(expected, actual); + } + #[test] fn single_now_expr() { let table_scan = test_table_scan().unwrap();