Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/binder/statement/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use sqlparser::ast::SetExpr::Select;
use sqlparser::ast::{Query, SelectItem};
use sqlparser::ast::{Join, JoinOperator, Query, SelectItem, TableWithJoins};

use super::expression::BoundExpr;
use super::table::BoundTableRef;
Expand Down Expand Up @@ -38,6 +38,22 @@ impl Binder {
// currently, only support select one table
let from_table = if select.from.is_empty() {
None
} else if select.from.len() > 1 {
// merge select from multiple tables into one cross join
// TODO: add more checks
let first_talbe = select.from[0].clone();
let joins = select.from[1..]
.iter()
.map(|a| Join {
relation: a.relation.clone(),
join_operator: JoinOperator::CrossJoin,
})
.collect();
let table_with_joins = TableWithJoins {
relation: first_talbe.relation,
joins,
};
Some(self.bind_table_with_joins(&table_with_joins)?)
} else {
Some(self.bind_table_with_joins(&select.from[0])?)
};
Expand Down
57 changes: 57 additions & 0 deletions src/executor/join/cross_join.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use arrow::datatypes::{Schema, SchemaRef};

use crate::catalog::ColumnCatalog;
use crate::executor::*;
use crate::types::{build_scalar_value_array, ScalarValue};

pub struct CrossJoinExecutor {
pub left_child: BoxedExecutor,
pub right_child: BoxedExecutor,
/// The schema once the join is applied
pub join_output_schema: Vec<ColumnCatalog>,
}

impl CrossJoinExecutor {
fn join_output_arrow_schema(&self) -> SchemaRef {
let fields = self
.join_output_schema
.iter()
.map(|c| c.to_arrow_field())
.collect::<Vec<_>>();
SchemaRef::new(Schema::new(fields))
}

#[try_stream(boxed, ok = RecordBatch, error = ExecutorError)]
pub async fn execute(self) {
let schema = self.join_output_arrow_schema();

// consume all left stream data and then iterate right stream chunk to build result
let left_batches = self.left_child.try_collect::<Vec<_>>().await?;

if left_batches.is_empty() {
return Ok(());
}

let left_single_batch = RecordBatch::concat(&left_batches[0].schema(), &left_batches)?;

#[for_await]
for right_batch in self.right_child {
let right_data = right_batch?;

// repeat left value n times to match right batch size
for row_idx in 0..left_single_batch.num_rows() {
let new_left_data = left_single_batch
.columns()
.iter()
.map(|col_arr| {
let scalar = ScalarValue::try_from_array(col_arr, row_idx);
build_scalar_value_array(&scalar, right_data.num_rows())
})
.collect::<Vec<_>>();
// concat left and right data
let data = vec![new_left_data, right_data.columns().to_vec()].concat();
yield RecordBatch::try_new(schema.clone(), data)?
}
}
}
}
1 change: 1 addition & 0 deletions src/executor/join/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod cross_join;
pub mod hash_join;
18 changes: 15 additions & 3 deletions src/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@ use futures_async_stream::try_stream;
use self::aggregate::hash_agg::HashAggExecutor;
use self::aggregate::simple_agg::SimpleAggExecutor;
use self::filter::FilterExecutor;
use self::join::cross_join::CrossJoinExecutor;
use self::join::hash_join::HashJoinExecutor;
use self::limit::LimitExecutor;
use self::order::OrderExecutor;
use self::project::ProjectExecutor;
use self::table_scan::TableScanExecutor;
use crate::optimizer::{
PhysicalFilter, PhysicalHashAgg, PhysicalHashJoin, PhysicalLimit, PhysicalOrder,
PhysicalProject, PhysicalSimpleAgg, PhysicalTableScan, PlanNode, PlanRef, PlanTreeNode,
PlanVisitor,
PhysicalCrossJoin, PhysicalFilter, PhysicalHashAgg, PhysicalHashJoin, PhysicalLimit,
PhysicalOrder, PhysicalProject, PhysicalSimpleAgg, PhysicalTableScan, PlanNode, PlanRef,
PlanTreeNode, PlanVisitor,
};
use crate::storage::{StorageError, StorageImpl};

Expand Down Expand Up @@ -112,6 +113,17 @@ impl PlanVisitor<BoxedExecutor> for ExecutorBuilder {
)
}

