diff --git a/rust/arrow/benches/comparison_kernels.rs b/rust/arrow/benches/comparison_kernels.rs index a1e43d2f314..216023879f6 100644 --- a/rust/arrow/benches/comparison_kernels.rs +++ b/rust/arrow/benches/comparison_kernels.rs @@ -23,6 +23,7 @@ extern crate arrow; use arrow::array::*; use arrow::compute::*; +use arrow::datatypes::ArrowNumericType; fn create_array(size: usize) -> Float32Array { let mut builder = Float32Builder::new(size); @@ -36,91 +37,300 @@ fn create_array(size: usize) -> Float32Array { builder.finish() } -pub fn eq_no_simd(size: usize) { - let arr_a = create_array(size); - let arr_b = create_array(size); - criterion::black_box(no_simd_compare_op(&arr_a, &arr_b, |a, b| a == b).unwrap()); +pub fn eq_no_simd(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) +where + T: ArrowNumericType, +{ + no_simd_compare_op( + criterion::black_box(arr_a), + criterion::black_box(arr_b), + |a, b| a == b, + ) + .unwrap(); } -pub fn neq_no_simd(size: usize) { - let arr_a = create_array(size); - let arr_b = create_array(size); - criterion::black_box(no_simd_compare_op(&arr_a, &arr_b, |a, b| a != b).unwrap()); +pub fn eq_no_simd_scalar(arr_a: &PrimitiveArray, value_b: T::Native) +where + T: ArrowNumericType, +{ + no_simd_compare_op_scalar( + criterion::black_box(arr_a), + criterion::black_box(value_b), + |a, b| a == b, + ) + .unwrap(); } -pub fn lt_no_simd(size: usize) { - let arr_a = create_array(size); - let arr_b = create_array(size); - criterion::black_box(no_simd_compare_op(&arr_a, &arr_b, |a, b| a < b).unwrap()); +pub fn neq_no_simd(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) +where + T: ArrowNumericType, +{ + no_simd_compare_op( + criterion::black_box(arr_a), + criterion::black_box(arr_b), + |a, b| a != b, + ) + .unwrap(); } -fn lt_eq_no_simd(size: usize) { - let arr_a = create_array(size); - let arr_b = create_array(size); - criterion::black_box(no_simd_compare_op(&arr_a, &arr_b, |a, b| a <= b).unwrap()); +pub fn neq_no_simd_scalar(arr_a: &PrimitiveArray, value_b: T::Native) +where + T: ArrowNumericType, +{ + no_simd_compare_op_scalar( + criterion::black_box(arr_a), + criterion::black_box(value_b), + |a, b| a != b, + ) + .unwrap(); } -pub fn gt_no_simd(size: usize) { - let arr_a = create_array(size); - let arr_b = create_array(size); - criterion::black_box(no_simd_compare_op(&arr_a, &arr_b, |a, b| a > b).unwrap()); +pub fn lt_no_simd(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) +where + T: ArrowNumericType, +{ + no_simd_compare_op( + criterion::black_box(arr_a), + criterion::black_box(arr_b), + |a, b| a < b, + ) + .unwrap(); } -fn gt_eq_no_simd(size: usize) { - let arr_a = create_array(size); - let arr_b = create_array(size); - criterion::black_box(no_simd_compare_op(&arr_a, &arr_b, |a, b| a >= b).unwrap()); +pub fn lt_no_simd_scalar(arr_a: &PrimitiveArray, value_b: T::Native) +where + T: ArrowNumericType, +{ + no_simd_compare_op_scalar( + criterion::black_box(arr_a), + criterion::black_box(value_b), + |a, b| a < b, + ) + .unwrap(); } -fn eq_simd(size: usize) { - let arr_a = create_array(size); - let arr_b = create_array(size); - criterion::black_box(eq(&arr_a, &arr_b).unwrap()); +fn lt_eq_no_simd(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) +where + T: ArrowNumericType, +{ + no_simd_compare_op( + criterion::black_box(arr_a), + criterion::black_box(arr_b), + |a, b| a <= b, + ) + .unwrap(); } -fn neq_simd(size: usize) { - let arr_a = create_array(size); - let arr_b = create_array(size); - criterion::black_box(neq(&arr_a, &arr_b).unwrap()); +fn lt_eq_no_simd_scalar(arr_a: &PrimitiveArray, value_b: T::Native) +where + T: ArrowNumericType, +{ + no_simd_compare_op_scalar( + criterion::black_box(arr_a), + criterion::black_box(value_b), + |a, b| a <= b, + ) + .unwrap(); } -fn lt_simd(size: usize) { - let arr_a = create_array(size); - let arr_b = create_array(size); - criterion::black_box(lt(&arr_a, &arr_b).unwrap()); +pub fn gt_no_simd(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) +where + T: ArrowNumericType, +{ + no_simd_compare_op( + criterion::black_box(arr_a), + criterion::black_box(arr_b), + |a, b| a > b, + ) + .unwrap(); } -fn lt_eq_simd(size: usize) { - let arr_a = create_array(size); - let arr_b = create_array(size); - criterion::black_box(lt_eq(&arr_a, &arr_b).unwrap()); +pub fn gt_no_simd_scalar(arr_a: &PrimitiveArray, value_b: T::Native) +where + T: ArrowNumericType, +{ + no_simd_compare_op_scalar( + criterion::black_box(arr_a), + criterion::black_box(value_b), + |a, b| a > b, + ) + .unwrap(); } -fn gt_simd(size: usize) { - let arr_a = create_array(size); - let arr_b = create_array(size); - criterion::black_box(gt(&arr_a, &arr_b).unwrap()); +fn gt_eq_no_simd(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) +where + T: ArrowNumericType, +{ + no_simd_compare_op( + criterion::black_box(arr_a), + criterion::black_box(arr_b), + |a, b| a >= b, + ) + .unwrap(); } -fn gt_eq_simd(size: usize) { - let arr_a = create_array(size); - let arr_b = create_array(size); - criterion::black_box(gt_eq(&arr_a, &arr_b).unwrap()); +fn gt_eq_no_simd_scalar(arr_a: &PrimitiveArray, value_b: T::Native) +where + T: ArrowNumericType, +{ + no_simd_compare_op_scalar( + criterion::black_box(arr_a), + criterion::black_box(value_b), + |a, b| a >= b, + ) + .unwrap(); +} + +fn eq_simd(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) +where + T: ArrowNumericType, +{ + eq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); +} + +fn eq_simd_scalar(arr_a: &PrimitiveArray, value_b: T::Native) +where + T: ArrowNumericType, +{ + eq_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)).unwrap(); +} + +fn neq_simd(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) +where + T: ArrowNumericType, +{ + neq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); +} + +fn neq_simd_scalar(arr_a: &PrimitiveArray, value_b: T::Native) +where + T: ArrowNumericType, +{ + neq_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)).unwrap(); +} + +fn lt_simd(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) +where + T: ArrowNumericType, +{ + lt(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); +} + +fn lt_simd_scalar(arr_a: &PrimitiveArray, value_b: T::Native) +where + T: ArrowNumericType, +{ + lt_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)).unwrap(); +} + +fn lt_eq_simd(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) +where + T: ArrowNumericType, +{ + lt_eq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); +} + +fn lt_eq_simd_scalar(arr_a: &PrimitiveArray, value_b: T::Native) +where + T: ArrowNumericType, +{ + lt_eq_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)).unwrap(); +} + +fn gt_simd(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) +where + T: ArrowNumericType, +{ + gt(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); +} + +fn gt_simd_scalar(arr_a: &PrimitiveArray, value_b: T::Native) +where + T: ArrowNumericType, +{ + gt_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)).unwrap(); +} + +fn gt_eq_simd(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) +where + T: ArrowNumericType, +{ + gt_eq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); +} + +fn gt_eq_simd_scalar(arr_a: &PrimitiveArray, value_b: T::Native) +where + T: ArrowNumericType, +{ + gt_eq_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)).unwrap(); } fn add_benchmark(c: &mut Criterion) { - c.bench_function("eq 512", |b| b.iter(|| eq_no_simd(512))); - c.bench_function("eq 512 simd", |b| b.iter(|| eq_simd(512))); - c.bench_function("neq 512", |b| b.iter(|| neq_no_simd(512))); - c.bench_function("neq 512 simd", |b| b.iter(|| neq_simd(512))); - c.bench_function("lt 512", |b| b.iter(|| lt_no_simd(512))); - c.bench_function("lt 512 simd", |b| b.iter(|| lt_simd(512))); - c.bench_function("lt_eq 512", |b| b.iter(|| lt_eq_no_simd(512))); - c.bench_function("lt_eq 512 simd", |b| b.iter(|| lt_eq_simd(512))); - c.bench_function("gt 512", |b| b.iter(|| gt_no_simd(512))); - c.bench_function("gt 512 simd", |b| b.iter(|| gt_simd(512))); - c.bench_function("gt_eq 512", |b| b.iter(|| gt_eq_no_simd(512))); - c.bench_function("gt_eq 512 simd", |b| b.iter(|| gt_eq_simd(512))); + let size = 65536; + let arr_a = create_array(size); + let arr_b = create_array(size); + + c.bench_function("eq Float32", |b| b.iter(|| eq_no_simd(&arr_a, &arr_b))); + c.bench_function("eq scalar Float32", |b| { + b.iter(|| eq_no_simd_scalar(&arr_a, 1.0)) + }); + c.bench_function("eq Float32 simd", |b| b.iter(|| eq_simd(&arr_a, &arr_b))); + c.bench_function("eq scalar Float32 simd", |b| { + b.iter(|| eq_simd_scalar(&arr_a, 1.0)) + }); + + c.bench_function("neq Float32", |b| b.iter(|| neq_no_simd(&arr_a, &arr_b))); + c.bench_function("neq scalar Float32", |b| { + b.iter(|| neq_no_simd_scalar(&arr_a, 1.0)) + }); + c.bench_function("neq Float32 simd", |b| b.iter(|| neq_simd(&arr_a, &arr_b))); + c.bench_function("neq scalar Float32 simd", |b| { + b.iter(|| neq_simd_scalar(&arr_a, 1.0)) + }); + + c.bench_function("lt Float32", |b| b.iter(|| lt_no_simd(&arr_a, &arr_b))); + c.bench_function("lt scalar Float32", |b| { + b.iter(|| lt_no_simd_scalar(&arr_a, 1.0)) + }); + c.bench_function("lt Float32 simd", |b| b.iter(|| lt_simd(&arr_a, &arr_b))); + c.bench_function("lt scalar Float32 simd", |b| { + b.iter(|| lt_simd_scalar(&arr_a, 1.0)) + }); + + c.bench_function("lt_eq Float32", |b| { + b.iter(|| lt_eq_no_simd(&arr_a, &arr_b)) + }); + c.bench_function("lt_eq scalar Float32", |b| { + b.iter(|| lt_eq_no_simd_scalar(&arr_a, 1.0)) + }); + c.bench_function("lt_eq Float32 simd", |b| { + b.iter(|| lt_eq_simd(&arr_a, &arr_b)) + }); + c.bench_function("lt_eq scalar Float32 simd", |b| { + b.iter(|| lt_eq_simd_scalar(&arr_a, 1.0)) + }); + + c.bench_function("gt Float32", |b| b.iter(|| gt_no_simd(&arr_a, &arr_b))); + c.bench_function("gt scalar Float32", |b| { + b.iter(|| gt_no_simd_scalar(&arr_a, 1.0)) + }); + c.bench_function("gt Float32 simd", |b| b.iter(|| gt_simd(&arr_a, &arr_b))); + c.bench_function("gt scalar Float32 simd", |b| { + b.iter(|| gt_simd_scalar(&arr_a, 1.0)) + }); + + c.bench_function("gt_eq Float32", |b| { + b.iter(|| gt_eq_no_simd(&arr_a, &arr_b)) + }); + c.bench_function("gt_eq scalar Float32", |b| { + b.iter(|| gt_eq_no_simd_scalar(&arr_a, 1.0)) + }); + c.bench_function("gt_eq Float32 simd", |b| { + b.iter(|| gt_eq_simd(&arr_a, &arr_b)) + }); + c.bench_function("gt_eq scalar Float32 simd", |b| { + b.iter(|| gt_eq_simd_scalar(&arr_a, 1.0)) + }); } criterion_group!(benches, add_benchmark); diff --git a/rust/arrow/src/compute/kernels/comparison.rs b/rust/arrow/src/compute/kernels/comparison.rs index 9d184c0ea67..60c8c86d5e9 100644 --- a/rust/arrow/src/compute/kernels/comparison.rs +++ b/rust/arrow/src/compute/kernels/comparison.rs @@ -66,6 +66,27 @@ macro_rules! compare_op { }}; } +macro_rules! compare_op_scalar { + ($left: expr, $right:expr, $op:expr) => {{ + let null_bit_buffer = $left.data().null_buffer().cloned(); + let mut result = BooleanBufferBuilder::new($left.len()); + for i in 0..$left.len() { + result.append($op($left.value(i), $right))?; + } + + let data = ArrayData::new( + DataType::Boolean, + $left.len(), + None, + null_bit_buffer, + $left.offset(), + vec![result.finish()], + vec![], + ); + Ok(PrimitiveArray::::from(Arc::new(data))) + }}; +} + pub fn no_simd_compare_op( left: &PrimitiveArray, right: &PrimitiveArray, @@ -78,6 +99,18 @@ where compare_op!(left, right, op) } +pub fn no_simd_compare_op_scalar( + left: &PrimitiveArray, + right: T::Native, + op: F, +) -> Result +where + T: ArrowNumericType, + F: Fn(T::Native, T::Native) -> bool, +{ + compare_op_scalar!(left, right, op) +} + pub fn like_utf8(left: &StringArray, right: &StringArray) -> Result { let mut map = HashMap::new(); if left.len() != right.len() { @@ -178,26 +211,50 @@ pub fn eq_utf8(left: &StringArray, right: &StringArray) -> Result compare_op!(left, right, |a, b| a == b) } +pub fn eq_utf8_scalar(left: &StringArray, right: &str) -> Result { + compare_op_scalar!(left, right, |a, b| a == b) +} + pub fn neq_utf8(left: &StringArray, right: &StringArray) -> Result { compare_op!(left, right, |a, b| a != b) } +pub fn neq_utf8_scalar(left: &StringArray, right: &str) -> Result { + compare_op_scalar!(left, right, |a, b| a != b) +} + pub fn lt_utf8(left: &StringArray, right: &StringArray) -> Result { compare_op!(left, right, |a, b| a < b) } +pub fn lt_utf8_scalar(left: &StringArray, right: &str) -> Result { + compare_op_scalar!(left, right, |a, b| a < b) +} + pub fn lt_eq_utf8(left: &StringArray, right: &StringArray) -> Result { compare_op!(left, right, |a, b| a <= b) } +pub fn lt_eq_utf8_scalar(left: &StringArray, right: &str) -> Result { + compare_op_scalar!(left, right, |a, b| a <= b) +} + pub fn gt_utf8(left: &StringArray, right: &StringArray) -> Result { compare_op!(left, right, |a, b| a > b) } +pub fn gt_utf8_scalar(left: &StringArray, right: &str) -> Result { + compare_op_scalar!(left, right, |a, b| a > b) +} + pub fn gt_eq_utf8(left: &StringArray, right: &StringArray) -> Result { compare_op!(left, right, |a, b| a >= b) } +pub fn gt_eq_utf8_scalar(left: &StringArray, right: &str) -> Result { + compare_op_scalar!(left, right, |a, b| a >= b) +} + /// Helper function to perform boolean lambda function on values from two arrays using /// SIMD. #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))] @@ -264,6 +321,59 @@ where Ok(PrimitiveArray::::from(Arc::new(data))) } +/// Helper function to perform boolean lambda function on values from an array and a scalar value using +/// SIMD. +#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))] +fn simd_compare_op_scalar( + left: &PrimitiveArray, + right: T::Native, + op: F, +) -> Result +where + T: ArrowNumericType, + F: Fn(T::Simd, T::Simd) -> T::SimdMask, +{ + use crate::buffer::MutableBuffer; + use std::io::Write; + use std::mem; + + let len = left.len(); + let null_bit_buffer = left.data().null_buffer().cloned(); + let lanes = T::lanes(); + let mut result = MutableBuffer::new(left.len() * mem::size_of::()); + let simd_right = T::init(right); + + let rem = len % lanes; + + for i in (0..len - rem).step_by(lanes) { + let simd_left = T::load(left.value_slice(i, lanes)); + let simd_result = op(simd_left, simd_right); + T::bitmask(&simd_result, |b| { + result.write(b).unwrap(); + }); + } + + if rem > 0 { + let simd_left = T::load(left.value_slice(len - rem, lanes)); + let simd_result = op(simd_left, simd_right); + let rem_buffer_size = (rem as f32 / 8f32).ceil() as usize; + T::bitmask(&simd_result, |b| { + result.write(&b[0..rem_buffer_size]).unwrap(); + }); + } + + let data = ArrayData::new( + DataType::Boolean, + left.len(), + None, + null_bit_buffer, + left.offset(), + vec![result.freeze()], + vec![], + ); + Ok(PrimitiveArray::::from(Arc::new(data))) +} + /// Perform `left == right` operation on two arrays. pub fn eq(left: &PrimitiveArray, right: &PrimitiveArray) -> Result where @@ -279,6 +389,21 @@ where compare_op!(left, right, |a, b| a == b) } +/// Perform `left == right` operation on an array and a scalar value. +pub fn eq_scalar(left: &PrimitiveArray, right: T::Native) -> Result +where + T: ArrowNumericType, +{ + #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))] + return simd_compare_op_scalar(left, right, T::eq); + + #[cfg(any( + not(any(target_arch = "x86", target_arch = "x86_64")), + not(feature = "simd") + ))] + compare_op_scalar!(left, right, |a, b| a == b) +} + /// Perform `left != right` operation on two arrays. pub fn neq(left: &PrimitiveArray, right: &PrimitiveArray) -> Result where @@ -294,6 +419,21 @@ where compare_op!(left, right, |a, b| a != b) } +/// Perform `left != right` operation on an array and a scalar value. +pub fn neq_scalar(left: &PrimitiveArray, right: T::Native) -> Result +where + T: ArrowNumericType, +{ + #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))] + return simd_compare_op_scalar(left, right, T::ne); + + #[cfg(any( + not(any(target_arch = "x86", target_arch = "x86_64")), + not(feature = "simd") + ))] + compare_op_scalar!(left, right, |a, b| a != b) +} + /// Perform `left < right` operation on two arrays. Null values are less than non-null /// values. pub fn lt(left: &PrimitiveArray, right: &PrimitiveArray) -> Result @@ -310,6 +450,22 @@ where compare_op!(left, right, |a, b| a < b) } +/// Perform `left < right` operation on an array and a scalar value. +/// Null values are less than non-null values. +pub fn lt_scalar(left: &PrimitiveArray, right: T::Native) -> Result +where + T: ArrowNumericType, +{ + #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))] + return simd_compare_op_scalar(left, right, T::lt); + + #[cfg(any( + not(any(target_arch = "x86", target_arch = "x86_64")), + not(feature = "simd") + ))] + compare_op_scalar!(left, right, |a, b| a < b) +} + /// Perform `left <= right` operation on two arrays. Null values are less than non-null /// values. pub fn lt_eq( @@ -329,6 +485,22 @@ where compare_op!(left, right, |a, b| a <= b) } +/// Perform `left <= right` operation on an array and a scalar value. +/// Null values are less than non-null values. +pub fn lt_eq_scalar(left: &PrimitiveArray, right: T::Native) -> Result +where + T: ArrowNumericType, +{ + #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))] + return simd_compare_op_scalar(left, right, T::le); + + #[cfg(any( + not(any(target_arch = "x86", target_arch = "x86_64")), + not(feature = "simd") + ))] + compare_op_scalar!(left, right, |a, b| a <= b) +} + /// Perform `left > right` operation on two arrays. Non-null values are greater than null /// values. pub fn gt(left: &PrimitiveArray, right: &PrimitiveArray) -> Result @@ -345,6 +517,22 @@ where compare_op!(left, right, |a, b| a > b) } +/// Perform `left > right` operation on an array and a scalar value. +/// Non-null values are greater than null values. +pub fn gt_scalar(left: &PrimitiveArray, right: T::Native) -> Result +where + T: ArrowNumericType, +{ + #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))] + return simd_compare_op_scalar(left, right, T::gt); + + #[cfg(any( + not(any(target_arch = "x86", target_arch = "x86_64")), + not(feature = "simd") + ))] + compare_op_scalar!(left, right, |a, b| a > b) +} + /// Perform `left >= right` operation on two arrays. Non-null values are greater than null /// values. pub fn gt_eq( @@ -364,6 +552,22 @@ where compare_op!(left, right, |a, b| a >= b) } +/// Perform `left >= right` operation on an array and a scalar value. +/// Non-null values are greater than null values. +pub fn gt_eq_scalar(left: &PrimitiveArray, right: T::Native) -> Result +where + T: ArrowNumericType, +{ + #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))] + return simd_compare_op_scalar(left, right, T::ge); + + #[cfg(any( + not(any(target_arch = "x86", target_arch = "x86_64")), + not(feature = "simd") + ))] + compare_op_scalar!(left, right, |a, b| a >= b) +} + #[cfg(test)] mod tests { use super::*; @@ -382,6 +586,17 @@ mod tests { assert_eq!(false, c.value(4)); } + #[test] + fn test_primitive_array_eq_scalar() { + let a = Int32Array::from(vec![6, 7, 8, 9, 10]); + let c = eq_scalar(&a, 8).unwrap(); + assert_eq!(false, c.value(0)); + assert_eq!(false, c.value(1)); + assert_eq!(true, c.value(2)); + assert_eq!(false, c.value(3)); + assert_eq!(false, c.value(4)); + } + #[test] fn test_primitive_array_neq() { let a = Int32Array::from(vec![8, 8, 8, 8, 8]); @@ -394,6 +609,17 @@ mod tests { assert_eq!(true, c.value(4)); } + #[test] + fn test_primitive_array_neq_scalar() { + let a = Int32Array::from(vec![6, 7, 8, 9, 10]); + let c = neq_scalar(&a, 8).unwrap(); + assert_eq!(true, c.value(0)); + assert_eq!(true, c.value(1)); + assert_eq!(false, c.value(2)); + assert_eq!(true, c.value(3)); + assert_eq!(true, c.value(4)); + } + #[test] fn test_primitive_array_lt() { let a = Int32Array::from(vec![8, 8, 8, 8, 8]); @@ -406,6 +632,17 @@ mod tests { assert_eq!(true, c.value(4)); } + #[test] + fn test_primitive_array_lt_scalar() { + let a = Int32Array::from(vec![6, 7, 8, 9, 10]); + let c = lt_scalar(&a, 8).unwrap(); + assert_eq!(true, c.value(0)); + assert_eq!(true, c.value(1)); + assert_eq!(false, c.value(2)); + assert_eq!(false, c.value(3)); + assert_eq!(false, c.value(4)); + } + #[test] fn test_primitive_array_lt_nulls() { let a = Int32Array::from(vec![None, None, Some(1)]); @@ -416,6 +653,15 @@ mod tests { assert_eq!(false, c.value(2)); } + #[test] + fn test_primitive_array_lt_scalar_nulls() { + let a = Int32Array::from(vec![None, Some(1), Some(2)]); + let c = lt_scalar(&a, 2).unwrap(); + assert_eq!(true, c.value(0)); + assert_eq!(true, c.value(1)); + assert_eq!(false, c.value(2)); + } + #[test] fn test_primitive_array_lt_eq() { let a = Int32Array::from(vec![8, 8, 8, 8, 8]); @@ -428,6 +674,17 @@ mod tests { assert_eq!(true, c.value(4)); } + #[test] + fn test_primitive_array_lt_eq_scalar() { + let a = Int32Array::from(vec![6, 7, 8, 9, 10]); + let c = lt_eq_scalar(&a, 8).unwrap(); + assert_eq!(true, c.value(0)); + assert_eq!(true, c.value(1)); + assert_eq!(true, c.value(2)); + assert_eq!(false, c.value(3)); + assert_eq!(false, c.value(4)); + } + #[test] fn test_primitive_array_lt_eq_nulls() { let a = Int32Array::from(vec![None, None, Some(1)]); @@ -438,6 +695,15 @@ mod tests { assert_eq!(false, c.value(2)); } + #[test] + fn test_primitive_array_lt_eq_scalar_nulls() { + let a = Int32Array::from(vec![None, Some(1), Some(2)]); + let c = lt_eq_scalar(&a, 1).unwrap(); + assert_eq!(true, c.value(0)); + assert_eq!(true, c.value(1)); + assert_eq!(false, c.value(2)); + } + #[test] fn test_primitive_array_gt() { let a = Int32Array::from(vec![8, 8, 8, 8, 8]); @@ -450,6 +716,17 @@ mod tests { assert_eq!(false, c.value(4)); } + #[test] + fn test_primitive_array_gt_scalar() { + let a = Int32Array::from(vec![6, 7, 8, 9, 10]); + let c = gt_scalar(&a, 8).unwrap(); + assert_eq!(false, c.value(0)); + assert_eq!(false, c.value(1)); + assert_eq!(false, c.value(2)); + assert_eq!(true, c.value(3)); + assert_eq!(true, c.value(4)); + } + #[test] fn test_primitive_array_gt_nulls() { let a = Int32Array::from(vec![None, None, Some(1)]); @@ -460,6 +737,15 @@ mod tests { assert_eq!(true, c.value(2)); } + #[test] + fn test_primitive_array_gt_scalar_nulls() { + let a = Int32Array::from(vec![None, Some(1), Some(2)]); + let c = gt_scalar(&a, 1).unwrap(); + assert_eq!(false, c.value(0)); + assert_eq!(false, c.value(1)); + assert_eq!(true, c.value(2)); + } + #[test] fn test_primitive_array_gt_eq() { let a = Int32Array::from(vec![8, 8, 8, 8, 8]); @@ -472,6 +758,17 @@ mod tests { assert_eq!(false, c.value(4)); } + #[test] + fn test_primitive_array_gt_eq_scalar() { + let a = Int32Array::from(vec![6, 7, 8, 9, 10]); + let c = gt_eq_scalar(&a, 8).unwrap(); + assert_eq!(false, c.value(0)); + assert_eq!(false, c.value(1)); + assert_eq!(true, c.value(2)); + assert_eq!(true, c.value(3)); + assert_eq!(true, c.value(4)); + } + #[test] fn test_primitive_array_gt_eq_nulls() { let a = Int32Array::from(vec![None, None, Some(1)]); @@ -482,6 +779,15 @@ mod tests { assert_eq!(true, c.value(2)); } + #[test] + fn test_primitive_array_gt_eq_scalar_nulls() { + let a = Int32Array::from(vec![None, Some(1), Some(2)]); + let c = gt_eq_scalar(&a, 1).unwrap(); + assert_eq!(false, c.value(0)); + assert_eq!(true, c.value(1)); + assert_eq!(true, c.value(2)); + } + #[test] fn test_length_of_result_buffer() { // `item_count` is chosen to not be a multiple of the number of SIMD lanes for this @@ -517,6 +823,29 @@ mod tests { }; } + macro_rules! test_utf8_scalar { + ($test_name:ident, $left:expr, $right:expr, $op:expr, $expected:expr) => { + #[test] + fn $test_name() { + let left = StringArray::from($left); + let res = $op(&left, $right).unwrap(); + let expected = $expected; + assert_eq!(expected.len(), res.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!( + v, + expected[i], + "unexpected result when comparing {} at position {} to {} ", + left.value(i), + i, + $right + ); + } + } + }; + } + test_utf8!( test_utf8_array_like, vec!["arrow", "arrow", "arrow", "arrow"], @@ -531,6 +860,7 @@ mod tests { nlike_utf8, vec![false, false, false, true] ); + test_utf8!( test_utf8_array_eq, vec!["arrow", "arrow", "arrow", "arrow"], @@ -538,13 +868,29 @@ mod tests { eq_utf8, vec![true, false, false, false] ); + test_utf8_scalar!( + test_utf8_array_eq_scalar, + vec!["arrow", "parquet", "datafusion", "flight"], + "arrow", + eq_utf8_scalar, + vec![true, false, false, false] + ); + test_utf8!( - test_utf8_array_new, + test_utf8_array_neq, vec!["arrow", "arrow", "arrow", "arrow"], vec!["arrow", "parquet", "datafusion", "flight"], neq_utf8, vec![false, true, true, true] ); + test_utf8_scalar!( + test_utf8_array_neq_scalar, + vec!["arrow", "parquet", "datafusion", "flight"], + "arrow", + neq_utf8_scalar, + vec![false, true, true, true] + ); + test_utf8!( test_utf8_array_lt, vec!["arrow", "datafusion", "flight", "parquet"], @@ -552,6 +898,14 @@ mod tests { lt_utf8, vec![true, true, false, false] ); + test_utf8_scalar!( + test_utf8_array_lt_scalar, + vec!["arrow", "datafusion", "flight", "parquet"], + "flight", + lt_utf8_scalar, + vec![true, true, false, false] + ); + test_utf8!( test_utf8_array_lt_eq, vec!["arrow", "datafusion", "flight", "parquet"], @@ -559,6 +913,14 @@ mod tests { lt_eq_utf8, vec![true, true, true, false] ); + test_utf8_scalar!( + test_utf8_array_lt_eq_scalar, + vec!["arrow", "datafusion", "flight", "parquet"], + "flight", + lt_eq_utf8_scalar, + vec![true, true, true, false] + ); + test_utf8!( test_utf8_array_gt, vec!["arrow", "datafusion", "flight", "parquet"], @@ -566,6 +928,14 @@ mod tests { gt_utf8, vec![false, false, false, true] ); + test_utf8_scalar!( + test_utf8_array_gt_scalar, + vec!["arrow", "datafusion", "flight", "parquet"], + "flight", + gt_utf8_scalar, + vec![false, false, false, true] + ); + test_utf8!( test_utf8_array_gt_eq, vec!["arrow", "datafusion", "flight", "parquet"], @@ -573,4 +943,11 @@ mod tests { gt_eq_utf8, vec![false, false, true, true] ); + test_utf8_scalar!( + test_utf8_array_gt_eq_scalar, + vec!["arrow", "datafusion", "flight", "parquet"], + "flight", + gt_eq_utf8_scalar, + vec![false, false, true, true] + ); } diff --git a/rust/arrow/src/datatypes.rs b/rust/arrow/src/datatypes.rs index d3e129c04f8..d639c5289d9 100644 --- a/rust/arrow/src/datatypes.rs +++ b/rust/arrow/src/datatypes.rs @@ -501,7 +501,8 @@ where Self::Simd: Add + Sub + Mul - + Div, + + Div + + Copy, { /// Defines the SIMD type that should be used for this numeric type type Simd;