diff --git a/rust/arrow/src/compute/kernels/length.rs b/rust/arrow/src/compute/kernels/length.rs index 740bb2b68c8..8c8fea8b89e 100644 --- a/rust/arrow/src/compute/kernels/length.rs +++ b/rust/arrow/src/compute/kernels/length.rs @@ -24,43 +24,77 @@ use crate::{ }; use std::sync::Arc; -#[allow(clippy::unnecessary_wraps)] -fn length_string(array: &Array, data_type: DataType) -> Result -where - OffsetSize: OffsetSizeTrait, -{ - // note: offsets are stored as u8, but they can be interpreted as OffsetSize - let offsets = &array.data_ref().buffers()[0]; - // this is a 30% improvement over iterating over u8s and building OffsetSize, which - // justifies the usage of `unsafe`. - let slice: &[OffsetSize] = - &unsafe { offsets.typed_data::() }[array.offset()..]; +fn clone_null_buffer(array: &impl Array) -> Option { + array + .data_ref() + .null_bitmap() + .as_ref() + .map(|b| b.bits.clone()) +} - let lengths = slice.windows(2).map(|offset| offset[1] - offset[0]); +fn length_from_offsets(offsets: &[T]) -> Buffer { + let lengths = offsets.windows(2).map(|offset| offset[1] - offset[0]); // JUSTIFICATION // Benefit // ~60% speedup // Soundness - // `values` is an iterator with a known size. - let buffer = unsafe { Buffer::from_trusted_len_iter(lengths) }; + // `lengths` is `TrustedLen` + unsafe { Buffer::from_trusted_len_iter(lengths) } +} - let null_bit_buffer = array - .data_ref() - .null_bitmap() - .as_ref() - .map(|b| b.bits.clone()); +fn length_string( + array: &GenericStringArray, + data_type: DataType, +) -> ArrayRef +where + OffsetSize: StringOffsetSizeTrait, +{ + make_array(Arc::new(ArrayData::new( + data_type, + array.len(), + None, + clone_null_buffer(array), + 0, + vec![length_from_offsets(array.value_offsets())], + vec![], + ))) +} - let data = ArrayData::new( +fn length_list( + array: &GenericListArray, + data_type: DataType, +) -> ArrayRef +where + OffsetSize: OffsetSizeTrait, +{ + make_array(Arc::new(ArrayData::new( data_type, array.len(), None, - null_bit_buffer, + clone_null_buffer(array), 0, - vec![buffer], + vec![length_from_offsets(array.value_offsets())], vec![], - ); - Ok(make_array(Arc::new(data))) + ))) +} + +fn length_binary( + array: &GenericBinaryArray, + data_type: DataType, +) -> ArrayRef +where + OffsetSize: BinaryOffsetSizeTrait, +{ + make_array(Arc::new(ArrayData::new( + data_type, + array.len(), + None, + clone_null_buffer(array), + 0, + vec![length_from_offsets(array.value_offsets())], + vec![], + ))) } /// Returns an array of Int32/Int64 denoting the number of characters in each string in the array. @@ -70,8 +104,30 @@ where /// * length is in number of bytes pub fn length(array: &Array) -> Result { match array.data_type() { - DataType::Utf8 => length_string::(array, DataType::Int32), - DataType::LargeUtf8 => length_string::(array, DataType::Int64), + DataType::Binary => { + let array = array.as_any().downcast_ref::().unwrap(); + Ok(length_binary(array, DataType::Int32)) + } + DataType::LargeBinary => { + let array = array.as_any().downcast_ref::().unwrap(); + Ok(length_binary(array, DataType::Int64)) + } + DataType::List(_) => { + let array = array.as_any().downcast_ref::().unwrap(); + Ok(length_list(array, DataType::Int32)) + } + DataType::LargeList(_) => { + let array = array.as_any().downcast_ref::().unwrap(); + Ok(length_list(array, DataType::Int64)) + } + DataType::Utf8 => { + let array = array.as_any().downcast_ref::().unwrap(); + Ok(length_string(array, DataType::Int32)) + } + DataType::LargeUtf8 => { + let array = array.as_any().downcast_ref::().unwrap(); + Ok(length_string(array, DataType::Int64)) + } _ => Err(ArrowError::ComputeError(format!( "length not supported for {:?}", array.data_type() @@ -203,4 +259,16 @@ mod tests { Ok(()) } + + #[test] + fn test_binary() -> Result<()> { + let data: Vec<&[u8]> = vec![b"hello", b" ", b"world"]; + let a = BinaryArray::from(data); + let result = length(&a)?; + + let expected: &Array = &Int32Array::from(vec![5, 1, 5]); + assert_eq!(expected, result.as_ref()); + + Ok(()) + } } diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index c5cd01f93c5..b705ee0b84e 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -233,7 +233,11 @@ pub fn return_type( match fun { BuiltinScalarFunction::Length => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::Int64, + DataType::LargeBinary => DataType::Int64, + DataType::LargeList(_) => DataType::Int64, DataType::Utf8 => DataType::Int32, + DataType::Binary => DataType::Int32, + DataType::List(_) => DataType::Int32, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( @@ -437,9 +441,20 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { // for now, the list is small, as we do not have many built-in functions. match fun { BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]), + BuiltinScalarFunction::Length => { + // todo: add support for non-constant DataType's (e.g. `DataType::List(_)`) + Signature::Uniform( + 1, + vec![ + DataType::Binary, + DataType::Utf8, + DataType::LargeBinary, + DataType::LargeUtf8, + ], + ) + } BuiltinScalarFunction::Upper | BuiltinScalarFunction::Lower - | BuiltinScalarFunction::Length | BuiltinScalarFunction::Trim | BuiltinScalarFunction::Ltrim | BuiltinScalarFunction::Rtrim