diff --git a/rust/datafusion/src/execution/physical_plan/hash.rs b/rust/datafusion/src/execution/physical_plan/hash.rs new file mode 100644 index 00000000000..a7b981431e6 --- /dev/null +++ b/rust/datafusion/src/execution/physical_plan/hash.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines anxiliary functions for hashing keys + +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Arc; + +use crate::error::{ExecutionError, Result}; +use crate::execution::physical_plan::Accumulator; + +use arrow::array::{ + ArrayRef, BooleanArray, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, +}; +use arrow::array::{ + BooleanBuilder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, StringBuilder, + UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder, +}; +use arrow::datatypes::DataType; + +use fnv::FnvHashMap; + +/// Enumeration of types that can be used in any expression that uses an hash (all primitives except +/// for floating point numerics) +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub enum KeyScalar { + /// Boolean + Boolean(bool), + /// 8 bits + UInt8(u8), + /// 16 bits + UInt16(u16), + /// 32 bits + UInt32(u32), + /// 64 bits + UInt64(u64), + /// 8 bits signed + Int8(i8), + /// 16 bits signed + Int16(i16), + /// 32 bits signed + Int32(i32), + /// 64 bits signed + Int64(i64), + /// string + Utf8(String), +} + +/// Return an KeyScalar from an ArrayRef of a given row. +pub fn create_key(col: &ArrayRef, row: usize) -> Result { + match col.data_type() { + DataType::Boolean => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Boolean(array.value(row))); + } + DataType::UInt8 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::UInt8(array.value(row))); + } + DataType::UInt16 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::UInt16(array.value(row))); + } + DataType::UInt32 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::UInt32(array.value(row))); + } + DataType::UInt64 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::UInt64(array.value(row))); + } + DataType::Int8 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Int8(array.value(row))); + } + DataType::Int16 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Int16(array.value(row))); + } + DataType::Int32 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Int32(array.value(row))); + } + DataType::Int64 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Int64(array.value(row))); + } + DataType::Utf8 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Utf8(String::from(array.value(row)))); + } + _ => { + return Err(ExecutionError::ExecutionError( + "Unsupported key data type".to_string(), + )) + } + } +} + +/// Create array from `key` attribute in map entry (representing a grouping scalar value) +macro_rules! key_array_from_map_entries { + ($BUILDER:ident, $TY:ident, $MAP:expr, $COL_INDEX:expr) => {{ + let mut builder = $BUILDER::new($MAP.len()); + let mut err = false; + for k in $MAP.keys() { + match k[$COL_INDEX] { + KeyScalar::$TY(n) => builder.append_value(n).unwrap(), + _ => err = true, + } + } + if err { + Err(ExecutionError::ExecutionError( + "unexpected type when creating grouping array from aggregate map" + .to_string(), + )) + } else { + Ok(Arc::new(builder.finish()) as ArrayRef) + } + }}; +} + +/// A set of accumulators +pub type AccumulatorSet = Vec>>; + +/// Builds the array of KeyScalars from `data_type`. +pub fn create_key_array( + i: usize, + data_type: DataType, + map: &FnvHashMap, Rc>, +) -> Result { + let array: Result = match data_type { + DataType::Boolean => key_array_from_map_entries!(BooleanBuilder, Boolean, map, i), + DataType::UInt8 => key_array_from_map_entries!(UInt8Builder, UInt8, map, i), + DataType::UInt16 => key_array_from_map_entries!(UInt16Builder, UInt16, map, i), + DataType::UInt32 => key_array_from_map_entries!(UInt32Builder, UInt32, map, i), + DataType::UInt64 => key_array_from_map_entries!(UInt64Builder, UInt64, map, i), + DataType::Int8 => key_array_from_map_entries!(Int8Builder, Int8, map, i), + DataType::Int16 => key_array_from_map_entries!(Int16Builder, Int16, map, i), + DataType::Int32 => key_array_from_map_entries!(Int32Builder, Int32, map, i), + DataType::Int64 => key_array_from_map_entries!(Int64Builder, Int64, map, i), + DataType::Utf8 => { + let mut builder = StringBuilder::new(1); + for k in map.keys() { + match &k[i] { + KeyScalar::Utf8(s) => builder.append_value(&s).unwrap(), + _ => { + return Err(ExecutionError::ExecutionError( + "Unexpected value for Utf8 group column".to_string(), + )) + } + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } + _ => Err(ExecutionError::ExecutionError( + "Unsupported key by expr".to_string(), + )), + }; + Ok(array?) +} diff --git a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs index 19836fd864d..7edb16d97bb 100644 --- a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs @@ -26,20 +26,17 @@ use crate::execution::physical_plan::{ Accumulator, AggregateExpr, ExecutionPlan, Partition, PhysicalExpr, }; -use arrow::array::{ - ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, - StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, -}; -use arrow::array::{ - Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, - Int8Builder, StringBuilder, UInt16Builder, UInt32Builder, UInt64Builder, - UInt8Builder, -}; +use arrow::array::ArrayRef; +use arrow::array::{Float32Builder, Float64Builder, Int64Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::{RecordBatch, RecordBatchReader}; +use crate::execution::physical_plan::common::get_scalar_value; use crate::execution::physical_plan::expressions::col; +use crate::execution::physical_plan::hash::{ + create_key, create_key_array, AccumulatorSet, KeyScalar, +}; use crate::logicalplan::ScalarValue; use fnv::FnvHashMap; @@ -164,28 +161,6 @@ impl Partition for HashAggregatePartition { } } -/// Create array from `key` attribute in map entry (representing a grouping scalar value) -macro_rules! group_array_from_map_entries { - ($BUILDER:ident, $TY:ident, $MAP:expr, $COL_INDEX:expr) => {{ - let mut builder = $BUILDER::new($MAP.len()); - let mut err = false; - for k in $MAP.keys() { - match k[$COL_INDEX] { - GroupByScalar::$TY(n) => builder.append_value(n).unwrap(), - _ => err = true, - } - } - if err { - Err(ExecutionError::ExecutionError( - "unexpected type when creating grouping array from aggregate map" - .to_string(), - )) - } else { - Ok(Arc::new(builder.finish()) as ArrayRef) - } - }}; -} - /// Create array from `value` attribute in map entry (representing an aggregate scalar /// value) macro_rules! aggr_array_from_map_entries { @@ -236,12 +211,6 @@ macro_rules! aggr_array_from_accumulator { }}; } -#[derive(Debug)] -struct MapEntry { - k: Vec, - v: Vec>, -} - struct GroupedHashAggregateIterator { schema: SchemaRef, group_expr: Vec>, @@ -268,24 +237,6 @@ impl GroupedHashAggregateIterator { } } -type AccumulatorSet = Vec>>; - -macro_rules! update_accumulators { - ($ARRAY:ident, $ARRAY_TY:ident, $SCALAR_TY:expr, $COL:expr, $ACCUM:expr) => {{ - let primitive_array = $ARRAY.as_any().downcast_ref::<$ARRAY_TY>().unwrap(); - - for row in 0..$ARRAY.len() { - if $ARRAY.is_valid(row) { - let value = Some($SCALAR_TY(primitive_array.value(row))); - let mut accum = $ACCUM[row][$COL].borrow_mut(); - accum - .accumulate_scalar(value) - .map_err(ExecutionError::into_arrow_external_error)?; - } - } - }}; -} - impl RecordBatchReader for GroupedHashAggregateIterator { fn schema(&self) -> SchemaRef { self.schema.clone() @@ -299,7 +250,7 @@ impl RecordBatchReader for GroupedHashAggregateIterator { self.finished = true; // create map to store accumulators for each unique grouping key - let mut map: FnvHashMap, Rc> = + let mut map: FnvHashMap, Rc> = FnvHashMap::default(); // iterate over all input batches and update the accumulators @@ -327,120 +278,49 @@ impl RecordBatchReader for GroupedHashAggregateIterator { }) .collect::>>()?; - // create vector large enough to hold the grouping key + // create vector to hold the grouping key let mut key = Vec::with_capacity(group_values.len()); for _ in 0..group_values.len() { - key.push(GroupByScalar::UInt32(0)); + key.push(KeyScalar::UInt32(0)); } // iterate over each row in the batch and create the accumulators for each grouping key - let mut accumulators: Vec> = - Vec::with_capacity(batch.num_rows()); - for row in 0..batch.num_rows() { - // create grouping key for this row - create_key(&group_values, row, &mut key) - .map_err(ExecutionError::into_arrow_external_error)?; - - if let Some(accumulator_set) = map.get(&key) { - accumulators.push(accumulator_set.clone()); - } else { - let accumulator_set: AccumulatorSet = self - .aggr_expr - .iter() - .map(|expr| expr.create_accumulator()) - .collect(); - - let accumulator_set = Rc::new(accumulator_set); - - map.insert(key.clone(), accumulator_set.clone()); - accumulators.push(accumulator_set); + // create and assign the grouping key of this row + for i in 0..group_values.len() { + key[i] = create_key(&group_values[i], row) + .map_err(ExecutionError::into_arrow_external_error)?; } - } - // iterate over each non-grouping column in the batch and update the accumulator - // for each row - for col in 0..aggr_input_values.len() { - let array = &aggr_input_values[col]; - - match array.data_type() { - DataType::Int8 => update_accumulators!( - array, - Int8Array, - ScalarValue::Int8, - col, - accumulators - ), - DataType::Int16 => update_accumulators!( - array, - Int16Array, - ScalarValue::Int16, - col, - accumulators - ), - DataType::Int32 => update_accumulators!( - array, - Int32Array, - ScalarValue::Int32, - col, - accumulators - ), - DataType::Int64 => update_accumulators!( - array, - Int64Array, - ScalarValue::Int64, - col, - accumulators - ), - DataType::UInt8 => update_accumulators!( - array, - UInt8Array, - ScalarValue::UInt8, - col, - accumulators - ), - DataType::UInt16 => update_accumulators!( - array, - UInt16Array, - ScalarValue::UInt16, - col, - accumulators - ), - DataType::UInt32 => update_accumulators!( - array, - UInt32Array, - ScalarValue::UInt32, - col, - accumulators - ), - DataType::UInt64 => update_accumulators!( - array, - UInt64Array, - ScalarValue::UInt64, - col, - accumulators - ), - DataType::Float32 => update_accumulators!( - array, - Float32Array, - ScalarValue::Float32, - col, - accumulators - ), - DataType::Float64 => update_accumulators!( - array, - Float64Array, - ScalarValue::Float64, - col, - accumulators - ), - other => { - return Err(ExecutionError::ExecutionError(format!( - "Unsupported data type {:?} for result of aggregate expression", - other - )).into_arrow_external_error()); + // for each new key on the map, add an accumulatorSet to the map + match map.get(&key) { + None => { + let accumulator_set: AccumulatorSet = self + .aggr_expr + .iter() + .map(|expr| expr.create_accumulator()) + .collect(); + map.insert(key.clone(), Rc::new(accumulator_set)); } + _ => (), }; + + // iterate over each non-grouping column in the batch and update the accumulator + // for each row + for col in 0..aggr_input_values.len() { + let value = get_scalar_value(&aggr_input_values[col], row) + .map_err(ExecutionError::into_arrow_external_error)?; + + match map.get(&key) { + None => panic!("This code cannot be reached."), + Some(accumulator_set) => { + let mut accum = accumulator_set[col].borrow_mut(); + accum + .accumulate_scalar(value) + .map_err(ExecutionError::into_arrow_external_error)?; + } + } + } } } @@ -452,54 +332,12 @@ impl RecordBatchReader for GroupedHashAggregateIterator { // grouping values for i in 0..self.group_expr.len() { - let array: Result = match self.group_expr[i] + let data_type = self.group_expr[i] .data_type(&input_schema) - .map_err(ExecutionError::into_arrow_external_error)? - { - DataType::UInt8 => { - group_array_from_map_entries!(UInt8Builder, UInt8, map, i) - } - DataType::UInt16 => { - group_array_from_map_entries!(UInt16Builder, UInt16, map, i) - } - DataType::UInt32 => { - group_array_from_map_entries!(UInt32Builder, UInt32, map, i) - } - DataType::UInt64 => { - group_array_from_map_entries!(UInt64Builder, UInt64, map, i) - } - DataType::Int8 => { - group_array_from_map_entries!(Int8Builder, Int8, map, i) - } - DataType::Int16 => { - group_array_from_map_entries!(Int16Builder, Int16, map, i) - } - DataType::Int32 => { - group_array_from_map_entries!(Int32Builder, Int32, map, i) - } - DataType::Int64 => { - group_array_from_map_entries!(Int64Builder, Int64, map, i) - } - DataType::Utf8 => { - let mut builder = StringBuilder::new(1); - for k in map.keys() { - match &k[i] { - GroupByScalar::Utf8(s) => builder.append_value(&s).unwrap(), - _ => { - return Err(ExecutionError::ExecutionError( - "Unexpected value for Utf8 group column".to_string(), - ) - .into_arrow_external_error()) - } - } - } - Ok(Arc::new(builder.finish()) as ArrayRef) - } - _ => Err(ExecutionError::ExecutionError( - "Unsupported group by expr".to_string(), - )), - }; - result_arrays.push(array.map_err(ExecutionError::into_arrow_external_error)?); + .map_err(ExecutionError::into_arrow_external_error)?; + let array = create_key_array(i, data_type, &map) + .map_err(ExecutionError::into_arrow_external_error)?; + result_arrays.push(array); // aggregate values for i in 0..self.aggr_expr.len() { @@ -677,76 +515,6 @@ impl RecordBatchReader for HashAggregateIterator { } } -/// Enumeration of types that can be used in a GROUP BY expression (all primitives except -/// for floating point numerics) -#[derive(Debug, PartialEq, Eq, Hash, Clone)] -enum GroupByScalar { - UInt8(u8), - UInt16(u16), - UInt32(u32), - UInt64(u64), - Int8(i8), - Int16(i16), - Int32(i32), - Int64(i64), - Utf8(String), -} - -/// Create a Vec that can be used as a map key -fn create_key( - group_by_keys: &[ArrayRef], - row: usize, - vec: &mut Vec, -) -> Result<()> { - for i in 0..group_by_keys.len() { - let col = &group_by_keys[i]; - match col.data_type() { - DataType::UInt8 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt8(array.value(row)) - } - DataType::UInt16 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt16(array.value(row)) - } - DataType::UInt32 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt32(array.value(row)) - } - DataType::UInt64 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt64(array.value(row)) - } - DataType::Int8 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int8(array.value(row)) - } - DataType::Int16 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int16(array.value(row)) - } - DataType::Int32 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int32(array.value(row)) - } - DataType::Int64 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int64(array.value(row)) - } - DataType::Utf8 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Utf8(String::from(array.value(row))) - } - _ => { - return Err(ExecutionError::ExecutionError( - "Unsupported GROUP BY data type".to_string(), - )) - } - } - } - Ok(()) -} - #[cfg(test)] mod tests { @@ -755,6 +523,7 @@ mod tests { use crate::execution::physical_plan::expressions::{col, sum}; use crate::execution::physical_plan::merge::MergeExec; use crate::test; + use arrow::array::{Int64Array, UInt32Array}; #[test] fn aggregate() -> Result<()> { diff --git a/rust/datafusion/src/execution/physical_plan/hash_partition.rs b/rust/datafusion/src/execution/physical_plan/hash_partition.rs new file mode 100644 index 00000000000..da0d1f4b13b --- /dev/null +++ b/rust/datafusion/src/execution/physical_plan/hash_partition.rs @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the execution plan for a repartition operation + +use std::hash::Hash; +use std::hash::Hasher; +use std::sync::{Arc, Mutex}; + +use arrow::array::BooleanBuilder; +use arrow::compute::filter; +use arrow::datatypes::Schema; +use arrow::record_batch::{RecordBatch, RecordBatchReader}; + +use crate::error::Result; +use crate::execution::physical_plan::common::RecordBatchIterator; +use crate::execution::physical_plan::expressions::col; +use crate::execution::physical_plan::hash::{create_key, KeyScalar}; +use crate::execution::physical_plan::ExecutionPlan; +use crate::execution::physical_plan::Partition; + +/// Executes partitions and combines them into a new set of partitions partitioned by a partitioning key. +/// This operation is also known as a Shuffle or Exchange. +/// The resulting partitions are guaranteed to have rows whose `partitioning_key` do not intersect. +/// Specifically, each row is mapped to the partition index `hash(partitioning_key) % new_partitions`. +pub struct RepartitionExec { + /// Input schema + schema: Arc, + /// Input partitions + partitions: Vec>, + /// The new number of partitions + new_partitions: usize, + /// The partitioning key used to map rows to each of the partitions. + partitioning_key: Vec, +} + +impl RepartitionExec { + /// Create a new RepartitionExec + pub fn new( + schema: Arc, + partitions: Vec>, + new_partitions: usize, + partitioning_key: Vec, + ) -> Self { + RepartitionExec { + schema, + partitions, + new_partitions, + partitioning_key: partitioning_key.clone(), + } + } +} + +impl ExecutionPlan for RepartitionExec { + fn schema(&self) -> Arc { + self.schema.clone() + } + + fn partitions(&self) -> Result>> { + // create `new_partitions` new partitions, on which each has a different index `i`. + let mut partitions: Vec> = + Vec::with_capacity(self.new_partitions); + for i in 0..self.new_partitions { + partitions.push(Arc::new(HashPartition { + schema: self.schema.clone(), + partitions: self.partitions.clone(), + index: i, + new_partitions: self.new_partitions, + partitioning_key: self.partitioning_key.clone(), + })) + } + Ok(partitions) + } +} + +struct HashPartition { + /// Input schema + schema: Arc, + /// Input partitions + partitions: Vec>, + /// index of the partition, whose hash of the row is mapped to + index: usize, + /// number of new partitions + new_partitions: usize, + /// the columns used to compute the hash + partitioning_key: Vec, +} + +fn selection_for_partition( + partitioning_key: &[String], + predicate: &dyn Fn(&Vec) -> bool, + partition: &Arc, +) -> Result> { + let iterator = partition.execute()?; + let mut input = iterator.lock().unwrap(); + + match input.next_batch()? { + None => Ok(None), + Some(batch) => { + // evaluate the keys + let keys_values = partitioning_key + .iter() + .map(|name| col(name).evaluate(&batch)) + .collect::>>()?; + + // evaluate the predicate for this partition from the keys_values. + + // Arrow does not provide a method + // for this hashing out-of-the-box and thus we need to do it row by row. + let mut builder = BooleanBuilder::new(batch.num_rows()); + for row in 0..batch.num_rows() { + let mut key = Vec::with_capacity(keys_values.len()); + for i in 0..keys_values.len() { + key.push(create_key(&keys_values[i], row)?); + } + builder.append_value(predicate(&key))?; + } + + let predicate_array = builder.finish(); + + // filter each array based on `predicate` + let mut filtered_arrays = vec![]; + for i in 0..batch.num_columns() { + let array = batch.column(i); + let filtered_array = filter(array.as_ref(), &predicate_array)?; + filtered_arrays.push(filtered_array); + } + Ok(Some(RecordBatch::try_new( + batch.schema().clone(), + filtered_arrays, + )?)) + } + } +} + +impl Partition for HashPartition { + /// Execute the partitioning + fn execute(&self) -> Result>> { + let r = self.partitions.iter().filter_map(|partition| { + selection_for_partition( + &self.partitioning_key, + &|key: &Vec| { + // construct the hash + let mut hasher = fnv::FnvHasher::default(); + key.hash(&mut hasher); + let key_hash = hasher.finish() as usize; + + // map the hash to the partition's index + (key_hash % self.new_partitions) == self.index + }, + partition, + // This is poor implementation. How to fix this? + ) + .expect("Valid partition") + }); + + let mut batches = vec![]; + r.for_each(|batch| batches.push(Arc::new(batch))); + + Ok(Arc::new(Mutex::new(RecordBatchIterator::new( + self.schema.clone(), + batches, + )))) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions}; + use crate::test; + use arrow::array::{Array, StringArray}; + use std::collections::{HashMap, HashSet}; + + #[test] + fn repartition() -> Result<()> { + let schema = test::aggr_test_schema(); + + let num_partitions = 4; + let path = + test::create_partitioned_csv("aggregate_test_100.csv", num_partitions)?; + + let csv = + CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema), None, 1024)?; + + // input should have 4 partitions + let input = csv.partitions()?; + assert_eq!(input.len(), num_partitions); + + let new_num_partitions = 2; + let repartition = RepartitionExec::new( + schema.clone(), + input, + new_num_partitions, + vec!["c1".to_string()], + ); + + // compute some statistics over the partitions + let mut partition_count = 0; + let mut row_count = 0; + let mut batch_count = 0; + let mut hash_all = HashMap::new(); + for partition in repartition.partitions()? { + partition_count += 1; + let mut hash = HashSet::new(); + let iterator = partition.execute()?; + let mut iterator = iterator.lock().unwrap(); + while let Some(batch) = iterator.next_batch()? { + row_count += batch.num_rows(); + batch_count += 1; + let array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..array.data().len() { + hash.insert(array.value(i).to_string()); + } + } + hash_all.insert(partition_count, hash); + } + // correct number of rows and partitions + assert_eq!(new_num_partitions, partition_count); + assert_eq!(100, row_count); + // old partitions * new partitions + assert_eq!(new_num_partitions * num_partitions, batch_count); + + // there is no intersection of the items across partitions + for i in hash_all.keys() { + for j in hash_all.keys() { + if j >= i { + continue; + } + let lhs = hash_all.get(i).unwrap(); + let rhs = hash_all.get(j).unwrap(); + assert_eq!(lhs.intersection(rhs).next(), None) + } + } + + Ok(()) + } +} diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index 2e191784678..1cfbeb0d8b4 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -80,7 +80,9 @@ pub mod common; pub mod csv; pub mod datasource; pub mod expressions; +pub mod hash; pub mod hash_aggregate; +pub mod hash_partition; pub mod limit; pub mod math_expressions; pub mod memory;