diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index b491a3529f353..95af4c6bb86b1 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1436,33 +1436,49 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // CASE WHEN true THEN A ... END --> A // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END + // CASE WHEN false THEN A END --> NULL + // CASE WHEN false THEN A ELSE B END --> B + // CASE WHEN X THEN A WHEN false THEN B END --> CASE WHEN X THEN A ELSE B END Expr::Case(Case { expr: None, - mut when_then_expr, - else_expr: _, - // if let guard is not stabilized so we can't use it yet: https://github.com/rust-lang/rust/issues/51114 - // Once it's supported we can avoid searching through when_then_expr twice in the below .any() and .position() calls - // }) if let Some(i) = when_then_expr.iter().position(|(when, _)| is_true(when.as_ref())) => { + when_then_expr, + mut else_expr, }) if when_then_expr .iter() - .any(|(when, _)| is_true(when.as_ref())) => + .any(|(when, _)| is_true(when.as_ref()) || is_false(when.as_ref())) => { - let i = when_then_expr - .iter() - .position(|(when, _)| is_true(when.as_ref())) - .unwrap(); - let (_, then_) = when_then_expr.swap_remove(i); - // CASE WHEN true THEN A ... END --> A - if i == 0 { - return Ok(Transformed::yes(*then_)); + let out_type = info.get_data_type(&when_then_expr[0].1)?; + let mut new_when_then_expr = Vec::with_capacity(when_then_expr.len()); + + for (when, then) in when_then_expr.into_iter() { + if is_true(when.as_ref()) { + // Skip adding the rest of the when-then expressions after WHEN true + // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END + else_expr = Some(then); + break; + } else if !is_false(when.as_ref()) { + new_when_then_expr.push((when, then)); + } + // else: skip WHEN false cases + } + + // Exclude CASE statement altogether if there are no when-then expressions left + if new_when_then_expr.is_empty() { + // CASE WHEN false THEN A ELSE B END --> B + if let Some(else_expr) = else_expr { + return Ok(Transformed::yes(*else_expr)); + // CASE WHEN false THEN A END --> NULL + } else { + let null = + Expr::Literal(ScalarValue::try_new_null(&out_type)?, None); + return Ok(Transformed::yes(null)); + } } - // CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END - when_then_expr.truncate(i); Transformed::yes(Expr::Case(Case { expr: None, - when_then_expr, - else_expr: Some(then_), + when_then_expr: new_when_then_expr, + else_expr, })) } @@ -3810,53 +3826,53 @@ mod tests { #[test] fn simplify_expr_case_when_first_true() { - // CASE WHEN true THEN 1 ELSE x END --> 1 + // CASE WHEN true THEN 1 ELSE c1 END --> 1 assert_eq!( simplify(Expr::Case(Case::new( None, vec![(Box::new(lit(true)), Box::new(lit(1)),)], - Some(Box::new(col("x"))), + Some(Box::new(col("c1"))), ))), lit(1) ); - // CASE WHEN true THEN col("a") ELSE col("b") END --> col("a") + // CASE WHEN true THEN col('a') ELSE col('b') END --> col('a') assert_eq!( simplify(Expr::Case(Case::new( None, - vec![(Box::new(lit(true)), Box::new(col("a")),)], - Some(Box::new(col("b"))), + vec![(Box::new(lit(true)), Box::new(lit("a")),)], + Some(Box::new(lit("b"))), ))), - col("a") + lit("a") ); - // CASE WHEN true THEN col("a") WHEN col("x") > 5 THEN col("b") ELSE col("c") END --> col("a") + // CASE WHEN true THEN col('a') WHEN col('x') > 5 THEN col('b') ELSE col('c') END --> col('a') assert_eq!( simplify(Expr::Case(Case::new( None, vec![ - (Box::new(lit(true)), Box::new(col("a"))), - (Box::new(col("x").gt(lit(5))), Box::new(col("b"))), + (Box::new(lit(true)), Box::new(lit("a"))), + (Box::new(lit("x").gt(lit(5))), Box::new(lit("b"))), ], - Some(Box::new(col("c"))), + Some(Box::new(lit("c"))), ))), - col("a") + lit("a") ); - // CASE WHEN true THEN col("a") END --> col("a") (no else clause) + // CASE WHEN true THEN col('a') END --> col('a') (no else clause) assert_eq!( simplify(Expr::Case(Case::new( None, - vec![(Box::new(lit(true)), Box::new(col("a")),)], + vec![(Box::new(lit(true)), Box::new(lit("a")),)], None, ))), - col("a") + lit("a") ); - // Negative test: CASE WHEN a THEN 1 ELSE 2 END should not be simplified + // Negative test: CASE WHEN c2 THEN 1 ELSE 2 END should not be simplified let expr = Expr::Case(Case::new( None, - vec![(Box::new(col("a")), Box::new(lit(1)))], + vec![(Box::new(col("c2")), Box::new(lit(1)))], Some(Box::new(lit(2))), )); assert_eq!(simplify(expr.clone()), expr); @@ -3869,10 +3885,10 @@ mod tests { )); assert_ne!(simplify(expr), lit(1)); - // Negative test: CASE WHEN col("x") > 5 THEN 1 ELSE 2 END should not be simplified + // Negative test: CASE WHEN col('c1') > 5 THEN 1 ELSE 2 END should not be simplified let expr = Expr::Case(Case::new( None, - vec![(Box::new(col("x").gt(lit(5))), Box::new(lit(1)))], + vec![(Box::new(col("c1").gt(lit(5))), Box::new(lit(1)))], Some(Box::new(lit(2))), )); assert_eq!(simplify(expr.clone()), expr); @@ -3880,76 +3896,124 @@ mod tests { #[test] fn simplify_expr_case_when_any_true() { - // CASE WHEN x > 0 THEN a WHEN true THEN b ELSE c END --> CASE WHEN x > 0 THEN a ELSE b END + // CASE WHEN c3 > 0 THEN 'a' WHEN true THEN 'b' ELSE 'c' END --> CASE WHEN c3 > 0 THEN 'a' ELSE 'b' END assert_eq!( simplify(Expr::Case(Case::new( None, vec![ - (Box::new(col("x").gt(lit(0))), Box::new(col("a"))), - (Box::new(lit(true)), Box::new(col("b"))), + (Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))), + (Box::new(lit(true)), Box::new(lit("b"))), ], - Some(Box::new(col("c"))), + Some(Box::new(lit("c"))), ))), Expr::Case(Case::new( None, - vec![(Box::new(col("x").gt(lit(0))), Box::new(col("a")))], - Some(Box::new(col("b"))), + vec![(Box::new(col("c3").gt(lit(0))), Box::new(lit("a")))], + Some(Box::new(lit("b"))), )) ); - // CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c WHEN z = 0 THEN d ELSE e END - // --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END + // CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' WHEN true THEN 'c' WHEN c3 = 0 THEN 'd' ELSE 'e' END + // --> CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' ELSE 'c' END assert_eq!( simplify(Expr::Case(Case::new( None, vec![ - (Box::new(col("x").gt(lit(0))), Box::new(col("a"))), - (Box::new(col("y").lt(lit(0))), Box::new(col("b"))), - (Box::new(lit(true)), Box::new(col("c"))), - (Box::new(col("z").eq(lit(0))), Box::new(col("d"))), + (Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))), + (Box::new(col("c4").lt(lit(0))), Box::new(lit("b"))), + (Box::new(lit(true)), Box::new(lit("c"))), + (Box::new(col("c3").eq(lit(0))), Box::new(lit("d"))), ], - Some(Box::new(col("e"))), + Some(Box::new(lit("e"))), ))), Expr::Case(Case::new( None, vec![ - (Box::new(col("x").gt(lit(0))), Box::new(col("a"))), - (Box::new(col("y").lt(lit(0))), Box::new(col("b"))), + (Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))), + (Box::new(col("c4").lt(lit(0))), Box::new(lit("b"))), ], - Some(Box::new(col("c"))), + Some(Box::new(lit("c"))), )) ); - // CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c END (no else) - // --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END + // CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 WHEN true THEN 3 END (no else) + // --> CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 ELSE 3 END assert_eq!( simplify(Expr::Case(Case::new( None, vec![ - (Box::new(col("x").gt(lit(0))), Box::new(col("a"))), - (Box::new(col("y").lt(lit(0))), Box::new(col("b"))), - (Box::new(lit(true)), Box::new(col("c"))), + (Box::new(col("c3").gt(lit(0))), Box::new(lit(1))), + (Box::new(col("c4").lt(lit(0))), Box::new(lit(2))), + (Box::new(lit(true)), Box::new(lit(3))), ], None, ))), Expr::Case(Case::new( None, vec![ - (Box::new(col("x").gt(lit(0))), Box::new(col("a"))), - (Box::new(col("y").lt(lit(0))), Box::new(col("b"))), + (Box::new(col("c3").gt(lit(0))), Box::new(lit(1))), + (Box::new(col("c4").lt(lit(0))), Box::new(lit(2))), ], - Some(Box::new(col("c"))), + Some(Box::new(lit(3))), )) ); - // Negative test: CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END should not be simplified + // Negative test: CASE WHEN c3 > 0 THEN c3 WHEN c4 < 0 THEN 2 ELSE 3 END should not be simplified let expr = Expr::Case(Case::new( None, vec![ - (Box::new(col("x").gt(lit(0))), Box::new(col("a"))), - (Box::new(col("y").lt(lit(0))), Box::new(col("b"))), + (Box::new(col("c3").gt(lit(0))), Box::new(col("c3"))), + (Box::new(col("c4").lt(lit(0))), Box::new(lit(2))), ], - Some(Box::new(col("c"))), + Some(Box::new(lit(3))), + )); + assert_eq!(simplify(expr.clone()), expr); + } + + #[test] + fn simplify_expr_case_when_any_false() { + // CASE WHEN false THEN 'a' END --> NULL + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![(Box::new(lit(false)), Box::new(lit("a")))], + None, + ))), + Expr::Literal(ScalarValue::Utf8(None), None) + ); + + // CASE WHEN false THEN 2 ELSE 1 END --> 1 + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![(Box::new(lit(false)), Box::new(lit(2)))], + Some(Box::new(lit(1))), + ))), + lit(1), + ); + + // CASE WHEN c3 < 10 THEN 'b' WHEN false then c3 ELSE c4 END --> CASE WHEN c3 < 10 THEN b ELSE c4 END + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![ + (Box::new(col("c3").lt(lit(10))), Box::new(lit("b"))), + (Box::new(lit(false)), Box::new(col("c3"))), + ], + Some(Box::new(col("c4"))), + ))), + Expr::Case(Case::new( + None, + vec![(Box::new(col("c3").lt(lit(10))), Box::new(lit("b")))], + Some(Box::new(col("c4"))), + )) + ); + + // Negative test: CASE WHEN c3 = 4 THEN 1 ELSE 2 END should not be simplified + let expr = Expr::Case(Case::new( + None, + vec![(Box::new(col("c3").eq(lit(4))), Box::new(lit(1)))], + Some(Box::new(lit(2))), )); assert_eq!(simplify(expr.clone()), expr); }