diff --git a/src/binder/statement/mod.rs b/src/binder/statement/mod.rs index f655d4f..eca1a9a 100644 --- a/src/binder/statement/mod.rs +++ b/src/binder/statement/mod.rs @@ -1,5 +1,5 @@ use sqlparser::ast::SetExpr::Select; -use sqlparser::ast::{Query, SelectItem}; +use sqlparser::ast::{Join, JoinOperator, Query, SelectItem, TableWithJoins}; use super::expression::BoundExpr; use super::table::BoundTableRef; @@ -38,6 +38,22 @@ impl Binder { // currently, only support select one table let from_table = if select.from.is_empty() { None + } else if select.from.len() > 1 { + // merge select from multiple tables into one cross join + // TODO: add more checks + let first_talbe = select.from[0].clone(); + let joins = select.from[1..] + .iter() + .map(|a| Join { + relation: a.relation.clone(), + join_operator: JoinOperator::CrossJoin, + }) + .collect(); + let table_with_joins = TableWithJoins { + relation: first_talbe.relation, + joins, + }; + Some(self.bind_table_with_joins(&table_with_joins)?) } else { Some(self.bind_table_with_joins(&select.from[0])?) }; diff --git a/src/executor/join/cross_join.rs b/src/executor/join/cross_join.rs new file mode 100644 index 0000000..6180ad4 --- /dev/null +++ b/src/executor/join/cross_join.rs @@ -0,0 +1,57 @@ +use arrow::datatypes::{Schema, SchemaRef}; + +use crate::catalog::ColumnCatalog; +use crate::executor::*; +use crate::types::{build_scalar_value_array, ScalarValue}; + +pub struct CrossJoinExecutor { + pub left_child: BoxedExecutor, + pub right_child: BoxedExecutor, + /// The schema once the join is applied + pub join_output_schema: Vec, +} + +impl CrossJoinExecutor { + fn join_output_arrow_schema(&self) -> SchemaRef { + let fields = self + .join_output_schema + .iter() + .map(|c| c.to_arrow_field()) + .collect::>(); + SchemaRef::new(Schema::new(fields)) + } + + #[try_stream(boxed, ok = RecordBatch, error = ExecutorError)] + pub async fn execute(self) { + let schema = self.join_output_arrow_schema(); + + // consume all left stream data and then iterate right stream chunk to build result + let left_batches = self.left_child.try_collect::>().await?; + + if left_batches.is_empty() { + return Ok(()); + } + + let left_single_batch = RecordBatch::concat(&left_batches[0].schema(), &left_batches)?; + + #[for_await] + for right_batch in self.right_child { + let right_data = right_batch?; + + // repeat left value n times to match right batch size + for row_idx in 0..left_single_batch.num_rows() { + let new_left_data = left_single_batch + .columns() + .iter() + .map(|col_arr| { + let scalar = ScalarValue::try_from_array(col_arr, row_idx); + build_scalar_value_array(&scalar, right_data.num_rows()) + }) + .collect::>(); + // concat left and right data + let data = vec![new_left_data, right_data.columns().to_vec()].concat(); + yield RecordBatch::try_new(schema.clone(), data)? + } + } + } +} diff --git a/src/executor/join/mod.rs b/src/executor/join/mod.rs index 788e4e7..550a72b 100644 --- a/src/executor/join/mod.rs +++ b/src/executor/join/mod.rs @@ -1 +1,2 @@ +pub mod cross_join; pub mod hash_join; diff --git a/src/executor/mod.rs b/src/executor/mod.rs index 1e54df8..e01c762 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -18,15 +18,16 @@ use futures_async_stream::try_stream; use self::aggregate::hash_agg::HashAggExecutor; use self::aggregate::simple_agg::SimpleAggExecutor; use self::filter::FilterExecutor; +use self::join::cross_join::CrossJoinExecutor; use self::join::hash_join::HashJoinExecutor; use self::limit::LimitExecutor; use self::order::OrderExecutor; use self::project::ProjectExecutor; use self::table_scan::TableScanExecutor; use crate::optimizer::{ - PhysicalFilter, PhysicalHashAgg, PhysicalHashJoin, PhysicalLimit, PhysicalOrder, - PhysicalProject, PhysicalSimpleAgg, PhysicalTableScan, PlanNode, PlanRef, PlanTreeNode, - PlanVisitor, + PhysicalCrossJoin, PhysicalFilter, PhysicalHashAgg, PhysicalHashJoin, PhysicalLimit, + PhysicalOrder, PhysicalProject, PhysicalSimpleAgg, PhysicalTableScan, PlanNode, PlanRef, + PlanTreeNode, PlanVisitor, }; use crate::storage::{StorageError, StorageImpl}; @@ -112,6 +113,17 @@ impl PlanVisitor for ExecutorBuilder { ) } + fn visit_physical_cross_join(&mut self, plan: &PhysicalCrossJoin) -> Option { + Some( + CrossJoinExecutor { + left_child: self.visit(plan.left()).unwrap(), + right_child: self.visit(plan.right()).unwrap(), + join_output_schema: plan.output_columns(), + } + .execute(), + ) + } + fn visit_physical_project(&mut self, plan: &PhysicalProject) -> Option { Some( ProjectExecutor { diff --git a/src/optimizer/input_ref_rewriter.rs b/src/optimizer/input_ref_rewriter.rs index 0cd5940..714a400 100644 --- a/src/optimizer/input_ref_rewriter.rs +++ b/src/optimizer/input_ref_rewriter.rs @@ -3,8 +3,9 @@ use std::sync::Arc; use super::expr_rewriter::ExprRewriter; use super::{ LogicalAgg, LogicalFilter, LogicalJoin, LogicalLimit, LogicalOrder, LogicalProject, - LogicalTableScan, PhysicalFilter, PhysicalHashAgg, PhysicalHashJoin, PhysicalLimit, - PhysicalOrder, PhysicalProject, PhysicalSimpleAgg, PhysicalTableScan, PlanRef, PlanRewriter, + LogicalTableScan, PhysicalCrossJoin, PhysicalFilter, PhysicalHashAgg, PhysicalHashJoin, + PhysicalLimit, PhysicalOrder, PhysicalProject, PhysicalSimpleAgg, PhysicalTableScan, PlanRef, + PlanRewriter, }; use crate::binder::{BoundColumnRef, BoundExpr, BoundInputRef, JoinCondition}; @@ -168,6 +169,13 @@ impl PlanRewriter for InputRefRewriter { )) } + fn rewrite_physical_cross_join(&mut self, plan: &super::PhysicalCrossJoin) -> PlanRef { + let logical = self.rewrite_logical_join(plan.logical()); + Arc::new(PhysicalCrossJoin::new( + logical.as_logical_join().unwrap().clone(), + )) + } + fn rewrite_logical_project(&mut self, plan: &LogicalProject) -> PlanRef { let new_child = self.rewrite(plan.input()); diff --git a/src/optimizer/physical_rewriter.rs b/src/optimizer/physical_rewriter.rs index 9ef8263..5259baa 100644 --- a/src/optimizer/physical_rewriter.rs +++ b/src/optimizer/physical_rewriter.rs @@ -2,10 +2,11 @@ use std::sync::Arc; use super::plan_rewriter::PlanRewriter; use super::{ - LogicalAgg, LogicalFilter, LogicalJoin, LogicalProject, LogicalTableScan, PhysicalHashAgg, - PhysicalHashJoin, PhysicalLimit, PhysicalOrder, PhysicalSimpleAgg, PhysicalTableScan, PlanRef, - PlanTreeNode, + LogicalAgg, LogicalFilter, LogicalJoin, LogicalProject, LogicalTableScan, PhysicalCrossJoin, + PhysicalHashAgg, PhysicalHashJoin, PhysicalLimit, PhysicalOrder, PhysicalSimpleAgg, + PhysicalTableScan, PlanRef, PlanTreeNode, }; +use crate::binder::JoinType; use crate::optimizer::{PhysicalFilter, PhysicalProject}; #[derive(Default)] @@ -21,12 +22,12 @@ impl PlanRewriter for PhysicalRewriter { let right = self.rewrite(plan.right()); let join_type = plan.join_type(); let join_condition = plan.join_condition(); - Arc::new(PhysicalHashJoin::new(LogicalJoin::new( - left, - right, - join_type, - join_condition, - ))) + let logical = LogicalJoin::new(left, right, join_type.clone(), join_condition); + if join_type == JoinType::Cross { + Arc::new(PhysicalCrossJoin::new(logical)) + } else { + Arc::new(PhysicalHashJoin::new(logical)) + } } fn rewrite_logical_project(&mut self, plan: &LogicalProject) -> PlanRef { diff --git a/src/optimizer/plan_node/logical_join.rs b/src/optimizer/plan_node/logical_join.rs index 02abfd9..45ff931 100644 --- a/src/optimizer/plan_node/logical_join.rs +++ b/src/optimizer/plan_node/logical_join.rs @@ -85,7 +85,7 @@ impl LogicalJoin { JoinType::Left => (false, true), JoinType::Right => (true, false), JoinType::Full => (true, true), - JoinType::Cross => unreachable!(""), + JoinType::Cross => (true, true), }; let left_fields = self .left diff --git a/src/optimizer/plan_node/mod.rs b/src/optimizer/plan_node/mod.rs index 89128c6..877365f 100644 --- a/src/optimizer/plan_node/mod.rs +++ b/src/optimizer/plan_node/mod.rs @@ -6,6 +6,7 @@ mod logical_limit; mod logical_order; mod logical_project; mod logical_table_scan; +mod physical_cross_join; mod physical_filter; mod physical_hash_agg; mod physical_hash_join; @@ -29,6 +30,7 @@ pub use logical_order::*; pub use logical_project::*; pub use logical_table_scan::*; use paste::paste; +pub use physical_cross_join::*; pub use physical_filter::*; pub use physical_hash_agg::*; pub use physical_hash_join::*; @@ -82,7 +84,8 @@ impl dyn PlanNode { | PlanNodeType::PhysicalHashAgg | PlanNodeType::PhysicalLimit | PlanNodeType::PhysicalOrder - | PlanNodeType::PhysicalHashJoin => false, + | PlanNodeType::PhysicalHashJoin + | PlanNodeType::PhysicalCrossJoin => false, } } @@ -125,6 +128,7 @@ impl dyn PlanNode { PlanNodeType::PhysicalLimit => false, PlanNodeType::PhysicalOrder => false, PlanNodeType::PhysicalHashJoin => false, + PlanNodeType::PhysicalCrossJoin => false, } } } @@ -155,7 +159,8 @@ macro_rules! for_all_plan_nodes { PhysicalHashAgg, PhysicalLimit, PhysicalOrder, - PhysicalHashJoin + PhysicalHashJoin, + PhysicalCrossJoin } }; } diff --git a/src/optimizer/plan_node/physical_cross_join.rs b/src/optimizer/plan_node/physical_cross_join.rs new file mode 100644 index 0000000..e02479f --- /dev/null +++ b/src/optimizer/plan_node/physical_cross_join.rs @@ -0,0 +1,68 @@ +use core::fmt; +use std::sync::Arc; + +use super::{LogicalJoin, PlanNode, PlanRef, PlanTreeNode}; +use crate::binder::JoinType; +use crate::catalog::ColumnCatalog; + +#[derive(Debug, Clone)] +pub struct PhysicalCrossJoin { + logical: LogicalJoin, +} + +impl PhysicalCrossJoin { + pub fn new(logical: LogicalJoin) -> Self { + Self { logical } + } + + pub fn left(&self) -> PlanRef { + self.logical.left() + } + + pub fn right(&self) -> PlanRef { + self.logical.right() + } + + pub fn join_type(&self) -> JoinType { + self.logical.join_type() + } + + pub fn logical(&self) -> &LogicalJoin { + &self.logical + } +} + +impl PlanNode for PhysicalCrossJoin { + fn referenced_columns(&self) -> Vec { + self.logical.referenced_columns() + } + + fn output_columns(&self) -> Vec { + self.logical.output_columns() + } +} + +impl PlanTreeNode for PhysicalCrossJoin { + fn children(&self) -> Vec { + vec![self.left(), self.right()] + } + + fn clone_with_children(&self, children: Vec) -> PlanRef { + let p = self.logical().clone_with_children(children); + Arc::new(Self::new(p.as_logical_join().unwrap().clone())) + } +} + +impl fmt::Display for PhysicalCrossJoin { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "PhysicalCrossJoin: type {:?}", self.join_type(),) + } +} + +impl PartialEq for PhysicalCrossJoin { + fn eq(&self, other: &Self) -> bool { + self.join_type() == other.join_type() + && self.left() == other.left() + && self.right() == other.right() + } +} diff --git a/src/optimizer/plan_node/plan_node_traits.rs b/src/optimizer/plan_node/plan_node_traits.rs index b7ddbdd..94d54c8 100644 --- a/src/optimizer/plan_node/plan_node_traits.rs +++ b/src/optimizer/plan_node/plan_node_traits.rs @@ -78,6 +78,9 @@ impl PartialEq for dyn PlanNode { PlanNodeType::PhysicalHashJoin => { self.as_physical_hash_join() == other.as_physical_hash_join() } + PlanNodeType::PhysicalCrossJoin => { + self.as_physical_cross_join() == other.as_physical_cross_join() + } } } } diff --git a/tests/slt/join.slt b/tests/slt/join.slt index 4be7e38..64da005 100644 --- a/tests/slt/join.slt +++ b/tests/slt/join.slt @@ -74,3 +74,20 @@ full join state on state.state_code=employee.state; 3 John Engineering Colorado State CO NULL NULL NULL New Jersey NJ 4 Von NULL NULL NULL + + +query IIIIII +select t1.*, t2.* from t1, t2 where t1.a = 0; +---- +0 4 7 10 2 7 +0 4 7 20 2 5 +0 4 7 30 3 6 +0 4 7 40 4 6 + +query IIIIII +select t1.*, t2.* from t1 cross join t2 where t1.a = 0; +---- +0 4 7 10 2 7 +0 4 7 20 2 5 +0 4 7 30 3 6 +0 4 7 40 4 6