diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index c31b37b63b219..4ad6952166e1d 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -1019,7 +1019,7 @@ impl ScalarValue { Self::List(scalars, Box::new(Field::new("item", child_type, true))) } - // Create a zero value in the given type. + /// Create a zero value in the given type. pub fn new_zero(datatype: &DataType) -> Result { assert!(datatype.is_primitive()); Ok(match datatype { @@ -1042,6 +1042,24 @@ impl ScalarValue { }) } + /// Create a negative one value in the given type. + pub fn new_negative_one(datatype: &DataType) -> Result { + assert!(datatype.is_primitive()); + Ok(match datatype { + DataType::Int8 | DataType::UInt8 => ScalarValue::Int8(Some(-1)), + DataType::Int16 | DataType::UInt16 => ScalarValue::Int16(Some(-1)), + DataType::Int32 | DataType::UInt32 => ScalarValue::Int32(Some(-1)), + DataType::Int64 | DataType::UInt64 => ScalarValue::Int64(Some(-1)), + DataType::Float32 => ScalarValue::Float32(Some(-1.0)), + DataType::Float64 => ScalarValue::Float64(Some(-1.0)), + _ => { + return Err(DataFusionError::NotImplemented(format!( + "Can't create a negative one scalar from data_type \"{datatype:?}\"" + ))); + } + }) + } + /// Getter for the `DataType` of the value pub fn get_datatype(&self) -> DataType { match self { diff --git a/datafusion/core/tests/simplification.rs b/datafusion/core/tests/simplification.rs index 6e74fc0d9be85..f6b944b50448e 100644 --- a/datafusion/core/tests/simplification.rs +++ b/datafusion/core/tests/simplification.rs @@ -49,6 +49,13 @@ impl SimplifyInfo for MyInfo { fn execution_props(&self) -> &ExecutionProps { &self.execution_props } + + fn get_data_type(&self, expr: &Expr) -> Result { + match expr.get_type(&self.schema) { + Ok(expr_data_type) => Ok(expr_data_type), + Err(e) => Err(e), + } + } } impl From for MyInfo { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 8b6c39043b6c9..1530513e96371 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -639,6 +639,31 @@ impl Expr { binary_expr(self, Operator::Or, other) } + /// Return `self & other` + pub fn bitwise_and(self, other: Expr) -> Expr { + binary_expr(self, Operator::BitwiseAnd, other) + } + + /// Return `self | other` + pub fn bitwise_or(self, other: Expr) -> Expr { + binary_expr(self, Operator::BitwiseOr, other) + } + + /// Return `self ^ other` + pub fn bitwise_xor(self, other: Expr) -> Expr { + binary_expr(self, Operator::BitwiseXor, other) + } + + /// Return `self >> other` + pub fn bitwise_shift_right(self, other: Expr) -> Expr { + binary_expr(self, Operator::BitwiseShiftRight, other) + } + + /// Return `self << other` + pub fn bitwise_shift_left(self, other: Expr) -> Expr { + binary_expr(self, Operator::BitwiseShiftLeft, other) + } + /// Return `!self` #[allow(clippy::should_implement_trait)] pub fn not(self) -> Expr { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 325fac57c9589..6465ca80b867a 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -112,6 +112,51 @@ pub fn count(expr: Expr) -> Expr { )) } +/// Return a new expression with bitwise AND +pub fn bitwise_and(left: Expr, right: Expr) -> Expr { + Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + Operator::BitwiseAnd, + Box::new(right), + )) +} + +/// Return a new expression with bitwise OR +pub fn bitwise_or(left: Expr, right: Expr) -> Expr { + Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + Operator::BitwiseOr, + Box::new(right), + )) +} + +/// Return a new expression with bitwise XOR +pub fn bitwise_xor(left: Expr, right: Expr) -> Expr { + Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + Operator::BitwiseXor, + Box::new(right), + )) +} + +/// Return a new expression with bitwise SHIFT RIGHT +pub fn bitwise_shift_right(left: Expr, right: Expr) -> Expr { + Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + Operator::BitwiseShiftRight, + Box::new(right), + )) +} + +/// Return a new expression with bitwise SHIFT LEFT +pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr { + Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + Operator::BitwiseShiftLeft, + Box::new(right), + )) +} + /// Create an expression to represent the count(distinct) aggregate function pub fn count_distinct(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( diff --git a/datafusion/optimizer/src/simplify_expressions/context.rs b/datafusion/optimizer/src/simplify_expressions/context.rs index 379a803f47a51..b6e3f07b29b33 100644 --- a/datafusion/optimizer/src/simplify_expressions/context.rs +++ b/datafusion/optimizer/src/simplify_expressions/context.rs @@ -38,6 +38,9 @@ pub trait SimplifyInfo { /// Returns details needed for partial expression evaluation fn execution_props(&self) -> &ExecutionProps; + + /// Returns data type of this expr needed for determining optimized int type of a value + fn get_data_type(&self, expr: &Expr) -> Result; } /// Provides simplification information based on DFSchema and @@ -123,6 +126,21 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> { }) } + /// Returns data type of this expr needed for determining optimized int type of a value + fn get_data_type(&self, expr: &Expr) -> Result { + if self.schemas.len() == 1 { + match expr.get_type(&self.schemas[0]) { + Ok(expr_data_type) => Ok(expr_data_type), + Err(e) => Err(e), + } + } else { + Err(DataFusionError::Internal( + "The expr has more than one schema, could not determine data type" + .to_string(), + )) + } + } + fn execution_props(&self) -> &ExecutionProps { self.props } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 5421c43da1236..220d532b03b14 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -70,6 +70,7 @@ impl ExprSimplifier { /// `b > 2` /// /// ``` + /// use arrow::datatypes::DataType; /// use datafusion_expr::{col, lit, Expr}; /// use datafusion_common::Result; /// use datafusion_physical_expr::execution_props::ExecutionProps; @@ -92,6 +93,9 @@ impl ExprSimplifier { /// fn execution_props(&self) -> &ExecutionProps { /// &self.execution_props /// } + /// fn get_data_type(&self, expr: &Expr) -> Result { + /// Ok(DataType::Int32) + /// } /// } /// /// // Create the simplifier @@ -337,7 +341,8 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { /// rewrite the expression simplifying any constant expressions fn mutate(&mut self, expr: Expr) -> Result { use datafusion_expr::Operator::{ - And, Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch, + And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor, + Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch, RegexNotIMatch, RegexNotMatch, }; @@ -700,6 +705,298 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)); } + // + // Rules for BitwiseAnd + // + + // A & null -> null + Expr::BinaryExpr(BinaryExpr { + left: _, + op: BitwiseAnd, + right, + }) if is_null(&right) => *right, + + // null & A -> null + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseAnd, + right: _, + }) if is_null(&left) => *left, + + // A & 0 -> 0 (if A not nullable) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseAnd, + right, + }) if !info.nullable(&left)? && is_zero(&right) => *right, + + // 0 & A -> 0 (if A not nullable) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseAnd, + right, + }) if !info.nullable(&right)? && is_zero(&left) => *left, + + // !A & A -> 0 (if A not nullable) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseAnd, + right, + }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { + Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + } + + // A & !A -> 0 (if A not nullable) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseAnd, + right, + }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { + Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + } + + // (..A..) & A --> (..A..) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseAnd, + right, + }) if expr_contains(&left, &right, BitwiseAnd) => *left, + + // A & (..A..) --> (..A..) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseAnd, + right, + }) if expr_contains(&right, &left, BitwiseAnd) => *right, + + // A & (A | B) --> A (if B not null) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseAnd, + right, + }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => { + *left + } + + // (A | B) & A --> A (if B not null) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseAnd, + right, + }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => { + *right + } + + // + // Rules for BitwiseOr + // + + // A | null -> null + Expr::BinaryExpr(BinaryExpr { + left: _, + op: BitwiseOr, + right, + }) if is_null(&right) => *right, + + // null | A -> null + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseOr, + right: _, + }) if is_null(&left) => *left, + + // A | 0 -> A (even if A is null) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseOr, + right, + }) if is_zero(&right) => *left, + + // 0 | A -> A (even if A is null) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseOr, + right, + }) if is_zero(&left) => *right, + + // !A | A -> -1 (if A not nullable) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseOr, + right, + }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { + Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + } + + // A | !A -> -1 (if A not nullable) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseOr, + right, + }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { + Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + } + + // (..A..) | A --> (..A..) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseOr, + right, + }) if expr_contains(&left, &right, BitwiseOr) => *left, + + // A | (..A..) --> (..A..) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseOr, + right, + }) if expr_contains(&right, &left, BitwiseOr) => *right, + + // A | (A & B) --> A (if B not null) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseOr, + right, + }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => { + *left + } + + // (A & B) | A --> A (if B not null) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseOr, + right, + }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => { + *right + } + + // + // Rules for BitwiseXor + // + + // A ^ null -> null + Expr::BinaryExpr(BinaryExpr { + left: _, + op: BitwiseXor, + right, + }) if is_null(&right) => *right, + + // null ^ A -> null + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseXor, + right: _, + }) if is_null(&left) => *left, + + // A ^ 0 -> A (if A not nullable) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseXor, + right, + }) if !info.nullable(&left)? && is_zero(&right) => *left, + + // 0 ^ A -> A (if A not nullable) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseXor, + right, + }) if !info.nullable(&right)? && is_zero(&left) => *right, + + // !A ^ A -> -1 (if A not nullable) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseXor, + right, + }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { + Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + } + + // A ^ !A -> -1 (if A not nullable) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseXor, + right, + }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { + Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + } + + // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseXor, + right, + }) if expr_contains(&left, &right, BitwiseXor) => { + let expr = delete_xor_in_complex_expr(&left, &right, false); + if expr == *right { + Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) + } else { + expr + } + } + + // A ^ (..A..) --> (the expression without A, if number of A is odd, otherwise one A) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseXor, + right, + }) if expr_contains(&right, &left, BitwiseXor) => { + let expr = delete_xor_in_complex_expr(&right, &left, true); + if expr == *left { + Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + } else { + expr + } + } + + // + // Rules for BitwiseShiftRight + // + + // A >> null -> null + Expr::BinaryExpr(BinaryExpr { + left: _, + op: BitwiseShiftRight, + right, + }) if is_null(&right) => *right, + + // null >> A -> null + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseShiftRight, + right: _, + }) if is_null(&left) => *left, + + // A >> 0 -> A (even if A is null) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseShiftRight, + right, + }) if is_zero(&right) => *left, + + // + // Rules for BitwiseShiftRight + // + + // A << null -> null + Expr::BinaryExpr(BinaryExpr { + left: _, + op: BitwiseShiftLeft, + right, + }) if is_null(&right) => *right, + + // null << A -> null + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseShiftLeft, + right: _, + }) if is_null(&left) => *left, + + // A << 0 -> A (even if A is null) + Expr::BinaryExpr(BinaryExpr { + left, + op: BitwiseShiftLeft, + right, + }) if is_zero(&right) => *left, + // // Rules for Not // @@ -1346,6 +1643,522 @@ mod tests { assert_eq!(simplify(expr), expected); } + #[test] + fn test_simplify_bitwise_xor_by_null() { + let null = Expr::Literal(ScalarValue::Null); + // A ^ null --> null + { + let expr = binary_expr(col("c2"), Operator::BitwiseXor, null.clone()); + assert_eq!(simplify(expr), null); + } + // null ^ A --> null + { + let expr = binary_expr(null.clone(), Operator::BitwiseXor, col("c2")); + assert_eq!(simplify(expr), null); + } + } + + #[test] + fn test_simplify_bitwise_shift_right_by_null() { + let null = Expr::Literal(ScalarValue::Null); + // A >> null --> null + { + let expr = binary_expr(col("c2"), Operator::BitwiseShiftRight, null.clone()); + assert_eq!(simplify(expr), null); + } + // null >> A --> null + { + let expr = binary_expr(null.clone(), Operator::BitwiseShiftRight, col("c2")); + assert_eq!(simplify(expr), null); + } + } + + #[test] + fn test_simplify_bitwise_shift_left_by_null() { + let null = Expr::Literal(ScalarValue::Null); + // A << null --> null + { + let expr = binary_expr(col("c2"), Operator::BitwiseShiftLeft, null.clone()); + assert_eq!(simplify(expr), null); + } + // null << A --> null + { + let expr = binary_expr(null.clone(), Operator::BitwiseShiftLeft, col("c2")); + assert_eq!(simplify(expr), null); + } + } + + #[test] + fn test_simplify_bitwise_and_by_zero() { + // A & 0 --> 0 + { + let expr = binary_expr(col("c2_non_null"), Operator::BitwiseAnd, lit(0)); + assert_eq!(simplify(expr), lit(0)); + } + // 0 & A --> 0 + { + let expr = binary_expr(lit(0), Operator::BitwiseAnd, col("c2_non_null")); + assert_eq!(simplify(expr), lit(0)); + } + } + + #[test] + fn test_simplify_bitwise_or_by_zero() { + // A | 0 --> A + { + let expr = binary_expr(col("c2_non_null"), Operator::BitwiseOr, lit(0)); + assert_eq!(simplify(expr), col("c2_non_null")); + } + // 0 | A --> A + { + let expr = binary_expr(lit(0), Operator::BitwiseOr, col("c2_non_null")); + assert_eq!(simplify(expr), col("c2_non_null")); + } + } + + #[test] + fn test_simplify_bitwise_xor_by_zero() { + // A ^ 0 --> A + { + let expr = binary_expr(col("c2_non_null"), Operator::BitwiseXor, lit(0)); + assert_eq!(simplify(expr), col("c2_non_null")); + } + // 0 ^ A --> A + { + let expr = binary_expr(lit(0), Operator::BitwiseXor, col("c2_non_null")); + assert_eq!(simplify(expr), col("c2_non_null")); + } + } + + #[test] + fn test_simplify_bitwise_bitwise_shift_right_by_zero() { + // A >> 0 --> A + { + let expr = + binary_expr(col("c2_non_null"), Operator::BitwiseShiftRight, lit(0)); + assert_eq!(simplify(expr), col("c2_non_null")); + } + } + + #[test] + fn test_simplify_bitwise_bitwise_shift_left_by_zero() { + // A << 0 --> A + { + let expr = + binary_expr(col("c2_non_null"), Operator::BitwiseShiftLeft, lit(0)); + assert_eq!(simplify(expr), col("c2_non_null")); + } + } + + #[test] + fn test_simplify_bitwise_and_by_null() { + let null = Expr::Literal(ScalarValue::Null); + // A & null --> null + { + let expr = binary_expr(col("c2"), Operator::BitwiseAnd, null.clone()); + assert_eq!(simplify(expr), null); + } + // null & A --> null + { + let expr = binary_expr(null.clone(), Operator::BitwiseAnd, col("c2")); + assert_eq!(simplify(expr), null); + } + } + + #[test] + fn test_simplify_composed_bitwise_and() { + // ((c2 > 5) & (c1 < 6)) & (c2 > 5) --> (c2 > 5) & (c1 < 6) + + let expr = binary_expr( + binary_expr( + col("c2").gt(lit(5)), + Operator::BitwiseAnd, + col("c1").lt(lit(6)), + ), + Operator::BitwiseAnd, + col("c2").gt(lit(5)), + ); + let expected = binary_expr( + col("c2").gt(lit(5)), + Operator::BitwiseAnd, + col("c1").lt(lit(6)), + ); + + assert_eq!(simplify(expr), expected); + + // (c2 > 5) & ((c2 > 5) & (c1 < 6)) --> (c2 > 5) & (c1 < 6) + + let expr = binary_expr( + col("c2").gt(lit(5)), + Operator::BitwiseAnd, + binary_expr( + col("c2").gt(lit(5)), + Operator::BitwiseAnd, + col("c1").lt(lit(6)), + ), + ); + let expected = binary_expr( + col("c2").gt(lit(5)), + Operator::BitwiseAnd, + col("c1").lt(lit(6)), + ); + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_composed_bitwise_or() { + // ((c2 > 5) | (c1 < 6)) | (c2 > 5) --> (c2 > 5) | (c1 < 6) + + let expr = binary_expr( + binary_expr( + col("c2").gt(lit(5)), + Operator::BitwiseOr, + col("c1").lt(lit(6)), + ), + Operator::BitwiseOr, + col("c2").gt(lit(5)), + ); + let expected = binary_expr( + col("c2").gt(lit(5)), + Operator::BitwiseOr, + col("c1").lt(lit(6)), + ); + + assert_eq!(simplify(expr), expected); + + // (c2 > 5) | ((c2 > 5) | (c1 < 6)) --> (c2 > 5) | (c1 < 6) + + let expr = binary_expr( + col("c2").gt(lit(5)), + Operator::BitwiseOr, + binary_expr( + col("c2").gt(lit(5)), + Operator::BitwiseOr, + col("c1").lt(lit(6)), + ), + ); + let expected = binary_expr( + col("c2").gt(lit(5)), + Operator::BitwiseOr, + col("c1").lt(lit(6)), + ); + + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_composed_bitwise_xor() { + // with an even number of the column "c2" + // c2 ^ ((c2 ^ (c2 | c1)) ^ (c1 & c2)) --> (c2 | c1) ^ (c1 & c2) + + let expr = binary_expr( + col("c2"), + Operator::BitwiseXor, + binary_expr( + binary_expr( + col("c2"), + Operator::BitwiseXor, + binary_expr(col("c2"), Operator::BitwiseOr, col("c1")), + ), + Operator::BitwiseXor, + binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")), + ), + ); + + let expected = binary_expr( + binary_expr(col("c2"), Operator::BitwiseOr, col("c1")), + Operator::BitwiseXor, + binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")), + ); + + assert_eq!(simplify(expr), expected); + + // with an odd number of the column "c2" + // c2 ^ (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) --> c2 ^ ((c2 | c1) ^ (c1 & c2)) + + let expr = binary_expr( + col("c2"), + Operator::BitwiseXor, + binary_expr( + binary_expr( + col("c2"), + Operator::BitwiseXor, + binary_expr(col("c2"), Operator::BitwiseOr, col("c1")), + ), + Operator::BitwiseXor, + binary_expr( + binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")), + Operator::BitwiseXor, + col("c2"), + ), + ), + ); + + let expected = binary_expr( + col("c2"), + Operator::BitwiseXor, + binary_expr( + binary_expr(col("c2"), Operator::BitwiseOr, col("c1")), + Operator::BitwiseXor, + binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")), + ), + ); + + assert_eq!(simplify(expr), expected); + + // with an even number of the column "c2" + // ((c2 ^ (c2 | c1)) ^ (c1 & c2)) ^ c2 --> (c2 | c1) ^ (c1 & c2) + + let expr = binary_expr( + binary_expr( + binary_expr( + col("c2"), + Operator::BitwiseXor, + binary_expr(col("c2"), Operator::BitwiseOr, col("c1")), + ), + Operator::BitwiseXor, + binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")), + ), + Operator::BitwiseXor, + col("c2"), + ); + + let expected = binary_expr( + binary_expr(col("c2"), Operator::BitwiseOr, col("c1")), + Operator::BitwiseXor, + binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")), + ); + + assert_eq!(simplify(expr), expected); + + // with an odd number of the column "c2" + // (c2 ^ (c2 | c1)) ^ ((c1 & c2) ^ c2) ^ c2 --> ((c2 | c1) ^ (c1 & c2)) ^ c2 + + let expr = binary_expr( + binary_expr( + binary_expr( + col("c2"), + Operator::BitwiseXor, + binary_expr(col("c2"), Operator::BitwiseOr, col("c1")), + ), + Operator::BitwiseXor, + binary_expr( + binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")), + Operator::BitwiseXor, + col("c2"), + ), + ), + Operator::BitwiseXor, + col("c2"), + ); + + let expected = binary_expr( + binary_expr( + binary_expr(col("c2"), Operator::BitwiseOr, col("c1")), + Operator::BitwiseXor, + binary_expr(col("c1"), Operator::BitwiseAnd, col("c2")), + ), + Operator::BitwiseXor, + col("c2"), + ); + + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_negated_bitwise_and() { + // !c4 & c4 --> 0 + let expr = binary_expr( + Expr::Negative(Box::new(col("c4_non_null"))), + Operator::BitwiseAnd, + col("c4_non_null"), + ); + let expected = Expr::Literal(ScalarValue::UInt32(Some(0))); + + assert_eq!(simplify(expr), expected); + // c4 & !c4 --> 0 + let expr = binary_expr( + col("c4_non_null"), + Operator::BitwiseAnd, + Expr::Negative(Box::new(col("c4_non_null"))), + ); + let expected = Expr::Literal(ScalarValue::UInt32(Some(0))); + + assert_eq!(simplify(expr), expected); + + // !c3 & c3 --> 0 + let expr = binary_expr( + Expr::Negative(Box::new(col("c3_non_null"))), + Operator::BitwiseAnd, + col("c3_non_null"), + ); + let expected = Expr::Literal(ScalarValue::Int64(Some(0))); + + assert_eq!(simplify(expr), expected); + // c3 & !c3 --> 0 + let expr = binary_expr( + col("c3_non_null"), + Operator::BitwiseAnd, + Expr::Negative(Box::new(col("c3_non_null"))), + ); + let expected = Expr::Literal(ScalarValue::Int64(Some(0))); + + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_negated_bitwise_or() { + // !c4 | c4 --> -1 + let expr = binary_expr( + Expr::Negative(Box::new(col("c4_non_null"))), + Operator::BitwiseOr, + col("c4_non_null"), + ); + let expected = Expr::Literal(ScalarValue::Int32(Some(-1))); + + assert_eq!(simplify(expr), expected); + + // c4 | !c4 --> -1 + let expr = binary_expr( + col("c4_non_null"), + Operator::BitwiseOr, + Expr::Negative(Box::new(col("c4_non_null"))), + ); + let expected = Expr::Literal(ScalarValue::Int32(Some(-1))); + + assert_eq!(simplify(expr), expected); + + // !c3 | c3 --> -1 + let expr = binary_expr( + Expr::Negative(Box::new(col("c3_non_null"))), + Operator::BitwiseOr, + col("c3_non_null"), + ); + let expected = Expr::Literal(ScalarValue::Int64(Some(-1))); + + assert_eq!(simplify(expr), expected); + + // c3 | !c3 --> -1 + let expr = binary_expr( + col("c3_non_null"), + Operator::BitwiseOr, + Expr::Negative(Box::new(col("c3_non_null"))), + ); + let expected = Expr::Literal(ScalarValue::Int64(Some(-1))); + + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_negated_bitwise_xor() { + // !c4 ^ c4 --> -1 + let expr = binary_expr( + Expr::Negative(Box::new(col("c4_non_null"))), + Operator::BitwiseXor, + col("c4_non_null"), + ); + let expected = Expr::Literal(ScalarValue::Int32(Some(-1))); + + assert_eq!(simplify(expr), expected); + + // c4 ^ !c4 --> -1 + let expr = binary_expr( + col("c4_non_null"), + Operator::BitwiseXor, + Expr::Negative(Box::new(col("c4_non_null"))), + ); + let expected = Expr::Literal(ScalarValue::Int32(Some(-1))); + + assert_eq!(simplify(expr), expected); + + // !c3 ^ c3 --> -1 + let expr = binary_expr( + Expr::Negative(Box::new(col("c3_non_null"))), + Operator::BitwiseXor, + col("c3_non_null"), + ); + let expected = Expr::Literal(ScalarValue::Int64(Some(-1))); + + assert_eq!(simplify(expr), expected); + + // c3 ^ !c3 --> -1 + let expr = binary_expr( + col("c3_non_null"), + Operator::BitwiseXor, + Expr::Negative(Box::new(col("c3_non_null"))), + ); + let expected = Expr::Literal(ScalarValue::Int64(Some(-1))); + + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_bitwise_and_or() { + // (c2 < 3) & ((c2 < 3) | c1) -> (c2 < 3) + let expr = binary_expr( + col("c2_non_null").lt(lit(3)), + Operator::BitwiseAnd, + binary_expr( + col("c2_non_null").lt(lit(3)), + Operator::BitwiseOr, + col("c1_non_null"), + ), + ); + let expected = col("c2_non_null").lt(lit(3)); + + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_bitwise_or_and() { + // (c2 < 3) | ((c2 < 3) & c1) -> (c2 < 3) + let expr = binary_expr( + col("c2_non_null").lt(lit(3)), + Operator::BitwiseOr, + binary_expr( + col("c2_non_null").lt(lit(3)), + Operator::BitwiseAnd, + col("c1_non_null"), + ), + ); + let expected = col("c2_non_null").lt(lit(3)); + + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_simple_bitwise_and() { + // (c2 > 5) & (c2 > 5) -> (c2 > 5) + let expr = (col("c2").gt(lit(5))).bitwise_and(col("c2").gt(lit(5))); + let expected = col("c2").gt(lit(5)); + + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_simple_bitwise_or() { + // (c2 > 5) | (c2 > 5) -> (c2 > 5) + let expr = (col("c2").gt(lit(5))).bitwise_or(col("c2").gt(lit(5))); + let expected = col("c2").gt(lit(5)); + + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_simple_bitwise_xor() { + // c4 ^ c4 -> 0 + let expr = (col("c4")).bitwise_xor(col("c4")); + let expected = Expr::Literal(ScalarValue::UInt32(Some(0))); + + assert_eq!(simplify(expr), expected); + + // c3 ^ c3 -> 0 + let expr = col("c3").bitwise_xor(col("c3")); + let expected = Expr::Literal(ScalarValue::Int64(Some(0))); + + assert_eq!(simplify(expr), expected); + } + #[test] #[should_panic( expected = "called `Result::unwrap()` on an `Err` value: ArrowError(DivideByZero)" @@ -1357,7 +2170,7 @@ mod tests { #[test] fn test_simplify_simple_and() { - // (c > 5) AND (c > 5) + // (c2 > 5) AND (c2 > 5) -> (c2 > 5) let expr = (col("c2").gt(lit(5))).and(col("c2").gt(lit(5))); let expected = col("c2").gt(lit(5)); @@ -1366,7 +2179,7 @@ mod tests { #[test] fn test_simplify_composed_and() { - // ((c > 5) AND (c1 < 6)) AND (c > 5) + // ((c2 > 5) AND (c1 < 6)) AND (c2 > 5) let expr = binary_expr( binary_expr(col("c2").gt(lit(5)), Operator::And, col("c1").lt(lit(6))), Operator::And, @@ -1380,7 +2193,7 @@ mod tests { #[test] fn test_simplify_negated_and() { - // (c > 5) AND !(c > 5) -- > (c > 5) AND (c <= 5) + // (c2 > 5) AND !(c2 > 5) --> (c2 > 5) AND (c2 <= 5) let expr = binary_expr( col("c2").gt(lit(5)), Operator::And, @@ -1760,8 +2573,12 @@ mod tests { vec![ DFField::new(None, "c1", DataType::Utf8, true), DFField::new(None, "c2", DataType::Boolean, true), + DFField::new(None, "c3", DataType::Int64, true), + DFField::new(None, "c4", DataType::UInt32, true), DFField::new(None, "c1_non_null", DataType::Utf8, false), DFField::new(None, "c2_non_null", DataType::Boolean, false), + DFField::new(None, "c3_non_null", DataType::Int64, false), + DFField::new(None, "c4_non_null", DataType::UInt32, false), ], HashMap::new(), ) @@ -1808,7 +2625,7 @@ mod tests { let schema = expr_test_schema(); assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean); - // true = ture -> true + // true = true -> true assert_eq!(simplify(lit(true).eq(lit(true))), lit(true)); // true = false -> false diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 10e5c0e874304..352674c3a68e8 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -77,6 +77,61 @@ pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { } } +/// Deletes all 'needles' or remains one 'needle' that are found in a chain of xor +/// expressions. Such as: A ^ (A ^ (B ^ A)) +pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> Expr { + /// Deletes recursively 'needles' in a chain of xor expressions + fn recursive_delete_xor_in_expr( + expr: &Expr, + needle: &Expr, + xor_counter: &mut i32, + ) -> Expr { + match expr { + Expr::BinaryExpr(BinaryExpr { left, op, right }) + if *op == Operator::BitwiseXor => + { + let left_expr = recursive_delete_xor_in_expr(left, needle, xor_counter); + let right_expr = recursive_delete_xor_in_expr(right, needle, xor_counter); + if left_expr == *needle { + *xor_counter += 1; + return right_expr; + } else if right_expr == *needle { + *xor_counter += 1; + return left_expr; + } + + Expr::BinaryExpr(BinaryExpr::new( + Box::new(left_expr), + *op, + Box::new(right_expr), + )) + } + _ => expr.clone(), + } + } + + let mut xor_counter: i32 = 0; + let result_expr = recursive_delete_xor_in_expr(expr, needle, &mut xor_counter); + if result_expr == *needle { + return needle.clone(); + } else if xor_counter % 2 == 0 { + if is_left { + return Expr::BinaryExpr(BinaryExpr::new( + Box::new(needle.clone()), + Operator::BitwiseXor, + Box::new(result_expr), + )); + } else { + return Expr::BinaryExpr(BinaryExpr::new( + Box::new(result_expr), + Operator::BitwiseXor, + Box::new(needle.clone()), + )); + } + } + result_expr +} + pub fn is_zero(s: &Expr) -> bool { match s { Expr::Literal(ScalarValue::Int8(Some(0))) @@ -154,11 +209,16 @@ pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool { matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref())) } -/// returns true if `not_expr` is !`expr` +/// returns true if `not_expr` is !`expr` (not) pub fn is_not_of(not_expr: &Expr, expr: &Expr) -> bool { matches!(not_expr, Expr::Not(inner) if expr == inner.as_ref()) } +/// returns true if `not_expr` is !`expr` (bitwise not) +pub fn is_negative_of(not_expr: &Expr, expr: &Expr) -> bool { + matches!(not_expr, Expr::Negative(inner) if expr == inner.as_ref()) +} + /// returns the contained boolean value in `expr` as /// `Expr::Literal(ScalarValue::Boolean(v))`. pub fn as_bool_lit(expr: Expr) -> Result> {