diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index b1e8227e06ea9..93e453101d357 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -152,12 +152,12 @@ async fn case_expr_with_null() -> Result<()> { let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------------------------------------------------+", - "| CASE WHEN #a.b IS NULL THEN NULL ELSE #a.b END |", - "+------------------------------------------------+", - "| |", - "| 3 |", - "+------------------------------------------------+", + "+----------------------------------------------+", + "| CASE WHEN a.b IS NULL THEN NULL ELSE a.b END |", + "+----------------------------------------------+", + "| |", + "| 3 |", + "+----------------------------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -165,12 +165,12 @@ async fn case_expr_with_null() -> Result<()> { let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------------------------------------------------+", - "| CASE WHEN #a.b IS NULL THEN NULL ELSE #a.b END |", - "+------------------------------------------------+", - "| 1 |", - "| 3 |", - "+------------------------------------------------+", + "+----------------------------------------------+", + "| CASE WHEN a.b IS NULL THEN NULL ELSE a.b END |", + "+----------------------------------------------+", + "| 1 |", + "| 3 |", + "+----------------------------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -184,13 +184,13 @@ async fn case_expr_with_nulls() -> Result<()> { let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+--------------------------------------------------------------------------------------------------------------------------+", - "| CASE WHEN #a.b IS NULL THEN NULL WHEN #a.b < Int64(3) THEN NULL WHEN #a.b >= Int64(3) THEN #a.b + Int64(1) ELSE #a.b END |", - "+--------------------------------------------------------------------------------------------------------------------------+", - "| |", - "| |", - "| 4 |", - "+--------------------------------------------------------------------------------------------------------------------------+" + "+---------------------------------------------------------------------------------------------------------------------+", + "| CASE WHEN a.b IS NULL THEN NULL WHEN a.b < Int64(3) THEN NULL WHEN a.b >= Int64(3) THEN a.b + Int64(1) ELSE a.b END |", + "+---------------------------------------------------------------------------------------------------------------------+", + "| |", + "| |", + "| 4 |", + "+---------------------------------------------------------------------------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); @@ -198,13 +198,13 @@ async fn case_expr_with_nulls() -> Result<()> { let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------------------------------------------------------------------------------------------------------------+", - "| CASE #a.b WHEN Int64(1) THEN NULL WHEN Int64(2) THEN NULL WHEN Int64(3) THEN #a.b + Int64(1) ELSE #a.b END |", - "+------------------------------------------------------------------------------------------------------------+", - "| |", - "| |", - "| 4 |", - "+------------------------------------------------------------------------------------------------------------+", + "+---------------------------------------------------------------------------------------------------------+", + "| CASE a.b WHEN Int64(1) THEN NULL WHEN Int64(2) THEN NULL WHEN Int64(3) THEN a.b + Int64(1) ELSE a.b END |", + "+---------------------------------------------------------------------------------------------------------+", + "| |", + "| |", + "| 4 |", + "+---------------------------------------------------------------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); diff --git a/datafusion/core/tests/sql/projection.rs b/datafusion/core/tests/sql/projection.rs index 97c6dcf8aa7fe..cee8e706c6bbf 100644 --- a/datafusion/core/tests/sql/projection.rs +++ b/datafusion/core/tests/sql/projection.rs @@ -252,13 +252,13 @@ async fn project_cast_dictionary() { let actual = collect(physical_plan, ctx.task_ctx()).await.unwrap(); let expected = vec![ - "+------------------------------------------------------------------------------------+", - "| CASE WHEN #cpu_load_short.host IS NULL THEN Utf8(\"\") ELSE #cpu_load_short.host END |", - "+------------------------------------------------------------------------------------+", - "| host1 |", - "| |", - "| host2 |", - "+------------------------------------------------------------------------------------+", + "+----------------------------------------------------------------------------------+", + "| CASE WHEN cpu_load_short.host IS NULL THEN Utf8(\"\") ELSE cpu_load_short.host END |", + "+----------------------------------------------------------------------------------+", + "| host1 |", + "| |", + "| host2 |", + "+----------------------------------------------------------------------------------+", ]; assert_batches_eq!(expected, &actual); } diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 3d45c504150b4..a2bdbb8e1a714 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -354,6 +354,50 @@ impl ExprRewriter for TypeCoercionRewriter { } } } + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + // all the result of then and else should be convert to a common data type, + // if they can be coercible to a common data type, return error. + let then_types = when_then_expr + .iter() + .map(|when_then| when_then.1.get_type(&self.schema)) + .collect::>>()?; + let else_type = match &else_expr { + None => Ok(None), + Some(expr) => expr.get_type(&self.schema).map(Some), + }?; + let case_when_coerce_type = + get_coerce_type_for_case_when(&then_types, &else_type); + match case_when_coerce_type { + None => Err(DataFusionError::Internal(format!( + "Failed to coerce then ({:?}) and else ({:?}) to common types in CASE WHEN expression", + then_types, else_type + ))), + Some(data_type) => { + let left = when_then_expr + .into_iter() + .map(|(when, then)| { + let then = then.cast_to(&data_type, &self.schema)?; + Ok((when, Box::new(then))) + }) + .collect::>>()?; + let right = match else_expr { + None => None, + Some(expr) => { + Some(Box::new(expr.cast_to(&data_type, &self.schema)?)) + } + }; + Ok(Expr::Case { + expr, + when_then_expr: left, + else_expr: right, + }) + } + } + } expr => Ok(expr), } } @@ -410,6 +454,28 @@ fn coerce_arguments_for_signature( .collect::>>() } +/// Find a common coerceable type for all `then_types` as well +/// and the `else_type`, if specified. +/// Returns the common data type for `then_types` and `else_type` +fn get_coerce_type_for_case_when( + then_types: &[DataType], + else_type: &Option, +) -> Option { + let else_type = match else_type { + None => then_types[0].clone(), + Some(data_type) => data_type.clone(), + }; + then_types + .iter() + .fold(Some(else_type), |left, right_type| match left { + // failed to find a valid coercion in a previous iteration + None => None, + // TODO: now just use the `equal` coercion rule for case when. If find the issue, and + // refactor again. + Some(left_type) => comparison_coercion(&left_type, right_type), + }) +} + #[cfg(test)] mod test { use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter}; diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index cf4f7defe7f36..b1bd0a6044648 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -25,7 +25,6 @@ use arrow::compute::{and, eq_dyn, is_null, not, or, or_kleene}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::binary_rule::comparison_coercion; use datafusion_expr::ColumnarValue; type WhenThen = (Arc, Arc); @@ -294,66 +293,10 @@ pub fn case( expr: Option>, when_thens: Vec, else_expr: Option>, - input_schema: &Schema, ) -> Result> { - // all the result of then and else should be convert to a common data type, - // if they can be coercible to a common data type, return error. - let coerce_type = get_case_common_type(&when_thens, else_expr.clone(), input_schema); - let (when_thens, else_expr) = match coerce_type { - None => Err(DataFusionError::Plan(format!( - "Can't get a common type for then {:?} and else {:?} expression", - when_thens, else_expr - ))), - Some(data_type) => { - // cast then expr - let left = when_thens - .into_iter() - .map(|(when, then)| { - let then = try_cast(then, input_schema, data_type.clone())?; - Ok((when, then)) - }) - .collect::>>()?; - let right = match else_expr { - None => None, - Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?), - }; - - Ok((left, right)) - } - }?; - Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?)) } -fn get_case_common_type( - when_thens: &[WhenThen], - else_expr: Option>, - input_schema: &Schema, -) -> Option { - let thens_type = when_thens - .iter() - .map(|when_then| { - let data_type = &when_then.1.data_type(input_schema).unwrap(); - data_type.clone() - }) - .collect::>(); - let else_type = match else_expr { - None => { - // case when then exprs must have one then value - thens_type[0].clone() - } - Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(), - }; - thens_type - .iter() - .fold(Some(else_type), |left, right_type| match left { - None => None, - // TODO: now just use the `equal` coercion rule for case when. If find the issue, and - // refactor again. - Some(left_type) => comparison_coercion(&left_type, right_type), - }) -} - #[cfg(test)] mod tests { use super::*; @@ -365,6 +308,7 @@ mod tests { use arrow::datatypes::DataType::Float64; use arrow::datatypes::*; use datafusion_common::ScalarValue; + use datafusion_expr::binary_rule::comparison_coercion; use datafusion_expr::Operator; #[test] @@ -378,7 +322,7 @@ mod tests { let when2 = lit("bar"); let then2 = lit(456i32); - let expr = case( + let expr = generate_case_when_with_type_coercion( Some(col("a", &schema)?), vec![(when1, then1), (when2, then2)], None, @@ -409,7 +353,7 @@ mod tests { let then2 = lit(456i32); let else_value = lit(999i32); - let expr = case( + let expr = generate_case_when_with_type_coercion( Some(col("a", &schema)?), vec![(when1, then1), (when2, then2)], Some(else_value), @@ -444,7 +388,7 @@ mod tests { &batch.schema(), )?; - let expr = case( + let expr = generate_case_when_with_type_coercion( Some(col("a", &schema)?), vec![(when1, then1)], Some(else_value), @@ -484,7 +428,7 @@ mod tests { )?; let then2 = lit(456i32); - let expr = case( + let expr = generate_case_when_with_type_coercion( None, vec![(when1, then1), (when2, then2)], None, @@ -518,7 +462,12 @@ mod tests { )?; let x = lit(ScalarValue::Float64(None)); - let expr = case(None, vec![(when1, then1)], Some(x), schema.as_ref())?; + let expr = generate_case_when_with_type_coercion( + None, + vec![(when1, then1)], + Some(x), + schema.as_ref(), + )?; let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); let result = result .as_any() @@ -561,7 +510,7 @@ mod tests { let then2 = lit(456i32); let else_value = lit(999i32); - let expr = case( + let expr = generate_case_when_with_type_coercion( None, vec![(when1, then1), (when2, then2)], Some(else_value), @@ -596,7 +545,12 @@ mod tests { let then = lit(123.3f64); let else_value = lit(999i32); - let expr = case(None, vec![(when, then)], Some(else_value), schema.as_ref())?; + let expr = generate_case_when_with_type_coercion( + None, + vec![(when, then)], + Some(else_value), + schema.as_ref(), + )?; let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); let result = result .as_any() @@ -625,7 +579,12 @@ mod tests { )?; let then = col("load4", &schema)?; - let expr = case(None, vec![(when, then)], None, schema.as_ref())?; + let expr = generate_case_when_with_type_coercion( + None, + vec![(when, then)], + None, + schema.as_ref(), + )?; let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); let result = result .as_any() @@ -650,7 +609,12 @@ mod tests { let when = lit(1.77f64); let then = col("load4", &schema)?; - let expr = case(Some(expr), vec![(when, then)], None, schema.as_ref())?; + let expr = generate_case_when_with_type_coercion( + Some(expr), + vec![(when, then)], + None, + schema.as_ref(), + )?; let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); let result = result .as_any() @@ -724,7 +688,7 @@ mod tests { )?; let then2 = lit(true); - let expr = case( + let expr = generate_case_when_with_type_coercion( None, vec![(when1, then1), (when2, then2)], None, @@ -752,7 +716,7 @@ mod tests { let then2 = lit(456i64); let else_expr = lit(1.23f64); - let expr = case( + let expr = generate_case_when_with_type_coercion( None, vec![(when1, then1), (when2, then2)], Some(else_expr), @@ -763,4 +727,66 @@ mod tests { assert_eq!(DataType::Float64, result_type); Ok(()) } + + fn generate_case_when_with_type_coercion( + expr: Option>, + when_thens: Vec, + else_expr: Option>, + input_schema: &Schema, + ) -> Result> { + let coerce_type = + get_case_common_type(&when_thens, else_expr.clone(), input_schema); + let (when_thens, else_expr) = match coerce_type { + None => Err(DataFusionError::Plan(format!( + "Can't get a common type for then {:?} and else {:?} expression", + when_thens, else_expr + ))), + Some(data_type) => { + // cast then expr + let left = when_thens + .into_iter() + .map(|(when, then)| { + let then = try_cast(then, input_schema, data_type.clone())?; + Ok((when, then)) + }) + .collect::>>()?; + let right = match else_expr { + None => None, + Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?), + }; + + Ok((left, right)) + } + }?; + case(expr, when_thens, else_expr) + } + + fn get_case_common_type( + when_thens: &[WhenThen], + else_expr: Option>, + input_schema: &Schema, + ) -> Option { + let thens_type = when_thens + .iter() + .map(|when_then| { + let data_type = &when_then.1.data_type(input_schema).unwrap(); + data_type.clone() + }) + .collect::>(); + let else_type = match else_expr { + None => { + // case when then exprs must have one then value + thens_type[0].clone() + } + Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(), + }; + thens_type + .iter() + .fold(Some(else_type), |left, right_type| match left { + None => None, + // TODO: now just use the `equal` coercion rule for case when. If find the issue, and + // refactor again. + Some(left_type) => comparison_coercion(&left_type, right_type), + }) + } } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index ba9664ef653d2..0964d64805b8a 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -275,12 +275,7 @@ pub fn create_physical_expr( } else { None }; - Ok(expressions::case( - expr, - when_then_expr, - else_expr, - input_schema, - )?) + Ok(expressions::case(expr, when_then_expr, else_expr)?) } Expr::Cast { expr, data_type } => expressions::cast( create_physical_expr(expr, input_dfschema, input_schema, execution_props)?,