diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index f9deada5389b..c57e27095c23 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -1624,7 +1624,7 @@ where mod tests { use super::*; use crate::array::Int32Array; - use crate::compute::{try_unary_mut, unary_mut}; + use crate::compute::{binary_mut, try_binary_mut, try_unary_mut, unary_mut}; use crate::datatypes::{Date64Type, Int32Type, Int8Type}; use arrow_buffer::i256; use chrono::NaiveDate; @@ -3100,6 +3100,35 @@ mod tests { assert_eq!(result.null_count(), 13); } + #[test] + fn test_primitive_array_add_mut_by_binary_mut() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); + + let c = binary_mut(a, &b, |a, b| a.add_wrapping(b)) + .unwrap() + .unwrap(); + let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]); + assert_eq!(c, expected); + } + + #[test] + fn test_primitive_add_mut_wrapping_overflow_by_try_binary_mut() { + let a = Int32Array::from(vec![i32::MAX, i32::MIN]); + let b = Int32Array::from(vec![1, 1]); + + let wrapped = binary_mut(a, &b, |a, b| a.add_wrapping(b)) + .unwrap() + .unwrap(); + let expected = Int32Array::from(vec![-2147483648, -2147483647]); + assert_eq!(expected, wrapped); + + let a = Int32Array::from(vec![i32::MAX, i32::MIN]); + let b = Int32Array::from(vec![1, 1]); + let overflow = try_binary_mut(a, &b, |a, b| a.add_checked(b)); + let _ = overflow.unwrap().expect_err("overflow should be detected"); + } + #[test] fn test_primitive_add_scalar_by_unary_mut() { let a = Int32Array::from(vec![15, 14, 9, 8, 1]); diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 946d15e9e984..d0f18cf5866d 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -232,6 +232,75 @@ where Ok(unsafe { build_primitive_array(len, buffer, null_count, null_buffer) }) } +/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, mutating +/// the mutable [`PrimitiveArray`] `a`. If any index is null in either `a` or `b`, the +/// corresponding index in the result will also be null. +/// +/// Mutable primitive array means that the buffer is not shared with other arrays. +/// As a result, this mutates the buffer directly without allocating new buffer. +/// +/// Like [`unary`] the provided function is evaluated for every index, ignoring validity. This +/// is beneficial when the cost of the operation is low compared to the cost of branching, and +/// especially when the operation can be vectorised, however, requires `op` to be infallible +/// for all possible values of its inputs +/// +/// # Error +/// +/// This function gives error if the arrays have different lengths. +/// This function gives error of original [`PrimitiveArray`] `a` if it is not a mutable +/// primitive array. +pub fn binary_mut( + a: PrimitiveArray, + b: &PrimitiveArray, + op: F, +) -> std::result::Result< + std::result::Result, ArrowError>, + PrimitiveArray, +> +where + T: ArrowPrimitiveType, + F: Fn(T::Native, T::Native) -> T::Native, +{ + if a.len() != b.len() { + return Ok(Err(ArrowError::ComputeError( + "Cannot perform binary operation on arrays of different length".to_string(), + ))); + } + + if a.is_empty() { + return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty( + &T::DATA_TYPE, + )))); + } + + let len = a.len(); + + let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap(); + let null_count = null_buffer + .as_ref() + .map(|x| len - x.count_set_bits_offset(0, len)) + .unwrap_or_default(); + + let mut builder = a.into_builder()?; + + builder + .values_slice_mut() + .iter_mut() + .zip(b.values()) + .for_each(|(l, r)| *l = op(*l, *r)); + + let array_builder = builder + .finish() + .data() + .clone() + .into_builder() + .null_bit_buffer(null_buffer) + .null_count(null_count); + + let array_data = unsafe { array_builder.build_unchecked() }; + Ok(Ok(PrimitiveArray::::from(array_data))) +} + /// Applies the provided fallible binary operation across `a` and `b`, returning any error, /// and collecting the results into a [`PrimitiveArray`]. If any index is null in either `a` /// or `b`, the corresponding index in the result will also be null @@ -289,6 +358,83 @@ where } } +/// Applies the provided fallible binary operation across `a` and `b` by mutating the mutable +/// [`PrimitiveArray`] `a` with the results, returning any error. If any index is null in +/// either `a` or `b`, the corresponding index in the result will also be null +/// +/// Like [`try_unary`] the function is only evaluated for non-null indices +/// +/// Mutable primitive array means that the buffer is not shared with other arrays. +/// As a result, this mutates the buffer directly without allocating new buffer. +/// +/// # Error +/// +/// Return an error if the arrays have different lengths or +/// the operation is under erroneous. +/// This function gives error of original [`PrimitiveArray`] `a` if it is not a mutable +/// primitive array. +pub fn try_binary_mut( + a: PrimitiveArray, + b: &PrimitiveArray, + op: F, +) -> std::result::Result< + std::result::Result, ArrowError>, + PrimitiveArray, +> +where + T: ArrowPrimitiveType, + F: Fn(T::Native, T::Native) -> Result, +{ + if a.len() != b.len() { + return Ok(Err(ArrowError::ComputeError( + "Cannot perform binary operation on arrays of different length".to_string(), + ))); + } + let len = a.len(); + + if a.is_empty() { + return Ok(Ok(PrimitiveArray::from(ArrayData::new_empty( + &T::DATA_TYPE, + )))); + } + + if a.null_count() == 0 && b.null_count() == 0 { + try_binary_no_nulls_mut(len, a, b, op) + } else { + let null_buffer = combine_option_bitmap(&[a.data(), b.data()], len).unwrap(); + let null_count = null_buffer + .as_ref() + .map(|x| len - x.count_set_bits_offset(0, len)) + .unwrap_or_default(); + + let mut builder = a.into_builder()?; + + let slice = builder.values_slice_mut(); + + match try_for_each_valid_idx(len, 0, null_count, null_buffer.as_deref(), |idx| { + unsafe { + *slice.get_unchecked_mut(idx) = + op(*slice.get_unchecked(idx), b.value_unchecked(idx))? + }; + Ok::<_, ArrowError>(()) + }) { + Ok(_) => {} + Err(err) => return Ok(Err(err)), + }; + + let array_builder = builder + .finish() + .data() + .clone() + .into_builder() + .null_bit_buffer(null_buffer) + .null_count(null_count); + + let array_data = unsafe { array_builder.build_unchecked() }; + Ok(Ok(PrimitiveArray::::from(array_data))) + } +} + /// This intentional inline(never) attribute helps LLVM optimize the loop. #[inline(never)] fn try_binary_no_nulls( @@ -310,6 +456,35 @@ where Ok(unsafe { build_primitive_array(len, buffer.into(), 0, None) }) } +/// This intentional inline(never) attribute helps LLVM optimize the loop. +#[inline(never)] +fn try_binary_no_nulls_mut( + len: usize, + a: PrimitiveArray, + b: &PrimitiveArray, + op: F, +) -> std::result::Result< + std::result::Result, ArrowError>, + PrimitiveArray, +> +where + T: ArrowPrimitiveType, + F: Fn(T::Native, T::Native) -> Result, +{ + let mut builder = a.into_builder()?; + let slice = builder.values_slice_mut(); + + for idx in 0..len { + unsafe { + match op(*slice.get_unchecked(idx), b.value_unchecked(idx)) { + Ok(value) => *slice.get_unchecked_mut(idx) = value, + Err(err) => return Ok(Err(err)), + }; + }; + } + Ok(Ok(builder.finish())) +} + #[inline(never)] fn try_binary_opt_no_nulls( len: usize, @@ -385,6 +560,7 @@ mod tests { use super::*; use crate::array::{as_primitive_array, Float64Array, PrimitiveDictionaryBuilder}; use crate::datatypes::{Float64Type, Int32Type, Int8Type}; + use arrow_array::Int32Array; #[test] fn test_unary_f64_slice() { @@ -444,4 +620,44 @@ mod tests { &expected ); } + + #[test] + fn test_binary_mut() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); + let c = binary_mut(a, &b, |l, r| l + r).unwrap().unwrap(); + + let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]); + assert_eq!(c, expected); + } + + #[test] + fn test_try_binary_mut() { + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); + let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap(); + + let expected = Int32Array::from(vec![Some(16), None, Some(12), None, Some(6)]); + assert_eq!(c, expected); + + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = Int32Array::from(vec![1, 2, 3, 4, 5]); + let c = try_binary_mut(a, &b, |l, r| Ok(l + r)).unwrap().unwrap(); + let expected = Int32Array::from(vec![16, 16, 12, 12, 6]); + assert_eq!(c, expected); + + let a = Int32Array::from(vec![15, 14, 9, 8, 1]); + let b = Int32Array::from(vec![Some(1), None, Some(3), None, Some(5)]); + let _ = try_binary_mut(a, &b, |l, r| { + if l == 1 { + Err(ArrowError::InvalidArgumentError( + "got error".parse().unwrap(), + )) + } else { + Ok(l + r) + } + }) + .unwrap() + .expect_err("should got error"); + } }