Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 14 additions & 67 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -656,23 +656,6 @@ macro_rules! compute_utf8_op_dyn_scalar {
}};
}

/// Invoke a compute kernel on a boolean data array and a scalar value
macro_rules! compute_bool_op_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
use std::convert::TryInto;
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
// generate the scalar function name, such as lt_scalar, from the $OP parameter
// (which could have a value of lt) and the suffix _scalar
Ok(Arc::new(paste::expr! {[<$OP _bool_scalar>]}(
&ll,
$RIGHT.try_into()?,
)?))
}};
}

/// Invoke a compute kernel on a boolean data array and a scalar value
macro_rules! compute_bool_op_dyn_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{
Expand Down Expand Up @@ -852,52 +835,6 @@ macro_rules! binary_primitive_array_op_scalar {
}};
}

/// The binary_array_op_scalar macro includes types that extend beyond the primitive,
/// such as Utf8 strings.
#[macro_export]
macro_rules! binary_array_op_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
DataType::Decimal(_,_) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, DecimalArray),
DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array),
DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array),
DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array),
DataType::Int64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int64Array),
DataType::UInt8 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt8Array),
DataType::UInt16 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt16Array),
DataType::UInt32 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt32Array),
DataType::UInt64 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt64Array),
DataType::Float32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array),
DataType::Float64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array),
DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray),
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampNanosecondArray)
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMicrosecondArray)
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMillisecondArray)
}
DataType::Timestamp(TimeUnit::Second, _) => {
compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampSecondArray)
}
DataType::Date32 => {
compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array)
}
DataType::Date64 => {
compute_op_scalar!($LEFT, $RIGHT, $OP, Date64Array)
}
DataType::Boolean => compute_bool_op_scalar!($LEFT, $RIGHT, $OP, BooleanArray),
other => Err(DataFusionError::Internal(format!(
"Data type {:?} not supported for scalar operation '{}' on dyn array",
other, stringify!($OP)
))),
};
Some(result)
}};
}

/// The binary_array_op macro includes types that extend beyond the primitive,
/// such as Utf8 strings.
#[macro_export]
Expand Down Expand Up @@ -1134,6 +1071,20 @@ macro_rules! binary_array_op_dyn_scalar {
}}
}

/// Compares the array with the scalar value for equality, sometimes
/// used in other kernels
pub(crate) fn array_eq_scalar(lhs: &dyn Array, rhs: &ScalarValue) -> Result<ArrayRef> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a function to call binary_array_op_dyn_scalar! from null_if rather than having it call the macro directly to clean that code up.

binary_array_op_dyn_scalar!(lhs, rhs.clone(), eq, &DataType::Boolean).ok_or_else(
|| {
DataFusionError::Internal(format!(
"Data type {:?} and scalar {:?} not supported for array_eq_scalar",
lhs.data_type(),
rhs.get_datatype()
))
},
)?
}

impl BinaryExpr {
/// Evaluate the expression of the left input is an array and
/// right is literal - use scalar operations
Expand Down Expand Up @@ -1366,10 +1317,6 @@ fn is_not_distinct_from_null(
make_boolean_array(length, true)
}

pub fn eq_null(left: &NullArray, _right: &NullArray) -> Result<BooleanArray> {
Ok((0..left.len()).into_iter().map(|_| None).collect())
}

fn make_boolean_array(length: usize, value: bool) -> Result<BooleanArray> {
Ok((0..length).into_iter().map(|_| Some(value)).collect())
}
Expand Down
18 changes: 8 additions & 10 deletions datafusion/physical-expr/src/expressions/nullif.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@

use std::sync::Arc;

use crate::expressions::binary::{eq_decimal, eq_decimal_scalar, eq_null};
use arrow::array::Array;
use arrow::array::*;
use arrow::compute::eq_dyn;
use arrow::compute::kernels::boolean::nullif;
use arrow::compute::kernels::comparison::{
eq, eq_bool, eq_bool_scalar, eq_scalar, eq_utf8, eq_utf8_scalar,
};
use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::ScalarValue;
use arrow::datatypes::DataType;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;

use super::binary::array_eq_scalar;

/// Invoke a compute kernel on a primitive array and a Boolean Array
macro_rules! compute_bool_array_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
Expand Down Expand Up @@ -82,18 +80,18 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {

match (lhs, rhs) {
(ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
let cond_array = binary_array_op_scalar!(lhs, rhs.clone(), eq).unwrap()?;
let cond_array = array_eq_scalar(lhs, rhs)?;

let array = primitive_bool_array_op!(lhs, *cond_array, nullif)?;

Ok(ColumnarValue::Array(array))
}
(ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => {
// Get args0 == args1 evaluated and produce a boolean array
let cond_array = binary_array_op!(lhs, rhs, eq)?;
let cond_array = eq_dyn(lhs, rhs)?;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hooray for eq_dyn thanks again @jimexist


// Now, invoke nullif on the result
let array = primitive_bool_array_op!(lhs, *cond_array, nullif)?;
let array = primitive_bool_array_op!(lhs, cond_array, nullif)?;
Ok(ColumnarValue::Array(array))
}
_ => Err(DataFusionError::NotImplemented(
Expand All @@ -105,7 +103,7 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
#[cfg(test)]
mod tests {
use super::*;
use datafusion_common::Result;
use datafusion_common::{Result, ScalarValue};

#[test]
fn nullif_int32() -> Result<()> {
Expand Down