Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions datafusion/core/tests/sql/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,25 +152,25 @@ async fn case_expr_with_null() -> Result<()> {
let actual = execute_to_batches(&ctx, sql).await;

let expected = vec![
"+------------------------------------------------+",
"| CASE WHEN #a.b IS NULL THEN NULL ELSE #a.b END |",
"+------------------------------------------------+",
"| |",
"| 3 |",
"+------------------------------------------------+",
"+----------------------------------------------+",
"| CASE WHEN a.b IS NULL THEN NULL ELSE a.b END |",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like an improvement but I don't understand the change

"+----------------------------------------------+",
"| |",
"| 3 |",
"+----------------------------------------------+",
];
assert_batches_eq!(expected, &actual);

let sql = "select case when b is null then null else b end from (select a,b from (values (1,1),(2,3)) as t (a,b)) a;";
let actual = execute_to_batches(&ctx, sql).await;

let expected = vec![
"+------------------------------------------------+",
"| CASE WHEN #a.b IS NULL THEN NULL ELSE #a.b END |",
"+------------------------------------------------+",
"| 1 |",
"| 3 |",
"+------------------------------------------------+",
"+----------------------------------------------+",
"| CASE WHEN a.b IS NULL THEN NULL ELSE a.b END |",
"+----------------------------------------------+",
"| 1 |",
"| 3 |",
"+----------------------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand All @@ -184,27 +184,27 @@ async fn case_expr_with_nulls() -> Result<()> {
let actual = execute_to_batches(&ctx, sql).await;

let expected = vec![
"+--------------------------------------------------------------------------------------------------------------------------+",
"| CASE WHEN #a.b IS NULL THEN NULL WHEN #a.b < Int64(3) THEN NULL WHEN #a.b >= Int64(3) THEN #a.b + Int64(1) ELSE #a.b END |",
"+--------------------------------------------------------------------------------------------------------------------------+",
"| |",
"| |",
"| 4 |",
"+--------------------------------------------------------------------------------------------------------------------------+"
"+---------------------------------------------------------------------------------------------------------------------+",
"| CASE WHEN a.b IS NULL THEN NULL WHEN a.b < Int64(3) THEN NULL WHEN a.b >= Int64(3) THEN a.b + Int64(1) ELSE a.b END |",
"+---------------------------------------------------------------------------------------------------------------------+",
"| |",
"| |",
"| 4 |",
"+---------------------------------------------------------------------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);

let sql = "select case b when 1 then null when 2 then null when 3 then b + 1 else b end from (select a,b from (values (1,null),(1,2),(2,3)) as t (a,b)) a;";
let actual = execute_to_batches(&ctx, sql).await;

let expected = vec![
"+------------------------------------------------------------------------------------------------------------+",
"| CASE #a.b WHEN Int64(1) THEN NULL WHEN Int64(2) THEN NULL WHEN Int64(3) THEN #a.b + Int64(1) ELSE #a.b END |",
"+------------------------------------------------------------------------------------------------------------+",
"| |",
"| |",
"| 4 |",
"+------------------------------------------------------------------------------------------------------------+",
"+---------------------------------------------------------------------------------------------------------+",
"| CASE a.b WHEN Int64(1) THEN NULL WHEN Int64(2) THEN NULL WHEN Int64(3) THEN a.b + Int64(1) ELSE a.b END |",
"+---------------------------------------------------------------------------------------------------------+",
"| |",
"| |",
"| 4 |",
"+---------------------------------------------------------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down
14 changes: 7 additions & 7 deletions datafusion/core/tests/sql/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,13 @@ async fn project_cast_dictionary() {
let actual = collect(physical_plan, ctx.task_ctx()).await.unwrap();

let expected = vec![
"+------------------------------------------------------------------------------------+",
"| CASE WHEN #cpu_load_short.host IS NULL THEN Utf8(\"\") ELSE #cpu_load_short.host END |",
"+------------------------------------------------------------------------------------+",
"| host1 |",
"| |",
"| host2 |",
"+------------------------------------------------------------------------------------+",
"+----------------------------------------------------------------------------------+",
"| CASE WHEN cpu_load_short.host IS NULL THEN Utf8(\"\") ELSE cpu_load_short.host END |",
"+----------------------------------------------------------------------------------+",
"| host1 |",
"| |",
"| host2 |",
"+----------------------------------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
}
Expand Down
66 changes: 66 additions & 0 deletions datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,50 @@ impl ExprRewriter for TypeCoercionRewriter {
}
}
}
Expr::Case {
expr,
when_then_expr,
else_expr,
} => {
// all the result of then and else should be convert to a common data type,
// if they can be coercible to a common data type, return error.
let then_types = when_then_expr
.iter()
.map(|when_then| when_then.1.get_type(&self.schema))
.collect::<Result<Vec<_>>>()?;
let else_type = match &else_expr {
None => Ok(None),
Some(expr) => expr.get_type(&self.schema).map(Some),
}?;
let case_when_coerce_type =
get_coerce_type_for_case_when(&then_types, &else_type);
match case_when_coerce_type {
None => Err(DataFusionError::Internal(format!(
"Failed to coerce then ({:?}) and else ({:?}) to common types in CASE WHEN expression",
then_types, else_type
))),
Some(data_type) => {
let left = when_then_expr
.into_iter()
.map(|(when, then)| {
let then = then.cast_to(&data_type, &self.schema)?;
Ok((when, Box::new(then)))
})
.collect::<Result<Vec<_>>>()?;
let right = match else_expr {
None => None,
Some(expr) => {
Some(Box::new(expr.cast_to(&data_type, &self.schema)?))
}
};
Ok(Expr::Case {
expr,
when_then_expr: left,
else_expr: right,
})
}
}
}
expr => Ok(expr),
}
}
Expand Down Expand Up @@ -410,6 +454,28 @@ fn coerce_arguments_for_signature(
.collect::<Result<Vec<_>>>()
}

/// Find a common coerceable type for all `then_types` as well
/// and the `else_type`, if specified.
/// Returns the common data type for `then_types` and `else_type`
fn get_coerce_type_for_case_when(
then_types: &[DataType],
else_type: &Option<DataType>,
) -> Option<DataType> {
let else_type = match else_type {
None => then_types[0].clone(),
Some(data_type) => data_type.clone(),
};
then_types
.iter()
.fold(Some(else_type), |left, right_type| match left {
// failed to find a valid coercion in a previous iteration
None => None,
// TODO: now just use the `equal` coercion rule for case when. If find the issue, and
// refactor again.
Some(left_type) => comparison_coercion(&left_type, right_type),
})
}

#[cfg(test)]
mod test {
use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter};
Expand Down
Loading