From 44d1b48a3924480c8ce8af06f8ff66d7ef1fc26b Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sun, 3 Sep 2023 14:53:42 -0700 Subject: [PATCH 01/15] feat: add guarantees to simplifcation --- .../simplify_expressions/expr_simplifier.rs | 18 + .../src/simplify_expressions/guarantees.rs | 336 ++++++++++++++++++ .../optimizer/src/simplify_expressions/mod.rs | 1 + .../physical-expr/src/intervals/cp_solver.rs | 202 +++++++---- 4 files changed, 480 insertions(+), 77 deletions(-) create mode 100644 datafusion/optimizer/src/simplify_expressions/guarantees.rs diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 76073728b0181..fc39aafe53005 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -43,6 +43,8 @@ use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionP use crate::simplify_expressions::SimplifyInfo; +use crate::simplify_expressions::guarantees::{Guarantee, GuaranteeRewriter}; + /// This structure handles API for expression simplification pub struct ExprSimplifier { info: S, @@ -149,6 +151,22 @@ impl ExprSimplifier { expr.rewrite(&mut expr_rewrite) } + + /// Add guarantees + pub fn simplify_with_gurantee<'a>( + &self, + expr: Expr, + guarantees: impl IntoIterator, + ) -> Result { + // Do a simplification pass in case it reveals places where a guarantee + // could be applied. + let expr = self.simplify(expr)?; + let mut rewriter = GuaranteeRewriter::new(guarantees); + let expr = expr.rewrite(&mut rewriter)?; + // Simplify after guarantees are applied, since constant folding should + // now be able to fold more expressions. + self.simplify(expr) + } } #[allow(rustdoc::private_intra_doc_links)] diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs new file mode 100644 index 0000000000000..a3a28b83ca560 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -0,0 +1,336 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Logic to inject guarantees with expressions. +//! +use datafusion_common::{tree_node::TreeNodeRewriter, Result, ScalarValue}; +use datafusion_expr::Expr; +use std::collections::HashMap; + +/// A bound on the value of an expression. +pub struct GuaranteeBound { + /// The value of the bound. + pub bound: ScalarValue, + /// If true, the bound is exclusive. If false, the bound is inclusive. + /// In terms of inequalities, this means the bound is `<` or `>` rather than + /// `<=` or `>=`. + pub open: bool, +} + +impl GuaranteeBound { + /// Create a new bound. + pub fn new(bound: ScalarValue, open: bool) -> Self { + Self { bound, open } + } +} + +impl Default for GuaranteeBound { + fn default() -> Self { + Self { + bound: ScalarValue::Null, + open: false, + } + } +} + +/// The null status of an expression. +/// +/// This might be populated by null count statistics, for example. A null count +/// of zero would mean `NeverNull`, while a null count equal to row count would +/// mean `AlwaysNull`. +pub enum NullStatus { + /// The expression is guaranteed to be non-null. + NeverNull, + /// The expression is guaranteed to be null. + AlwaysNull, + /// The expression isn't guaranteed to never be null or always be null. + MaybeNull, +} + +/// A set of constraints on the value of an expression. +/// +/// This is similar to [datafusion_physical_expr::intervals::Interval], except +/// that this is designed for working with logical expressions and also handles +/// nulls. +pub struct Guarantee { + /// The min values that the expression can take on. If `min.bound` is + pub min: GuaranteeBound, + /// The max values that the expression can take on. + pub max: GuaranteeBound, + /// Whether the expression is expected to be either always null or never null. + pub null_status: NullStatus, +} + +impl Guarantee { + /// Create a new guarantee. + pub fn new( + min: Option, + max: Option, + null_status: NullStatus, + ) -> Self { + Self { + min: min.unwrap_or_default(), + max: max.unwrap_or_default(), + null_status, + } + } +} + +impl From<&ScalarValue> for Guarantee { + fn from(value: &ScalarValue) -> Self { + Self { + min: GuaranteeBound { + bound: value.clone(), + open: false, + }, + max: GuaranteeBound { + bound: value.clone(), + open: false, + }, + null_status: match value { + ScalarValue::Null => NullStatus::AlwaysNull, + _ => NullStatus::NeverNull, + }, + } + } +} + +/// Rewrite expressions to incorporate guarantees. +/// +/// +pub(crate) struct GuaranteeRewriter<'a> { + guarantees: HashMap<&'a Expr, &'a Guarantee>, +} + +impl<'a> GuaranteeRewriter<'a> { + pub fn new(guarantees: impl IntoIterator) -> Self { + Self { + guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), + } + } +} + +impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { + type N = Expr; + + fn mutate(&mut self, expr: Expr) -> Result { + // IS NUll / NOT NUll + + // Inequality expressions + + // Columns (if bounds are equal and closed and column is not nullable) + + // In list + + Ok(expr) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use datafusion_common::tree_node::TreeNode; + use datafusion_expr::{col, lit}; + + #[test] + fn test_null_handling() { + // IsNull / IsNotNull can be rewritten to true / false + let guarantees = vec![ + (col("x"), Guarantee::new(None, None, NullStatus::AlwaysNull)), + (col("y"), Guarantee::new(None, None, NullStatus::NeverNull)), + ]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + let cases = &[ + (col("x").is_null(), true), + (col("x").is_not_null(), false), + (col("y").is_null(), false), + (col("y").is_not_null(), true), + ]; + + for (expr, expected_value) in cases { + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + assert_eq!( + output, + Expr::Literal(ScalarValue::Boolean(Some(*expected_value))) + ); + } + } + + #[test] + fn test_inequalities() { + let guarantees = vec![ + // 1 < x <= 3 + ( + col("x"), + Guarantee::new( + Some(GuaranteeBound::new(ScalarValue::Int32(Some(1)), true)), + Some(GuaranteeBound::new(ScalarValue::Int32(Some(3)), false)), + NullStatus::NeverNull, + ), + ), + // 2021-01-01 <= y + ( + col("y"), + Guarantee::new( + Some(GuaranteeBound::new(ScalarValue::Date32(Some(18628)), false)), + None, + NullStatus::NeverNull, + ), + ), + // "abc" < z <= "def" + ( + col("z"), + Guarantee::new( + Some(GuaranteeBound::new( + ScalarValue::Utf8(Some("abc".to_string())), + true, + )), + Some(GuaranteeBound::new( + ScalarValue::Utf8(Some("def".to_string())), + false, + )), + NullStatus::MaybeNull, + ), + ), + ]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // These cases should be simplified + let cases = &[ + (col("x").lt_eq(lit(1)), false), + (col("x").gt(lit(3)), false), + (col("y").gt_eq(lit(18628)), true), + (col("y").gt(lit(19000)), true), + (col("y").lt_eq(lit(17000)), false), + ]; + + for (expr, expected_value) in cases { + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + assert_eq!( + output, + Expr::Literal(ScalarValue::Boolean(Some(*expected_value))) + ); + } + + // These cases should be left as-is + let cases = &[ + col("x").gt(lit(2)), + col("x").lt_eq(lit(3)), + col("y").gt_eq(lit(17000)), + ]; + + for expr in cases { + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + assert_eq!(&output, expr); + } + } + + #[test] + fn test_column_single_value() { + let guarantees = vec![ + // x = 2 + (col("x"), Guarantee::from(&ScalarValue::Int32(Some(2)))), + // y is Null + (col("y"), Guarantee::from(&ScalarValue::Null)), + ]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // These cases should be simplified + let cases = &[ + (col("x").lt_eq(lit(1)), false), + (col("x").gt(lit(3)), false), + (col("x").eq(lit(1)), false), + (col("x").eq(lit(2)), true), + (col("x").gt(lit(1)), true), + (col("x").lt_eq(lit(2)), true), + (col("x").is_not_null(), true), + (col("x").is_null(), false), + (col("y").is_null(), true), + (col("y").is_not_null(), false), + (col("y").lt_eq(lit(17000)), false), + ]; + + for (expr, expected_value) in cases { + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + assert_eq!( + output, + Expr::Literal(ScalarValue::Boolean(Some(*expected_value))) + ); + } + } + + #[test] + fn test_in_list() { + let guarantees = vec![ + // x = 2 + (col("x"), Guarantee::from(&ScalarValue::Int32(Some(2)))), + // 1 <= y < 10 + ( + col("y"), + Guarantee::new( + Some(GuaranteeBound::new(ScalarValue::Int32(Some(1)), false)), + Some(GuaranteeBound::new(ScalarValue::Int32(Some(10)), true)), + NullStatus::NeverNull, + ), + ), + // z is null + (col("z"), Guarantee::from(&ScalarValue::Null)), + ]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // These cases should be simplified + let cases = &[ + // x IN () + (col("x").in_list(vec![], false), false), + // x IN (10, 11) + (col("x").in_list(vec![lit(10), lit(11)], false), false), + // x IN (10, 2) + (col("x").in_list(vec![lit(10), lit(2)], false), true), + // x NOT IN (10, 2) + (col("x").in_list(vec![lit(10), lit(2)], true), false), + // y IN (10, 11) + (col("y").in_list(vec![lit(10), lit(11)], false), false), + // y NOT IN (0, 22) + (col("y").in_list(vec![lit(0), lit(22)], true), true), + // z IN (10, 11) + (col("z").in_list(vec![lit(10), lit(11)], false), false), + ]; + + for (expr, expected_value) in cases { + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + assert_eq!( + output, + Expr::Literal(ScalarValue::Boolean(Some(*expected_value))) + ); + } + + // These cases should be left as-is + let cases = &[ + // y IN (10, 2) + col("y").in_list(vec![lit(10), lit(2)], false), + // y NOT IN (10, 2) + col("y").in_list(vec![lit(10), lit(2)], true), + ]; + + for expr in cases { + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + assert_eq!(&output, expr); + } + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index dfa0fe70433ba..b030793e67ce8 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -17,6 +17,7 @@ pub mod context; pub mod expr_simplifier; +pub mod guarantees; mod or_in_list_simplifier; mod regex; pub mod simplify_exprs; diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index edf1507c705a4..71691017186e4 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -16,6 +16,99 @@ // under the License. //! Constraint propagator/solver for custom PhysicalExpr graphs. +//! +//! Interval arithmetic provides a way to perform mathematical operations on +//! intervals, which represent a range of possible values rather than a single +//! point value. This allows for the propagation of ranges through mathematical +//! operations, and can be used to compute bounds for a complicated expression. +//! The key idea is that by breaking down a complicated expression into simpler +//! terms, and then combining the bounds for those simpler terms, one can +//! obtain bounds for the overall expression. +//! +//! For example, consider a mathematical expression such as x^2 + y = 4. Since +//! it would be a binary tree in [PhysicalExpr] notation, this type of an +//! hierarchical computation is well-suited for a graph based implementation. +//! In such an implementation, an equation system f(x) = 0 is represented by a +//! directed acyclic expression graph (DAEG). +//! +//! In order to use interval arithmetic to compute bounds for this expression, +//! one would first determine intervals that represent the possible values of x +//! and y. Let's say that the interval for x is [1, 2] and the interval for y +//! is [-3, 1]. In the chart below, you can see how the computation takes place. +//! +//! This way of using interval arithmetic to compute bounds for a complex +//! expression by combining the bounds for the constituent terms within the +//! original expression allows us to reason about the range of possible values +//! of the expression. This information later can be used in range pruning of +//! the provably unnecessary parts of `RecordBatch`es. +//! +//! References +//! 1 - Kabak, Mehmet Ozan. Analog Circuit Start-Up Behavior Analysis: An Interval +//! Arithmetic Based Approach, Chapter 4. Stanford University, 2015. +//! 2 - Moore, Ramon E. Interval analysis. Vol. 4. Englewood Cliffs: Prentice-Hall, 1966. +//! 3 - F. Messine, "Deterministic global optimization using interval constraint +//! propagation techniques," RAIRO-Operations Research, vol. 38, no. 04, +//! pp. 277{293, 2004. +//! +//! ``` text +//! Computing bounds for an expression using interval arithmetic. Constraint propagation through a top-down evaluation of the expression +//! graph using inverse semantics. +//! +//! [-2, 5] ∩ [4, 4] = [4, 4] [4, 4] +//! +-----+ +-----+ +-----+ +-----+ +//! +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ +//! | | | | | | | | | | | | | | | | +//! | +-----+ | | +-----+ | | +-----+ | | +-----+ | +//! | | | | | | | | +//! +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +//! | 2 | | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | [0, 1]* +//! |[.] | | | |[.] | | | |[.] | | | |[.] | | | +//! +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +//! | | | [-3, 1] | +//! | | | | +//! +---+ +---+ +---+ +---+ +//! | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [1, 2] +//! +---+ +---+ +---+ +---+ +//! +//! (a) Bottom-up evaluation: Step1 (b) Bottom up evaluation: Step2 (a) Top-down propagation: Step1 (b) Top-down propagation: Step2 +//! +//! [1 - 3, 4 + 1] = [-2, 5] [1 - 3, 4 + 1] = [-2, 5] +//! +-----+ +-----+ +-----+ +-----+ +//! +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ +//! | | | | | | | | | | | | | | | | +//! | +-----+ | | +-----+ | | +-----+ | | +-----+ | +//! | | | | | | | | +//! +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +//! | 2 |[1, 4] | y | | 2 |[1, 4] | y | | 2 |[3, 4]** | y | | 2 |[1, 4] | y | +//! |[.] | | | |[.] | | | |[.] | | | |[.] | | | +//! +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +//! | [-3, 1] | [-3, 1] | [0, 1] | [-3, 1] +//! | | | | +//! +---+ +---+ +---+ +---+ +//! | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [sqrt(3), 2]*** +//! +---+ +---+ +---+ +---+ +//! +//! (c) Bottom-up evaluation: Step3 (d) Bottom-up evaluation: Step4 (c) Top-down propagation: Step3 (d) Top-down propagation: Step4 +//! +//! * [-3, 1] ∩ ([4, 4] - [1, 4]) = [0, 1] +//! ** [1, 4] ∩ ([4, 4] - [0, 1]) = [3, 4] +//! *** [1, 2] ∩ [sqrt(3), sqrt(4)] = [sqrt(3), 2] +//! ``` +//! +//! # Examples +//! +//! ``` +//! # Expression: (x + 4) - y +//! +//! +//! # x: [0, 4), y: (1, 3] +//! +//! # Result: +//! ``` +//! +//! # Null handling +//! +//! use std::collections::HashSet; use std::fmt::{Display, Formatter}; @@ -39,83 +132,7 @@ use crate::PhysicalExpr; use super::IntervalBound; -// Interval arithmetic provides a way to perform mathematical operations on -// intervals, which represent a range of possible values rather than a single -// point value. This allows for the propagation of ranges through mathematical -// operations, and can be used to compute bounds for a complicated expression. -// The key idea is that by breaking down a complicated expression into simpler -// terms, and then combining the bounds for those simpler terms, one can -// obtain bounds for the overall expression. -// -// For example, consider a mathematical expression such as x^2 + y = 4. Since -// it would be a binary tree in [PhysicalExpr] notation, this type of an -// hierarchical computation is well-suited for a graph based implementation. -// In such an implementation, an equation system f(x) = 0 is represented by a -// directed acyclic expression graph (DAEG). -// -// In order to use interval arithmetic to compute bounds for this expression, -// one would first determine intervals that represent the possible values of x -// and y. Let's say that the interval for x is [1, 2] and the interval for y -// is [-3, 1]. In the chart below, you can see how the computation takes place. -// -// This way of using interval arithmetic to compute bounds for a complex -// expression by combining the bounds for the constituent terms within the -// original expression allows us to reason about the range of possible values -// of the expression. This information later can be used in range pruning of -// the provably unnecessary parts of `RecordBatch`es. -// -// References -// 1 - Kabak, Mehmet Ozan. Analog Circuit Start-Up Behavior Analysis: An Interval -// Arithmetic Based Approach, Chapter 4. Stanford University, 2015. -// 2 - Moore, Ramon E. Interval analysis. Vol. 4. Englewood Cliffs: Prentice-Hall, 1966. -// 3 - F. Messine, "Deterministic global optimization using interval constraint -// propagation techniques," RAIRO-Operations Research, vol. 38, no. 04, -// pp. 277{293, 2004. -// -// ``` text -// Computing bounds for an expression using interval arithmetic. Constraint propagation through a top-down evaluation of the expression -// graph using inverse semantics. -// -// [-2, 5] ∩ [4, 4] = [4, 4] [4, 4] -// +-----+ +-----+ +-----+ +-----+ -// +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ -// | | | | | | | | | | | | | | | | -// | +-----+ | | +-----+ | | +-----+ | | +-----+ | -// | | | | | | | | -// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -// | 2 | | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | [0, 1]* -// |[.] | | | |[.] | | | |[.] | | | |[.] | | | -// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -// | | | [-3, 1] | -// | | | | -// +---+ +---+ +---+ +---+ -// | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [1, 2] -// +---+ +---+ +---+ +---+ -// -// (a) Bottom-up evaluation: Step1 (b) Bottom up evaluation: Step2 (a) Top-down propagation: Step1 (b) Top-down propagation: Step2 -// -// [1 - 3, 4 + 1] = [-2, 5] [1 - 3, 4 + 1] = [-2, 5] -// +-----+ +-----+ +-----+ +-----+ -// +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ -// | | | | | | | | | | | | | | | | -// | +-----+ | | +-----+ | | +-----+ | | +-----+ | -// | | | | | | | | -// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -// | 2 |[1, 4] | y | | 2 |[1, 4] | y | | 2 |[3, 4]** | y | | 2 |[1, 4] | y | -// |[.] | | | |[.] | | | |[.] | | | |[.] | | | -// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -// | [-3, 1] | [-3, 1] | [0, 1] | [-3, 1] -// | | | | -// +---+ +---+ +---+ +---+ -// | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [sqrt(3), 2]*** -// +---+ +---+ +---+ +---+ -// -// (c) Bottom-up evaluation: Step3 (d) Bottom-up evaluation: Step4 (c) Top-down propagation: Step3 (d) Top-down propagation: Step4 -// -// * [-3, 1] ∩ ([4, 4] - [1, 4]) = [0, 1] -// ** [1, 4] ∩ ([4, 4] - [0, 1]) = [3, 4] -// *** [1, 2] ∩ [sqrt(3), sqrt(4)] = [sqrt(3), 2] -// ``` + /// This object implements a directed acyclic expression graph (DAEG) that /// is used to compute ranges for expressions through interval arithmetic. @@ -561,6 +578,7 @@ pub fn check_support(expr: &Arc) -> bool { #[cfg(test)] mod tests { use super::*; + use datafusion_expr::expr; use itertools::Itertools; use crate::expressions::{BinaryExpr, Column}; @@ -1123,6 +1141,36 @@ mod tests { let final_node_count = graph.node_count(); // Assert that the final node count is equal the previous node count (i.e., no node was pruned). assert_eq!(prev_node_count, final_node_count); + Ok(()) + } + + #[test] + fn my_test() -> Result<()> { + let expr_x = Arc::new(Column::new("x", 0)); + let expr_y = Arc::new(Column::new("y", 1)); + let expression = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + expr_x.clone(), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(4)))), + )), + Operator::Minus, + expr_y.clone(), + )); + + let mut graph = ExprIntervalGraph::try_new(expression.clone()).unwrap(); + // Pass in expr_x and expr_y so we can input their intervals + let mapping = graph.gather_node_indices(&[expr_x, expr_y]); + + let interval_x = Interval::make(Some(0), Some(4), (false, true)); + let interval_y = Interval::make(Some(1), Some(3), (true, false)); + graph.assign_intervals(&[ + (mapping[0].1, interval_x.clone()), + (mapping[1].1, interval_y.clone()), + ]); + + + Ok(()) } } From 4c1c3a96e378779bcd062bd1422ec98f046e64ad Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sun, 3 Sep 2023 21:50:50 -0700 Subject: [PATCH 02/15] null and comparison support --- .../src/simplify_expressions/guarantees.rs | 218 ++++++++++++++++-- 1 file changed, 203 insertions(+), 15 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index a3a28b83ca560..3bca3345ae4c3 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -18,10 +18,11 @@ //! Logic to inject guarantees with expressions. //! use datafusion_common::{tree_node::TreeNodeRewriter, Result, ScalarValue}; -use datafusion_expr::Expr; +use datafusion_expr::{lit, Between, BinaryExpr, Expr, Operator}; use std::collections::HashMap; /// A bound on the value of an expression. +#[derive(Debug, Clone, PartialEq)] pub struct GuaranteeBound { /// The value of the bound. pub bound: ScalarValue, @@ -52,6 +53,7 @@ impl Default for GuaranteeBound { /// This might be populated by null count statistics, for example. A null count /// of zero would mean `NeverNull`, while a null count equal to row count would /// mean `AlwaysNull`. +#[derive(Debug, Clone, PartialEq)] pub enum NullStatus { /// The expression is guaranteed to be non-null. NeverNull, @@ -66,6 +68,7 @@ pub enum NullStatus { /// This is similar to [datafusion_physical_expr::intervals::Interval], except /// that this is designed for working with logical expressions and also handles /// nulls. +#[derive(Debug, Clone, PartialEq)] pub struct Guarantee { /// The min values that the expression can take on. If `min.bound` is pub min: GuaranteeBound, @@ -88,6 +91,23 @@ impl Guarantee { null_status, } } + + /// Whether values are guaranteed to be greater than the given value. + fn greater_than(&self, value: &ScalarValue) -> bool { + self.min.bound > *value || (self.min.bound == *value && self.min.open) + } + + fn greater_than_or_eq(&self, value: &ScalarValue) -> bool { + self.min.bound >= *value + } + + fn less_than(&self, value: &ScalarValue) -> bool { + self.max.bound < *value || (self.max.bound == *value && self.max.open) + } + + fn less_than_or_eq(&self, value: &ScalarValue) -> bool { + self.max.bound <= *value + } } impl From<&ScalarValue> for Guarantee { @@ -128,15 +148,180 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { type N = Expr; fn mutate(&mut self, expr: Expr) -> Result { - // IS NUll / NOT NUll - - // Inequality expressions - - // Columns (if bounds are equal and closed and column is not nullable) - - // In list - - Ok(expr) + match &expr { + // IS NUll / NOT NUll + Expr::IsNull(inner) => { + if let Some(guarantee) = self.guarantees.get(inner.as_ref()) { + match guarantee.null_status { + NullStatus::AlwaysNull => Ok(lit(true)), + NullStatus::NeverNull => Ok(lit(false)), + NullStatus::MaybeNull => Ok(expr), + } + } else { + Ok(expr) + } + } + Expr::IsNotNull(inner) => { + if let Some(guarantee) = self.guarantees.get(inner.as_ref()) { + match guarantee.null_status { + NullStatus::AlwaysNull => Ok(lit(false)), + NullStatus::NeverNull => Ok(lit(true)), + NullStatus::MaybeNull => Ok(expr), + } + } else { + Ok(expr) + } + } + // Inequality expressions + Expr::Between(Between { + expr: inner, + negated, + low, + high, + }) => { + if let Some(guarantee) = self.guarantees.get(inner.as_ref()) { + match (low.as_ref(), high.as_ref()) { + (Expr::Literal(low), Expr::Literal(high)) => { + if guarantee.greater_than_or_eq(low) + && guarantee.less_than_or_eq(high) + { + // All values are between the bounds + Ok(lit(!negated)) + } else if guarantee.greater_than(high) + || guarantee.less_than(low) + { + // All values are outside the bounds + Ok(lit(*negated)) + } else { + Ok(expr) + } + } + (Expr::Literal(low), _) + if !guarantee.less_than(low) && !negated => + { + // All values are below the lower bound + Ok(lit(false)) + } + (_, Expr::Literal(high)) + if !guarantee.greater_than(high) && !negated => + { + // All values are above the upper bound + Ok(lit(false)) + } + _ => Ok(expr), + } + } else { + Ok(expr) + } + } + + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + // Check if this is a comparison + match op { + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq => {} + _ => return Ok(expr), + }; + + // Check if this is a comparison between a column and literal + let (col, op, value) = match (left.as_ref(), right.as_ref()) { + (Expr::Column(_), Expr::Literal(value)) => (left, *op, value), + (Expr::Literal(value), Expr::Column(_)) => { + (right, op.swap().unwrap(), value) + } + _ => return Ok(expr), + }; + + if let Some(guarantee) = self.guarantees.get(col.as_ref()) { + match op { + Operator::Eq => { + if guarantee.greater_than(value) || guarantee.less_than(value) + { + // All values are outside the bounds + Ok(lit(false)) + } else if guarantee.greater_than_or_eq(value) + && guarantee.less_than_or_eq(value) + { + // All values are equal to the bound + Ok(lit(true)) + } else { + Ok(expr) + } + } + Operator::NotEq => { + if guarantee.greater_than(value) || guarantee.less_than(value) + { + // All values are outside the bounds + Ok(lit(true)) + } else if guarantee.greater_than_or_eq(value) + && guarantee.less_than_or_eq(value) + { + // All values are equal to the bound + Ok(lit(false)) + } else { + Ok(expr) + } + } + Operator::Gt => { + if guarantee.less_than_or_eq(value) { + // All values are less than or equal to the bound + Ok(lit(false)) + } else if guarantee.greater_than(value) { + // All values are greater than the bound + Ok(lit(true)) + } else { + Ok(expr) + } + } + Operator::GtEq => { + if guarantee.less_than(value) { + // All values are less than the bound + Ok(lit(false)) + } else if guarantee.greater_than_or_eq(value) { + // All values are greater than or equal to the bound + Ok(lit(true)) + } else { + Ok(expr) + } + } + Operator::Lt => { + if guarantee.greater_than_or_eq(value) { + // All values are greater than or equal to the bound + Ok(lit(false)) + } else if guarantee.less_than(value) { + // All values are less than the bound + Ok(lit(true)) + } else { + Ok(expr) + } + } + Operator::LtEq => { + if guarantee.greater_than(value) { + // All values are greater than the bound + Ok(lit(false)) + } else if guarantee.less_than_or_eq(value) { + // All values are less than or equal to the bound + Ok(lit(true)) + } else { + Ok(expr) + } + } + _ => Ok(expr), + } + } else { + Ok(expr) + } + } + + // Columns (if bounds are equal and closed and column is not nullable) + + // In list + _ => Ok(expr), + } } } @@ -214,10 +399,11 @@ mod tests { // These cases should be simplified let cases = &[ (col("x").lt_eq(lit(1)), false), + (col("x").lt_eq(lit(3)), true), (col("x").gt(lit(3)), false), - (col("y").gt_eq(lit(18628)), true), - (col("y").gt(lit(19000)), true), - (col("y").lt_eq(lit(17000)), false), + (col("y").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true), + (col("y").gt_eq(lit(ScalarValue::Date32(Some(17000)))), true), + (col("y").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false), ]; for (expr, expected_value) in cases { @@ -231,8 +417,10 @@ mod tests { // These cases should be left as-is let cases = &[ col("x").gt(lit(2)), - col("x").lt_eq(lit(3)), - col("y").gt_eq(lit(17000)), + col("x").lt_eq(lit(2)), + col("x").between(lit(2), lit(5)), + col("x").not_between(lit(3), lit(10)), + col("y").gt(lit(ScalarValue::Date32(Some(19000)))), ]; for expr in cases { From 2134f2f18fbf0937ba8d0cc5abd8d0919d9b4d3b Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 4 Sep 2023 10:22:06 -0700 Subject: [PATCH 03/15] add support for literal expressions --- .../src/simplify_expressions/guarantees.rs | 93 ++++++++++--------- 1 file changed, 49 insertions(+), 44 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 3bca3345ae4c3..4e142ef2803ab 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -121,9 +121,10 @@ impl From<&ScalarValue> for Guarantee { bound: value.clone(), open: false, }, - null_status: match value { - ScalarValue::Null => NullStatus::AlwaysNull, - _ => NullStatus::NeverNull, + null_status: if value.is_null() { + NullStatus::AlwaysNull + } else { + NullStatus::NeverNull }, } } @@ -318,6 +319,25 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { } // Columns (if bounds are equal and closed and column is not nullable) + Expr::Column(_) => { + if let Some(guarantee) = self.guarantees.get(&expr) { + if guarantee.min == guarantee.max + // Case where column has a single valid value + && ((!guarantee.min.open + && !guarantee.min.bound.is_null() + && guarantee.null_status == NullStatus::NeverNull) + // Case where column is always null + || (guarantee.min.bound.is_null() + && guarantee.null_status == NullStatus::AlwaysNull)) + { + Ok(lit(guarantee.min.bound.clone())) + } else { + Ok(expr) + } + } else { + Ok(expr) + } + } // In list _ => Ok(expr), @@ -336,25 +356,21 @@ mod tests { fn test_null_handling() { // IsNull / IsNotNull can be rewritten to true / false let guarantees = vec![ - (col("x"), Guarantee::new(None, None, NullStatus::AlwaysNull)), - (col("y"), Guarantee::new(None, None, NullStatus::NeverNull)), + // Note: AlwaysNull case handled by test_column_single_value test, + // since it's a special case of a column with a single value. + (col("x"), Guarantee::new(None, None, NullStatus::NeverNull)), ]; let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - let cases = &[ - (col("x").is_null(), true), - (col("x").is_not_null(), false), - (col("y").is_null(), false), - (col("y").is_not_null(), true), - ]; + // x IS NULL => guaranteed false + let expr = col("x").is_null(); + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + assert_eq!(output, lit(false)); - for (expr, expected_value) in cases { - let output = expr.clone().rewrite(&mut rewriter).unwrap(); - assert_eq!( - output, - Expr::Literal(ScalarValue::Boolean(Some(*expected_value))) - ); - } + // x IS NOT NULL => guaranteed true + let expr = col("x").is_not_null(); + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + assert_eq!(output, lit(true)); } #[test] @@ -431,35 +447,24 @@ mod tests { #[test] fn test_column_single_value() { - let guarantees = vec![ - // x = 2 - (col("x"), Guarantee::from(&ScalarValue::Int32(Some(2)))), - // y is Null - (col("y"), Guarantee::from(&ScalarValue::Null)), + let scalars = [ + ScalarValue::Null, + ScalarValue::Int32(Some(1)), + ScalarValue::Boolean(Some(true)), + ScalarValue::Boolean(None), + ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::LargeUtf8(Some("def".to_string())), + ScalarValue::Date32(Some(18628)), + ScalarValue::Date32(None), + ScalarValue::Decimal128(Some(1000), 19, 2), ]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - // These cases should be simplified - let cases = &[ - (col("x").lt_eq(lit(1)), false), - (col("x").gt(lit(3)), false), - (col("x").eq(lit(1)), false), - (col("x").eq(lit(2)), true), - (col("x").gt(lit(1)), true), - (col("x").lt_eq(lit(2)), true), - (col("x").is_not_null(), true), - (col("x").is_null(), false), - (col("y").is_null(), true), - (col("y").is_not_null(), false), - (col("y").lt_eq(lit(17000)), false), - ]; + for scalar in &scalars { + let guarantees = vec![(col("x"), Guarantee::from(scalar))]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - for (expr, expected_value) in cases { - let output = expr.clone().rewrite(&mut rewriter).unwrap(); - assert_eq!( - output, - Expr::Literal(ScalarValue::Boolean(Some(*expected_value))) - ); + let output = col("x").rewrite(&mut rewriter).unwrap(); + assert_eq!(output, Expr::Literal(scalar.clone())); } } From caa738f591470295bd6a4af026b1ba9d292f86bb Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 4 Sep 2023 10:48:11 -0700 Subject: [PATCH 04/15] implement inlist guarantee use --- .../src/simplify_expressions/guarantees.rs | 104 +++++++++++------- 1 file changed, 66 insertions(+), 38 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 4e142ef2803ab..0772eaab50f3e 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -18,7 +18,7 @@ //! Logic to inject guarantees with expressions. //! use datafusion_common::{tree_node::TreeNodeRewriter, Result, ScalarValue}; -use datafusion_expr::{lit, Between, BinaryExpr, Expr, Operator}; +use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr, Operator}; use std::collections::HashMap; /// A bound on the value of an expression. @@ -108,6 +108,11 @@ impl Guarantee { fn less_than_or_eq(&self, value: &ScalarValue) -> bool { self.max.bound <= *value } + + /// Whether the guarantee could contain the given value. + fn contains(&self, value: &ScalarValue) -> bool { + !self.less_than(value) && !self.greater_than(value) + } } impl From<&ScalarValue> for Guarantee { @@ -237,6 +242,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { _ => return Ok(expr), }; + // TODO: can this be simplified? if let Some(guarantee) = self.guarantees.get(col.as_ref()) { match op { Operator::Eq => { @@ -339,7 +345,35 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { } } - // In list + Expr::InList(InList { + expr: inner, + list, + negated, + }) => { + if let Some(guarantee) = self.guarantees.get(inner.as_ref()) { + // Can remove items from the list that don't match the guarantee + let new_list: Vec = list + .iter() + .filter(|item| { + if let Expr::Literal(item) = item { + guarantee.contains(item) + } else { + true + } + }) + .cloned() + .collect(); + + Ok(Expr::InList(InList { + expr: inner.clone(), + list: new_list, + negated: *negated, + })) + } else { + Ok(expr) + } + } + _ => Ok(expr), } } @@ -471,59 +505,53 @@ mod tests { #[test] fn test_in_list() { let guarantees = vec![ - // x = 2 - (col("x"), Guarantee::from(&ScalarValue::Int32(Some(2)))), - // 1 <= y < 10 + // 1 <= x < 10 ( - col("y"), + col("x"), Guarantee::new( Some(GuaranteeBound::new(ScalarValue::Int32(Some(1)), false)), Some(GuaranteeBound::new(ScalarValue::Int32(Some(10)), true)), NullStatus::NeverNull, ), ), - // z is null - (col("z"), Guarantee::from(&ScalarValue::Null)), ]; let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - // These cases should be simplified + // These cases should be simplified so the list doesn't contain any + // values the guarantee says are outside the range. + // (column_name, starting_list, negated, expected_list) let cases = &[ - // x IN () - (col("x").in_list(vec![], false), false), - // x IN (10, 11) - (col("x").in_list(vec![lit(10), lit(11)], false), false), - // x IN (10, 2) - (col("x").in_list(vec![lit(10), lit(2)], false), true), - // x NOT IN (10, 2) - (col("x").in_list(vec![lit(10), lit(2)], true), false), - // y IN (10, 11) - (col("y").in_list(vec![lit(10), lit(11)], false), false), - // y NOT IN (0, 22) - (col("y").in_list(vec![lit(0), lit(22)], true), true), - // z IN (10, 11) - (col("z").in_list(vec![lit(10), lit(11)], false), false), + // x IN (9, 11) => x IN (9) + ("x", vec![9, 11], false, vec![9]), + // x IN (10, 2) => x IN (2) + ("x", vec![10, 2], false, vec![2]), + // x NOT IN (9, 11) => x NOT IN (9) + ("x", vec![9, 11], true, vec![9]), + // x NOT IN (0, 22) => x NOT IN () + ("x", vec![0, 22], true, vec![]), ]; - for (expr, expected_value) in cases { + for (column_name, starting_list, negated, expected_list) in cases { + let expr = col(*column_name).in_list( + starting_list + .iter() + .map(|v| lit(ScalarValue::Int32(Some(*v)))) + .collect(), + *negated, + ); let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let expected_list = expected_list + .iter() + .map(|v| lit(ScalarValue::Int32(Some(*v)))) + .collect(); assert_eq!( output, - Expr::Literal(ScalarValue::Boolean(Some(*expected_value))) + Expr::InList(InList { + expr: Box::new(col(*column_name)), + list: expected_list, + negated: *negated, + }) ); } - - // These cases should be left as-is - let cases = &[ - // y IN (10, 2) - col("y").in_list(vec![lit(10), lit(2)], false), - // y NOT IN (10, 2) - col("y").in_list(vec![lit(10), lit(2)], true), - ]; - - for expr in cases { - let output = expr.clone().rewrite(&mut rewriter).unwrap(); - assert_eq!(&output, expr); - } } } From ff7ed70260d9aa5512ab1bc494bd885497d39479 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 4 Sep 2023 11:34:01 -0700 Subject: [PATCH 05/15] test the outer function --- .../simplify_expressions/expr_simplifier.rs | 66 ++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index fc39aafe53005..78a8d70635b4d 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -153,7 +153,7 @@ impl ExprSimplifier { } /// Add guarantees - pub fn simplify_with_gurantee<'a>( + pub fn simplify_with_guarantees<'a>( &self, expr: Expr, guarantees: impl IntoIterator, @@ -1215,6 +1215,7 @@ mod tests { }; use crate::simplify_expressions::{ + guarantees::{GuaranteeBound, NullStatus}, utils::for_test::{cast_to_int64_expr, now_expr, to_timestamp_expr}, SimplifyContext, }; @@ -2693,6 +2694,17 @@ mod tests { try_simplify(expr).unwrap() } + fn simplify_with_guarantee(expr: Expr, guarantees: &[(Expr, Guarantee)]) -> Expr { + let schema = expr_test_schema(); + let execution_props = ExecutionProps::new(); + let simplifier = ExprSimplifier::new( + SimplifyContext::new(&execution_props).with_schema(schema), + ); + simplifier + .simplify_with_guarantees(expr, guarantees) + .unwrap() + } + fn expr_test_schema() -> DFSchemaRef { Arc::new( DFSchema::new_with_metadata( @@ -3156,4 +3168,56 @@ mod tests { let expr = not_ilike(null, "%"); assert_eq!(simplify(expr), lit_bool_null()); } + + #[test] + fn test_simplify_with_guarantee() { + // (x >= 3) AND (y + 2 < 10 OR (z NOT IN ("a", "b"))) + let expr_x = col("c3").gt(lit(3_i64)); + let expr_y = (col("c4") + lit(2_u32)).lt(lit(10_u32)); + let expr_z = col("c1").in_list(vec![lit("a"), lit("b")], true); + let expr = expr_x.clone().and(expr_y.or(expr_z)); + + // All guaranteed null + let guarantees = vec![ + (col("c3"), Guarantee::from(&ScalarValue::Int64(None))), + (col("c4"), Guarantee::from(&ScalarValue::UInt32(None))), + (col("c1"), Guarantee::from(&ScalarValue::Utf8(None))), + ]; + + let output = simplify_with_guarantee(expr.clone(), &guarantees); + assert_eq!(output, lit_bool_null()); + + // All guaranteed false + let guarantees = vec![ + ( + col("c3"), + Guarantee::new( + Some(GuaranteeBound::new(ScalarValue::Int64(Some(0)), false)), + Some(GuaranteeBound::new(ScalarValue::Int64(Some(2)), false)), + NullStatus::NeverNull, + ), + ), + (col("c4"), Guarantee::from(&ScalarValue::UInt32(Some(9)))), + ( + col("c1"), + Guarantee::from(&ScalarValue::Utf8(Some("a".to_string()))), + ), + ]; + let output = simplify_with_guarantee(expr.clone(), &guarantees); + assert_eq!(output, lit(false)); + + // Sufficient true guarantees + let guarantees = vec![ + (col("c3"), Guarantee::from(&ScalarValue::Int64(Some(9)))), + (col("c4"), Guarantee::from(&ScalarValue::UInt32(Some(3)))), + ]; + let output = simplify_with_guarantee(expr.clone(), &guarantees); + assert_eq!(output, lit(true)); + + // Only partially simplify + let guarantees = + vec![(col("c4"), Guarantee::from(&ScalarValue::UInt32(Some(3))))]; + let output = simplify_with_guarantee(expr.clone(), &guarantees); + assert_eq!(&output, &expr_x); + } } From a6b57e38eb00da2a6c5396dca0b5f1772578ac78 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 4 Sep 2023 12:19:26 -0700 Subject: [PATCH 06/15] docs --- .../simplify_expressions/expr_simplifier.rs | 53 ++++- .../src/simplify_expressions/guarantees.rs | 37 +++- .../physical-expr/src/intervals/cp_solver.rs | 202 +++++++----------- 3 files changed, 160 insertions(+), 132 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 78a8d70635b4d..7f1c692238056 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -152,7 +152,58 @@ impl ExprSimplifier { expr.rewrite(&mut expr_rewrite) } - /// Add guarantees + /// Input guarantees and simplify the expression. + /// + /// The guarantees can simplify expressions. For example, if a column is + /// guaranteed to always be a certain value, it's references in the expression + /// can be replaced with that literal. + /// + /// ```rust + /// use arrow::datatypes::{DataType, Field, Schema}; + /// use datafusion_expr::{col, lit, Expr}; + /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; + /// use datafusion_physical_expr::execution_props::ExecutionProps; + /// use datafusion_optimizer::simplify_expressions::{ + /// ExprSimplifier, SimplifyContext, + /// guarantees::{Guarantee, GuaranteeBound, NullStatus}}; + /// + /// let schema = Schema::new(vec![ + /// Field::new("x", DataType::Int64, false), + /// Field::new("y", DataType::UInt32, false), + /// Field::new("z", DataType::Int64, false), + /// ]) + /// .to_dfschema_ref().unwrap(); + /// + /// // Create the simplifier + /// let props = ExecutionProps::new(); + /// let context = SimplifyContext::new(&props) + /// .with_schema(schema); + /// let simplifier = ExprSimplifier::new(context); + /// + /// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5) + /// let expr_x = col("x").gt_eq(lit(3_i64)); + /// let expr_y = (col("y") + lit(2_u32)).lt(lit(10_u32)); + /// let expr_z = col("z").gt(lit(5_i64)); + /// let expr = expr_x.and(expr_y).and(expr_z.clone()); + /// + /// let guarantees = vec![ + /// // x is guaranteed to be between 3 and 5 + /// ( + /// col("x"), + /// Guarantee::new( + /// Some(GuaranteeBound::new(ScalarValue::Int64(Some(3)), false)), + /// Some(GuaranteeBound::new(ScalarValue::Int64(Some(5)), false)), + /// NullStatus::NeverNull, + /// ) + /// ), + /// // y is guaranteed to be 3 + /// (col("y"), Guarantee::from(&ScalarValue::UInt32(Some(3)))), + /// ]; + /// let output = simplifier.simplify_with_guarantees(expr, &guarantees).unwrap(); + /// // Expression becomes: true AND true AND (z > 5), which simplifies to + /// // z > 5. + /// assert_eq!(output, expr_z); + /// ``` pub fn simplify_with_guarantees<'a>( &self, expr: Expr, diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 0772eaab50f3e..9b62acd44d103 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -15,8 +15,27 @@ // specific language governing permissions and limitations // under the License. -//! Logic to inject guarantees with expressions. +//! Guarantees which can be used with [ExprSimplifier::simplify_with_guarantees()][crate::simplify_expressions::expr_simplifier::ExprSimplifier::simplify_with_guarantees]. //! +//! Guarantees can represent single values or possible ranges of values. +//! +//! ``` +//! use datafusion_common::scalar::ScalarValue; +//! use datafusion_optimizer::simplify_expressions::guarantees::{ +//! Guarantee, GuaranteeBound, NullStatus}; +//! +//! // Guarantee that value is always 1_i32 +//! Guarantee::from(&ScalarValue::Int32(Some(1))); +//! // Guarantee that value is always NULL +//! Guarantee::from(&ScalarValue::Null); +//! // Guarantee that value is always between 1_i32 and 10_i32 (inclusive) +//! // and never null. +//! Guarantee::new( +//! Some(GuaranteeBound::new(ScalarValue::Int32(Some(1)), false)), +//! Some(GuaranteeBound::new(ScalarValue::Int32(Some(10)), false)), +//! NullStatus::NeverNull, +//! ); +//! ``` use datafusion_common::{tree_node::TreeNodeRewriter, Result, ScalarValue}; use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr, Operator}; use std::collections::HashMap; @@ -24,7 +43,7 @@ use std::collections::HashMap; /// A bound on the value of an expression. #[derive(Debug, Clone, PartialEq)] pub struct GuaranteeBound { - /// The value of the bound. + /// The value of the bound. If the bound is null, then there is no bound. pub bound: ScalarValue, /// If true, the bound is exclusive. If false, the bound is inclusive. /// In terms of inequalities, this means the bound is `<` or `>` rather than @@ -40,6 +59,7 @@ impl GuaranteeBound { } impl Default for GuaranteeBound { + /// Default value is a closed bound at null. fn default() -> Self { Self { bound: ScalarValue::Null, @@ -70,9 +90,11 @@ pub enum NullStatus { /// nulls. #[derive(Debug, Clone, PartialEq)] pub struct Guarantee { - /// The min values that the expression can take on. If `min.bound` is + /// The min values that the expression can take on. If the min is null, then + /// there is no known min. pub min: GuaranteeBound, - /// The max values that the expression can take on. + /// The max values that the expression can take on. If the max is null, + /// then there is no known max. pub max: GuaranteeBound, /// Whether the expression is expected to be either always null or never null. pub null_status: NullStatus, @@ -97,14 +119,19 @@ impl Guarantee { self.min.bound > *value || (self.min.bound == *value && self.min.open) } + /// Whether values are guaranteed to be greater than or equal to the given + /// value. fn greater_than_or_eq(&self, value: &ScalarValue) -> bool { self.min.bound >= *value } + /// Whether values are guaranteed to be less than the given value. fn less_than(&self, value: &ScalarValue) -> bool { self.max.bound < *value || (self.max.bound == *value && self.max.open) } + /// Whether values are guaranteed to be less than or equal to the given + /// value. fn less_than_or_eq(&self, value: &ScalarValue) -> bool { self.max.bound <= *value } @@ -136,8 +163,6 @@ impl From<&ScalarValue> for Guarantee { } /// Rewrite expressions to incorporate guarantees. -/// -/// pub(crate) struct GuaranteeRewriter<'a> { guarantees: HashMap<&'a Expr, &'a Guarantee>, } diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 71691017186e4..edf1507c705a4 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -16,99 +16,6 @@ // under the License. //! Constraint propagator/solver for custom PhysicalExpr graphs. -//! -//! Interval arithmetic provides a way to perform mathematical operations on -//! intervals, which represent a range of possible values rather than a single -//! point value. This allows for the propagation of ranges through mathematical -//! operations, and can be used to compute bounds for a complicated expression. -//! The key idea is that by breaking down a complicated expression into simpler -//! terms, and then combining the bounds for those simpler terms, one can -//! obtain bounds for the overall expression. -//! -//! For example, consider a mathematical expression such as x^2 + y = 4. Since -//! it would be a binary tree in [PhysicalExpr] notation, this type of an -//! hierarchical computation is well-suited for a graph based implementation. -//! In such an implementation, an equation system f(x) = 0 is represented by a -//! directed acyclic expression graph (DAEG). -//! -//! In order to use interval arithmetic to compute bounds for this expression, -//! one would first determine intervals that represent the possible values of x -//! and y. Let's say that the interval for x is [1, 2] and the interval for y -//! is [-3, 1]. In the chart below, you can see how the computation takes place. -//! -//! This way of using interval arithmetic to compute bounds for a complex -//! expression by combining the bounds for the constituent terms within the -//! original expression allows us to reason about the range of possible values -//! of the expression. This information later can be used in range pruning of -//! the provably unnecessary parts of `RecordBatch`es. -//! -//! References -//! 1 - Kabak, Mehmet Ozan. Analog Circuit Start-Up Behavior Analysis: An Interval -//! Arithmetic Based Approach, Chapter 4. Stanford University, 2015. -//! 2 - Moore, Ramon E. Interval analysis. Vol. 4. Englewood Cliffs: Prentice-Hall, 1966. -//! 3 - F. Messine, "Deterministic global optimization using interval constraint -//! propagation techniques," RAIRO-Operations Research, vol. 38, no. 04, -//! pp. 277{293, 2004. -//! -//! ``` text -//! Computing bounds for an expression using interval arithmetic. Constraint propagation through a top-down evaluation of the expression -//! graph using inverse semantics. -//! -//! [-2, 5] ∩ [4, 4] = [4, 4] [4, 4] -//! +-----+ +-----+ +-----+ +-----+ -//! +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ -//! | | | | | | | | | | | | | | | | -//! | +-----+ | | +-----+ | | +-----+ | | +-----+ | -//! | | | | | | | | -//! +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -//! | 2 | | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | [0, 1]* -//! |[.] | | | |[.] | | | |[.] | | | |[.] | | | -//! +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -//! | | | [-3, 1] | -//! | | | | -//! +---+ +---+ +---+ +---+ -//! | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [1, 2] -//! +---+ +---+ +---+ +---+ -//! -//! (a) Bottom-up evaluation: Step1 (b) Bottom up evaluation: Step2 (a) Top-down propagation: Step1 (b) Top-down propagation: Step2 -//! -//! [1 - 3, 4 + 1] = [-2, 5] [1 - 3, 4 + 1] = [-2, 5] -//! +-----+ +-----+ +-----+ +-----+ -//! +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ -//! | | | | | | | | | | | | | | | | -//! | +-----+ | | +-----+ | | +-----+ | | +-----+ | -//! | | | | | | | | -//! +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -//! | 2 |[1, 4] | y | | 2 |[1, 4] | y | | 2 |[3, 4]** | y | | 2 |[1, 4] | y | -//! |[.] | | | |[.] | | | |[.] | | | |[.] | | | -//! +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ -//! | [-3, 1] | [-3, 1] | [0, 1] | [-3, 1] -//! | | | | -//! +---+ +---+ +---+ +---+ -//! | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [sqrt(3), 2]*** -//! +---+ +---+ +---+ +---+ -//! -//! (c) Bottom-up evaluation: Step3 (d) Bottom-up evaluation: Step4 (c) Top-down propagation: Step3 (d) Top-down propagation: Step4 -//! -//! * [-3, 1] ∩ ([4, 4] - [1, 4]) = [0, 1] -//! ** [1, 4] ∩ ([4, 4] - [0, 1]) = [3, 4] -//! *** [1, 2] ∩ [sqrt(3), sqrt(4)] = [sqrt(3), 2] -//! ``` -//! -//! # Examples -//! -//! ``` -//! # Expression: (x + 4) - y -//! -//! -//! # x: [0, 4), y: (1, 3] -//! -//! # Result: -//! ``` -//! -//! # Null handling -//! -//! use std::collections::HashSet; use std::fmt::{Display, Formatter}; @@ -132,7 +39,83 @@ use crate::PhysicalExpr; use super::IntervalBound; - +// Interval arithmetic provides a way to perform mathematical operations on +// intervals, which represent a range of possible values rather than a single +// point value. This allows for the propagation of ranges through mathematical +// operations, and can be used to compute bounds for a complicated expression. +// The key idea is that by breaking down a complicated expression into simpler +// terms, and then combining the bounds for those simpler terms, one can +// obtain bounds for the overall expression. +// +// For example, consider a mathematical expression such as x^2 + y = 4. Since +// it would be a binary tree in [PhysicalExpr] notation, this type of an +// hierarchical computation is well-suited for a graph based implementation. +// In such an implementation, an equation system f(x) = 0 is represented by a +// directed acyclic expression graph (DAEG). +// +// In order to use interval arithmetic to compute bounds for this expression, +// one would first determine intervals that represent the possible values of x +// and y. Let's say that the interval for x is [1, 2] and the interval for y +// is [-3, 1]. In the chart below, you can see how the computation takes place. +// +// This way of using interval arithmetic to compute bounds for a complex +// expression by combining the bounds for the constituent terms within the +// original expression allows us to reason about the range of possible values +// of the expression. This information later can be used in range pruning of +// the provably unnecessary parts of `RecordBatch`es. +// +// References +// 1 - Kabak, Mehmet Ozan. Analog Circuit Start-Up Behavior Analysis: An Interval +// Arithmetic Based Approach, Chapter 4. Stanford University, 2015. +// 2 - Moore, Ramon E. Interval analysis. Vol. 4. Englewood Cliffs: Prentice-Hall, 1966. +// 3 - F. Messine, "Deterministic global optimization using interval constraint +// propagation techniques," RAIRO-Operations Research, vol. 38, no. 04, +// pp. 277{293, 2004. +// +// ``` text +// Computing bounds for an expression using interval arithmetic. Constraint propagation through a top-down evaluation of the expression +// graph using inverse semantics. +// +// [-2, 5] ∩ [4, 4] = [4, 4] [4, 4] +// +-----+ +-----+ +-----+ +-----+ +// +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ +// | | | | | | | | | | | | | | | | +// | +-----+ | | +-----+ | | +-----+ | | +-----+ | +// | | | | | | | | +// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +// | 2 | | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | [0, 1]* +// |[.] | | | |[.] | | | |[.] | | | |[.] | | | +// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +// | | | [-3, 1] | +// | | | | +// +---+ +---+ +---+ +---+ +// | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [1, 2] +// +---+ +---+ +---+ +---+ +// +// (a) Bottom-up evaluation: Step1 (b) Bottom up evaluation: Step2 (a) Top-down propagation: Step1 (b) Top-down propagation: Step2 +// +// [1 - 3, 4 + 1] = [-2, 5] [1 - 3, 4 + 1] = [-2, 5] +// +-----+ +-----+ +-----+ +-----+ +// +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ +// | | | | | | | | | | | | | | | | +// | +-----+ | | +-----+ | | +-----+ | | +-----+ | +// | | | | | | | | +// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +// | 2 |[1, 4] | y | | 2 |[1, 4] | y | | 2 |[3, 4]** | y | | 2 |[1, 4] | y | +// |[.] | | | |[.] | | | |[.] | | | |[.] | | | +// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +// | [-3, 1] | [-3, 1] | [0, 1] | [-3, 1] +// | | | | +// +---+ +---+ +---+ +---+ +// | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [sqrt(3), 2]*** +// +---+ +---+ +---+ +---+ +// +// (c) Bottom-up evaluation: Step3 (d) Bottom-up evaluation: Step4 (c) Top-down propagation: Step3 (d) Top-down propagation: Step4 +// +// * [-3, 1] ∩ ([4, 4] - [1, 4]) = [0, 1] +// ** [1, 4] ∩ ([4, 4] - [0, 1]) = [3, 4] +// *** [1, 2] ∩ [sqrt(3), sqrt(4)] = [sqrt(3), 2] +// ``` /// This object implements a directed acyclic expression graph (DAEG) that /// is used to compute ranges for expressions through interval arithmetic. @@ -578,7 +561,6 @@ pub fn check_support(expr: &Arc) -> bool { #[cfg(test)] mod tests { use super::*; - use datafusion_expr::expr; use itertools::Itertools; use crate::expressions::{BinaryExpr, Column}; @@ -1141,36 +1123,6 @@ mod tests { let final_node_count = graph.node_count(); // Assert that the final node count is equal the previous node count (i.e., no node was pruned). assert_eq!(prev_node_count, final_node_count); - Ok(()) - } - - #[test] - fn my_test() -> Result<()> { - let expr_x = Arc::new(Column::new("x", 0)); - let expr_y = Arc::new(Column::new("y", 1)); - let expression = Arc::new(BinaryExpr::new( - Arc::new(BinaryExpr::new( - expr_x.clone(), - Operator::Plus, - Arc::new(Literal::new(ScalarValue::Int32(Some(4)))), - )), - Operator::Minus, - expr_y.clone(), - )); - - let mut graph = ExprIntervalGraph::try_new(expression.clone()).unwrap(); - // Pass in expr_x and expr_y so we can input their intervals - let mapping = graph.gather_node_indices(&[expr_x, expr_y]); - - let interval_x = Interval::make(Some(0), Some(4), (false, true)); - let interval_y = Interval::make(Some(1), Some(3), (true, false)); - graph.assign_intervals(&[ - (mapping[0].1, interval_x.clone()), - (mapping[1].1, interval_y.clone()), - ]); - - - Ok(()) } } From a78f8373f9609996711ce22d97e7f53173663d5d Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sun, 10 Sep 2023 16:40:58 -0700 Subject: [PATCH 07/15] refactor to use intervals --- .../simplify_expressions/expr_simplifier.rs | 84 ++- .../src/simplify_expressions/guarantees.rs | 649 ++++++++---------- .../src/intervals/interval_aritmetic.rs | 250 +++++++ 3 files changed, 605 insertions(+), 378 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 7f1c692238056..c9db9377a30dc 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -39,11 +39,13 @@ use datafusion_expr::{ and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, Like, Volatility, }; -use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; +use datafusion_physical_expr::{ + create_physical_expr, execution_props::ExecutionProps, intervals::NullableInterval, +}; use crate::simplify_expressions::SimplifyInfo; -use crate::simplify_expressions::guarantees::{Guarantee, GuaranteeRewriter}; +use crate::simplify_expressions::guarantees::GuaranteeRewriter; /// This structure handles API for expression simplification pub struct ExprSimplifier { @@ -154,18 +156,18 @@ impl ExprSimplifier { /// Input guarantees and simplify the expression. /// - /// The guarantees can simplify expressions. For example, if a column is - /// guaranteed to always be a certain value, it's references in the expression - /// can be replaced with that literal. + /// The guarantees can simplify expressions. For example, if a column `x` is + /// guaranteed to be `3`, then the expression `x > 1` can be replaced by the + /// literal `true`. /// /// ```rust /// use arrow::datatypes::{DataType, Field, Schema}; /// use datafusion_expr::{col, lit, Expr}; /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; /// use datafusion_physical_expr::execution_props::ExecutionProps; + /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; /// use datafusion_optimizer::simplify_expressions::{ - /// ExprSimplifier, SimplifyContext, - /// guarantees::{Guarantee, GuaranteeBound, NullStatus}}; + /// ExprSimplifier, SimplifyContext}; /// /// let schema = Schema::new(vec![ /// Field::new("x", DataType::Int64, false), @@ -187,17 +189,16 @@ impl ExprSimplifier { /// let expr = expr_x.and(expr_y).and(expr_z.clone()); /// /// let guarantees = vec![ - /// // x is guaranteed to be between 3 and 5 + /// // x ∈ [3, 5] /// ( /// col("x"), - /// Guarantee::new( - /// Some(GuaranteeBound::new(ScalarValue::Int64(Some(3)), false)), - /// Some(GuaranteeBound::new(ScalarValue::Int64(Some(5)), false)), - /// NullStatus::NeverNull, - /// ) + /// NullableInterval { + /// values: Interval::make(Some(3_i64), Some(5_i64), (false, false)), + /// is_valid: Interval::CERTAINLY_TRUE, + /// } /// ), - /// // y is guaranteed to be 3 - /// (col("y"), Guarantee::from(&ScalarValue::UInt32(Some(3)))), + /// // y = 3 + /// (col("y"), NullableInterval::from(&ScalarValue::UInt32(Some(3)))), /// ]; /// let output = simplifier.simplify_with_guarantees(expr, &guarantees).unwrap(); /// // Expression becomes: true AND true AND (z > 5), which simplifies to @@ -207,7 +208,7 @@ impl ExprSimplifier { pub fn simplify_with_guarantees<'a>( &self, expr: Expr, - guarantees: impl IntoIterator, + guarantees: impl IntoIterator, ) -> Result { // Do a simplification pass in case it reveals places where a guarantee // could be applied. @@ -1266,7 +1267,6 @@ mod tests { }; use crate::simplify_expressions::{ - guarantees::{GuaranteeBound, NullStatus}, utils::for_test::{cast_to_int64_expr, now_expr, to_timestamp_expr}, SimplifyContext, }; @@ -1281,7 +1281,9 @@ mod tests { use datafusion_common::{assert_contains, cast::as_int32_array, DFField, ToDFSchema}; use datafusion_expr::*; use datafusion_physical_expr::{ - execution_props::ExecutionProps, functions::make_scalar_function, + execution_props::ExecutionProps, + functions::make_scalar_function, + intervals::{Interval, NullableInterval}, }; // ------------------------------ @@ -2745,7 +2747,10 @@ mod tests { try_simplify(expr).unwrap() } - fn simplify_with_guarantee(expr: Expr, guarantees: &[(Expr, Guarantee)]) -> Expr { + fn simplify_with_guarantee( + expr: Expr, + guarantees: &[(Expr, NullableInterval)], + ) -> Expr { let schema = expr_test_schema(); let execution_props = ExecutionProps::new(); let simplifier = ExprSimplifier::new( @@ -3230,9 +3235,12 @@ mod tests { // All guaranteed null let guarantees = vec![ - (col("c3"), Guarantee::from(&ScalarValue::Int64(None))), - (col("c4"), Guarantee::from(&ScalarValue::UInt32(None))), - (col("c1"), Guarantee::from(&ScalarValue::Utf8(None))), + (col("c3"), NullableInterval::from(&ScalarValue::Int64(None))), + ( + col("c4"), + NullableInterval::from(&ScalarValue::UInt32(None)), + ), + (col("c1"), NullableInterval::from(&ScalarValue::Utf8(None))), ]; let output = simplify_with_guarantee(expr.clone(), &guarantees); @@ -3242,16 +3250,18 @@ mod tests { let guarantees = vec![ ( col("c3"), - Guarantee::new( - Some(GuaranteeBound::new(ScalarValue::Int64(Some(0)), false)), - Some(GuaranteeBound::new(ScalarValue::Int64(Some(2)), false)), - NullStatus::NeverNull, - ), + NullableInterval { + values: Interval::make(Some(0_i64), Some(2_i64), (false, false)), + is_valid: Interval::CERTAINLY_TRUE, + }, + ), + ( + col("c4"), + NullableInterval::from(&ScalarValue::UInt32(Some(9))), ), - (col("c4"), Guarantee::from(&ScalarValue::UInt32(Some(9)))), ( col("c1"), - Guarantee::from(&ScalarValue::Utf8(Some("a".to_string()))), + NullableInterval::from(&ScalarValue::Utf8(Some("a".to_string()))), ), ]; let output = simplify_with_guarantee(expr.clone(), &guarantees); @@ -3259,15 +3269,23 @@ mod tests { // Sufficient true guarantees let guarantees = vec![ - (col("c3"), Guarantee::from(&ScalarValue::Int64(Some(9)))), - (col("c4"), Guarantee::from(&ScalarValue::UInt32(Some(3)))), + ( + col("c3"), + NullableInterval::from(&ScalarValue::Int64(Some(9))), + ), + ( + col("c4"), + NullableInterval::from(&ScalarValue::UInt32(Some(3))), + ), ]; let output = simplify_with_guarantee(expr.clone(), &guarantees); assert_eq!(output, lit(true)); // Only partially simplify - let guarantees = - vec![(col("c4"), Guarantee::from(&ScalarValue::UInt32(Some(3))))]; + let guarantees = vec![( + col("c4"), + NullableInterval::from(&ScalarValue::UInt32(Some(3))), + )]; let output = simplify_with_guarantee(expr.clone(), &guarantees); assert_eq!(&output, &expr_x); } diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 9b62acd44d103..d81f98af10e4d 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -15,162 +15,24 @@ // specific language governing permissions and limitations // under the License. -//! Guarantees which can be used with [ExprSimplifier::simplify_with_guarantees()][crate::simplify_expressions::expr_simplifier::ExprSimplifier::simplify_with_guarantees]. -//! -//! Guarantees can represent single values or possible ranges of values. -//! -//! ``` -//! use datafusion_common::scalar::ScalarValue; -//! use datafusion_optimizer::simplify_expressions::guarantees::{ -//! Guarantee, GuaranteeBound, NullStatus}; -//! -//! // Guarantee that value is always 1_i32 -//! Guarantee::from(&ScalarValue::Int32(Some(1))); -//! // Guarantee that value is always NULL -//! Guarantee::from(&ScalarValue::Null); -//! // Guarantee that value is always between 1_i32 and 10_i32 (inclusive) -//! // and never null. -//! Guarantee::new( -//! Some(GuaranteeBound::new(ScalarValue::Int32(Some(1)), false)), -//! Some(GuaranteeBound::new(ScalarValue::Int32(Some(10)), false)), -//! NullStatus::NeverNull, -//! ); -//! ``` -use datafusion_common::{tree_node::TreeNodeRewriter, Result, ScalarValue}; +//! Simplifier implementation for [ExprSimplifier::simplify_with_guarantees()][crate::simplify_expressions::expr_simplifier::ExprSimplifier::simplify_with_guarantees]. +use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr, Operator}; use std::collections::HashMap; -/// A bound on the value of an expression. -#[derive(Debug, Clone, PartialEq)] -pub struct GuaranteeBound { - /// The value of the bound. If the bound is null, then there is no bound. - pub bound: ScalarValue, - /// If true, the bound is exclusive. If false, the bound is inclusive. - /// In terms of inequalities, this means the bound is `<` or `>` rather than - /// `<=` or `>=`. - pub open: bool, -} - -impl GuaranteeBound { - /// Create a new bound. - pub fn new(bound: ScalarValue, open: bool) -> Self { - Self { bound, open } - } -} - -impl Default for GuaranteeBound { - /// Default value is a closed bound at null. - fn default() -> Self { - Self { - bound: ScalarValue::Null, - open: false, - } - } -} - -/// The null status of an expression. -/// -/// This might be populated by null count statistics, for example. A null count -/// of zero would mean `NeverNull`, while a null count equal to row count would -/// mean `AlwaysNull`. -#[derive(Debug, Clone, PartialEq)] -pub enum NullStatus { - /// The expression is guaranteed to be non-null. - NeverNull, - /// The expression is guaranteed to be null. - AlwaysNull, - /// The expression isn't guaranteed to never be null or always be null. - MaybeNull, -} - -/// A set of constraints on the value of an expression. -/// -/// This is similar to [datafusion_physical_expr::intervals::Interval], except -/// that this is designed for working with logical expressions and also handles -/// nulls. -#[derive(Debug, Clone, PartialEq)] -pub struct Guarantee { - /// The min values that the expression can take on. If the min is null, then - /// there is no known min. - pub min: GuaranteeBound, - /// The max values that the expression can take on. If the max is null, - /// then there is no known max. - pub max: GuaranteeBound, - /// Whether the expression is expected to be either always null or never null. - pub null_status: NullStatus, -} - -impl Guarantee { - /// Create a new guarantee. - pub fn new( - min: Option, - max: Option, - null_status: NullStatus, - ) -> Self { - Self { - min: min.unwrap_or_default(), - max: max.unwrap_or_default(), - null_status, - } - } - - /// Whether values are guaranteed to be greater than the given value. - fn greater_than(&self, value: &ScalarValue) -> bool { - self.min.bound > *value || (self.min.bound == *value && self.min.open) - } - - /// Whether values are guaranteed to be greater than or equal to the given - /// value. - fn greater_than_or_eq(&self, value: &ScalarValue) -> bool { - self.min.bound >= *value - } - - /// Whether values are guaranteed to be less than the given value. - fn less_than(&self, value: &ScalarValue) -> bool { - self.max.bound < *value || (self.max.bound == *value && self.max.open) - } - - /// Whether values are guaranteed to be less than or equal to the given - /// value. - fn less_than_or_eq(&self, value: &ScalarValue) -> bool { - self.max.bound <= *value - } - - /// Whether the guarantee could contain the given value. - fn contains(&self, value: &ScalarValue) -> bool { - !self.less_than(value) && !self.greater_than(value) - } -} - -impl From<&ScalarValue> for Guarantee { - fn from(value: &ScalarValue) -> Self { - Self { - min: GuaranteeBound { - bound: value.clone(), - open: false, - }, - max: GuaranteeBound { - bound: value.clone(), - open: false, - }, - null_status: if value.is_null() { - NullStatus::AlwaysNull - } else { - NullStatus::NeverNull - }, - } - } -} +use datafusion_physical_expr::intervals::{Interval, IntervalBound, NullableInterval}; /// Rewrite expressions to incorporate guarantees. pub(crate) struct GuaranteeRewriter<'a> { - guarantees: HashMap<&'a Expr, &'a Guarantee>, + intervals: HashMap<&'a Expr, &'a NullableInterval>, } impl<'a> GuaranteeRewriter<'a> { - pub fn new(guarantees: impl IntoIterator) -> Self { + pub fn new( + guarantees: impl IntoIterator, + ) -> Self { Self { - guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), + intervals: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), } } } @@ -180,66 +42,63 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { fn mutate(&mut self, expr: Expr) -> Result { match &expr { - // IS NUll / NOT NUll Expr::IsNull(inner) => { - if let Some(guarantee) = self.guarantees.get(inner.as_ref()) { - match guarantee.null_status { - NullStatus::AlwaysNull => Ok(lit(true)), - NullStatus::NeverNull => Ok(lit(false)), - NullStatus::MaybeNull => Ok(expr), + if let Some(interval) = self.intervals.get(inner.as_ref()) { + if interval.is_valid == Interval::CERTAINLY_FALSE { + Ok(lit(true)) + } else if interval.is_valid == Interval::CERTAINLY_TRUE { + Ok(lit(false)) + } else { + Ok(expr) } } else { Ok(expr) } } Expr::IsNotNull(inner) => { - if let Some(guarantee) = self.guarantees.get(inner.as_ref()) { - match guarantee.null_status { - NullStatus::AlwaysNull => Ok(lit(false)), - NullStatus::NeverNull => Ok(lit(true)), - NullStatus::MaybeNull => Ok(expr), + if let Some(interval) = self.intervals.get(inner.as_ref()) { + if interval.is_valid == Interval::CERTAINLY_FALSE { + Ok(lit(false)) + } else if interval.is_valid == Interval::CERTAINLY_TRUE { + Ok(lit(true)) + } else { + Ok(expr) } } else { Ok(expr) } } - // Inequality expressions Expr::Between(Between { expr: inner, negated, low, high, }) => { - if let Some(guarantee) = self.guarantees.get(inner.as_ref()) { - match (low.as_ref(), high.as_ref()) { - (Expr::Literal(low), Expr::Literal(high)) => { - if guarantee.greater_than_or_eq(low) - && guarantee.less_than_or_eq(high) - { - // All values are between the bounds - Ok(lit(!negated)) - } else if guarantee.greater_than(high) - || guarantee.less_than(low) - { - // All values are outside the bounds - Ok(lit(*negated)) - } else { - Ok(expr) - } - } - (Expr::Literal(low), _) - if !guarantee.less_than(low) && !negated => - { - // All values are below the lower bound - Ok(lit(false)) - } - (_, Expr::Literal(high)) - if !guarantee.greater_than(high) && !negated => - { - // All values are above the upper bound - Ok(lit(false)) - } - _ => Ok(expr), + if let (Some(interval), Expr::Literal(low), Expr::Literal(high)) = ( + self.intervals.get(inner.as_ref()), + low.as_ref(), + high.as_ref(), + ) { + let expr_interval = NullableInterval { + values: Interval::new( + IntervalBound::new(low.clone(), false), + IntervalBound::new(high.clone(), false), + ), + is_valid: Interval::CERTAINLY_TRUE, + }; + + let contains = expr_interval.contains(*interval)?; + + if contains.is_valid == Interval::CERTAINLY_TRUE + && contains.values == Interval::CERTAINLY_TRUE + { + Ok(lit(!negated)) + } else if contains.is_valid == Interval::CERTAINLY_TRUE + && contains.values == Interval::CERTAINLY_FALSE + { + Ok(lit(*negated)) + } else { + Ok(expr) } } else { Ok(expr) @@ -254,7 +113,9 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { | Operator::Lt | Operator::LtEq | Operator::Gt - | Operator::GtEq => {} + | Operator::GtEq + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom => {} _ => return Ok(expr), }; @@ -262,106 +123,39 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { let (col, op, value) = match (left.as_ref(), right.as_ref()) { (Expr::Column(_), Expr::Literal(value)) => (left, *op, value), (Expr::Literal(value), Expr::Column(_)) => { - (right, op.swap().unwrap(), value) + // If we can swap the op, we can simplify the expression + if let Some(op) = op.swap() { + (right, op, value) + } else { + return Ok(expr); + } } _ => return Ok(expr), }; - // TODO: can this be simplified? - if let Some(guarantee) = self.guarantees.get(col.as_ref()) { - match op { - Operator::Eq => { - if guarantee.greater_than(value) || guarantee.less_than(value) - { - // All values are outside the bounds - Ok(lit(false)) - } else if guarantee.greater_than_or_eq(value) - && guarantee.less_than_or_eq(value) - { - // All values are equal to the bound - Ok(lit(true)) - } else { - Ok(expr) - } - } - Operator::NotEq => { - if guarantee.greater_than(value) || guarantee.less_than(value) - { - // All values are outside the bounds - Ok(lit(true)) - } else if guarantee.greater_than_or_eq(value) - && guarantee.less_than_or_eq(value) - { - // All values are equal to the bound - Ok(lit(false)) - } else { - Ok(expr) - } - } - Operator::Gt => { - if guarantee.less_than_or_eq(value) { - // All values are less than or equal to the bound - Ok(lit(false)) - } else if guarantee.greater_than(value) { - // All values are greater than the bound - Ok(lit(true)) - } else { - Ok(expr) - } - } - Operator::GtEq => { - if guarantee.less_than(value) { - // All values are less than the bound - Ok(lit(false)) - } else if guarantee.greater_than_or_eq(value) { - // All values are greater than or equal to the bound - Ok(lit(true)) - } else { - Ok(expr) - } - } - Operator::Lt => { - if guarantee.greater_than_or_eq(value) { - // All values are greater than or equal to the bound - Ok(lit(false)) - } else if guarantee.less_than(value) { - // All values are less than the bound - Ok(lit(true)) - } else { - Ok(expr) - } - } - Operator::LtEq => { - if guarantee.greater_than(value) { - // All values are greater than the bound - Ok(lit(false)) - } else if guarantee.less_than_or_eq(value) { - // All values are less than or equal to the bound - Ok(lit(true)) - } else { - Ok(expr) - } - } - _ => Ok(expr), + if let Some(col_interval) = self.intervals.get(col.as_ref()) { + let result = col_interval.apply_operator(&op, &value.into())?; + if result.is_valid == Interval::CERTAINLY_TRUE + && result.values == Interval::CERTAINLY_TRUE + { + Ok(lit(true)) + } else if result.is_valid == Interval::CERTAINLY_TRUE + && result.values == Interval::CERTAINLY_FALSE + { + Ok(lit(false)) + } else { + Ok(expr) } } else { Ok(expr) } } - // Columns (if bounds are equal and closed and column is not nullable) + // Columns (if interval is collapsed to a single value) Expr::Column(_) => { - if let Some(guarantee) = self.guarantees.get(&expr) { - if guarantee.min == guarantee.max - // Case where column has a single valid value - && ((!guarantee.min.open - && !guarantee.min.bound.is_null() - && guarantee.null_status == NullStatus::NeverNull) - // Case where column is always null - || (guarantee.min.bound.is_null() - && guarantee.null_status == NullStatus::AlwaysNull)) - { - Ok(lit(guarantee.min.bound.clone())) + if let Some(col_interval) = self.intervals.get(&expr) { + if let Some(value) = col_interval.single_value() { + Ok(lit(value)) } else { Ok(expr) } @@ -375,19 +169,29 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { list, negated, }) => { - if let Some(guarantee) = self.guarantees.get(inner.as_ref()) { + if let Some(interval) = self.intervals.get(inner.as_ref()) { // Can remove items from the list that don't match the guarantee let new_list: Vec = list .iter() - .filter(|item| { - if let Expr::Literal(item) = item { - guarantee.contains(item) + .filter_map(|expr| { + if let Expr::Literal(item) = expr { + match interval.contains(&NullableInterval::from(item)) { + // If we know for certain the value isn't in the column's interval, + // we can skip checking it. + Ok(result_interval) + if result_interval.values + == Interval::CERTAINLY_FALSE => + { + None + } + Err(err) => Some(Err(err)), + _ => Some(Ok(expr.clone())), + } } else { - true + Some(Ok(expr.clone())) } }) - .cloned() - .collect(); + .collect::>()?; Ok(Expr::InList(InList { expr: inner.clone(), @@ -408,7 +212,8 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { mod tests { use super::*; - use datafusion_common::tree_node::TreeNode; + use arrow::datatypes::DataType; + use datafusion_common::{tree_node::TreeNode, ScalarValue}; use datafusion_expr::{col, lit}; #[test] @@ -417,7 +222,13 @@ mod tests { let guarantees = vec![ // Note: AlwaysNull case handled by test_column_single_value test, // since it's a special case of a column with a single value. - (col("x"), Guarantee::new(None, None, NullStatus::NeverNull)), + ( + col("x"), + NullableInterval { + is_valid: Interval::CERTAINLY_TRUE, + ..Default::default() + }, + ), ]; let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); @@ -432,76 +243,224 @@ mod tests { assert_eq!(output, lit(true)); } + fn validate_simplified_cases(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)]) + where + ScalarValue: From, + T: Clone, + { + for (expr, expected_value) in cases { + let output = expr.clone().rewrite(rewriter).unwrap(); + let expected = lit(ScalarValue::from(expected_value.clone())); + assert_eq!( + output, expected, + "{} simplified to {}, but expected {}", + expr, output, expected + ); + } + } + + fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { + for expr in cases { + let output = expr.clone().rewrite(rewriter).unwrap(); + assert_eq!( + &output, expr, + "{} was simplified to {}, but expected it to be unchanged", + expr, output + ); + } + } + #[test] - fn test_inequalities() { + fn test_inequalities_non_null_bounded() { let guarantees = vec![ - // 1 < x <= 3 + // x ∈ (1, 3] (not null) ( col("x"), - Guarantee::new( - Some(GuaranteeBound::new(ScalarValue::Int32(Some(1)), true)), - Some(GuaranteeBound::new(ScalarValue::Int32(Some(3)), false)), - NullStatus::NeverNull, - ), - ), - // 2021-01-01 <= y - ( - col("y"), - Guarantee::new( - Some(GuaranteeBound::new(ScalarValue::Date32(Some(18628)), false)), - None, - NullStatus::NeverNull, - ), - ), - // "abc" < z <= "def" - ( - col("z"), - Guarantee::new( - Some(GuaranteeBound::new( - ScalarValue::Utf8(Some("abc".to_string())), - true, - )), - Some(GuaranteeBound::new( - ScalarValue::Utf8(Some("def".to_string())), - false, - )), - NullStatus::MaybeNull, - ), + NullableInterval { + values: Interval::make(Some(1_i32), Some(3_i32), (true, false)), + is_valid: Interval::CERTAINLY_TRUE, + }, ), ]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - // These cases should be simplified - let cases = &[ + // (original_expr, expected_simplification) + let simplified_cases = &[ (col("x").lt_eq(lit(1)), false), (col("x").lt_eq(lit(3)), true), (col("x").gt(lit(3)), false), - (col("y").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true), - (col("y").gt_eq(lit(ScalarValue::Date32(Some(17000)))), true), - (col("y").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false), + (col("x").gt(lit(1)), true), + (col("x").eq(lit(0)), false), + (col("x").not_eq(lit(0)), true), + (col("x").between(lit(2), lit(5)), true), + (col("x").between(lit(5), lit(10)), false), + (col("x").not_between(lit(2), lit(5)), false), + (col("x").not_between(lit(5), lit(10)), true), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Null)), + }), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(5)), + }), + true, + ), ]; - for (expr, expected_value) in cases { - let output = expr.clone().rewrite(&mut rewriter).unwrap(); - assert_eq!( - output, - Expr::Literal(ScalarValue::Boolean(Some(*expected_value))) - ); - } + validate_simplified_cases(&mut rewriter, simplified_cases); - // These cases should be left as-is - let cases = &[ + let unchanged_cases = &[ col("x").gt(lit(2)), col("x").lt_eq(lit(2)), - col("x").between(lit(2), lit(5)), + col("x").eq(lit(2)), + col("x").not_eq(lit(2)), + col("x").between(lit(3), lit(5)), col("x").not_between(lit(3), lit(10)), - col("y").gt(lit(ScalarValue::Date32(Some(19000)))), ]; - for expr in cases { - let output = expr.clone().rewrite(&mut rewriter).unwrap(); - assert_eq!(&output, expr); - } + validate_unchanged_cases(&mut rewriter, unchanged_cases); + } + + #[test] + fn test_inequalities_non_null_unbounded() { + let guarantees = vec![ + // y ∈ [2021-01-01, ∞) (not null) + ( + col("x"), + NullableInterval { + values: Interval::new( + IntervalBound::new(ScalarValue::Date32(Some(18628)), false), + IntervalBound::make_unbounded(DataType::Date32).unwrap(), + ), + is_valid: Interval::CERTAINLY_TRUE, + }, + ), + ]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // (original_expr, expected_simplification) + let simplified_cases = &[ + (col("x").lt(lit(ScalarValue::Date32(Some(18628)))), false), + (col("x").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false), + (col("x").gt(lit(ScalarValue::Date32(Some(18627)))), true), + (col("x").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true), + (col("x").eq(lit(ScalarValue::Date32(Some(17000)))), false), + (col("x").not_eq(lit(ScalarValue::Date32(Some(17000)))), true), + ( + col("x").between( + lit(ScalarValue::Date32(Some(16000))), + lit(ScalarValue::Date32(Some(17000))), + ), + false, + ), + ( + col("x").not_between( + lit(ScalarValue::Date32(Some(16000))), + lit(ScalarValue::Date32(Some(17000))), + ), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Null)), + }), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Date32(Some(17000)))), + }), + true, + ), + ]; + + validate_simplified_cases(&mut rewriter, simplified_cases); + + let unchanged_cases = &[ + col("x").lt(lit(ScalarValue::Date32(Some(19000)))), + col("x").lt_eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").gt(lit(ScalarValue::Date32(Some(19000)))), + col("x").gt_eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").not_eq(lit(ScalarValue::Date32(Some(19000)))), + col("x").between( + lit(ScalarValue::Date32(Some(18000))), + lit(ScalarValue::Date32(Some(19000))), + ), + col("x").not_between( + lit(ScalarValue::Date32(Some(18000))), + lit(ScalarValue::Date32(Some(19000))), + ), + ]; + + validate_unchanged_cases(&mut rewriter, unchanged_cases); + } + + #[test] + fn test_inequalities_maybe_null() { + let guarantees = vec![ + // x ∈ ("abc", "def"]? (maybe null) + ( + col("x"), + NullableInterval { + values: Interval::make(Some("abc"), Some("def"), (true, false)), + is_valid: Interval::UNCERTAIN, + }, + ), + ]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // (original_expr, expected_simplification) + let simplified_cases = &[ + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit("z")), + }), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsNotDistinctFrom, + right: Box::new(lit("z")), + }), + false, + ), + ]; + + validate_simplified_cases(&mut rewriter, simplified_cases); + + let unchanged_cases = &[ + col("x").lt(lit("z")), + col("x").lt_eq(lit("z")), + col("x").gt(lit("a")), + col("x").gt_eq(lit("a")), + col("x").eq(lit("abc")), + col("x").not_eq(lit("a")), + col("x").between(lit("a"), lit("z")), + col("x").not_between(lit("a"), lit("z")), + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Null)), + }), + ]; + + validate_unchanged_cases(&mut rewriter, unchanged_cases); } #[test] @@ -519,7 +478,8 @@ mod tests { ]; for scalar in &scalars { - let guarantees = vec![(col("x"), Guarantee::from(scalar))]; + let guarantees = vec![(col("x"), NullableInterval::from(scalar))]; + dbg!(&guarantees); let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); let output = col("x").rewrite(&mut rewriter).unwrap(); @@ -530,14 +490,13 @@ mod tests { #[test] fn test_in_list() { let guarantees = vec![ - // 1 <= x < 10 + // x ∈ [1, 10) (not null) ( col("x"), - Guarantee::new( - Some(GuaranteeBound::new(ScalarValue::Int32(Some(1)), false)), - Some(GuaranteeBound::new(ScalarValue::Int32(Some(10)), true)), - NullStatus::NeverNull, - ), + NullableInterval { + values: Interval::make(Some(1_i32), Some(10_i32), (false, true)), + is_valid: Interval::CERTAINLY_TRUE, + }, ), ]; let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs index 659a47da1504d..8316a36ec8d00 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs @@ -383,6 +383,17 @@ impl Interval { } } + /// Compute the logical negation of this (boolean) interval. + pub(crate) fn not(&self) -> Self { + if self == &Interval::CERTAINLY_TRUE { + Interval::CERTAINLY_FALSE + } else if self == &Interval::CERTAINLY_FALSE { + Interval::CERTAINLY_TRUE + } else { + Interval::UNCERTAIN + } + } + /// Compute the intersection of the interval with the given interval. /// If the intersection is empty, return None. pub(crate) fn intersect>( @@ -413,6 +424,23 @@ impl Interval { Ok(non_empty.then_some(Interval::new(lower, upper))) } + /// Decide if this interval is certainly contains, possibly contains, + /// or can't can't `other` by returning [true, true], + /// [false, true] or [false, false] respectively. + pub fn contains>(&self, other: T) -> Result { + match self.intersect(other.borrow())? { + Some(intersection) => { + // Need to compare with same bounds close-ness. + if intersection.close_bounds() == other.borrow().clone().close_bounds() { + Ok(Interval::CERTAINLY_TRUE) + } else { + Ok(Interval::UNCERTAIN) + } + } + None => Ok(Interval::CERTAINLY_FALSE), + } + } + /// Add the given interval (`other`) to this interval. Say we have /// intervals [a1, b1] and [a2, b2], then their sum is [a1 + a2, b1 + b2]. /// Note that this represents all possible values the sum can take if @@ -652,6 +680,7 @@ pub fn is_datatype_supported(data_type: &DataType) -> bool { pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { match *op { Operator::Eq => Ok(lhs.equal(rhs)), + Operator::NotEq => Ok(lhs.equal(rhs).not()), Operator::Gt => Ok(lhs.gt(rhs)), Operator::GtEq => Ok(lhs.gt_eq(rhs)), Operator::Lt => Ok(lhs.lt(rhs)), @@ -686,6 +715,227 @@ fn calculate_cardinality_based_on_bounds( } } +/// The null status of an [NullableInterval]. +/// +/// This is an internal convenience that can be used in match statements +/// (unlike Interval). +#[derive(Debug, Clone, PartialEq, Eq)] +enum NullStatus { + /// The interval is guaranteed to be non-null. + Never, + /// The interval is guaranteed to be null. + Always, + /// The interval isn't guaranteed to never be null or always be null. + Maybe, +} + +/// An [Interval] that also tracks null status using a boolean interval. +/// +/// This represents values that may be in a particular range or be null. +/// +/// # Examples +/// +/// ``` +/// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; +/// use datafusion_common::ScalarValue; +/// +/// // [1, 2) U {NULL} +/// NullableInterval { +/// values: Interval::make(Some(1), Some(2), (false, true)), +/// is_valid: Interval::UNCERTAIN, +/// } +/// +/// // (0, ∞) +/// NullableInterval { +/// values: Interval::make(Some(0), None, (true, true)), +/// is_valid: Interval::CERTAINLY_TRUE, +/// } +/// +/// // {NULL} +/// NullableInterval::from(ScalarValue::Int32(None)) +/// +/// // {4} +/// NullableInterval::from(ScalarValue::Int32(4)) +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct NullableInterval { + /// The interval for the values + pub values: Interval, + /// A boolean interval representing whether the value is certainly valid + /// (not null), certainly null, or has an unknown validity. This takes + /// precedence over the values in the interval: if this field is equal to + /// [Interval::CERTAINLY_FALSE], then the interval is certainly null + /// regardless of what `values` is. + pub is_valid: Interval, +} + +impl Display for NullableInterval { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if self.is_valid == Interval::CERTAINLY_FALSE { + write!(f, "NullableInterval: {{NULL}}") + } else if Interval::CERTAINLY_TRUE == self.is_valid { + write!(f, "NullableInterval: {}", self.values) + } else { + write!(f, "NullableInterval: {} U {{NULL}}", self.values) + } + } +} + +impl From<&ScalarValue> for NullableInterval { + /// Create an interval that represents a single value. + fn from(value: &ScalarValue) -> Self { + Self { + values: Interval::new( + IntervalBound::new(value.clone(), false), + IntervalBound::new(value.clone(), false), + ), + is_valid: if value.is_null() { + Interval::CERTAINLY_FALSE + } else { + Interval::CERTAINLY_TRUE + }, + } + } +} + +impl NullableInterval { + fn null_status(&self) -> NullStatus { + if self.is_valid == Interval::CERTAINLY_FALSE { + NullStatus::Always + } else if self.is_valid == Interval::CERTAINLY_TRUE { + NullStatus::Never + } else { + NullStatus::Maybe + } + } + + /// Apply the given operator to this interval and the given interval. + /// + /// # Examples + /// + /// ``` + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::Operator; + /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; + /// + /// // 4 > 3 -> true + /// let lhs = NullableInterval::from(&ScalarValue::Int32(Some(4))); + /// let rhs = NullableInterval::from(&ScalarValue::Int32(Some(3))); + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// assert_eq!(result, NullableInterval::from(&ScalarValue::Boolean(Some(true)))); + /// + /// // [1, 3) > NULL -> NULL + /// let lhs = NullableInterval { + /// values: Interval::make(Some(1), Some(3), (false, true)), + /// is_valid: Interval::CERTAINLY_TRUE, + /// }; + /// let rhs = NullableInterval::from(&ScalarValue::Int32(None)); + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// assert_eq!(result.single_value(), Some(ScalarValue::Boolean(None))); + /// + /// // [1, 3] > [2, 4] -> [false, true] + /// let lhs = NullableInterval { + /// values: Interval::make(Some(1), Some(3), (false, false)), + /// is_valid: Interval::CERTAINLY_TRUE, + /// }; + /// let rhs = NullableInterval { + /// values: Interval::make(Some(2), Some(4), (false, false)), + /// is_valid: Interval::CERTAINLY_TRUE, + /// }; + /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); + /// assert_eq!(result, NullableInterval { + /// // Uncertain whether inequality is true or false + /// values: Interval::UNCERTAIN, + /// // Both inputs are valid (non-null), so result must be non-null + /// is_valid: Interval::CERTAINLY_TRUE, + /// }); + /// + /// ``` + pub fn apply_operator(&self, op: &Operator, rhs: &Self) -> Result { + match op { + Operator::IsDistinctFrom => { + let values = match (self.null_status(), rhs.null_status()) { + // NULL is distinct from NULL -> False + (NullStatus::Always, NullStatus::Always) => { + Interval::CERTAINLY_FALSE + } + // NULL is distinct from non-NULL -> True + // non-NULL is distinct from NULL -> True + (NullStatus::Always, NullStatus::Never) | + (NullStatus::Never, NullStatus::Always) => { + Interval::CERTAINLY_TRUE + } + // x is distinct from y -> x != y, + // if at least one of them is never null. + (NullStatus::Never, _) | + (_, NullStatus::Never) => { + self.values.equal(&rhs.values).not() + } + _ => { + Interval::UNCERTAIN + } + }; + // IsDistinctFrom never returns null. + Ok(Self { values, is_valid: Interval::CERTAINLY_TRUE }) + }, + Operator::IsNotDistinctFrom => { + self.apply_operator(&Operator::IsDistinctFrom, rhs) + .map(|i| NullableInterval { values: i.values.not(), is_valid: i.is_valid }) + }, + _ => { + let values = apply_operator(op, &self.values, &rhs.values)?; + let is_valid = self.is_valid.and(&rhs.is_valid)?; + Ok(Self { values, is_valid }) + } + } + } + + /// Determine if this interval contains the given interval. Returns a boolean + /// interval that is [true, true] if this interval is a superset of the + /// given interval, [false, false] if this interval is disjoint from the + /// given interval, and [false, true] otherwise. + pub fn contains>(&self, other: T) -> Result { + let rhs = other.borrow(); + let values = self.values.contains(&rhs.values)?; + let is_valid = self.is_valid.and(&rhs.is_valid)?; + Ok(Self { values, is_valid }) + } + + /// If the interval has collapsed to a single value, return that value. + /// + /// Otherwise returns None. + /// + /// # Examples + /// + /// ``` + /// use datafusion_common::ScalarValue; + /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; + /// + /// let interval = NullableInterval::from(&ScalarValue::Int32(Some(4))); + /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4)))); + /// + /// let interval = NullableInterval::from(&ScalarValue::Int32(None)); + /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(None))); + /// + /// let interval = NullableInterval { + /// values: Interval::make(Some(1), Some(4), (false, true)), + /// is_valid: Interval::UNCERTAIN, + /// }; + /// assert_eq!(interval.single_value(), None); + /// ``` + pub fn single_value(&self) -> Option { + if self.is_valid == Interval::CERTAINLY_FALSE { + Some(self.values.get_datatype().and_then(ScalarValue::try_from).unwrap_or(ScalarValue::Null)) + } else if self.is_valid == Interval::CERTAINLY_TRUE && + self.values.lower.value == self.values.upper.value && + !self.values.lower.value.is_null() { + Some(self.values.lower.value.clone()) + } else { + None + } + } +} + #[cfg(test)] mod tests { use super::next_value; From 011f1768b24b0216bb40f5673ae3e91fdfc85962 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sun, 10 Sep 2023 16:59:27 -0700 Subject: [PATCH 08/15] add high-level test --- .../simplify_expressions/expr_simplifier.rs | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c9db9377a30dc..2841de1a7fa17 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -3227,11 +3227,11 @@ mod tests { #[test] fn test_simplify_with_guarantee() { - // (x >= 3) AND (y + 2 < 10 OR (z NOT IN ("a", "b"))) + // (c3 >= 3) AND (c4 + 2 < 10 OR (c1 NOT IN ("a", "b"))) let expr_x = col("c3").gt(lit(3_i64)); let expr_y = (col("c4") + lit(2_u32)).lt(lit(10_u32)); let expr_z = col("c1").in_list(vec![lit("a"), lit("b")], true); - let expr = expr_x.clone().and(expr_y.or(expr_z)); + let expr = expr_x.clone().and(expr_y.clone().or(expr_z)); // All guaranteed null let guarantees = vec![ @@ -3267,6 +3267,30 @@ mod tests { let output = simplify_with_guarantee(expr.clone(), &guarantees); assert_eq!(output, lit(false)); + // Guaranteed true or null -> no change. + let guarantees = vec![ + ( + col("c3"), + NullableInterval { + values: Interval::make(Some(0_i64), Some(2_i64), (false, false)), + is_valid: Interval::UNCERTAIN, + }, + ), + ( + col("c4"), + NullableInterval { + values: Interval::make(Some(0_u32), Some(5_u32), (false, false)), + is_valid: Interval::UNCERTAIN, + }, + ), + ( + col("c1"), + NullableInterval::from(&ScalarValue::Utf8(Some("a".to_string()))), + ), + ]; + let output = simplify_with_guarantee(expr.clone(), &guarantees); + assert_eq!(output, expr_x.clone().and(expr_y.clone())); + // Sufficient true guarantees let guarantees = vec![ ( From 4bd9b60eda3749234a84d1c99b36bad20fa944a2 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sun, 10 Sep 2023 17:06:04 -0700 Subject: [PATCH 09/15] cleanup --- .../simplify_expressions/expr_simplifier.rs | 4 + .../optimizer/src/simplify_expressions/mod.rs | 2 +- .../src/intervals/interval_aritmetic.rs | 86 ++++++++++--------- 3 files changed, 50 insertions(+), 42 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 2841de1a7fa17..edea3a7b9561a 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -160,6 +160,10 @@ impl ExprSimplifier { /// guaranteed to be `3`, then the expression `x > 1` can be replaced by the /// literal `true`. /// + /// The guarantees are provided as an iterator of `(Expr, NullableInterval)` + /// pairs, where the [Expr] is a column reference and the [NullableInterval] + /// is an interval representing the known possible values of that column. + /// /// ```rust /// use arrow::datatypes::{DataType, Field, Schema}; /// use datafusion_expr::{col, lit, Expr}; diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index b030793e67ce8..2cf6ed166cdde 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -17,7 +17,7 @@ pub mod context; pub mod expr_simplifier; -pub mod guarantees; +mod guarantees; mod or_in_list_simplifier; mod regex; pub mod simplify_exprs; diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs index 8316a36ec8d00..7dcf331751644 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs @@ -730,30 +730,30 @@ enum NullStatus { } /// An [Interval] that also tracks null status using a boolean interval. -/// +/// /// This represents values that may be in a particular range or be null. -/// +/// /// # Examples -/// +/// /// ``` /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; /// use datafusion_common::ScalarValue; -/// +/// /// // [1, 2) U {NULL} /// NullableInterval { /// values: Interval::make(Some(1), Some(2), (false, true)), /// is_valid: Interval::UNCERTAIN, /// } -/// +/// /// // (0, ∞) /// NullableInterval { /// values: Interval::make(Some(0), None, (true, true)), /// is_valid: Interval::CERTAINLY_TRUE, /// } -/// +/// /// // {NULL} /// NullableInterval::from(ScalarValue::Int32(None)) -/// +/// /// // {4} /// NullableInterval::from(ScalarValue::Int32(4)) /// ``` @@ -810,20 +810,20 @@ impl NullableInterval { } /// Apply the given operator to this interval and the given interval. - /// + /// /// # Examples - /// + /// /// ``` /// use datafusion_common::ScalarValue; /// use datafusion_expr::Operator; /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; - /// + /// /// // 4 > 3 -> true /// let lhs = NullableInterval::from(&ScalarValue::Int32(Some(4))); /// let rhs = NullableInterval::from(&ScalarValue::Int32(Some(3))); /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); /// assert_eq!(result, NullableInterval::from(&ScalarValue::Boolean(Some(true)))); - /// + /// /// // [1, 3) > NULL -> NULL /// let lhs = NullableInterval { /// values: Interval::make(Some(1), Some(3), (false, true)), @@ -832,7 +832,7 @@ impl NullableInterval { /// let rhs = NullableInterval::from(&ScalarValue::Int32(None)); /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); /// assert_eq!(result.single_value(), Some(ScalarValue::Boolean(None))); - /// + /// /// // [1, 3] > [2, 4] -> [false, true] /// let lhs = NullableInterval { /// values: Interval::make(Some(1), Some(3), (false, false)), @@ -849,39 +849,37 @@ impl NullableInterval { /// // Both inputs are valid (non-null), so result must be non-null /// is_valid: Interval::CERTAINLY_TRUE, /// }); - /// + /// /// ``` pub fn apply_operator(&self, op: &Operator, rhs: &Self) -> Result { match op { Operator::IsDistinctFrom => { let values = match (self.null_status(), rhs.null_status()) { // NULL is distinct from NULL -> False - (NullStatus::Always, NullStatus::Always) => { - Interval::CERTAINLY_FALSE - } + (NullStatus::Always, NullStatus::Always) => Interval::CERTAINLY_FALSE, // NULL is distinct from non-NULL -> True // non-NULL is distinct from NULL -> True - (NullStatus::Always, NullStatus::Never) | - (NullStatus::Never, NullStatus::Always) => { - Interval::CERTAINLY_TRUE - } + (NullStatus::Always, NullStatus::Never) + | (NullStatus::Never, NullStatus::Always) => Interval::CERTAINLY_TRUE, // x is distinct from y -> x != y, // if at least one of them is never null. - (NullStatus::Never, _) | - (_, NullStatus::Never) => { + (NullStatus::Never, _) | (_, NullStatus::Never) => { self.values.equal(&rhs.values).not() } - _ => { - Interval::UNCERTAIN - } + _ => Interval::UNCERTAIN, }; // IsDistinctFrom never returns null. - Ok(Self { values, is_valid: Interval::CERTAINLY_TRUE }) - }, - Operator::IsNotDistinctFrom => { - self.apply_operator(&Operator::IsDistinctFrom, rhs) - .map(|i| NullableInterval { values: i.values.not(), is_valid: i.is_valid }) - }, + Ok(Self { + values, + is_valid: Interval::CERTAINLY_TRUE, + }) + } + Operator::IsNotDistinctFrom => self + .apply_operator(&Operator::IsDistinctFrom, rhs) + .map(|i| NullableInterval { + values: i.values.not(), + is_valid: i.is_valid, + }), _ => { let values = apply_operator(op, &self.values, &rhs.values)?; let is_valid = self.is_valid.and(&rhs.is_valid)?; @@ -902,21 +900,21 @@ impl NullableInterval { } /// If the interval has collapsed to a single value, return that value. - /// + /// /// Otherwise returns None. - /// + /// /// # Examples - /// + /// /// ``` /// use datafusion_common::ScalarValue; /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; - /// + /// /// let interval = NullableInterval::from(&ScalarValue::Int32(Some(4))); /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4)))); - /// + /// /// let interval = NullableInterval::from(&ScalarValue::Int32(None)); /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(None))); - /// + /// /// let interval = NullableInterval { /// values: Interval::make(Some(1), Some(4), (false, true)), /// is_valid: Interval::UNCERTAIN, @@ -925,10 +923,16 @@ impl NullableInterval { /// ``` pub fn single_value(&self) -> Option { if self.is_valid == Interval::CERTAINLY_FALSE { - Some(self.values.get_datatype().and_then(ScalarValue::try_from).unwrap_or(ScalarValue::Null)) - } else if self.is_valid == Interval::CERTAINLY_TRUE && - self.values.lower.value == self.values.upper.value && - !self.values.lower.value.is_null() { + Some( + self.values + .get_datatype() + .and_then(ScalarValue::try_from) + .unwrap_or(ScalarValue::Null), + ) + } else if self.is_valid == Interval::CERTAINLY_TRUE + && self.values.lower.value == self.values.upper.value + && !self.values.lower.value.is_null() + { Some(self.values.lower.value.clone()) } else { None From 16d78c64d15946ed59c40f92e609094c50cdd3d5 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Sun, 10 Sep 2023 17:22:02 -0700 Subject: [PATCH 10/15] fix test to be false or null, not true --- .../optimizer/src/simplify_expressions/expr_simplifier.rs | 4 ++-- .../physical-expr/src/intervals/interval_aritmetic.rs | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index edea3a7b9561a..9b3f710ea8749 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -3271,7 +3271,7 @@ mod tests { let output = simplify_with_guarantee(expr.clone(), &guarantees); assert_eq!(output, lit(false)); - // Guaranteed true or null -> no change. + // Guaranteed false or null -> no change. let guarantees = vec![ ( col("c3"), @@ -3283,7 +3283,7 @@ mod tests { ( col("c4"), NullableInterval { - values: Interval::make(Some(0_u32), Some(5_u32), (false, false)), + values: Interval::make(Some(9_u32), Some(9_u32), (false, false)), is_valid: Interval::UNCERTAIN, }, ), diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs index 7dcf331751644..caf15dac7d9c9 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs @@ -743,19 +743,19 @@ enum NullStatus { /// NullableInterval { /// values: Interval::make(Some(1), Some(2), (false, true)), /// is_valid: Interval::UNCERTAIN, -/// } +/// }; /// /// // (0, ∞) /// NullableInterval { /// values: Interval::make(Some(0), None, (true, true)), /// is_valid: Interval::CERTAINLY_TRUE, -/// } +/// }; /// /// // {NULL} -/// NullableInterval::from(ScalarValue::Int32(None)) +/// NullableInterval::from(&ScalarValue::Int32(None)); /// /// // {4} -/// NullableInterval::from(ScalarValue::Int32(4)) +/// NullableInterval::from(&ScalarValue::Int32(Some(4))); /// ``` #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct NullableInterval { From bffb137002a5f92375b67b038d396ea7b988ca93 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 11 Sep 2023 22:38:55 -0700 Subject: [PATCH 11/15] refactor: change NullableInterval to an enum --- .../simplify_expressions/expr_simplifier.rs | 18 +- .../src/simplify_expressions/guarantees.rs | 77 ++---- .../src/intervals/interval_aritmetic.rs | 258 +++++++++++------- 3 files changed, 183 insertions(+), 170 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 9b3f710ea8749..078e3cb5d635a 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -196,9 +196,8 @@ impl ExprSimplifier { /// // x ∈ [3, 5] /// ( /// col("x"), - /// NullableInterval { + /// NullableInterval::NotNull { /// values: Interval::make(Some(3_i64), Some(5_i64), (false, false)), - /// is_valid: Interval::CERTAINLY_TRUE, /// } /// ), /// // y = 3 @@ -3254,9 +3253,8 @@ mod tests { let guarantees = vec![ ( col("c3"), - NullableInterval { + NullableInterval::NotNull { values: Interval::make(Some(0_i64), Some(2_i64), (false, false)), - is_valid: Interval::CERTAINLY_TRUE, }, ), ( @@ -3275,25 +3273,25 @@ mod tests { let guarantees = vec![ ( col("c3"), - NullableInterval { + NullableInterval::MaybeNull { values: Interval::make(Some(0_i64), Some(2_i64), (false, false)), - is_valid: Interval::UNCERTAIN, }, ), ( col("c4"), - NullableInterval { + NullableInterval::MaybeNull { values: Interval::make(Some(9_u32), Some(9_u32), (false, false)), - is_valid: Interval::UNCERTAIN, }, ), ( col("c1"), - NullableInterval::from(&ScalarValue::Utf8(Some("a".to_string()))), + NullableInterval::NotNull { + values: Interval::make(Some("d"), Some("f"), (false, false)), + }, ), ]; let output = simplify_with_guarantee(expr.clone(), &guarantees); - assert_eq!(output, expr_x.clone().and(expr_y.clone())); + assert_eq!(&output, &expr_x); // Sufficient true guarantees let guarantees = vec![ diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index d81f98af10e4d..a7a29f358e5fd 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -42,32 +42,16 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { fn mutate(&mut self, expr: Expr) -> Result { match &expr { - Expr::IsNull(inner) => { - if let Some(interval) = self.intervals.get(inner.as_ref()) { - if interval.is_valid == Interval::CERTAINLY_FALSE { - Ok(lit(true)) - } else if interval.is_valid == Interval::CERTAINLY_TRUE { - Ok(lit(false)) - } else { - Ok(expr) - } - } else { - Ok(expr) - } - } - Expr::IsNotNull(inner) => { - if let Some(interval) = self.intervals.get(inner.as_ref()) { - if interval.is_valid == Interval::CERTAINLY_FALSE { - Ok(lit(false)) - } else if interval.is_valid == Interval::CERTAINLY_TRUE { - Ok(lit(true)) - } else { - Ok(expr) - } - } else { - Ok(expr) - } - } + Expr::IsNull(inner) => match self.intervals.get(inner.as_ref()) { + Some(NullableInterval::Null { .. }) => Ok(lit(true)), + Some(NullableInterval::NotNull { .. }) => Ok(lit(false)), + _ => Ok(expr), + }, + Expr::IsNotNull(inner) => match self.intervals.get(inner.as_ref()) { + Some(NullableInterval::Null { .. }) => Ok(lit(false)), + Some(NullableInterval::NotNull { .. }) => Ok(lit(true)), + _ => Ok(expr), + }, Expr::Between(Between { expr: inner, negated, @@ -79,23 +63,18 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { low.as_ref(), high.as_ref(), ) { - let expr_interval = NullableInterval { + let expr_interval = NullableInterval::NotNull { values: Interval::new( IntervalBound::new(low.clone(), false), IntervalBound::new(high.clone(), false), ), - is_valid: Interval::CERTAINLY_TRUE, }; let contains = expr_interval.contains(*interval)?; - if contains.is_valid == Interval::CERTAINLY_TRUE - && contains.values == Interval::CERTAINLY_TRUE - { + if contains.is_certainly_true() { Ok(lit(!negated)) - } else if contains.is_valid == Interval::CERTAINLY_TRUE - && contains.values == Interval::CERTAINLY_FALSE - { + } else if contains.is_certainly_false() { Ok(lit(*negated)) } else { Ok(expr) @@ -135,13 +114,9 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { if let Some(col_interval) = self.intervals.get(col.as_ref()) { let result = col_interval.apply_operator(&op, &value.into())?; - if result.is_valid == Interval::CERTAINLY_TRUE - && result.values == Interval::CERTAINLY_TRUE - { + if result.is_certainly_true() { Ok(lit(true)) - } else if result.is_valid == Interval::CERTAINLY_TRUE - && result.values == Interval::CERTAINLY_FALSE - { + } else if result.is_certainly_false() { Ok(lit(false)) } else { Ok(expr) @@ -178,9 +153,8 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { match interval.contains(&NullableInterval::from(item)) { // If we know for certain the value isn't in the column's interval, // we can skip checking it. - Ok(result_interval) - if result_interval.values - == Interval::CERTAINLY_FALSE => + Ok(NullableInterval::NotNull { values }) + if values == Interval::CERTAINLY_FALSE => { None } @@ -224,9 +198,8 @@ mod tests { // since it's a special case of a column with a single value. ( col("x"), - NullableInterval { - is_valid: Interval::CERTAINLY_TRUE, - ..Default::default() + NullableInterval::NotNull { + values: Default::default(), }, ), ]; @@ -276,9 +249,8 @@ mod tests { // x ∈ (1, 3] (not null) ( col("x"), - NullableInterval { + NullableInterval::NotNull { values: Interval::make(Some(1_i32), Some(3_i32), (true, false)), - is_valid: Interval::CERTAINLY_TRUE, }, ), ]; @@ -335,12 +307,11 @@ mod tests { // y ∈ [2021-01-01, ∞) (not null) ( col("x"), - NullableInterval { + NullableInterval::NotNull { values: Interval::new( IntervalBound::new(ScalarValue::Date32(Some(18628)), false), IntervalBound::make_unbounded(DataType::Date32).unwrap(), ), - is_valid: Interval::CERTAINLY_TRUE, }, ), ]; @@ -414,9 +385,8 @@ mod tests { // x ∈ ("abc", "def"]? (maybe null) ( col("x"), - NullableInterval { + NullableInterval::MaybeNull { values: Interval::make(Some("abc"), Some("def"), (true, false)), - is_valid: Interval::UNCERTAIN, }, ), ]; @@ -493,9 +463,8 @@ mod tests { // x ∈ [1, 10) (not null) ( col("x"), - NullableInterval { + NullableInterval::NotNull { values: Interval::make(Some(1_i32), Some(10_i32), (false, true)), - is_valid: Interval::CERTAINLY_TRUE, }, ), ]; diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs index caf15dac7d9c9..f7888c9517895 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs @@ -384,13 +384,18 @@ impl Interval { } /// Compute the logical negation of this (boolean) interval. - pub(crate) fn not(&self) -> Self { + pub(crate) fn not(&self) -> Result { + if !matches!(self.get_datatype()?, DataType::Boolean) { + return internal_err!( + "Cannot apply logical negation to non-boolean interval" + ); + } if self == &Interval::CERTAINLY_TRUE { - Interval::CERTAINLY_FALSE + Ok(Interval::CERTAINLY_FALSE) } else if self == &Interval::CERTAINLY_FALSE { - Interval::CERTAINLY_TRUE + Ok(Interval::CERTAINLY_TRUE) } else { - Interval::UNCERTAIN + Ok(Interval::UNCERTAIN) } } @@ -680,7 +685,7 @@ pub fn is_datatype_supported(data_type: &DataType) -> bool { pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { match *op { Operator::Eq => Ok(lhs.equal(rhs)), - Operator::NotEq => Ok(lhs.equal(rhs).not()), + Operator::NotEq => Ok(lhs.equal(rhs).not()?), Operator::Gt => Ok(lhs.gt(rhs)), Operator::GtEq => Ok(lhs.gt_eq(rhs)), Operator::Lt => Ok(lhs.lt(rhs)), @@ -715,20 +720,6 @@ fn calculate_cardinality_based_on_bounds( } } -/// The null status of an [NullableInterval]. -/// -/// This is an internal convenience that can be used in match statements -/// (unlike Interval). -#[derive(Debug, Clone, PartialEq, Eq)] -enum NullStatus { - /// The interval is guaranteed to be non-null. - Never, - /// The interval is guaranteed to be null. - Always, - /// The interval isn't guaranteed to never be null or always be null. - Maybe, -} - /// An [Interval] that also tracks null status using a boolean interval. /// /// This represents values that may be in a particular range or be null. @@ -740,43 +731,51 @@ enum NullStatus { /// use datafusion_common::ScalarValue; /// /// // [1, 2) U {NULL} -/// NullableInterval { +/// NullableInterval::MaybeNull { /// values: Interval::make(Some(1), Some(2), (false, true)), -/// is_valid: Interval::UNCERTAIN, /// }; /// /// // (0, ∞) -/// NullableInterval { +/// NullableInterval::NotNull { /// values: Interval::make(Some(0), None, (true, true)), -/// is_valid: Interval::CERTAINLY_TRUE, /// }; /// /// // {NULL} -/// NullableInterval::from(&ScalarValue::Int32(None)); +/// NullableInterval::Null; /// /// // {4} /// NullableInterval::from(&ScalarValue::Int32(Some(4))); /// ``` -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct NullableInterval { - /// The interval for the values - pub values: Interval, - /// A boolean interval representing whether the value is certainly valid - /// (not null), certainly null, or has an unknown validity. This takes - /// precedence over the values in the interval: if this field is equal to - /// [Interval::CERTAINLY_FALSE], then the interval is certainly null - /// regardless of what `values` is. - pub is_valid: Interval, +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum NullableInterval { + /// The value is always null in this interval + /// + /// This is typed so it can be used in physical expressions, which don't do + /// type coercion. + Null { datatype: DataType }, + /// The value may or may not be null in this interval. If it is non null its value is within + /// the specified values interval + MaybeNull { values: Interval }, + /// The value is definitely not null in this interval and is within values + NotNull { values: Interval }, +} + +impl Default for NullableInterval { + fn default() -> Self { + NullableInterval::MaybeNull { + values: Interval::default(), + } + } } impl Display for NullableInterval { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - if self.is_valid == Interval::CERTAINLY_FALSE { - write!(f, "NullableInterval: {{NULL}}") - } else if Interval::CERTAINLY_TRUE == self.is_valid { - write!(f, "NullableInterval: {}", self.values) - } else { - write!(f, "NullableInterval: {} U {{NULL}}", self.values) + match self { + Self::Null { .. } => write!(f, "NullableInterval: {{NULL}}"), + Self::MaybeNull { values } => { + write!(f, "NullableInterval: {} U {{NULL}}", values) + } + Self::NotNull { values } => write!(f, "NullableInterval: {}", values), } } } @@ -784,28 +783,65 @@ impl Display for NullableInterval { impl From<&ScalarValue> for NullableInterval { /// Create an interval that represents a single value. fn from(value: &ScalarValue) -> Self { - Self { - values: Interval::new( - IntervalBound::new(value.clone(), false), - IntervalBound::new(value.clone(), false), - ), - is_valid: if value.is_null() { - Interval::CERTAINLY_FALSE - } else { - Interval::CERTAINLY_TRUE - }, + if value.is_null() { + Self::Null { datatype: value.get_datatype() } + } else { + Self::NotNull { + values: Interval::new( + IntervalBound::new(value.clone(), false), + IntervalBound::new(value.clone(), false), + ), + } } } } impl NullableInterval { - fn null_status(&self) -> NullStatus { - if self.is_valid == Interval::CERTAINLY_FALSE { - NullStatus::Always - } else if self.is_valid == Interval::CERTAINLY_TRUE { - NullStatus::Never - } else { - NullStatus::Maybe + /// Get the values interval, or None if this interval is definitely null. + pub fn values(&self) -> Option<&Interval> { + match self { + Self::Null { .. } => None, + Self::MaybeNull { values } | Self::NotNull { values } => Some(values), + } + } + + /// Get the data type + pub fn get_datatype(&self) -> Result { + match self { + Self::Null { datatype } => Ok(datatype.clone()), + Self::MaybeNull { values } | Self::NotNull { values } => { + values.get_datatype() + } + } + } + + /// Return true if the value is definitely true (and not null). + pub fn is_certainly_true(&self) -> bool { + match self { + Self::Null { .. } | Self::MaybeNull { .. } => false, + Self::NotNull { values } => values == &Interval::CERTAINLY_TRUE, + } + } + + /// Return true if the value is definitely false (and not null). + pub fn is_certainly_false(&self) -> bool { + match self { + Self::Null { .. } => false, + Self::MaybeNull { .. } => false, + Self::NotNull { values } => values == &Interval::CERTAINLY_FALSE, + } + } + + /// Perform logical negation on a boolean nullable interval. + fn not(&self) -> Result { + match self { + Self::Null { datatype } => Ok(Self::Null { datatype: datatype.clone() }), + Self::MaybeNull { values } => Ok(Self::MaybeNull { + values: values.not()?, + }), + Self::NotNull { values } => Ok(Self::NotNull { + values: values.not()?, + }), } } @@ -825,65 +861,73 @@ impl NullableInterval { /// assert_eq!(result, NullableInterval::from(&ScalarValue::Boolean(Some(true)))); /// /// // [1, 3) > NULL -> NULL - /// let lhs = NullableInterval { + /// let lhs = NullableInterval::NotNull { /// values: Interval::make(Some(1), Some(3), (false, true)), - /// is_valid: Interval::CERTAINLY_TRUE, /// }; /// let rhs = NullableInterval::from(&ScalarValue::Int32(None)); /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); - /// assert_eq!(result.single_value(), Some(ScalarValue::Boolean(None))); + /// assert_eq!(result.single_value(), Some(ScalarValue::Null)); /// /// // [1, 3] > [2, 4] -> [false, true] - /// let lhs = NullableInterval { + /// let lhs = NullableInterval::NotNull { /// values: Interval::make(Some(1), Some(3), (false, false)), - /// is_valid: Interval::CERTAINLY_TRUE, /// }; - /// let rhs = NullableInterval { + /// let rhs = NullableInterval::NotNull { /// values: Interval::make(Some(2), Some(4), (false, false)), - /// is_valid: Interval::CERTAINLY_TRUE, /// }; /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); - /// assert_eq!(result, NullableInterval { + /// // Both inputs are valid (non-null), so result must be non-null + /// assert_eq!(result, NullableInterval::NotNull { /// // Uncertain whether inequality is true or false /// values: Interval::UNCERTAIN, - /// // Both inputs are valid (non-null), so result must be non-null - /// is_valid: Interval::CERTAINLY_TRUE, /// }); /// /// ``` pub fn apply_operator(&self, op: &Operator, rhs: &Self) -> Result { match op { Operator::IsDistinctFrom => { - let values = match (self.null_status(), rhs.null_status()) { + let values = match (self, rhs) { // NULL is distinct from NULL -> False - (NullStatus::Always, NullStatus::Always) => Interval::CERTAINLY_FALSE, - // NULL is distinct from non-NULL -> True - // non-NULL is distinct from NULL -> True - (NullStatus::Always, NullStatus::Never) - | (NullStatus::Never, NullStatus::Always) => Interval::CERTAINLY_TRUE, + (Self::Null { .. }, Self::Null { .. }) => Interval::CERTAINLY_FALSE, // x is distinct from y -> x != y, // if at least one of them is never null. - (NullStatus::Never, _) | (_, NullStatus::Never) => { - self.values.equal(&rhs.values).not() + (Self::NotNull { .. }, _) | (_, Self::NotNull { .. }) => { + let lhs_values = self.values(); + let rhs_values = rhs.values(); + match (lhs_values, rhs_values) { + (Some(lhs_values), Some(rhs_values)) => { + lhs_values.equal(rhs_values).not()? + } + (Some(_), None) | (None, Some(_)) => Interval::CERTAINLY_TRUE, + (None, None) => unreachable!("Null case handled above"), + } } _ => Interval::UNCERTAIN, }; // IsDistinctFrom never returns null. - Ok(Self { - values, - is_valid: Interval::CERTAINLY_TRUE, - }) + Ok(Self::NotNull { values }) } Operator::IsNotDistinctFrom => self .apply_operator(&Operator::IsDistinctFrom, rhs) - .map(|i| NullableInterval { - values: i.values.not(), - is_valid: i.is_valid, - }), + .map(|i| i.not())?, _ => { - let values = apply_operator(op, &self.values, &rhs.values)?; - let is_valid = self.is_valid.and(&rhs.is_valid)?; - Ok(Self { values, is_valid }) + if let (Some(left_values), Some(right_values)) = + (self.values(), rhs.values()) + { + let values = apply_operator(op, left_values, right_values)?; + match (self, rhs) { + (Self::NotNull { .. }, Self::NotNull { .. }) => { + Ok(Self::NotNull { values }) + } + _ => Ok(Self::MaybeNull { values }), + } + } else { + if op.is_comparison_operator() { + Ok(Self::Null { datatype: DataType::Boolean}) + } else { + Ok(Self::Null { datatype: self.get_datatype()? }) + } + } } } } @@ -894,9 +938,17 @@ impl NullableInterval { /// given interval, and [false, true] otherwise. pub fn contains>(&self, other: T) -> Result { let rhs = other.borrow(); - let values = self.values.contains(&rhs.values)?; - let is_valid = self.is_valid.and(&rhs.is_valid)?; - Ok(Self { values, is_valid }) + if let (Some(left_values), Some(right_values)) = (self.values(), rhs.values()) { + let values = left_values.contains(right_values)?; + match (self, rhs) { + (Self::NotNull { .. }, Self::NotNull { .. }) => { + Ok(Self::NotNull { values }) + } + _ => Ok(Self::MaybeNull { values }), + } + } else { + Ok(Self::Null { datatype: DataType::Boolean }) + } } /// If the interval has collapsed to a single value, return that value. @@ -913,29 +965,23 @@ impl NullableInterval { /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4)))); /// /// let interval = NullableInterval::from(&ScalarValue::Int32(None)); - /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(None))); + /// assert_eq!(interval.single_value(), Some(ScalarValue::Null)); /// - /// let interval = NullableInterval { + /// let interval = NullableInterval::MaybeNull { /// values: Interval::make(Some(1), Some(4), (false, true)), - /// is_valid: Interval::UNCERTAIN, /// }; /// assert_eq!(interval.single_value(), None); /// ``` pub fn single_value(&self) -> Option { - if self.is_valid == Interval::CERTAINLY_FALSE { - Some( - self.values - .get_datatype() - .and_then(ScalarValue::try_from) - .unwrap_or(ScalarValue::Null), - ) - } else if self.is_valid == Interval::CERTAINLY_TRUE - && self.values.lower.value == self.values.upper.value - && !self.values.lower.value.is_null() - { - Some(self.values.lower.value.clone()) - } else { - None + match self { + Self::Null { datatype } => Some(ScalarValue::try_from(datatype).unwrap_or(ScalarValue::Null)), + Self::MaybeNull { values } | Self::NotNull { values } + if values.lower.value == values.upper.value && + !values.lower.is_unbounded() => + { + Some(values.lower.value.clone()) + } + _ => None, } } } From e4427a37e4ba60f5eb18aec8cdf29d3b0754d044 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 11 Sep 2023 22:59:41 -0700 Subject: [PATCH 12/15] refactor: use a builder-like API --- .../simplify_expressions/expr_simplifier.rs | 56 +++++++++---------- .../src/simplify_expressions/guarantees.rs | 20 ++++--- .../src/intervals/interval_aritmetic.rs | 34 +++++++---- 3 files changed, 60 insertions(+), 50 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 078e3cb5d635a..584b37adfb6cf 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -50,6 +50,8 @@ use crate::simplify_expressions::guarantees::GuaranteeRewriter; /// This structure handles API for expression simplification pub struct ExprSimplifier { info: S, + /// + guarantees: Vec<(Expr, NullableInterval)>, } pub const THRESHOLD_INLINE_INLIST: usize = 3; @@ -61,7 +63,10 @@ impl ExprSimplifier { /// /// [`SimplifyContext`]: crate::simplify_expressions::context::SimplifyContext pub fn new(info: S) -> Self { - Self { info } + Self { + info, + guarantees: vec![], + } } /// Simplifies this [`Expr`]`s as much as possible, evaluating @@ -125,6 +130,7 @@ impl ExprSimplifier { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut or_in_list_simplifier = OrInListSimplifier::new(); + let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); // TODO iterate until no changes are made during rewrite // (evaluating constants can enable new simplifications and @@ -133,6 +139,7 @@ impl ExprSimplifier { expr.rewrite(&mut const_evaluator)? .rewrite(&mut simplifier)? .rewrite(&mut or_in_list_simplifier)? + .rewrite(&mut guarantee_rewriter)? // run both passes twice to try an minimize simplifications that we missed .rewrite(&mut const_evaluator)? .rewrite(&mut simplifier) @@ -154,14 +161,14 @@ impl ExprSimplifier { expr.rewrite(&mut expr_rewrite) } - /// Input guarantees and simplify the expression. + /// Input guarantees about the values of columns. /// /// The guarantees can simplify expressions. For example, if a column `x` is /// guaranteed to be `3`, then the expression `x > 1` can be replaced by the /// literal `true`. /// - /// The guarantees are provided as an iterator of `(Expr, NullableInterval)` - /// pairs, where the [Expr] is a column reference and the [NullableInterval] + /// The guarantees are provided as a `Vec<(Expr, NullableInterval)>`, + /// where the [Expr] is a column reference and the [NullableInterval] /// is an interval representing the known possible values of that column. /// /// ```rust @@ -184,7 +191,6 @@ impl ExprSimplifier { /// let props = ExecutionProps::new(); /// let context = SimplifyContext::new(&props) /// .with_schema(schema); - /// let simplifier = ExprSimplifier::new(context); /// /// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5) /// let expr_x = col("x").gt_eq(lit(3_i64)); @@ -203,24 +209,15 @@ impl ExprSimplifier { /// // y = 3 /// (col("y"), NullableInterval::from(&ScalarValue::UInt32(Some(3)))), /// ]; - /// let output = simplifier.simplify_with_guarantees(expr, &guarantees).unwrap(); + /// let simplifier = ExprSimplifier::new(context).with_guarantees(guarantees); + /// let output = simplifier.simplify(expr).unwrap(); /// // Expression becomes: true AND true AND (z > 5), which simplifies to /// // z > 5. /// assert_eq!(output, expr_z); /// ``` - pub fn simplify_with_guarantees<'a>( - &self, - expr: Expr, - guarantees: impl IntoIterator, - ) -> Result { - // Do a simplification pass in case it reveals places where a guarantee - // could be applied. - let expr = self.simplify(expr)?; - let mut rewriter = GuaranteeRewriter::new(guarantees); - let expr = expr.rewrite(&mut rewriter)?; - // Simplify after guarantees are applied, since constant folding should - // now be able to fold more expressions. - self.simplify(expr) + pub fn with_guarantees(mut self, guarantees: Vec<(Expr, NullableInterval)>) -> Self { + self.guarantees = guarantees; + self } } @@ -2752,16 +2749,15 @@ mod tests { fn simplify_with_guarantee( expr: Expr, - guarantees: &[(Expr, NullableInterval)], + guarantees: Vec<(Expr, NullableInterval)>, ) -> Expr { let schema = expr_test_schema(); let execution_props = ExecutionProps::new(); - let simplifier = ExprSimplifier::new( + let mut simplifier = ExprSimplifier::new( SimplifyContext::new(&execution_props).with_schema(schema), - ); - simplifier - .simplify_with_guarantees(expr, guarantees) - .unwrap() + ) + .with_guarantees(guarantees); + simplifier.simplify(expr).unwrap() } fn expr_test_schema() -> DFSchemaRef { @@ -3246,7 +3242,7 @@ mod tests { (col("c1"), NullableInterval::from(&ScalarValue::Utf8(None))), ]; - let output = simplify_with_guarantee(expr.clone(), &guarantees); + let output = simplify_with_guarantee(expr.clone(), guarantees); assert_eq!(output, lit_bool_null()); // All guaranteed false @@ -3266,7 +3262,7 @@ mod tests { NullableInterval::from(&ScalarValue::Utf8(Some("a".to_string()))), ), ]; - let output = simplify_with_guarantee(expr.clone(), &guarantees); + let output = simplify_with_guarantee(expr.clone(), guarantees); assert_eq!(output, lit(false)); // Guaranteed false or null -> no change. @@ -3290,7 +3286,7 @@ mod tests { }, ), ]; - let output = simplify_with_guarantee(expr.clone(), &guarantees); + let output = simplify_with_guarantee(expr.clone(), guarantees); assert_eq!(&output, &expr_x); // Sufficient true guarantees @@ -3304,7 +3300,7 @@ mod tests { NullableInterval::from(&ScalarValue::UInt32(Some(3))), ), ]; - let output = simplify_with_guarantee(expr.clone(), &guarantees); + let output = simplify_with_guarantee(expr.clone(), guarantees); assert_eq!(output, lit(true)); // Only partially simplify @@ -3312,7 +3308,7 @@ mod tests { col("c4"), NullableInterval::from(&ScalarValue::UInt32(Some(3))), )]; - let output = simplify_with_guarantee(expr.clone(), &guarantees); + let output = simplify_with_guarantee(expr.clone(), guarantees); assert_eq!(&output, &expr_x); } } diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index a7a29f358e5fd..7ad305b0a3205 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -24,7 +24,7 @@ use datafusion_physical_expr::intervals::{Interval, IntervalBound, NullableInter /// Rewrite expressions to incorporate guarantees. pub(crate) struct GuaranteeRewriter<'a> { - intervals: HashMap<&'a Expr, &'a NullableInterval>, + guarantees: HashMap<&'a Expr, &'a NullableInterval>, } impl<'a> GuaranteeRewriter<'a> { @@ -32,7 +32,7 @@ impl<'a> GuaranteeRewriter<'a> { guarantees: impl IntoIterator, ) -> Self { Self { - intervals: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), + guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(), } } } @@ -41,13 +41,17 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { type N = Expr; fn mutate(&mut self, expr: Expr) -> Result { + if self.guarantees.is_empty() { + return Ok(expr); + } + match &expr { - Expr::IsNull(inner) => match self.intervals.get(inner.as_ref()) { + Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) { Some(NullableInterval::Null { .. }) => Ok(lit(true)), Some(NullableInterval::NotNull { .. }) => Ok(lit(false)), _ => Ok(expr), }, - Expr::IsNotNull(inner) => match self.intervals.get(inner.as_ref()) { + Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) { Some(NullableInterval::Null { .. }) => Ok(lit(false)), Some(NullableInterval::NotNull { .. }) => Ok(lit(true)), _ => Ok(expr), @@ -59,7 +63,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { high, }) => { if let (Some(interval), Expr::Literal(low), Expr::Literal(high)) = ( - self.intervals.get(inner.as_ref()), + self.guarantees.get(inner.as_ref()), low.as_ref(), high.as_ref(), ) { @@ -112,7 +116,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { _ => return Ok(expr), }; - if let Some(col_interval) = self.intervals.get(col.as_ref()) { + if let Some(col_interval) = self.guarantees.get(col.as_ref()) { let result = col_interval.apply_operator(&op, &value.into())?; if result.is_certainly_true() { Ok(lit(true)) @@ -128,7 +132,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { // Columns (if interval is collapsed to a single value) Expr::Column(_) => { - if let Some(col_interval) = self.intervals.get(&expr) { + if let Some(col_interval) = self.guarantees.get(&expr) { if let Some(value) = col_interval.single_value() { Ok(lit(value)) } else { @@ -144,7 +148,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { list, negated, }) => { - if let Some(interval) = self.intervals.get(inner.as_ref()) { + if let Some(interval) = self.guarantees.get(inner.as_ref()) { // Can remove items from the list that don't match the guarantee let new_list: Vec = list .iter() diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs index f7888c9517895..81ece2d94068e 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs @@ -749,7 +749,7 @@ fn calculate_cardinality_based_on_bounds( #[derive(Debug, Clone, PartialEq, Eq)] pub enum NullableInterval { /// The value is always null in this interval - /// + /// /// This is typed so it can be used in physical expressions, which don't do /// type coercion. Null { datatype: DataType }, @@ -784,7 +784,9 @@ impl From<&ScalarValue> for NullableInterval { /// Create an interval that represents a single value. fn from(value: &ScalarValue) -> Self { if value.is_null() { - Self::Null { datatype: value.get_datatype() } + Self::Null { + datatype: value.get_datatype(), + } } else { Self::NotNull { values: Interval::new( @@ -835,7 +837,9 @@ impl NullableInterval { /// Perform logical negation on a boolean nullable interval. fn not(&self) -> Result { match self { - Self::Null { datatype } => Ok(Self::Null { datatype: datatype.clone() }), + Self::Null { datatype } => Ok(Self::Null { + datatype: datatype.clone(), + }), Self::MaybeNull { values } => Ok(Self::MaybeNull { values: values.not()?, }), @@ -921,12 +925,14 @@ impl NullableInterval { } _ => Ok(Self::MaybeNull { values }), } + } else if op.is_comparison_operator() { + Ok(Self::Null { + datatype: DataType::Boolean, + }) } else { - if op.is_comparison_operator() { - Ok(Self::Null { datatype: DataType::Boolean}) - } else { - Ok(Self::Null { datatype: self.get_datatype()? }) - } + Ok(Self::Null { + datatype: self.get_datatype()?, + }) } } } @@ -947,7 +953,9 @@ impl NullableInterval { _ => Ok(Self::MaybeNull { values }), } } else { - Ok(Self::Null { datatype: DataType::Boolean }) + Ok(Self::Null { + datatype: DataType::Boolean, + }) } } @@ -974,10 +982,12 @@ impl NullableInterval { /// ``` pub fn single_value(&self) -> Option { match self { - Self::Null { datatype } => Some(ScalarValue::try_from(datatype).unwrap_or(ScalarValue::Null)), + Self::Null { datatype } => { + Some(ScalarValue::try_from(datatype).unwrap_or(ScalarValue::Null)) + } Self::MaybeNull { values } | Self::NotNull { values } - if values.lower.value == values.upper.value && - !values.lower.is_unbounded() => + if values.lower.value == values.upper.value + && !values.lower.is_unbounded() => { Some(values.lower.value.clone()) } From f4e868052d065b1e128ac0cac3c914cb5ff96b2e Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 12 Sep 2023 15:11:26 -0700 Subject: [PATCH 13/15] pr feedback --- .../simplify_expressions/expr_simplifier.rs | 26 +++++----- .../src/simplify_expressions/guarantees.rs | 52 ++++++++++--------- .../src/intervals/interval_aritmetic.rs | 27 +++++----- 3 files changed, 53 insertions(+), 52 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 584b37adfb6cf..4e1bc5fe27fbd 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -50,7 +50,8 @@ use crate::simplify_expressions::guarantees::GuaranteeRewriter; /// This structure handles API for expression simplification pub struct ExprSimplifier { info: S, - /// + /// Guarantees about the values of columns. This is provided by the user + /// in [ExprSimplifier::with_guarantees()]. guarantees: Vec<(Expr, NullableInterval)>, } @@ -207,7 +208,7 @@ impl ExprSimplifier { /// } /// ), /// // y = 3 - /// (col("y"), NullableInterval::from(&ScalarValue::UInt32(Some(3)))), + /// (col("y"), NullableInterval::from(ScalarValue::UInt32(Some(3)))), /// ]; /// let simplifier = ExprSimplifier::new(context).with_guarantees(guarantees); /// let output = simplifier.simplify(expr).unwrap(); @@ -2753,7 +2754,7 @@ mod tests { ) -> Expr { let schema = expr_test_schema(); let execution_props = ExecutionProps::new(); - let mut simplifier = ExprSimplifier::new( + let simplifier = ExprSimplifier::new( SimplifyContext::new(&execution_props).with_schema(schema), ) .with_guarantees(guarantees); @@ -3234,12 +3235,9 @@ mod tests { // All guaranteed null let guarantees = vec![ - (col("c3"), NullableInterval::from(&ScalarValue::Int64(None))), - ( - col("c4"), - NullableInterval::from(&ScalarValue::UInt32(None)), - ), - (col("c1"), NullableInterval::from(&ScalarValue::Utf8(None))), + (col("c3"), NullableInterval::from(ScalarValue::Int64(None))), + (col("c4"), NullableInterval::from(ScalarValue::UInt32(None))), + (col("c1"), NullableInterval::from(ScalarValue::Utf8(None))), ]; let output = simplify_with_guarantee(expr.clone(), guarantees); @@ -3255,11 +3253,11 @@ mod tests { ), ( col("c4"), - NullableInterval::from(&ScalarValue::UInt32(Some(9))), + NullableInterval::from(ScalarValue::UInt32(Some(9))), ), ( col("c1"), - NullableInterval::from(&ScalarValue::Utf8(Some("a".to_string()))), + NullableInterval::from(ScalarValue::Utf8(Some("a".to_string()))), ), ]; let output = simplify_with_guarantee(expr.clone(), guarantees); @@ -3293,11 +3291,11 @@ mod tests { let guarantees = vec![ ( col("c3"), - NullableInterval::from(&ScalarValue::Int64(Some(9))), + NullableInterval::from(ScalarValue::Int64(Some(9))), ), ( col("c4"), - NullableInterval::from(&ScalarValue::UInt32(Some(3))), + NullableInterval::from(ScalarValue::UInt32(Some(3))), ), ]; let output = simplify_with_guarantee(expr.clone(), guarantees); @@ -3306,7 +3304,7 @@ mod tests { // Only partially simplify let guarantees = vec![( col("c4"), - NullableInterval::from(&ScalarValue::UInt32(Some(3))), + NullableInterval::from(ScalarValue::UInt32(Some(3))), )]; let output = simplify_with_guarantee(expr.clone(), guarantees); assert_eq!(&output, &expr_x); diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 7ad305b0a3205..fbddca2a60966 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -17,12 +17,22 @@ //! Simplifier implementation for [ExprSimplifier::simplify_with_guarantees()][crate::simplify_expressions::expr_simplifier::ExprSimplifier::simplify_with_guarantees]. use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; -use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr, Operator}; +use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; use std::collections::HashMap; use datafusion_physical_expr::intervals::{Interval, IntervalBound, NullableInterval}; /// Rewrite expressions to incorporate guarantees. +/// +/// Guarantees are a mapping from an expression (which currently is always a +/// column reference) to a [NullableInterval]. The interval represents the known +/// possible values of the column. Using these known values, expressions are +/// rewritten so they can be simplified using [ConstEvaluator] and [Simplifier]. +/// +/// For example, if we know that a column is not null and has values in the +/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`. +/// +/// See a full example in [ExprSimplifier::with_guarantees()]. pub(crate) struct GuaranteeRewriter<'a> { guarantees: HashMap<&'a Expr, &'a NullableInterval>, } @@ -89,17 +99,9 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - // Check if this is a comparison - match op { - Operator::Eq - | Operator::NotEq - | Operator::Lt - | Operator::LtEq - | Operator::Gt - | Operator::GtEq - | Operator::IsDistinctFrom - | Operator::IsNotDistinctFrom => {} - _ => return Ok(expr), + // We only support comparisons for now + if !op.is_comparison_operator() { + return Ok(expr); }; // Check if this is a comparison between a column and literal @@ -117,7 +119,8 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { }; if let Some(col_interval) = self.guarantees.get(col.as_ref()) { - let result = col_interval.apply_operator(&op, &value.into())?; + let result = + col_interval.apply_operator(&op, &value.clone().into())?; if result.is_certainly_true() { Ok(lit(true)) } else if result.is_certainly_false() { @@ -154,16 +157,14 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { .iter() .filter_map(|expr| { if let Expr::Literal(item) = expr { - match interval.contains(&NullableInterval::from(item)) { + match interval + .contains(&NullableInterval::from(item.clone())) + { // If we know for certain the value isn't in the column's interval, // we can skip checking it. - Ok(NullableInterval::NotNull { values }) - if values == Interval::CERTAINLY_FALSE => - { - None - } - Err(err) => Some(Err(err)), - _ => Some(Ok(expr.clone())), + Ok(interval) if interval.is_certainly_false() => None, + Ok(_) => Some(Ok(expr.clone())), + Err(e) => Some(Err(e)), } } else { Some(Ok(expr.clone())) @@ -192,7 +193,7 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::{tree_node::TreeNode, ScalarValue}; - use datafusion_expr::{col, lit}; + use datafusion_expr::{col, lit, Operator}; #[test] fn test_null_handling() { @@ -270,8 +271,10 @@ mod tests { (col("x").eq(lit(0)), false), (col("x").not_eq(lit(0)), true), (col("x").between(lit(2), lit(5)), true), + (col("x").between(lit(2), lit(3)), true), (col("x").between(lit(5), lit(10)), false), (col("x").not_between(lit(2), lit(5)), false), + (col("x").not_between(lit(2), lit(3)), false), (col("x").not_between(lit(5), lit(10)), true), ( Expr::BinaryExpr(BinaryExpr { @@ -451,9 +454,8 @@ mod tests { ScalarValue::Decimal128(Some(1000), 19, 2), ]; - for scalar in &scalars { - let guarantees = vec![(col("x"), NullableInterval::from(scalar))]; - dbg!(&guarantees); + for scalar in scalars { + let guarantees = vec![(col("x"), NullableInterval::from(scalar.clone()))]; let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); let output = col("x").rewrite(&mut rewriter).unwrap(); diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs index 81ece2d94068e..f94706cafd9f5 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs @@ -727,6 +727,7 @@ fn calculate_cardinality_based_on_bounds( /// # Examples /// /// ``` +/// use arrow::datatypes::DataType; /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; /// use datafusion_common::ScalarValue; /// @@ -741,10 +742,10 @@ fn calculate_cardinality_based_on_bounds( /// }; /// /// // {NULL} -/// NullableInterval::Null; +/// NullableInterval::Null { datatype: DataType::Int32 }; /// /// // {4} -/// NullableInterval::from(&ScalarValue::Int32(Some(4))); +/// NullableInterval::from(ScalarValue::Int32(Some(4))); /// ``` #[derive(Debug, Clone, PartialEq, Eq)] pub enum NullableInterval { @@ -780,9 +781,9 @@ impl Display for NullableInterval { } } -impl From<&ScalarValue> for NullableInterval { +impl From for NullableInterval { /// Create an interval that represents a single value. - fn from(value: &ScalarValue) -> Self { + fn from(value: ScalarValue) -> Self { if value.is_null() { Self::Null { datatype: value.get_datatype(), @@ -791,7 +792,7 @@ impl From<&ScalarValue> for NullableInterval { Self::NotNull { values: Interval::new( IntervalBound::new(value.clone(), false), - IntervalBound::new(value.clone(), false), + IntervalBound::new(value, false), ), } } @@ -859,18 +860,18 @@ impl NullableInterval { /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; /// /// // 4 > 3 -> true - /// let lhs = NullableInterval::from(&ScalarValue::Int32(Some(4))); - /// let rhs = NullableInterval::from(&ScalarValue::Int32(Some(3))); + /// let lhs = NullableInterval::from(ScalarValue::Int32(Some(4))); + /// let rhs = NullableInterval::from(ScalarValue::Int32(Some(3))); /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); - /// assert_eq!(result, NullableInterval::from(&ScalarValue::Boolean(Some(true)))); + /// assert_eq!(result, NullableInterval::from(ScalarValue::Boolean(Some(true)))); /// /// // [1, 3) > NULL -> NULL /// let lhs = NullableInterval::NotNull { /// values: Interval::make(Some(1), Some(3), (false, true)), /// }; - /// let rhs = NullableInterval::from(&ScalarValue::Int32(None)); + /// let rhs = NullableInterval::from(ScalarValue::Int32(None)); /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap(); - /// assert_eq!(result.single_value(), Some(ScalarValue::Null)); + /// assert_eq!(result.single_value(), Some(ScalarValue::Boolean(None))); /// /// // [1, 3] > [2, 4] -> [false, true] /// let lhs = NullableInterval::NotNull { @@ -969,11 +970,11 @@ impl NullableInterval { /// use datafusion_common::ScalarValue; /// use datafusion_physical_expr::intervals::{Interval, NullableInterval}; /// - /// let interval = NullableInterval::from(&ScalarValue::Int32(Some(4))); + /// let interval = NullableInterval::from(ScalarValue::Int32(Some(4))); /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4)))); /// - /// let interval = NullableInterval::from(&ScalarValue::Int32(None)); - /// assert_eq!(interval.single_value(), Some(ScalarValue::Null)); + /// let interval = NullableInterval::from(ScalarValue::Int32(None)); + /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(None))); /// /// let interval = NullableInterval::MaybeNull { /// values: Interval::make(Some(1), Some(4), (false, true)), From b50df8064c4fa7bc79336a0f48e5dece391ea417 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 13 Sep 2023 12:25:32 -0400 Subject: [PATCH 14/15] Fix clippy --- datafusion/physical-expr/src/intervals/interval_aritmetic.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs index 084eec9d9d9dc..5501c8cae090b 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs @@ -767,7 +767,7 @@ impl From for NullableInterval { fn from(value: ScalarValue) -> Self { if value.is_null() { Self::Null { - datatype: value.get_datatype(), + datatype: value.data_type(), } } else { Self::NotNull { From 24529579542970d46afc0d6531f8a8ac7598365a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 13 Sep 2023 12:42:22 -0400 Subject: [PATCH 15/15] fix doc links --- .../optimizer/src/simplify_expressions/guarantees.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index fbddca2a60966..5504d7d76e359 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! Simplifier implementation for [ExprSimplifier::simplify_with_guarantees()][crate::simplify_expressions::expr_simplifier::ExprSimplifier::simplify_with_guarantees]. +//! Simplifier implementation for [`ExprSimplifier::with_guarantees()`] +//! +//! [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; use std::collections::HashMap; @@ -27,12 +29,14 @@ use datafusion_physical_expr::intervals::{Interval, IntervalBound, NullableInter /// Guarantees are a mapping from an expression (which currently is always a /// column reference) to a [NullableInterval]. The interval represents the known /// possible values of the column. Using these known values, expressions are -/// rewritten so they can be simplified using [ConstEvaluator] and [Simplifier]. +/// rewritten so they can be simplified using `ConstEvaluator` and `Simplifier`. /// /// For example, if we know that a column is not null and has values in the /// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`. /// -/// See a full example in [ExprSimplifier::with_guarantees()]. +/// See a full example in [`ExprSimplifier::with_guarantees()`]. +/// +/// [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees pub(crate) struct GuaranteeRewriter<'a> { guarantees: HashMap<&'a Expr, &'a NullableInterval>, }