diff --git a/datafusion/core/tests/sqllogictests/test_files/select.slt b/datafusion/core/tests/sqllogictests/test_files/select.slt index 101d031885a8c..65cf24b6448ce 100644 --- a/datafusion/core/tests/sqllogictests/test_files/select.slt +++ b/datafusion/core/tests/sqllogictests/test_files/select.slt @@ -214,3 +214,19 @@ select * from (select 1 a union all select 2) b order by a limit null; query I select * from (select 1 a union all select 2) b order by a limit 0; ---- + +# select case when type coercion with case expression +query I +select CASE 10.5 WHEN 0 THEN 1 ELSE 2 END; +---- +2 + +# select case when type coercion without case expression +query I +select CASE + WHEN 10 = 5 THEN 1 + WHEN 'true' THEN 2 + ELSE 3 +END; +---- +2 diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 2806683ab87bf..13973bd349963 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -277,7 +277,7 @@ impl Display for BinaryExpr { } /// CASE expression -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Case { /// Optional base expression that can be compared to literal values in the "when" expressions pub expr: Option>, diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index bfe464b232cd5..479b74ea08521 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -21,7 +21,7 @@ use crate::expr::{ }; use crate::field_util::get_indexed_field; use crate::type_coercion::binary::binary_operator_data_type; -use crate::type_coercion::other::get_coerce_type_for_case_when; +use crate::type_coercion::other::get_coerce_type_for_case_expression; use crate::{aggregate_function, function, window_function}; use arrow::compute::can_cast_types; use arrow::datatypes::DataType; @@ -81,13 +81,12 @@ impl ExprSchemable for Expr { None => Ok(None), Some(expr) => expr.get_type(schema).map(Some), }?; - get_coerce_type_for_case_when(&then_types, else_type.as_ref()).ok_or_else( - || { + get_coerce_type_for_case_expression(&then_types, else_type.as_ref()) + .ok_or_else(|| { DataFusionError::Internal(String::from( "Cannot infer type for CASE statement", )) - }, - ) + }) } Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), diff --git a/datafusion/expr/src/type_coercion/other.rs b/datafusion/expr/src/type_coercion/other.rs index 6ff1300f64e2a..c53054e82112f 100644 --- a/datafusion/expr/src/type_coercion/other.rs +++ b/datafusion/expr/src/type_coercion/other.rs @@ -34,20 +34,20 @@ pub fn get_coerce_type_for_list( }) } -/// Find a common coerceable type for all `then_types` as well -/// and the `else_type`, if specified. -/// Returns the common data type for `then_types` and `else_type` -pub fn get_coerce_type_for_case_when( - then_types: &[DataType], - else_type: Option<&DataType>, +/// Find a common coerceable type for all `when_or_then_types` as well +/// and the `case_or_else_type`, if specified. +/// Returns the common data type for `when_or_then_types` and `case_or_else_type` +pub fn get_coerce_type_for_case_expression( + when_or_then_types: &[DataType], + case_or_else_type: Option<&DataType>, ) -> Option { - let else_type = match else_type { - None => then_types[0].clone(), + let case_or_else_type = match case_or_else_type { + None => when_or_then_types[0].clone(), Some(data_type) => data_type.clone(), }; - then_types + when_or_then_types .iter() - .fold(Some(else_type), |left, right_type| match left { + .fold(Some(case_or_else_type), |left, right_type| match left { // failed to find a valid coercion in a previous iteration None => None, // TODO: now just use the `equal` coercion rule for case when. If find the issue, and diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 0be9c89b6ccb6..35730c0f427ce 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -32,7 +32,7 @@ use datafusion_expr::type_coercion::binary::{ }; use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::type_coercion::other::{ - get_coerce_type_for_case_when, get_coerce_type_for_list, + get_coerce_type_for_case_expression, get_coerce_type_for_list, }; use datafusion_expr::type_coercion::{ is_date, is_numeric, is_timestamp, is_utf8_or_large_utf8, @@ -330,40 +330,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } } Expr::Case(case) => { - // all the result of then and else should be convert to a common data type, - // if they can be coercible to a common data type, return error. - let then_types = case - .when_then_expr - .iter() - .map(|when_then| when_then.1.get_type(&self.schema)) - .collect::>>()?; - let else_type = match &case.else_expr { - None => Ok(None), - Some(expr) => expr.get_type(&self.schema).map(Some), - }?; - let case_when_coerce_type = - get_coerce_type_for_case_when(&then_types, else_type.as_ref()); - match case_when_coerce_type { - None => Err(DataFusionError::Internal(format!( - "Failed to coerce then ({then_types:?}) and else ({else_type:?}) to common types in CASE WHEN expression" - ))), - Some(data_type) => { - let left = case.when_then_expr - .into_iter() - .map(|(when, then)| { - let then = then.cast_to(&data_type, &self.schema)?; - Ok((when, Box::new(then))) - }) - .collect::>>()?; - let right = match &case.else_expr { - None => None, - Some(expr) => { - Some(Box::new(expr.clone().cast_to(&data_type, &self.schema)?)) - } - }; - Ok(Expr::Case(Case::new(case.expr,left,right))) - } - } + let case = coerce_case_expression(case, &self.schema)?; + Ok(Expr::Case(case)) } Expr::ScalarUDF { fun, args } => { let new_expr = coerce_arguments_for_signature( @@ -638,19 +606,130 @@ fn coerce_agg_exprs_for_signature( .collect::>>() } +fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { + // Given expressions like: + // + // CASE a1 + // WHEN a2 THEN b1 + // WHEN a3 THEN b2 + // ELSE b3 + // END + // + // or: + // + // CASE + // WHEN x1 THEN b1 + // WHEN x2 THEN b2 + // ELSE b3 + // END + // + // Then all aN (a1, a2, a3) must be converted to a common data type in the first example + // (case-when expression coercion) + // + // All xN (x1, x2) must be converted to a boolean data type in the second example + // (when-boolean expression coercion) + // + // And all bN (b1, b2, b3) must be converted to a common data type in both examples + // (then-else expression coercion) + // + // If any fail to find and cast to a common/specific data type, will return error + // + // Note that case-when and when-boolean expression coercions are mutually exclusive + // Only one or the other can occur for a case expression, whilst then-else expression coercion will always occur + + // prepare types + let case_type = case + .expr + .as_ref() + .map(|expr| expr.get_type(&schema)) + .transpose()?; + let then_types = case + .when_then_expr + .iter() + .map(|(_when, then)| then.get_type(&schema)) + .collect::>>()?; + let else_type = case + .else_expr + .as_ref() + .map(|expr| expr.get_type(&schema)) + .transpose()?; + + // find common coercible types + let case_when_coerce_type = case_type + .as_ref() + .map(|case_type| { + let when_types = case + .when_then_expr + .iter() + .map(|(when, _then)| when.get_type(&schema)) + .collect::>>()?; + let coerced_type = + get_coerce_type_for_case_expression(&when_types, Some(case_type)); + coerced_type.ok_or_else(|| { + DataFusionError::Plan(format!( + "Failed to coerce case ({case_type:?}) and when ({when_types:?}) \ + to common types in CASE WHEN expression" + )) + }) + }) + .transpose()?; + let then_else_coerce_type = + get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else( + || { + DataFusionError::Plan(format!( + "Failed to coerce then ({then_types:?}) and else ({else_type:?}) \ + to common types in CASE WHEN expression" + )) + }, + )?; + + // do cast if found common coercible types + let case_expr = case + .expr + .zip(case_when_coerce_type.as_ref()) + .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, &schema)) + .transpose()? + .map(Box::new); + let when_then = case + .when_then_expr + .into_iter() + .map(|(when, then)| { + let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean); + let when = when.cast_to(when_type, &schema).map_err(|e| { + DataFusionError::Context( + format!( + "WHEN expressions in CASE couldn't be \ + converted to common type ({when_type})" + ), + Box::new(e), + ) + })?; + let then = then.cast_to(&then_else_coerce_type, &schema)?; + Ok((Box::new(when), Box::new(then))) + }) + .collect::>>()?; + let else_expr = case + .else_expr + .map(|expr| expr.cast_to(&then_else_coerce_type, &schema)) + .transpose()? + .map(Box::new); + + Ok(Case::new(case_expr, when_then, else_expr)) +} + #[cfg(test)] mod test { use std::sync::Arc; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::tree_node::TreeNode; - use datafusion_common::{DFField, DFSchema, Result, ScalarValue}; + use datafusion_common::{DFField, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{self, Like}; use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFunctionImplementation, AggregateFunction, AggregateUDF, - BuiltinScalarFunction, ColumnarValue, StateTypeFunction, + BuiltinScalarFunction, Case, ColumnarValue, ExprSchemable, StateTypeFunction, }; use datafusion_expr::{ lit, @@ -663,6 +742,8 @@ mod test { use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter}; use crate::{OptimizerContext, OptimizerRule}; + use super::coerce_case_expression; + fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { let rule = TypeCoercion::new(); let config = OptimizerContext::default(); @@ -1173,4 +1254,150 @@ mod test { assert_optimized_plan_eq(&plan, expected)?; Ok(()) } + + fn cast_if_not_same_type( + expr: Box, + data_type: &DataType, + schema: &DFSchemaRef, + ) -> Box { + if &expr.get_type(schema).unwrap() != data_type { + Box::new(cast(*expr, data_type.clone())) + } else { + expr + } + } + + fn cast_helper( + case: Case, + case_when_type: DataType, + then_else_type: DataType, + schema: &DFSchemaRef, + ) -> Case { + let expr = case + .expr + .map(|e| cast_if_not_same_type(e, &case_when_type, schema)); + let when_then_expr = case + .when_then_expr + .into_iter() + .map(|(when, then)| { + ( + cast_if_not_same_type(when, &case_when_type, schema), + cast_if_not_same_type(then, &then_else_type, schema), + ) + }) + .collect::>(); + let else_expr = case + .else_expr + .map(|e| cast_if_not_same_type(e, &then_else_type, schema)); + + Case { + expr, + when_then_expr, + else_expr, + } + } + + #[test] + fn test_case_expression_coercion() -> Result<()> { + let schema = Arc::new(DFSchema::new_with_metadata( + vec![ + DFField::new_unqualified("boolean", DataType::Boolean, true), + DFField::new_unqualified("integer", DataType::Int32, true), + DFField::new_unqualified("float", DataType::Float32, true), + DFField::new_unqualified( + "timestamp", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + DFField::new_unqualified("date", DataType::Date32, true), + DFField::new_unqualified( + "interval", + DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano), + true, + ), + DFField::new_unqualified("binary", DataType::Binary, true), + DFField::new_unqualified("string", DataType::Utf8, true), + DFField::new_unqualified("decimal", DataType::Decimal128(10, 10), true), + ], + std::collections::HashMap::new(), + )?); + + let case = Case { + expr: None, + when_then_expr: vec![ + (Box::new(col("boolean")), Box::new(col("integer"))), + (Box::new(col("integer")), Box::new(col("float"))), + (Box::new(col("string")), Box::new(col("string"))), + ], + else_expr: None, + }; + let case_when_common_type = DataType::Boolean; + let then_else_common_type = DataType::Utf8; + let expected = cast_helper( + case.clone(), + case_when_common_type, + then_else_common_type, + &schema, + ); + let actual = coerce_case_expression(case, &schema)?; + assert_eq!(expected, actual); + + let case = Case { + expr: Some(Box::new(col("string"))), + when_then_expr: vec![ + (Box::new(col("float")), Box::new(col("integer"))), + (Box::new(col("integer")), Box::new(col("float"))), + (Box::new(col("string")), Box::new(col("string"))), + ], + else_expr: Some(Box::new(col("string"))), + }; + let case_when_common_type = DataType::Utf8; + let then_else_common_type = DataType::Utf8; + let expected = cast_helper( + case.clone(), + case_when_common_type, + then_else_common_type, + &schema, + ); + let actual = coerce_case_expression(case, &schema)?; + assert_eq!(expected, actual); + + let case = Case { + expr: Some(Box::new(col("interval"))), + when_then_expr: vec![ + (Box::new(col("float")), Box::new(col("integer"))), + (Box::new(col("binary")), Box::new(col("float"))), + (Box::new(col("string")), Box::new(col("string"))), + ], + else_expr: Some(Box::new(col("string"))), + }; + let err = coerce_case_expression(case, &schema).unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: \ + Failed to coerce case (Interval(MonthDayNano)) and \ + when ([Float32, Binary, Utf8]) to common types in \ + CASE WHEN expression" + ); + + let case = Case { + expr: Some(Box::new(col("string"))), + when_then_expr: vec![ + (Box::new(col("float")), Box::new(col("date"))), + (Box::new(col("string")), Box::new(col("float"))), + (Box::new(col("string")), Box::new(col("binary"))), + ], + else_expr: Some(Box::new(col("timestamp"))), + }; + let err = coerce_case_expression(case, &schema).unwrap_err(); + assert_eq!( + err.to_string(), + "Error during planning: \ + Failed to coerce then ([Date32, Float32, Binary]) and \ + else (Some(Timestamp(Nanosecond, None))) to common types \ + in CASE WHEN expression" + ); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index b4a7b1c59b47d..2d97d57324fee 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -195,8 +195,12 @@ impl CaseExpr { _ => when_value, }; let when_value = when_value.into_array(batch.num_rows()); - let when_value = as_boolean_array(&when_value) - .expect("WHEN expression did not return a BooleanArray"); + let when_value = as_boolean_array(&when_value).map_err(|e| { + DataFusionError::Context( + "WHEN expression did not return a BooleanArray".to_string(), + Box::new(e), + ) + })?; let then_value = self.when_then_expr[i] .1