diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index f629eaf95b5f4..7c9e976d1d0bb 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -251,6 +251,25 @@ fn simplify(expr: &Expr) -> Expr { op: *op, right: Box::new(simplify(right)), }, + Expr::InList { + expr, + list, + negated, + } if list.len() >= 1 && list.len() <= 2 => { + if *negated { + list.iter() + .skip(1) + .fold((**expr).clone().not_eq(list[0].clone()), |acc, e| { + acc.and((**expr).clone().not_eq(e.clone())) + }) + } else { + list.iter() + .skip(1) + .fold((**expr).clone().eq(list[0].clone()), |acc, e| { + acc.or((**expr).clone().eq(e.clone())) + }) + } + } _ => expr.clone(), } } @@ -293,7 +312,9 @@ impl SimplifyExpressions { #[cfg(test)] mod tests { use super::*; - use crate::logical_plan::{and, binary_expr, col, lit, Expr, LogicalPlanBuilder}; + use crate::logical_plan::{ + and, binary_expr, col, in_list, lit, Expr, LogicalPlanBuilder, + }; use crate::test::*; fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { @@ -396,6 +417,33 @@ mod tests { Ok(()) } + #[test] + fn simplify_inlist() -> Result<()> { + let expr = in_list(col("c"), vec![lit(1), lit(2)], false); + let expected = col("c").eq(lit(1)).or(col("c").eq(lit(2))); + + assert_eq!(simplify(&expr), expected); + Ok(()) + } + + #[test] + fn simplify_inlist_negated() -> Result<()> { + let expr = in_list(col("c"), vec![lit(1), lit(2)], true); + let expected = col("c").not_eq(lit(1)).and(col("c").not_eq(lit(2))); + + assert_eq!(simplify(&expr), expected); + Ok(()) + } + + #[test] + fn simplify_inlist_single() -> Result<()> { + let expr = in_list(col("c"), vec![lit(1)], false); + let expected = col("c").eq(lit(1)); + + assert_eq!(simplify(&expr), expected); + Ok(()) + } + #[test] fn test_simplify_simple_and() -> Result<()> { // (c > 5) AND (c > 5)