From e35394ecfe5e52c533e220c9e04c647dcb8c9179 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Wed, 4 Nov 2020 17:02:56 +0100 Subject: [PATCH 1/4] Added generic array builder. --- rust/arrow/src/array/mod.rs | 3 + rust/arrow/src/array/transform/boolean.rs | 40 ++ rust/arrow/src/array/transform/list.rs | 75 +++ rust/arrow/src/array/transform/mod.rs | 536 ++++++++++++++++++ rust/arrow/src/array/transform/primitive.rs | 35 ++ rust/arrow/src/array/transform/utils.rs | 63 ++ .../src/array/transform/variable_size.rs | 93 +++ rust/arrow/src/buffer.rs | 9 + 8 files changed, 854 insertions(+) create mode 100644 rust/arrow/src/array/transform/boolean.rs create mode 100644 rust/arrow/src/array/transform/list.rs create mode 100644 rust/arrow/src/array/transform/mod.rs create mode 100644 rust/arrow/src/array/transform/primitive.rs create mode 100644 rust/arrow/src/array/transform/utils.rs create mode 100644 rust/arrow/src/array/transform/variable_size.rs diff --git a/rust/arrow/src/array/mod.rs b/rust/arrow/src/array/mod.rs index d8cfb46449f..dd1fa0f57ee 100644 --- a/rust/arrow/src/array/mod.rs +++ b/rust/arrow/src/array/mod.rs @@ -99,6 +99,7 @@ mod iterator; mod null; mod ord; mod raw_pointer; +mod transform; use crate::datatypes::*; @@ -249,6 +250,8 @@ pub type DurationMillisecondBuilder = PrimitiveBuilder; pub type DurationMicrosecondBuilder = PrimitiveBuilder; pub type DurationNanosecondBuilder = PrimitiveBuilder; +pub use self::transform::MutableArrayData; + // --------------------- Array Iterator --------------------- pub use self::iterator::*; diff --git a/rust/arrow/src/array/transform/boolean.rs b/rust/arrow/src/array/transform/boolean.rs new file mode 100644 index 00000000000..889b99be88e --- /dev/null +++ b/rust/arrow/src/array/transform/boolean.rs @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::ArrayData; + +use super::{ + Extend, _MutableArrayData, + utils::{reserve_for_bits, set_bits}, +}; + +pub(super) fn build_extend(array: &ArrayData) -> Extend { + let values = array.buffers()[0].data(); + Box::new( + move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + let buffer = &mut mutable.buffers[0]; + reserve_for_bits(buffer, mutable.len + len); + set_bits( + &mut buffer.data_mut(), + values, + mutable.len, + array.offset() + start, + len, + ); + }, + ) +} diff --git a/rust/arrow/src/array/transform/list.rs b/rust/arrow/src/array/transform/list.rs new file mode 100644 index 00000000000..ff4df854643 --- /dev/null +++ b/rust/arrow/src/array/transform/list.rs @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + array::{ArrayData, OffsetSizeTrait}, + datatypes::ToByteSlice, +}; + +use super::{Extend, _MutableArrayData, utils::extend_offsets}; + +pub(super) fn build_extend(array: &ArrayData) -> Extend { + let offsets = array.buffer::(0); + if array.null_count() == 0 { + // fast case where we can copy regions without nullability checks + Box::new( + move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + let mutable_offsets = mutable.buffer::(0); + let last_offset = mutable_offsets[mutable_offsets.len() - 1]; + // offsets + extend_offsets::( + &mut mutable.buffers[0], + last_offset, + &offsets[start..start + len + 1], + ); + + mutable.child_data[0].extend( + offsets[start].to_usize().unwrap(), + offsets[start + len].to_usize().unwrap(), + ) + }, + ) + } else { + // nulls present: append item by item, ignoring null entries + Box::new( + move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + let mutable_offsets = mutable.buffer::(0); + let mut last_offset = mutable_offsets[mutable_offsets.len() - 1]; + + let buffer = &mut mutable.buffers[0]; + let delta_len = array.len() - array.null_count(); + buffer.reserve(buffer.len() + delta_len * std::mem::size_of::()); + + let child = &mut mutable.child_data[0]; + (start..start + len).for_each(|i| { + if array.is_valid(i) { + // compute the new offset + last_offset = last_offset + offsets[i + 1] - offsets[i]; + + // append value + child.extend( + offsets[i].to_usize().unwrap(), + offsets[i + 1].to_usize().unwrap(), + ); + } + // append offset + buffer.extend_from_slice(last_offset.to_byte_slice()); + }) + }, + ) + } +} diff --git a/rust/arrow/src/array/transform/mod.rs b/rust/arrow/src/array/transform/mod.rs new file mode 100644 index 00000000000..96362ab0284 --- /dev/null +++ b/rust/arrow/src/array/transform/mod.rs @@ -0,0 +1,536 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{mem::size_of, sync::Arc}; + +use crate::{buffer::MutableBuffer, datatypes::DataType, util::bit_util}; + +use super::{ArrayData, ArrayDataRef}; + +mod boolean; +mod list; +mod primitive; +mod utils; +mod variable_size; + +type ExtendNullBits<'a> = Box; +// function that extends `[start..start+len]` to the mutable array. +// this is dynamic because different data_types influence how buffers and childs are extended. +type Extend<'a> = Box; + +/// A mutable [ArrayData] that knows how to freeze itself into an [ArrayData]. +/// This is just a data container. +#[derive(Debug)] +struct _MutableArrayData<'a> { + pub data_type: DataType, + pub null_count: usize, + + pub len: usize, + pub null_buffer: MutableBuffer, + + pub buffers: Vec, + pub child_data: Vec>, +} + +impl<'a> _MutableArrayData<'a> { + fn freeze(self, dictionary: Option) -> ArrayData { + let mut buffers = Vec::with_capacity(self.buffers.len()); + for buffer in self.buffers { + buffers.push(buffer.freeze()); + } + + let child_data = match self.data_type { + DataType::Dictionary(_, _) => vec![dictionary.unwrap()], + _ => { + let mut child_data = Vec::with_capacity(self.child_data.len()); + for child in self.child_data { + child_data.push(Arc::new(child.freeze())); + } + child_data + } + }; + ArrayData::new( + self.data_type, + self.len, + Some(self.null_count), + if self.null_count > 0 { + Some(self.null_buffer.freeze()) + } else { + None + }, + 0, + buffers, + child_data, + ) + } + + /// Returns the buffer `buffer` as a slice of type `T`. When the expected buffer is bit-packed, + /// the slice is not offset. + #[inline] + pub(super) fn buffer(&self, buffer: usize) -> &[T] { + let values = unsafe { self.buffers[buffer].data().align_to::() }; + if !values.0.is_empty() || !values.2.is_empty() { + // this is unreachable because + unreachable!("The buffer is not byte-aligned with its interpretation") + }; + &values.1 + } +} + +fn build_extend_nulls(array: &ArrayData) -> ExtendNullBits { + if let Some(bitmap) = array.null_bitmap() { + let bytes = bitmap.bits.data(); + Box::new(move |mutable, start, len| { + utils::reserve_for_bits(&mut mutable.null_buffer, mutable.len + len); + mutable.null_count += utils::set_bits( + mutable.null_buffer.data_mut(), + bytes, + mutable.len, + array.offset() + start, + len, + ); + }) + } else { + Box::new(|_, _, _| {}) + } +} + +/// Struct to efficiently and interactively create an [ArrayData] from an existing [ArrayData] by +/// copying chunks. +/// The main use case of this struct is to perform unary operations to arrays of arbitrary types, such as `filter` and `take`. +/// # Example: +/// +/// ``` +/// use std::sync::Arc; +/// use arrow::{array::{Int32Array, Array, MutableArrayData}}; +/// +/// let array = Int32Array::from(vec![1, 2, 3, 4, 5]).data(); +/// // Create a new `MutableArrayData` from an array and with a capacity. +/// // Capacity here is equivalent to `Vec::with_capacity` +/// let mut mutable = MutableArrayData::new(&array, 4); +/// mutable.extend(1, 3); // extend from the slice [1..3], [2,3] +/// mutable.extend(0, 3); // extend from the slice [0..3], [1,2,3] +/// // `.freeze()` to convert `MutableArrayData` into a `ArrayData`. +/// let new_array = Int32Array::from(Arc::new(mutable.freeze())); +/// assert_eq!(Int32Array::from(vec![2, 3, 1, 2, 3]), new_array); +/// ``` +pub struct MutableArrayData<'a> { + // The attributes in [_MutableArrayData] cannot be in [MutableArrayData] due to + // mutability invariants (interior mutability): + // [MutableArrayData] contains a function that can only mutate [_MutableArrayData], not + // [MutableArrayData] itself + data: _MutableArrayData<'a>, + + // the child data of the `Array` in Dictionary arrays. + // This is not stored in `MutableArrayData` because these values constant and only needed + // at the end, when freezing [_MutableArrayData]. + dictionary: Option, + + // the function used to extend values. This function's lifetime is bound to the array + // because it reads values from it. + extend_values: Extend<'a>, + // the function used to extend nulls. This function's lifetime is bound to the array + // because it reads nulls from it. + extend_nulls: ExtendNullBits<'a>, +} + +impl<'a> std::fmt::Debug for MutableArrayData<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + // ignores the closures. + f.debug_struct("MutableArrayData") + .field("data", &self.data) + .finish() + } +} + +impl<'a> MutableArrayData<'a> { + /// returns a new [MutableArrayData] with capacity to `capacity` slots and specialized to create an + /// [ArrayData] from `array` + pub fn new(array: &'a ArrayData, capacity: usize) -> Self { + let data_type = array.data_type(); + use crate::datatypes::*; + let extend_values = match &data_type { + DataType::Boolean => boolean::build_extend(array), + DataType::UInt8 => primitive::build_extend::(array), + DataType::UInt16 => primitive::build_extend::(array), + DataType::UInt32 => primitive::build_extend::(array), + DataType::UInt64 => primitive::build_extend::(array), + DataType::Int8 => primitive::build_extend::(array), + DataType::Int16 => primitive::build_extend::(array), + DataType::Int32 => primitive::build_extend::(array), + DataType::Int64 => primitive::build_extend::(array), + DataType::Float32 => primitive::build_extend::(array), + DataType::Float64 => primitive::build_extend::(array), + DataType::Date32(_) + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + primitive::build_extend::(array) + } + DataType::Date64(_) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + | DataType::Interval(IntervalUnit::DayTime) => { + primitive::build_extend::(array) + } + DataType::Utf8 | DataType::Binary => { + variable_size::build_extend::(array) + } + DataType::LargeUtf8 | DataType::LargeBinary => { + variable_size::build_extend::(array) + } + DataType::List(_) => list::build_extend::(array), + DataType::LargeList(_) => list::build_extend::(array), + DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { + DataType::UInt8 => primitive::build_extend::(array), + DataType::UInt16 => primitive::build_extend::(array), + DataType::UInt32 => primitive::build_extend::(array), + DataType::UInt64 => primitive::build_extend::(array), + DataType::Int8 => primitive::build_extend::(array), + DataType::Int16 => primitive::build_extend::(array), + DataType::Int32 => primitive::build_extend::(array), + DataType::Int64 => primitive::build_extend::(array), + _ => unreachable!(), + }, + DataType::Float16 => unreachable!(), + /* + DataType::Null => {} + DataType::FixedSizeBinary(_) => {} + DataType::FixedSizeList(_, _) => {} + DataType::Struct(_) => {} + DataType::Union(_) => {} + */ + _ => { + todo!("Take and filter operations still not supported for this datatype") + } + }; + + let buffers = match &data_type { + DataType::Boolean => { + let bytes = bit_util::ceil(capacity, 8); + let buffer = MutableBuffer::new(bytes).with_bitset(bytes, false); + vec![buffer] + } + DataType::UInt8 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::UInt16 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::UInt32 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::UInt64 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int8 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int16 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int32 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int64 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Float32 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Float64 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Date32(_) | DataType::Time32(_) => { + vec![MutableBuffer::new(capacity * size_of::())] + } + DataType::Date64(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Timestamp(_, _) => { + vec![MutableBuffer::new(capacity * size_of::())] + } + DataType::Interval(IntervalUnit::YearMonth) => { + vec![MutableBuffer::new(capacity * size_of::())] + } + DataType::Interval(IntervalUnit::DayTime) => { + vec![MutableBuffer::new(capacity * size_of::())] + } + DataType::Utf8 | DataType::Binary => { + let mut buffer = MutableBuffer::new((1 + capacity) * size_of::()); + buffer.extend_from_slice(&[0i32].to_byte_slice()); + vec![buffer, MutableBuffer::new(capacity * size_of::())] + } + DataType::LargeUtf8 | DataType::LargeBinary => { + let mut buffer = MutableBuffer::new((1 + capacity) * size_of::()); + buffer.extend_from_slice(&[0i64].to_byte_slice()); + vec![buffer, MutableBuffer::new(capacity * size_of::())] + } + DataType::List(_) => { + // offset buffer always starts with a zero + let mut buffer = MutableBuffer::new((1 + capacity) * size_of::()); + buffer.extend_from_slice(0i32.to_byte_slice()); + vec![buffer] + } + DataType::LargeList(_) => { + // offset buffer always starts with a zero + let mut buffer = MutableBuffer::new((1 + capacity) * size_of::()); + buffer.extend_from_slice(&[0i64].to_byte_slice()); + vec![buffer] + } + DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { + DataType::UInt8 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::UInt16 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::UInt32 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::UInt64 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int8 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int16 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int32 => vec![MutableBuffer::new(capacity * size_of::())], + DataType::Int64 => vec![MutableBuffer::new(capacity * size_of::())], + _ => unreachable!(), + }, + DataType::Float16 => unreachable!(), + _ => { + todo!("Take and filter operations still not supported for this datatype") + } + }; + + let child_data = match &data_type { + DataType::Null + | DataType::Boolean + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32(_) + | DataType::Date64(_) + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Timestamp(_, _) + | DataType::Utf8 + | DataType::Binary + | DataType::LargeUtf8 + | DataType::LargeBinary + | DataType::Interval(_) + | DataType::FixedSizeBinary(_) => vec![], + DataType::List(_) | DataType::LargeList(_) => { + vec![MutableArrayData::new(&array.child_data()[0], capacity)] + } + // the dictionary type just appends keys and clones the values. + DataType::Dictionary(_, _) => vec![], + DataType::Float16 => unreachable!(), + _ => { + todo!("Take and filter operations still not supported for this datatype") + } + }; + + let dictionary = match &data_type { + DataType::Dictionary(_, _) => Some(array.child_data()[0].clone()), + _ => None, + }; + + let extend_nulls = build_extend_nulls(array); + + let null_bytes = bit_util::ceil(capacity, 8); + let null_buffer = MutableBuffer::new(null_bytes).with_bitset(null_bytes, false); + + let data = _MutableArrayData { + data_type: data_type.clone(), + len: 0, + null_count: 0, + null_buffer, + buffers, + child_data, + }; + Self { + data, + dictionary, + extend_values: Box::new(extend_values), + extend_nulls, + } + } + + /// Extends this [MutableArrayData] with elements from the bounded [ArrayData] at `start` + /// and for a size of `len`. + /// # Panic + /// This function panics if the range is out of bounds, i.e. if `start + len >= array.len()`. + pub fn extend(&mut self, start: usize, end: usize) { + let len = end - start; + (self.extend_nulls)(&mut self.data, start, len); + (self.extend_values)(&mut self.data, start, len); + self.data.len += len; + } + + /// Creates a [ArrayData] from the pushed regions up to this point, consuming `self`. + pub fn freeze(self) -> ArrayData { + self.data.freeze(self.dictionary) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::array::{ + Array, ArrayDataRef, BooleanArray, DictionaryArray, Int16Array, Int16Type, + Int64Builder, ListBuilder, PrimitiveBuilder, StringArray, + StringDictionaryBuilder, UInt8Array, + }; + use crate::{array::ListArray, error::Result}; + + /// tests extending from a primitive array w/ offset nor nulls + #[test] + fn test_primitive() { + let b = UInt8Array::from(vec![Some(1), Some(2), Some(3)]).data(); + let mut a = MutableArrayData::new(&b, 3); + a.extend(0, 2); + let result = a.freeze(); + let array = UInt8Array::from(Arc::new(result)); + let expected = UInt8Array::from(vec![Some(1), Some(2)]); + assert_eq!(array, expected); + } + + /// tests extending from a primitive array with offset w/ nulls + #[test] + fn test_primitive_offset() { + let b = UInt8Array::from(vec![Some(1), Some(2), Some(3)]); + let b = b.slice(1, 2).data(); + let mut a = MutableArrayData::new(&b, 2); + a.extend(0, 2); + let result = a.freeze(); + let array = UInt8Array::from(Arc::new(result)); + let expected = UInt8Array::from(vec![Some(2), Some(3)]); + assert_eq!(array, expected); + } + + /// tests extending from a primitive array with offset and nulls + #[test] + fn test_primitive_null_offset() { + let b = UInt8Array::from(vec![Some(1), None, Some(3)]); + let b = b.slice(1, 2).data(); + let mut a = MutableArrayData::new(&b, 2); + a.extend(0, 2); + let result = a.freeze(); + let array = UInt8Array::from(Arc::new(result)); + let expected = UInt8Array::from(vec![None, Some(3)]); + assert_eq!(array, expected); + } + + #[test] + fn test_list_null_offset() -> Result<()> { + let int_builder = Int64Builder::new(24); + let mut builder = ListBuilder::::new(int_builder); + builder.values().append_slice(&[1, 2, 3])?; + builder.append(true)?; + builder.values().append_slice(&[4, 5])?; + builder.append(true)?; + builder.values().append_slice(&[6, 7, 8])?; + builder.append(true)?; + let array = builder.finish().data(); + + let mut mutable = MutableArrayData::new(&array, 0); + mutable.extend(0, 1); + + let result = mutable.freeze(); + let array = ListArray::from(Arc::new(result)); + + let int_builder = Int64Builder::new(24); + let mut builder = ListBuilder::::new(int_builder); + builder.values().append_slice(&[1, 2, 3])?; + builder.append(true)?; + let expected = builder.finish(); + + assert_eq!(array, expected); + + Ok(()) + } + + /// tests extending from a variable-sized (strings and binary) array w/ offset with nulls + #[test] + fn test_variable_sized_nulls() { + let array = + StringArray::from(vec![Some("a"), Some("bc"), None, Some("defh")]).data(); + + let mut mutable = MutableArrayData::new(&array, 0); + + mutable.extend(1, 3); + + let result = mutable.freeze(); + let result = StringArray::from(Arc::new(result)); + + let expected = StringArray::from(vec![Some("bc"), None]); + assert_eq!(result, expected); + } + + /// tests extending from a variable-sized (strings and binary) array + /// with an offset and nulls + #[test] + fn test_variable_sized_offsets() { + let array = + StringArray::from(vec![Some("a"), Some("bc"), None, Some("defh")]).data(); + let array = array.slice(1, 3); + + let mut mutable = MutableArrayData::new(&array, 0); + + mutable.extend(0, 3); + + let result = mutable.freeze(); + let result = StringArray::from(Arc::new(result)); + + let expected = StringArray::from(vec![Some("bc"), None, Some("defh")]); + assert_eq!(result, expected); + } + + #[test] + fn test_bool() { + let array = + BooleanArray::from(vec![Some(false), Some(true), None, Some(false)]).data(); + + let mut mutable = MutableArrayData::new(&array, 0); + + mutable.extend(1, 3); + + let result = mutable.freeze(); + let result = BooleanArray::from(Arc::new(result)); + + let expected = BooleanArray::from(vec![Some(true), None]); + assert_eq!(result, expected); + } + + fn create_dictionary_array(values: &[&str], keys: &[Option<&str>]) -> ArrayDataRef { + let values = StringArray::from(values.to_vec()); + let mut builder = StringDictionaryBuilder::new_with_dictionary( + PrimitiveBuilder::::new(3), + &values, + ) + .unwrap(); + for key in keys { + if let Some(v) = key { + builder.append(v).unwrap(); + } else { + builder.append_null().unwrap() + } + } + builder.finish().data() + } + + #[test] + fn test_dictionary() { + // (a, b, c), (0, 1, 0, 2) => (a, b, a, c) + let array = create_dictionary_array( + &["a", "b", "c"], + &[Some("a"), Some("b"), None, Some("c")], + ); + + let mut mutable = MutableArrayData::new(&array, 0); + + mutable.extend(1, 3); + + let result = mutable.freeze(); + let result = DictionaryArray::from(Arc::new(result)); + + let expected = Int16Array::from(vec![Some(1), None]); + assert_eq!(result.keys(), &expected); + } +} diff --git a/rust/arrow/src/array/transform/primitive.rs b/rust/arrow/src/array/transform/primitive.rs new file mode 100644 index 00000000000..d2b44f28d42 --- /dev/null +++ b/rust/arrow/src/array/transform/primitive.rs @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::mem::size_of; + +use crate::{array::ArrayData, datatypes::ArrowNativeType}; + +use super::{Extend, _MutableArrayData}; + +pub(super) fn build_extend(array: &ArrayData) -> Extend { + let values = &array.buffers()[0].data()[array.offset() * size_of::()..]; + Box::new( + move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + let start = start * size_of::(); + let len = len * size_of::(); + let bytes = &values[start..start + len]; + let buffer = &mut mutable.buffers[0]; + buffer.extend_from_slice(bytes); + }, + ) +} diff --git a/rust/arrow/src/array/transform/utils.rs b/rust/arrow/src/array/transform/utils.rs new file mode 100644 index 00000000000..df9ce2453be --- /dev/null +++ b/rust/arrow/src/array/transform/utils.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + array::OffsetSizeTrait, buffer::MutableBuffer, datatypes::ToByteSlice, util::bit_util, +}; + +/// extends the `buffer` to be able to hold `len` bits, setting all bits of the new size to zero. +#[inline] +pub(super) fn reserve_for_bits(buffer: &mut MutableBuffer, len: usize) { + let needed_bytes = bit_util::ceil(len, 8); + if buffer.len() < needed_bytes { + buffer.extend(needed_bytes - buffer.len()); + } +} + +/// sets all bits on `write_data` on the range `[offset_write..offset_write+len]` to be equal to the +/// bits on `data` on the range `[offset_read..offset_read+len]` +pub(super) fn set_bits( + write_data: &mut [u8], + data: &[u8], + offset_write: usize, + offset_read: usize, + len: usize, +) -> usize { + let mut count = 0; + (0..len).for_each(|i| { + if bit_util::get_bit(data, offset_read + i) { + bit_util::set_bit(write_data, offset_write + i); + } else { + count += 1; + } + }); + count +} + +pub(super) fn extend_offsets( + buffer: &mut MutableBuffer, + mut last_offset: T, + offsets: &[T], +) { + buffer.reserve(buffer.len() + offsets.len() * std::mem::size_of::()); + offsets.windows(2).for_each(|offsets| { + // compute the new offset + let length = offsets[1] - offsets[0]; + last_offset = last_offset + length; + buffer.extend_from_slice(last_offset.to_byte_slice()); + }); +} diff --git a/rust/arrow/src/array/transform/variable_size.rs b/rust/arrow/src/array/transform/variable_size.rs new file mode 100644 index 00000000000..6e7c80a97e1 --- /dev/null +++ b/rust/arrow/src/array/transform/variable_size.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + array::{ArrayData, OffsetSizeTrait}, + buffer::MutableBuffer, + datatypes::ToByteSlice, +}; + +use super::{Extend, _MutableArrayData, utils::extend_offsets}; + +fn extend_offset_values( + buffer: &mut MutableBuffer, + offsets: &[T], + values: &[u8], + start: usize, + len: usize, +) { + let start_values = offsets[start].to_usize().unwrap(); + let end_values = offsets[start + len].to_usize().unwrap(); + let new_values = &values[start_values..end_values]; + buffer.extend_from_slice(new_values); +} + +pub(super) fn build_extend(array: &ArrayData) -> Extend { + let offsets = array.buffer::(0); + let values = &array.buffers()[1].data()[array.offset()..]; + if array.null_count() == 0 { + // fast case where we can copy regions without null issues + Box::new( + move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + let mutable_offsets = mutable.buffer::(0); + let last_offset = mutable_offsets[mutable_offsets.len() - 1]; + // offsets + let buffer = &mut mutable.buffers[0]; + extend_offsets::( + buffer, + last_offset, + &offsets[start..start + len + 1], + ); + // values + let buffer = &mut mutable.buffers[1]; + extend_offset_values::(buffer, offsets, values, start, len); + }, + ) + } else { + Box::new( + move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + let mutable_offsets = mutable.buffer::(0); + let mut last_offset = mutable_offsets[mutable_offsets.len() - 1]; + + // nulls present: append item by item, ignoring null entries + let (offset_buffer, values_buffer) = mutable.buffers.split_at_mut(1); + let offset_buffer = &mut offset_buffer[0]; + let values_buffer = &mut values_buffer[0]; + offset_buffer.reserve( + offset_buffer.len() + array.len() * std::mem::size_of::(), + ); + + (start..start + len).for_each(|i| { + if array.is_valid(i) { + // compute the new offset + let length = offsets[i + 1] - offsets[i]; + last_offset = last_offset + length; + let length = length.to_usize().unwrap(); + + // append value + let start = offsets[i].to_usize().unwrap() + - offsets[0].to_usize().unwrap(); + let bytes = &values[start..(start + length)]; + values_buffer.extend_from_slice(bytes); + } + // offsets are always present + offset_buffer.extend_from_slice(last_offset.to_byte_slice()); + }) + }, + ) + } +} diff --git a/rust/arrow/src/buffer.rs b/rust/arrow/src/buffer.rs index ece811b742f..0e9d481b79d 100644 --- a/rust/arrow/src/buffer.rs +++ b/rust/arrow/src/buffer.rs @@ -888,6 +888,15 @@ impl MutableBuffer { } self.len += bytes.len(); } + + /// Extends the buffer by `len` with all bytes equal to `0u8`, incrementing its capacity if needed. + pub fn extend(&mut self, len: usize) { + let remaining_capacity = self.capacity - self.len; + if len > remaining_capacity { + self.reserve(self.len + len); + } + self.len += len; + } } impl Drop for MutableBuffer { From 3776085d5ef7a53a817c5bcee11f39d39a2711b4 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Wed, 18 Nov 2020 21:06:36 +0100 Subject: [PATCH 2/4] Generalized mutableArrayData for multiple arrays. --- rust/arrow/src/array/transform/boolean.rs | 2 +- rust/arrow/src/array/transform/list.rs | 12 +- rust/arrow/src/array/transform/mod.rs | 185 ++++++++++-------- rust/arrow/src/array/transform/primitive.rs | 2 +- .../src/array/transform/variable_size.rs | 4 +- 5 files changed, 117 insertions(+), 88 deletions(-) diff --git a/rust/arrow/src/array/transform/boolean.rs b/rust/arrow/src/array/transform/boolean.rs index 889b99be88e..31634f4bf8b 100644 --- a/rust/arrow/src/array/transform/boolean.rs +++ b/rust/arrow/src/array/transform/boolean.rs @@ -25,7 +25,7 @@ use super::{ pub(super) fn build_extend(array: &ArrayData) -> Extend { let values = array.buffers()[0].data(); Box::new( - move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { let buffer = &mut mutable.buffers[0]; reserve_for_bits(buffer, mutable.len + len); set_bits( diff --git a/rust/arrow/src/array/transform/list.rs b/rust/arrow/src/array/transform/list.rs index ff4df854643..8a8ccdf631a 100644 --- a/rust/arrow/src/array/transform/list.rs +++ b/rust/arrow/src/array/transform/list.rs @@ -27,7 +27,10 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { if array.null_count() == 0 { // fast case where we can copy regions without nullability checks Box::new( - move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + move |mutable: &mut _MutableArrayData, + index: usize, + start: usize, + len: usize| { let mutable_offsets = mutable.buffer::(0); let last_offset = mutable_offsets[mutable_offsets.len() - 1]; // offsets @@ -38,6 +41,7 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { ); mutable.child_data[0].extend( + index, offsets[start].to_usize().unwrap(), offsets[start + len].to_usize().unwrap(), ) @@ -46,7 +50,10 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { } else { // nulls present: append item by item, ignoring null entries Box::new( - move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + move |mutable: &mut _MutableArrayData, + index: usize, + start: usize, + len: usize| { let mutable_offsets = mutable.buffer::(0); let mut last_offset = mutable_offsets[mutable_offsets.len() - 1]; @@ -62,6 +69,7 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { // append value child.extend( + index, offsets[i].to_usize().unwrap(), offsets[i + 1].to_usize().unwrap(), ); diff --git a/rust/arrow/src/array/transform/mod.rs b/rust/arrow/src/array/transform/mod.rs index 96362ab0284..410ca369f67 100644 --- a/rust/arrow/src/array/transform/mod.rs +++ b/rust/arrow/src/array/transform/mod.rs @@ -30,7 +30,7 @@ mod variable_size; type ExtendNullBits<'a> = Box; // function that extends `[start..start+len]` to the mutable array. // this is dynamic because different data_types influence how buffers and childs are extended. -type Extend<'a> = Box; +type Extend<'a> = Box; /// A mutable [ArrayData] that knows how to freeze itself into an [ArrayData]. /// This is just a data container. @@ -121,14 +121,16 @@ fn build_extend_nulls(array: &ArrayData) -> ExtendNullBits { /// let array = Int32Array::from(vec![1, 2, 3, 4, 5]).data(); /// // Create a new `MutableArrayData` from an array and with a capacity. /// // Capacity here is equivalent to `Vec::with_capacity` -/// let mut mutable = MutableArrayData::new(&array, 4); -/// mutable.extend(1, 3); // extend from the slice [1..3], [2,3] -/// mutable.extend(0, 3); // extend from the slice [0..3], [1,2,3] +/// let arrays = vec![array.as_ref()]; +/// let mut mutable = MutableArrayData::new(arrays, 4); +/// mutable.extend(0, 1, 3); // extend from the slice [1..3], [2,3] +/// mutable.extend(0, 0, 3); // extend from the slice [0..3], [1,2,3] /// // `.freeze()` to convert `MutableArrayData` into a `ArrayData`. /// let new_array = Int32Array::from(Arc::new(mutable.freeze())); /// assert_eq!(Int32Array::from(vec![2, 3, 1, 2, 3]), new_array); /// ``` pub struct MutableArrayData<'a> { + arrays: Vec<&'a ArrayData>, // The attributes in [_MutableArrayData] cannot be in [MutableArrayData] due to // mutability invariants (interior mutability): // [MutableArrayData] contains a function that can only mutate [_MutableArrayData], not @@ -142,10 +144,10 @@ pub struct MutableArrayData<'a> { // the function used to extend values. This function's lifetime is bound to the array // because it reads values from it. - extend_values: Extend<'a>, + extend_values: Vec>, // the function used to extend nulls. This function's lifetime is bound to the array // because it reads nulls from it. - extend_nulls: ExtendNullBits<'a>, + extend_nulls: Vec>, } impl<'a> std::fmt::Debug for MutableArrayData<'a> { @@ -157,14 +159,39 @@ impl<'a> std::fmt::Debug for MutableArrayData<'a> { } } -impl<'a> MutableArrayData<'a> { - /// returns a new [MutableArrayData] with capacity to `capacity` slots and specialized to create an - /// [ArrayData] from `array` - pub fn new(array: &'a ArrayData, capacity: usize) -> Self { - let data_type = array.data_type(); - use crate::datatypes::*; - let extend_values = match &data_type { - DataType::Boolean => boolean::build_extend(array), +fn build_extend<'a>(array: &'a ArrayData) -> Extend<'a> { + use crate::datatypes::*; + match array.data_type() { + DataType::Boolean => boolean::build_extend(array), + DataType::UInt8 => primitive::build_extend::(array), + DataType::UInt16 => primitive::build_extend::(array), + DataType::UInt32 => primitive::build_extend::(array), + DataType::UInt64 => primitive::build_extend::(array), + DataType::Int8 => primitive::build_extend::(array), + DataType::Int16 => primitive::build_extend::(array), + DataType::Int32 => primitive::build_extend::(array), + DataType::Int64 => primitive::build_extend::(array), + DataType::Float32 => primitive::build_extend::(array), + DataType::Float64 => primitive::build_extend::(array), + DataType::Date32(_) + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + primitive::build_extend::(array) + } + DataType::Date64(_) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + | DataType::Interval(IntervalUnit::DayTime) => { + primitive::build_extend::(array) + } + DataType::Utf8 | DataType::Binary => variable_size::build_extend::(array), + DataType::LargeUtf8 | DataType::LargeBinary => { + variable_size::build_extend::(array) + } + DataType::List(_) => list::build_extend::(array), + DataType::LargeList(_) => list::build_extend::(array), + DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { DataType::UInt8 => primitive::build_extend::(array), DataType::UInt16 => primitive::build_extend::(array), DataType::UInt32 => primitive::build_extend::(array), @@ -173,51 +200,26 @@ impl<'a> MutableArrayData<'a> { DataType::Int16 => primitive::build_extend::(array), DataType::Int32 => primitive::build_extend::(array), DataType::Int64 => primitive::build_extend::(array), - DataType::Float32 => primitive::build_extend::(array), - DataType::Float64 => primitive::build_extend::(array), - DataType::Date32(_) - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { - primitive::build_extend::(array) - } - DataType::Date64(_) - | DataType::Time64(_) - | DataType::Timestamp(_, _) - | DataType::Duration(_) - | DataType::Interval(IntervalUnit::DayTime) => { - primitive::build_extend::(array) - } - DataType::Utf8 | DataType::Binary => { - variable_size::build_extend::(array) - } - DataType::LargeUtf8 | DataType::LargeBinary => { - variable_size::build_extend::(array) - } - DataType::List(_) => list::build_extend::(array), - DataType::LargeList(_) => list::build_extend::(array), - DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { - DataType::UInt8 => primitive::build_extend::(array), - DataType::UInt16 => primitive::build_extend::(array), - DataType::UInt32 => primitive::build_extend::(array), - DataType::UInt64 => primitive::build_extend::(array), - DataType::Int8 => primitive::build_extend::(array), - DataType::Int16 => primitive::build_extend::(array), - DataType::Int32 => primitive::build_extend::(array), - DataType::Int64 => primitive::build_extend::(array), - _ => unreachable!(), - }, - DataType::Float16 => unreachable!(), - /* - DataType::Null => {} - DataType::FixedSizeBinary(_) => {} - DataType::FixedSizeList(_, _) => {} - DataType::Struct(_) => {} - DataType::Union(_) => {} - */ - _ => { - todo!("Take and filter operations still not supported for this datatype") - } - }; + _ => unreachable!(), + }, + DataType::Float16 => unreachable!(), + /* + DataType::Null => {} + DataType::FixedSizeBinary(_) => {} + DataType::FixedSizeList(_, _) => {} + DataType::Struct(_) => {} + DataType::Union(_) => {} + */ + _ => todo!("Take and filter operations still not supported for this datatype"), + } +} + +impl<'a> MutableArrayData<'a> { + /// returns a new [MutableArrayData] with capacity to `capacity` slots and specialized to create an + /// [ArrayData] from `array` + pub fn new(arrays: Vec<&'a ArrayData>, capacity: usize) -> Self { + let data_type = arrays[0].data_type(); + use crate::datatypes::*; let buffers = match &data_type { DataType::Boolean => { @@ -315,7 +317,11 @@ impl<'a> MutableArrayData<'a> { | DataType::Interval(_) | DataType::FixedSizeBinary(_) => vec![], DataType::List(_) | DataType::LargeList(_) => { - vec![MutableArrayData::new(&array.child_data()[0], capacity)] + let childs = arrays + .iter() + .map(|array| array.child_data()[0].as_ref()) + .collect::>(); + vec![MutableArrayData::new(childs, capacity)] } // the dictionary type just appends keys and clones the values. DataType::Dictionary(_, _) => vec![], @@ -326,15 +332,20 @@ impl<'a> MutableArrayData<'a> { }; let dictionary = match &data_type { - DataType::Dictionary(_, _) => Some(array.child_data()[0].clone()), + DataType::Dictionary(_, _) => Some(arrays[0].child_data()[0].clone()), _ => None, }; - let extend_nulls = build_extend_nulls(array); + let extend_nulls = arrays + .iter() + .map(|array| build_extend_nulls(array)) + .collect(); let null_bytes = bit_util::ceil(capacity, 8); let null_buffer = MutableBuffer::new(null_bytes).with_bitset(null_bytes, false); + let extend_values = arrays.iter().map(|array| build_extend(array)).collect(); + let data = _MutableArrayData { data_type: data_type.clone(), len: 0, @@ -344,9 +355,10 @@ impl<'a> MutableArrayData<'a> { child_data, }; Self { + arrays: arrays.to_vec(), data, dictionary, - extend_values: Box::new(extend_values), + extend_values, extend_nulls, } } @@ -355,10 +367,10 @@ impl<'a> MutableArrayData<'a> { /// and for a size of `len`. /// # Panic /// This function panics if the range is out of bounds, i.e. if `start + len >= array.len()`. - pub fn extend(&mut self, start: usize, end: usize) { + pub fn extend(&mut self, index: usize, start: usize, end: usize) { let len = end - start; - (self.extend_nulls)(&mut self.data, start, len); - (self.extend_values)(&mut self.data, start, len); + (self.extend_nulls[index])(&mut self.data, start, len); + (self.extend_values[index])(&mut self.data, index, start, len); self.data.len += len; } @@ -383,8 +395,9 @@ mod tests { #[test] fn test_primitive() { let b = UInt8Array::from(vec![Some(1), Some(2), Some(3)]).data(); - let mut a = MutableArrayData::new(&b, 3); - a.extend(0, 2); + let arrays = vec![b.as_ref()]; + let mut a = MutableArrayData::new(arrays, 3); + a.extend(0, 0, 2); let result = a.freeze(); let array = UInt8Array::from(Arc::new(result)); let expected = UInt8Array::from(vec![Some(1), Some(2)]); @@ -396,8 +409,9 @@ mod tests { fn test_primitive_offset() { let b = UInt8Array::from(vec![Some(1), Some(2), Some(3)]); let b = b.slice(1, 2).data(); - let mut a = MutableArrayData::new(&b, 2); - a.extend(0, 2); + let arrays = vec![b.as_ref()]; + let mut a = MutableArrayData::new(arrays, 2); + a.extend(0, 0, 2); let result = a.freeze(); let array = UInt8Array::from(Arc::new(result)); let expected = UInt8Array::from(vec![Some(2), Some(3)]); @@ -409,8 +423,9 @@ mod tests { fn test_primitive_null_offset() { let b = UInt8Array::from(vec![Some(1), None, Some(3)]); let b = b.slice(1, 2).data(); - let mut a = MutableArrayData::new(&b, 2); - a.extend(0, 2); + let arrays = vec![b.as_ref()]; + let mut a = MutableArrayData::new(arrays, 2); + a.extend(0, 0, 2); let result = a.freeze(); let array = UInt8Array::from(Arc::new(result)); let expected = UInt8Array::from(vec![None, Some(3)]); @@ -428,9 +443,10 @@ mod tests { builder.values().append_slice(&[6, 7, 8])?; builder.append(true)?; let array = builder.finish().data(); + let arrays = vec![array.as_ref()]; - let mut mutable = MutableArrayData::new(&array, 0); - mutable.extend(0, 1); + let mut mutable = MutableArrayData::new(arrays, 0); + mutable.extend(0, 0, 1); let result = mutable.freeze(); let array = ListArray::from(Arc::new(result)); @@ -451,10 +467,11 @@ mod tests { fn test_variable_sized_nulls() { let array = StringArray::from(vec![Some("a"), Some("bc"), None, Some("defh")]).data(); + let arrays = vec![array.as_ref()]; - let mut mutable = MutableArrayData::new(&array, 0); + let mut mutable = MutableArrayData::new(arrays, 0); - mutable.extend(1, 3); + mutable.extend(0, 1, 3); let result = mutable.freeze(); let result = StringArray::from(Arc::new(result)); @@ -471,9 +488,11 @@ mod tests { StringArray::from(vec![Some("a"), Some("bc"), None, Some("defh")]).data(); let array = array.slice(1, 3); - let mut mutable = MutableArrayData::new(&array, 0); + let arrays = vec![&array]; + + let mut mutable = MutableArrayData::new(arrays, 0); - mutable.extend(0, 3); + mutable.extend(0, 0, 3); let result = mutable.freeze(); let result = StringArray::from(Arc::new(result)); @@ -486,10 +505,11 @@ mod tests { fn test_bool() { let array = BooleanArray::from(vec![Some(false), Some(true), None, Some(false)]).data(); + let arrays = vec![array.as_ref()]; - let mut mutable = MutableArrayData::new(&array, 0); + let mut mutable = MutableArrayData::new(arrays, 0); - mutable.extend(1, 3); + mutable.extend(0, 1, 3); let result = mutable.freeze(); let result = BooleanArray::from(Arc::new(result)); @@ -522,10 +542,11 @@ mod tests { &["a", "b", "c"], &[Some("a"), Some("b"), None, Some("c")], ); + let arrays = vec![array.as_ref()]; - let mut mutable = MutableArrayData::new(&array, 0); + let mut mutable = MutableArrayData::new(arrays, 0); - mutable.extend(1, 3); + mutable.extend(0, 1, 3); let result = mutable.freeze(); let result = DictionaryArray::from(Arc::new(result)); diff --git a/rust/arrow/src/array/transform/primitive.rs b/rust/arrow/src/array/transform/primitive.rs index d2b44f28d42..356080221d4 100644 --- a/rust/arrow/src/array/transform/primitive.rs +++ b/rust/arrow/src/array/transform/primitive.rs @@ -24,7 +24,7 @@ use super::{Extend, _MutableArrayData}; pub(super) fn build_extend(array: &ArrayData) -> Extend { let values = &array.buffers()[0].data()[array.offset() * size_of::()..]; Box::new( - move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { let start = start * size_of::(); let len = len * size_of::(); let bytes = &values[start..start + len]; diff --git a/rust/arrow/src/array/transform/variable_size.rs b/rust/arrow/src/array/transform/variable_size.rs index 6e7c80a97e1..48e77d53415 100644 --- a/rust/arrow/src/array/transform/variable_size.rs +++ b/rust/arrow/src/array/transform/variable_size.rs @@ -42,7 +42,7 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { if array.null_count() == 0 { // fast case where we can copy regions without null issues Box::new( - move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { let mutable_offsets = mutable.buffer::(0); let last_offset = mutable_offsets[mutable_offsets.len() - 1]; // offsets @@ -59,7 +59,7 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { ) } else { Box::new( - move |mutable: &mut _MutableArrayData, start: usize, len: usize| { + move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { let mutable_offsets = mutable.buffer::(0); let mut last_offset = mutable_offsets[mutable_offsets.len() - 1]; From 65d3d23004666e28000a89b4748ebf7f129f3146 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Tue, 17 Nov 2020 06:07:47 +0100 Subject: [PATCH 3/4] Added join --- .../src/physical_plan/hash_aggregate.rs | 2 +- .../datafusion/src/physical_plan/hash_join.rs | 524 ++++++++++++++++++ .../src/physical_plan/hash_utils.rs | 172 ++++++ rust/datafusion/src/physical_plan/mod.rs | 2 + rust/datafusion/src/test/mod.rs | 30 +- 5 files changed, 728 insertions(+), 2 deletions(-) create mode 100644 rust/datafusion/src/physical_plan/hash_join.rs create mode 100644 rust/datafusion/src/physical_plan/hash_utils.rs diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index cc893b96a4c..57a4a01f0fc 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -721,7 +721,7 @@ fn finalize_aggregation( } /// Create a Vec that can be used as a map key -fn create_key( +pub(crate) fn create_key( group_by_keys: &[ArrayRef], row: usize, vec: &mut Vec, diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs new file mode 100644 index 00000000000..de12d912ba0 --- /dev/null +++ b/rust/datafusion/src/physical_plan/hash_join.rs @@ -0,0 +1,524 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the join plan for executing partitions in parallel and then joining the results +//! into a set of partitions. + +use std::sync::Arc; +use std::{ + any::Any, + collections::{HashMap, HashSet}, +}; + +use async_trait::async_trait; +use futures::{Stream, StreamExt, TryStreamExt}; + +use arrow::array::{make_array, Array, MutableArrayData}; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::error::Result as ArrowResult; +use arrow::record_batch::RecordBatch; + +use super::{expressions::col, hash_aggregate::create_key}; +use super::{ + hash_utils::{build_join_schema, check_join_is_valid, JoinOn, JoinType}, + merge::MergeExec, +}; +use crate::error::{DataFusionError, Result}; + +use super::{ + group_scalar::GroupByScalar, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, +}; + +// An index of (batch, row) uniquely identifying a row in a part. +type Index = (usize, usize); +// A pair (left index, right index) +// Note that while this is currently equal to `Index`, the `JoinIndex` is semantically different +// as a left join may issue None indices, in which case +type JoinIndex = Option<(usize, usize)>; +// Maps ["on" value] -> [list of indices with this key's value] +// E.g. [1, 2] -> [(0, 3), (1, 6), (0, 8)] indicates that (column1, column2) = [1, 2] is true +// for rows 3 and 8 from batch 0 and row 6 from batch 1. +type JoinHashMap = HashMap, Vec>; +type JoinLeftData = (JoinHashMap, Vec); + +/// join execution plan executes partitions in parallel and combines them into a set of +/// partitions. +#[derive(Debug)] +pub struct HashJoinExec { + /// left (build) side which gets hashed + left: Arc, + /// right (probe) side which are filtered by the hash table + right: Arc, + /// Set of common columns used to join on + on: Vec<(String, String)>, + /// How the join is performed + join_type: JoinType, + /// The schema once the join is applied + schema: SchemaRef, +} + +impl HashJoinExec { + /// Tries to create a new [HashJoinExec]. + /// # Error + /// This function errors when it is not possible to join the left and right sides on keys `on`. + pub fn try_new( + left: Arc, + right: Arc, + on: &JoinOn, + join_type: &JoinType, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + check_join_is_valid(&left_schema, &right_schema, &on)?; + + let schema = Arc::new(build_join_schema( + &left_schema, + &right_schema, + on, + &join_type, + )); + + let on = on + .iter() + .map(|(l, r)| (l.to_string(), r.to_string())) + .collect(); + + Ok(HashJoinExec { + left, + right, + on, + join_type: join_type.clone(), + schema, + }) + } +} + +#[async_trait] +impl ExecutionPlan for HashJoinExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec> { + vec![self.left.clone(), self.right.clone()] + } + + fn with_new_children( + &self, + children: Vec>, + ) -> Result> { + match children.len() { + 2 => Ok(Arc::new(HashJoinExec::try_new( + children[0].clone(), + children[1].clone(), + &self + .on + .iter() + .map(|(x, y)| (x.as_str(), y.as_str())) + .collect::>(), + &self.join_type, + )?)), + _ => Err(DataFusionError::Internal( + "HashJoinExec wrong number of children".to_string(), + )), + } + } + + fn output_partitioning(&self) -> Partitioning { + self.right.output_partitioning() + } + + async fn execute(&self, partition: usize) -> Result { + // merge all parts into a single stream + // this is currently expensive as we re-compute this for every part from the right + // TODO: Fix this issue: we can't share this state across parts on the right. + // We need to change this `execute` to allow sharing state across parts... + let merge = MergeExec::new(self.left.clone()); + let stream = merge.execute(0).await?; + + let on_left = self + .on + .iter() + .map(|on| on.0.clone()) + .collect::>(); + let on_right = self + .on + .iter() + .map(|on| on.1.clone()) + .collect::>(); + + // This operation performs 2 steps at once: + // 1. creates a [JoinHashMap] of all batches from the stream + // 2. stores the batches in a vector. + let initial = (JoinHashMap::new(), Vec::new(), 0); + let left_data = stream + .try_fold(initial, |mut acc, batch| async { + let hash = &mut acc.0; + let values = &mut acc.1; + let index = acc.2; + update_hash(&on_left, &batch, hash, index).unwrap(); + values.push(batch); + acc.2 += 1; + Ok(acc) + }) + .await?; + // we have the batches and the hash map with their keys. We can how create a stream + // over the right that uses this information to issue new batches. + + let stream = self.right.execute(partition).await?; + Ok(Box::pin(HashJoinStream { + schema: self.schema.clone(), + on_right, + join_type: self.join_type.clone(), + left_data: (left_data.0, left_data.1), + right: stream, + })) + } +} + +/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, +/// assuming that the [RecordBatch] corresponds to the `index`th +fn update_hash( + on: &HashSet, + batch: &RecordBatch, + hash: &mut JoinHashMap, + index: usize, +) -> Result<()> { + // evaluate the keys + let keys_values = on + .iter() + .map(|name| Ok(col(name).evaluate(batch)?.into_array(batch))) + .collect::>>()?; + + let mut key = Vec::with_capacity(keys_values.len()); + for _ in 0..keys_values.len() { + key.push(GroupByScalar::UInt32(0)); + } + + // update the hash map + for row in 0..batch.num_rows() { + create_key(&keys_values, row, &mut key)?; + match hash.get_mut(&key) { + Some(v) => v.push((index, row)), + None => { + hash.insert(key.clone(), vec![(index, row)]); + } + }; + } + Ok(()) +} + +/// A stream that issues [RecordBatch]es as they arrive from the right of the join. +struct HashJoinStream { + /// Input schema + schema: Arc, + /// columns from the right used to compute the hash + on_right: HashSet, + /// type of the join + join_type: JoinType, + /// information from the left + left_data: JoinLeftData, + /// right + right: SendableRecordBatchStream, +} + +impl RecordBatchStream for HashJoinStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +/// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`. +/// The resulting batch has [Schema] `schema`. +/// # Error +/// This function errors when: +/// * +fn build_batch_from_indices( + schema: &Schema, + left: &Vec, + right: &RecordBatch, + indices: &[(JoinIndex, JoinIndex)], +) -> ArrowResult { + if left.is_empty() { + todo!("Create empty record batch"); + } + // this is just for symmetry of the code below. + let right = vec![right.clone()]; + + // build the columns of the new [RecordBatch]: + // 1. pick whether the column is from the left or right + // 2. based on the pick, `take` items from the different recordBatches + let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); + for field in schema.fields() { + // pick the column (left or right) based on the field name. + // Note that we take `.data()` to gather the [ArrayData] of each array. + let (is_left, arrays) = match left[0].schema().index_of(field.name()) { + Ok(i) => Ok((true, left.iter().map(|batch| batch.column(i).data()).collect::>())), + Err(_) => { + match right[0].schema().index_of(field.name()) { + Ok(i) => Ok((false, right.iter().map(|batch| batch.column(i).data()).collect::>())), + _ => Err(DataFusionError::Internal( + format!("During execution, the column {} was not found in neither the left or right side of the join", field.name()).to_string() + )) + } + } + }.map_err(DataFusionError::into_arrow_external_error)?; + + // create a vector of references to be passed to [MutableArrayData] + let arrays = arrays + .iter() + .map(|array| array.as_ref()) + .collect::>(); + let capacity = arrays.iter().map(|array| array.len()).sum(); + let mut mutable = MutableArrayData::new(arrays, capacity); + + let array = if is_left { + // build the array using the left + for (join_index, _) in indices { + match join_index { + Some((batch, row)) => mutable.extend(*batch, *row, *row + 1), + // something like `mutable.extend_nulls(*row, *row + 1)` + None => unimplemented!(), + } + } + make_array(Arc::new(mutable.freeze())) + } else { + // build the array using the right + for (_, join_index) in indices { + match join_index { + Some((batch, row)) => mutable.extend(*batch, *row, *row + 1), + // something like `mutable.extend_nulls(*row, *row + 1)` + None => unimplemented!(), + } + } + make_array(Arc::new(mutable.freeze())) + }; + columns.push(array); + } + Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?) +} + +fn build_batch( + batch: &RecordBatch, + left_data: &JoinLeftData, + on_right: &HashSet, + join_type: &JoinType, + schema: &Schema, +) -> ArrowResult { + let mut right_hash = JoinHashMap::with_capacity(batch.num_rows()); + update_hash(on_right, batch, &mut right_hash, 0).unwrap(); + + let indices = build_join_indexes(&left_data.0, &right_hash, join_type).unwrap(); + + build_batch_from_indices(schema, &left_data.1, &batch, &indices) +} + +/// returns a vector with (index from left, index from right). +/// The size of this vector corresponds to the total size of a joined batch +// For a join on column A: +// left right +// batch 1 +// A B A D +// --------------- +// 1 a 3 6 +// 2 b 1 2 +// 3 c 2 4 +// batch 2 +// A B A D +// --------------- +// 1 a 5 10 +// 2 b 2 2 +// 4 d 1 1 +// indices (batch, batch_row) +// left right +// (0, 2) (0, 0) +// (0, 0) (0, 1) +// (0, 1) (0, 2) +// (1, 0) (0, 1) +// (1, 1) (0, 2) +// (0, 1) (1, 1) +// (0, 0) (1, 2) +// (1, 1) (1, 1) +// (1, 0) (1, 2) +fn build_join_indexes( + left: &JoinHashMap, + right: &JoinHashMap, + join_type: &JoinType, +) -> Result> { + match join_type { + JoinType::Inner => { + // inner => key intersection + // unfortunately rust does not support intersection of map keys :( + let left_set: HashSet> = left.keys().cloned().collect(); + let left_right: HashSet> = right.keys().cloned().collect(); + let inner = left_set.intersection(&left_right); + + let mut indexes = Vec::new(); // unknown a prior size + for key in inner { + // the unwrap never happens by construction of the key + let left_indexes = left.get(key).unwrap(); + let right_indexes = right.get(key).unwrap(); + + // for every item on the left and right with this key, add the respective pair + left_indexes.iter().for_each(|x| { + right_indexes.iter().for_each(|y| { + // on an inner join, left and right indices are present + indexes.push((Some(*x), Some(*y))); + }) + }) + } + Ok(indexes) + } + } +} + +impl Stream for HashJoinStream { + type Item = ArrowResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.right + .poll_next_unpin(cx) + .map(|maybe_batch| match maybe_batch { + Some(Ok(batch)) => Some(build_batch( + &batch, + &self.left_data, + &self.on_right, + &self.join_type, + &self.schema, + )), + other => other, + }) + } +} + +#[cfg(test)] +mod tests { + + use crate::{ + physical_plan::{common, memory::MemoryExec}, + test::{build_table_i32, columns, format_batch}, + }; + + use super::*; + use std::collections::HashSet; + use std::sync::Arc; + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Result> { + let (batch, schema) = build_table_i32(a, b, c)?; + Ok(Arc::new(MemoryExec::try_new( + &vec![vec![batch]], + Arc::new(schema), + None, + )?)) + } + + fn join( + left: Arc, + right: Arc, + on: &JoinOn, + ) -> Result { + HashJoinExec::try_new(left, right, on, &JoinType::Inner) + } + + /// Asserts that the rows are the same, taking into account that their order + /// is irrelevant + fn assert_same_rows(result: &[String], expected: &[&str]) { + assert_eq!(result.len(), expected.len()); + + // convert to set since row order is irrelevant + let result = result.iter().map(|s| s.clone()).collect::>(); + + let expected = expected + .iter() + .map(|s| s.to_string()) + .collect::>(); + assert_eq!(result, expected); + } + + #[tokio::test] + async fn join_one() -> Result<()> { + let t1 = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + )?; + let t2 = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + )?; + let on = &[("b1", "b1")]; + + let join = join(t1, t2, on)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + let result = format_batch(&batches[0]); + let expected = vec!["2,5,8,20,80", "3,5,9,20,80", "1,4,7,10,70"]; + + assert_same_rows(&result, &expected); + + Ok(()) + } + + #[tokio::test] + async fn join_two() -> Result<()> { + let t1 = build_table( + ("a1", &vec![1, 2, 2]), + ("b2", &vec![1, 2, 2]), + ("c1", &vec![7, 8, 9]), + )?; + let t2 = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + )?; + let on = &[("a1", "a1"), ("b2", "b2")]; + + let join = join(t1, t2, on)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + assert_eq!(batches.len(), 1); + + let result = format_batch(&batches[0]); + let expected = vec!["1,1,7,70", "2,2,8,80", "2,2,9,80"]; + + assert_same_rows(&result, &expected); + + Ok(()) + } +} diff --git a/rust/datafusion/src/physical_plan/hash_utils.rs b/rust/datafusion/src/physical_plan/hash_utils.rs new file mode 100644 index 00000000000..c3987faa88b --- /dev/null +++ b/rust/datafusion/src/physical_plan/hash_utils.rs @@ -0,0 +1,172 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Functionality used both on logical and physical plans + +use crate::error::{DataFusionError, Result}; +use arrow::datatypes::{Field, Schema}; +use std::collections::HashSet; + +/// All valid types of joins. +#[derive(Clone, Debug)] +pub enum JoinType { + /// Inner join + Inner, +} + +/// The on clause of the join, as vector of (left, right) columns. +pub type JoinOn<'a> = [(&'a str, &'a str)]; + +/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. +/// They are valid whenever their columns' intersection equals the set `on` +pub fn check_join_is_valid(left: &Schema, right: &Schema, on: &JoinOn) -> Result<()> { + let left: HashSet = left.fields().iter().map(|f| f.name().clone()).collect(); + let right: HashSet = + right.fields().iter().map(|f| f.name().clone()).collect(); + + check_join_set_is_valid(&left, &right, on) +} + +/// Checks whether the sets left, right and on compose a valid join. +/// They are valid whenever their intersection equals the set `on` +fn check_join_set_is_valid( + left: &HashSet, + right: &HashSet, + on: &JoinOn, +) -> Result<()> { + if on.len() == 0 { + return Err(DataFusionError::Plan( + "The 'on' clause of a join cannot be empty".to_string(), + )); + } + let on_left = &on.iter().map(|on| on.0.to_string()).collect::>(); + let left_missing = on_left.difference(left).collect::>(); + + let on_right = &on.iter().map(|on| on.1.to_string()).collect::>(); + let right_missing = on_right.difference(right).collect::>(); + + if (left_missing.len() > 0) | (right_missing.len() > 0) { + return Err(DataFusionError::Plan(format!( + "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {:?}\nMissing on the right: {:?}", + left_missing, + right_missing, + ))); + }; + + let remaining = right + .difference(on_right) + .cloned() + .collect::>(); + + let collisions = left.intersection(&remaining).collect::>(); + + if collisions.len() > 0 { + return Err(DataFusionError::Plan(format!( + "The left schema and the right schema have the following columns with the same name without being on the ON statement: {:?}. Consider aliasing them.", + collisions, + ))); + }; + + Ok(()) +} + +/// Creates a schema for a join operation. +/// The fields from the left side are first +pub fn build_join_schema( + left: &Schema, + right: &Schema, + on: &JoinOn, + join_type: &JoinType, +) -> Schema { + let fields: Vec = match join_type { + JoinType::Inner => { + // inner: all fields are there + let on_right = &on.iter().map(|on| on.1.to_string()).collect::>(); + + let left_fields = left.fields().iter(); + + let right_fields = right + .fields() + .iter() + .filter(|f| !on_right.contains(f.name())); + + // left then right + left_fields.chain(right_fields).cloned().collect() + } + }; + Schema::new(fields) +} + +#[cfg(test)] +mod tests { + + use super::*; + + fn check(left: &[&str], right: &[&str], on: &[(&str, &str)]) -> Result<()> { + let left = left.iter().map(|x| x.to_string()).collect::>(); + let right = right.iter().map(|x| x.to_string()).collect::>(); + + check_join_set_is_valid(&left, &right, on) + } + + #[test] + fn check_valid() -> Result<()> { + let left = vec!["a", "b1"]; + let right = vec!["a", "b2"]; + let on = &[("a", "a")]; + + check(&left, &right, on)?; + Ok(()) + } + + #[test] + fn check_not_in_right() { + let left = vec!["a", "b"]; + let right = vec!["b"]; + let on = &[("a", "a")]; + + assert!(check(&left, &right, on).is_err()); + } + + #[test] + fn check_not_in_left() { + let left = vec!["b"]; + let right = vec!["a"]; + let on = &[("a", "a")]; + + assert!(check(&left, &right, on).is_err()); + } + + #[test] + fn check_collision() { + // column "a" would appear both in left and right + let left = vec!["a", "c"]; + let right = vec!["a", "b"]; + let on = &[("a", "b")]; + + assert!(check(&left, &right, on).is_err()); + } + + #[test] + fn check_in_right() { + let left = vec!["a", "c"]; + let right = vec!["b"]; + let on = &[("a", "b")]; + + assert!(check(&left, &right, on).is_ok()); + } +} diff --git a/rust/datafusion/src/physical_plan/mod.rs b/rust/datafusion/src/physical_plan/mod.rs index 0a9711ac8aa..f1d0a344b93 100644 --- a/rust/datafusion/src/physical_plan/mod.rs +++ b/rust/datafusion/src/physical_plan/mod.rs @@ -235,6 +235,8 @@ pub mod filter; pub mod functions; pub mod group_scalar; pub mod hash_aggregate; +pub mod hash_join; +pub mod hash_utils; pub mod limit; pub mod math_expressions; pub mod memory; diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index e995c9428d5..e9ed7acb9a6 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -22,7 +22,7 @@ use crate::error::Result; use crate::execution::context::ExecutionContext; use crate::logical_plan::{LogicalPlan, LogicalPlanBuilder}; use crate::physical_plan::ExecutionPlan; -use arrow::array; +use arrow::array::{self, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use std::env; @@ -249,6 +249,34 @@ pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { assert_eq!(actual, expected); } +/// returns a table with 3 columns of i32 in memory +pub fn build_table_i32( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), +) -> Result<(RecordBatch, Schema)> { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Int32, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + )?; + Ok((batch, schema)) +} + +/// Returns the column names on the schema +pub fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() +} + pub mod user_defined; pub mod variable; From 4e471fd7ec34d05ddbbde397c15d95dedad7b7af Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Fri, 20 Nov 2020 21:32:26 +0100 Subject: [PATCH 4/4] Added more tests. --- .../datafusion/src/physical_plan/hash_join.rs | 121 +++++++++++++++--- rust/datafusion/src/test/mod.rs | 8 +- 2 files changed, 106 insertions(+), 23 deletions(-) diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs index de12d912ba0..69a3d5a432e 100644 --- a/rust/datafusion/src/physical_plan/hash_join.rs +++ b/rust/datafusion/src/physical_plan/hash_join.rs @@ -429,13 +429,10 @@ mod tests { a: (&str, &Vec), b: (&str, &Vec), c: (&str, &Vec), - ) -> Result> { - let (batch, schema) = build_table_i32(a, b, c)?; - Ok(Arc::new(MemoryExec::try_new( - &vec![vec![batch]], - Arc::new(schema), - None, - )?)) + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&vec![vec![batch]], schema, None).unwrap()) } fn join( @@ -449,8 +446,6 @@ mod tests { /// Asserts that the rows are the same, taking into account that their order /// is irrelevant fn assert_same_rows(result: &[String], expected: &[&str]) { - assert_eq!(result.len(), expected.len()); - // convert to set since row order is irrelevant let result = result.iter().map(|s| s.clone()).collect::>(); @@ -463,19 +458,19 @@ mod tests { #[tokio::test] async fn join_one() -> Result<()> { - let t1 = build_table( + let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition ("c1", &vec![7, 8, 9]), - )?; - let t2 = build_table( + ); + let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 6]), ("c2", &vec![70, 80, 90]), - )?; + ); let on = &[("b1", "b1")]; - let join = join(t1, t2, on)?; + let join = join(left, right, on)?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); @@ -493,19 +488,58 @@ mod tests { #[tokio::test] async fn join_two() -> Result<()> { - let t1 = build_table( + let left = build_table( ("a1", &vec![1, 2, 2]), ("b2", &vec![1, 2, 2]), ("c1", &vec![7, 8, 9]), - )?; - let t2 = build_table( + ); + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = &[("a1", "a1"), ("b2", "b2")]; + + let join = join(left, right, on)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + assert_eq!(batches.len(), 1); + + let result = format_batch(&batches[0]); + let expected = vec!["1,1,7,70", "2,2,8,80", "2,2,9,80"]; + + assert_same_rows(&result, &expected); + + Ok(()) + } + + /// Test where the left has 2 parts, the right with 1 part => 1 part + #[tokio::test] + async fn join_one_two_parts_left() -> Result<()> { + let batch1 = build_table_i32( + ("a1", &vec![1, 2]), + ("b2", &vec![1, 2]), + ("c1", &vec![7, 8]), + ); + let batch2 = + build_table_i32(("a1", &vec![2]), ("b2", &vec![2]), ("c1", &vec![9])); + let schema = batch1.schema(); + let left = Arc::new( + MemoryExec::try_new(&vec![vec![batch1], vec![batch2]], schema, None).unwrap(), + ); + + let right = build_table( ("a1", &vec![1, 2, 3]), ("b2", &vec![1, 2, 2]), ("c2", &vec![70, 80, 90]), - )?; + ); let on = &[("a1", "a1"), ("b2", "b2")]; - let join = join(t1, t2, on)?; + let join = join(left, right, on)?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); @@ -521,4 +555,53 @@ mod tests { Ok(()) } + + /// Test where the left has 1 part, the right has 2 parts => 2 parts + #[tokio::test] + async fn join_one_two_parts_right() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + + let batch1 = build_table_i32( + ("a2", &vec![10, 20]), + ("b1", &vec![4, 6]), + ("c2", &vec![70, 80]), + ); + let batch2 = + build_table_i32(("a2", &vec![30]), ("b1", &vec![5]), ("c2", &vec![90])); + let schema = batch1.schema(); + let right = Arc::new( + MemoryExec::try_new(&vec![vec![batch1], vec![batch2]], schema, None).unwrap(), + ); + + let on = &[("b1", "b1")]; + + let join = join(left, right, on)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + + // first part + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + assert_eq!(batches.len(), 1); + + let result = format_batch(&batches[0]); + let expected = vec!["1,4,7,10,70"]; + assert_same_rows(&result, &expected); + + // second part + let stream = join.execute(1).await?; + let batches = common::collect(stream).await?; + assert_eq!(batches.len(), 1); + let result = format_batch(&batches[0]); + let expected = vec!["2,5,8,30,90", "3,5,9,30,90"]; + + assert_same_rows(&result, &expected); + + Ok(()) + } } diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index e9ed7acb9a6..d27cbc44893 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -254,22 +254,22 @@ pub fn build_table_i32( a: (&str, &Vec), b: (&str, &Vec), c: (&str, &Vec), -) -> Result<(RecordBatch, Schema)> { +) -> RecordBatch { let schema = Schema::new(vec![ Field::new(a.0, DataType::Int32, false), Field::new(b.0, DataType::Int32, false), Field::new(c.0, DataType::Int32, false), ]); - let batch = RecordBatch::try_new( + RecordBatch::try_new( Arc::new(schema.clone()), vec![ Arc::new(Int32Array::from(a.1.clone())), Arc::new(Int32Array::from(b.1.clone())), Arc::new(Int32Array::from(c.1.clone())), ], - )?; - Ok((batch, schema)) + ) + .unwrap() } /// Returns the column names on the schema