diff --git a/rust/datafusion/src/dataframe.rs b/rust/datafusion/src/dataframe.rs index d2ad7e6917a..0945fa593d5 100644 --- a/rust/datafusion/src/dataframe.rs +++ b/rust/datafusion/src/dataframe.rs @@ -19,7 +19,7 @@ use crate::arrow::record_batch::RecordBatch; use crate::error::Result; -use crate::logical_plan::{Expr, FunctionRegistry, LogicalPlan}; +use crate::logical_plan::{Expr, FunctionRegistry, JoinType, LogicalPlan}; use arrow::datatypes::Schema; use std::sync::Arc; @@ -146,6 +146,33 @@ pub trait DataFrame { /// ``` fn sort(&self, expr: Vec) -> Result>; + /// Join this DataFrame with another DataFrame using the specified columns as join keys + /// + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let mut ctx = ExecutionContext::new(); + /// let left = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?; + /// let right = ctx.read_csv("tests/example.csv", CsvReadOptions::new())? + /// .select(vec![ + /// col("a").alias("a2"), + /// col("b").alias("b2"), + /// col("c").alias("c2")])?; + /// let join = left.join(right, JoinType::Inner, &["a", "b"], &["a2", "b2"])?; + /// let batches = join.collect().await?; + /// # Ok(()) + /// # } + /// ``` + fn join( + &self, + right: Arc, + join_type: JoinType, + left_cols: &[&str], + right_cols: &[&str], + ) -> Result>; + /// Executes this DataFrame and collects all results into a vector of RecordBatch. /// /// ``` diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index a93e5745a03..44a557b3107 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -23,7 +23,9 @@ use crate::arrow::record_batch::RecordBatch; use crate::dataframe::*; use crate::error::Result; use crate::execution::context::{ExecutionContext, ExecutionContextState}; -use crate::logical_plan::{col, Expr, FunctionRegistry, LogicalPlan, LogicalPlanBuilder}; +use crate::logical_plan::{ + col, Expr, FunctionRegistry, JoinType, LogicalPlan, LogicalPlanBuilder, +}; use arrow::datatypes::Schema; use async_trait::async_trait; @@ -102,6 +104,25 @@ impl DataFrame for DataFrameImpl { Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) } + /// Join with another DataFrame + fn join( + &self, + right: Arc, + join_type: JoinType, + left_cols: &[&str], + right_cols: &[&str], + ) -> Result> { + let plan = LogicalPlanBuilder::from(&self.plan) + .join( + Arc::new(right.to_logical_plan()), + join_type, + left_cols, + right_cols, + )? + .build()?; + Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) + } + /// Convert to logical plan fn to_logical_plan(&self) -> LogicalPlan { self.plan.clone() @@ -203,6 +224,23 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join() -> Result<()> { + let left = test_table()?.select_columns(vec!["c1", "c2"])?; + let right = test_table()?.select_columns(vec!["c1", "c3"])?; + let left_rows = left.collect().await?; + let right_rows = right.collect().await?; + let join = left.join(right, JoinType::Inner, &["c1"], &["c1"])?; + let join_rows = join.collect().await?; + assert_eq!(1, left_rows.len()); + assert_eq!(100, left_rows[0].num_rows()); + assert_eq!(1, right_rows.len()); + assert_eq!(100, right_rows[0].num_rows()); + assert_eq!(1, join_rows.len()); + assert_eq!(2008, join_rows[0].num_rows()); + Ok(()) + } + #[test] fn limit() -> Result<()> { // build query using Table API diff --git a/rust/datafusion/src/logical_plan/builder.rs b/rust/datafusion/src/logical_plan/builder.rs index 32e53039281..a88944448b9 100644 --- a/rust/datafusion/src/logical_plan/builder.rs +++ b/rust/datafusion/src/logical_plan/builder.rs @@ -27,8 +27,10 @@ use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; use super::{ - col, exprlist_to_fields, Expr, LogicalPlan, PlanType, StringifiedPlan, TableSource, + col, exprlist_to_fields, Expr, JoinType, LogicalPlan, PlanType, StringifiedPlan, + TableSource, }; +use crate::physical_plan::hash_utils; /// Builder for logical plans pub struct LogicalPlanBuilder { @@ -181,6 +183,47 @@ impl LogicalPlanBuilder { })) } + /// Apply a join + pub fn join( + &self, + right: Arc, + join_type: JoinType, + left_keys: &[&str], + right_keys: &[&str], + ) -> Result { + if left_keys.len() != right_keys.len() { + Err(DataFusionError::Plan( + "left_keys and right_keys were not the same length".to_string(), + )) + } else { + let on: Vec<_> = left_keys + .iter() + .zip(right_keys.iter()) + .map(|(x, y)| (x.to_string(), y.to_string())) + .collect::>(); + let physical_join_type = match join_type { + JoinType::Inner => hash_utils::JoinType::Inner, + }; + let physical_schema = hash_utils::build_join_schema( + self.plan.schema(), + right.schema(), + &on, + &physical_join_type, + ); + Ok(Self::from(&LogicalPlan::Join { + left: Arc::new(self.plan.clone()), + right, + on: left_keys + .iter() + .zip(right_keys.iter()) + .map(|(l, r)| (l.to_string(), r.to_string())) + .collect(), + join_type, + schema: Arc::new(physical_schema), + })) + } + } + /// Apply an aggregate pub fn aggregate(&self, group_expr: Vec, aggr_expr: Vec) -> Result { let mut all_expr: Vec = group_expr.clone(); diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index c2f90da3581..a3a44d5341a 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -38,5 +38,7 @@ pub use expr::{ }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; -pub use plan::{LogicalPlan, PlanType, PlanVisitor, StringifiedPlan, TableSource}; +pub use plan::{ + JoinType, LogicalPlan, PlanType, PlanVisitor, StringifiedPlan, TableSource, +}; pub use registry::FunctionRegistry; diff --git a/rust/datafusion/src/logical_plan/plan.rs b/rust/datafusion/src/logical_plan/plan.rs index ced13fa89ce..74f2082481d 100644 --- a/rust/datafusion/src/logical_plan/plan.rs +++ b/rust/datafusion/src/logical_plan/plan.rs @@ -41,6 +41,13 @@ pub enum TableSource { FromProvider(Arc), } +/// Join type +#[derive(Debug, Clone)] +pub enum JoinType { + /// Inner join + Inner, +} + /// A LogicalPlan represents the different types of relational /// operators (such as Projection, Filter, etc) and can be created by /// the SQL query planner and the DataFrame API. @@ -94,6 +101,19 @@ pub enum LogicalPlan { /// The incoming logical plan input: Arc, }, + /// Join two logical plans on one or more join columns + Join { + /// Left input + left: Arc, + /// Right input + right: Arc, + /// Equijoin clause expressed as pairs of (left, right) join columns + on: Vec<(String, String)>, + /// Join type + join_type: JoinType, + /// The output schema, containing fields from the left and right inputs + schema: SchemaRef, + }, /// Produces rows from a table provider by reference or from the context TableScan { /// The name of the schema @@ -211,6 +231,7 @@ impl LogicalPlan { LogicalPlan::Filter { input, .. } => input.schema(), LogicalPlan::Aggregate { schema, .. } => &schema, LogicalPlan::Sort { input, .. } => input.schema(), + LogicalPlan::Join { schema, .. } => &schema, LogicalPlan::Limit { input, .. } => input.schema(), LogicalPlan::CreateExternalTable { schema, .. } => &schema, LogicalPlan::Explain { schema, .. } => &schema, @@ -292,6 +313,9 @@ impl LogicalPlan { LogicalPlan::Filter { input, .. } => input.accept(visitor)?, LogicalPlan::Aggregate { input, .. } => input.accept(visitor)?, LogicalPlan::Sort { input, .. } => input.accept(visitor)?, + LogicalPlan::Join { left, right, .. } => { + left.accept(visitor)? && right.accept(visitor)? + } LogicalPlan::Limit { input, .. } => input.accept(visitor)?, LogicalPlan::Extension { node } => { for input in node.inputs() { @@ -555,6 +579,11 @@ impl LogicalPlan { } Ok(()) } + LogicalPlan::Join { on: ref keys, .. } => { + let join_expr: Vec = + keys.iter().map(|(l, r)| format!("{} = {}", l, r)).collect(); + write!(f, "Join: {}", join_expr.join(", ")) + } LogicalPlan::Limit { ref n, .. } => write!(f, "Limit: {}", n), LogicalPlan::CreateExternalTable { ref name, .. } => { write!(f, "CreateExternalTable: {:?}", name) diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs b/rust/datafusion/src/optimizer/projection_push_down.rs index f4f190cd193..4c99ab5b72f 100644 --- a/rust/datafusion/src/optimizer/projection_push_down.rs +++ b/rust/datafusion/src/optimizer/projection_push_down.rs @@ -312,6 +312,7 @@ fn optimize_plan( | LogicalPlan::Filter { .. } | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Sort { .. } + | LogicalPlan::Join { .. } | LogicalPlan::CreateExternalTable { .. } | LogicalPlan::Extension { .. } => { let expr = utils::expressions(plan); diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index b4c4e890ba9..7b19856cea0 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -127,6 +127,7 @@ pub fn expressions(plan: &LogicalPlan) -> Vec { | LogicalPlan::CsvScan { .. } | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Limit { .. } + | LogicalPlan::Join { .. } | LogicalPlan::CreateExternalTable { .. } | LogicalPlan::Explain { .. } => vec![], } @@ -139,6 +140,7 @@ pub fn inputs(plan: &LogicalPlan) -> Vec<&LogicalPlan> { LogicalPlan::Filter { input, .. } => vec![input], LogicalPlan::Aggregate { input, .. } => vec![input], LogicalPlan::Sort { input, .. } => vec![input], + LogicalPlan::Join { left, right, .. } => vec![left, right], LogicalPlan::Limit { input, .. } => vec![input], LogicalPlan::Extension { node } => node.inputs(), // plans without inputs @@ -180,6 +182,18 @@ pub fn from_plan( expr: expr.clone(), input: Arc::new(inputs[0].clone()), }), + LogicalPlan::Join { + join_type, + on, + schema, + .. + } => Ok(LogicalPlan::Join { + left: Arc::new(inputs[0].clone()), + right: Arc::new(inputs[1].clone()), + join_type: join_type.clone(), + on: on.clone(), + schema: schema.clone(), + }), LogicalPlan::Limit { n, .. } => Ok(LogicalPlan::Limit { n: *n, input: Arc::new(inputs[0].clone()), diff --git a/rust/datafusion/src/physical_plan/distinct_expressions.rs b/rust/datafusion/src/physical_plan/distinct_expressions.rs index 09194439777..bbccc3be6eb 100644 --- a/rust/datafusion/src/physical_plan/distinct_expressions.rs +++ b/rust/datafusion/src/physical_plan/distinct_expressions.rs @@ -262,7 +262,7 @@ mod tests { let mut states = state1 .iter() .zip(state2.iter()) - .map(|(a, b)| (a.clone(), b.clone())) + .map(|(l, r)| (l.clone(), r.clone())) .collect::, Option)>>(); states.sort(); states diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs index 69a3d5a432e..4933e3370ba 100644 --- a/rust/datafusion/src/physical_plan/hash_join.rs +++ b/rust/datafusion/src/physical_plan/hash_join.rs @@ -130,11 +130,7 @@ impl ExecutionPlan for HashJoinExec { 2 => Ok(Arc::new(HashJoinExec::try_new( children[0].clone(), children[1].clone(), - &self - .on - .iter() - .map(|(x, y)| (x.as_str(), y.as_str())) - .collect::>(), + &self.on, &self.join_type, )?)), _ => Err(DataFusionError::Internal( @@ -438,9 +434,13 @@ mod tests { fn join( left: Arc, right: Arc, - on: &JoinOn, + on: &[(&str, &str)], ) -> Result { - HashJoinExec::try_new(left, right, on, &JoinType::Inner) + let on: Vec<_> = on + .iter() + .map(|(l, r)| (l.to_string(), r.to_string())) + .collect(); + HashJoinExec::try_new(left, right, &on, &JoinType::Inner) } /// Asserts that the rows are the same, taking into account that their order @@ -486,6 +486,36 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_one_no_shared_column_names() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = &[("b1", "b2")]; + + let join = join(left, right, on)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + let result = format_batch(&batches[0]); + let expected = vec!["2,5,8,20,5,80", "3,5,9,20,5,80", "1,4,7,10,4,70"]; + + assert_same_rows(&result, &expected); + + Ok(()) + } + #[tokio::test] async fn join_two() -> Result<()> { let left = build_table( diff --git a/rust/datafusion/src/physical_plan/hash_utils.rs b/rust/datafusion/src/physical_plan/hash_utils.rs index c3987faa88b..1492c036990 100644 --- a/rust/datafusion/src/physical_plan/hash_utils.rs +++ b/rust/datafusion/src/physical_plan/hash_utils.rs @@ -29,7 +29,7 @@ pub enum JoinType { } /// The on clause of the join, as vector of (left, right) columns. -pub type JoinOn<'a> = [(&'a str, &'a str)]; +pub type JoinOn = [(String, String)]; /// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. /// They are valid whenever their columns' intersection equals the set `on` @@ -94,15 +94,19 @@ pub fn build_join_schema( ) -> Schema { let fields: Vec = match join_type { JoinType::Inner => { - // inner: all fields are there - let on_right = &on.iter().map(|on| on.1.to_string()).collect::>(); + // remove right-side join keys if they have the same names as the left-side + let duplicate_keys = &on + .iter() + .filter(|(l, r)| l == r) + .map(|on| on.1.to_string()) + .collect::>(); let left_fields = left.fields().iter(); let right_fields = right .fields() .iter() - .filter(|f| !on_right.contains(f.name())); + .filter(|f| !duplicate_keys.contains(f.name())); // left then right left_fields.chain(right_fields).cloned().collect() @@ -119,8 +123,11 @@ mod tests { fn check(left: &[&str], right: &[&str], on: &[(&str, &str)]) -> Result<()> { let left = left.iter().map(|x| x.to_string()).collect::>(); let right = right.iter().map(|x| x.to_string()).collect::>(); - - check_join_set_is_valid(&left, &right, on) + let on: Vec<_> = on + .iter() + .map(|(l, r)| (l.to_string(), r.to_string())) + .collect(); + check_join_set_is_valid(&left, &right, &on) } #[test] diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index 4740a5dde14..90dc2da9426 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -30,6 +30,8 @@ use crate::physical_plan::explain::ExplainExec; use crate::physical_plan::expressions::{Column, Literal, PhysicalSortExpr}; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; +use crate::physical_plan::hash_join::HashJoinExec; +use crate::physical_plan::hash_utils; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::memory::MemoryExec; use crate::physical_plan::merge::MergeExec; @@ -39,6 +41,7 @@ use crate::physical_plan::sort::SortExec; use crate::physical_plan::udf; use crate::physical_plan::{expressions, Distribution}; use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, PhysicalPlanner}; +use crate::prelude::JoinType; use crate::variable::VarType; use arrow::compute::SortOptions; use arrow::datatypes::Schema; @@ -287,6 +290,26 @@ impl DefaultPhysicalPlanner { ctx_state.config.concurrency, )?)) } + LogicalPlan::Join { + left, + right, + on: keys, + join_type, + schema, + } => { + let left = self.create_physical_plan(left, ctx_state)?; + let right = self.create_physical_plan(right, ctx_state)?; + let physical_join_type = match join_type { + JoinType::Inner => hash_utils::JoinType::Inner, + }; + let hash_join = + HashJoinExec::try_new(left, right, &keys, &physical_join_type)?; + if schema.as_ref() == hash_join.schema().as_ref() { + Ok(Arc::new(hash_join)) + } else { + Err(DataFusionError::Plan("schema mismatch".to_string())) + } + } LogicalPlan::EmptyRelation { produce_one_row, schema, diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index aac2ebf71f1..20bbbe47c97 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -28,6 +28,6 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ - array, avg, col, concat, count, create_udf, length, lit, max, min, sum, + array, avg, col, concat, count, create_udf, length, lit, max, min, sum, JoinType, }; pub use crate::physical_plan::csv::CsvReadOptions;