fn visit_physical_cross_join(&mut self, plan: &PhysicalCrossJoin) -> Option<BoxedExecutor> {
Some(
CrossJoinExecutor {
left_child: self.visit(plan.left()).unwrap(),
right_child: self.visit(plan.right()).unwrap(),
join_output_schema: plan.output_columns(),
}
.execute(),
)
}

fn visit_physical_project(&mut self, plan: &PhysicalProject) -> Option<BoxedExecutor> {
Some(
ProjectExecutor {
Expand Down
12 changes: 10 additions & 2 deletions src/optimizer/input_ref_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use std::sync::Arc;
use super::expr_rewriter::ExprRewriter;
use super::{
LogicalAgg, LogicalFilter, LogicalJoin, LogicalLimit, LogicalOrder, LogicalProject,
LogicalTableScan, PhysicalFilter, PhysicalHashAgg, PhysicalHashJoin, PhysicalLimit,
PhysicalOrder, PhysicalProject, PhysicalSimpleAgg, PhysicalTableScan, PlanRef, PlanRewriter,
LogicalTableScan, PhysicalCrossJoin, PhysicalFilter, PhysicalHashAgg, PhysicalHashJoin,
PhysicalLimit, PhysicalOrder, PhysicalProject, PhysicalSimpleAgg, PhysicalTableScan, PlanRef,
PlanRewriter,
};
use crate::binder::{BoundColumnRef, BoundExpr, BoundInputRef, JoinCondition};

Expand Down Expand Up @@ -168,6 +169,13 @@ impl PlanRewriter for InputRefRewriter {
))
}

fn rewrite_physical_cross_join(&mut self, plan: &super::PhysicalCrossJoin) -> PlanRef {
let logical = self.rewrite_logical_join(plan.logical());
Arc::new(PhysicalCrossJoin::new(
logical.as_logical_join().unwrap().clone(),
))
}

fn rewrite_logical_project(&mut self, plan: &LogicalProject) -> PlanRef {
let new_child = self.rewrite(plan.input());

Expand Down
19 changes: 10 additions & 9 deletions src/optimizer/physical_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use std::sync::Arc;

use super::plan_rewriter::PlanRewriter;
use super::{
LogicalAgg, LogicalFilter, LogicalJoin, LogicalProject, LogicalTableScan, PhysicalHashAgg,
PhysicalHashJoin, PhysicalLimit, PhysicalOrder, PhysicalSimpleAgg, PhysicalTableScan, PlanRef,
PlanTreeNode,
LogicalAgg, LogicalFilter, LogicalJoin, LogicalProject, LogicalTableScan, PhysicalCrossJoin,
PhysicalHashAgg, PhysicalHashJoin, PhysicalLimit, PhysicalOrder, PhysicalSimpleAgg,
PhysicalTableScan, PlanRef, PlanTreeNode,
};
use crate::binder::JoinType;
use crate::optimizer::{PhysicalFilter, PhysicalProject};

#[derive(Default)]
Expand All @@ -21,12 +22,12 @@ impl PlanRewriter for PhysicalRewriter {
let right = self.rewrite(plan.right());
let join_type = plan.join_type();
let join_condition = plan.join_condition();
Arc::new(PhysicalHashJoin::new(LogicalJoin::new(
left,
right,
join_type,
join_condition,
)))
let logical = LogicalJoin::new(left, right, join_type.clone(), join_condition);
if join_type == JoinType::Cross {
Arc::new(PhysicalCrossJoin::new(logical))
} else {
Arc::new(PhysicalHashJoin::new(logical))
}
}

fn rewrite_logical_project(&mut self, plan: &LogicalProject) -> PlanRef {
Expand Down
2 changes: 1 addition & 1 deletion src/optimizer/plan_node/logical_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl LogicalJoin {
JoinType::Left => (false, true),
JoinType::Right => (true, false),
JoinType::Full => (true, true),
JoinType::Cross => unreachable!(""),
JoinType::Cross => (true, true),
};
let left_fields = self
.left
Expand Down
9 changes: 7 additions & 2 deletions src/optimizer/plan_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod logical_limit;
mod logical_order;
mod logical_project;
mod logical_table_scan;
mod physical_cross_join;
mod physical_filter;
mod physical_hash_agg;
mod physical_hash_join;
Expand All @@ -29,6 +30,7 @@ pub use logical_order::*;
pub use logical_project::*;
pub use logical_table_scan::*;
use paste::paste;
pub use physical_cross_join::*;
pub use physical_filter::*;
pub use physical_hash_agg::*;
pub use physical_hash_join::*;
Expand Down Expand Up @@ -82,7 +84,8 @@ impl dyn PlanNode {
| PlanNodeType::PhysicalHashAgg
| PlanNodeType::PhysicalLimit
| PlanNodeType::PhysicalOrder
| PlanNodeType::PhysicalHashJoin => false,
| PlanNodeType::PhysicalHashJoin
| PlanNodeType::PhysicalCrossJoin => false,
}
}

Expand Down Expand Up @@ -125,6 +128,7 @@ impl dyn PlanNode {
PlanNodeType::PhysicalLimit => false,
PlanNodeType::PhysicalOrder => false,
PlanNodeType::PhysicalHashJoin => false,
PlanNodeType::PhysicalCrossJoin => false,
}
}
}
Expand Down Expand Up @@ -155,7 +159,8 @@ macro_rules! for_all_plan_nodes {
PhysicalHashAgg,
PhysicalLimit,
PhysicalOrder,
PhysicalHashJoin
PhysicalHashJoin,
PhysicalCrossJoin
}
};
}
Expand Down
68 changes: 68 additions & 0 deletions src/optimizer/plan_node/physical_cross_join.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use core::fmt;
use std::sync::Arc;

