From 10e435089f232203aaabf71f7e46b604e167fce9 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Wed, 8 Jul 2020 06:19:46 +0200 Subject: [PATCH 01/11] Moved hash-related code to its own module. --- .../src/execution/physical_plan/hash.rs | 169 +++++++++++++++++ .../execution/physical_plan/hash_aggregate.rs | 170 ++---------------- .../src/execution/physical_plan/mod.rs | 1 + 3 files changed, 188 insertions(+), 152 deletions(-) create mode 100644 rust/datafusion/src/execution/physical_plan/hash.rs diff --git a/rust/datafusion/src/execution/physical_plan/hash.rs b/rust/datafusion/src/execution/physical_plan/hash.rs new file mode 100644 index 00000000000..76266a31fc8 --- /dev/null +++ b/rust/datafusion/src/execution/physical_plan/hash.rs @@ -0,0 +1,169 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines anxiliary functions for hashing keys + +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Arc; + +use crate::error::{ExecutionError, Result}; +use crate::execution::physical_plan::Accumulator; + +use arrow::array::{ + ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, UInt16Array, + UInt32Array, UInt64Array, UInt8Array, +}; +use arrow::array::{ + Int16Builder, Int32Builder, Int64Builder, Int8Builder, StringBuilder, UInt16Builder, + UInt32Builder, UInt64Builder, UInt8Builder, +}; +use arrow::datatypes::DataType; + +use fnv::FnvHashMap; + +/// Enumeration of types that can be used in any expression that uses an hash (all primitives except +/// for floating point numerics) +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub enum KeyScalar { + /// 8 bits + UInt8(u8), + /// 16 bits + UInt16(u16), + /// 32 bits + UInt32(u32), + /// 64 bits + UInt64(u64), + /// 8 bits signed + Int8(i8), + /// 16 bits signed + Int16(i16), + /// 32 bits signed + Int32(i32), + /// 64 bits signed + Int64(i64), + /// string + Utf8(String), +} + +/// Return an KeyScalar from an ArrayRef of a given row. +pub fn create_key(col: &ArrayRef, row: usize) -> Result { + match col.data_type() { + DataType::UInt8 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::UInt8(array.value(row))); + } + DataType::UInt16 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::UInt16(array.value(row))); + } + DataType::UInt32 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::UInt32(array.value(row))); + } + DataType::UInt64 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::UInt64(array.value(row))); + } + DataType::Int8 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Int8(array.value(row))); + } + DataType::Int16 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Int16(array.value(row))); + } + DataType::Int32 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Int32(array.value(row))); + } + DataType::Int64 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Int64(array.value(row))); + } + DataType::Utf8 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Utf8(String::from(array.value(row)))); + } + _ => { + return Err(ExecutionError::ExecutionError( + "Unsupported key data type".to_string(), + )) + } + } +} + +/// Create array from `key` attribute in map entry (representing a grouping scalar value) +macro_rules! key_array_from_map_entries { + ($BUILDER:ident, $TY:ident, $MAP:expr, $COL_INDEX:expr) => {{ + let mut builder = $BUILDER::new($MAP.len()); + let mut err = false; + for k in $MAP.keys() { + match k[$COL_INDEX] { + KeyScalar::$TY(n) => builder.append_value(n).unwrap(), + _ => err = true, + } + } + if err { + Err(ExecutionError::ExecutionError( + "unexpected type when creating grouping array from aggregate map" + .to_string(), + )) + } else { + Ok(Arc::new(builder.finish()) as ArrayRef) + } + }}; +} + +/// A set of accumulators +pub type AccumulatorSet = Vec>>; + +/// Builds the array of KeyScalars from `data_type`. +pub fn create_key_array( + i: usize, + data_type: DataType, + map: &FnvHashMap, Rc>, +) -> Result { + let array: Result = match data_type { + DataType::UInt8 => key_array_from_map_entries!(UInt8Builder, UInt8, map, i), + DataType::UInt16 => key_array_from_map_entries!(UInt16Builder, UInt16, map, i), + DataType::UInt32 => key_array_from_map_entries!(UInt32Builder, UInt32, map, i), + DataType::UInt64 => key_array_from_map_entries!(UInt64Builder, UInt64, map, i), + DataType::Int8 => key_array_from_map_entries!(Int8Builder, Int8, map, i), + DataType::Int16 => key_array_from_map_entries!(Int16Builder, Int16, map, i), + DataType::Int32 => key_array_from_map_entries!(Int32Builder, Int32, map, i), + DataType::Int64 => key_array_from_map_entries!(Int64Builder, Int64, map, i), + DataType::Utf8 => { + let mut builder = StringBuilder::new(1); + for k in map.keys() { + match &k[i] { + KeyScalar::Utf8(s) => builder.append_value(&s).unwrap(), + _ => { + return Err(ExecutionError::ExecutionError( + "Unexpected value for Utf8 group column".to_string(), + )) + } + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } + _ => Err(ExecutionError::ExecutionError( + "Unsupported key by expr".to_string(), + )), + }; + Ok(array?) +} diff --git a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs index 19836fd864d..8223e5e0869 100644 --- a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs @@ -28,18 +28,17 @@ use crate::execution::physical_plan::{ use arrow::array::{ ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, - StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, -}; -use arrow::array::{ - Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, - Int8Builder, StringBuilder, UInt16Builder, UInt32Builder, UInt64Builder, - UInt8Builder, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; +use arrow::array::{Float32Builder, Float64Builder, Int64Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::{RecordBatch, RecordBatchReader}; use crate::execution::physical_plan::expressions::col; +use crate::execution::physical_plan::hash::{ + create_key, create_key_array, AccumulatorSet, KeyScalar, +}; use crate::logicalplan::ScalarValue; use fnv::FnvHashMap; @@ -164,28 +163,6 @@ impl Partition for HashAggregatePartition { } } -/// Create array from `key` attribute in map entry (representing a grouping scalar value) -macro_rules! group_array_from_map_entries { - ($BUILDER:ident, $TY:ident, $MAP:expr, $COL_INDEX:expr) => {{ - let mut builder = $BUILDER::new($MAP.len()); - let mut err = false; - for k in $MAP.keys() { - match k[$COL_INDEX] { - GroupByScalar::$TY(n) => builder.append_value(n).unwrap(), - _ => err = true, - } - } - if err { - Err(ExecutionError::ExecutionError( - "unexpected type when creating grouping array from aggregate map" - .to_string(), - )) - } else { - Ok(Arc::new(builder.finish()) as ArrayRef) - } - }}; -} - /// Create array from `value` attribute in map entry (representing an aggregate scalar /// value) macro_rules! aggr_array_from_map_entries { @@ -238,7 +215,7 @@ macro_rules! aggr_array_from_accumulator { #[derive(Debug)] struct MapEntry { - k: Vec, + k: Vec, v: Vec>, } @@ -268,8 +245,6 @@ impl GroupedHashAggregateIterator { } } -type AccumulatorSet = Vec>>; - macro_rules! update_accumulators { ($ARRAY:ident, $ARRAY_TY:ident, $SCALAR_TY:expr, $COL:expr, $ACCUM:expr) => {{ let primitive_array = $ARRAY.as_any().downcast_ref::<$ARRAY_TY>().unwrap(); @@ -299,7 +274,7 @@ impl RecordBatchReader for GroupedHashAggregateIterator { self.finished = true; // create map to store accumulators for each unique grouping key - let mut map: FnvHashMap, Rc> = + let mut map: FnvHashMap, Rc> = FnvHashMap::default(); // iterate over all input batches and update the accumulators @@ -330,7 +305,7 @@ impl RecordBatchReader for GroupedHashAggregateIterator { // create vector large enough to hold the grouping key let mut key = Vec::with_capacity(group_values.len()); for _ in 0..group_values.len() { - key.push(GroupByScalar::UInt32(0)); + key.push(KeyScalar::UInt32(0)); } // iterate over each row in the batch and create the accumulators for each grouping key @@ -339,8 +314,11 @@ impl RecordBatchReader for GroupedHashAggregateIterator { for row in 0..batch.num_rows() { // create grouping key for this row - create_key(&group_values, row, &mut key) - .map_err(ExecutionError::into_arrow_external_error)?; + for i in 0..group_values.len() { + let col = &group_values[i]; + key[i] = create_key(col, row) + .map_err(ExecutionError::into_arrow_external_error)?; + } if let Some(accumulator_set) = map.get(&key) { accumulators.push(accumulator_set.clone()); @@ -452,54 +430,12 @@ impl RecordBatchReader for GroupedHashAggregateIterator { // grouping values for i in 0..self.group_expr.len() { - let array: Result = match self.group_expr[i] + let data_type = self.group_expr[i] .data_type(&input_schema) - .map_err(ExecutionError::into_arrow_external_error)? - { - DataType::UInt8 => { - group_array_from_map_entries!(UInt8Builder, UInt8, map, i) - } - DataType::UInt16 => { - group_array_from_map_entries!(UInt16Builder, UInt16, map, i) - } - DataType::UInt32 => { - group_array_from_map_entries!(UInt32Builder, UInt32, map, i) - } - DataType::UInt64 => { - group_array_from_map_entries!(UInt64Builder, UInt64, map, i) - } - DataType::Int8 => { - group_array_from_map_entries!(Int8Builder, Int8, map, i) - } - DataType::Int16 => { - group_array_from_map_entries!(Int16Builder, Int16, map, i) - } - DataType::Int32 => { - group_array_from_map_entries!(Int32Builder, Int32, map, i) - } - DataType::Int64 => { - group_array_from_map_entries!(Int64Builder, Int64, map, i) - } - DataType::Utf8 => { - let mut builder = StringBuilder::new(1); - for k in map.keys() { - match &k[i] { - GroupByScalar::Utf8(s) => builder.append_value(&s).unwrap(), - _ => { - return Err(ExecutionError::ExecutionError( - "Unexpected value for Utf8 group column".to_string(), - ) - .into_arrow_external_error()) - } - } - } - Ok(Arc::new(builder.finish()) as ArrayRef) - } - _ => Err(ExecutionError::ExecutionError( - "Unsupported group by expr".to_string(), - )), - }; - result_arrays.push(array.map_err(ExecutionError::into_arrow_external_error)?); + .map_err(ExecutionError::into_arrow_external_error)?; + let array = create_key_array(i, data_type, &map) + .map_err(ExecutionError::into_arrow_external_error)?; + result_arrays.push(array); // aggregate values for i in 0..self.aggr_expr.len() { @@ -677,76 +613,6 @@ impl RecordBatchReader for HashAggregateIterator { } } -/// Enumeration of types that can be used in a GROUP BY expression (all primitives except -/// for floating point numerics) -#[derive(Debug, PartialEq, Eq, Hash, Clone)] -enum GroupByScalar { - UInt8(u8), - UInt16(u16), - UInt32(u32), - UInt64(u64), - Int8(i8), - Int16(i16), - Int32(i32), - Int64(i64), - Utf8(String), -} - -/// Create a Vec that can be used as a map key -fn create_key( - group_by_keys: &[ArrayRef], - row: usize, - vec: &mut Vec, -) -> Result<()> { - for i in 0..group_by_keys.len() { - let col = &group_by_keys[i]; - match col.data_type() { - DataType::UInt8 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt8(array.value(row)) - } - DataType::UInt16 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt16(array.value(row)) - } - DataType::UInt32 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt32(array.value(row)) - } - DataType::UInt64 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt64(array.value(row)) - } - DataType::Int8 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int8(array.value(row)) - } - DataType::Int16 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int16(array.value(row)) - } - DataType::Int32 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int32(array.value(row)) - } - DataType::Int64 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int64(array.value(row)) - } - DataType::Utf8 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Utf8(String::from(array.value(row))) - } - _ => { - return Err(ExecutionError::ExecutionError( - "Unsupported GROUP BY data type".to_string(), - )) - } - } - } - Ok(()) -} - #[cfg(test)] mod tests { diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index 2e191784678..4d5c56e349b 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -80,6 +80,7 @@ pub mod common; pub mod csv; pub mod datasource; pub mod expressions; +pub mod hash; pub mod hash_aggregate; pub mod limit; pub mod math_expressions; From e7ec0a15c0dce064387321ffa90fcee1228e48c0 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Wed, 8 Jul 2020 07:24:01 +0200 Subject: [PATCH 02/11] Removed unused struct. --- .../src/execution/physical_plan/hash_aggregate.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs index 8223e5e0869..c402c1551e7 100644 --- a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs @@ -213,12 +213,6 @@ macro_rules! aggr_array_from_accumulator { }}; } -#[derive(Debug)] -struct MapEntry { - k: Vec, - v: Vec>, -} - struct GroupedHashAggregateIterator { schema: SchemaRef, group_expr: Vec>, From c02c4e2cdfa54bc35fd45ed38c59b20fbb552c16 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Wed, 8 Jul 2020 08:21:28 +0200 Subject: [PATCH 03/11] Simplified grouped aggregation. This reduces * the runtime complexity of this operation from O(N*(1 + M)) to O(N*M) (N=number of rows, M=number of aggregations), * the memory footprint from O(N*M) acumulators to O(M) accumulators * the code complexity via DRY. --- .../execution/physical_plan/hash_aggregate.rs | 155 ++++-------------- 1 file changed, 32 insertions(+), 123 deletions(-) diff --git a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs index c402c1551e7..7edb16d97bb 100644 --- a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs @@ -26,15 +26,13 @@ use crate::execution::physical_plan::{ Accumulator, AggregateExpr, ExecutionPlan, Partition, PhysicalExpr, }; -use arrow::array::{ - ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, - UInt16Array, UInt32Array, UInt64Array, UInt8Array, -}; +use arrow::array::ArrayRef; use arrow::array::{Float32Builder, Float64Builder, Int64Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::{RecordBatch, RecordBatchReader}; +use crate::execution::physical_plan::common::get_scalar_value; use crate::execution::physical_plan::expressions::col; use crate::execution::physical_plan::hash::{ create_key, create_key_array, AccumulatorSet, KeyScalar, @@ -239,22 +237,6 @@ impl GroupedHashAggregateIterator { } } -macro_rules! update_accumulators { - ($ARRAY:ident, $ARRAY_TY:ident, $SCALAR_TY:expr, $COL:expr, $ACCUM:expr) => {{ - let primitive_array = $ARRAY.as_any().downcast_ref::<$ARRAY_TY>().unwrap(); - - for row in 0..$ARRAY.len() { - if $ARRAY.is_valid(row) { - let value = Some($SCALAR_TY(primitive_array.value(row))); - let mut accum = $ACCUM[row][$COL].borrow_mut(); - accum - .accumulate_scalar(value) - .map_err(ExecutionError::into_arrow_external_error)?; - } - } - }}; -} - impl RecordBatchReader for GroupedHashAggregateIterator { fn schema(&self) -> SchemaRef { self.schema.clone() @@ -296,123 +278,49 @@ impl RecordBatchReader for GroupedHashAggregateIterator { }) .collect::>>()?; - // create vector large enough to hold the grouping key + // create vector to hold the grouping key let mut key = Vec::with_capacity(group_values.len()); for _ in 0..group_values.len() { key.push(KeyScalar::UInt32(0)); } // iterate over each row in the batch and create the accumulators for each grouping key - let mut accumulators: Vec> = - Vec::with_capacity(batch.num_rows()); - for row in 0..batch.num_rows() { - // create grouping key for this row + // create and assign the grouping key of this row for i in 0..group_values.len() { - let col = &group_values[i]; - key[i] = create_key(col, row) + key[i] = create_key(&group_values[i], row) .map_err(ExecutionError::into_arrow_external_error)?; } - if let Some(accumulator_set) = map.get(&key) { - accumulators.push(accumulator_set.clone()); - } else { - let accumulator_set: AccumulatorSet = self - .aggr_expr - .iter() - .map(|expr| expr.create_accumulator()) - .collect(); - - let accumulator_set = Rc::new(accumulator_set); + // for each new key on the map, add an accumulatorSet to the map + match map.get(&key) { + None => { + let accumulator_set: AccumulatorSet = self + .aggr_expr + .iter() + .map(|expr| expr.create_accumulator()) + .collect(); + map.insert(key.clone(), Rc::new(accumulator_set)); + } + _ => (), + }; - map.insert(key.clone(), accumulator_set.clone()); - accumulators.push(accumulator_set); - } - } + // iterate over each non-grouping column in the batch and update the accumulator + // for each row + for col in 0..aggr_input_values.len() { + let value = get_scalar_value(&aggr_input_values[col], row) + .map_err(ExecutionError::into_arrow_external_error)?; - // iterate over each non-grouping column in the batch and update the accumulator - // for each row - for col in 0..aggr_input_values.len() { - let array = &aggr_input_values[col]; - - match array.data_type() { - DataType::Int8 => update_accumulators!( - array, - Int8Array, - ScalarValue::Int8, - col, - accumulators - ), - DataType::Int16 => update_accumulators!( - array, - Int16Array, - ScalarValue::Int16, - col, - accumulators - ), - DataType::Int32 => update_accumulators!( - array, - Int32Array, - ScalarValue::Int32, - col, - accumulators - ), - DataType::Int64 => update_accumulators!( - array, - Int64Array, - ScalarValue::Int64, - col, - accumulators - ), - DataType::UInt8 => update_accumulators!( - array, - UInt8Array, - ScalarValue::UInt8, - col, - accumulators - ), - DataType::UInt16 => update_accumulators!( - array, - UInt16Array, - ScalarValue::UInt16, - col, - accumulators - ), - DataType::UInt32 => update_accumulators!( - array, - UInt32Array, - ScalarValue::UInt32, - col, - accumulators - ), - DataType::UInt64 => update_accumulators!( - array, - UInt64Array, - ScalarValue::UInt64, - col, - accumulators - ), - DataType::Float32 => update_accumulators!( - array, - Float32Array, - ScalarValue::Float32, - col, - accumulators - ), - DataType::Float64 => update_accumulators!( - array, - Float64Array, - ScalarValue::Float64, - col, - accumulators - ), - other => { - return Err(ExecutionError::ExecutionError(format!( - "Unsupported data type {:?} for result of aggregate expression", - other - )).into_arrow_external_error()); + match map.get(&key) { + None => panic!("This code cannot be reached."), + Some(accumulator_set) => { + let mut accum = accumulator_set[col].borrow_mut(); + accum + .accumulate_scalar(value) + .map_err(ExecutionError::into_arrow_external_error)?; + } } - }; + } } } @@ -615,6 +523,7 @@ mod tests { use crate::execution::physical_plan::expressions::{col, sum}; use crate::execution::physical_plan::merge::MergeExec; use crate::test; + use arrow::array::{Int64Array, UInt32Array}; #[test] fn aggregate() -> Result<()> { From 994329ffdfd4131ed65fb978b0cd7d2c64112a74 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Thu, 9 Jul 2020 07:10:28 +0200 Subject: [PATCH 04/11] Added boolean to set of valid types to group by. --- .../src/execution/physical_plan/hash.rs | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/rust/datafusion/src/execution/physical_plan/hash.rs b/rust/datafusion/src/execution/physical_plan/hash.rs index 76266a31fc8..a7b981431e6 100644 --- a/rust/datafusion/src/execution/physical_plan/hash.rs +++ b/rust/datafusion/src/execution/physical_plan/hash.rs @@ -25,12 +25,12 @@ use crate::error::{ExecutionError, Result}; use crate::execution::physical_plan::Accumulator; use arrow::array::{ - ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, UInt16Array, - UInt32Array, UInt64Array, UInt8Array, + ArrayRef, BooleanArray, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow::array::{ - Int16Builder, Int32Builder, Int64Builder, Int8Builder, StringBuilder, UInt16Builder, - UInt32Builder, UInt64Builder, UInt8Builder, + BooleanBuilder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, StringBuilder, + UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder, }; use arrow::datatypes::DataType; @@ -40,6 +40,8 @@ use fnv::FnvHashMap; /// for floating point numerics) #[derive(Debug, PartialEq, Eq, Hash, Clone)] pub enum KeyScalar { + /// Boolean + Boolean(bool), /// 8 bits UInt8(u8), /// 16 bits @@ -63,6 +65,10 @@ pub enum KeyScalar { /// Return an KeyScalar from an ArrayRef of a given row. pub fn create_key(col: &ArrayRef, row: usize) -> Result { match col.data_type() { + DataType::Boolean => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Boolean(array.value(row))); + } DataType::UInt8 => { let array = col.as_any().downcast_ref::().unwrap(); return Ok(KeyScalar::UInt8(array.value(row))); @@ -139,6 +145,7 @@ pub fn create_key_array( map: &FnvHashMap, Rc>, ) -> Result { let array: Result = match data_type { + DataType::Boolean => key_array_from_map_entries!(BooleanBuilder, Boolean, map, i), DataType::UInt8 => key_array_from_map_entries!(UInt8Builder, UInt8, map, i), DataType::UInt16 => key_array_from_map_entries!(UInt16Builder, UInt16, map, i), DataType::UInt32 => key_array_from_map_entries!(UInt32Builder, UInt32, map, i), From 43623a734794abcd178f6047a67998408e47e942 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Fri, 24 Jul 2020 21:17:07 +0200 Subject: [PATCH 05/11] Added physical plan for inner join The gist of the implementation for a given partition is: ``` for left_record in left_records: hash_left = build_hash_of_keys(left_record) for right_record in right_records: hash_right = build_hash_of_keys(right_record) indexes = inner_join(hash_left, hash_right) yield concat(left_record, right_record)[indexes] ``` I.e. inefficient. The implementation is currently sequential, even though it can be trivially distributed as each RecordBatch is evaluated independently (we still lock the mutex on partition reading, as in other physical plans). Since we have not committed to a distributed computational model, IMO the sequential is enough for now. --- rust/datafusion/src/common.rs | 97 ++++ .../src/execution/physical_plan/hash_join.rs | 494 ++++++++++++++++++ .../src/execution/physical_plan/mod.rs | 2 + .../src/execution/physical_plan/utils.rs | 81 +++ rust/datafusion/src/lib.rs | 1 + 5 files changed, 675 insertions(+) create mode 100644 rust/datafusion/src/common.rs create mode 100644 rust/datafusion/src/execution/physical_plan/hash_join.rs create mode 100644 rust/datafusion/src/execution/physical_plan/utils.rs diff --git a/rust/datafusion/src/common.rs b/rust/datafusion/src/common.rs new file mode 100644 index 00000000000..b457f1efe27 --- /dev/null +++ b/rust/datafusion/src/common.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Functionality used both on logical and physical plans + +use crate::error::{ExecutionError, Result}; +use arrow::datatypes::{Field, Schema}; +use std::collections::HashSet; + +/// All valid types of joins. +#[derive(Clone, Debug)] +pub enum JoinHow { + /// Inner join + Inner, +} + +/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. +/// They are valid whenever their columns' intersection equals the set `on` +pub fn check_join_is_valid( + left: &Schema, + right: &Schema, + on: &HashSet, +) -> Result<()> { + let left: HashSet = left.fields().iter().map(|f| f.name().clone()).collect(); + let right: HashSet = + right.fields().iter().map(|f| f.name().clone()).collect(); + + check_join_set_is_valid(&left, &right, &on)?; + Ok(()) +} + +/// Checks whether the sets left, right and on compose a valid join. +/// They are valid whenever their intersection equals the set `on` +fn check_join_set_is_valid( + left: &HashSet, + right: &HashSet, + on: &HashSet, +) -> Result<()> { + let mandatory_columns = on.iter().map(|s| s).collect::>(); + let common_columns = left.intersection(&right).collect::>(); + let missing_columns = mandatory_columns + .difference(&common_columns) + .collect::>(); + if missing_columns.len() > 0 { + return Err(ExecutionError::General(format!( + "The left or right side of the join does not have columns {:?} columns on \"on\": \nLeft: {:?}\nRight: {:?}\nOn: {:?}", + missing_columns, + left, + right, + on, + ).to_string())); + }; + Ok(()) +} + +/// Creates a schema for a join operation. +/// The fields "on" from the left side are always first +pub fn build_join_schema( + left: &Schema, + right: &Schema, + on: &HashSet, + how: &JoinHow, +) -> Result { + let fields: Vec = match how { + JoinHow::Inner => { + // inner: all fields are there + + let on_fields = left.fields().iter().filter(|f| on.contains(f.name())); + + let left_fields = left.fields().iter().filter(|f| !on.contains(f.name())); + + let right_fields = right.fields().iter().filter(|f| !on.contains(f.name())); + + // "on" are first by construction, then left, then right + on_fields + .chain(left_fields) + .chain(right_fields) + .map(|f| f.clone()) + .collect() + } + }; + Ok(Schema::new(fields)) +} diff --git a/rust/datafusion/src/execution/physical_plan/hash_join.rs b/rust/datafusion/src/execution/physical_plan/hash_join.rs new file mode 100644 index 00000000000..d6a4caf0c36 --- /dev/null +++ b/rust/datafusion/src/execution/physical_plan/hash_join.rs @@ -0,0 +1,494 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the join plan for executing partitions in parallel and then joining the results +//! into a set of partitions. + +use std::collections::{HashMap, HashSet}; +use std::sync::{Arc, Mutex}; + +use arrow::array::Array; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::record_batch::{RecordBatch, RecordBatchReader}; + +use super::{utils::build_array, ExecutionPlan}; +use crate::common::{build_join_schema, check_join_is_valid, JoinHow}; +use crate::error::{ExecutionError, Result}; +use crate::execution::physical_plan::common::RecordBatchIterator; +use crate::execution::physical_plan::expressions::col; +use crate::execution::physical_plan::hash::{create_key, KeyScalar}; +use crate::execution::physical_plan::Partition; + +// A mapping "on" value -> list of row indexes with this key's value +// E.g. [1, 2] -> [3, 6, 8] indicates that rows 3, 6 and 8 have (column1, column2) = [1, 2] +type JoinHashMap = HashMap, Vec>; + +/// join execution plan executes partitions in parallel and combines them into a set of +/// partitions. +pub struct HashJoinExec { + /// left side + left: Arc, + /// right side + right: Arc, + /// Set of common columns used to join on + on: HashSet, + /// How the join is performed + how: JoinHow, + /// The schema once the join is applied + schema: SchemaRef, +} + +impl HashJoinExec { + /// Create a new HashJoinExec + pub fn try_new( + left: Arc, + right: Arc, + on: &HashSet, + how: &JoinHow, + ) -> Result { + let left_schema = left.schema(); + let right_schema = left.schema(); + check_join_is_valid(&left_schema, &right_schema, &on)?; + + let on = on.iter().map(|s| s.clone()).collect::>(); + + let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &on, &how)?); + + Ok(HashJoinExec { + left, + right, + on: on.clone(), + how: how.clone(), + schema, + }) + } +} + +impl ExecutionPlan for HashJoinExec { + fn schema(&self) -> Arc { + self.schema.clone() + } + + fn partitions(&self) -> Result>> { + self.left + .partitions()? + .iter() + .map(move |p| { + let projection: Arc = Arc::new(HashJoinPartition { + schema: self.schema.clone(), + on: self.on.clone(), + how: self.how.clone(), + left: p.clone(), + rights: self.right.partitions()?.clone(), + }); + + Ok(projection) + }) + .collect::>>() + } +} + +/// Partition with a computed hash table +struct HashJoinPartition { + /// Input schema + schema: Arc, + /// columns used to compute the hash + on: HashSet, + /// how to join + how: JoinHow, + /// left partition + left: Arc, + /// partitions on the right + rights: Vec>, +} + +/// returns a HashMap +/// The size of this vector corresponds to the total size of a joined batch +fn build_hash_batch(on: &HashSet, batch: &RecordBatch) -> Result { + let mut hash: JoinHashMap = HashMap::new(); + + // evaluate the keys + let keys_values = on + .iter() + .map(|name| col(name).evaluate(batch)) + .collect::>>()?; + + // build the hash map + for row in 0..batch.num_rows() { + let mut key = Vec::with_capacity(keys_values.len()); + for i in 0..keys_values.len() { + key.push(create_key(&keys_values[i], row)?); + } + match hash.get_mut(&key) { + Some(v) => v.push(row), + None => { + hash.insert(key, vec![row]); + } + }; + } + Ok(hash) +} + +fn build_join_batch( + on: &HashSet, + schema: &Schema, + how: &JoinHow, + left: &RecordBatch, + right: &RecordBatch, + left_hash: &HashMap, Vec>, +) -> Result { + let right_hash = build_hash_batch(on, right)?; + + let join_indexes: Vec<(usize, usize)> = + build_join_indexes(&left_hash, &right_hash, how)?; + + // build the columns for the RecordBatch + let mut columns: Vec> = vec![]; + for field in schema.fields() { + // pick the column (left or right) based on the field name + // if two fields have the same name on left and right, the left is given preference + let (is_left, array) = match left.schema().index_of(field.name()) { + Ok(i) => Ok((true, left.column(i))), + Err(_) => { + match right.schema().index_of(field.name()) { + Ok(i) => Ok((false, right.column(i))), + _ => Err(ExecutionError::InternalError( + format!("During execution, the column {} was not found in neither the left or right side of the join", field.name()).to_string() + )) + } + } + }?; + + // pick the (left or right) indexes of the array + let indexes = join_indexes + .iter() + .map(|(left, right)| if is_left { *left } else { *right }) + .collect(); + + // build of the array out of the indexes. On a join, we expect more entries (due to duplicates) + let array = build_array(&array, &indexes, field.data_type())?; + columns.push(array); + } + Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?) +} + +/// returns a vector with (index from left, index from right). +/// The size of this vector corresponds to the total size of a joined batch +fn build_join_indexes( + left: &JoinHashMap, + right: &JoinHashMap, + how: &JoinHow, +) -> Result> { + // unfortunately rust does not support intersection of map keys :( + let left_set: HashSet> = left.keys().cloned().collect(); + let left_right: HashSet> = right.keys().cloned().collect(); + + match how { + JoinHow::Inner => { + let inner = left_set.intersection(&left_right); + + let mut indexes = Vec::new(); // unknown a prior size + for key in inner { + // the unwrap never happens by construction of the key + let left_indexes = left.get(key).unwrap(); + let right_indexes = right.get(key).unwrap(); + + // for every item on the left and right with this key, add the respective pair + left_indexes.iter().for_each(|x| { + right_indexes.iter().for_each(|y| { + indexes.push((*x, *y)); + }) + }) + } + Ok(indexes) + } + } +} + +/// filter values base on predicate +pub fn build_joined_partition( + schema: &Schema, + on: &HashSet, + how: &JoinHow, + left: &Arc, + right: &Arc, +) -> Result> { + let iterator = left.execute()?; + let mut input = iterator.lock().unwrap(); + + match input.next_batch()? { + None => Ok(None), + Some(left) => { + let left_hash = build_hash_batch(on, &left)?; + + let iterator_other = right.execute()?; + let mut input_other = iterator_other.lock().unwrap(); + match input_other.next_batch()? { + None => Ok(None), + Some(right) => Ok(Some(build_join_batch( + on, schema, how, &left, &right, &left_hash, + )?)), + } + } + } +} + +impl Partition for HashJoinPartition { + /// Execute the join + fn execute(&self) -> Result>> { + let batches = self + .rights + .iter() + .map(|right| { + build_joined_partition( + &self.schema, + &self.on, + &self.how, + &self.left, + &right, + ) + }) + .collect::>>>()?; + let batches = batches + .iter() + .filter_map(|x| match x { + Some(x) => Some(Arc::new(x.clone())), + None => None, + }) + .collect::>(); + + Ok(Arc::new(Mutex::new(RecordBatchIterator::new( + self.schema.clone(), + batches, + )))) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::execution::physical_plan::{memory::MemoryExec, ExecutionPlan}; + use arrow::{ + array::{Array, Int32Array}, + datatypes::{DataType, Field}, + }; + use std::collections::{HashMap, HashSet}; + use std::sync::Arc; + + fn statistics( + partitions: &Vec>, + ) -> Result<(usize, usize, HashMap)>)> { + // compute some statistics over the partitions + let mut partition_count = 0; + let mut row_count = 0; + let mut on_all = HashMap::new(); + for partition in partitions { + partition_count += 1; + let mut hash = HashSet::new(); + let iterator = partition.execute()?; + let mut iterator = iterator.lock().unwrap(); + let mut batch_row_count = 0; + while let Some(batch) = iterator.next_batch()? { + row_count += batch.num_rows(); + batch_row_count += batch.num_rows(); + let array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..array.data().len() { + hash.insert(array.value(i).to_string()); + } + } + on_all.insert(partition_count, (batch_row_count, hash)); + } + Ok((row_count, partition_count, on_all)) + } + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Result> { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Int32, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + )?; + + Ok(Arc::new(MemoryExec::try_new( + &vec![vec![batch]], + Arc::new(schema.clone()), + None, + )?)) + } + + fn join( + left: Arc, + right: Arc, + on: &[&str], + ) -> Result { + let on = on.iter().map(|s| s.to_string()).collect::>(); + HashJoinExec::try_new(left, right, &on, &JoinHow::Inner) + } + + fn get_self_join( + on: &[&str], + a: &Vec, + b: &Vec, + c: &Vec, + ) -> Result { + let table = build_table(("a", a), ("b", b), ("c", c))?; + join(table, table.clone(), on) + } + + fn columns(join: &HashJoinExec) -> Vec { + join.schema() + .fields() + .iter() + .map(|f| f.name().clone()) + .collect() + } + + #[test] + fn self_join() -> Result<()> { + let on = vec!["a"]; + let a = vec![1, 2]; + let b = vec![1, 2]; + let c = vec![1, 2]; + + let (row_count, _, _) = + statistics(&get_self_join(&on, &a, &b, &c)?.partitions()?)?; + + // unique keys => no change in row count + assert_eq!(a.len(), row_count); + + Ok(()) + } + + #[test] + fn self_join_duplicates() -> Result<()> { + let on = vec!["a"]; + let a = vec![1, 2, 2]; + let b = vec![1, 2, 2]; + let c = vec![1, 2, 2]; + + let (row_count, _, _) = + statistics(&get_self_join(&on, &a, &b, &c)?.partitions()?)?; + + // one 1 + two 2s + assert_eq!(1 + 2 * 2, row_count); + + Ok(()) + } + + #[test] + fn self_join_two_columns() -> Result<()> { + let on = vec!["a", "b"]; + let a = vec![1, 2, 2]; + let b = vec![1, 2, 3]; + let c = vec![1, 2, 2]; + + let (row_count, _, _) = + statistics(&get_self_join(&on, &a, &b, &c)?.partitions()?)?; + + // one (1, 1), one (2, 2), one (2, 3) + assert_eq!(3, row_count); + + Ok(()) + } + + #[test] + fn self_join_two_columns_duplicates() -> Result<()> { + let on = vec!["a", "b"]; + let a = vec![1, 2, 2]; + let b = vec![1, 2, 2]; + let c = vec![1, 2, 2]; + + let (row_count, _, _) = + statistics(&get_self_join(&on, &a, &b, &c)?.partitions()?)?; + + // one (1, 1), two (2, 2) + assert_eq!(1 + 2 * 2, row_count); + + Ok(()) + } + + #[test] + fn join_one() -> Result<()> { + let t1 = build_table( + ("a1", &vec![1, 2, 2]), + ("b1", &vec![1, 2, 2]), + ("c1", &vec![1, 2, 2]), + )?; + let t2 = build_table( + ("a2", &vec![1, 2, 2]), + ("b1", &vec![1, 2, 3]), + ("c2", &vec![1, 2, 2]), + )?; + let on = vec!["b1"]; + + let join = join(t1, t2, &on)?; + + let columns = columns(&join); + assert_eq!(columns, vec!["b1", "a1", "c1", "a2", "c2"]); + + let (row_count, _, _) = statistics(&join.partitions()?)?; + + // one 1, two 2, 3 is only on the right + assert_eq!(1 + 2, row_count); + + Ok(()) + } + + #[test] + fn join_two() -> Result<()> { + let t1 = build_table( + ("a1", &vec![1, 2, 2]), + ("b2", &vec![1, 2, 2]), + ("c1", &vec![1, 2, 2]), + )?; + let t2 = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![1, 2, 2]), + )?; + let on = vec!["a1", "b2"]; + + let join = join(t1, t2, &on)?; + + let columns = columns(&join); + assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); + + let (row_count, _, _) = statistics(&join.partitions()?)?; + + // one (1, 1), two (2, 2), 3 is only on the right + assert_eq!(1 + 2, row_count); + + Ok(()) + } +} diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index 4d5c56e349b..c9bbef9f3ec 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -82,6 +82,7 @@ pub mod datasource; pub mod expressions; pub mod hash; pub mod hash_aggregate; +pub mod hash_join; pub mod limit; pub mod math_expressions; pub mod memory; @@ -91,3 +92,4 @@ pub mod projection; pub mod selection; pub mod sort; pub mod udf; +pub mod utils; diff --git a/rust/datafusion/src/execution/physical_plan/utils.rs b/rust/datafusion/src/execution/physical_plan/utils.rs new file mode 100644 index 00000000000..93b6b08a075 --- /dev/null +++ b/rust/datafusion/src/execution/physical_plan/utils.rs @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines utility functions for low-level operations on Arrays + +use std::sync::Arc; + +use crate::error::{ExecutionError, Result}; +use arrow::array::{ + Array, ArrayRef, BooleanArray, BooleanBuilder, Float32Array, Float32Builder, + Float64Array, Float64Builder, Int16Array, Int16Builder, Int32Array, Int32Builder, + Int64Array, Int64Builder, Int8Array, Int8Builder, StringArray, StringBuilder, + UInt16Array, UInt16Builder, UInt32Array, UInt32Builder, UInt64Array, UInt64Builder, + UInt8Array, UInt8Builder, +}; +use arrow::datatypes::DataType; + +// cast, iterate over and re-build array +macro_rules! _build_array { + ($ARRAY:expr, $INDEXES:expr, $TYPE:ty, $BUILDER:ty) => {{ + let array = match $ARRAY.as_any().downcast_ref::<$TYPE>() { + Some(n) => Ok(n), + None => Err(ExecutionError::InternalError( + format!("Invalid data type for ").to_string(), + )), + }?; + + let mut builder = <$BUILDER>::new($INDEXES.len()); + for index in $INDEXES { + if array.is_null(*index) { + builder.append_null()?; + } else { + builder.append_value(array.value(*index))?; + } + } + Ok(Arc::new(builder.finish())) + };}; +} + +/// Builds and array +pub fn build_array( + array: &ArrayRef, + indexes: &Vec, + datatype: &DataType, +) -> Result { + match datatype { + DataType::Boolean => _build_array!(array, indexes, BooleanArray, BooleanBuilder), + DataType::Int8 => _build_array!(array, indexes, Int8Array, Int8Builder), + DataType::Int16 => _build_array!(array, indexes, Int16Array, Int16Builder), + DataType::Int32 => _build_array!(array, indexes, Int32Array, Int32Builder), + DataType::Int64 => _build_array!(array, indexes, Int64Array, Int64Builder), + DataType::UInt8 => _build_array!(array, indexes, UInt8Array, UInt8Builder), + DataType::UInt16 => _build_array!(array, indexes, UInt16Array, UInt16Builder), + DataType::UInt32 => _build_array!(array, indexes, UInt32Array, UInt32Builder), + DataType::UInt64 => _build_array!(array, indexes, UInt64Array, UInt64Builder), + DataType::Float64 => _build_array!(array, indexes, Float64Array, Float64Builder), + DataType::Float32 => _build_array!(array, indexes, Float32Array, Float32Builder), + DataType::Utf8 => _build_array!(array, indexes, StringArray, StringBuilder), + _ => Err(ExecutionError::NotImplemented( + format!( + "Conversions for type {:?} are still not implemented", + datatype + ) + .to_string(), + )), + } +} diff --git a/rust/datafusion/src/lib.rs b/rust/datafusion/src/lib.rs index fb4e5af303f..c69c0b55698 100644 --- a/rust/datafusion/src/lib.rs +++ b/rust/datafusion/src/lib.rs @@ -29,6 +29,7 @@ extern crate arrow; extern crate sqlparser; +pub mod common; pub mod datasource; pub mod error; pub mod execution; From 9e85a5b9ed6559b125ac48a9091b27d4e58c4112 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 25 Jul 2020 09:08:39 +0200 Subject: [PATCH 06/11] Added LogicalPlan::Join and respective integration. --- rust/datafusion/src/execution/context.rs | 12 +++++ rust/datafusion/src/logicalplan.rs | 51 ++++++++++++++++++- .../src/optimizer/projection_push_down.rs | 19 +++++++ .../datafusion/src/optimizer/type_coercion.rs | 20 ++++++++ 4 files changed, 101 insertions(+), 1 deletion(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index cec0f1531e0..cc0a04b0fd8 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -28,6 +28,7 @@ use arrow::csv; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; +use super::physical_plan::hash_join::HashJoinExec; use crate::datasource::csv::CsvFile; use crate::datasource::parquet::ParquetTable; use crate::datasource::TableProvider; @@ -439,6 +440,17 @@ impl ExecutionContext { merge, )?)) } + LogicalPlan::Join { + left, + right, + on, + how, + .. + } => { + let left = self.create_physical_plan(left, batch_size)?; + let right = self.create_physical_plan(right, batch_size)?; + Ok(Arc::new(HashJoinExec::try_new(left, right, on, how)?)) + } LogicalPlan::Selection { input, expr, .. } => { let input = self.create_physical_plan(input, batch_size)?; let input_schema = input.as_ref().schema().clone(); diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 25780f9a127..66adc0d29a2 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -21,10 +21,11 @@ //! Logical query plans can then be optimized and executed directly, or translated into //! physical query plans and executed. -use std::fmt; +use std::{collections::HashSet, fmt}; use arrow::datatypes::{DataType, Field, Schema}; +use crate::common::{build_join_schema, check_join_is_valid, JoinHow}; use crate::datasource::csv::{CsvFile, CsvReadOptions}; use crate::datasource::parquet::ParquetTable; use crate::datasource::TableProvider; @@ -636,6 +637,19 @@ pub enum LogicalPlan { /// The schema description schema: Box, }, + /// Represents a join operation + Join { + /// The left side + left: Box, + /// The right side + right: Box, + /// Grouping column names + on: HashSet, + /// How the join is performed + how: JoinHow, + /// The schema description + schema: Box, + }, /// Represents a list of sort expressions to be applied to a relation Sort { /// The sort expressions @@ -747,6 +761,7 @@ impl LogicalPlan { LogicalPlan::Sort { schema, .. } => &schema, LogicalPlan::Limit { schema, .. } => &schema, LogicalPlan::CreateExternalTable { schema, .. } => &schema, + LogicalPlan::Join { schema, .. } => &schema, } } } @@ -834,6 +849,18 @@ impl LogicalPlan { write!(f, "Limit: {}", n)?; input.fmt_with_indent(f, indent + 1) } + LogicalPlan::Join { + ref left, + ref right, + ref on, + ref how, + .. + } => { + let on = on.iter().map(|x| x.clone()).collect::>().join(", #"); + write!(f, "Join: on=[#{}] how={:?}", on, how)?; + left.fmt_with_indent(f, indent + 1)?; + right.fmt_with_indent(f, indent + 1) + } LogicalPlan::CreateExternalTable { ref name, .. } => { write!(f, "CreateExternalTable: {:?}", name) } @@ -1045,6 +1072,28 @@ impl LogicalPlanBuilder { })) } + /// Apply a join + pub fn join( + &self, + right: &LogicalPlan, + on: &Vec, + how: &JoinHow, + ) -> Result { + let on = on.iter().map(|s| s.clone()).collect::>(); + + check_join_is_valid(&self.plan.schema(), &right.schema(), &on)?; + + let schema = build_join_schema(&self.plan.schema(), &right.schema(), &on, how)?; + + Ok(Self::from(&LogicalPlan::Join { + left: Box::new(self.plan.clone()), + right: Box::new(right.clone()), + on, + how: how.clone(), + schema: Box::new(schema), + })) + } + /// Apply an aggregate pub fn aggregate(&self, group_expr: Vec, aggr_expr: Vec) -> Result { let mut all_expr: Vec = group_expr.clone(); diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs index e99a996d7c9..0baac025576 100644 --- a/rust/datafusion/src/optimizer/projection_push_down.rs +++ b/rust/datafusion/src/optimizer/projection_push_down.rs @@ -110,6 +110,25 @@ impl ProjectionPushDown { input: Box::new(self.optimize_plan(&input, accum, has_projection)?), schema: schema.clone(), }), + LogicalPlan::Join { + left, + right, + on, + how, + schema, + } => { + // optimize each of the plans + let left = self.optimize_plan(&left, accum, has_projection)?; + let right = self.optimize_plan(&right, accum, has_projection)?; + + Ok(LogicalPlan::Join { + left: Box::new(left), + right: Box::new(right), + on: on.clone(), + how: how.clone(), + schema: schema.clone(), + }) + } LogicalPlan::EmptyRelation { .. } => Ok(plan.clone()), LogicalPlan::TableScan { schema_name, diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index a03a92cdfe5..0e9ff73ce7f 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -165,6 +165,26 @@ impl<'a> OptimizerRule for TypeCoercionRule<'a> { self.rewrite_expr_list(aggr_expr, input.schema())?, )? .build(), + LogicalPlan::Join { + left, + right, + on, + how, + schema, + } => { + // no optimization to be made on this node, so we just pass the optimization + // to its children + let left = self.optimize(&left)?; + let right = self.optimize(&right)?; + + Ok(LogicalPlan::Join { + left: Box::new(left), + right: Box::new(right), + on: on.clone(), + how: how.clone(), + schema: schema.clone(), + }) + } LogicalPlan::TableScan { .. } => Ok(plan.clone()), LogicalPlan::InMemoryScan { .. } => Ok(plan.clone()), LogicalPlan::ParquetScan { .. } => Ok(plan.clone()), From e4bfa276efb709f618e7247ff564d647b35afebb Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 25 Jul 2020 09:55:17 +0200 Subject: [PATCH 07/11] Added test for LogicalPla::Join --- .../src/execution/physical_plan/hash_join.rs | 187 ++++-------------- rust/datafusion/src/logicalplan.rs | 43 +++- rust/datafusion/src/test/mod.rs | 29 +++ 3 files changed, 107 insertions(+), 152 deletions(-) diff --git a/rust/datafusion/src/execution/physical_plan/hash_join.rs b/rust/datafusion/src/execution/physical_plan/hash_join.rs index d6a4caf0c36..b7eddee33fd 100644 --- a/rust/datafusion/src/execution/physical_plan/hash_join.rs +++ b/rust/datafusion/src/execution/physical_plan/hash_join.rs @@ -61,7 +61,7 @@ impl HashJoinExec { how: &JoinHow, ) -> Result { let left_schema = left.schema(); - let right_schema = left.schema(); + let right_schema = right.schema(); check_join_is_valid(&left_schema, &right_schema, &on)?; let on = on.iter().map(|s| s.clone()).collect::>(); @@ -282,67 +282,22 @@ impl Partition for HashJoinPartition { mod tests { use super::*; - use crate::execution::physical_plan::{memory::MemoryExec, ExecutionPlan}; - use arrow::{ - array::{Array, Int32Array}, - datatypes::{DataType, Field}, + use crate::{ + execution::physical_plan::{common, memory::MemoryExec, ExecutionPlan}, + test::{build_table_i32, columns, format_batch}, }; - use std::collections::{HashMap, HashSet}; + use std::collections::HashSet; use std::sync::Arc; - fn statistics( - partitions: &Vec>, - ) -> Result<(usize, usize, HashMap)>)> { - // compute some statistics over the partitions - let mut partition_count = 0; - let mut row_count = 0; - let mut on_all = HashMap::new(); - for partition in partitions { - partition_count += 1; - let mut hash = HashSet::new(); - let iterator = partition.execute()?; - let mut iterator = iterator.lock().unwrap(); - let mut batch_row_count = 0; - while let Some(batch) = iterator.next_batch()? { - row_count += batch.num_rows(); - batch_row_count += batch.num_rows(); - let array = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - for i in 0..array.data().len() { - hash.insert(array.value(i).to_string()); - } - } - on_all.insert(partition_count, (batch_row_count, hash)); - } - Ok((row_count, partition_count, on_all)) - } - fn build_table( a: (&str, &Vec), b: (&str, &Vec), c: (&str, &Vec), ) -> Result> { - let schema = Schema::new(vec![ - Field::new(a.0, DataType::Int32, false), - Field::new(b.0, DataType::Int32, false), - Field::new(c.0, DataType::Int32, false), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![ - Arc::new(Int32Array::from(a.1.clone())), - Arc::new(Int32Array::from(b.1.clone())), - Arc::new(Int32Array::from(c.1.clone())), - ], - )?; - + let (batch, schema) = build_table_i32(a, b, c)?; Ok(Arc::new(MemoryExec::try_new( &vec![vec![batch]], - Arc::new(schema.clone()), + Arc::new(schema), None, )?)) } @@ -356,111 +311,50 @@ mod tests { HashJoinExec::try_new(left, right, &on, &JoinHow::Inner) } - fn get_self_join( - on: &[&str], - a: &Vec, - b: &Vec, - c: &Vec, - ) -> Result { - let table = build_table(("a", a), ("b", b), ("c", c))?; - join(table, table.clone(), on) - } + /// Asserts that the rows are the same, taking into account that their order + /// is irrelevant + fn assert_same_rows(result: &[String], expected: &[&str]) { + assert_eq!(result.len(), expected.len()); - fn columns(join: &HashJoinExec) -> Vec { - join.schema() - .fields() + // convert to set since row order is irrelevant + let result = result .iter() - .map(|f| f.name().clone()) - .collect() - } - - #[test] - fn self_join() -> Result<()> { - let on = vec!["a"]; - let a = vec![1, 2]; - let b = vec![1, 2]; - let c = vec![1, 2]; - - let (row_count, _, _) = - statistics(&get_self_join(&on, &a, &b, &c)?.partitions()?)?; - - // unique keys => no change in row count - assert_eq!(a.len(), row_count); - - Ok(()) - } - - #[test] - fn self_join_duplicates() -> Result<()> { - let on = vec!["a"]; - let a = vec![1, 2, 2]; - let b = vec![1, 2, 2]; - let c = vec![1, 2, 2]; - - let (row_count, _, _) = - statistics(&get_self_join(&on, &a, &b, &c)?.partitions()?)?; + .map(|s| s.clone()) + .collect::>(); - // one 1 + two 2s - assert_eq!(1 + 2 * 2, row_count); - - Ok(()) - } - - #[test] - fn self_join_two_columns() -> Result<()> { - let on = vec!["a", "b"]; - let a = vec![1, 2, 2]; - let b = vec![1, 2, 3]; - let c = vec![1, 2, 2]; - - let (row_count, _, _) = - statistics(&get_self_join(&on, &a, &b, &c)?.partitions()?)?; - - // one (1, 1), one (2, 2), one (2, 3) - assert_eq!(3, row_count); - - Ok(()) - } - - #[test] - fn self_join_two_columns_duplicates() -> Result<()> { - let on = vec!["a", "b"]; - let a = vec![1, 2, 2]; - let b = vec![1, 2, 2]; - let c = vec![1, 2, 2]; - - let (row_count, _, _) = - statistics(&get_self_join(&on, &a, &b, &c)?.partitions()?)?; - - // one (1, 1), two (2, 2) - assert_eq!(1 + 2 * 2, row_count); - - Ok(()) + let expected = expected + .iter() + .map(|s| s.to_string()) + .collect::>(); + assert_eq!(result, expected); } #[test] fn join_one() -> Result<()> { let t1 = build_table( - ("a1", &vec![1, 2, 2]), - ("b1", &vec![1, 2, 2]), - ("c1", &vec![1, 2, 2]), + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), )?; let t2 = build_table( - ("a2", &vec![1, 2, 2]), - ("b1", &vec![1, 2, 3]), - ("c2", &vec![1, 2, 2]), + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), )?; let on = vec!["b1"]; let join = join(t1, t2, &on)?; - let columns = columns(&join); + let columns = columns(&join.schema()); assert_eq!(columns, vec!["b1", "a1", "c1", "a2", "c2"]); - let (row_count, _, _) = statistics(&join.partitions()?)?; + let batches = common::collect(join.partitions()?[0].execute()?)?; + assert_eq!(batches.len(), 1); - // one 1, two 2, 3 is only on the right - assert_eq!(1 + 2, row_count); + let result = format_batch(&batches[0]); + let expected = vec!["5,2,8,20,80", "5,3,9,20,80", "4,1,7,10,70"]; + + assert_same_rows(&result, &expected); Ok(()) } @@ -470,24 +364,27 @@ mod tests { let t1 = build_table( ("a1", &vec![1, 2, 2]), ("b2", &vec![1, 2, 2]), - ("c1", &vec![1, 2, 2]), + ("c1", &vec![7, 8, 9]), )?; let t2 = build_table( ("a1", &vec![1, 2, 3]), ("b2", &vec![1, 2, 2]), - ("c2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), )?; let on = vec!["a1", "b2"]; let join = join(t1, t2, &on)?; - let columns = columns(&join); + let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); - let (row_count, _, _) = statistics(&join.partitions()?)?; + let batches = common::collect(join.partitions()?[0].execute()?)?; + assert_eq!(batches.len(), 1); + + let result = format_batch(&batches[0]); + let expected = vec!["1,1,7,70", "2,2,8,80", "2,2,9,80"]; - // one (1, 1), two (2, 2), 3 is only on the right - assert_eq!(1 + 2, row_count); + assert_same_rows(&result, &expected); Ok(()) } diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 66adc0d29a2..1b5a8261503 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -1073,13 +1073,8 @@ impl LogicalPlanBuilder { } /// Apply a join - pub fn join( - &self, - right: &LogicalPlan, - on: &Vec, - how: &JoinHow, - ) -> Result { - let on = on.iter().map(|s| s.clone()).collect::>(); + pub fn join(&self, right: &LogicalPlan, on: &[&str], how: &JoinHow) -> Result { + let on = on.iter().map(|s| s.to_string()).collect::>(); check_join_is_valid(&self.plan.schema(), &right.schema(), &on)?; @@ -1118,6 +1113,7 @@ impl LogicalPlanBuilder { #[cfg(test)] mod tests { use super::*; + use crate::test::{build_table_i32, columns}; #[test] fn plan_builder_simple() -> Result<()> { @@ -1185,6 +1181,39 @@ mod tests { Ok(()) } + fn build_table(a: &str, b: &str, c: &str) -> Result { + let (batch, schema) = build_table_i32((a, &vec![]), (b, &vec![]), (c, &vec![]))?; + + Ok(LogicalPlan::InMemoryScan { + data: vec![vec![batch]], + schema: Box::new(schema.clone()), + projection: None, + projected_schema: Box::new(schema), + }) + } + + #[test] + fn plan_builder_join() -> Result<()> { + let on = vec!["a1"]; + let t1 = build_table("a1", "b1", "c1")?; + let t2 = build_table("a1", "b2", "c2")?; + + let plan = LogicalPlanBuilder::from(&t1) + .join(&t2, &on, &JoinHow::Inner)? + .build()?; + + let expected = "\ + Join: on=[#a1] how=Inner\ + \n InMemoryScan: projection=None\ + \n InMemoryScan: projection=None"; + assert_eq!(expected, format!("{:?}", plan)); + + let columns = columns(&plan.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "b2", "c2"]); + + Ok(()) + } + #[test] fn plan_builder_sort() -> Result<()> { let plan = LogicalPlanBuilder::scan( diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index 317c14564f1..72eba71ccd9 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -21,6 +21,7 @@ use crate::error::Result; use crate::execution::context::ExecutionContext; use crate::execution::physical_plan::ExecutionPlan; use crate::logicalplan::{Expr, LogicalPlan, LogicalPlanBuilder}; +use array::Int32Array; use arrow::array; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; @@ -231,3 +232,31 @@ pub fn max(expr: Expr) -> Expr { return_type: DataType::Float64, } } + +/// returns a table with 3 columns of i32 in memory +pub fn build_table_i32( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), +) -> Result<(RecordBatch, Schema)> { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Int32, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + )?; + Ok((batch, schema)) +} + +/// Returns the column names on the schema +pub fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() +} From 221694a3b4d283fa7088add15ced35dfc29d1041 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 25 Jul 2020 10:03:04 +0200 Subject: [PATCH 08/11] Minor generalization of a test. --- rust/datafusion/src/execution/context.rs | 41 +++++++++++++----------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index cc0a04b0fd8..69acbe0dd6e 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -816,28 +816,31 @@ mod tests { Ok(()) } - #[test] - fn projection_on_memory_scan() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Int32, false), - ]); - let plan = LogicalPlanBuilder::from(&LogicalPlan::InMemoryScan { - data: vec![vec![RecordBatch::try_new( - Arc::new(schema.clone()), - vec![ - Arc::new(Int32Array::from(vec![1, 10, 10, 100])), - Arc::new(Int32Array::from(vec![2, 12, 12, 120])), - Arc::new(Int32Array::from(vec![3, 12, 12, 120])), - ], - )?]], + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Result { + let (batch, schema) = build_table_i32(a, b, c)?; + + Ok(LogicalPlan::InMemoryScan { + data: vec![vec![batch]], schema: Box::new(schema.clone()), projection: None, - projected_schema: Box::new(schema.clone()), + projected_schema: Box::new(schema), }) - .project(vec![col("b")])? - .build()?; + } + + #[test] + fn projection_on_memory_scan() -> Result<()> { + let plan = build_table( + ("a", &vec![1, 10, 10, 100]), + ("b", &vec![2, 12, 12, 120]), + ("c", &vec![3, 12, 12, 120]), + )?; + let plan = LogicalPlanBuilder::from(&plan) + .project(vec![col("b")])? + .build()?; assert_fields_eq(&plan, vec!["b"]); let ctx = ExecutionContext::new(); From 0e86c6796c4532647a69e89ed0b814f3cfcdb69a Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 25 Jul 2020 11:07:08 +0200 Subject: [PATCH 09/11] Added test for context using joins. --- rust/datafusion/src/execution/context.rs | 25 +++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 69acbe0dd6e..758e8830b0b 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -701,7 +701,7 @@ mod tests { use crate::datasource::MemTable; use crate::execution::physical_plan::udf::ScalarUdf; use crate::logicalplan::{aggregate_expr, col, scalar_function}; - use crate::test; + use crate::{common::JoinHow, test}; use arrow::array::{ArrayRef, Int32Array}; use arrow::compute::add; use std::fs::File; @@ -877,6 +877,29 @@ mod tests { Ok(()) } + #[test] + fn join() -> Result<()> { + let left = + build_table(("a", &vec![1, 1]), ("b", &vec![2, 3]), ("c", &vec![3, 4]))?; + let right = build_table( + ("a", &vec![1, 1]), + ("b2", &vec![12, 13]), + ("c2", &vec![13, 14]), + )?; + let plan = LogicalPlanBuilder::from(&left) + .join(&right, &vec!["a"], &JoinHow::Inner)? + .build()?; + + let ctx = ExecutionContext::new(); + let physical_plan = ctx.create_physical_plan(&plan, 1024)?; + + let batches = ctx.collect(physical_plan.as_ref())?; + let expected: Vec<&str> = + vec!["1,2,3,12,13", "1,2,3,13,14", "1,3,4,12,13", "1,3,4,13,14"]; + assert_eq!(test::format_batch(&batches[0]), expected); + Ok(()) + } + #[test] fn sort() -> Result<()> { let results = execute("SELECT c1, c2 FROM test ORDER BY c1 DESC, c2 ASC", 4)?; From e8ed369a191b43b65c848dbe7efa9f59af5486ce Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 25 Jul 2020 12:10:46 +0200 Subject: [PATCH 10/11] Added support for INNER JOIN on sql. --- rust/datafusion/src/execution/context.rs | 2 +- .../src/execution/physical_plan/hash_join.rs | 5 +- rust/datafusion/src/logicalplan.rs | 9 ++- rust/datafusion/src/sql/planner.rs | 75 ++++++++++++++++++- 4 files changed, 82 insertions(+), 9 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 758e8830b0b..02f7162751e 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -887,7 +887,7 @@ mod tests { ("c2", &vec![13, 14]), )?; let plan = LogicalPlanBuilder::from(&left) - .join(&right, &vec!["a"], &JoinHow::Inner)? + .join(&right, &vec!["a".to_string()], &JoinHow::Inner)? .build()?; let ctx = ExecutionContext::new(); diff --git a/rust/datafusion/src/execution/physical_plan/hash_join.rs b/rust/datafusion/src/execution/physical_plan/hash_join.rs index b7eddee33fd..aa8211ba521 100644 --- a/rust/datafusion/src/execution/physical_plan/hash_join.rs +++ b/rust/datafusion/src/execution/physical_plan/hash_join.rs @@ -317,10 +317,7 @@ mod tests { assert_eq!(result.len(), expected.len()); // convert to set since row order is irrelevant - let result = result - .iter() - .map(|s| s.clone()) - .collect::>(); + let result = result.iter().map(|s| s.clone()).collect::>(); let expected = expected .iter() diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 1b5a8261503..b2b878c9b3a 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -1073,7 +1073,12 @@ impl LogicalPlanBuilder { } /// Apply a join - pub fn join(&self, right: &LogicalPlan, on: &[&str], how: &JoinHow) -> Result { + pub fn join( + &self, + right: &LogicalPlan, + on: &[String], + how: &JoinHow, + ) -> Result { let on = on.iter().map(|s| s.to_string()).collect::>(); check_join_is_valid(&self.plan.schema(), &right.schema(), &on)?; @@ -1194,7 +1199,7 @@ mod tests { #[test] fn plan_builder_join() -> Result<()> { - let on = vec!["a1"]; + let on = vec!["a1".to_string()]; let t1 = build_table("a1", "b1", "c1")?; let t2 = build_table("a1", "b2", "c2")?; diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 21cf870ba34..9e883752bdf 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -26,7 +26,7 @@ use crate::logicalplan::{ use arrow::datatypes::*; -use crate::logicalplan::Expr::Alias; +use crate::{common::JoinHow, logicalplan::Expr::Alias}; use sqlparser::sqlast::*; /// The SchemaProvider trait allows the query planner to obtain meta-data about tables and @@ -60,6 +60,7 @@ impl SqlToRel { ref limit, ref group_by, ref having, + ref joins, .. } => { if having.is_some() { @@ -74,7 +75,9 @@ impl SqlToRel { None => LogicalPlanBuilder::empty().build()?, }; - // selection first + // join first, since a filter may include columns from both sides + let plan = self.join(&plan, &joins)?; + let plan = self.filter(&plan, selection)?; let projection_expr: Vec = projection @@ -125,6 +128,43 @@ impl SqlToRel { } } + /// Apply a join to the plan + pub fn join(&self, plan: &LogicalPlan, joins: &[Join]) -> Result { + if joins.len() == 0 { + // short-circuit if no join exists + return Ok(plan.clone()); + } + if joins.len() > 1 { + return Err(ExecutionError::NotImplemented( + "statements with more than one join relation are still not supported" + .to_owned(), + )); + }; + let join: &Join = &joins[0]; + + match &join.join_operator { + JoinOperator::Inner(JoinConstraint::On(relation)) => { + let expr = self.sql_to_rex(&relation, &plan.schema())?; + + let names = match expr { + Expr::Column(name) => Ok(vec![name]), + _ => Err(ExecutionError::NotImplemented( + "Only joins on single columns are supported".to_owned(), + )), + }?; + + let right = self.sql_to_rel(&join.relation)?; + + LogicalPlanBuilder::from(&plan) + .join(&right, &names, &JoinHow::Inner)? + .build() + } + _ => Err(ExecutionError::NotImplemented( + "Only inner joins (ON) are currently supported".to_owned(), + )), + } + } + /// Apply a filter to the plan fn filter( &self, @@ -556,6 +596,27 @@ mod tests { ); } + #[test] + fn one_column_join() { + quick_test( + "SELECT a FROM simple1 JOIN simple2 ON a", + "\ + Projection: #a\ + \n Join: on=[#a] how=Inner\ + \n TableScan: simple1 projection=None\ + \n TableScan: simple2 projection=None", + ); + + quick_test( + "SELECT a FROM simple1 JOIN simple2 ON b", + "\ + Projection: #a\ + \n Join: on=[#b] how=Inner\ + \n TableScan: simple1 projection=None\ + \n TableScan: simple2 projection=None", + ); + } + #[test] fn test_wildcard() { quick_test( @@ -719,6 +780,16 @@ mod tests { Field::new("c12", DataType::Float64, false), Field::new("c13", DataType::Utf8, false), ]))), + "simple1" => Some(Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::UInt32, false), + Field::new("c1", DataType::UInt32, false), + ]))), + "simple2" => Some(Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::UInt32, false), + Field::new("c2", DataType::UInt32, false), + ]))), _ => None, } } From ae8dbd3e7e1166a5d5e7e465a2da729c193b88fe Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 25 Jul 2020 14:10:43 +0200 Subject: [PATCH 11/11] Added tests to join checks. --- rust/datafusion/src/common.rs | 56 ++++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/rust/datafusion/src/common.rs b/rust/datafusion/src/common.rs index b457f1efe27..c904d5c0c6a 100644 --- a/rust/datafusion/src/common.rs +++ b/rust/datafusion/src/common.rs @@ -50,15 +50,21 @@ fn check_join_set_is_valid( right: &HashSet, on: &HashSet, ) -> Result<()> { - let mandatory_columns = on.iter().map(|s| s).collect::>(); + if on.len() == 0 { + return Err(ExecutionError::General( + "The 'on' clause of a join cannot be empty".to_string(), + )); + } + + let on_columns = on.iter().map(|s| s).collect::>(); let common_columns = left.intersection(&right).collect::>(); - let missing_columns = mandatory_columns + let missing = on_columns .difference(&common_columns) .collect::>(); - if missing_columns.len() > 0 { + if missing.len() > 0 { return Err(ExecutionError::General(format!( "The left or right side of the join does not have columns {:?} columns on \"on\": \nLeft: {:?}\nRight: {:?}\nOn: {:?}", - missing_columns, + missing, left, right, on, @@ -95,3 +101,45 @@ pub fn build_join_schema( }; Ok(Schema::new(fields)) } + +#[cfg(test)] +mod tests { + + use super::*; + + fn check(left: &[&str], right: &[&str], on: &[&str]) -> Result<()> { + let left = left.iter().map(|x| x.to_string()).collect::>(); + let right = right.iter().map(|x| x.to_string()).collect::>(); + let on = on.iter().map(|x| x.to_string()).collect::>(); + + check_join_set_is_valid(&left, &right, &on) + } + + #[test] + fn check_valid() -> Result<()> { + let left = vec!["a", "b1"]; + let right = vec!["a", "b2"]; + let on = vec!["a"]; + + check(&left, &right, &on)?; + Ok(()) + } + + #[test] + fn check_not_in_right() { + let left = vec!["a", "b"]; + let right = vec!["b"]; + let on = vec!["a"]; + + assert!(check(&left, &right, &on).is_err()); + } + + #[test] + fn check_not_in_left() { + let left = vec!["b"]; + let right = vec!["a"]; + let on = vec!["a"]; + + assert!(check(&left, &right, &on).is_err()); + } +}