Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 94 additions & 26 deletions rust/arrow/src/compute/kernels/length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,43 +24,77 @@ use crate::{
};
use std::sync::Arc;

#[allow(clippy::unnecessary_wraps)]
fn length_string<OffsetSize>(array: &Array, data_type: DataType) -> Result<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
{
// note: offsets are stored as u8, but they can be interpreted as OffsetSize
let offsets = &array.data_ref().buffers()[0];
// this is a 30% improvement over iterating over u8s and building OffsetSize, which
// justifies the usage of `unsafe`.
let slice: &[OffsetSize] =
&unsafe { offsets.typed_data::<OffsetSize>() }[array.offset()..];
fn clone_null_buffer(array: &impl Array) -> Option<Buffer> {
array
.data_ref()
.null_bitmap()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.null_buffer().cloned() ?

.as_ref()
.map(|b| b.bits.clone())
}

let lengths = slice.windows(2).map(|offset| offset[1] - offset[0]);
fn length_from_offsets<T: OffsetSizeTrait>(offsets: &[T]) -> Buffer {
let lengths = offsets.windows(2).map(|offset| offset[1] - offset[0]);

// JUSTIFICATION
// Benefit
// ~60% speedup
// Soundness
// `values` is an iterator with a known size.
let buffer = unsafe { Buffer::from_trusted_len_iter(lengths) };
// `lengths` is `TrustedLen`
unsafe { Buffer::from_trusted_len_iter(lengths) }
}

let null_bit_buffer = array
.data_ref()
.null_bitmap()
.as_ref()
.map(|b| b.bits.clone());
fn length_string<OffsetSize>(
array: &GenericStringArray<OffsetSize>,
data_type: DataType,
) -> ArrayRef
where
OffsetSize: StringOffsetSizeTrait,
{
make_array(Arc::new(ArrayData::new(
data_type,
array.len(),
None,
clone_null_buffer(array),
0,
vec![length_from_offsets(array.value_offsets())],
vec![],
)))
}

let data = ArrayData::new(
fn length_list<OffsetSize>(
array: &GenericListArray<OffsetSize>,
data_type: DataType,
) -> ArrayRef
where
OffsetSize: OffsetSizeTrait,
{
make_array(Arc::new(ArrayData::new(
data_type,
array.len(),
None,
null_bit_buffer,
clone_null_buffer(array),
0,
vec![buffer],
vec![length_from_offsets(array.value_offsets())],
vec![],
);
Ok(make_array(Arc::new(data)))
)))
}

fn length_binary<OffsetSize>(
array: &GenericBinaryArray<OffsetSize>,
data_type: DataType,
) -> ArrayRef
where
OffsetSize: BinaryOffsetSizeTrait,
{
make_array(Arc::new(ArrayData::new(
data_type,
array.len(),
None,
clone_null_buffer(array),
0,
vec![length_from_offsets(array.value_offsets())],
vec![],
)))
}

/// Returns an array of Int32/Int64 denoting the number of characters in each string in the array.
Expand All @@ -70,8 +104,30 @@ where
/// * length is in number of bytes
pub fn length(array: &Array) -> Result<ArrayRef> {
match array.data_type() {
DataType::Utf8 => length_string::<i32>(array, DataType::Int32),
DataType::LargeUtf8 => length_string::<i64>(array, DataType::Int64),
DataType::Binary => {
let array = array.as_any().downcast_ref::<BinaryArray>().unwrap();
Ok(length_binary(array, DataType::Int32))
}
DataType::LargeBinary => {
let array = array.as_any().downcast_ref::<LargeBinaryArray>().unwrap();
Ok(length_binary(array, DataType::Int64))
}
DataType::List(_) => {
let array = array.as_any().downcast_ref::<ListArray>().unwrap();
Ok(length_list(array, DataType::Int32))
}
DataType::LargeList(_) => {
let array = array.as_any().downcast_ref::<LargeListArray>().unwrap();
Ok(length_list(array, DataType::Int64))
}
DataType::Utf8 => {
let array = array.as_any().downcast_ref::<StringArray>().unwrap();
Ok(length_string(array, DataType::Int32))
}
DataType::LargeUtf8 => {
let array = array.as_any().downcast_ref::<LargeStringArray>().unwrap();
Ok(length_string(array, DataType::Int64))
}
_ => Err(ArrowError::ComputeError(format!(
"length not supported for {:?}",
array.data_type()
Expand Down Expand Up @@ -203,4 +259,16 @@ mod tests {

Ok(())
}

#[test]
fn test_binary() -> Result<()> {
let data: Vec<&[u8]> = vec![b"hello", b" ", b"world"];
let a = BinaryArray::from(data);
let result = length(&a)?;

let expected: &Array = &Int32Array::from(vec![5, 1, 5]);
assert_eq!(expected, result.as_ref());

Ok(())
}
}
17 changes: 16 additions & 1 deletion rust/datafusion/src/physical_plan/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,11 @@ pub fn return_type(
match fun {
BuiltinScalarFunction::Length => Ok(match arg_types[0] {
DataType::LargeUtf8 => DataType::Int64,
DataType::LargeBinary => DataType::Int64,
DataType::LargeList(_) => DataType::Int64,
DataType::Utf8 => DataType::Int32,
DataType::Binary => DataType::Int32,
DataType::List(_) => DataType::Int32,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(
Expand Down Expand Up @@ -437,9 +441,20 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {
// for now, the list is small, as we do not have many built-in functions.
match fun {
BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]),
BuiltinScalarFunction::Length => {
// todo: add support for non-constant DataType's (e.g. `DataType::List(_)`)
Signature::Uniform(
1,
vec![
DataType::Binary,
DataType::Utf8,
DataType::LargeBinary,
DataType::LargeUtf8,
],
)
}
BuiltinScalarFunction::Upper
| BuiltinScalarFunction::Lower
| BuiltinScalarFunction::Length
| BuiltinScalarFunction::Trim
| BuiltinScalarFunction::Ltrim
| BuiltinScalarFunction::Rtrim
Expand Down