From 18683d67dd4fcdce7fc172cda762cd0e1341f84e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 11 May 2022 09:22:41 -0400 Subject: [PATCH] Remove binary_array_op_dyn_scalar! --- .../physical-expr/src/expressions/binary.rs | 81 ++++--------------- .../physical-expr/src/expressions/nullif.rs | 18 ++--- 2 files changed, 22 insertions(+), 77 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index b3de2461cdd0c..9c6eedad1ac76 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -656,23 +656,6 @@ macro_rules! compute_utf8_op_dyn_scalar { }}; } -/// Invoke a compute kernel on a boolean data array and a scalar value -macro_rules! compute_bool_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - use std::convert::TryInto; - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast array"); - // generate the scalar function name, such as lt_scalar, from the $OP parameter - // (which could have a value of lt) and the suffix _scalar - Ok(Arc::new(paste::expr! {[<$OP _bool_scalar>]}( - &ll, - $RIGHT.try_into()?, - )?)) - }}; -} - /// Invoke a compute kernel on a boolean data array and a scalar value macro_rules! compute_bool_op_dyn_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ @@ -852,52 +835,6 @@ macro_rules! binary_primitive_array_op_scalar { }}; } -/// The binary_array_op_scalar macro includes types that extend beyond the primitive, -/// such as Utf8 strings. -#[macro_export] -macro_rules! binary_array_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ - let result: Result> = match $LEFT.data_type() { - DataType::Decimal(_,_) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, DecimalArray), - DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array), - DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array), - DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array), - DataType::Int64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int64Array), - DataType::UInt8 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt8Array), - DataType::UInt16 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt16Array), - DataType::UInt32 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt32Array), - DataType::UInt64 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt64Array), - DataType::Float32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array), - DataType::Float64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array), - DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray), - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampNanosecondArray) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMicrosecondArray) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMillisecondArray) - } - DataType::Timestamp(TimeUnit::Second, _) => { - compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampSecondArray) - } - DataType::Date32 => { - compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array) - } - DataType::Date64 => { - compute_op_scalar!($LEFT, $RIGHT, $OP, Date64Array) - } - DataType::Boolean => compute_bool_op_scalar!($LEFT, $RIGHT, $OP, BooleanArray), - other => Err(DataFusionError::Internal(format!( - "Data type {:?} not supported for scalar operation '{}' on dyn array", - other, stringify!($OP) - ))), - }; - Some(result) - }}; -} - /// The binary_array_op macro includes types that extend beyond the primitive, /// such as Utf8 strings. #[macro_export] @@ -1134,6 +1071,20 @@ macro_rules! binary_array_op_dyn_scalar { }} } +/// Compares the array with the scalar value for equality, sometimes +/// used in other kernels +pub(crate) fn array_eq_scalar(lhs: &dyn Array, rhs: &ScalarValue) -> Result { + binary_array_op_dyn_scalar!(lhs, rhs.clone(), eq, &DataType::Boolean).ok_or_else( + || { + DataFusionError::Internal(format!( + "Data type {:?} and scalar {:?} not supported for array_eq_scalar", + lhs.data_type(), + rhs.get_datatype() + )) + }, + )? +} + impl BinaryExpr { /// Evaluate the expression of the left input is an array and /// right is literal - use scalar operations @@ -1366,10 +1317,6 @@ fn is_not_distinct_from_null( make_boolean_array(length, true) } -pub fn eq_null(left: &NullArray, _right: &NullArray) -> Result { - Ok((0..left.len()).into_iter().map(|_| None).collect()) -} - fn make_boolean_array(length: usize, value: bool) -> Result { Ok((0..length).into_iter().map(|_| Some(value)).collect()) } diff --git a/datafusion/physical-expr/src/expressions/nullif.rs b/datafusion/physical-expr/src/expressions/nullif.rs index 2d1f3654d241d..2ef40272a8ee2 100644 --- a/datafusion/physical-expr/src/expressions/nullif.rs +++ b/datafusion/physical-expr/src/expressions/nullif.rs @@ -17,18 +17,16 @@ use std::sync::Arc; -use crate::expressions::binary::{eq_decimal, eq_decimal_scalar, eq_null}; use arrow::array::Array; use arrow::array::*; +use arrow::compute::eq_dyn; use arrow::compute::kernels::boolean::nullif; -use arrow::compute::kernels::comparison::{ - eq, eq_bool, eq_bool_scalar, eq_scalar, eq_utf8, eq_utf8_scalar, -}; -use arrow::datatypes::{DataType, TimeUnit}; -use datafusion_common::ScalarValue; +use arrow::datatypes::DataType; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; +use super::binary::array_eq_scalar; + /// Invoke a compute kernel on a primitive array and a Boolean Array macro_rules! compute_bool_array_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ @@ -82,7 +80,7 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result { match (lhs, rhs) { (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => { - let cond_array = binary_array_op_scalar!(lhs, rhs.clone(), eq).unwrap()?; + let cond_array = array_eq_scalar(lhs, rhs)?; let array = primitive_bool_array_op!(lhs, *cond_array, nullif)?; @@ -90,10 +88,10 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result { } (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => { // Get args0 == args1 evaluated and produce a boolean array - let cond_array = binary_array_op!(lhs, rhs, eq)?; + let cond_array = eq_dyn(lhs, rhs)?; // Now, invoke nullif on the result - let array = primitive_bool_array_op!(lhs, *cond_array, nullif)?; + let array = primitive_bool_array_op!(lhs, cond_array, nullif)?; Ok(ColumnarValue::Array(array)) } _ => Err(DataFusionError::NotImplemented( @@ -105,7 +103,7 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result { #[cfg(test)] mod tests { use super::*; - use datafusion_common::Result; + use datafusion_common::{Result, ScalarValue}; #[test] fn nullif_int32() -> Result<()> {