Skip to content
145 changes: 145 additions & 0 deletions rust/datafusion/src/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Functionality used both on logical and physical plans

use crate::error::{ExecutionError, Result};
use arrow::datatypes::{Field, Schema};
use std::collections::HashSet;

/// All valid types of joins.
#[derive(Clone, Debug)]
pub enum JoinHow {
/// Inner join
Inner,
}

/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join.
/// They are valid whenever their columns' intersection equals the set `on`
pub fn check_join_is_valid(
left: &Schema,
right: &Schema,
on: &HashSet<String>,
) -> Result<()> {
let left: HashSet<String> = left.fields().iter().map(|f| f.name().clone()).collect();
let right: HashSet<String> =
right.fields().iter().map(|f| f.name().clone()).collect();

check_join_set_is_valid(&left, &right, &on)?;
Ok(())
}

/// Checks whether the sets left, right and on compose a valid join.
/// They are valid whenever their intersection equals the set `on`
fn check_join_set_is_valid(
left: &HashSet<String>,
right: &HashSet<String>,
on: &HashSet<String>,
) -> Result<()> {
if on.len() == 0 {
return Err(ExecutionError::General(
"The 'on' clause of a join cannot be empty".to_string(),
));
}

let on_columns = on.iter().map(|s| s).collect::<HashSet<_>>();
let common_columns = left.intersection(&right).collect::<HashSet<_>>();
let missing = on_columns
.difference(&common_columns)
.collect::<HashSet<_>>();
if missing.len() > 0 {
return Err(ExecutionError::General(format!(
"The left or right side of the join does not have columns {:?} columns on \"on\": \nLeft: {:?}\nRight: {:?}\nOn: {:?}",
missing,
left,
right,
on,
).to_string()));
};
Ok(())
}

/// Creates a schema for a join operation.
/// The fields "on" from the left side are always first
pub fn build_join_schema(
left: &Schema,
right: &Schema,
on: &HashSet<String>,
how: &JoinHow,
) -> Result<Schema> {
let fields: Vec<Field> = match how {
JoinHow::Inner => {
// inner: all fields are there

let on_fields = left.fields().iter().filter(|f| on.contains(f.name()));

let left_fields = left.fields().iter().filter(|f| !on.contains(f.name()));

let right_fields = right.fields().iter().filter(|f| !on.contains(f.name()));

// "on" are first by construction, then left, then right
on_fields
.chain(left_fields)
.chain(right_fields)
.map(|f| f.clone())
.collect()
}
};
Ok(Schema::new(fields))
}

#[cfg(test)]
mod tests {

use super::*;

fn check(left: &[&str], right: &[&str], on: &[&str]) -> Result<()> {
let left = left.iter().map(|x| x.to_string()).collect::<HashSet<_>>();
let right = right.iter().map(|x| x.to_string()).collect::<HashSet<_>>();
let on = on.iter().map(|x| x.to_string()).collect::<HashSet<_>>();

check_join_set_is_valid(&left, &right, &on)
}

#[test]
fn check_valid() -> Result<()> {
let left = vec!["a", "b1"];
let right = vec!["a", "b2"];
let on = vec!["a"];

check(&left, &right, &on)?;
Ok(())
}

#[test]
fn check_not_in_right() {
let left = vec!["a", "b"];
let right = vec!["b"];
let on = vec!["a"];

assert!(check(&left, &right, &on).is_err());
}

#[test]
fn check_not_in_left() {
let left = vec!["b"];
let right = vec!["a"];
let on = vec!["a"];

assert!(check(&left, &right, &on).is_err());
}
}
78 changes: 58 additions & 20 deletions rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use arrow::csv;
use arrow::datatypes::*;
use arrow::record_batch::RecordBatch;

use super::physical_plan::hash_join::HashJoinExec;
use crate::datasource::csv::CsvFile;
use crate::datasource::parquet::ParquetTable;
use crate::datasource::TableProvider;
Expand Down Expand Up @@ -439,6 +440,17 @@ impl ExecutionContext {
merge,
)?))
}
LogicalPlan::Join {
left,
right,
on,
how,
..
} => {
let left = self.create_physical_plan(left, batch_size)?;
let right = self.create_physical_plan(right, batch_size)?;
Ok(Arc::new(HashJoinExec::try_new(left, right, on, how)?))
}
LogicalPlan::Selection { input, expr, .. } => {
let input = self.create_physical_plan(input, batch_size)?;
let input_schema = input.as_ref().schema().clone();
Expand Down Expand Up @@ -689,7 +701,7 @@ mod tests {
use crate::datasource::MemTable;
use crate::execution::physical_plan::udf::ScalarUdf;
use crate::logicalplan::{aggregate_expr, col, scalar_function};
use crate::test;
use crate::{common::JoinHow, test};
use arrow::array::{ArrayRef, Int32Array};
use arrow::compute::add;
use std::fs::File;
Expand Down Expand Up @@ -804,28 +816,31 @@ mod tests {
Ok(())
}

#[test]
fn projection_on_memory_scan() -> Result<()> {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]);
let plan = LogicalPlanBuilder::from(&LogicalPlan::InMemoryScan {
data: vec![vec![RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![1, 10, 10, 100])),
Arc::new(Int32Array::from(vec![2, 12, 12, 120])),
Arc::new(Int32Array::from(vec![3, 12, 12, 120])),
],
)?]],
fn build_table(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
c: (&str, &Vec<i32>),
) -> Result<LogicalPlan> {
let (batch, schema) = build_table_i32(a, b, c)?;

Ok(LogicalPlan::InMemoryScan {
data: vec![vec![batch]],
schema: Box::new(schema.clone()),
projection: None,
projected_schema: Box::new(schema.clone()),
projected_schema: Box::new(schema),
})
.project(vec![col("b")])?
.build()?;
}

#[test]
fn projection_on_memory_scan() -> Result<()> {
let plan = build_table(
("a", &vec![1, 10, 10, 100]),
("b", &vec![2, 12, 12, 120]),
("c", &vec![3, 12, 12, 120]),
)?;
let plan = LogicalPlanBuilder::from(&plan)
.project(vec![col("b")])?
.build()?;
assert_fields_eq(&plan, vec!["b"]);

let ctx = ExecutionContext::new();
Expand Down Expand Up @@ -862,6 +877,29 @@ mod tests {
Ok(())
}

#[test]
fn join() -> Result<()> {
let left =
build_table(("a", &vec![1, 1]), ("b", &vec![2, 3]), ("c", &vec![3, 4]))?;
let right = build_table(
("a", &vec![1, 1]),
("b2", &vec![12, 13]),
("c2", &vec![13, 14]),
)?;
let plan = LogicalPlanBuilder::from(&left)
.join(&right, &vec!["a".to_string()], &JoinHow::Inner)?
.build()?;

let ctx = ExecutionContext::new();
let physical_plan = ctx.create_physical_plan(&plan, 1024)?;

let batches = ctx.collect(physical_plan.as_ref())?;
let expected: Vec<&str> =
vec!["1,2,3,12,13", "1,2,3,13,14", "1,3,4,12,13", "1,3,4,13,14"];
assert_eq!(test::format_batch(&batches[0]), expected);
Ok(())
}

#[test]
fn sort() -> Result<()> {
let results = execute("SELECT c1, c2 FROM test ORDER BY c1 DESC, c2 ASC", 4)?;
Expand Down
Loading