diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 98fec3f7c928b..8aebae18c1ae9 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -2478,6 +2478,37 @@ mod tests { regex_not_match(col("c1"), lit("^foo$")), col("c1").not_eq(lit("foo")), ); + + // regular expressions that match exact captured literals + assert_change( + regex_match(col("c1"), lit("^(foo|bar)$")), + col("c1").eq(lit("foo")).or(col("c1").eq(lit("bar"))), + ); + assert_change( + regex_not_match(col("c1"), lit("^(foo|bar)$")), + col("c1") + .not_eq(lit("foo")) + .and(col("c1").not_eq(lit("bar"))), + ); + assert_change( + regex_match(col("c1"), lit("^(foo)$")), + col("c1").eq(lit("foo")), + ); + assert_change( + regex_match(col("c1"), lit("^(foo|bar|baz)$")), + ((col("c1").eq(lit("foo"))).or(col("c1").eq(lit("bar")))) + .or(col("c1").eq(lit("baz"))), + ); + assert_change( + regex_match(col("c1"), lit("^(foo|bar|baz|qux)$")), + col("c1") + .in_list(vec![lit("foo"), lit("bar"), lit("baz"), lit("qux")], false), + ); + + // regular expressions that mismatch captured literals + assert_no_change(regex_match(col("c1"), lit("(foo|bar)"))); + assert_no_change(regex_match(col("c1"), lit("(foo|bar)*"))); + assert_no_change(regex_match(col("c1"), lit("^(foo|bar)*"))); assert_no_change(regex_match(col("c1"), lit("^foo|bar$"))); assert_no_change(regex_match(col("c1"), lit("^(foo)(bar)$"))); assert_no_change(regex_match(col("c1"), lit("^"))); diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index ca298abcfa00d..108f1774b42c0 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -17,7 +17,7 @@ use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{lit, BinaryExpr, Expr, Like, Operator}; -use regex_syntax::hir::{Hir, HirKind, Literal, Look}; +use regex_syntax::hir::{Capture, Hir, HirKind, Literal, Look}; /// Maximum number of regex alternations (`foo|bar|...`) that will be expanded into multiple `LIKE` expressions. const MAX_REGEX_ALTERNATIONS_EXPANSION: usize = 4; @@ -33,7 +33,6 @@ pub fn simplify_regex_expr( match regex_syntax::Parser::new().parse(pattern) { Ok(hir) => { let kind = hir.kind(); - if let HirKind::Alternation(alts) = kind { if alts.len() <= MAX_REGEX_ALTERNATIONS_EXPANSION { if let Some(expr) = lower_alt(&mode, &left, alts) { @@ -166,6 +165,33 @@ fn is_anchored_literal(v: &[Hir]) -> bool { .all(|h| matches!(h.kind(), HirKind::Literal(_))) } +/// returns true if the elements in a `Concat` pattern are: +/// - `[Look::Start, Capture(Alternation(Literals...)), Look::End]` +fn is_anchored_capture(v: &[Hir]) -> bool { + if v.len() != 3 + || !matches!( + (v.first().unwrap().kind(), v.last().unwrap().kind()), + (&HirKind::Look(Look::Start), &HirKind::Look(Look::End)) + ) + { + return false; + } + + if let HirKind::Capture(cap, ..) = v[1].kind() { + let Capture { sub, .. } = cap; + if let HirKind::Alternation(alters) = sub.kind() { + let has_non_literal = alters + .iter() + .any(|v| !matches!(v.kind(), &HirKind::Literal(_))); + if has_non_literal { + return false; + } + } + } + + true +} + /// extracts a string literal expression assuming that [`is_anchored_literal`] /// returned true. fn anchored_literal_to_expr(v: &[Hir]) -> Option { @@ -179,6 +205,40 @@ fn anchored_literal_to_expr(v: &[Hir]) -> Option { } } +fn anchored_alternation_to_exprs(v: &[Hir]) -> Option> { + if 3 != v.len() { + return None; + } + + if let HirKind::Capture(cap, ..) = v[1].kind() { + let Capture { sub, .. } = cap; + if let HirKind::Alternation(alters) = sub.kind() { + let mut literals = Vec::with_capacity(alters.len()); + for hir in alters { + let mut is_safe = false; + if let HirKind::Literal(l) = hir.kind() { + if let Some(safe_literal) = str_from_literal(l).map(lit) { + literals.push(safe_literal); + is_safe = true; + } + } + + if !is_safe { + return None; + } + } + + return Some(literals); + } else if let HirKind::Literal(l) = sub.kind() { + if let Some(safe_literal) = str_from_literal(l).map(lit) { + return Some(vec![safe_literal]); + } + return None; + } + } + None +} + fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { match hir.kind() { HirKind::Empty => { @@ -189,10 +249,13 @@ fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { return Some(mode.expr(Box::new(left.clone()), format!("%{s}%"))); } HirKind::Concat(inner) if is_anchored_literal(inner) => { - let right = anchored_literal_to_expr(inner)?; - return Some( - mode.expr_matches_literal(Box::new(left.clone()), Box::new(right)), - ); + return anchored_literal_to_expr(inner).map(|right| { + mode.expr_matches_literal(Box::new(left.clone()), Box::new(right)) + }); + } + HirKind::Concat(inner) if is_anchored_capture(inner) => { + return anchored_alternation_to_exprs(inner) + .map(|right| left.clone().in_list(right, mode.not)); } HirKind::Concat(inner) => { if let Some(pattern) = collect_concat_to_like_string(inner) { @@ -201,7 +264,6 @@ fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { } _ => {} } - None }