diff --git a/rust/datafusion/src/common.rs b/rust/datafusion/src/common.rs new file mode 100644 index 00000000000..c904d5c0c6a --- /dev/null +++ b/rust/datafusion/src/common.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Functionality used both on logical and physical plans + +use crate::error::{ExecutionError, Result}; +use arrow::datatypes::{Field, Schema}; +use std::collections::HashSet; + +/// All valid types of joins. +#[derive(Clone, Debug)] +pub enum JoinHow { + /// Inner join + Inner, +} + +/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. +/// They are valid whenever their columns' intersection equals the set `on` +pub fn check_join_is_valid( + left: &Schema, + right: &Schema, + on: &HashSet, +) -> Result<()> { + let left: HashSet = left.fields().iter().map(|f| f.name().clone()).collect(); + let right: HashSet = + right.fields().iter().map(|f| f.name().clone()).collect(); + + check_join_set_is_valid(&left, &right, &on)?; + Ok(()) +} + +/// Checks whether the sets left, right and on compose a valid join. +/// They are valid whenever their intersection equals the set `on` +fn check_join_set_is_valid( + left: &HashSet, + right: &HashSet, + on: &HashSet, +) -> Result<()> { + if on.len() == 0 { + return Err(ExecutionError::General( + "The 'on' clause of a join cannot be empty".to_string(), + )); + } + + let on_columns = on.iter().map(|s| s).collect::>(); + let common_columns = left.intersection(&right).collect::>(); + let missing = on_columns + .difference(&common_columns) + .collect::>(); + if missing.len() > 0 { + return Err(ExecutionError::General(format!( + "The left or right side of the join does not have columns {:?} columns on \"on\": \nLeft: {:?}\nRight: {:?}\nOn: {:?}", + missing, + left, + right, + on, + ).to_string())); + }; + Ok(()) +} + +/// Creates a schema for a join operation. +/// The fields "on" from the left side are always first +pub fn build_join_schema( + left: &Schema, + right: &Schema, + on: &HashSet, + how: &JoinHow, +) -> Result { + let fields: Vec = match how { + JoinHow::Inner => { + // inner: all fields are there + + let on_fields = left.fields().iter().filter(|f| on.contains(f.name())); + + let left_fields = left.fields().iter().filter(|f| !on.contains(f.name())); + + let right_fields = right.fields().iter().filter(|f| !on.contains(f.name())); + + // "on" are first by construction, then left, then right + on_fields + .chain(left_fields) + .chain(right_fields) + .map(|f| f.clone()) + .collect() + } + }; + Ok(Schema::new(fields)) +} + +#[cfg(test)] +mod tests { + + use super::*; + + fn check(left: &[&str], right: &[&str], on: &[&str]) -> Result<()> { + let left = left.iter().map(|x| x.to_string()).collect::>(); + let right = right.iter().map(|x| x.to_string()).collect::>(); + let on = on.iter().map(|x| x.to_string()).collect::>(); + + check_join_set_is_valid(&left, &right, &on) + } + + #[test] + fn check_valid() -> Result<()> { + let left = vec!["a", "b1"]; + let right = vec!["a", "b2"]; + let on = vec!["a"]; + + check(&left, &right, &on)?; + Ok(()) + } + + #[test] + fn check_not_in_right() { + let left = vec!["a", "b"]; + let right = vec!["b"]; + let on = vec!["a"]; + + assert!(check(&left, &right, &on).is_err()); + } + + #[test] + fn check_not_in_left() { + let left = vec!["b"]; + let right = vec!["a"]; + let on = vec!["a"]; + + assert!(check(&left, &right, &on).is_err()); + } +} diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index cec0f1531e0..02f7162751e 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -28,6 +28,7 @@ use arrow::csv; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; +use super::physical_plan::hash_join::HashJoinExec; use crate::datasource::csv::CsvFile; use crate::datasource::parquet::ParquetTable; use crate::datasource::TableProvider; @@ -439,6 +440,17 @@ impl ExecutionContext { merge, )?)) } + LogicalPlan::Join { + left, + right, + on, + how, + .. + } => { + let left = self.create_physical_plan(left, batch_size)?; + let right = self.create_physical_plan(right, batch_size)?; + Ok(Arc::new(HashJoinExec::try_new(left, right, on, how)?)) + } LogicalPlan::Selection { input, expr, .. } => { let input = self.create_physical_plan(input, batch_size)?; let input_schema = input.as_ref().schema().clone(); @@ -689,7 +701,7 @@ mod tests { use crate::datasource::MemTable; use crate::execution::physical_plan::udf::ScalarUdf; use crate::logicalplan::{aggregate_expr, col, scalar_function}; - use crate::test; + use crate::{common::JoinHow, test}; use arrow::array::{ArrayRef, Int32Array}; use arrow::compute::add; use std::fs::File; @@ -804,28 +816,31 @@ mod tests { Ok(()) } - #[test] - fn projection_on_memory_scan() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Int32, false), - ]); - let plan = LogicalPlanBuilder::from(&LogicalPlan::InMemoryScan { - data: vec![vec![RecordBatch::try_new( - Arc::new(schema.clone()), - vec![ - Arc::new(Int32Array::from(vec![1, 10, 10, 100])), - Arc::new(Int32Array::from(vec![2, 12, 12, 120])), - Arc::new(Int32Array::from(vec![3, 12, 12, 120])), - ], - )?]], + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Result { + let (batch, schema) = build_table_i32(a, b, c)?; + + Ok(LogicalPlan::InMemoryScan { + data: vec![vec![batch]], schema: Box::new(schema.clone()), projection: None, - projected_schema: Box::new(schema.clone()), + projected_schema: Box::new(schema), }) - .project(vec![col("b")])? - .build()?; + } + + #[test] + fn projection_on_memory_scan() -> Result<()> { + let plan = build_table( + ("a", &vec![1, 10, 10, 100]), + ("b", &vec![2, 12, 12, 120]), + ("c", &vec![3, 12, 12, 120]), + )?; + let plan = LogicalPlanBuilder::from(&plan) + .project(vec![col("b")])? + .build()?; assert_fields_eq(&plan, vec!["b"]); let ctx = ExecutionContext::new(); @@ -862,6 +877,29 @@ mod tests { Ok(()) } + #[test] + fn join() -> Result<()> { + let left = + build_table(("a", &vec![1, 1]), ("b", &vec![2, 3]), ("c", &vec![3, 4]))?; + let right = build_table( + ("a", &vec![1, 1]), + ("b2", &vec![12, 13]), + ("c2", &vec![13, 14]), + )?; + let plan = LogicalPlanBuilder::from(&left) + .join(&right, &vec!["a".to_string()], &JoinHow::Inner)? + .build()?; + + let ctx = ExecutionContext::new(); + let physical_plan = ctx.create_physical_plan(&plan, 1024)?; + + let batches = ctx.collect(physical_plan.as_ref())?; + let expected: Vec<&str> = + vec!["1,2,3,12,13", "1,2,3,13,14", "1,3,4,12,13", "1,3,4,13,14"]; + assert_eq!(test::format_batch(&batches[0]), expected); + Ok(()) + } + #[test] fn sort() -> Result<()> { let results = execute("SELECT c1, c2 FROM test ORDER BY c1 DESC, c2 ASC", 4)?; diff --git a/rust/datafusion/src/execution/physical_plan/hash.rs b/rust/datafusion/src/execution/physical_plan/hash.rs new file mode 100644 index 00000000000..a7b981431e6 --- /dev/null +++ b/rust/datafusion/src/execution/physical_plan/hash.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines anxiliary functions for hashing keys + +use std::cell::RefCell; +use std::rc::Rc; +use std::sync::Arc; + +use crate::error::{ExecutionError, Result}; +use crate::execution::physical_plan::Accumulator; + +use arrow::array::{ + ArrayRef, BooleanArray, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, +}; +use arrow::array::{ + BooleanBuilder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, StringBuilder, + UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder, +}; +use arrow::datatypes::DataType; + +use fnv::FnvHashMap; + +/// Enumeration of types that can be used in any expression that uses an hash (all primitives except +/// for floating point numerics) +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub enum KeyScalar { + /// Boolean + Boolean(bool), + /// 8 bits + UInt8(u8), + /// 16 bits + UInt16(u16), + /// 32 bits + UInt32(u32), + /// 64 bits + UInt64(u64), + /// 8 bits signed + Int8(i8), + /// 16 bits signed + Int16(i16), + /// 32 bits signed + Int32(i32), + /// 64 bits signed + Int64(i64), + /// string + Utf8(String), +} + +/// Return an KeyScalar from an ArrayRef of a given row. +pub fn create_key(col: &ArrayRef, row: usize) -> Result { + match col.data_type() { + DataType::Boolean => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Boolean(array.value(row))); + } + DataType::UInt8 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::UInt8(array.value(row))); + } + DataType::UInt16 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::UInt16(array.value(row))); + } + DataType::UInt32 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::UInt32(array.value(row))); + } + DataType::UInt64 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::UInt64(array.value(row))); + } + DataType::Int8 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Int8(array.value(row))); + } + DataType::Int16 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Int16(array.value(row))); + } + DataType::Int32 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Int32(array.value(row))); + } + DataType::Int64 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Int64(array.value(row))); + } + DataType::Utf8 => { + let array = col.as_any().downcast_ref::().unwrap(); + return Ok(KeyScalar::Utf8(String::from(array.value(row)))); + } + _ => { + return Err(ExecutionError::ExecutionError( + "Unsupported key data type".to_string(), + )) + } + } +} + +/// Create array from `key` attribute in map entry (representing a grouping scalar value) +macro_rules! key_array_from_map_entries { + ($BUILDER:ident, $TY:ident, $MAP:expr, $COL_INDEX:expr) => {{ + let mut builder = $BUILDER::new($MAP.len()); + let mut err = false; + for k in $MAP.keys() { + match k[$COL_INDEX] { + KeyScalar::$TY(n) => builder.append_value(n).unwrap(), + _ => err = true, + } + } + if err { + Err(ExecutionError::ExecutionError( + "unexpected type when creating grouping array from aggregate map" + .to_string(), + )) + } else { + Ok(Arc::new(builder.finish()) as ArrayRef) + } + }}; +} + +/// A set of accumulators +pub type AccumulatorSet = Vec>>; + +/// Builds the array of KeyScalars from `data_type`. +pub fn create_key_array( + i: usize, + data_type: DataType, + map: &FnvHashMap, Rc>, +) -> Result { + let array: Result = match data_type { + DataType::Boolean => key_array_from_map_entries!(BooleanBuilder, Boolean, map, i), + DataType::UInt8 => key_array_from_map_entries!(UInt8Builder, UInt8, map, i), + DataType::UInt16 => key_array_from_map_entries!(UInt16Builder, UInt16, map, i), + DataType::UInt32 => key_array_from_map_entries!(UInt32Builder, UInt32, map, i), + DataType::UInt64 => key_array_from_map_entries!(UInt64Builder, UInt64, map, i), + DataType::Int8 => key_array_from_map_entries!(Int8Builder, Int8, map, i), + DataType::Int16 => key_array_from_map_entries!(Int16Builder, Int16, map, i), + DataType::Int32 => key_array_from_map_entries!(Int32Builder, Int32, map, i), + DataType::Int64 => key_array_from_map_entries!(Int64Builder, Int64, map, i), + DataType::Utf8 => { + let mut builder = StringBuilder::new(1); + for k in map.keys() { + match &k[i] { + KeyScalar::Utf8(s) => builder.append_value(&s).unwrap(), + _ => { + return Err(ExecutionError::ExecutionError( + "Unexpected value for Utf8 group column".to_string(), + )) + } + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } + _ => Err(ExecutionError::ExecutionError( + "Unsupported key by expr".to_string(), + )), + }; + Ok(array?) +} diff --git a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs index 19836fd864d..7edb16d97bb 100644 --- a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs @@ -26,20 +26,17 @@ use crate::execution::physical_plan::{ Accumulator, AggregateExpr, ExecutionPlan, Partition, PhysicalExpr, }; -use arrow::array::{ - ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, - StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, -}; -use arrow::array::{ - Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, - Int8Builder, StringBuilder, UInt16Builder, UInt32Builder, UInt64Builder, - UInt8Builder, -}; +use arrow::array::ArrayRef; +use arrow::array::{Float32Builder, Float64Builder, Int64Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::{RecordBatch, RecordBatchReader}; +use crate::execution::physical_plan::common::get_scalar_value; use crate::execution::physical_plan::expressions::col; +use crate::execution::physical_plan::hash::{ + create_key, create_key_array, AccumulatorSet, KeyScalar, +}; use crate::logicalplan::ScalarValue; use fnv::FnvHashMap; @@ -164,28 +161,6 @@ impl Partition for HashAggregatePartition { } } -/// Create array from `key` attribute in map entry (representing a grouping scalar value) -macro_rules! group_array_from_map_entries { - ($BUILDER:ident, $TY:ident, $MAP:expr, $COL_INDEX:expr) => {{ - let mut builder = $BUILDER::new($MAP.len()); - let mut err = false; - for k in $MAP.keys() { - match k[$COL_INDEX] { - GroupByScalar::$TY(n) => builder.append_value(n).unwrap(), - _ => err = true, - } - } - if err { - Err(ExecutionError::ExecutionError( - "unexpected type when creating grouping array from aggregate map" - .to_string(), - )) - } else { - Ok(Arc::new(builder.finish()) as ArrayRef) - } - }}; -} - /// Create array from `value` attribute in map entry (representing an aggregate scalar /// value) macro_rules! aggr_array_from_map_entries { @@ -236,12 +211,6 @@ macro_rules! aggr_array_from_accumulator { }}; } -#[derive(Debug)] -struct MapEntry { - k: Vec, - v: Vec>, -} - struct GroupedHashAggregateIterator { schema: SchemaRef, group_expr: Vec>, @@ -268,24 +237,6 @@ impl GroupedHashAggregateIterator { } } -type AccumulatorSet = Vec>>; - -macro_rules! update_accumulators { - ($ARRAY:ident, $ARRAY_TY:ident, $SCALAR_TY:expr, $COL:expr, $ACCUM:expr) => {{ - let primitive_array = $ARRAY.as_any().downcast_ref::<$ARRAY_TY>().unwrap(); - - for row in 0..$ARRAY.len() { - if $ARRAY.is_valid(row) { - let value = Some($SCALAR_TY(primitive_array.value(row))); - let mut accum = $ACCUM[row][$COL].borrow_mut(); - accum - .accumulate_scalar(value) - .map_err(ExecutionError::into_arrow_external_error)?; - } - } - }}; -} - impl RecordBatchReader for GroupedHashAggregateIterator { fn schema(&self) -> SchemaRef { self.schema.clone() @@ -299,7 +250,7 @@ impl RecordBatchReader for GroupedHashAggregateIterator { self.finished = true; // create map to store accumulators for each unique grouping key - let mut map: FnvHashMap, Rc> = + let mut map: FnvHashMap, Rc> = FnvHashMap::default(); // iterate over all input batches and update the accumulators @@ -327,120 +278,49 @@ impl RecordBatchReader for GroupedHashAggregateIterator { }) .collect::>>()?; - // create vector large enough to hold the grouping key + // create vector to hold the grouping key let mut key = Vec::with_capacity(group_values.len()); for _ in 0..group_values.len() { - key.push(GroupByScalar::UInt32(0)); + key.push(KeyScalar::UInt32(0)); } // iterate over each row in the batch and create the accumulators for each grouping key - let mut accumulators: Vec> = - Vec::with_capacity(batch.num_rows()); - for row in 0..batch.num_rows() { - // create grouping key for this row - create_key(&group_values, row, &mut key) - .map_err(ExecutionError::into_arrow_external_error)?; - - if let Some(accumulator_set) = map.get(&key) { - accumulators.push(accumulator_set.clone()); - } else { - let accumulator_set: AccumulatorSet = self - .aggr_expr - .iter() - .map(|expr| expr.create_accumulator()) - .collect(); - - let accumulator_set = Rc::new(accumulator_set); - - map.insert(key.clone(), accumulator_set.clone()); - accumulators.push(accumulator_set); + // create and assign the grouping key of this row + for i in 0..group_values.len() { + key[i] = create_key(&group_values[i], row) + .map_err(ExecutionError::into_arrow_external_error)?; } - } - // iterate over each non-grouping column in the batch and update the accumulator - // for each row - for col in 0..aggr_input_values.len() { - let array = &aggr_input_values[col]; - - match array.data_type() { - DataType::Int8 => update_accumulators!( - array, - Int8Array, - ScalarValue::Int8, - col, - accumulators - ), - DataType::Int16 => update_accumulators!( - array, - Int16Array, - ScalarValue::Int16, - col, - accumulators - ), - DataType::Int32 => update_accumulators!( - array, - Int32Array, - ScalarValue::Int32, - col, - accumulators - ), - DataType::Int64 => update_accumulators!( - array, - Int64Array, - ScalarValue::Int64, - col, - accumulators - ), - DataType::UInt8 => update_accumulators!( - array, - UInt8Array, - ScalarValue::UInt8, - col, - accumulators - ), - DataType::UInt16 => update_accumulators!( - array, - UInt16Array, - ScalarValue::UInt16, - col, - accumulators - ), - DataType::UInt32 => update_accumulators!( - array, - UInt32Array, - ScalarValue::UInt32, - col, - accumulators - ), - DataType::UInt64 => update_accumulators!( - array, - UInt64Array, - ScalarValue::UInt64, - col, - accumulators - ), - DataType::Float32 => update_accumulators!( - array, - Float32Array, - ScalarValue::Float32, - col, - accumulators - ), - DataType::Float64 => update_accumulators!( - array, - Float64Array, - ScalarValue::Float64, - col, - accumulators - ), - other => { - return Err(ExecutionError::ExecutionError(format!( - "Unsupported data type {:?} for result of aggregate expression", - other - )).into_arrow_external_error()); + // for each new key on the map, add an accumulatorSet to the map + match map.get(&key) { + None => { + let accumulator_set: AccumulatorSet = self + .aggr_expr + .iter() + .map(|expr| expr.create_accumulator()) + .collect(); + map.insert(key.clone(), Rc::new(accumulator_set)); } + _ => (), }; + + // iterate over each non-grouping column in the batch and update the accumulator + // for each row + for col in 0..aggr_input_values.len() { + let value = get_scalar_value(&aggr_input_values[col], row) + .map_err(ExecutionError::into_arrow_external_error)?; + + match map.get(&key) { + None => panic!("This code cannot be reached."), + Some(accumulator_set) => { + let mut accum = accumulator_set[col].borrow_mut(); + accum + .accumulate_scalar(value) + .map_err(ExecutionError::into_arrow_external_error)?; + } + } + } } } @@ -452,54 +332,12 @@ impl RecordBatchReader for GroupedHashAggregateIterator { // grouping values for i in 0..self.group_expr.len() { - let array: Result = match self.group_expr[i] + let data_type = self.group_expr[i] .data_type(&input_schema) - .map_err(ExecutionError::into_arrow_external_error)? - { - DataType::UInt8 => { - group_array_from_map_entries!(UInt8Builder, UInt8, map, i) - } - DataType::UInt16 => { - group_array_from_map_entries!(UInt16Builder, UInt16, map, i) - } - DataType::UInt32 => { - group_array_from_map_entries!(UInt32Builder, UInt32, map, i) - } - DataType::UInt64 => { - group_array_from_map_entries!(UInt64Builder, UInt64, map, i) - } - DataType::Int8 => { - group_array_from_map_entries!(Int8Builder, Int8, map, i) - } - DataType::Int16 => { - group_array_from_map_entries!(Int16Builder, Int16, map, i) - } - DataType::Int32 => { - group_array_from_map_entries!(Int32Builder, Int32, map, i) - } - DataType::Int64 => { - group_array_from_map_entries!(Int64Builder, Int64, map, i) - } - DataType::Utf8 => { - let mut builder = StringBuilder::new(1); - for k in map.keys() { - match &k[i] { - GroupByScalar::Utf8(s) => builder.append_value(&s).unwrap(), - _ => { - return Err(ExecutionError::ExecutionError( - "Unexpected value for Utf8 group column".to_string(), - ) - .into_arrow_external_error()) - } - } - } - Ok(Arc::new(builder.finish()) as ArrayRef) - } - _ => Err(ExecutionError::ExecutionError( - "Unsupported group by expr".to_string(), - )), - }; - result_arrays.push(array.map_err(ExecutionError::into_arrow_external_error)?); + .map_err(ExecutionError::into_arrow_external_error)?; + let array = create_key_array(i, data_type, &map) + .map_err(ExecutionError::into_arrow_external_error)?; + result_arrays.push(array); // aggregate values for i in 0..self.aggr_expr.len() { @@ -677,76 +515,6 @@ impl RecordBatchReader for HashAggregateIterator { } } -/// Enumeration of types that can be used in a GROUP BY expression (all primitives except -/// for floating point numerics) -#[derive(Debug, PartialEq, Eq, Hash, Clone)] -enum GroupByScalar { - UInt8(u8), - UInt16(u16), - UInt32(u32), - UInt64(u64), - Int8(i8), - Int16(i16), - Int32(i32), - Int64(i64), - Utf8(String), -} - -/// Create a Vec that can be used as a map key -fn create_key( - group_by_keys: &[ArrayRef], - row: usize, - vec: &mut Vec, -) -> Result<()> { - for i in 0..group_by_keys.len() { - let col = &group_by_keys[i]; - match col.data_type() { - DataType::UInt8 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt8(array.value(row)) - } - DataType::UInt16 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt16(array.value(row)) - } - DataType::UInt32 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt32(array.value(row)) - } - DataType::UInt64 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::UInt64(array.value(row)) - } - DataType::Int8 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int8(array.value(row)) - } - DataType::Int16 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int16(array.value(row)) - } - DataType::Int32 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int32(array.value(row)) - } - DataType::Int64 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Int64(array.value(row)) - } - DataType::Utf8 => { - let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Utf8(String::from(array.value(row))) - } - _ => { - return Err(ExecutionError::ExecutionError( - "Unsupported GROUP BY data type".to_string(), - )) - } - } - } - Ok(()) -} - #[cfg(test)] mod tests { @@ -755,6 +523,7 @@ mod tests { use crate::execution::physical_plan::expressions::{col, sum}; use crate::execution::physical_plan::merge::MergeExec; use crate::test; + use arrow::array::{Int64Array, UInt32Array}; #[test] fn aggregate() -> Result<()> { diff --git a/rust/datafusion/src/execution/physical_plan/hash_join.rs b/rust/datafusion/src/execution/physical_plan/hash_join.rs new file mode 100644 index 00000000000..aa8211ba521 --- /dev/null +++ b/rust/datafusion/src/execution/physical_plan/hash_join.rs @@ -0,0 +1,388 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the join plan for executing partitions in parallel and then joining the results +//! into a set of partitions. + +use std::collections::{HashMap, HashSet}; +use std::sync::{Arc, Mutex}; + +use arrow::array::Array; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::record_batch::{RecordBatch, RecordBatchReader}; + +use super::{utils::build_array, ExecutionPlan}; +use crate::common::{build_join_schema, check_join_is_valid, JoinHow}; +use crate::error::{ExecutionError, Result}; +use crate::execution::physical_plan::common::RecordBatchIterator; +use crate::execution::physical_plan::expressions::col; +use crate::execution::physical_plan::hash::{create_key, KeyScalar}; +use crate::execution::physical_plan::Partition; + +// A mapping "on" value -> list of row indexes with this key's value +// E.g. [1, 2] -> [3, 6, 8] indicates that rows 3, 6 and 8 have (column1, column2) = [1, 2] +type JoinHashMap = HashMap, Vec>; + +/// join execution plan executes partitions in parallel and combines them into a set of +/// partitions. +pub struct HashJoinExec { + /// left side + left: Arc, + /// right side + right: Arc, + /// Set of common columns used to join on + on: HashSet, + /// How the join is performed + how: JoinHow, + /// The schema once the join is applied + schema: SchemaRef, +} + +impl HashJoinExec { + /// Create a new HashJoinExec + pub fn try_new( + left: Arc, + right: Arc, + on: &HashSet, + how: &JoinHow, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + check_join_is_valid(&left_schema, &right_schema, &on)?; + + let on = on.iter().map(|s| s.clone()).collect::>(); + + let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &on, &how)?); + + Ok(HashJoinExec { + left, + right, + on: on.clone(), + how: how.clone(), + schema, + }) + } +} + +impl ExecutionPlan for HashJoinExec { + fn schema(&self) -> Arc { + self.schema.clone() + } + + fn partitions(&self) -> Result>> { + self.left + .partitions()? + .iter() + .map(move |p| { + let projection: Arc = Arc::new(HashJoinPartition { + schema: self.schema.clone(), + on: self.on.clone(), + how: self.how.clone(), + left: p.clone(), + rights: self.right.partitions()?.clone(), + }); + + Ok(projection) + }) + .collect::>>() + } +} + +/// Partition with a computed hash table +struct HashJoinPartition { + /// Input schema + schema: Arc, + /// columns used to compute the hash + on: HashSet, + /// how to join + how: JoinHow, + /// left partition + left: Arc, + /// partitions on the right + rights: Vec>, +} + +/// returns a HashMap +/// The size of this vector corresponds to the total size of a joined batch +fn build_hash_batch(on: &HashSet, batch: &RecordBatch) -> Result { + let mut hash: JoinHashMap = HashMap::new(); + + // evaluate the keys + let keys_values = on + .iter() + .map(|name| col(name).evaluate(batch)) + .collect::>>()?; + + // build the hash map + for row in 0..batch.num_rows() { + let mut key = Vec::with_capacity(keys_values.len()); + for i in 0..keys_values.len() { + key.push(create_key(&keys_values[i], row)?); + } + match hash.get_mut(&key) { + Some(v) => v.push(row), + None => { + hash.insert(key, vec![row]); + } + }; + } + Ok(hash) +} + +fn build_join_batch( + on: &HashSet, + schema: &Schema, + how: &JoinHow, + left: &RecordBatch, + right: &RecordBatch, + left_hash: &HashMap, Vec>, +) -> Result { + let right_hash = build_hash_batch(on, right)?; + + let join_indexes: Vec<(usize, usize)> = + build_join_indexes(&left_hash, &right_hash, how)?; + + // build the columns for the RecordBatch + let mut columns: Vec> = vec![]; + for field in schema.fields() { + // pick the column (left or right) based on the field name + // if two fields have the same name on left and right, the left is given preference + let (is_left, array) = match left.schema().index_of(field.name()) { + Ok(i) => Ok((true, left.column(i))), + Err(_) => { + match right.schema().index_of(field.name()) { + Ok(i) => Ok((false, right.column(i))), + _ => Err(ExecutionError::InternalError( + format!("During execution, the column {} was not found in neither the left or right side of the join", field.name()).to_string() + )) + } + } + }?; + + // pick the (left or right) indexes of the array + let indexes = join_indexes + .iter() + .map(|(left, right)| if is_left { *left } else { *right }) + .collect(); + + // build of the array out of the indexes. On a join, we expect more entries (due to duplicates) + let array = build_array(&array, &indexes, field.data_type())?; + columns.push(array); + } + Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?) +} + +/// returns a vector with (index from left, index from right). +/// The size of this vector corresponds to the total size of a joined batch +fn build_join_indexes( + left: &JoinHashMap, + right: &JoinHashMap, + how: &JoinHow, +) -> Result> { + // unfortunately rust does not support intersection of map keys :( + let left_set: HashSet> = left.keys().cloned().collect(); + let left_right: HashSet> = right.keys().cloned().collect(); + + match how { + JoinHow::Inner => { + let inner = left_set.intersection(&left_right); + + let mut indexes = Vec::new(); // unknown a prior size + for key in inner { + // the unwrap never happens by construction of the key + let left_indexes = left.get(key).unwrap(); + let right_indexes = right.get(key).unwrap(); + + // for every item on the left and right with this key, add the respective pair + left_indexes.iter().for_each(|x| { + right_indexes.iter().for_each(|y| { + indexes.push((*x, *y)); + }) + }) + } + Ok(indexes) + } + } +} + +/// filter values base on predicate +pub fn build_joined_partition( + schema: &Schema, + on: &HashSet, + how: &JoinHow, + left: &Arc, + right: &Arc, +) -> Result> { + let iterator = left.execute()?; + let mut input = iterator.lock().unwrap(); + + match input.next_batch()? { + None => Ok(None), + Some(left) => { + let left_hash = build_hash_batch(on, &left)?; + + let iterator_other = right.execute()?; + let mut input_other = iterator_other.lock().unwrap(); + match input_other.next_batch()? { + None => Ok(None), + Some(right) => Ok(Some(build_join_batch( + on, schema, how, &left, &right, &left_hash, + )?)), + } + } + } +} + +impl Partition for HashJoinPartition { + /// Execute the join + fn execute(&self) -> Result>> { + let batches = self + .rights + .iter() + .map(|right| { + build_joined_partition( + &self.schema, + &self.on, + &self.how, + &self.left, + &right, + ) + }) + .collect::>>>()?; + let batches = batches + .iter() + .filter_map(|x| match x { + Some(x) => Some(Arc::new(x.clone())), + None => None, + }) + .collect::>(); + + Ok(Arc::new(Mutex::new(RecordBatchIterator::new( + self.schema.clone(), + batches, + )))) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::{ + execution::physical_plan::{common, memory::MemoryExec, ExecutionPlan}, + test::{build_table_i32, columns, format_batch}, + }; + use std::collections::HashSet; + use std::sync::Arc; + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Result> { + let (batch, schema) = build_table_i32(a, b, c)?; + Ok(Arc::new(MemoryExec::try_new( + &vec![vec![batch]], + Arc::new(schema), + None, + )?)) + } + + fn join( + left: Arc, + right: Arc, + on: &[&str], + ) -> Result { + let on = on.iter().map(|s| s.to_string()).collect::>(); + HashJoinExec::try_new(left, right, &on, &JoinHow::Inner) + } + + /// Asserts that the rows are the same, taking into account that their order + /// is irrelevant + fn assert_same_rows(result: &[String], expected: &[&str]) { + assert_eq!(result.len(), expected.len()); + + // convert to set since row order is irrelevant + let result = result.iter().map(|s| s.clone()).collect::>(); + + let expected = expected + .iter() + .map(|s| s.to_string()) + .collect::>(); + assert_eq!(result, expected); + } + + #[test] + fn join_one() -> Result<()> { + let t1 = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + )?; + let t2 = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + )?; + let on = vec!["b1"]; + + let join = join(t1, t2, &on)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["b1", "a1", "c1", "a2", "c2"]); + + let batches = common::collect(join.partitions()?[0].execute()?)?; + assert_eq!(batches.len(), 1); + + let result = format_batch(&batches[0]); + let expected = vec!["5,2,8,20,80", "5,3,9,20,80", "4,1,7,10,70"]; + + assert_same_rows(&result, &expected); + + Ok(()) + } + + #[test] + fn join_two() -> Result<()> { + let t1 = build_table( + ("a1", &vec![1, 2, 2]), + ("b2", &vec![1, 2, 2]), + ("c1", &vec![7, 8, 9]), + )?; + let t2 = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + )?; + let on = vec!["a1", "b2"]; + + let join = join(t1, t2, &on)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]); + + let batches = common::collect(join.partitions()?[0].execute()?)?; + assert_eq!(batches.len(), 1); + + let result = format_batch(&batches[0]); + let expected = vec!["1,1,7,70", "2,2,8,80", "2,2,9,80"]; + + assert_same_rows(&result, &expected); + + Ok(()) + } +} diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index 2e191784678..c9bbef9f3ec 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -80,7 +80,9 @@ pub mod common; pub mod csv; pub mod datasource; pub mod expressions; +pub mod hash; pub mod hash_aggregate; +pub mod hash_join; pub mod limit; pub mod math_expressions; pub mod memory; @@ -90,3 +92,4 @@ pub mod projection; pub mod selection; pub mod sort; pub mod udf; +pub mod utils; diff --git a/rust/datafusion/src/execution/physical_plan/utils.rs b/rust/datafusion/src/execution/physical_plan/utils.rs new file mode 100644 index 00000000000..93b6b08a075 --- /dev/null +++ b/rust/datafusion/src/execution/physical_plan/utils.rs @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines utility functions for low-level operations on Arrays + +use std::sync::Arc; + +use crate::error::{ExecutionError, Result}; +use arrow::array::{ + Array, ArrayRef, BooleanArray, BooleanBuilder, Float32Array, Float32Builder, + Float64Array, Float64Builder, Int16Array, Int16Builder, Int32Array, Int32Builder, + Int64Array, Int64Builder, Int8Array, Int8Builder, StringArray, StringBuilder, + UInt16Array, UInt16Builder, UInt32Array, UInt32Builder, UInt64Array, UInt64Builder, + UInt8Array, UInt8Builder, +}; +use arrow::datatypes::DataType; + +// cast, iterate over and re-build array +macro_rules! _build_array { + ($ARRAY:expr, $INDEXES:expr, $TYPE:ty, $BUILDER:ty) => {{ + let array = match $ARRAY.as_any().downcast_ref::<$TYPE>() { + Some(n) => Ok(n), + None => Err(ExecutionError::InternalError( + format!("Invalid data type for ").to_string(), + )), + }?; + + let mut builder = <$BUILDER>::new($INDEXES.len()); + for index in $INDEXES { + if array.is_null(*index) { + builder.append_null()?; + } else { + builder.append_value(array.value(*index))?; + } + } + Ok(Arc::new(builder.finish())) + };}; +} + +/// Builds and array +pub fn build_array( + array: &ArrayRef, + indexes: &Vec, + datatype: &DataType, +) -> Result { + match datatype { + DataType::Boolean => _build_array!(array, indexes, BooleanArray, BooleanBuilder), + DataType::Int8 => _build_array!(array, indexes, Int8Array, Int8Builder), + DataType::Int16 => _build_array!(array, indexes, Int16Array, Int16Builder), + DataType::Int32 => _build_array!(array, indexes, Int32Array, Int32Builder), + DataType::Int64 => _build_array!(array, indexes, Int64Array, Int64Builder), + DataType::UInt8 => _build_array!(array, indexes, UInt8Array, UInt8Builder), + DataType::UInt16 => _build_array!(array, indexes, UInt16Array, UInt16Builder), + DataType::UInt32 => _build_array!(array, indexes, UInt32Array, UInt32Builder), + DataType::UInt64 => _build_array!(array, indexes, UInt64Array, UInt64Builder), + DataType::Float64 => _build_array!(array, indexes, Float64Array, Float64Builder), + DataType::Float32 => _build_array!(array, indexes, Float32Array, Float32Builder), + DataType::Utf8 => _build_array!(array, indexes, StringArray, StringBuilder), + _ => Err(ExecutionError::NotImplemented( + format!( + "Conversions for type {:?} are still not implemented", + datatype + ) + .to_string(), + )), + } +} diff --git a/rust/datafusion/src/lib.rs b/rust/datafusion/src/lib.rs index fb4e5af303f..c69c0b55698 100644 --- a/rust/datafusion/src/lib.rs +++ b/rust/datafusion/src/lib.rs @@ -29,6 +29,7 @@ extern crate arrow; extern crate sqlparser; +pub mod common; pub mod datasource; pub mod error; pub mod execution; diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 25780f9a127..b2b878c9b3a 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -21,10 +21,11 @@ //! Logical query plans can then be optimized and executed directly, or translated into //! physical query plans and executed. -use std::fmt; +use std::{collections::HashSet, fmt}; use arrow::datatypes::{DataType, Field, Schema}; +use crate::common::{build_join_schema, check_join_is_valid, JoinHow}; use crate::datasource::csv::{CsvFile, CsvReadOptions}; use crate::datasource::parquet::ParquetTable; use crate::datasource::TableProvider; @@ -636,6 +637,19 @@ pub enum LogicalPlan { /// The schema description schema: Box, }, + /// Represents a join operation + Join { + /// The left side + left: Box, + /// The right side + right: Box, + /// Grouping column names + on: HashSet, + /// How the join is performed + how: JoinHow, + /// The schema description + schema: Box, + }, /// Represents a list of sort expressions to be applied to a relation Sort { /// The sort expressions @@ -747,6 +761,7 @@ impl LogicalPlan { LogicalPlan::Sort { schema, .. } => &schema, LogicalPlan::Limit { schema, .. } => &schema, LogicalPlan::CreateExternalTable { schema, .. } => &schema, + LogicalPlan::Join { schema, .. } => &schema, } } } @@ -834,6 +849,18 @@ impl LogicalPlan { write!(f, "Limit: {}", n)?; input.fmt_with_indent(f, indent + 1) } + LogicalPlan::Join { + ref left, + ref right, + ref on, + ref how, + .. + } => { + let on = on.iter().map(|x| x.clone()).collect::>().join(", #"); + write!(f, "Join: on=[#{}] how={:?}", on, how)?; + left.fmt_with_indent(f, indent + 1)?; + right.fmt_with_indent(f, indent + 1) + } LogicalPlan::CreateExternalTable { ref name, .. } => { write!(f, "CreateExternalTable: {:?}", name) } @@ -1045,6 +1072,28 @@ impl LogicalPlanBuilder { })) } + /// Apply a join + pub fn join( + &self, + right: &LogicalPlan, + on: &[String], + how: &JoinHow, + ) -> Result { + let on = on.iter().map(|s| s.to_string()).collect::>(); + + check_join_is_valid(&self.plan.schema(), &right.schema(), &on)?; + + let schema = build_join_schema(&self.plan.schema(), &right.schema(), &on, how)?; + + Ok(Self::from(&LogicalPlan::Join { + left: Box::new(self.plan.clone()), + right: Box::new(right.clone()), + on, + how: how.clone(), + schema: Box::new(schema), + })) + } + /// Apply an aggregate pub fn aggregate(&self, group_expr: Vec, aggr_expr: Vec) -> Result { let mut all_expr: Vec = group_expr.clone(); @@ -1069,6 +1118,7 @@ impl LogicalPlanBuilder { #[cfg(test)] mod tests { use super::*; + use crate::test::{build_table_i32, columns}; #[test] fn plan_builder_simple() -> Result<()> { @@ -1136,6 +1186,39 @@ mod tests { Ok(()) } + fn build_table(a: &str, b: &str, c: &str) -> Result { + let (batch, schema) = build_table_i32((a, &vec![]), (b, &vec![]), (c, &vec![]))?; + + Ok(LogicalPlan::InMemoryScan { + data: vec![vec![batch]], + schema: Box::new(schema.clone()), + projection: None, + projected_schema: Box::new(schema), + }) + } + + #[test] + fn plan_builder_join() -> Result<()> { + let on = vec!["a1".to_string()]; + let t1 = build_table("a1", "b1", "c1")?; + let t2 = build_table("a1", "b2", "c2")?; + + let plan = LogicalPlanBuilder::from(&t1) + .join(&t2, &on, &JoinHow::Inner)? + .build()?; + + let expected = "\ + Join: on=[#a1] how=Inner\ + \n InMemoryScan: projection=None\ + \n InMemoryScan: projection=None"; + assert_eq!(expected, format!("{:?}", plan)); + + let columns = columns(&plan.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "b2", "c2"]); + + Ok(()) + } + #[test] fn plan_builder_sort() -> Result<()> { let plan = LogicalPlanBuilder::scan( diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs index e99a996d7c9..0baac025576 100644 --- a/rust/datafusion/src/optimizer/projection_push_down.rs +++ b/rust/datafusion/src/optimizer/projection_push_down.rs @@ -110,6 +110,25 @@ impl ProjectionPushDown { input: Box::new(self.optimize_plan(&input, accum, has_projection)?), schema: schema.clone(), }), + LogicalPlan::Join { + left, + right, + on, + how, + schema, + } => { + // optimize each of the plans + let left = self.optimize_plan(&left, accum, has_projection)?; + let right = self.optimize_plan(&right, accum, has_projection)?; + + Ok(LogicalPlan::Join { + left: Box::new(left), + right: Box::new(right), + on: on.clone(), + how: how.clone(), + schema: schema.clone(), + }) + } LogicalPlan::EmptyRelation { .. } => Ok(plan.clone()), LogicalPlan::TableScan { schema_name, diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index a03a92cdfe5..0e9ff73ce7f 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -165,6 +165,26 @@ impl<'a> OptimizerRule for TypeCoercionRule<'a> { self.rewrite_expr_list(aggr_expr, input.schema())?, )? .build(), + LogicalPlan::Join { + left, + right, + on, + how, + schema, + } => { + // no optimization to be made on this node, so we just pass the optimization + // to its children + let left = self.optimize(&left)?; + let right = self.optimize(&right)?; + + Ok(LogicalPlan::Join { + left: Box::new(left), + right: Box::new(right), + on: on.clone(), + how: how.clone(), + schema: schema.clone(), + }) + } LogicalPlan::TableScan { .. } => Ok(plan.clone()), LogicalPlan::InMemoryScan { .. } => Ok(plan.clone()), LogicalPlan::ParquetScan { .. } => Ok(plan.clone()), diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 21cf870ba34..9e883752bdf 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -26,7 +26,7 @@ use crate::logicalplan::{ use arrow::datatypes::*; -use crate::logicalplan::Expr::Alias; +use crate::{common::JoinHow, logicalplan::Expr::Alias}; use sqlparser::sqlast::*; /// The SchemaProvider trait allows the query planner to obtain meta-data about tables and @@ -60,6 +60,7 @@ impl SqlToRel { ref limit, ref group_by, ref having, + ref joins, .. } => { if having.is_some() { @@ -74,7 +75,9 @@ impl SqlToRel { None => LogicalPlanBuilder::empty().build()?, }; - // selection first + // join first, since a filter may include columns from both sides + let plan = self.join(&plan, &joins)?; + let plan = self.filter(&plan, selection)?; let projection_expr: Vec = projection @@ -125,6 +128,43 @@ impl SqlToRel { } } + /// Apply a join to the plan + pub fn join(&self, plan: &LogicalPlan, joins: &[Join]) -> Result { + if joins.len() == 0 { + // short-circuit if no join exists + return Ok(plan.clone()); + } + if joins.len() > 1 { + return Err(ExecutionError::NotImplemented( + "statements with more than one join relation are still not supported" + .to_owned(), + )); + }; + let join: &Join = &joins[0]; + + match &join.join_operator { + JoinOperator::Inner(JoinConstraint::On(relation)) => { + let expr = self.sql_to_rex(&relation, &plan.schema())?; + + let names = match expr { + Expr::Column(name) => Ok(vec![name]), + _ => Err(ExecutionError::NotImplemented( + "Only joins on single columns are supported".to_owned(), + )), + }?; + + let right = self.sql_to_rel(&join.relation)?; + + LogicalPlanBuilder::from(&plan) + .join(&right, &names, &JoinHow::Inner)? + .build() + } + _ => Err(ExecutionError::NotImplemented( + "Only inner joins (ON) are currently supported".to_owned(), + )), + } + } + /// Apply a filter to the plan fn filter( &self, @@ -556,6 +596,27 @@ mod tests { ); } + #[test] + fn one_column_join() { + quick_test( + "SELECT a FROM simple1 JOIN simple2 ON a", + "\ + Projection: #a\ + \n Join: on=[#a] how=Inner\ + \n TableScan: simple1 projection=None\ + \n TableScan: simple2 projection=None", + ); + + quick_test( + "SELECT a FROM simple1 JOIN simple2 ON b", + "\ + Projection: #a\ + \n Join: on=[#b] how=Inner\ + \n TableScan: simple1 projection=None\ + \n TableScan: simple2 projection=None", + ); + } + #[test] fn test_wildcard() { quick_test( @@ -719,6 +780,16 @@ mod tests { Field::new("c12", DataType::Float64, false), Field::new("c13", DataType::Utf8, false), ]))), + "simple1" => Some(Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::UInt32, false), + Field::new("c1", DataType::UInt32, false), + ]))), + "simple2" => Some(Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::UInt32, false), + Field::new("c2", DataType::UInt32, false), + ]))), _ => None, } } diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index 317c14564f1..72eba71ccd9 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -21,6 +21,7 @@ use crate::error::Result; use crate::execution::context::ExecutionContext; use crate::execution::physical_plan::ExecutionPlan; use crate::logicalplan::{Expr, LogicalPlan, LogicalPlanBuilder}; +use array::Int32Array; use arrow::array; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; @@ -231,3 +232,31 @@ pub fn max(expr: Expr) -> Expr { return_type: DataType::Float64, } } + +/// returns a table with 3 columns of i32 in memory +pub fn build_table_i32( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), +) -> Result<(RecordBatch, Schema)> { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Int32, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + )?; + Ok((batch, schema)) +} + +/// Returns the column names on the schema +pub fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() +}