use super::{LogicalJoin, PlanNode, PlanRef, PlanTreeNode};
use crate::binder::JoinType;
use crate::catalog::ColumnCatalog;

#[derive(Debug, Clone)]
pub struct PhysicalCrossJoin {
logical: LogicalJoin,
}

impl PhysicalCrossJoin {
pub fn new(logical: LogicalJoin) -> Self {
Self { logical }
}

pub fn left(&self) -> PlanRef {
self.logical.left()
}

pub fn right(&self) -> PlanRef {
self.logical.right()
}

pub fn join_type(&self) -> JoinType {
self.logical.join_type()
}

pub fn logical(&self) -> &LogicalJoin {
&self.logical
}
}

impl PlanNode for PhysicalCrossJoin {
fn referenced_columns(&self) -> Vec<ColumnCatalog> {
self.logical.referenced_columns()
}

fn output_columns(&self) -> Vec<ColumnCatalog> {
self.logical.output_columns()
}
}

impl PlanTreeNode for PhysicalCrossJoin {
fn children(&self) -> Vec<PlanRef> {
vec![self.left(), self.right()]
}

fn clone_with_children(&self, children: Vec<PlanRef>) -> PlanRef {
let p = self.logical().clone_with_children(children);
Arc::new(Self::new(p.as_logical_join().unwrap().clone()))
}
}

impl fmt::Display for PhysicalCrossJoin {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "PhysicalCrossJoin: type {:?}", self.join_type(),)
}
}

impl PartialEq for PhysicalCrossJoin {
fn eq(&self, other: &Self) -> bool {
self.join_type() == other.join_type()
&& self.left() == other.left()
&& self.right() == other.right()
}
}
3 changes: 3 additions & 0 deletions src/optimizer/plan_node/plan_node_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ impl PartialEq for dyn PlanNode {
PlanNodeType::PhysicalHashJoin => {
self.as_physical_hash_join() == other.as_physical_hash_join()
}
PlanNodeType::PhysicalCrossJoin => {
self.as_physical_cross_join() == other.as_physical_cross_join()
}
}
}
}
Expand Down
17 changes: 17 additions & 0 deletions tests/slt/join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,20 @@ full join state on state.state_code=employee.state;
3 John Engineering Colorado State CO
NULL NULL NULL New Jersey NJ
4 Von NULL NULL NULL


query IIIIII
select t1.*, t2.* from t1, t2 where t1.a = 0;
----
0 4 7 10 2 7
0 4 7 20 2 5
0 4 7 30 3 6
0 4 7 40 4 6

query IIIIII
select t1.*, t2.* from t1 cross join t2 where t1.a = 0;
----
0 4 7 10 2 7
0 4 7 20 2 5
0 4 7 30 3 6
0 4 7 40 4 6