diff --git a/rust/datafusion/src/physical_plan/expressions/binary.rs b/rust/datafusion/src/physical_plan/expressions/binary.rs index 0d503508d63..9e048c9d4fd 100644 --- a/rust/datafusion/src/physical_plan/expressions/binary.rs +++ b/rust/datafusion/src/physical_plan/expressions/binary.rs @@ -18,7 +18,9 @@ use std::{any::Any, sync::Arc}; use arrow::array::*; -use arrow::compute::kernels::arithmetic::{add, divide, multiply, subtract}; +use arrow::compute::kernels::arithmetic::{ + add, divide, divide_scalar, multiply, subtract, +}; use arrow::compute::kernels::boolean::{and, or}; use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow::compute::kernels::comparison::{ @@ -162,10 +164,10 @@ macro_rules! compute_op { macro_rules! binary_string_array_op_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ - let result = match $LEFT.data_type() { + let result: Result> = match $LEFT.data_type() { DataType::Utf8 => compute_utf8_op_scalar!($LEFT, $RIGHT, $OP, StringArray), other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?}", + "Data type {:?} not supported for scalar operation on string array", other ))), }; @@ -178,7 +180,7 @@ macro_rules! binary_string_array_op { match $LEFT.data_type() { DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray), other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?}", + "Data type {:?} not supported for binary operation on string arrays", other ))), } @@ -202,19 +204,44 @@ macro_rules! binary_primitive_array_op { DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array), DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array), other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?}", + "Data type {:?} not supported for binary operation on primitive arrays", other ))), } }}; } +/// Invoke a compute kernel on an array and a scalar +/// The binary_primitive_array_op_scalar macro only evaluates for primitive +/// types like integers and floats. +macro_rules! binary_primitive_array_op_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + let result: Result> = match $LEFT.data_type() { + DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array), + DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array), + DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array), + DataType::Int64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int64Array), + DataType::UInt8 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt8Array), + DataType::UInt16 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt16Array), + DataType::UInt32 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt32Array), + DataType::UInt64 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt64Array), + DataType::Float32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array), + DataType::Float64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array), + other => Err(DataFusionError::Internal(format!( + "Data type {:?} not supported for scalar operation on primitive array", + other + ))), + }; + Some(result) + }}; +} + /// The binary_array_op_scalar macro includes types that extend beyond the primitive, /// such as Utf8 strings. #[macro_export] macro_rules! binary_array_op_scalar { ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ - let result = match $LEFT.data_type() { + let result: Result> = match $LEFT.data_type() { DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array), DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array), DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array), @@ -233,7 +260,7 @@ macro_rules! binary_array_op_scalar { compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array) } other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?}", + "Data type {:?} not supported for scalar operation on dyn array", other ))), }; @@ -268,7 +295,7 @@ macro_rules! binary_array_op { compute_op!($LEFT, $RIGHT, $OP, Date64Array) } other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?}", + "Data type {:?} not supported for binary operation on dyn arrays", other ))), } @@ -424,6 +451,9 @@ impl PhysicalExpr for BinaryExpr { Operator::NotLike => { binary_string_array_op_scalar!(array, scalar.clone(), nlike) } + Operator::Divide => { + binary_primitive_array_op_scalar!(array, scalar.clone(), divide) + } // if scalar operation is not supported - fallback to array implementation _ => None, }