From f0303cad2a59b41e0a4ed57619a4ab5d68319791 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 19 Nov 2020 16:59:25 -0700 Subject: [PATCH 1/7] Implement DataFrame.join --- rust/datafusion/src/dataframe.rs | 25 +++++++++++- .../src/execution/dataframe_impl.rs | 23 ++++++++++- rust/datafusion/src/logical_plan/builder.rs | 34 ++++++++++++++++- rust/datafusion/src/logical_plan/mod.rs | 4 +- rust/datafusion/src/logical_plan/plan.rs | 38 +++++++++++++++++++ .../src/optimizer/projection_push_down.rs | 1 + rust/datafusion/src/optimizer/utils.rs | 18 +++++++++ rust/datafusion/src/physical_plan/planner.rs | 5 +++ rust/datafusion/src/prelude.rs | 2 +- 9 files changed, 145 insertions(+), 5 deletions(-) diff --git a/rust/datafusion/src/dataframe.rs b/rust/datafusion/src/dataframe.rs index d2ad7e6917a..c814affb463 100644 --- a/rust/datafusion/src/dataframe.rs +++ b/rust/datafusion/src/dataframe.rs @@ -19,7 +19,7 @@ use crate::arrow::record_batch::RecordBatch; use crate::error::Result; -use crate::logical_plan::{Expr, FunctionRegistry, LogicalPlan}; +use crate::logical_plan::{Expr, FunctionRegistry, JoinType, LogicalPlan}; use arrow::datatypes::Schema; use std::sync::Arc; @@ -146,6 +146,29 @@ pub trait DataFrame { /// ``` fn sort(&self, expr: Vec) -> Result>; + /// Join this DataFrame with another DataFrame using the specified columns as join keys + /// + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let mut ctx = ExecutionContext::new(); + /// let left = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?; + /// let right = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?; + /// let join = left.join(right, JoinType::Inner, vec!["a", "b"], vec!["a", "b"])?; + /// let batches = join.collect().await?; + /// # Ok(()) + /// # } + /// ``` + fn join( + &self, + right: Arc, + join_type: JoinType, + left_cols: Vec<&str>, + right_cols: Vec<&str>, + ) -> Result>; + /// Executes this DataFrame and collects all results into a vector of RecordBatch. /// /// ``` diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index a93e5745a03..b4179639f65 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -23,7 +23,9 @@ use crate::arrow::record_batch::RecordBatch; use crate::dataframe::*; use crate::error::Result; use crate::execution::context::{ExecutionContext, ExecutionContextState}; -use crate::logical_plan::{col, Expr, FunctionRegistry, LogicalPlan, LogicalPlanBuilder}; +use crate::logical_plan::{ + col, Expr, FunctionRegistry, JoinType, LogicalPlan, LogicalPlanBuilder, +}; use arrow::datatypes::Schema; use async_trait::async_trait; @@ -102,6 +104,25 @@ impl DataFrame for DataFrameImpl { Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) } + /// Join with another DataFrame + fn join( + &self, + right: Arc, + join_type: JoinType, + left_cols: Vec<&str>, + right_cols: Vec<&str>, + ) -> Result> { + let plan = LogicalPlanBuilder::from(&self.plan) + .join( + Arc::new(right.to_logical_plan()), + join_type, + left_cols, + right_cols, + )? + .build()?; + Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) + } + /// Convert to logical plan fn to_logical_plan(&self) -> LogicalPlan { self.plan.clone() diff --git a/rust/datafusion/src/logical_plan/builder.rs b/rust/datafusion/src/logical_plan/builder.rs index 32e53039281..89a7613d282 100644 --- a/rust/datafusion/src/logical_plan/builder.rs +++ b/rust/datafusion/src/logical_plan/builder.rs @@ -27,7 +27,8 @@ use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; use super::{ - col, exprlist_to_fields, Expr, LogicalPlan, PlanType, StringifiedPlan, TableSource, + col, exprlist_to_fields, Expr, JoinType, LogicalPlan, PlanType, StringifiedPlan, + TableSource, }; /// Builder for logical plans @@ -181,6 +182,37 @@ impl LogicalPlanBuilder { })) } + /// Apply a join + pub fn join( + &self, + right: Arc, + join_type: JoinType, + left_keys: Vec<&str>, + right_keys: Vec<&str>, + ) -> Result { + //TODO reconcile this with the logic in https://github.com/apache/arrow/pull/8709 + let mut fields = vec![]; + self.plan + .schema() + .fields() + .iter() + .for_each(|f| fields.push(f.to_owned())); + right + .schema() + .fields() + .iter() + .for_each(|f| fields.push(f.to_owned())); + + Ok(Self::from(&LogicalPlan::Join { + left: Arc::new(self.plan.clone()), + right, + left_keys: left_keys.iter().map(|k| k.to_string()).collect(), + right_keys: right_keys.iter().map(|k| k.to_string()).collect(), + join_type, + schema: Arc::new(Schema::new(fields)), + })) + } + /// Apply an aggregate pub fn aggregate(&self, group_expr: Vec, aggr_expr: Vec) -> Result { let mut all_expr: Vec = group_expr.clone(); diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index c2f90da3581..a3a44d5341a 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -38,5 +38,7 @@ pub use expr::{ }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; -pub use plan::{LogicalPlan, PlanType, PlanVisitor, StringifiedPlan, TableSource}; +pub use plan::{ + JoinType, LogicalPlan, PlanType, PlanVisitor, StringifiedPlan, TableSource, +}; pub use registry::FunctionRegistry; diff --git a/rust/datafusion/src/logical_plan/plan.rs b/rust/datafusion/src/logical_plan/plan.rs index ced13fa89ce..525984b895e 100644 --- a/rust/datafusion/src/logical_plan/plan.rs +++ b/rust/datafusion/src/logical_plan/plan.rs @@ -41,6 +41,13 @@ pub enum TableSource { FromProvider(Arc), } +/// Join type +#[derive(Debug, Clone)] +pub enum JoinType { + /// Inner join + Inner, +} + /// A LogicalPlan represents the different types of relational /// operators (such as Projection, Filter, etc) and can be created by /// the SQL query planner and the DataFrame API. @@ -94,6 +101,21 @@ pub enum LogicalPlan { /// The incoming logical plan input: Arc, }, + /// Join two logical plans on one or more join columns + Join { + /// Left input + left: Arc, + /// Right input + right: Arc, + /// Columns in left input to use for join keys + left_keys: Vec, + /// Columns in right input to use for join keys + right_keys: Vec, + /// Join type + join_type: JoinType, + /// The output schema, containing fields from the left and right inputs + schema: SchemaRef, + }, /// Produces rows from a table provider by reference or from the context TableScan { /// The name of the schema @@ -211,6 +233,7 @@ impl LogicalPlan { LogicalPlan::Filter { input, .. } => input.schema(), LogicalPlan::Aggregate { schema, .. } => &schema, LogicalPlan::Sort { input, .. } => input.schema(), + LogicalPlan::Join { schema, .. } => &schema, LogicalPlan::Limit { input, .. } => input.schema(), LogicalPlan::CreateExternalTable { schema, .. } => &schema, LogicalPlan::Explain { schema, .. } => &schema, @@ -292,6 +315,9 @@ impl LogicalPlan { LogicalPlan::Filter { input, .. } => input.accept(visitor)?, LogicalPlan::Aggregate { input, .. } => input.accept(visitor)?, LogicalPlan::Sort { input, .. } => input.accept(visitor)?, + LogicalPlan::Join { left, right, .. } => { + left.accept(visitor)? && right.accept(visitor)? + } LogicalPlan::Limit { input, .. } => input.accept(visitor)?, LogicalPlan::Extension { node } => { for input in node.inputs() { @@ -555,6 +581,18 @@ impl LogicalPlan { } Ok(()) } + LogicalPlan::Join { + ref left_keys, + ref right_keys, + .. + } => { + let join_expr: Vec = left_keys + .iter() + .zip(right_keys) + .map(|(l, r)| format!("{} = {}", l, r)) + .collect(); + write!(f, "Join: {}", join_expr.join(", ")) + } LogicalPlan::Limit { ref n, .. } => write!(f, "Limit: {}", n), LogicalPlan::CreateExternalTable { ref name, .. } => { write!(f, "CreateExternalTable: {:?}", name) diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs index f4f190cd193..4c99ab5b72f 100644 --- a/rust/datafusion/src/optimizer/projection_push_down.rs +++ b/rust/datafusion/src/optimizer/projection_push_down.rs @@ -312,6 +312,7 @@ fn optimize_plan( | LogicalPlan::Filter { .. } | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Sort { .. } + | LogicalPlan::Join { .. } | LogicalPlan::CreateExternalTable { .. } | LogicalPlan::Extension { .. } => { let expr = utils::expressions(plan); diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index b4c4e890ba9..20b46eb96ed 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -118,6 +118,9 @@ pub fn expressions(plan: &LogicalPlan) -> Vec { result.extend(aggr_expr.clone()); result } + LogicalPlan::Join { .. } => { + vec![] + } LogicalPlan::Sort { expr, .. } => expr.clone(), LogicalPlan::Extension { node } => node.expressions(), // plans without expressions @@ -139,6 +142,7 @@ pub fn inputs(plan: &LogicalPlan) -> Vec<&LogicalPlan> { LogicalPlan::Filter { input, .. } => vec![input], LogicalPlan::Aggregate { input, .. } => vec![input], LogicalPlan::Sort { input, .. } => vec![input], + LogicalPlan::Join { left, right, .. } => vec![left, right], LogicalPlan::Limit { input, .. } => vec![input], LogicalPlan::Extension { node } => node.inputs(), // plans without inputs @@ -180,6 +184,20 @@ pub fn from_plan( expr: expr.clone(), input: Arc::new(inputs[0].clone()), }), + LogicalPlan::Join { + join_type, + left_keys, + right_keys, + schema, + .. + } => Ok(LogicalPlan::Join { + left: Arc::new(inputs[0].clone()), + right: Arc::new(inputs[1].clone()), + join_type: join_type.clone(), + left_keys: left_keys.clone(), + right_keys: right_keys.clone(), + schema: schema.clone(), + }), LogicalPlan::Limit { n, .. } => Ok(LogicalPlan::Limit { n: *n, input: Arc::new(inputs[0].clone()), diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index 4740a5dde14..023205b43ec 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -287,6 +287,11 @@ impl DefaultPhysicalPlanner { ctx_state.config.concurrency, )?)) } + LogicalPlan::Join { .. } => { + // TODO once https://github.com/apache/arrow/pull/8709 is merged we can + // create the physical operator here + todo!() + } LogicalPlan::EmptyRelation { produce_one_row, schema, diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index aac2ebf71f1..20bbbe47c97 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -28,6 +28,6 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ - array, avg, col, concat, count, create_udf, length, lit, max, min, sum, + array, avg, col, concat, count, create_udf, length, lit, max, min, sum, JoinType, }; pub use crate::physical_plan::csv::CsvReadOptions; From def6c52ec073a0b2640ee3a91c901cd5bf1ce9e2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 21 Nov 2020 11:07:09 -0700 Subject: [PATCH 2/7] Implement planner for join --- rust/datafusion/src/dataframe.rs | 8 ++- rust/datafusion/src/logical_plan/builder.rs | 50 +++++++++++-------- .../datafusion/src/physical_plan/hash_join.rs | 14 +++--- .../src/physical_plan/hash_utils.rs | 9 ++-- rust/datafusion/src/physical_plan/planner.rs | 32 ++++++++++-- 5 files changed, 76 insertions(+), 37 deletions(-) diff --git a/rust/datafusion/src/dataframe.rs b/rust/datafusion/src/dataframe.rs index c814affb463..1f29c53b66e 100644 --- a/rust/datafusion/src/dataframe.rs +++ b/rust/datafusion/src/dataframe.rs @@ -155,8 +155,12 @@ pub trait DataFrame { /// # async fn main() -> Result<()> { /// let mut ctx = ExecutionContext::new(); /// let left = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?; - /// let right = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?; - /// let join = left.join(right, JoinType::Inner, vec!["a", "b"], vec!["a", "b"])?; + /// let right = ctx.read_csv("tests/example.csv", CsvReadOptions::new())? + /// .select(vec![ + /// col("a").alias("a2"), + /// col("b").alias("b2"), + /// col("c").alias("c2")])?; + /// let join = left.join(right, JoinType::Inner, vec!["a", "b"], vec!["a2", "b2"])?; /// let batches = join.collect().await?; /// # Ok(()) /// # } diff --git a/rust/datafusion/src/logical_plan/builder.rs b/rust/datafusion/src/logical_plan/builder.rs index 89a7613d282..2877b0c5a3e 100644 --- a/rust/datafusion/src/logical_plan/builder.rs +++ b/rust/datafusion/src/logical_plan/builder.rs @@ -30,6 +30,7 @@ use super::{ col, exprlist_to_fields, Expr, JoinType, LogicalPlan, PlanType, StringifiedPlan, TableSource, }; +use crate::physical_plan::hash_utils; /// Builder for logical plans pub struct LogicalPlanBuilder { @@ -190,27 +191,34 @@ impl LogicalPlanBuilder { left_keys: Vec<&str>, right_keys: Vec<&str>, ) -> Result { - //TODO reconcile this with the logic in https://github.com/apache/arrow/pull/8709 - let mut fields = vec![]; - self.plan - .schema() - .fields() - .iter() - .for_each(|f| fields.push(f.to_owned())); - right - .schema() - .fields() - .iter() - .for_each(|f| fields.push(f.to_owned())); - - Ok(Self::from(&LogicalPlan::Join { - left: Arc::new(self.plan.clone()), - right, - left_keys: left_keys.iter().map(|k| k.to_string()).collect(), - right_keys: right_keys.iter().map(|k| k.to_string()).collect(), - join_type, - schema: Arc::new(Schema::new(fields)), - })) + if left_keys.len() != right_keys.len() { + Err(DataFusionError::Plan( + "left_keys and right_keys were not the same length".to_string(), + )) + } else { + let on: Vec<_> = left_keys + .iter() + .zip(right_keys.iter()) + .map(|(x, y)| (x.to_string(), y.to_string())) + .collect::>(); + let physical_join_type = match join_type { + JoinType::Inner => hash_utils::JoinType::Inner, + }; + let physical_schema = hash_utils::build_join_schema( + self.plan.schema(), + right.schema(), + &on, + &physical_join_type, + ); + Ok(Self::from(&LogicalPlan::Join { + left: Arc::new(self.plan.clone()), + right, + left_keys: left_keys.iter().map(|k| k.to_string()).collect(), + right_keys: right_keys.iter().map(|k| k.to_string()).collect(), + join_type, + schema: Arc::new(physical_schema), + })) + } } /// Apply an aggregate diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs index 69a3d5a432e..c5cde23c90a 100644 --- a/rust/datafusion/src/physical_plan/hash_join.rs +++ b/rust/datafusion/src/physical_plan/hash_join.rs @@ -130,11 +130,7 @@ impl ExecutionPlan for HashJoinExec { 2 => Ok(Arc::new(HashJoinExec::try_new( children[0].clone(), children[1].clone(), - &self - .on - .iter() - .map(|(x, y)| (x.as_str(), y.as_str())) - .collect::>(), + &self.on, &self.join_type, )?)), _ => Err(DataFusionError::Internal( @@ -438,9 +434,13 @@ mod tests { fn join( left: Arc, right: Arc, - on: &JoinOn, + on: &[(&str, &str)], ) -> Result { - HashJoinExec::try_new(left, right, on, &JoinType::Inner) + let on: Vec<_> = on + .iter() + .map(|(a, b)| (a.to_string(), b.to_string())) + .collect(); + HashJoinExec::try_new(left, right, &on, &JoinType::Inner) } /// Asserts that the rows are the same, taking into account that their order diff --git a/rust/datafusion/src/physical_plan/hash_utils.rs b/rust/datafusion/src/physical_plan/hash_utils.rs index c3987faa88b..be8f4090242 100644 --- a/rust/datafusion/src/physical_plan/hash_utils.rs +++ b/rust/datafusion/src/physical_plan/hash_utils.rs @@ -29,7 +29,7 @@ pub enum JoinType { } /// The on clause of the join, as vector of (left, right) columns. -pub type JoinOn<'a> = [(&'a str, &'a str)]; +pub type JoinOn = [(String, String)]; /// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. /// They are valid whenever their columns' intersection equals the set `on` @@ -119,8 +119,11 @@ mod tests { fn check(left: &[&str], right: &[&str], on: &[(&str, &str)]) -> Result<()> { let left = left.iter().map(|x| x.to_string()).collect::>(); let right = right.iter().map(|x| x.to_string()).collect::>(); - - check_join_set_is_valid(&left, &right, on) + let on: Vec<_> = on + .iter() + .map(|(a, b)| (a.to_string(), b.to_string())) + .collect(); + check_join_set_is_valid(&left, &right, &on) } #[test] diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index 023205b43ec..8a2fdb96bc7 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -30,6 +30,8 @@ use crate::physical_plan::explain::ExplainExec; use crate::physical_plan::expressions::{Column, Literal, PhysicalSortExpr}; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; +use crate::physical_plan::hash_join::HashJoinExec; +use crate::physical_plan::hash_utils; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::memory::MemoryExec; use crate::physical_plan::merge::MergeExec; @@ -39,6 +41,7 @@ use crate::physical_plan::sort::SortExec; use crate::physical_plan::udf; use crate::physical_plan::{expressions, Distribution}; use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, PhysicalPlanner}; +use crate::prelude::JoinType; use crate::variable::VarType; use arrow::compute::SortOptions; use arrow::datatypes::Schema; @@ -287,10 +290,31 @@ impl DefaultPhysicalPlanner { ctx_state.config.concurrency, )?)) } - LogicalPlan::Join { .. } => { - // TODO once https://github.com/apache/arrow/pull/8709 is merged we can - // create the physical operator here - todo!() + LogicalPlan::Join { + left, + right, + left_keys, + right_keys, + join_type, + schema, + } => { + let left = self.create_physical_plan(left, ctx_state)?; + let right = self.create_physical_plan(right, ctx_state)?; + let physical_join_type = match join_type { + JoinType::Inner => hash_utils::JoinType::Inner, + }; + let on: Vec<_> = left_keys + .iter() + .zip(right_keys.iter()) + .map(|(a, b)| (a.to_string(), b.to_string())) + .collect(); + let hash_join = + HashJoinExec::try_new(left, right, &on, &physical_join_type)?; + if schema.as_ref() == hash_join.schema().as_ref() { + Ok(Arc::new(hash_join)) + } else { + Err(DataFusionError::Plan("schema mismatch".to_string())) + } } LogicalPlan::EmptyRelation { produce_one_row, From c8fb67424a4a7965b71e9624d7ca7e3401ee48b8 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 21 Nov 2020 11:19:37 -0700 Subject: [PATCH 3/7] fix rustdoc test failure --- rust/datafusion/src/logical_plan/builder.rs | 3 ++ .../datafusion/src/physical_plan/hash_join.rs | 30 +++++++++++++++++++ .../src/physical_plan/hash_utils.rs | 10 +++++-- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/rust/datafusion/src/logical_plan/builder.rs b/rust/datafusion/src/logical_plan/builder.rs index 2877b0c5a3e..f4a8136233e 100644 --- a/rust/datafusion/src/logical_plan/builder.rs +++ b/rust/datafusion/src/logical_plan/builder.rs @@ -210,6 +210,9 @@ impl LogicalPlanBuilder { &on, &physical_join_type, ); + println!("left: {:?}", self.plan.schema()); + println!("right: {:?}", right.schema()); + println!("join: {:?}", physical_schema); Ok(Self::from(&LogicalPlan::Join { left: Arc::new(self.plan.clone()), right, diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs index c5cde23c90a..94c0e09d741 100644 --- a/rust/datafusion/src/physical_plan/hash_join.rs +++ b/rust/datafusion/src/physical_plan/hash_join.rs @@ -486,6 +486,36 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_one_no_shared_column_names() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = &[("b1", "b2")]; + + let join = join(left, right, on)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + let result = format_batch(&batches[0]); + let expected = vec!["2,5,8,20,5,80", "3,5,9,20,5,80", "1,4,7,10,4,70"]; + + assert_same_rows(&result, &expected); + + Ok(()) + } + #[tokio::test] async fn join_two() -> Result<()> { let left = build_table( diff --git a/rust/datafusion/src/physical_plan/hash_utils.rs b/rust/datafusion/src/physical_plan/hash_utils.rs index be8f4090242..1175a6886d5 100644 --- a/rust/datafusion/src/physical_plan/hash_utils.rs +++ b/rust/datafusion/src/physical_plan/hash_utils.rs @@ -94,15 +94,19 @@ pub fn build_join_schema( ) -> Schema { let fields: Vec = match join_type { JoinType::Inner => { - // inner: all fields are there - let on_right = &on.iter().map(|on| on.1.to_string()).collect::>(); + // remove right-side join keys if they have the same names as the left-side + let duplicate_keys = &on + .iter() + .filter(|(l, r)| l == r) + .map(|on| on.1.to_string()) + .collect::>(); let left_fields = left.fields().iter(); let right_fields = right .fields() .iter() - .filter(|f| !on_right.contains(f.name())); + .filter(|f| !duplicate_keys.contains(f.name())); // left then right left_fields.chain(right_fields).cloned().collect() From 24e03c5f8af499459b62287e0d4a2cae052ab3f0 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 21 Nov 2020 11:34:46 -0700 Subject: [PATCH 4/7] add unit test --- rust/datafusion/src/dataframe.rs | 4 ++-- .../datafusion/src/execution/dataframe_impl.rs | 18 ++++++++++++++++-- rust/datafusion/src/logical_plan/builder.rs | 7 ++----- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/rust/datafusion/src/dataframe.rs b/rust/datafusion/src/dataframe.rs index 1f29c53b66e..edf454fa96a 100644 --- a/rust/datafusion/src/dataframe.rs +++ b/rust/datafusion/src/dataframe.rs @@ -169,8 +169,8 @@ pub trait DataFrame { &self, right: Arc, join_type: JoinType, - left_cols: Vec<&str>, - right_cols: Vec<&str>, + left_cols: &[&str], + right_cols: &[&str], ) -> Result>; /// Executes this DataFrame and collects all results into a vector of RecordBatch. diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index b4179639f65..6b93db5070b 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -109,8 +109,8 @@ impl DataFrame for DataFrameImpl { &self, right: Arc, join_type: JoinType, - left_cols: Vec<&str>, - right_cols: Vec<&str>, + left_cols: &[&str], + right_cols: &[&str], ) -> Result> { let plan = LogicalPlanBuilder::from(&self.plan) .join( @@ -224,6 +224,20 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join() -> Result<()> { + let left = test_table()?.select_columns(vec!["c1", "c2"])?; + let right = test_table()?.select_columns(vec!["c1", "c3"])?; + let left_rows = left.collect().await?; + let right_rows = right.collect().await?; + let join = left.join(right, JoinType::Inner, &["c1"], &["c1"])?; + let join_rows = join.collect().await?; + assert_eq!(100, left_rows.len()); + assert_eq!(100, right_rows.len()); + assert_eq!(1000, join_rows.len()); //TODO determine expected number but should be > 100 + Ok(()) + } + #[test] fn limit() -> Result<()> { // build query using Table API diff --git a/rust/datafusion/src/logical_plan/builder.rs b/rust/datafusion/src/logical_plan/builder.rs index f4a8136233e..7f849dcc290 100644 --- a/rust/datafusion/src/logical_plan/builder.rs +++ b/rust/datafusion/src/logical_plan/builder.rs @@ -188,8 +188,8 @@ impl LogicalPlanBuilder { &self, right: Arc, join_type: JoinType, - left_keys: Vec<&str>, - right_keys: Vec<&str>, + left_keys: &[&str], + right_keys: &[&str], ) -> Result { if left_keys.len() != right_keys.len() { Err(DataFusionError::Plan( @@ -210,9 +210,6 @@ impl LogicalPlanBuilder { &on, &physical_join_type, ); - println!("left: {:?}", self.plan.schema()); - println!("right: {:?}", right.schema()); - println!("join: {:?}", physical_schema); Ok(Self::from(&LogicalPlan::Join { left: Arc::new(self.plan.clone()), right, From 2318b47b44f62a138e7f79e7f4bfeebdab5ff223 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 21 Nov 2020 11:42:55 -0700 Subject: [PATCH 5/7] change logical plan for join to use pairs instead of separate left/right vectors --- rust/datafusion/src/logical_plan/builder.rs | 7 +++++-- rust/datafusion/src/logical_plan/plan.rs | 19 +++++-------------- rust/datafusion/src/optimizer/utils.rs | 6 ++---- .../src/physical_plan/distinct_expressions.rs | 2 +- .../datafusion/src/physical_plan/hash_join.rs | 2 +- .../src/physical_plan/hash_utils.rs | 2 +- rust/datafusion/src/physical_plan/planner.rs | 10 ++-------- 7 files changed, 17 insertions(+), 31 deletions(-) diff --git a/rust/datafusion/src/logical_plan/builder.rs b/rust/datafusion/src/logical_plan/builder.rs index 7f849dcc290..a88944448b9 100644 --- a/rust/datafusion/src/logical_plan/builder.rs +++ b/rust/datafusion/src/logical_plan/builder.rs @@ -213,8 +213,11 @@ impl LogicalPlanBuilder { Ok(Self::from(&LogicalPlan::Join { left: Arc::new(self.plan.clone()), right, - left_keys: left_keys.iter().map(|k| k.to_string()).collect(), - right_keys: right_keys.iter().map(|k| k.to_string()).collect(), + on: left_keys + .iter() + .zip(right_keys.iter()) + .map(|(l, r)| (l.to_string(), r.to_string())) + .collect(), join_type, schema: Arc::new(physical_schema), })) diff --git a/rust/datafusion/src/logical_plan/plan.rs b/rust/datafusion/src/logical_plan/plan.rs index 525984b895e..74f2082481d 100644 --- a/rust/datafusion/src/logical_plan/plan.rs +++ b/rust/datafusion/src/logical_plan/plan.rs @@ -107,10 +107,8 @@ pub enum LogicalPlan { left: Arc, /// Right input right: Arc, - /// Columns in left input to use for join keys - left_keys: Vec, - /// Columns in right input to use for join keys - right_keys: Vec, + /// Equijoin clause expressed as pairs of (left, right) join columns + on: Vec<(String, String)>, /// Join type join_type: JoinType, /// The output schema, containing fields from the left and right inputs @@ -581,16 +579,9 @@ impl LogicalPlan { } Ok(()) } - LogicalPlan::Join { - ref left_keys, - ref right_keys, - .. - } => { - let join_expr: Vec = left_keys - .iter() - .zip(right_keys) - .map(|(l, r)| format!("{} = {}", l, r)) - .collect(); + LogicalPlan::Join { on: ref keys, .. } => { + let join_expr: Vec = + keys.iter().map(|(l, r)| format!("{} = {}", l, r)).collect(); write!(f, "Join: {}", join_expr.join(", ")) } LogicalPlan::Limit { ref n, .. } => write!(f, "Limit: {}", n), diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index 20b46eb96ed..72cf75525b6 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -186,16 +186,14 @@ pub fn from_plan( }), LogicalPlan::Join { join_type, - left_keys, - right_keys, + on, schema, .. } => Ok(LogicalPlan::Join { left: Arc::new(inputs[0].clone()), right: Arc::new(inputs[1].clone()), join_type: join_type.clone(), - left_keys: left_keys.clone(), - right_keys: right_keys.clone(), + on: on.clone(), schema: schema.clone(), }), LogicalPlan::Limit { n, .. } => Ok(LogicalPlan::Limit { diff --git a/rust/datafusion/src/physical_plan/distinct_expressions.rs b/rust/datafusion/src/physical_plan/distinct_expressions.rs index 09194439777..bbccc3be6eb 100644 --- a/rust/datafusion/src/physical_plan/distinct_expressions.rs +++ b/rust/datafusion/src/physical_plan/distinct_expressions.rs @@ -262,7 +262,7 @@ mod tests { let mut states = state1 .iter() .zip(state2.iter()) - .map(|(a, b)| (a.clone(), b.clone())) + .map(|(l, r)| (l.clone(), r.clone())) .collect::, Option)>>(); states.sort(); states diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs index 94c0e09d741..4933e3370ba 100644 --- a/rust/datafusion/src/physical_plan/hash_join.rs +++ b/rust/datafusion/src/physical_plan/hash_join.rs @@ -438,7 +438,7 @@ mod tests { ) -> Result { let on: Vec<_> = on .iter() - .map(|(a, b)| (a.to_string(), b.to_string())) + .map(|(l, r)| (l.to_string(), r.to_string())) .collect(); HashJoinExec::try_new(left, right, &on, &JoinType::Inner) } diff --git a/rust/datafusion/src/physical_plan/hash_utils.rs b/rust/datafusion/src/physical_plan/hash_utils.rs index 1175a6886d5..1492c036990 100644 --- a/rust/datafusion/src/physical_plan/hash_utils.rs +++ b/rust/datafusion/src/physical_plan/hash_utils.rs @@ -125,7 +125,7 @@ mod tests { let right = right.iter().map(|x| x.to_string()).collect::>(); let on: Vec<_> = on .iter() - .map(|(a, b)| (a.to_string(), b.to_string())) + .map(|(l, r)| (l.to_string(), r.to_string())) .collect(); check_join_set_is_valid(&left, &right, &on) } diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index 8a2fdb96bc7..90dc2da9426 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -293,8 +293,7 @@ impl DefaultPhysicalPlanner { LogicalPlan::Join { left, right, - left_keys, - right_keys, + on: keys, join_type, schema, } => { @@ -303,13 +302,8 @@ impl DefaultPhysicalPlanner { let physical_join_type = match join_type { JoinType::Inner => hash_utils::JoinType::Inner, }; - let on: Vec<_> = left_keys - .iter() - .zip(right_keys.iter()) - .map(|(a, b)| (a.to_string(), b.to_string())) - .collect(); let hash_join = - HashJoinExec::try_new(left, right, &on, &physical_join_type)?; + HashJoinExec::try_new(left, right, &keys, &physical_join_type)?; if schema.as_ref() == hash_join.schema().as_ref() { Ok(Arc::new(hash_join)) } else { From 12ff573a15af4e2952972cadd248856f32ed1639 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 21 Nov 2020 12:07:09 -0700 Subject: [PATCH 6/7] fix bad test --- rust/datafusion/src/dataframe.rs | 2 +- rust/datafusion/src/execution/dataframe_impl.rs | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/rust/datafusion/src/dataframe.rs b/rust/datafusion/src/dataframe.rs index edf454fa96a..0945fa593d5 100644 --- a/rust/datafusion/src/dataframe.rs +++ b/rust/datafusion/src/dataframe.rs @@ -160,7 +160,7 @@ pub trait DataFrame { /// col("a").alias("a2"), /// col("b").alias("b2"), /// col("c").alias("c2")])?; - /// let join = left.join(right, JoinType::Inner, vec!["a", "b"], vec!["a2", "b2"])?; + /// let join = left.join(right, JoinType::Inner, &["a", "b"], &["a2", "b2"])?; /// let batches = join.collect().await?; /// # Ok(()) /// # } diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index 6b93db5070b..44a557b3107 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -232,9 +232,12 @@ mod tests { let right_rows = right.collect().await?; let join = left.join(right, JoinType::Inner, &["c1"], &["c1"])?; let join_rows = join.collect().await?; - assert_eq!(100, left_rows.len()); - assert_eq!(100, right_rows.len()); - assert_eq!(1000, join_rows.len()); //TODO determine expected number but should be > 100 + assert_eq!(1, left_rows.len()); + assert_eq!(100, left_rows[0].num_rows()); + assert_eq!(1, right_rows.len()); + assert_eq!(100, right_rows[0].num_rows()); + assert_eq!(1, join_rows.len()); + assert_eq!(2008, join_rows[0].num_rows()); Ok(()) } From 11a9c27082caa17d27241ef62aee2b75323ebc33 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 21 Nov 2020 12:08:52 -0700 Subject: [PATCH 7/7] code cleanup --- rust/datafusion/src/optimizer/utils.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index 72cf75525b6..7b19856cea0 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -118,9 +118,6 @@ pub fn expressions(plan: &LogicalPlan) -> Vec { result.extend(aggr_expr.clone()); result } - LogicalPlan::Join { .. } => { - vec![] - } LogicalPlan::Sort { expr, .. } => expr.clone(), LogicalPlan::Extension { node } => node.expressions(), // plans without expressions @@ -130,6 +127,7 @@ pub fn expressions(plan: &LogicalPlan) -> Vec { | LogicalPlan::CsvScan { .. } | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Limit { .. } + | LogicalPlan::Join { .. } | LogicalPlan::CreateExternalTable { .. } | LogicalPlan::Explain { .. } => vec![], }