Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion rust/datafusion/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use crate::arrow::record_batch::RecordBatch;
use crate::error::Result;
use crate::logical_plan::{Expr, FunctionRegistry, LogicalPlan};
use crate::logical_plan::{Expr, FunctionRegistry, JoinType, LogicalPlan};
use arrow::datatypes::Schema;
use std::sync::Arc;

Expand Down Expand Up @@ -146,6 +146,33 @@ pub trait DataFrame {
/// ```
fn sort(&self, expr: Vec<Expr>) -> Result<Arc<dyn DataFrame>>;

/// Join this DataFrame with another DataFrame using the specified columns as join keys
///
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let mut ctx = ExecutionContext::new();
/// let left = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?;
/// let right = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?
/// .select(vec![
/// col("a").alias("a2"),
/// col("b").alias("b2"),
/// col("c").alias("c2")])?;
/// let join = left.join(right, JoinType::Inner, &["a", "b"], &["a2", "b2"])?;
/// let batches = join.collect().await?;
/// # Ok(())
/// # }
/// ```
fn join(
&self,
right: Arc<dyn DataFrame>,
join_type: JoinType,
left_cols: &[&str],
right_cols: &[&str],
) -> Result<Arc<dyn DataFrame>>;

/// Executes this DataFrame and collects all results into a vector of RecordBatch.
///
/// ```
Expand Down
40 changes: 39 additions & 1 deletion rust/datafusion/src/execution/dataframe_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ use crate::arrow::record_batch::RecordBatch;
use crate::dataframe::*;
use crate::error::Result;
use crate::execution::context::{ExecutionContext, ExecutionContextState};
use crate::logical_plan::{col, Expr, FunctionRegistry, LogicalPlan, LogicalPlanBuilder};
use crate::logical_plan::{
col, Expr, FunctionRegistry, JoinType, LogicalPlan, LogicalPlanBuilder,
};
use arrow::datatypes::Schema;

use async_trait::async_trait;
Expand Down Expand Up @@ -102,6 +104,25 @@ impl DataFrame for DataFrameImpl {
Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
}

/// Join with another DataFrame
fn join(
&self,
right: Arc<dyn DataFrame>,
join_type: JoinType,
left_cols: &[&str],
right_cols: &[&str],
) -> Result<Arc<dyn DataFrame>> {
let plan = LogicalPlanBuilder::from(&self.plan)
.join(
Arc::new(right.to_logical_plan()),
join_type,
left_cols,
right_cols,
)?
.build()?;
Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
}

/// Convert to logical plan
fn to_logical_plan(&self) -> LogicalPlan {
self.plan.clone()
Expand Down Expand Up @@ -203,6 +224,23 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn join() -> Result<()> {
let left = test_table()?.select_columns(vec!["c1", "c2"])?;
let right = test_table()?.select_columns(vec!["c1", "c3"])?;
let left_rows = left.collect().await?;
let right_rows = right.collect().await?;
let join = left.join(right, JoinType::Inner, &["c1"], &["c1"])?;
let join_rows = join.collect().await?;
assert_eq!(1, left_rows.len());
assert_eq!(100, left_rows[0].num_rows());
assert_eq!(1, right_rows.len());
assert_eq!(100, right_rows[0].num_rows());
assert_eq!(1, join_rows.len());
assert_eq!(2008, join_rows[0].num_rows());
Ok(())
}

#[test]
fn limit() -> Result<()> {
// build query using Table API
Expand Down
45 changes: 44 additions & 1 deletion rust/datafusion/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ use crate::datasource::TableProvider;
use crate::error::{DataFusionError, Result};

use super::{
col, exprlist_to_fields, Expr, LogicalPlan, PlanType, StringifiedPlan, TableSource,
col, exprlist_to_fields, Expr, JoinType, LogicalPlan, PlanType, StringifiedPlan,
TableSource,
};
use crate::physical_plan::hash_utils;

/// Builder for logical plans
pub struct LogicalPlanBuilder {
Expand Down Expand Up @@ -181,6 +183,47 @@ impl LogicalPlanBuilder {
}))
}

/// Apply a join
pub fn join(
&self,
right: Arc<LogicalPlan>,
join_type: JoinType,
left_keys: &[&str],
right_keys: &[&str],
) -> Result<Self> {
if left_keys.len() != right_keys.len() {
Err(DataFusionError::Plan(
"left_keys and right_keys were not the same length".to_string(),
))
} else {
let on: Vec<_> = left_keys
.iter()
.zip(right_keys.iter())
.map(|(x, y)| (x.to_string(), y.to_string()))
.collect::<Vec<_>>();
let physical_join_type = match join_type {
JoinType::Inner => hash_utils::JoinType::Inner,
};
let physical_schema = hash_utils::build_join_schema(
self.plan.schema(),
right.schema(),
&on,
&physical_join_type,
);
Ok(Self::from(&LogicalPlan::Join {
left: Arc::new(self.plan.clone()),
right,
on: left_keys
.iter()
.zip(right_keys.iter())
.map(|(l, r)| (l.to_string(), r.to_string()))
.collect(),
join_type,
schema: Arc::new(physical_schema),
}))
}
}

/// Apply an aggregate
pub fn aggregate(&self, group_expr: Vec<Expr>, aggr_expr: Vec<Expr>) -> Result<Self> {
let mut all_expr: Vec<Expr> = group_expr.clone();
Expand Down
4 changes: 3 additions & 1 deletion rust/datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,7 @@ pub use expr::{
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
pub use plan::{LogicalPlan, PlanType, PlanVisitor, StringifiedPlan, TableSource};
pub use plan::{
JoinType, LogicalPlan, PlanType, PlanVisitor, StringifiedPlan, TableSource,
};
pub use registry::FunctionRegistry;
29 changes: 29 additions & 0 deletions rust/datafusion/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ pub enum TableSource {
FromProvider(Arc<dyn TableProvider + Send + Sync>),
}

/// Join type
#[derive(Debug, Clone)]
pub enum JoinType {
/// Inner join
Inner,
}

/// A LogicalPlan represents the different types of relational
/// operators (such as Projection, Filter, etc) and can be created by
/// the SQL query planner and the DataFrame API.
Expand Down Expand Up @@ -94,6 +101,19 @@ pub enum LogicalPlan {
/// The incoming logical plan
input: Arc<LogicalPlan>,
},
/// Join two logical plans on one or more join columns
Join {
/// Left input
left: Arc<LogicalPlan>,
/// Right input
right: Arc<LogicalPlan>,
/// Equijoin clause expressed as pairs of (left, right) join columns
on: Vec<(String, String)>,
/// Join type
join_type: JoinType,
/// The output schema, containing fields from the left and right inputs
schema: SchemaRef,
},
/// Produces rows from a table provider by reference or from the context
TableScan {
/// The name of the schema
Expand Down Expand Up @@ -211,6 +231,7 @@ impl LogicalPlan {
LogicalPlan::Filter { input, .. } => input.schema(),
LogicalPlan::Aggregate { schema, .. } => &schema,
LogicalPlan::Sort { input, .. } => input.schema(),
LogicalPlan::Join { schema, .. } => &schema,
LogicalPlan::Limit { input, .. } => input.schema(),
LogicalPlan::CreateExternalTable { schema, .. } => &schema,
LogicalPlan::Explain { schema, .. } => &schema,
Expand Down Expand Up @@ -292,6 +313,9 @@ impl LogicalPlan {
LogicalPlan::Filter { input, .. } => input.accept(visitor)?,
LogicalPlan::Aggregate { input, .. } => input.accept(visitor)?,
LogicalPlan::Sort { input, .. } => input.accept(visitor)?,
LogicalPlan::Join { left, right, .. } => {
left.accept(visitor)? && right.accept(visitor)?
}
LogicalPlan::Limit { input, .. } => input.accept(visitor)?,
LogicalPlan::Extension { node } => {
for input in node.inputs() {
Expand Down Expand Up @@ -555,6 +579,11 @@ impl LogicalPlan {
}
Ok(())
}
LogicalPlan::Join { on: ref keys, .. } => {
let join_expr: Vec<String> =
keys.iter().map(|(l, r)| format!("{} = {}", l, r)).collect();
write!(f, "Join: {}", join_expr.join(", "))
}
LogicalPlan::Limit { ref n, .. } => write!(f, "Limit: {}", n),
LogicalPlan::CreateExternalTable { ref name, .. } => {
write!(f, "CreateExternalTable: {:?}", name)
Expand Down
1 change: 1 addition & 0 deletions rust/datafusion/src/optimizer/projection_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ fn optimize_plan(
| LogicalPlan::Filter { .. }
| LogicalPlan::EmptyRelation { .. }
| LogicalPlan::Sort { .. }
| LogicalPlan::Join { .. }
| LogicalPlan::CreateExternalTable { .. }
| LogicalPlan::Extension { .. } => {
let expr = utils::expressions(plan);
Expand Down
14 changes: 14 additions & 0 deletions rust/datafusion/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ pub fn expressions(plan: &LogicalPlan) -> Vec<Expr> {
| LogicalPlan::CsvScan { .. }
| LogicalPlan::EmptyRelation { .. }
| LogicalPlan::Limit { .. }
| LogicalPlan::Join { .. }
| LogicalPlan::CreateExternalTable { .. }
| LogicalPlan::Explain { .. } => vec![],
}
Expand All @@ -139,6 +140,7 @@ pub fn inputs(plan: &LogicalPlan) -> Vec<&LogicalPlan> {
LogicalPlan::Filter { input, .. } => vec![input],
LogicalPlan::Aggregate { input, .. } => vec![input],
LogicalPlan::Sort { input, .. } => vec![input],
LogicalPlan::Join { left, right, .. } => vec![left, right],
LogicalPlan::Limit { input, .. } => vec![input],
LogicalPlan::Extension { node } => node.inputs(),
// plans without inputs
Expand Down Expand Up @@ -180,6 +182,18 @@ pub fn from_plan(
expr: expr.clone(),
input: Arc::new(inputs[0].clone()),
}),
LogicalPlan::Join {
join_type,
on,
schema,
..
} => Ok(LogicalPlan::Join {
left: Arc::new(inputs[0].clone()),
right: Arc::new(inputs[1].clone()),
join_type: join_type.clone(),
on: on.clone(),
schema: schema.clone(),
}),
LogicalPlan::Limit { n, .. } => Ok(LogicalPlan::Limit {
n: *n,
input: Arc::new(inputs[0].clone()),
Expand Down
2 changes: 1 addition & 1 deletion rust/datafusion/src/physical_plan/distinct_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ mod tests {
let mut states = state1
.iter()
.zip(state2.iter())
.map(|(a, b)| (a.clone(), b.clone()))
.map(|(l, r)| (l.clone(), r.clone()))
.collect::<Vec<(Option<T>, Option<S>)>>();
states.sort();
states
Expand Down
44 changes: 37 additions & 7 deletions rust/datafusion/src/physical_plan/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,7 @@ impl ExecutionPlan for HashJoinExec {
2 => Ok(Arc::new(HashJoinExec::try_new(
children[0].clone(),
children[1].clone(),
&self
.on
.iter()
.map(|(x, y)| (x.as_str(), y.as_str()))
.collect::<Vec<_>>(),
&self.on,
&self.join_type,
)?)),
_ => Err(DataFusionError::Internal(
Expand Down Expand Up @@ -438,9 +434,13 @@ mod tests {
fn join(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: &JoinOn,
on: &[(&str, &str)],
) -> Result<HashJoinExec> {
HashJoinExec::try_new(left, right, on, &JoinType::Inner)
let on: Vec<_> = on
.iter()
.map(|(l, r)| (l.to_string(), r.to_string()))
.collect();
HashJoinExec::try_new(left, right, &on, &JoinType::Inner)
}

/// Asserts that the rows are the same, taking into account that their order
Expand Down Expand Up @@ -486,6 +486,36 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn join_one_no_shared_column_names() -> Result<()> {
let left = build_table(
("a1", &vec![1, 2, 3]),
("b1", &vec![4, 5, 5]), // this has a repetition
("c1", &vec![7, 8, 9]),
);
let right = build_table(
("a2", &vec![10, 20, 30]),
("b2", &vec![4, 5, 6]),
("c2", &vec![70, 80, 90]),
);
let on = &[("b1", "b2")];

let join = join(left, right, on)?;

let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);

let stream = join.execute(0).await?;
let batches = common::collect(stream).await?;

let result = format_batch(&batches[0]);
let expected = vec!["2,5,8,20,5,80", "3,5,9,20,5,80", "1,4,7,10,4,70"];

assert_same_rows(&result, &expected);

Ok(())
}

#[tokio::test]
async fn join_two() -> Result<()> {
let left = build_table(
Expand Down
Loading