diff --git a/rust/arrow/Cargo.toml b/rust/arrow/Cargo.toml index 2943069b985..0ceb135643f 100644 --- a/rust/arrow/Cargo.toml +++ b/rust/arrow/Cargo.toml @@ -77,6 +77,10 @@ harness = false name = "comparison_kernels" harness = false +[[bench]] +name = "take_kernels" +harness = false + [[bench]] name = "csv_writer" harness = false diff --git a/rust/arrow/benches/take_kernels.rs b/rust/arrow/benches/take_kernels.rs new file mode 100644 index 00000000000..ee420808348 --- /dev/null +++ b/rust/arrow/benches/take_kernels.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +use criterion::Criterion; +use rand::distributions::{Distribution, Standard}; +use rand::prelude::random; +use rand::Rng; + +use std::sync::Arc; + +extern crate arrow; + +use arrow::array::*; +use arrow::compute::{cast, take}; +use arrow::datatypes::*; + +// cast array from specified primitive array type to desired data type +fn create_numeric(size: usize) -> ArrayRef +where + T: ArrowNumericType, + Standard: Distribution, + PrimitiveArray: std::convert::From>, +{ + Arc::new(PrimitiveArray::::from(vec![random::(); size])) as ArrayRef +} + +fn create_random_index(size: usize) -> UInt32Array { + let mut rng = rand::thread_rng(); + let ints = Int32Array::from(vec![rng.gen_range(-24i32, size as i32); size]); + // cast to u32, conveniently marking negative values as nulls + UInt32Array::from( + cast(&(Arc::new(ints) as ArrayRef), &DataType::UInt32) + .unwrap() + .data(), + ) +} + +fn take_numeric(size: usize, index_len: usize) -> () +where + T: ArrowNumericType, + Standard: Distribution, + PrimitiveArray: std::convert::From>, + T::Native: num::NumCast, +{ + let array = create_numeric::(size); + let index = create_random_index(index_len); + criterion::black_box(take(&array, &index, None).unwrap()); +} + +fn take_boolean(size: usize, index_len: usize) -> () { + let array = Arc::new(BooleanArray::from(vec![random::(); size])) as ArrayRef; + let index = create_random_index(index_len); + criterion::black_box(take(&array, &index, None).unwrap()); +} + +fn add_benchmark(c: &mut Criterion) { + c.bench_function("take u8 256", |b| { + b.iter(|| take_numeric::(256, 256)) + }); + c.bench_function("take u8 512", |b| { + b.iter(|| take_numeric::(512, 512)) + }); + c.bench_function("take u8 1024", |b| { + b.iter(|| take_numeric::(1024, 1024)) + }); + c.bench_function("take i32 256", |b| { + b.iter(|| take_numeric::(256, 256)) + }); + c.bench_function("take i32 512", |b| { + b.iter(|| take_numeric::(512, 512)) + }); + c.bench_function("take i32 1024", |b| { + b.iter(|| take_numeric::(1024, 1024)) + }); + c.bench_function("take bool 256", |b| b.iter(|| take_boolean(256, 256))); + c.bench_function("take bool 512", |b| b.iter(|| take_boolean(512, 512))); + c.bench_function("take bool 1024", |b| b.iter(|| take_boolean(1024, 1024))); +} + +criterion_group!(benches, add_benchmark); +criterion_main!(benches); diff --git a/rust/arrow/src/array/array.rs b/rust/arrow/src/array/array.rs index 2c353d578f0..e4e55d06650 100644 --- a/rust/arrow/src/array/array.rs +++ b/rust/arrow/src/array/array.rs @@ -994,6 +994,11 @@ impl StructArray { pub fn num_columns(&self) -> usize { self.boxed_fields.len() } + + /// Returns the fields of the struct array + pub fn columns(&self) -> Vec<&ArrayRef> { + self.boxed_fields.iter().collect() + } } impl From for StructArray { diff --git a/rust/arrow/src/array/builder.rs b/rust/arrow/src/array/builder.rs index b0b97c22107..da0357b4924 100644 --- a/rust/arrow/src/array/builder.rs +++ b/rust/arrow/src/array/builder.rs @@ -467,11 +467,21 @@ impl BinaryBuilder { /// /// Note, when appending individual byte values you must call `append` to delimit each /// distinct list value. - pub fn append_value(&mut self, value: u8) -> Result<()> { + pub fn append_byte(&mut self, value: u8) -> Result<()> { self.builder.values().append_value(value)?; Ok(()) } + /// Appends a byte slice into the builder. + /// + /// Automatically calls the `append` method to delimit the slice appended in as a + /// distinct array element. + pub fn append_value(&mut self, value: &[u8]) -> Result<()> { + self.builder.values().append_slice(value)?; + self.builder.append(true)?; + Ok(()) + } + /// Appends a `&String` or `&str` into the builder. /// /// Automatically calls the `append` method to delimit the string appended in as a @@ -1156,18 +1166,18 @@ mod tests { fn test_binary_array_builder() { let mut builder = BinaryBuilder::new(20); - builder.append_value(b'h').unwrap(); - builder.append_value(b'e').unwrap(); - builder.append_value(b'l').unwrap(); - builder.append_value(b'l').unwrap(); - builder.append_value(b'o').unwrap(); + builder.append_byte(b'h').unwrap(); + builder.append_byte(b'e').unwrap(); + builder.append_byte(b'l').unwrap(); + builder.append_byte(b'l').unwrap(); + builder.append_byte(b'o').unwrap(); builder.append(true).unwrap(); builder.append(true).unwrap(); - builder.append_value(b'w').unwrap(); - builder.append_value(b'o').unwrap(); - builder.append_value(b'r').unwrap(); - builder.append_value(b'l').unwrap(); - builder.append_value(b'd').unwrap(); + builder.append_byte(b'w').unwrap(); + builder.append_byte(b'o').unwrap(); + builder.append_byte(b'r').unwrap(); + builder.append_byte(b'l').unwrap(); + builder.append_byte(b'd').unwrap(); builder.append(true).unwrap(); let array = builder.finish(); diff --git a/rust/arrow/src/compute/kernels/mod.rs b/rust/arrow/src/compute/kernels/mod.rs index 2483f519b97..ae1ab0cc45d 100644 --- a/rust/arrow/src/compute/kernels/mod.rs +++ b/rust/arrow/src/compute/kernels/mod.rs @@ -21,4 +21,5 @@ pub mod arithmetic; pub mod boolean; pub mod cast; pub mod comparison; +pub mod take; pub mod temporal; diff --git a/rust/arrow/src/compute/kernels/take.rs b/rust/arrow/src/compute/kernels/take.rs new file mode 100644 index 00000000000..6cce7fb47d9 --- /dev/null +++ b/rust/arrow/src/compute/kernels/take.rs @@ -0,0 +1,595 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines take kernel for `ArrayRef` + +use std::sync::Arc; + +use crate::array::*; +use crate::buffer::{Buffer, MutableBuffer}; +use crate::compute::util::take_value_indices_from_list; +use crate::datatypes::*; +use crate::error::{ArrowError, Result}; +use crate::util::bit_util; + +use TimeUnit::*; + +/// Take elements from `ArrayRef` by supplying an array of indices. +/// +/// Supports: +/// * null indices, returning a null value for the index +/// * checking for overflowing indices +pub fn take( + values: &ArrayRef, + indices: &UInt32Array, + options: Option, +) -> Result { + let options = options.unwrap_or(Default::default()); + if options.check_bounds { + let len = values.len(); + for i in 0..indices.len() { + if indices.is_valid(i) { + let ix = indices.value(i) as usize; + if ix >= len { + return Err(ArrowError::ComputeError( + format!("Array index out of bounds, cannot get item at index {} from {} entries", ix, len)) + ); + } + } + } + } + match values.data_type() { + DataType::Boolean => take_primitive::(values, indices), + DataType::Int8 => take_primitive::(values, indices), + DataType::Int16 => take_primitive::(values, indices), + DataType::Int32 => take_primitive::(values, indices), + DataType::Int64 => take_primitive::(values, indices), + DataType::UInt8 => take_primitive::(values, indices), + DataType::UInt16 => take_primitive::(values, indices), + DataType::UInt32 => take_primitive::(values, indices), + DataType::UInt64 => take_primitive::(values, indices), + DataType::Float32 => take_primitive::(values, indices), + DataType::Float64 => take_primitive::(values, indices), + DataType::Date32(_) => take_primitive::(values, indices), + DataType::Date64(_) => take_primitive::(values, indices), + DataType::Time32(Second) => take_primitive::(values, indices), + DataType::Time32(Millisecond) => { + take_primitive::(values, indices) + } + DataType::Time64(Microsecond) => { + take_primitive::(values, indices) + } + DataType::Time64(Nanosecond) => { + take_primitive::(values, indices) + } + DataType::Timestamp(Second) => { + take_primitive::(values, indices) + } + DataType::Timestamp(Millisecond) => { + take_primitive::(values, indices) + } + DataType::Timestamp(Microsecond) => { + take_primitive::(values, indices) + } + DataType::Timestamp(Nanosecond) => { + take_primitive::(values, indices) + } + DataType::Utf8 => take_binary(values, indices), + DataType::List(_) => take_list(values, indices), + DataType::Struct(fields) => { + let struct_: &StructArray = + values.as_any().downcast_ref::().unwrap(); + let arrays: Result> = struct_ + .columns() + .iter() + .map(|a| take(a, indices, Some(options.clone()))) + .collect(); + let arrays = arrays?; + let pairs: Vec<(Field, ArrayRef)> = + fields.clone().into_iter().zip(arrays).collect(); + Ok(Arc::new(StructArray::from(pairs)) as ArrayRef) + } + t @ _ => unimplemented!("Take not supported for data type {:?}", t), + } +} + +/// Options that define how `take` should behave +#[derive(Clone)] +pub struct TakeOptions { + /// Perform bounds check before taking indices from values. + /// If enabled, an `ArrowError` is returned if the indices are out of bounds. + /// If not enabled, and indices exceed bounds, the kernel will panic. + pub check_bounds: bool, +} + +impl Default for TakeOptions { + fn default() -> Self { + Self { + check_bounds: false, + } + } +} + +/// `take` implementation for primitive arrays +/// +/// This checks if an `indices` slot is populated, and gets the value from `values` +/// as the populated index. +/// If the `indices` slot is null, a null value is returned. +/// For example, given: +/// values: [1, 2, 3, null, 5] +/// indices: [0, null, 4, 3] +/// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)] +fn take_primitive(values: &ArrayRef, indices: &UInt32Array) -> Result +where + T: ArrowPrimitiveType, +{ + let mut builder = PrimitiveBuilder::::new(indices.len()); + let a = values.as_any().downcast_ref::>().unwrap(); + for i in 0..indices.len() { + if indices.is_null(i) { + // populate with null if index is null + builder.append_null()?; + } else { + // get index value to use in looking up the value from `values` + let ix = indices.value(i) as usize; + if a.is_valid(ix) { + builder.append_value(a.value(ix))?; + } else { + builder.append_null()?; + } + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +/// `take` implementation for binary arrays +fn take_binary(values: &ArrayRef, indices: &UInt32Array) -> Result { + let mut builder = BinaryBuilder::new(indices.len()); + let a = values.as_any().downcast_ref::().unwrap(); + for i in 0..indices.len() { + if indices.is_null(i) { + builder.append(false)?; + } else { + let ix = indices.value(i) as usize; + if a.is_null(ix) { + builder.append(false)?; + } else { + builder.append_value(a.value(ix))?; + } + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +/// `take` implementation for list arrays +/// +/// Calculates the index and indexed offset for the inner array, +/// applying `take` on the inner array, then reconstructing a list array +/// with the indexed offsets +fn take_list(values: &ArrayRef, indices: &UInt32Array) -> Result { + // TODO: Some optimizations can be done here such as if it is + // taking the whole list or a contiguous sublist + let list: &ListArray = values.as_any().downcast_ref::().unwrap(); + let (list_indices, offsets) = take_value_indices_from_list(values, indices); + let taken = take(&list.values(), &list_indices, None)?; + // determine null count and null buffer, which are a function of `values` and `indices` + let mut null_count = 0; + let num_bytes = bit_util::ceil(indices.len(), 8); + let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); + { + let null_slice = null_buf.data_mut(); + &offsets[..] + .windows(2) + .enumerate() + .for_each(|(i, window): (usize, &[i32])| { + if window[0] != window[1] { + // offsets are unequal, slot is not null + bit_util::set_bit(null_slice, i); + } else { + null_count += 1; + } + }); + } + let value_offsets = Buffer::from(offsets[..].to_byte_slice()); + // create a new list with taken data and computed null information + let list_data = ArrayDataBuilder::new(list.data_type().clone()) + .len(indices.len()) + .null_count(null_count) + .null_bit_buffer(null_buf.freeze()) + .offset(0) + .add_child_data(taken.data()) + .add_buffer(value_offsets) + .build(); + let list_array = Arc::new(ListArray::from(list_data)) as ArrayRef; + Ok(list_array) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_take_primitive_arrays<'a, T>( + data: Vec>, + index: &UInt32Array, + options: Option, + expected_data: Vec>, + ) where + T: ArrowPrimitiveType, + PrimitiveArray: From>> + ArrayEqual, + { + let output = PrimitiveArray::::from(data); + let expected = PrimitiveArray::::from(expected_data); + let output = take(&(Arc::new(output) as ArrayRef), index, options).unwrap(); + let output = output.as_any().downcast_ref::>().unwrap(); + assert!(output.equals(&expected)) + } + + // create a simple struct for testing purposes + fn create_test_struct() -> ArrayRef { + let boolean_data = BooleanArray::from(vec![true, false, false, true]).data(); + let int_data = Int32Array::from(vec![42, 28, 19, 31]).data(); + let mut field_types = vec![]; + field_types.push(Field::new("a", DataType::Boolean, true)); + field_types.push(Field::new("b", DataType::Int32, true)); + let struct_array_data = ArrayData::builder(DataType::Struct(field_types)) + .len(4) + .null_count(0) + .add_child_data(boolean_data) + .add_child_data(int_data) + .build(); + let struct_array = StructArray::from(struct_array_data); + Arc::new(struct_array) as ArrayRef + } + + #[test] + fn test_take_primitive() { + let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); + + // uint8 + test_take_primitive_arrays::( + vec![Some(0), None, Some(2), Some(3), None], + &index, + None, + vec![Some(3), None, None, Some(3), Some(2)], + ); + + // uint16 + test_take_primitive_arrays::( + vec![Some(0), None, Some(2), Some(3), None], + &index, + None, + vec![Some(3), None, None, Some(3), Some(2)], + ); + + // uint32 + test_take_primitive_arrays::( + vec![Some(0), None, Some(2), Some(3), None], + &index, + None, + vec![Some(3), None, None, Some(3), Some(2)], + ); + + // int64 + test_take_primitive_arrays::( + vec![Some(0), None, Some(2), Some(-15), None], + &index, + None, + vec![Some(-15), None, None, Some(-15), Some(2)], + ); + + // float32 + test_take_primitive_arrays::( + vec![Some(0.0), None, Some(2.21), Some(-3.1), None], + &index, + None, + vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)], + ); + + // float64 + test_take_primitive_arrays::( + vec![Some(0.0), None, Some(2.21), Some(-3.1), None], + &index, + None, + vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)], + ); + + // boolean + // float32 + test_take_primitive_arrays::( + vec![Some(false), None, Some(true), Some(false), None], + &index, + None, + vec![Some(false), None, None, Some(false), Some(true)], + ); + } + + #[test] + fn test_take_binary() { + let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]); + let mut builder: BinaryBuilder = BinaryBuilder::new(6); + builder.append_string("one").unwrap(); + builder.append_null().unwrap(); + builder.append_string("three").unwrap(); + builder.append_string("four").unwrap(); + builder.append_string("five").unwrap(); + let array = Arc::new(builder.finish()) as ArrayRef; + let a = take(&array, &index, None).unwrap(); + assert_eq!(a.len(), index.len()); + builder.append_string("four").unwrap(); + builder.append_null().unwrap(); + builder.append_null().unwrap(); + builder.append_string("four").unwrap(); + builder.append_string("five").unwrap(); + let b = builder.finish(); + assert_eq!(a.data(), b.data()); + } + + #[test] + fn test_take_list() { + // Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]] + let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).data(); + // Construct offsets + let value_offsets = Buffer::from(&[0, 3, 6, 8].to_byte_slice()); + // Construct a list array from the above two + let list_data_type = DataType::List(Box::new(DataType::Int32)); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data) + .build(); + let list_array = Arc::new(ListArray::from(list_data)) as ArrayRef; + + // index returns: [[2,3], null, [-1,-2,-1], [2,3], [0,0,0]] + let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(2), Some(0)]); + + let a = take(&list_array, &index, None).unwrap(); + let a: &ListArray = a.as_any().downcast_ref::().unwrap(); + + // construct a value aray with expected results: + // [[2,3], null, [-1,-2,-1], [2,3], [0,0,0]] + let expected_data = Int32Array::from(vec![ + Some(2), + Some(3), + Some(-1), + Some(-2), + Some(-1), + Some(2), + Some(3), + Some(0), + Some(0), + Some(0), + ]) + .data(); + // construct offsets + let expected_offsets = Buffer::from(&[0, 2, 2, 5, 7, 10].to_byte_slice()); + // construct list array from the two + let expected_list_data = ArrayData::builder(list_data_type.clone()) + .len(5) + .null_count(1) + // null buffer remains the same as only the indices have nulls + .null_bit_buffer(index.data().null_bitmap().as_ref().unwrap().bits.clone()) + .add_buffer(expected_offsets) + .add_child_data(expected_data) + .build(); + let expected_list_array = ListArray::from(expected_list_data); + + assert!(a.equals(&expected_list_array)); + } + + #[test] + fn test_take_list_with_value_nulls() { + // Construct a value array, [[0,null,0], [-1,-2,3], [null], [5,null]] + let value_data = Int32Array::from(vec![ + Some(0), + None, + Some(0), + Some(-1), + Some(-2), + Some(3), + None, + Some(5), + None, + ]) + .data(); + // Construct offsets + let value_offsets = Buffer::from(&[0, 3, 6, 7, 9].to_byte_slice()); + // Construct a list array from the above two + let list_data_type = DataType::List(Box::new(DataType::Int32)); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(4) + .add_buffer(value_offsets) + .null_count(0) + .null_bit_buffer(Buffer::from([0b10111101, 0b00000000])) + .add_child_data(value_data) + .build(); + let list_array = Arc::new(ListArray::from(list_data)) as ArrayRef; + + // index returns: [[null], null, [-1,-2,3], [2,null], [0,null,0]] + let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]); + + let a = take(&list_array, &index, None).unwrap(); + let a: &ListArray = a.as_any().downcast_ref::().unwrap(); + + // construct a value aray with expected results: + // [[null], null, [-1,-2,3], [5,null], [0,null,0]] + let expected_data = Int32Array::from(vec![ + None, + Some(-1), + Some(-2), + Some(3), + Some(5), + None, + Some(0), + None, + Some(0), + ]) + .data(); + // construct offsets + let expected_offsets = Buffer::from(&[0, 1, 1, 4, 6, 9].to_byte_slice()); + // construct list array from the two + let expected_list_data = ArrayData::builder(list_data_type.clone()) + .len(5) + .null_count(1) + // null buffer remains the same as only the indices have nulls + .null_bit_buffer(index.data().null_bitmap().as_ref().unwrap().bits.clone()) + .add_buffer(expected_offsets) + .add_child_data(expected_data) + .build(); + let expected_list_array = ListArray::from(expected_list_data); + + assert!(a.equals(&expected_list_array)); + } + + #[test] + fn test_take_list_with_list_nulls() { + // Construct a value array, [[0,null,0], [-1,-2,3], null, [5,null]] + let value_data = Int32Array::from(vec![ + Some(0), + None, + Some(0), + Some(-1), + Some(-2), + Some(3), + Some(5), + None, + ]) + .data(); + // Construct offsets + let value_offsets = Buffer::from(&[0, 3, 6, 6, 8].to_byte_slice()); + // Construct a list array from the above two + let list_data_type = DataType::List(Box::new(DataType::Int32)); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(4) + .add_buffer(value_offsets) + .null_count(1) + .null_bit_buffer(Buffer::from([0b01111101])) + .add_child_data(value_data) + .build(); + let list_array = Arc::new(ListArray::from(list_data)) as ArrayRef; + + // index returns: [null, null, [-1,-2,3], [5,null], [0,null,0]] + let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]); + + let a = take(&list_array, &index, None).unwrap(); + let a: &ListArray = a.as_any().downcast_ref::().unwrap(); + + // construct a value aray with expected results: + // [null, null, [-1,-2,3], [5,null], [0,null,0]] + let expected_data = Int32Array::from(vec![ + Some(-1), + Some(-2), + Some(3), + Some(5), + None, + Some(0), + None, + Some(0), + ]) + .data(); + // construct offsets + let expected_offsets = Buffer::from(&[0, 0, 0, 3, 5, 8].to_byte_slice()); + // construct list array from the two + let mut null_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut null_bits, 2); + bit_util::set_bit(&mut null_bits, 3); + bit_util::set_bit(&mut null_bits, 4); + let expected_list_data = ArrayData::builder(list_data_type.clone()) + .len(5) + .null_count(2) + // null buffer must be recalculated as both values and indices have nulls + .null_bit_buffer(Buffer::from(null_bits)) + .add_buffer(expected_offsets) + .add_child_data(expected_data) + .build(); + let expected_list_array = ListArray::from(expected_list_data); + + assert!(a.equals(&expected_list_array)); + } + + #[test] + fn test_take_struct() { + let array = create_test_struct(); + + let index = UInt32Array::from(vec![0, 3, 1, 0, 2]); + let a = take(&array, &index, None).unwrap(); + let a: &StructArray = a.as_any().downcast_ref::().unwrap(); + assert_eq!(index.len(), a.len()); + assert_eq!(0, a.null_count()); + + let expected_bool_data = + BooleanArray::from(vec![true, true, false, true, false]).data(); + let expected_int_data = Int32Array::from(vec![42, 31, 28, 42, 19]).data(); + let mut field_types = vec![]; + field_types.push(Field::new("a", DataType::Boolean, true)); + field_types.push(Field::new("b", DataType::Int32, true)); + let struct_array_data = ArrayData::builder(DataType::Struct(field_types)) + .len(5) + .null_count(0) + .add_child_data(expected_bool_data) + .add_child_data(expected_int_data) + .build(); + let struct_array = StructArray::from(struct_array_data); + assert!(a.equals(&struct_array)); + } + + #[test] + fn test_take_struct_with_nulls() { + let array = create_test_struct(); + + let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0)]); + let a = take(&array, &index, None).unwrap(); + let a: &StructArray = a.as_any().downcast_ref::().unwrap(); + assert_eq!(index.len(), a.len()); + assert_eq!(0, a.null_count()); + + let expected_bool_data = + BooleanArray::from(vec![None, Some(true), Some(false), None, Some(true)]) + .data(); + let expected_int_data = + Int32Array::from(vec![None, Some(31), Some(28), None, Some(42)]).data(); + + let mut field_types = vec![]; + field_types.push(Field::new("a", DataType::Boolean, true)); + field_types.push(Field::new("b", DataType::Int32, true)); + let struct_array_data = ArrayData::builder(DataType::Struct(field_types)) + .len(5) + // TODO: see https://issues.apache.org/jira/browse/ARROW-5408 for why count != 2 + .null_count(0) + .add_child_data(expected_bool_data) + .add_child_data(expected_int_data) + .build(); + let struct_array = StructArray::from(struct_array_data); + assert!(a.equals(&struct_array)); + } + + #[test] + #[should_panic( + expected = "Array index out of bounds, cannot get item at index 6 from 5 entries" + )] + fn test_take_out_of_bounds() { + let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]); + let take_opt = TakeOptions { check_bounds: true }; + + // int64 + test_take_primitive_arrays::( + vec![Some(0), None, Some(2), Some(3), None], + &index, + Some(take_opt), + vec![None], + ); + } +} diff --git a/rust/arrow/src/compute/mod.rs b/rust/arrow/src/compute/mod.rs index 7e31c52d85d..15af978af0a 100644 --- a/rust/arrow/src/compute/mod.rs +++ b/rust/arrow/src/compute/mod.rs @@ -27,4 +27,5 @@ pub use self::kernels::arithmetic::*; pub use self::kernels::boolean::*; pub use self::kernels::cast::*; pub use self::kernels::comparison::*; +pub use self::kernels::take::*; pub use self::kernels::temporal::*; diff --git a/rust/arrow/src/compute/util.rs b/rust/arrow/src/compute/util.rs index 55726b85eda..dc1f54fdd2a 100644 --- a/rust/arrow/src/compute/util.rs +++ b/rust/arrow/src/compute/util.rs @@ -17,6 +17,7 @@ //! Common utilities for computation kernels. +use crate::array::*; use crate::bitmap::Bitmap; use crate::buffer::Buffer; use crate::error::Result; @@ -44,10 +45,57 @@ where } } +/// Takes/filters a list array's inner data using the offsets of the list array. +/// +/// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns +/// an array of the indices `[5..10, 0..2]` and offsets `[0,5,7]` (5 elements and 2 +/// elements) +pub(super) fn take_value_indices_from_list( + values: &ArrayRef, + indices: &UInt32Array, +) -> (UInt32Array, Vec) { + // TODO: benchmark this function, there might be a faster unsafe alternative + // get list array's offsets + let list: &ListArray = values.as_any().downcast_ref::().unwrap(); + let offsets: Vec = (0..=list.len()) + .map(|i| list.value_offset(i) as u32) + .collect(); + let mut new_offsets = Vec::with_capacity(indices.len()); + let mut values = Vec::new(); + let mut current_offset = 0; + // add first offset + new_offsets.push(0); + // compute the value indices, and set offsets accordingly + for i in 0..indices.len() { + if indices.is_valid(i) { + let ix = indices.value(i) as usize; + let start = offsets[ix]; + let end = offsets[ix + 1]; + current_offset += (end - start) as i32; + new_offsets.push(current_offset); + // if start == end, this slot is empty + if start != end { + // type annotation needed to guide compiler a bit + let mut offsets: Vec> = + (start..end).map(|v| Some(v)).collect::>>(); + values.append(&mut offsets); + } + } else { + new_offsets.push(current_offset); + } + } + (UInt32Array::from(values), new_offsets) +} + #[cfg(test)] mod tests { use super::*; + use std::sync::Arc; + + use crate::array::ArrayData; + use crate::datatypes::{DataType, ToByteSlice}; + #[test] fn test_apply_bin_op_to_option_bitmap() { assert_eq!( @@ -80,4 +128,30 @@ mod tests { ); } + #[test] + fn test_take_value_index_from_list() { + let value_data = Int32Array::from((0..10).collect::>()).data(); + let value_offsets = Buffer::from(&[0, 2, 5, 10].to_byte_slice()); + let list_data_type = DataType::List(Box::new(DataType::Int32)); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(3) + .add_buffer(value_offsets.clone()) + .add_child_data(value_data.clone()) + .build(); + let array = Arc::new(ListArray::from(list_data)) as ArrayRef; + let index = UInt32Array::from(vec![2, 0]); + let (indexed, offsets) = take_value_indices_from_list(&array, &index); + assert_eq!(vec![0, 5, 7], offsets); + let data = UInt32Array::from(vec![ + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + Some(0), + Some(1), + ]) + .data(); + assert_eq!(data, indexed.data()); + } }