diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 9b2a5596827d0..b288706a54c9d 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -27,7 +27,7 @@ use arrow::datatypes::{ DataType, Field, Fields, Schema, SchemaBuilder, SchemaRef, TimeUnit, }; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; +use datafusion_common::tree_node::TransformedResult; use datafusion_common::{plan_err, DFSchema, Result, ScalarValue, TableReference}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{ @@ -37,7 +37,6 @@ use datafusion_expr::{ use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::simplify_expressions::GuaranteeRewriter; use datafusion_optimizer::{OptimizerConfig, OptimizerContext}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::Statement; @@ -45,6 +44,7 @@ use datafusion_sql::sqlparser::dialect::GenericDialect; use datafusion_sql::sqlparser::parser::Parser; use chrono::DateTime; +use datafusion_expr::expr_rewriter::rewrite_with_guarantees; use datafusion_functions::datetime; #[cfg(test)] @@ -304,8 +304,6 @@ fn test_inequalities_non_null_bounded() { ), ]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - // (original_expr, expected_simplification) let simplified_cases = &[ (col("x").lt(lit(0)), false), @@ -337,7 +335,7 @@ fn test_inequalities_non_null_bounded() { ), ]; - validate_simplified_cases(&mut rewriter, simplified_cases); + validate_simplified_cases(&guarantees, simplified_cases); let unchanged_cases = &[ col("x").gt(lit(2)), @@ -348,16 +346,20 @@ fn test_inequalities_non_null_bounded() { col("x").not_between(lit(3), lit(10)), ]; - validate_unchanged_cases(&mut rewriter, unchanged_cases); + validate_unchanged_cases(&guarantees, unchanged_cases); } -fn validate_simplified_cases(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)]) -where +fn validate_simplified_cases( + guarantees: &[(Expr, NullableInterval)], + cases: &[(Expr, T)], +) where ScalarValue: From, T: Clone, { for (expr, expected_value) in cases { - let output = expr.clone().rewrite(rewriter).data().unwrap(); + let output = rewrite_with_guarantees(expr.clone(), guarantees) + .data() + .unwrap(); let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, @@ -365,9 +367,11 @@ where ); } } -fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { +fn validate_unchanged_cases(guarantees: &[(Expr, NullableInterval)], cases: &[Expr]) { for expr in cases { - let output = expr.clone().rewrite(rewriter).data().unwrap(); + let output = rewrite_with_guarantees(expr.clone(), guarantees) + .data() + .unwrap(); assert_eq!( &output, expr, "{expr} was simplified to {output}, but expected it to be unchanged" diff --git a/datafusion/expr/src/expr_rewriter/guarantees.rs b/datafusion/expr/src/expr_rewriter/guarantees.rs new file mode 100644 index 0000000000000..b8589a17df3e4 --- /dev/null +++ b/datafusion/expr/src/expr_rewriter/guarantees.rs @@ -0,0 +1,668 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Rewrite expressions based on external expression value range guarantees. + +use crate::{expr::InList, lit, Between, BinaryExpr, Expr}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::{DataFusionError, HashMap, Result, ScalarValue}; +use datafusion_expr_common::interval_arithmetic::{Interval, NullableInterval}; +use std::borrow::Cow; + +/// Rewrite expressions to incorporate guarantees. +/// +/// See [`rewrite_with_guarantees`] for more information +pub struct GuaranteeRewriter<'a> { + guarantees: HashMap<&'a Expr, &'a NullableInterval>, +} + +impl<'a> GuaranteeRewriter<'a> { + pub fn new( + guarantees: impl IntoIterator, + ) -> Self { + Self { + guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), + } + } +} + +/// Rewrite expressions to incorporate guarantees. +/// +/// Guarantees are a mapping from an expression (which currently is always a +/// column reference) to a [NullableInterval] that represents the known possible +/// values of the expression. +/// +/// Rewriting expressions using this type of guarantee can make the work of other expression +/// simplifications, like const evaluation, easier. +/// +/// For example, if we know that a column is not null and has values in the +/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`. +/// +/// If the set of guarantees will be used to rewrite more than one expression, consider using +/// [rewrite_with_guarantees_map] instead. +/// +/// A full example of using this rewrite rule can be found in +/// [`ExprSimplifier::with_guarantees()`](https://docs.rs/datafusion/latest/datafusion/optimizer/simplify_expressions/struct.ExprSimplifier.html#method.with_guarantees). +pub fn rewrite_with_guarantees<'a>( + expr: Expr, + guarantees: impl IntoIterator, +) -> Result> { + let guarantees_map: HashMap<&Expr, &NullableInterval> = + guarantees.into_iter().map(|(k, v)| (k, v)).collect(); + rewrite_with_guarantees_map(expr, &guarantees_map) +} + +/// Rewrite expressions to incorporate guarantees. +/// +/// Guarantees are a mapping from an expression (which currently is always a +/// column reference) to a [NullableInterval]. The interval represents the known +/// possible values of the column. +/// +/// For example, if we know that a column is not null and has values in the +/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`. +pub fn rewrite_with_guarantees_map<'a>( + expr: Expr, + guarantees: &'a HashMap<&'a Expr, &'a NullableInterval>, +) -> Result> { + if guarantees.is_empty() { + return Ok(Transformed::no(expr)); + } + + expr.transform_up(|e| rewrite_expr(e, guarantees)) +} + +impl TreeNodeRewriter for GuaranteeRewriter<'_> { + type Node = Expr; + + fn f_up(&mut self, expr: Expr) -> Result> { + if self.guarantees.is_empty() { + return Ok(Transformed::no(expr)); + } + + rewrite_expr(expr, &self.guarantees) + } +} + +fn rewrite_expr( + expr: Expr, + guarantees: &HashMap<&Expr, &NullableInterval>, +) -> Result> { + // If an expression collapses to a single value, replace it with a literal + if let Some(interval) = guarantees.get(&expr) { + if let Some(value) = interval.single_value() { + return Ok(Transformed::yes(lit(value))); + } + } + + let result = match expr { + Expr::IsNull(inner) => match guarantees.get(inner.as_ref()) { + Some(NullableInterval::Null { .. }) => Transformed::yes(lit(true)), + Some(NullableInterval::NotNull { .. }) => Transformed::yes(lit(false)), + _ => Transformed::no(Expr::IsNull(inner)), + }, + Expr::IsNotNull(inner) => match guarantees.get(inner.as_ref()) { + Some(NullableInterval::Null { .. }) => Transformed::yes(lit(false)), + Some(NullableInterval::NotNull { .. }) => Transformed::yes(lit(true)), + _ => Transformed::no(Expr::IsNotNull(inner)), + }, + Expr::Between(b) => rewrite_between(b, guarantees)?, + Expr::BinaryExpr(b) => rewrite_binary_expr(b, guarantees)?, + Expr::InList(i) => rewrite_inlist(i, guarantees)?, + expr => Transformed::no(expr), + }; + Ok(result) +} + +fn rewrite_between( + between: Between, + guarantees: &HashMap<&Expr, &NullableInterval>, +) -> Result> { + let (Some(expr_interval), Expr::Literal(low, _), Expr::Literal(high, _)) = ( + guarantees.get(between.expr.as_ref()), + between.low.as_ref(), + between.high.as_ref(), + ) else { + return Ok(Transformed::no(Expr::Between(between))); + }; + + // Ensure that, if low or high are null, their type matches the other bound + let low = ensure_typed_null(low, high)?; + let high = ensure_typed_null(high, &low)?; + + let Ok(between_interval) = Interval::try_new(low, high) else { + // If we can't create an interval from the literals, be conservative and simply leave + // the expression unmodified. + return Ok(Transformed::no(Expr::Between(between))); + }; + + if between_interval.lower().is_null() && between_interval.upper().is_null() { + return Ok(Transformed::yes(lit(between_interval.lower().clone()))); + } + + let expr_interval = match expr_interval { + NullableInterval::Null { datatype } => { + // Value is guaranteed to be null, so we can simplify to null. + return Ok(Transformed::yes(lit( + ScalarValue::try_new_null(datatype).unwrap_or(ScalarValue::Null) + ))); + } + NullableInterval::MaybeNull { .. } => { + // Value may or may not be null, so we can't simplify the expression. + return Ok(Transformed::no(Expr::Between(between))); + } + NullableInterval::NotNull { values } => values, + }; + + let result = if between_interval.lower().is_null() { + // (NOT) BETWEEN NULL AND + let upper_bound = Interval::from(between_interval.upper().clone()); + if expr_interval.gt(&upper_bound)?.eq(&Interval::TRUE) { + // if > high, then certainly false + Transformed::yes(lit(between.negated)) + } else if expr_interval.lt_eq(&upper_bound)?.eq(&Interval::TRUE) { + // if <= high, then certainly null + Transformed::yes(lit(ScalarValue::try_new_null(&expr_interval.data_type()) + .unwrap_or(ScalarValue::Null))) + } else { + // otherwise unknown + Transformed::no(Expr::Between(between)) + } + } else if between_interval.upper().is_null() { + // (NOT) BETWEEN AND NULL + let lower_bound = Interval::from(between_interval.lower().clone()); + if expr_interval.lt(&lower_bound)?.eq(&Interval::TRUE) { + // if < low, then certainly false + Transformed::yes(lit(between.negated)) + } else if expr_interval.gt_eq(&lower_bound)?.eq(&Interval::TRUE) { + // if >= low, then certainly null + Transformed::yes(lit(ScalarValue::try_new_null(&expr_interval.data_type()) + .unwrap_or(ScalarValue::Null))) + } else { + // otherwise unknown + Transformed::no(Expr::Between(between)) + } + } else { + let contains = between_interval.contains(expr_interval)?; + if contains.eq(&Interval::TRUE) { + Transformed::yes(lit(!between.negated)) + } else if contains.eq(&Interval::FALSE) { + Transformed::yes(lit(between.negated)) + } else { + Transformed::no(Expr::Between(between)) + } + }; + Ok(result) +} + +fn ensure_typed_null( + value: &ScalarValue, + other: &ScalarValue, +) -> Result { + Ok( + if value.data_type().is_null() && !other.data_type().is_null() { + ScalarValue::try_new_null(&other.data_type())? + } else { + value.clone() + }, + ) +} + +fn rewrite_binary_expr( + binary: BinaryExpr, + guarantees: &HashMap<&Expr, &NullableInterval>, +) -> Result, DataFusionError> { + // The left or right side of expression might either have a guarantee + // or be a literal. Either way, we can resolve them to a NullableInterval. + let left_interval = guarantees + .get(binary.left.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value, _) = binary.left.as_ref() { + Some(Cow::Owned(value.clone().into())) + } else { + None + } + }); + let right_interval = guarantees + .get(binary.right.as_ref()) + .map(|interval| Cow::Borrowed(*interval)) + .or_else(|| { + if let Expr::Literal(value, _) = binary.right.as_ref() { + Some(Cow::Owned(value.clone().into())) + } else { + None + } + }); + + if let (Some(left_interval), Some(right_interval)) = (left_interval, right_interval) { + let result = left_interval.apply_operator(&binary.op, right_interval.as_ref())?; + if result.is_certainly_true() { + return Ok(Transformed::yes(lit(true))); + } else if result.is_certainly_false() { + return Ok(Transformed::yes(lit(false))); + } + } + Ok(Transformed::no(Expr::BinaryExpr(binary))) +} + +fn rewrite_inlist( + inlist: InList, + guarantees: &HashMap<&Expr, &NullableInterval>, +) -> Result, DataFusionError> { + let Some(interval) = guarantees.get(inlist.expr.as_ref()) else { + return Ok(Transformed::no(Expr::InList(inlist))); + }; + + let InList { + expr, + list, + negated, + } = inlist; + + // Can remove items from the list that don't match the guarantee + let list: Vec = list + .into_iter() + .filter_map(|expr| { + if let Expr::Literal(item, _) = &expr { + match interval.contains(NullableInterval::from(item.clone())) { + // If we know for certain the value isn't in the column's interval, + // we can skip checking it. + Ok(interval) if interval.is_certainly_false() => None, + Ok(_) => Some(Ok(expr)), + Err(e) => Some(Err(e)), + } + } else { + Some(Ok(expr)) + } + }) + .collect::>()?; + + Ok(Transformed::yes(Expr::InList(InList { + expr, + list, + negated, + }))) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::{col, Operator}; + use datafusion_common::tree_node::TransformedResult; + use datafusion_common::ScalarValue; + + #[test] + fn test_not_null_guarantee() { + // IsNull / IsNotNull can be rewritten to true / false + let guarantees = [ + // Note: AlwaysNull case handled by test_column_single_value test, + // since it's a special case of a column with a single value. + ( + col("x"), + NullableInterval::NotNull { + values: Interval::make(Some(1), Some(3)).unwrap(), + }, + ), + ]; + + let is_null_cases = vec![ + // x IS NULL => guaranteed false + (col("x").is_null(), Some(lit(false))), + // x IS NOT NULL => guaranteed true + (col("x").is_not_null(), Some(lit(true))), + // [1, 3] BETWEEN 0 AND 10 => guaranteed true + (col("x").between(lit(0), lit(10)), Some(lit(true))), + // x BETWEEN 1 AND -2 => unknown (actually guaranteed false) + (col("x").between(lit(1), lit(-2)), None), + // [1, 3] BETWEEN NULL AND 0 => guaranteed false + ( + col("x").between(lit(ScalarValue::Null), lit(0)), + Some(lit(false)), + ), + // [1, 3] BETWEEN NULL AND 1 => unknown + (col("x").between(lit(ScalarValue::Null), lit(1)), None), + // [1, 3] BETWEEN NULL AND 2 => unknown + (col("x").between(lit(ScalarValue::Null), lit(2)), None), + // [1, 3] BETWEEN NULL AND 3 => guaranteed NULL + ( + col("x").between(lit(ScalarValue::Null), lit(3)), + Some(lit(ScalarValue::Int32(None))), + ), + // [1, 3] BETWEEN NULL AND 4 => guaranteed NULL + ( + col("x").between(lit(ScalarValue::Null), lit(4)), + Some(lit(ScalarValue::Int32(None))), + ), + // [1, 3] BETWEEN 1 AND NULL => guaranteed NULL + ( + col("x").between(lit(0), lit(ScalarValue::Null)), + Some(lit(ScalarValue::Int32(None))), + ), + // [1, 3] BETWEEN 1 AND NULL => guaranteed NULL + ( + col("x").between(lit(1), lit(ScalarValue::Null)), + Some(lit(ScalarValue::Int32(None))), + ), + // [1, 3] BETWEEN 2 AND NULL => unknown + (col("x").between(lit(2), lit(ScalarValue::Null)), None), + // [1, 3] BETWEEN 3 AND NULL => unknown + (col("x").between(lit(3), lit(ScalarValue::Null)), None), + // [1, 3] BETWEEN 4 AND NULL => guaranteed false + ( + col("x").between(lit(4), lit(ScalarValue::Null)), + Some(lit(false)), + ), + // [1, 3] NOT BETWEEN NULL AND 0 => guaranteed false + ( + col("x").not_between(lit(ScalarValue::Null), lit(0)), + Some(lit(true)), + ), + // [1, 3] NOT BETWEEN NULL AND 1 => unknown + (col("x").not_between(lit(ScalarValue::Null), lit(1)), None), + // [1, 3] NOT BETWEEN NULL AND 2 => unknown + (col("x").not_between(lit(ScalarValue::Null), lit(2)), None), + // [1, 3] NOT BETWEEN NULL AND 3 => guaranteed NULL + ( + col("x").not_between(lit(ScalarValue::Null), lit(3)), + Some(lit(ScalarValue::Int32(None))), + ), + // [1, 3] NOT BETWEEN NULL AND 4 => guaranteed NULL + ( + col("x").not_between(lit(ScalarValue::Null), lit(4)), + Some(lit(ScalarValue::Int32(None))), + ), + // [1, 3] NOT BETWEEN 1 AND NULL => guaranteed NULL + ( + col("x").not_between(lit(0), lit(ScalarValue::Null)), + Some(lit(ScalarValue::Int32(None))), + ), + // [1, 3] NOT BETWEEN 1 AND NULL => guaranteed NULL + ( + col("x").not_between(lit(1), lit(ScalarValue::Null)), + Some(lit(ScalarValue::Int32(None))), + ), + // [1, 3] NOT BETWEEN 2 AND NULL => unknown + (col("x").not_between(lit(2), lit(ScalarValue::Null)), None), + // [1, 3] NOT BETWEEN 3 AND NULL => unknown + (col("x").not_between(lit(3), lit(ScalarValue::Null)), None), + // [1, 3] NOT BETWEEN 4 AND NULL => guaranteed false + ( + col("x").not_between(lit(4), lit(ScalarValue::Null)), + Some(lit(true)), + ), + ]; + + for case in is_null_cases { + let output = rewrite_with_guarantees(case.0.clone(), guarantees.iter()) + .data() + .unwrap(); + let expected = match case.1 { + None => case.0.clone(), + Some(expected) => expected, + }; + + assert_eq!(output, expected, "Failed for {}", case.0); + } + } + + fn validate_simplified_cases( + guarantees: &[(Expr, NullableInterval)], + cases: &[(Expr, T)], + ) where + ScalarValue: From, + T: Clone, + { + for (expr, expected_value) in cases { + let output = rewrite_with_guarantees(expr.clone(), guarantees.iter()) + .data() + .unwrap(); + let expected = lit(ScalarValue::from(expected_value.clone())); + assert_eq!( + output, expected, + "{expr} simplified to {output}, but expected {expected}" + ); + } + } + + fn validate_unchanged_cases(guarantees: &[(Expr, NullableInterval)], cases: &[Expr]) { + for expr in cases { + let output = rewrite_with_guarantees(expr.clone(), guarantees.iter()) + .data() + .unwrap(); + assert_eq!( + &output, expr, + "{expr} was simplified to {output}, but expected it to be unchanged" + ); + } + } + + #[test] + fn test_inequalities_non_null_unbounded() { + let guarantees = [ + // y ∈ [2021-01-01, ∞) (not null) + ( + col("x"), + NullableInterval::NotNull { + values: Interval::try_new( + ScalarValue::Date32(Some(18628)), + ScalarValue::Date32(None), + ) + .unwrap(), + }, + ), + ]; + + // (original_expr, expected_simplification) + let simplified_cases = &[ + (col("x").lt(lit(ScalarValue::Date32(Some(18628)))), false), + (col("x").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false), + (col("x").gt(lit(ScalarValue::Date32(Some(18627)))), true), + (col("x").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true), + (col("x").eq(lit(ScalarValue::Date32(Some(17000)))), false), + (col("x").not_eq(lit(ScalarValue::Date32(Some(17000)))), true), + ( + col("x").between( + lit(ScalarValue::Date32(Some(16000))), + lit(ScalarValue::Date32(Some(17000))), + ), + false, + ), + ( + col("x").not_between( + lit(ScalarValue::Date32(Some(16000))), + lit(ScalarValue::Date32(Some(17000))), + ), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Null)), + }), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Date32(Some(17000)))), + }), + true, + ), + ]; + + validate_simplified_cases(&guarantees, simplified_cases); + + let unchanged_cases = &[ + col("x").lt(lit(ScalarValue::Date32(Some(19000)))), + col("x").lt_eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").gt(lit(ScalarValue::Date32(Some(19000)))), + col("x").gt_eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").not_eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").between( + lit(ScalarValue::Date32(Some(18000))), + lit(ScalarValue::Date32(Some(19000))), + ), + col("x").not_between( + lit(ScalarValue::Date32(Some(18000))), + lit(ScalarValue::Date32(Some(19000))), + ), + ]; + + validate_unchanged_cases(&guarantees, unchanged_cases); + } + + #[test] + fn test_inequalities_maybe_null() { + let guarantees = [ + // x ∈ ("abc", "def"]? (maybe null) + ( + col("x"), + NullableInterval::MaybeNull { + values: Interval::try_new( + ScalarValue::from("abc"), + ScalarValue::from("def"), + ) + .unwrap(), + }, + ), + ]; + + // (original_expr, expected_simplification) + let simplified_cases = &[ + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit("z")), + }), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsNotDistinctFrom, + right: Box::new(lit("z")), + }), + false, + ), + ]; + + validate_simplified_cases(&guarantees, simplified_cases); + + let unchanged_cases = &[ + col("x").lt(lit("z")), + col("x").lt_eq(lit("z")), + col("x").gt(lit("a")), + col("x").gt_eq(lit("a")), + col("x").eq(lit("abc")), + col("x").not_eq(lit("a")), + col("x").between(lit("a"), lit("z")), + col("x").not_between(lit("a"), lit("z")), + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Null)), + }), + ]; + + validate_unchanged_cases(&guarantees, unchanged_cases); + } + + #[test] + fn test_column_single_value() { + let scalars = [ + ScalarValue::Null, + ScalarValue::Int32(Some(1)), + ScalarValue::Boolean(Some(true)), + ScalarValue::Boolean(None), + ScalarValue::from("abc"), + ScalarValue::LargeUtf8(Some("def".to_string())), + ScalarValue::Date32(Some(18628)), + ScalarValue::Date32(None), + ScalarValue::Decimal128(Some(1000), 19, 2), + ]; + + for scalar in scalars { + let guarantees = [(col("x"), NullableInterval::from(scalar.clone()))]; + + let output = rewrite_with_guarantees(col("x"), guarantees.iter()) + .data() + .unwrap(); + assert_eq!(output, Expr::Literal(scalar.clone(), None)); + } + } + + #[test] + fn test_in_list() { + let guarantees = [ + // x ∈ [1, 10] (not null) + ( + col("x"), + NullableInterval::NotNull { + values: Interval::try_new( + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(10)), + ) + .unwrap(), + }, + ), + ]; + + // These cases should be simplified so the list doesn't contain any + // values the guarantee says are outside the range. + // (column_name, starting_list, negated, expected_list) + let cases = &[ + // x IN (9, 11) => x IN (9) + ("x", vec![9, 11], false, vec![9]), + // x IN (10, 2) => x IN (10, 2) + ("x", vec![10, 2], false, vec![10, 2]), + // x NOT IN (9, 11) => x NOT IN (9) + ("x", vec![9, 11], true, vec![9]), + // x NOT IN (0, 22) => x NOT IN () + ("x", vec![0, 22], true, vec![]), + ]; + + for (column_name, starting_list, negated, expected_list) in cases { + let expr = col(*column_name).in_list( + starting_list + .iter() + .map(|v| lit(ScalarValue::Int32(Some(*v)))) + .collect(), + *negated, + ); + let output = rewrite_with_guarantees(expr.clone(), guarantees.iter()) + .data() + .unwrap(); + let expected_list = expected_list + .iter() + .map(|v| lit(ScalarValue::Int32(Some(*v)))) + .collect(); + assert_eq!( + output, + Expr::InList(InList { + expr: Box::new(col(*column_name)), + list: expected_list, + negated: *negated, + }) + ); + } + } +} diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 9c3c5df7007ff..31759f1cc9cfe 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -31,7 +31,12 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::TableReference; use datafusion_common::{Column, DFSchema, Result}; +mod guarantees; +pub use guarantees::rewrite_with_guarantees; +pub use guarantees::rewrite_with_guarantees_map; +pub use guarantees::GuaranteeRewriter; mod order_by; + pub use order_by::rewrite_sort_cols_by_aggs; /// Trait for rewriting [`Expr`]s into function calls. diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c7912bbf70b05..366c99ce8f28b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -31,6 +31,7 @@ use datafusion_common::{ cast::{as_large_list_array, as_list_array}, metadata::FieldMetadata, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, + HashMap, }; use datafusion_common::{ exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, ScalarValue, @@ -50,7 +51,6 @@ use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionP use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; use crate::analyzer::type_coercion::TypeCoercionRewriter; -use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::unwrap_cast::{ is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary, @@ -58,6 +58,7 @@ use crate::simplify_expressions::unwrap_cast::{ unwrap_cast_in_comparison_for_binary, }; use crate::simplify_expressions::SimplifyInfo; +use datafusion_expr::expr_rewriter::rewrite_with_guarantees_map; use datafusion_expr_common::casts::try_cast_literal_to_type; use indexmap::IndexSet; use regex::Regex; @@ -226,7 +227,8 @@ impl ExprSimplifier { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); - let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); + let guarantees_map: HashMap<&Expr, &NullableInterval> = + self.guarantees.iter().map(|(k, v)| (k, v)).collect(); if self.canonicalize { expr = expr.rewrite(&mut Canonicalizer::new()).data()? @@ -243,7 +245,9 @@ impl ExprSimplifier { } = expr .rewrite(&mut const_evaluator)? .transform_data(|expr| expr.rewrite(&mut simplifier))? - .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?; + .transform_data(|expr| { + rewrite_with_guarantees_map(expr, &guarantees_map) + })?; expr = data; num_cycles += 1; // Track if any transformation occurred diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs deleted file mode 100644 index 515fd29003af9..0000000000000 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ /dev/null @@ -1,476 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Simplifier implementation for [`ExprSimplifier::with_guarantees()`] -//! -//! [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees - -use std::{borrow::Cow, collections::HashMap}; - -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; -use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; - -/// Rewrite expressions to incorporate guarantees. -/// -/// Guarantees are a mapping from an expression (which currently is always a -/// column reference) to a [NullableInterval]. The interval represents the known -/// possible values of the column. Using these known values, expressions are -/// rewritten so they can be simplified using `ConstEvaluator` and `Simplifier`. -/// -/// For example, if we know that a column is not null and has values in the -/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`. -/// -/// See a full example in [`ExprSimplifier::with_guarantees()`]. -/// -/// [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees -pub struct GuaranteeRewriter<'a> { - guarantees: HashMap<&'a Expr, &'a NullableInterval>, -} - -impl<'a> GuaranteeRewriter<'a> { - pub fn new( - guarantees: impl IntoIterator, - ) -> Self { - Self { - // TODO: Clippy wants the "map" call removed, but doing so generates - // a compilation error. Remove the clippy directive once this - // issue is fixed. - #[allow(clippy::map_identity)] - guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), - } - } -} - -impl TreeNodeRewriter for GuaranteeRewriter<'_> { - type Node = Expr; - - fn f_up(&mut self, expr: Expr) -> Result> { - if self.guarantees.is_empty() { - return Ok(Transformed::no(expr)); - } - - match &expr { - Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(true))), - Some(NullableInterval::NotNull { .. }) => { - Ok(Transformed::yes(lit(false))) - } - _ => Ok(Transformed::no(expr)), - }, - Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(false))), - Some(NullableInterval::NotNull { .. }) => Ok(Transformed::yes(lit(true))), - _ => Ok(Transformed::no(expr)), - }, - Expr::Between(Between { - expr: inner, - negated, - low, - high, - }) => { - if let (Some(interval), Expr::Literal(low, _), Expr::Literal(high, _)) = ( - self.guarantees.get(inner.as_ref()), - low.as_ref(), - high.as_ref(), - ) { - let expr_interval = NullableInterval::NotNull { - values: Interval::try_new(low.clone(), high.clone())?, - }; - - let contains = expr_interval.contains(*interval)?; - - if contains.is_certainly_true() { - Ok(Transformed::yes(lit(!negated))) - } else if contains.is_certainly_false() { - Ok(Transformed::yes(lit(*negated))) - } else { - Ok(Transformed::no(expr)) - } - } else { - Ok(Transformed::no(expr)) - } - } - - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - // The left or right side of expression might either have a guarantee - // or be a literal. Either way, we can resolve them to a NullableInterval. - let left_interval = self - .guarantees - .get(left.as_ref()) - .map(|interval| Cow::Borrowed(*interval)) - .or_else(|| { - if let Expr::Literal(value, _) = left.as_ref() { - Some(Cow::Owned(value.clone().into())) - } else { - None - } - }); - let right_interval = self - .guarantees - .get(right.as_ref()) - .map(|interval| Cow::Borrowed(*interval)) - .or_else(|| { - if let Expr::Literal(value, _) = right.as_ref() { - Some(Cow::Owned(value.clone().into())) - } else { - None - } - }); - - match (left_interval, right_interval) { - (Some(left_interval), Some(right_interval)) => { - let result = - left_interval.apply_operator(op, right_interval.as_ref())?; - if result.is_certainly_true() { - Ok(Transformed::yes(lit(true))) - } else if result.is_certainly_false() { - Ok(Transformed::yes(lit(false))) - } else { - Ok(Transformed::no(expr)) - } - } - _ => Ok(Transformed::no(expr)), - } - } - - // Columns (if interval is collapsed to a single value) - Expr::Column(_) => { - if let Some(interval) = self.guarantees.get(&expr) { - Ok(Transformed::yes(interval.single_value().map_or(expr, lit))) - } else { - Ok(Transformed::no(expr)) - } - } - - Expr::InList(InList { - expr: inner, - list, - negated, - }) => { - if let Some(interval) = self.guarantees.get(inner.as_ref()) { - // Can remove items from the list that don't match the guarantee - let new_list: Vec = list - .iter() - .filter_map(|expr| { - if let Expr::Literal(item, _) = expr { - match interval - .contains(NullableInterval::from(item.clone())) - { - // If we know for certain the value isn't in the column's interval, - // we can skip checking it. - Ok(interval) if interval.is_certainly_false() => None, - Ok(_) => Some(Ok(expr.clone())), - Err(e) => Some(Err(e)), - } - } else { - Some(Ok(expr.clone())) - } - }) - .collect::>()?; - - Ok(Transformed::yes(Expr::InList(InList { - expr: inner.clone(), - list: new_list, - negated: *negated, - }))) - } else { - Ok(Transformed::no(expr)) - } - } - - _ => Ok(Transformed::no(expr)), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use arrow::datatypes::DataType; - use datafusion_common::tree_node::{TransformedResult, TreeNode}; - use datafusion_common::ScalarValue; - use datafusion_expr::{col, Operator}; - - #[test] - fn test_null_handling() { - // IsNull / IsNotNull can be rewritten to true / false - let guarantees = [ - // Note: AlwaysNull case handled by test_column_single_value test, - // since it's a special case of a column with a single value. - ( - col("x"), - NullableInterval::NotNull { - values: Interval::make_unbounded(&DataType::Boolean).unwrap(), - }, - ), - ]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - - // x IS NULL => guaranteed false - let expr = col("x").is_null(); - let output = expr.rewrite(&mut rewriter).data().unwrap(); - assert_eq!(output, lit(false)); - - // x IS NOT NULL => guaranteed true - let expr = col("x").is_not_null(); - let output = expr.rewrite(&mut rewriter).data().unwrap(); - assert_eq!(output, lit(true)); - } - - fn validate_simplified_cases(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)]) - where - ScalarValue: From, - T: Clone, - { - for (expr, expected_value) in cases { - let output = expr.clone().rewrite(rewriter).data().unwrap(); - let expected = lit(ScalarValue::from(expected_value.clone())); - assert_eq!( - output, expected, - "{expr} simplified to {output}, but expected {expected}" - ); - } - } - - fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { - for expr in cases { - let output = expr.clone().rewrite(rewriter).data().unwrap(); - assert_eq!( - &output, expr, - "{expr} was simplified to {output}, but expected it to be unchanged" - ); - } - } - - #[test] - fn test_inequalities_non_null_unbounded() { - let guarantees = [ - // y ∈ [2021-01-01, ∞) (not null) - ( - col("x"), - NullableInterval::NotNull { - values: Interval::try_new( - ScalarValue::Date32(Some(18628)), - ScalarValue::Date32(None), - ) - .unwrap(), - }, - ), - ]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - - // (original_expr, expected_simplification) - let simplified_cases = &[ - (col("x").lt(lit(ScalarValue::Date32(Some(18628)))), false), - (col("x").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false), - (col("x").gt(lit(ScalarValue::Date32(Some(18627)))), true), - (col("x").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true), - (col("x").eq(lit(ScalarValue::Date32(Some(17000)))), false), - (col("x").not_eq(lit(ScalarValue::Date32(Some(17000)))), true), - ( - col("x").between( - lit(ScalarValue::Date32(Some(16000))), - lit(ScalarValue::Date32(Some(17000))), - ), - false, - ), - ( - col("x").not_between( - lit(ScalarValue::Date32(Some(16000))), - lit(ScalarValue::Date32(Some(17000))), - ), - true, - ), - ( - Expr::BinaryExpr(BinaryExpr { - left: Box::new(col("x")), - op: Operator::IsDistinctFrom, - right: Box::new(lit(ScalarValue::Null)), - }), - true, - ), - ( - Expr::BinaryExpr(BinaryExpr { - left: Box::new(col("x")), - op: Operator::IsDistinctFrom, - right: Box::new(lit(ScalarValue::Date32(Some(17000)))), - }), - true, - ), - ]; - - validate_simplified_cases(&mut rewriter, simplified_cases); - - let unchanged_cases = &[ - col("x").lt(lit(ScalarValue::Date32(Some(19000)))), - col("x").lt_eq(lit(ScalarValue::Date32(Some(19000)))), - col("x").gt(lit(ScalarValue::Date32(Some(19000)))), - col("x").gt_eq(lit(ScalarValue::Date32(Some(19000)))), - col("x").eq(lit(ScalarValue::Date32(Some(19000)))), - col("x").not_eq(lit(ScalarValue::Date32(Some(19000)))), - col("x").between( - lit(ScalarValue::Date32(Some(18000))), - lit(ScalarValue::Date32(Some(19000))), - ), - col("x").not_between( - lit(ScalarValue::Date32(Some(18000))), - lit(ScalarValue::Date32(Some(19000))), - ), - ]; - - validate_unchanged_cases(&mut rewriter, unchanged_cases); - } - - #[test] - fn test_inequalities_maybe_null() { - let guarantees = [ - // x ∈ ("abc", "def"]? (maybe null) - ( - col("x"), - NullableInterval::MaybeNull { - values: Interval::try_new( - ScalarValue::from("abc"), - ScalarValue::from("def"), - ) - .unwrap(), - }, - ), - ]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - - // (original_expr, expected_simplification) - let simplified_cases = &[ - ( - Expr::BinaryExpr(BinaryExpr { - left: Box::new(col("x")), - op: Operator::IsDistinctFrom, - right: Box::new(lit("z")), - }), - true, - ), - ( - Expr::BinaryExpr(BinaryExpr { - left: Box::new(col("x")), - op: Operator::IsNotDistinctFrom, - right: Box::new(lit("z")), - }), - false, - ), - ]; - - validate_simplified_cases(&mut rewriter, simplified_cases); - - let unchanged_cases = &[ - col("x").lt(lit("z")), - col("x").lt_eq(lit("z")), - col("x").gt(lit("a")), - col("x").gt_eq(lit("a")), - col("x").eq(lit("abc")), - col("x").not_eq(lit("a")), - col("x").between(lit("a"), lit("z")), - col("x").not_between(lit("a"), lit("z")), - Expr::BinaryExpr(BinaryExpr { - left: Box::new(col("x")), - op: Operator::IsDistinctFrom, - right: Box::new(lit(ScalarValue::Null)), - }), - ]; - - validate_unchanged_cases(&mut rewriter, unchanged_cases); - } - - #[test] - fn test_column_single_value() { - let scalars = [ - ScalarValue::Null, - ScalarValue::Int32(Some(1)), - ScalarValue::Boolean(Some(true)), - ScalarValue::Boolean(None), - ScalarValue::from("abc"), - ScalarValue::LargeUtf8(Some("def".to_string())), - ScalarValue::Date32(Some(18628)), - ScalarValue::Date32(None), - ScalarValue::Decimal128(Some(1000), 19, 2), - ]; - - for scalar in scalars { - let guarantees = [(col("x"), NullableInterval::from(scalar.clone()))]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - - let output = col("x").rewrite(&mut rewriter).data().unwrap(); - assert_eq!(output, Expr::Literal(scalar.clone(), None)); - } - } - - #[test] - fn test_in_list() { - let guarantees = [ - // x ∈ [1, 10] (not null) - ( - col("x"), - NullableInterval::NotNull { - values: Interval::try_new( - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(10)), - ) - .unwrap(), - }, - ), - ]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - - // These cases should be simplified so the list doesn't contain any - // values the guarantee says are outside the range. - // (column_name, starting_list, negated, expected_list) - let cases = &[ - // x IN (9, 11) => x IN (9) - ("x", vec![9, 11], false, vec![9]), - // x IN (10, 2) => x IN (10, 2) - ("x", vec![10, 2], false, vec![10, 2]), - // x NOT IN (9, 11) => x NOT IN (9) - ("x", vec![9, 11], true, vec![9]), - // x NOT IN (0, 22) => x NOT IN () - ("x", vec![0, 22], true, vec![]), - ]; - - for (column_name, starting_list, negated, expected_list) in cases { - let expr = col(*column_name).in_list( - starting_list - .iter() - .map(|v| lit(ScalarValue::Int32(Some(*v)))) - .collect(), - *negated, - ); - let output = expr.clone().rewrite(&mut rewriter).data().unwrap(); - let expected_list = expected_list - .iter() - .map(|v| lit(ScalarValue::Int32(Some(*v)))) - .collect(); - assert_eq!( - output, - Expr::InList(InList { - expr: Box::new(col(*column_name)), - list: expected_list, - negated: *negated, - }) - ); - } - } -} diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 7ae38eec9a3ad..e238fca32689d 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -19,7 +19,6 @@ //! [`ExprSimplifier`] simplifies individual `Expr`s. pub mod expr_simplifier; -mod guarantees; mod inlist_simplifier; mod regex; pub mod simplify_exprs; @@ -35,4 +34,4 @@ pub use simplify_exprs::*; pub use simplify_predicates::simplify_predicates; // Export for test in datafusion/core/tests/optimizer_integration.rs -pub use guarantees::GuaranteeRewriter; +pub use datafusion_expr::expr_rewriter::GuaranteeRewriter;