diff --git a/rust/arrow/src/compute/kernels/comparison.rs b/rust/arrow/src/compute/kernels/comparison.rs index 1cd4633a1ed..39138ef1c21 100644 --- a/rust/arrow/src/compute/kernels/comparison.rs +++ b/rust/arrow/src/compute/kernels/comparison.rs @@ -109,7 +109,10 @@ where compare_op_scalar!(left, right, op) } -pub fn like_utf8(left: &StringArray, right: &StringArray) -> Result { +pub fn like_utf8( + left: &GenericStringArray, + right: &GenericStringArray, +) -> Result { let mut map = HashMap::new(); if left.len() != right.len() { return Err(ArrowError::ComputeError( @@ -158,7 +161,10 @@ fn is_like_pattern(c: char) -> bool { c == '%' || c == '_' } -pub fn like_utf8_scalar(left: &StringArray, right: &str) -> Result { +pub fn like_utf8_scalar( + left: &GenericStringArray, + right: &str, +) -> Result { let null_bit_buffer = left.data().null_buffer().cloned(); let bytes = bit_util::ceil(left.len(), 8); let mut bool_buf = MutableBuffer::from_len_zeroed(bytes); @@ -217,7 +223,10 @@ pub fn like_utf8_scalar(left: &StringArray, right: &str) -> Result Ok(BooleanArray::from(Arc::new(data))) } -pub fn nlike_utf8(left: &StringArray, right: &StringArray) -> Result { +pub fn nlike_utf8( + left: &GenericStringArray, + right: &GenericStringArray, +) -> Result { let mut map = HashMap::new(); if left.len() != right.len() { return Err(ArrowError::ComputeError( @@ -262,7 +271,10 @@ pub fn nlike_utf8(left: &StringArray, right: &StringArray) -> Result Result { +pub fn nlike_utf8_scalar( + left: &GenericStringArray, + right: &str, +) -> Result { let null_bit_buffer = left.data().null_buffer().cloned(); let mut result = BooleanBufferBuilder::new(left.len()); @@ -308,51 +320,87 @@ pub fn nlike_utf8_scalar(left: &StringArray, right: &str) -> Result Result { +pub fn eq_utf8( + left: &GenericStringArray, + right: &GenericStringArray, +) -> Result { compare_op!(left, right, |a, b| a == b) } -pub fn eq_utf8_scalar(left: &StringArray, right: &str) -> Result { +pub fn eq_utf8_scalar( + left: &GenericStringArray, + right: &str, +) -> Result { compare_op_scalar!(left, right, |a, b| a == b) } -pub fn neq_utf8(left: &StringArray, right: &StringArray) -> Result { +pub fn neq_utf8( + left: &GenericStringArray, + right: &GenericStringArray, +) -> Result { compare_op!(left, right, |a, b| a != b) } -pub fn neq_utf8_scalar(left: &StringArray, right: &str) -> Result { +pub fn neq_utf8_scalar( + left: &GenericStringArray, + right: &str, +) -> Result { compare_op_scalar!(left, right, |a, b| a != b) } -pub fn lt_utf8(left: &StringArray, right: &StringArray) -> Result { +pub fn lt_utf8( + left: &GenericStringArray, + right: &GenericStringArray, +) -> Result { compare_op!(left, right, |a, b| a < b) } -pub fn lt_utf8_scalar(left: &StringArray, right: &str) -> Result { +pub fn lt_utf8_scalar( + left: &GenericStringArray, + right: &str, +) -> Result { compare_op_scalar!(left, right, |a, b| a < b) } -pub fn lt_eq_utf8(left: &StringArray, right: &StringArray) -> Result { +pub fn lt_eq_utf8( + left: &GenericStringArray, + right: &GenericStringArray, +) -> Result { compare_op!(left, right, |a, b| a <= b) } -pub fn lt_eq_utf8_scalar(left: &StringArray, right: &str) -> Result { +pub fn lt_eq_utf8_scalar( + left: &GenericStringArray, + right: &str, +) -> Result { compare_op_scalar!(left, right, |a, b| a <= b) } -pub fn gt_utf8(left: &StringArray, right: &StringArray) -> Result { +pub fn gt_utf8( + left: &GenericStringArray, + right: &GenericStringArray, +) -> Result { compare_op!(left, right, |a, b| a > b) } -pub fn gt_utf8_scalar(left: &StringArray, right: &str) -> Result { +pub fn gt_utf8_scalar( + left: &GenericStringArray, + right: &str, +) -> Result { compare_op_scalar!(left, right, |a, b| a > b) } -pub fn gt_eq_utf8(left: &StringArray, right: &StringArray) -> Result { +pub fn gt_eq_utf8( + left: &GenericStringArray, + right: &GenericStringArray, +) -> Result { compare_op!(left, right, |a, b| a >= b) } -pub fn gt_eq_utf8_scalar(left: &StringArray, right: &str) -> Result { +pub fn gt_eq_utf8_scalar( + left: &GenericStringArray, + right: &str, +) -> Result { compare_op_scalar!(left, right, |a, b| a >= b) } @@ -1227,6 +1275,22 @@ mod tests { $right ); } + + let left = LargeStringArray::from($left); + let res = $op(&left, $right).unwrap(); + let expected = $expected; + assert_eq!(expected.len(), res.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!( + v, + expected[i], + "unexpected result when comparing {} at position {} to {} ", + left.value(i), + i, + $right + ); + } } }; }