From df9a2add895312d92a745c8546a75475986a7cbd Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Fri, 8 Apr 2022 21:35:28 +0800 Subject: [PATCH 01/10] Implement Sort-Merge join (#141) --- datafusion/core/src/physical_plan/mod.rs | 1 + .../core/src/physical_plan/sort_merge_join.rs | 1601 +++++++++++++++++ 2 files changed, 1602 insertions(+) create mode 100644 datafusion/core/src/physical_plan/sort_merge_join.rs diff --git a/datafusion/core/src/physical_plan/mod.rs b/datafusion/core/src/physical_plan/mod.rs index f3d9ced624189..6f47a91d91ad9 100644 --- a/datafusion/core/src/physical_plan/mod.rs +++ b/datafusion/core/src/physical_plan/mod.rs @@ -566,6 +566,7 @@ pub mod metrics; pub mod planner; pub mod projection; pub mod repartition; +pub mod sort_merge_join; pub mod sorts; pub mod stream; pub mod type_coercion; diff --git a/datafusion/core/src/physical_plan/sort_merge_join.rs b/datafusion/core/src/physical_plan/sort_merge_join.rs new file mode 100644 index 0000000000000..6ea552a4933d5 --- /dev/null +++ b/datafusion/core/src/physical_plan/sort_merge_join.rs @@ -0,0 +1,1601 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::cmp::Ordering; +use std::collections::VecDeque; +use std::ops::Range; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::*; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::error::{ArrowError, Result as ArrowResult}; +use arrow::record_batch::RecordBatch; +use async_trait::async_trait; +use futures::{Stream, StreamExt}; + +use crate::error::DataFusionError; +use crate::error::Result; +use crate::execution::context::TaskContext; +use crate::logical_plan::JoinType; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::expressions::Column; +use crate::physical_plan::expressions::PhysicalSortExpr; +use crate::physical_plan::join_utils::{build_join_schema, check_join_is_valid, JoinOn}; +use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::physical_plan::{ + metrics, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, + Statistics, +}; + +#[derive(Debug)] +pub struct SortMergeJoinExec { + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + schema: SchemaRef, + metrics: ExecutionPlanMetricsSet, + sort_options: SortOptions, + null_equals_null: bool, +} + +impl SortMergeJoinExec { + pub fn try_new( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + sort_options: SortOptions, + null_equals_null: bool, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + + check_join_is_valid(&left_schema, &right_schema, &on)?; + let schema = + Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); + + Ok(Self { + left, + right, + on, + join_type, + schema, + metrics: ExecutionPlanMetricsSet::new(), + sort_options, + null_equals_null, + }) + } +} + +#[async_trait] +impl ExecutionPlan for SortMergeJoinExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn output_partitioning(&self) -> Partitioning { + self.right.output_partitioning() + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.right.output_ordering() + } + + fn children(&self) -> Vec> { + vec![self.left.clone(), self.right.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match &children[..] { + [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new( + left.clone(), + right.clone(), + self.on.clone(), + self.join_type, + self.sort_options, + self.null_equals_null, + )?)), + _ => Err(DataFusionError::Internal( + "SortMergeJoin wrong number of children".to_string(), + )), + } + } + + async fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let (streamed, buffered, on_streamed, on_buffered) = match self.join_type { + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::Anti + | JoinType::Semi => ( + self.left.clone(), + self.right.clone(), + self.on.iter().map(|on| on.0.clone()).collect(), + self.on.iter().map(|on| on.1.clone()).collect(), + ), + JoinType::Right => ( + self.right.clone(), + self.left.clone(), + self.on.iter().map(|on| on.1.clone()).collect(), + self.on.iter().map(|on| on.0.clone()).collect(), + ), + }; + + // execute children plans + let streamed = CoalescePartitionsExec::new(streamed) + .execute(0, context.clone()) + .await?; + let buffered = buffered.execute(partition, context.clone()).await?; + + // create output buffer + let batch_size = context.session_config().batch_size; + let output_buffer = new_array_builders(self.schema(), batch_size) + .map_err(DataFusionError::ArrowError)?; + + // create join stream + Ok(Box::pin(SMJStream::try_new( + self.schema.clone(), + self.sort_options, + self.null_equals_null, + streamed, + buffered, + on_streamed, + on_buffered, + self.join_type, + output_buffer, + batch_size, + )?)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Statistics { + todo!() + } +} + +/// Metrics for SortMergeJoinExec +#[warn(dead_code)] +struct SortMergeJoinMetrics { + /// Total time for joining probe-side batches to the build-side batches + join_time: metrics::Time, + /// Number of batches consumed by this operator + input_batches: metrics::Count, + /// Number of rows consumed by this operator + input_rows: metrics::Count, + /// Number of batches produced by this operator + output_batches: metrics::Count, + /// Number of rows produced by this operator + output_rows: metrics::Count, +} + +impl SortMergeJoinMetrics { + #[allow(dead_code)] + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition); + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + join_time, + input_batches, + input_rows, + output_batches, + output_rows, + } + } +} + +#[derive(Debug, PartialEq, Eq)] +enum SMJState { + Init, + Polling, + JoinOutput, + Exhausted, +} + +#[derive(Debug, PartialEq, Eq)] +enum StreamedState { + Init, + Polling, + Ready, + Exhausted, +} + +#[derive(Debug, PartialEq, Eq)] +enum BufferedState { + Init, + PollingFirst, + PollingRest, + Ready, + Exhausted, +} + +#[derive(Debug)] +struct BufferedBatch { + pub batch: RecordBatch, + pub range: Range, + pub join_arrays: Vec, +} +impl BufferedBatch { + fn new(batch: RecordBatch, range: Range, on_column: &[Column]) -> Self { + let join_arrays = join_arrays(&batch, on_column); + BufferedBatch { + batch, + range, + join_arrays, + } + } +} + +struct SMJStream { + pub state: SMJState, + pub schema: SchemaRef, + pub sort_options: SortOptions, + pub null_equals_null: bool, + pub streamed_schema: SchemaRef, + pub buffered_schema: SchemaRef, + pub num_streamed_columns: usize, + pub num_buffered_columns: usize, + pub streamed: SendableRecordBatchStream, + pub buffered: SendableRecordBatchStream, + pub streamed_batch: RecordBatch, + pub streamed_idx: usize, + pub buffered_data: BufferedData, + pub streamed_joined: bool, + pub buffered_joined: bool, + pub streamed_state: StreamedState, + pub buffered_state: BufferedState, + pub current_ordering: Ordering, + pub on_streamed: Vec, + pub on_buffered: Vec, + pub output_buffer: Vec>, + pub output_size: usize, + pub batch_size: usize, + pub join_type: JoinType, +} + +impl RecordBatchStream for SMJStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl Stream for SMJStream { + type Item = ArrowResult; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + match &self.state { + SMJState::Init => { + self.buffered_data.scanning_reset(); + let streamed_exhausted = + self.streamed_state == StreamedState::Exhausted; + let buffered_exhausted = + self.buffered_state == BufferedState::Exhausted; + self.state = if streamed_exhausted && buffered_exhausted { + SMJState::Exhausted + } else { + match self.current_ordering { + Ordering::Less | Ordering::Equal => { + if !streamed_exhausted { + self.streamed_joined = false; + self.streamed_state = StreamedState::Init; + } + } + Ordering::Greater => { + if !buffered_exhausted { + self.buffered_joined = false; + self.buffered_state = BufferedState::Init; + } + } + } + SMJState::Polling + }; + } + SMJState::Polling => { + if ![StreamedState::Exhausted, StreamedState::Ready] + .contains(&self.streamed_state) + { + match self.poll_streamed_row(cx) { + Poll::Ready(Some(Ok(()))) => {} + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(e))) + } + Poll::Ready(None) => {} + Poll::Pending => return Poll::Pending, + } + } + + if ![BufferedState::Exhausted, BufferedState::Ready] + .contains(&self.buffered_state) + { + match self.poll_buffered_batches(cx) { + Poll::Ready(Some(Ok(()))) => {} + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(e))) + } + Poll::Ready(None) => {} + Poll::Pending => return Poll::Pending, + } + } + let streamed_exhausted = + self.streamed_state == StreamedState::Exhausted; + let buffered_exhausted = + self.buffered_state == BufferedState::Exhausted; + if streamed_exhausted && buffered_exhausted { + self.state = SMJState::Exhausted; + continue; + } + self.current_ordering = self.compare_streamed_buffered()?; + self.state = SMJState::JoinOutput; + } + SMJState::JoinOutput => { + self.join_partial()?; + if self.output_size == self.batch_size { + let record_batch = self.output_record_batch_and_reset()?; + return Poll::Ready(Some(Ok(record_batch))); + } + if self.buffered_data.scanning_finished() { + if self.current_ordering.is_le() { + self.streamed_joined = true; + } + if self.current_ordering.is_ge() { + self.buffered_joined = true; + } + self.state = SMJState::Init; + } + } + SMJState::Exhausted => { + if self.output_size > 0 { + let record_batch = self.output_record_batch_and_reset()?; + return Poll::Ready(Some(Ok(record_batch))); + } + return Poll::Ready(None); + } + } + } + } +} + +impl SMJStream { + pub fn try_new( + schema: SchemaRef, + sort_options: SortOptions, + null_equals_null: bool, + streamed: SendableRecordBatchStream, + buffered: SendableRecordBatchStream, + on_streamed: Vec, + on_buffered: Vec, + join_type: JoinType, + output_buffer: Vec>, + batch_size: usize, + ) -> Result { + Ok(Self { + state: SMJState::Init, + sort_options, + null_equals_null, + schema: schema.clone(), + streamed_schema: streamed.schema(), + buffered_schema: buffered.schema(), + num_streamed_columns: streamed.schema().fields().len(), + num_buffered_columns: buffered.schema().fields().len(), + streamed, + buffered, + streamed_batch: RecordBatch::new_empty(schema), + streamed_idx: 0, + buffered_data: BufferedData::default(), + streamed_joined: false, + buffered_joined: false, + streamed_state: StreamedState::Init, + buffered_state: BufferedState::Init, + current_ordering: Ordering::Equal, + on_streamed, + on_buffered, + output_buffer, + output_size: 0, + batch_size, + join_type, + }) + } + + fn poll_streamed_row(&mut self, cx: &mut Context) -> Poll>> { + loop { + match &self.streamed_state { + StreamedState::Init => { + if self.streamed_idx + 1 < self.streamed_batch.num_rows() { + self.streamed_idx += 1; + self.streamed_state = StreamedState::Ready; + return Poll::Ready(Some(Ok(()))); + } else { + self.streamed_state = StreamedState::Polling; + } + continue; + } + StreamedState::Polling => match self.streamed.poll_next_unpin(cx)? { + Poll::Pending => { + return Poll::Pending; + } + Poll::Ready(None) => { + self.streamed_state = StreamedState::Exhausted; + } + Poll::Ready(Some(batch)) => { + if batch.num_rows() > 0 { + self.streamed_batch = batch; + self.streamed_idx = 0; + self.streamed_state = StreamedState::Ready; + } + } + }, + StreamedState::Ready => { + return Poll::Ready(Some(Ok(()))); + } + StreamedState::Exhausted => { + return Poll::Ready(None); + } + } + } + } + + fn poll_buffered_batches( + &mut self, + cx: &mut Context, + ) -> Poll>> { + loop { + match &self.buffered_state { + BufferedState::Init => { + // pop previous buffered batches + while !self.buffered_data.batches.is_empty() { + let head_batch = self.buffered_data.head_batch(); + if head_batch.range.end == head_batch.batch.num_rows() { + self.buffered_data.batches.pop_front(); + } else { + break; + } + } + if self.buffered_data.batches.is_empty() { + self.buffered_state = BufferedState::PollingFirst; + } else { + let tail_batch = self.buffered_data.tail_batch_mut(); + tail_batch.range.start = tail_batch.range.end; + tail_batch.range.end += 1; + self.buffered_state = BufferedState::PollingRest; + } + } + BufferedState::PollingFirst => match self.buffered.poll_next_unpin(cx)? { + Poll::Pending => { + return Poll::Pending; + } + Poll::Ready(None) => { + self.buffered_state = BufferedState::Exhausted; + return Poll::Ready(None); + } + Poll::Ready(Some(batch)) => { + if batch.num_rows() > 0 { + self.buffered_data.batches.push_back(BufferedBatch::new( + batch, + 0..1, + &self.on_buffered, + )); + self.buffered_state = BufferedState::PollingRest; + } + } + }, + BufferedState::PollingRest => { + if self.buffered_data.tail_batch().range.end + < self.buffered_data.tail_batch().batch.num_rows() + { + while self.buffered_data.tail_batch().range.end + < self.buffered_data.tail_batch().batch.num_rows() + { + if is_join_arrays_equal( + self.buffered_data.head_batch().batch.columns(), + self.buffered_data.head_batch().range.start, + self.buffered_data.tail_batch().batch.columns(), + self.buffered_data.tail_batch().range.end, + )? { + self.buffered_data.tail_batch_mut().range.end += 1; + } else { + self.buffered_state = BufferedState::Ready; + return Poll::Ready(Some(Ok(()))); + } + } + } else { + match self.buffered.poll_next_unpin(cx)? { + Poll::Pending => { + return Poll::Pending; + } + Poll::Ready(None) => { + self.buffered_state = BufferedState::Ready; + } + Poll::Ready(Some(batch)) => { + self.buffered_data.batches.push_back(BufferedBatch::new( + batch, + 0..0, + &self.on_buffered, + )); + } + } + } + } + BufferedState::Ready => { + return Poll::Ready(Some(Ok(()))); + } + BufferedState::Exhausted => { + return Poll::Ready(None); + } + } + } + } + + fn compare_streamed_buffered(&self) -> ArrowResult { + if self.streamed_state == StreamedState::Exhausted { + return Ok(Ordering::Greater); + } + if !self.buffered_data.has_buffered_rows() { + return Ok(Ordering::Less); + } + + return compare_join_arrays( + &join_arrays(&self.streamed_batch, &self.on_streamed), + self.streamed_idx, + &join_arrays(&self.buffered_data.head_batch().batch, &self.on_buffered), + self.buffered_data.head_batch().range.start, + self.sort_options, + self.null_equals_null, + ); + } + + fn join_partial(&mut self) -> ArrowResult<()> { + // decide streamed/buffered output columns by join type + let output_parts = + self.output_buffer + .split_at_mut(if self.join_type != JoinType::Right { + self.num_streamed_columns + } else { + self.num_buffered_columns + }); + let (streamed_output, buffered_output) = if self.join_type != JoinType::Right { + (output_parts.0, output_parts.1) + } else { + (output_parts.1, output_parts.0) + }; + + match self.current_ordering { + Ordering::Less => { + let output_streamed_join = match self.join_type { + JoinType::Inner | JoinType::Semi => false, + JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::Anti => !self.streamed_joined, + }; + + // streamed joins null + if output_streamed_join { + append_row_to_output( + &self.streamed_batch, + self.streamed_idx, + streamed_output, + )?; + append_nulls_row_to_output(&self.buffered_schema, buffered_output)?; + self.output_size += 1; + } + self.buffered_data.scanning_finish(); + } + Ordering::Equal => { + let output_equal_join = match self.join_type { + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::Semi => true, + JoinType::Anti => false, + }; + + // streamed joins buffered + if !output_equal_join { + self.buffered_data.scanning_finish(); + } + } + Ordering::Greater => { + let output_buffered_join = match self.join_type { + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Anti + | JoinType::Semi => false, + JoinType::Full => !self.buffered_joined, + }; + + // null joins buffered + if !output_buffered_join { + self.buffered_data.scanning_finish(); + } + } + } + + // scan buffered stream and write to output buffer + while !self.buffered_data.scanning_finished() + && self.output_size < self.batch_size + { + if self.current_ordering == Ordering::Equal { + append_row_to_output( + &self.streamed_batch, + self.streamed_idx, + streamed_output, + )?; + } else { + append_nulls_row_to_output(&self.streamed_schema, streamed_output)?; + } + + append_row_to_output( + &self.buffered_data.scanning_batch().batch, + self.buffered_data.scanning_idx(), + buffered_output, + )?; + self.output_size += 1; + self.buffered_data.scanning_advance(); + } + Ok(()) + } + + #[inline] + fn output_buffer_full(&self) -> bool { + self.output_size == self.batch_size + } + + fn output_record_batch_and_reset(&mut self) -> ArrowResult { + let record_batch = + make_batch(self.schema.clone(), self.output_buffer.drain(..).collect())?; + self.output_size = 0; + self.output_buffer + .extend(new_array_builders(self.schema.clone(), self.batch_size)?); + Ok(record_batch) + } +} + +#[derive(Default)] +struct BufferedData { + pub batches: VecDeque, + pub scanning_batch_idx: usize, + pub scanning_offset: usize, +} +impl BufferedData { + pub fn head_batch(&self) -> &BufferedBatch { + self.batches.front().unwrap() + } + + pub fn tail_batch(&self) -> &BufferedBatch { + self.batches.back().unwrap() + } + + pub fn head_batch_mut(&mut self) -> &mut BufferedBatch { + self.batches.front_mut().unwrap() + } + + pub fn tail_batch_mut(&mut self) -> &mut BufferedBatch { + self.batches.back_mut().unwrap() + } + + pub fn has_buffered_rows(&self) -> bool { + self.batches.iter().any(|batch| !batch.range.is_empty()) + } + + pub fn scanning_reset(&mut self) { + self.scanning_batch_idx = 0; + self.scanning_offset = 0; + } + + pub fn scanning_advance(&mut self) { + self.scanning_offset += 1; + while !self.scanning_finished() && self.scanning_batch_finished() { + self.scanning_batch_idx += 1; + self.scanning_offset = 0; + } + } + + pub fn scanning_batch(&self) -> &BufferedBatch { + &self.batches[self.scanning_batch_idx] + } + + pub fn scanning_idx(&self) -> usize { + self.scanning_batch().range.start + self.scanning_offset + } + + pub fn scanning_batch_finished(&self) -> bool { + self.scanning_offset == self.scanning_batch().range.len() + } + + pub fn scanning_finished(&self) -> bool { + self.scanning_batch_idx == self.batches.len() + } + + pub fn scanning_finish(&mut self) { + self.scanning_batch_idx = self.batches.len(); + self.scanning_offset = 0; + } +} + +fn join_arrays(batch: &RecordBatch, on_column: &[Column]) -> Vec { + on_column + .iter() + .map(|c| batch.column(c.index()).clone()) + .collect() +} + +fn compare_join_arrays( + left_arrays: &[ArrayRef], + left: usize, + right_arrays: &[ArrayRef], + right: usize, + sort_options: SortOptions, + null_equals_null: bool, +) -> ArrowResult { + let mut res = Ordering::Equal; + for (left_array, right_array) in left_arrays.iter().zip(right_arrays) { + macro_rules! compare_value { + ($T:ty) => {{ + let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); + let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); + match (left_array.is_null(left), right_array.is_null(right)) { + (false, false) => { + let left_value = &left_array.value(left); + let right_value = &right_array.value(right); + res = left_value.partial_cmp(right_value).unwrap(); + if sort_options.descending { + res = res.reverse(); + } + } + (true, false) => { + res = if sort_options.nulls_first { + Ordering::Less + } else { + Ordering::Greater + }; + } + (false, true) => { + res = if sort_options.nulls_first { + Ordering::Greater + } else { + Ordering::Less + }; + } + _ => { + res = if null_equals_null { + Ordering::Equal + } else { + Ordering::Less + }; + } + } + }}; + } + + match left_array.data_type() { + DataType::Null => {} + DataType::Boolean => compare_value!(BooleanArray), + DataType::Int8 => compare_value!(Int8Array), + DataType::Int16 => compare_value!(Int16Array), + DataType::Int32 => compare_value!(Int32Array), + DataType::Int64 => compare_value!(Int64Array), + DataType::UInt8 => compare_value!(UInt8Array), + DataType::UInt16 => compare_value!(UInt16Array), + DataType::UInt32 => compare_value!(UInt32Array), + DataType::UInt64 => compare_value!(UInt64Array), + DataType::Timestamp(_, None) => compare_value!(Int64Array), + DataType::Utf8 => compare_value!(StringArray), + DataType::LargeUtf8 => compare_value!(LargeStringArray), + _ => { + return Err(ArrowError::NotYetImplemented( + "Unsupported data type in sort merge join comparator".to_owned(), + )); + } + } + if !res.is_eq() { + break; + } + } + Ok(res) +} + +fn is_join_arrays_equal( + left_arrays: &[ArrayRef], + left: usize, + right_arrays: &[ArrayRef], + right: usize, +) -> ArrowResult { + let mut is_equal = true; + for (left_array, right_array) in left_arrays.iter().zip(right_arrays) { + macro_rules! compare_value { + ($T:ty) => {{ + let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); + let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); + match (left_array.is_null(left), right_array.is_null(right)) { + (false, false) => { + if left_array.value(left) != right_array.value(right) { + is_equal = false; + } + } + (true, false) => is_equal = false, + (false, true) => is_equal = false, + _ => {} + } + }}; + } + + match left_array.data_type() { + DataType::Null => {} + DataType::Boolean => compare_value!(BooleanArray), + DataType::Int8 => compare_value!(Int8Array), + DataType::Int16 => compare_value!(Int16Array), + DataType::Int32 => compare_value!(Int32Array), + DataType::Int64 => compare_value!(Int64Array), + DataType::UInt8 => compare_value!(UInt8Array), + DataType::UInt16 => compare_value!(UInt16Array), + DataType::UInt32 => compare_value!(UInt32Array), + DataType::UInt64 => compare_value!(UInt64Array), + DataType::Timestamp(_, None) => compare_value!(Int64Array), + DataType::Utf8 => compare_value!(StringArray), + DataType::LargeUtf8 => compare_value!(LargeStringArray), + _ => { + return Err(ArrowError::NotYetImplemented( + "Unsupported data type in sort merge join comparator".to_owned(), + )); + } + } + if !is_equal { + return Ok(false); + } + } + Ok(true) +} + +fn new_array_builders( + schema: SchemaRef, + batch_size: usize, +) -> ArrowResult>> { + let arrays: Vec> = schema + .fields() + .iter() + .map(|field| { + let dt = field.data_type(); + make_builder(dt, batch_size) + }) + .collect(); + Ok(arrays) +} + +fn append_row_to_output( + batch: &RecordBatch, + idx: usize, + arrays: &mut [Box], +) -> ArrowResult<()> { + if !arrays.is_empty() { + return batch + .columns() + .iter() + .zip(batch.schema().fields()) + .enumerate() + .try_for_each(|(i, (column, f))| { + array_append_value(f.data_type(), &mut arrays[i], &*column, idx) + }); + } + Ok(()) +} + +fn append_nulls_row_to_output( + schema: &Schema, + arrays: &mut [Box], +) -> ArrowResult<()> { + if !arrays.is_empty() { + return schema + .fields() + .iter() + .enumerate() + .try_for_each(|(i, f)| array_append_null(f.data_type(), &mut arrays[i])); + } + Ok(()) +} + +fn make_batch( + schema: SchemaRef, + mut arrays: Vec>, +) -> ArrowResult { + let columns = arrays.iter_mut().map(|array| array.finish()).collect(); + RecordBatch::try_new(schema, columns) +} + +/// repeat times of cell located by `idx` at streamed side to output +fn array_append_null( + data_type: &DataType, + to: &mut Box, +) -> ArrowResult<()> { + macro_rules! append_null { + ($TO:ty) => {{ + to.as_any_mut().downcast_mut::<$TO>().unwrap().append_null() + }}; + } + match data_type { + DataType::Boolean => append_null!(BooleanBuilder), + DataType::Int8 => append_null!(Int8Builder), + DataType::Int16 => append_null!(Int16Builder), + DataType::Int32 => append_null!(Int32Builder), + DataType::Int64 => append_null!(Int64Builder), + DataType::UInt8 => append_null!(UInt8Builder), + DataType::UInt16 => append_null!(UInt16Builder), + DataType::UInt32 => append_null!(UInt32Builder), + DataType::UInt64 => append_null!(UInt64Builder), + DataType::Float32 => append_null!(Float32Builder), + DataType::Float64 => append_null!(Float64Builder), + DataType::Utf8 => append_null!(GenericStringBuilder), + _ => todo!(), + } +} + +fn array_append_value( + data_type: &DataType, + to: &mut Box, + from: &dyn Array, + idx: usize, +) -> ArrowResult<()> { + macro_rules! append_value { + ($TO:ty, $FROM:ty) => {{ + let to = to.as_any_mut().downcast_mut::<$TO>().unwrap(); + let from = from.as_any().downcast_ref::<$FROM>().unwrap(); + if from.is_valid(idx) { + to.append_value(from.value(idx)) + } else { + to.append_null() + } + }}; + } + + match data_type { + DataType::Boolean => append_value!(BooleanBuilder, BooleanArray), + DataType::Int8 => append_value!(Int8Builder, Int8Array), + DataType::Int16 => append_value!(Int16Builder, Int16Array), + DataType::Int32 => append_value!(Int32Builder, Int32Array), + DataType::Int64 => append_value!(Int64Builder, Int64Array), + DataType::UInt8 => append_value!(UInt8Builder, UInt8Array), + DataType::UInt16 => append_value!(UInt16Builder, UInt16Array), + DataType::UInt32 => append_value!(UInt32Builder, UInt32Array), + DataType::UInt64 => append_value!(UInt64Builder, UInt64Array), + DataType::Float32 => append_value!(Float32Builder, Float32Array), + DataType::Float64 => append_value!(Float64Builder, Float64Array), + DataType::Utf8 => { + append_value!(GenericStringBuilder, GenericStringArray) + } + _ => todo!(), + } +} + +#[cfg(test)] +mod tests { + use arrow::array::Int32Array; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use std::sync::Arc; + + use arrow::record_batch::RecordBatch; + + use crate::assert_batches_sorted_eq; + use crate::error::Result; + use crate::logical_plan::JoinType; + use crate::physical_plan::expressions::Column; + use crate::physical_plan::join_utils::JoinOn; + use crate::physical_plan::memory::MemoryExec; + use crate::physical_plan::sort_merge_join::SortMergeJoinExec; + use crate::physical_plan::{common, ExecutionPlan}; + use crate::prelude::{SessionConfig, SessionContext}; + use crate::test::{build_table_i32, columns}; + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + /// returns a table with 3 columns of i32 in memory + pub fn build_table_i32_nullable( + a: (&str, &Vec>), + b: (&str, &Vec>), + c: (&str, &Vec>), + ) -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new(a.0, DataType::Int32, true), + Field::new(b.0, DataType::Int32, true), + Field::new(c.0, DataType::Int32, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + fn join( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + ) -> Result { + SortMergeJoinExec::try_new( + left, + right, + on, + join_type, + SortOptions::default(), + false, + ) + } + + fn join_with_options( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + sort_options: SortOptions, + null_equals_null: bool, + ) -> Result { + SortMergeJoinExec::try_new( + left, + right, + on, + join_type, + sort_options, + null_equals_null, + ) + } + + async fn join_collect( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + join_collect_with_options( + left, + right, + on, + join_type, + SortOptions::default(), + false, + ) + .await + } + + async fn join_collect_with_options( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + sort_options: SortOptions, + null_equals_null: bool, + ) -> Result<(Vec, Vec)> { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let join = join_with_options( + left, + right, + on, + join_type, + sort_options, + null_equals_null, + )?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx).await?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + + async fn join_collect_batch_size_equals_two( + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + let session_ctx = + SessionContext::with_config(SessionConfig::new().with_batch_size(2)); + let task_ctx = session_ctx.task_ctx(); + let join = join(left, right, on, join_type)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx).await?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + + #[tokio::test] + async fn join_inner_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_inner_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b2", &vec![1, 2, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + + let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_inner_with_nulls() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]), + ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field + ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]), + ("b2", &vec![None, Some(1), Some(2), Some(2)]), + ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + + let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_inner_with_nulls_with_options() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(2), Some(2), Some(1), Some(1)]), + ("b2", &vec![Some(2), Some(2), Some(1), None]), // null in key field + ("c1", &vec![Some(9), Some(8), None, Some(1)]), // null in non-key field + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(3), Some(2), Some(1), Some(1)]), + ("b2", &vec![Some(2), Some(2), Some(1), None]), + ("c2", &vec![Some(90), Some(80), Some(70), Some(10)]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + + let (_, batches) = join_collect_with_options( + left, + right, + on, + JoinType::Inner, + SortOptions { + descending: true, + nulls_first: false, + }, + true, + ) + .await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | | 1 | 1 | | 10 |", + "| 1 | 1 | | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + //assert_eq!(batches.len(), 1); + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] + async fn join_inner_output_two_batches() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b2", &vec![1, 2, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + + let (_, batches) = + join_collect_batch_size_equals_two(left, right, on, JoinType::Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 2); + assert_eq!(batches[1].num_rows(), 1); + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + /// Test where the left has 1 part, the right has 2 parts => 2 parts + #[tokio::test] + async fn join_inner_one_two_parts_right() -> Result<()> { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + + let batch1 = build_table_i32( + ("a2", &vec![10, 20]), + ("b1", &vec![4, 6]), + ("c2", &vec![70, 80]), + ); + let batch2 = + build_table_i32(("a2", &vec![30]), ("b1", &vec![5]), ("c2", &vec![90])); + let schema = batch1.schema(); + let right = Arc::new( + MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), + ); + + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let join = join(left, right, on, JoinType::Inner)?; + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + // first part + let stream = join.execute(0, task_ctx.clone()).await?; + let batches = common::collect(stream).await?; + assert_eq!(batches.len(), 1); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + // second part + let stream = join.execute(1, task_ctx.clone()).await?; + let batches = common::collect(stream).await?; + assert_eq!(batches.len(), 1); + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 2 | 5 | 8 | 30 | 5 | 90 |", + "| 3 | 5 | 9 | 30 | 5 | 90 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_left_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_full_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Full).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_anti() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3, 5]), + ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9, 11]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Anti).await?; + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 3 | 7 | 9 |", + "| 5 | 7 | 11 |", + "+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_semi() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 5 is double on the right + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Semi).await?; + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_with_duplicated_column_names() -> Result<()> { + let left = build_table( + ("a", &vec![1, 2, 3]), + ("b", &vec![4, 5, 7]), + ("c", &vec![7, 8, 9]), + ); + let right = build_table( + ("a", &vec![10, 20, 30]), + ("b", &vec![1, 2, 7]), + ("c", &vec![70, 80, 90]), + ); + let on = vec![( + // join on a=b so there are duplicate column names on unjoined columns + Column::new_with_schema("a", &left.schema())?, + Column::new_with_schema("b", &right.schema())?, + )]; + + let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let expected = vec![ + "+---+---+---+----+---+----+", + "| a | b | c | a | b | c |", + "+---+---+---+----+---+----+", + "| 1 | 4 | 7 | 10 | 1 | 70 |", + "| 2 | 5 | 8 | 20 | 2 | 80 |", + "+---+---+---+----+---+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } +} From d1603933bff0ee32059b088844ba5138665f2414 Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Mon, 18 Apr 2022 11:35:19 +0800 Subject: [PATCH 02/10] Complete doc comments and pass cargo clippy --- .../core/src/physical_plan/sort_merge_join.rs | 110 +++++++++++++++--- 1 file changed, 95 insertions(+), 15 deletions(-) diff --git a/datafusion/core/src/physical_plan/sort_merge_join.rs b/datafusion/core/src/physical_plan/sort_merge_join.rs index 6ea552a4933d5..db61f8d722cd1 100644 --- a/datafusion/core/src/physical_plan/sort_merge_join.rs +++ b/datafusion/core/src/physical_plan/sort_merge_join.rs @@ -15,6 +15,10 @@ // specific language governing permissions and limitations // under the License. +//! Defines the Sort-Merge join execution plan. +//! A sort-merge join plan consumes two sorted children plan and produces +//! joined output by given join type and other options. + use std::any::Any; use std::cmp::Ordering; use std::collections::VecDeque; @@ -45,19 +49,32 @@ use crate::physical_plan::{ Statistics, }; +/// join execution plan executes partitions in parallel and combines them into a set of +/// partitions. #[derive(Debug)] pub struct SortMergeJoinExec { + /// Left sorted joining execution plan left: Arc, + /// Right sorting joining execution plan right: Arc, + /// Set of common columns used to join on on: JoinOn, + /// How the join is performed join_type: JoinType, + /// The schema once the join is applied schema: SchemaRef, + /// Execution metrics metrics: ExecutionPlanMetricsSet, + /// Sort options used in sorting left and right execution plans sort_options: SortOptions, + /// If null_equals_null is true, null == null else null != null null_equals_null: bool, } impl SortMergeJoinExec { + /// Tries to create a new [SortMergeJoinExec]. + /// # Error + /// This function errors when it is not possible to join the left and right sides on keys `on`. pub fn try_new( left: Arc, right: Arc, @@ -186,8 +203,8 @@ impl ExecutionPlan for SortMergeJoinExec { } } -/// Metrics for SortMergeJoinExec -#[warn(dead_code)] +/// Metrics for SortMergeJoinExec (Not yet implemented) +#[allow(dead_code)] struct SortMergeJoinMetrics { /// Total time for joining probe-side batches to the build-side batches join_time: metrics::Time, @@ -222,35 +239,54 @@ impl SortMergeJoinMetrics { } } +/// State of SMJ stream #[derive(Debug, PartialEq, Eq)] enum SMJState { + /// Init joining with a new streamed row or a new buffered batches Init, + /// Polling one streamed row or one buffered batch, or both Polling, + /// Joining polled data and making output JoinOutput, + /// No more output Exhausted, } +/// State of streamed data stream #[derive(Debug, PartialEq, Eq)] enum StreamedState { + /// Init polling Init, + /// Polling one streamed row Polling, + /// Ready to produce one streamed row Ready, + /// No more streamed row Exhausted, } +/// State of buffered data stream #[derive(Debug, PartialEq, Eq)] enum BufferedState { + /// Init polling Init, + /// Polling first row in the next batch PollingFirst, + /// Polling rest rows in the next batch PollingRest, + /// Ready to produce one batch Ready, + /// No more buffered batches Exhausted, } -#[derive(Debug)] +/// A buffered batch that contains contiguous rows with same join key struct BufferedBatch { + /// The buffered record batch pub batch: RecordBatch, + /// The range in which the rows share the same join key pub range: Range, + /// Array refs of the join key pub join_arrays: Vec, } impl BufferedBatch { @@ -264,30 +300,58 @@ impl BufferedBatch { } } +/// Sort-merge join stream that consumes streamed and buffered data stream +/// and produces joined output struct SMJStream { + /// Current state of the stream pub state: SMJState, + /// Output schema pub schema: SchemaRef, + /// Sort options used to sort streamed and buffered data stream pub sort_options: SortOptions, + /// null == null? pub null_equals_null: bool, + /// Input schema of streamed pub streamed_schema: SchemaRef, + /// Input schema of buffered pub buffered_schema: SchemaRef, + /// Number of columns of streamed pub num_streamed_columns: usize, + /// Number of columns of buffered pub num_buffered_columns: usize, + /// Streamed data stream pub streamed: SendableRecordBatchStream, + /// Buffered data stream pub buffered: SendableRecordBatchStream, + /// Current processing record batch of streamed pub streamed_batch: RecordBatch, + /// Current processing streamed join arrays + pub streamed_join_arrays: Vec, + /// Current processing row of streamed pub streamed_idx: usize, + /// Currrent buffered data pub buffered_data: BufferedData, + /// (used in outer join) Is current streamed row joined at least once? pub streamed_joined: bool, + /// (used in outer join) Is current buffered batches joined at least once? pub buffered_joined: bool, + /// State of streamed pub streamed_state: StreamedState, + /// State of buffered pub buffered_state: BufferedState, + /// The comparison result of current streamed row and buffered batches pub current_ordering: Ordering, + /// Join key columns of streamed pub on_streamed: Vec, + /// Join key columns of buffered pub on_buffered: Vec, + /// Staging output array builders pub output_buffer: Vec>, + /// Staging output size pub output_size: usize, + /// Target output batch size pub batch_size: usize, + /// How the join is performed pub join_type: JoinType, } @@ -398,6 +462,7 @@ impl Stream for SMJStream { } impl SMJStream { + #[allow(clippy::too_many_arguments)] pub fn try_new( schema: SchemaRef, sort_options: SortOptions, @@ -422,6 +487,7 @@ impl SMJStream { streamed, buffered, streamed_batch: RecordBatch::new_empty(schema), + streamed_join_arrays: vec![], streamed_idx: 0, buffered_data: BufferedData::default(), streamed_joined: false, @@ -438,6 +504,7 @@ impl SMJStream { }) } + /// Poll next streamed row fn poll_streamed_row(&mut self, cx: &mut Context) -> Poll>> { loop { match &self.streamed_state { @@ -461,6 +528,10 @@ impl SMJStream { Poll::Ready(Some(batch)) => { if batch.num_rows() > 0 { self.streamed_batch = batch; + self.streamed_join_arrays = join_arrays( + &self.streamed_batch, + &self.on_streamed + ); self.streamed_idx = 0; self.streamed_state = StreamedState::Ready; } @@ -476,6 +547,7 @@ impl SMJStream { } } + /// Poll next buffered batches fn poll_buffered_batches( &mut self, cx: &mut Context, @@ -567,6 +639,7 @@ impl SMJStream { } } + /// Get comparison result of streamed row and buffered batches fn compare_streamed_buffered(&self) -> ArrowResult { if self.streamed_state == StreamedState::Exhausted { return Ok(Ordering::Greater); @@ -576,15 +649,17 @@ impl SMJStream { } return compare_join_arrays( - &join_arrays(&self.streamed_batch, &self.on_streamed), + &self.streamed_join_arrays, self.streamed_idx, - &join_arrays(&self.buffered_data.head_batch().batch, &self.on_buffered), + &self.buffered_data.head_batch().join_arrays, self.buffered_data.head_batch().range.start, self.sort_options, self.null_equals_null, ); } + /// Produce join and fill output buffer until reaching target batch size + /// or the join is finished fn join_partial(&mut self) -> ArrowResult<()> { // decide streamed/buffered output columns by join type let output_parts = @@ -679,11 +754,6 @@ impl SMJStream { Ok(()) } - #[inline] - fn output_buffer_full(&self) -> bool { - self.output_size == self.batch_size - } - fn output_record_batch_and_reset(&mut self) -> ArrowResult { let record_batch = make_batch(self.schema.clone(), self.output_buffer.drain(..).collect())?; @@ -694,10 +764,14 @@ impl SMJStream { } } +/// Buffered data contains all buffered batches with one unique join key #[derive(Default)] struct BufferedData { + /// Buffered batches with the same key pub batches: VecDeque, + /// current scanning batch index used in join_partial() pub scanning_batch_idx: usize, + /// current scanning offset used in join_partial() pub scanning_offset: usize, } impl BufferedData { @@ -709,10 +783,6 @@ impl BufferedData { self.batches.back().unwrap() } - pub fn head_batch_mut(&mut self) -> &mut BufferedBatch { - self.batches.front_mut().unwrap() - } - pub fn tail_batch_mut(&mut self) -> &mut BufferedBatch { self.batches.back_mut().unwrap() } @@ -756,6 +826,7 @@ impl BufferedData { } } +/// Get join array refs of given batch and join columns fn join_arrays(batch: &RecordBatch, on_column: &[Column]) -> Vec { on_column .iter() @@ -763,6 +834,7 @@ fn join_arrays(batch: &RecordBatch, on_column: &[Column]) -> Vec { .collect() } +/// Get comparison result of two rows of join arrays fn compare_join_arrays( left_arrays: &[ArrayRef], left: usize, @@ -838,6 +910,8 @@ fn compare_join_arrays( Ok(res) } +/// A faster version of compare_join_arrays() that only output whether +/// the given two rows are equal fn is_join_arrays_equal( left_arrays: &[ArrayRef], left: usize, @@ -890,6 +964,7 @@ fn is_join_arrays_equal( Ok(true) } +/// Create new array builders of given schema and batch size fn new_array_builders( schema: SchemaRef, batch_size: usize, @@ -905,6 +980,7 @@ fn new_array_builders( Ok(arrays) } +/// Append one row to part of output buffer (the array builders) fn append_row_to_output( batch: &RecordBatch, idx: usize, @@ -923,6 +999,8 @@ fn append_row_to_output( Ok(()) } +/// Append one row which all values are null to part of output buffer (the +/// array builders), used in outer join fn append_nulls_row_to_output( schema: &Schema, arrays: &mut [Box], @@ -937,6 +1015,7 @@ fn append_nulls_row_to_output( Ok(()) } +/// Finish output buffer and produce one record batch fn make_batch( schema: SchemaRef, mut arrays: Vec>, @@ -945,7 +1024,7 @@ fn make_batch( RecordBatch::try_new(schema, columns) } -/// repeat times of cell located by `idx` at streamed side to output +/// Append null value to a array builder fn array_append_null( data_type: &DataType, to: &mut Box, @@ -972,6 +1051,7 @@ fn array_append_null( } } +/// Append value to a array builder fn array_append_value( data_type: &DataType, to: &mut Box, From fcb596e720865b77fc5e74540ad02bc9a5d3d29e Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Mon, 18 Apr 2022 15:03:09 +0800 Subject: [PATCH 03/10] Implement metrics for SMJ --- .../core/src/physical_plan/sort_merge_join.rs | 58 +++++++++++++++++-- 1 file changed, 54 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/physical_plan/sort_merge_join.rs b/datafusion/core/src/physical_plan/sort_merge_join.rs index db61f8d722cd1..b5692769bdd80 100644 --- a/datafusion/core/src/physical_plan/sort_merge_join.rs +++ b/datafusion/core/src/physical_plan/sort_merge_join.rs @@ -191,6 +191,7 @@ impl ExecutionPlan for SortMergeJoinExec { self.join_type, output_buffer, batch_size, + SortMergeJoinMetrics::new(partition, &self.metrics), )?)) } @@ -353,6 +354,8 @@ struct SMJStream { pub batch_size: usize, /// How the join is performed pub join_type: JoinType, + /// Metrics + pub join_metrics: SortMergeJoinMetrics, } impl RecordBatchStream for SMJStream { @@ -368,6 +371,7 @@ impl Stream for SMJStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { + self.join_metrics.join_time.timer(); loop { match &self.state { SMJState::Init => { @@ -437,6 +441,8 @@ impl Stream for SMJStream { self.join_partial()?; if self.output_size == self.batch_size { let record_batch = self.output_record_batch_and_reset()?; + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(record_batch.num_rows()); return Poll::Ready(Some(Ok(record_batch))); } if self.buffered_data.scanning_finished() { @@ -452,6 +458,8 @@ impl Stream for SMJStream { SMJState::Exhausted => { if self.output_size > 0 { let record_batch = self.output_record_batch_and_reset()?; + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(record_batch.num_rows()); return Poll::Ready(Some(Ok(record_batch))); } return Poll::Ready(None); @@ -474,6 +482,7 @@ impl SMJStream { join_type: JoinType, output_buffer: Vec>, batch_size: usize, + join_metrics: SortMergeJoinMetrics, ) -> Result { Ok(Self { state: SMJState::Init, @@ -501,6 +510,7 @@ impl SMJStream { output_size: 0, batch_size, join_type, + join_metrics, }) } @@ -527,11 +537,11 @@ impl SMJStream { } Poll::Ready(Some(batch)) => { if batch.num_rows() > 0 { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); self.streamed_batch = batch; - self.streamed_join_arrays = join_arrays( - &self.streamed_batch, - &self.on_streamed - ); + self.streamed_join_arrays = + join_arrays(&self.streamed_batch, &self.on_streamed); self.streamed_idx = 0; self.streamed_state = StreamedState::Ready; } @@ -582,6 +592,8 @@ impl SMJStream { return Poll::Ready(None); } Poll::Ready(Some(batch)) => { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); if batch.num_rows() > 0 { self.buffered_data.batches.push_back(BufferedBatch::new( batch, @@ -620,6 +632,8 @@ impl SMJStream { self.buffered_state = BufferedState::Ready; } Poll::Ready(Some(batch)) => { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); self.buffered_data.batches.push_back(BufferedBatch::new( batch, 0..0, @@ -1490,6 +1504,42 @@ mod tests { "+----+----+----+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); + + let metrics = join.metrics().unwrap(); + assert!( + 0 < metrics + .sum(|m| m.value().name() == "join_time") + .map(|v| v.as_usize()) + .unwrap() + ); + assert_eq!( + 2, + metrics + .sum(|m| m.value().name() == "output_batches") + .map(|v| v.as_usize()) + .unwrap() + ); // 1+1 + assert_eq!( + 3, + metrics + .sum(|m| m.value().name() == "output_rows") + .map(|v| v.as_usize()) + .unwrap() + ); // 2+1 + assert_eq!( + 4, + metrics + .sum(|m| m.value().name() == "input_batches") + .map(|v| v.as_usize()) + .unwrap() + ); // (1+1) + (1+1) + assert_eq!( + 9, + metrics + .sum(|m| m.value().name() == "input_rows") + .map(|v| v.as_usize()) + .unwrap() + ); // (3+2) + (3+1) Ok(()) } From f8515594ba603520580ca9d06ef95ed31b68c388 Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Mon, 18 Apr 2022 15:52:47 +0800 Subject: [PATCH 04/10] Support join columns with different sort options --- .../core/src/physical_plan/sort_merge_join.rs | 69 +++++++++---------- 1 file changed, 34 insertions(+), 35 deletions(-) diff --git a/datafusion/core/src/physical_plan/sort_merge_join.rs b/datafusion/core/src/physical_plan/sort_merge_join.rs index b5692769bdd80..004115cdc19d3 100644 --- a/datafusion/core/src/physical_plan/sort_merge_join.rs +++ b/datafusion/core/src/physical_plan/sort_merge_join.rs @@ -65,8 +65,8 @@ pub struct SortMergeJoinExec { schema: SchemaRef, /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// Sort options used in sorting left and right execution plans - sort_options: SortOptions, + /// Sort options of join columns used in sorting left and right execution plans + sort_options: Vec, /// If null_equals_null is true, null == null else null != null null_equals_null: bool, } @@ -80,13 +80,21 @@ impl SortMergeJoinExec { right: Arc, on: JoinOn, join_type: JoinType, - sort_options: SortOptions, + sort_options: Vec, null_equals_null: bool, ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); check_join_is_valid(&left_schema, &right_schema, &on)?; + if sort_options.len() != on.len() { + return Err(DataFusionError::Plan(format!( + "Expected number of sort options: {}, actual: {}", + on.len(), + sort_options.len() + ))); + } + let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); @@ -135,7 +143,7 @@ impl ExecutionPlan for SortMergeJoinExec { right.clone(), self.on.clone(), self.join_type, - self.sort_options, + self.sort_options.clone(), self.null_equals_null, )?)), _ => Err(DataFusionError::Internal( @@ -182,7 +190,7 @@ impl ExecutionPlan for SortMergeJoinExec { // create join stream Ok(Box::pin(SMJStream::try_new( self.schema.clone(), - self.sort_options, + self.sort_options.clone(), self.null_equals_null, streamed, buffered, @@ -308,8 +316,8 @@ struct SMJStream { pub state: SMJState, /// Output schema pub schema: SchemaRef, - /// Sort options used to sort streamed and buffered data stream - pub sort_options: SortOptions, + /// Sort options of join columns used to sort streamed and buffered data stream + pub sort_options: Vec, /// null == null? pub null_equals_null: bool, /// Input schema of streamed @@ -473,7 +481,7 @@ impl SMJStream { #[allow(clippy::too_many_arguments)] pub fn try_new( schema: SchemaRef, - sort_options: SortOptions, + sort_options: Vec, null_equals_null: bool, streamed: SendableRecordBatchStream, buffered: SendableRecordBatchStream, @@ -667,7 +675,7 @@ impl SMJStream { self.streamed_idx, &self.buffered_data.head_batch().join_arrays, self.buffered_data.head_batch().range.start, - self.sort_options, + &self.sort_options, self.null_equals_null, ); } @@ -854,11 +862,13 @@ fn compare_join_arrays( left: usize, right_arrays: &[ArrayRef], right: usize, - sort_options: SortOptions, + sort_options: &[SortOptions], null_equals_null: bool, ) -> ArrowResult { let mut res = Ordering::Equal; - for (left_array, right_array) in left_arrays.iter().zip(right_arrays) { + for ((left_array, right_array), sort_options) in + left_arrays.iter().zip(right_arrays).zip(sort_options) + { macro_rules! compare_value { ($T:ty) => {{ let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); @@ -1162,14 +1172,8 @@ mod tests { on: JoinOn, join_type: JoinType, ) -> Result { - SortMergeJoinExec::try_new( - left, - right, - on, - join_type, - SortOptions::default(), - false, - ) + let sort_options = vec![SortOptions::default(); on.len()]; + SortMergeJoinExec::try_new(left, right, on, join_type, sort_options, false) } fn join_with_options( @@ -1177,7 +1181,7 @@ mod tests { right: Arc, on: JoinOn, join_type: JoinType, - sort_options: SortOptions, + sort_options: Vec, null_equals_null: bool, ) -> Result { SortMergeJoinExec::try_new( @@ -1196,15 +1200,8 @@ mod tests { on: JoinOn, join_type: JoinType, ) -> Result<(Vec, Vec)> { - join_collect_with_options( - left, - right, - on, - join_type, - SortOptions::default(), - false, - ) - .await + let sort_options = vec![SortOptions::default(); on.len()]; + join_collect_with_options(left, right, on, join_type, sort_options, false).await } async fn join_collect_with_options( @@ -1212,7 +1209,7 @@ mod tests { right: Arc, on: JoinOn, join_type: JoinType, - sort_options: SortOptions, + sort_options: Vec, null_equals_null: bool, ) -> Result<(Vec, Vec)> { let session_ctx = SessionContext::new(); @@ -1378,16 +1375,18 @@ mod tests { Column::new_with_schema("b2", &right.schema())?, ), ]; - let (_, batches) = join_collect_with_options( left, right, on, JoinType::Inner, - SortOptions { - descending: true, - nulls_first: false, - }, + vec![ + SortOptions { + descending: true, + nulls_first: false + }; + 2 + ], true, ) .await?; From 78eeb7e5ea26b4024000b3a3c45320fc4a3e32e9 Mon Sep 17 00:00:00 2001 From: Zhang Li Date: Tue, 19 Apr 2022 22:19:01 +0800 Subject: [PATCH 05/10] Update datafusion/core/src/physical_plan/sort_merge_join.rs Add detailed comments of the ordering requirements of two input children. Co-authored-by: Andrew Lamb --- datafusion/core/src/physical_plan/sort_merge_join.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/core/src/physical_plan/sort_merge_join.rs b/datafusion/core/src/physical_plan/sort_merge_join.rs index 004115cdc19d3..e4ab7c4286146 100644 --- a/datafusion/core/src/physical_plan/sort_merge_join.rs +++ b/datafusion/core/src/physical_plan/sort_merge_join.rs @@ -73,6 +73,7 @@ pub struct SortMergeJoinExec { impl SortMergeJoinExec { /// Tries to create a new [SortMergeJoinExec]. + /// The inputs are sorted using `sort_options` are applied to the columns in the `on` /// # Error /// This function errors when it is not possible to join the left and right sides on keys `on`. pub fn try_new( From d6531e92fbfcf6eec41464ed27c9a55778508e56 Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Wed, 20 Apr 2022 19:59:07 +0800 Subject: [PATCH 06/10] use indices instead of ArrayBuilders for constructing output record batches --- .../core/src/physical_plan/sort_merge_join.rs | 578 +++++++++--------- 1 file changed, 288 insertions(+), 290 deletions(-) diff --git a/datafusion/core/src/physical_plan/sort_merge_join.rs b/datafusion/core/src/physical_plan/sort_merge_join.rs index e4ab7c4286146..ffbf23b491d7b 100644 --- a/datafusion/core/src/physical_plan/sort_merge_join.rs +++ b/datafusion/core/src/physical_plan/sort_merge_join.rs @@ -28,8 +28,8 @@ use std::sync::Arc; use std::task::{Context, Poll}; use arrow::array::*; -use arrow::compute::SortOptions; -use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::compute::{take, SortOptions}; +use arrow::datatypes::{DataType, SchemaRef}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use async_trait::async_trait; @@ -39,7 +39,7 @@ use crate::error::DataFusionError; use crate::error::Result; use crate::execution::context::TaskContext; use crate::logical_plan::JoinType; -use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::common::combine_batches; use crate::physical_plan::expressions::Column; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::join_utils::{build_join_schema, check_join_is_valid, JoinOn}; @@ -178,15 +178,11 @@ impl ExecutionPlan for SortMergeJoinExec { }; // execute children plans - let streamed = CoalescePartitionsExec::new(streamed) - .execute(0, context.clone()) - .await?; + let streamed = streamed.execute(partition, context.clone()).await?; let buffered = buffered.execute(partition, context.clone()).await?; // create output buffer let batch_size = context.session_config().batch_size; - let output_buffer = new_array_builders(self.schema(), batch_size) - .map_err(DataFusionError::ArrowError)?; // create join stream Ok(Box::pin(SMJStream::try_new( @@ -198,7 +194,6 @@ impl ExecutionPlan for SortMergeJoinExec { on_streamed, on_buffered, self.join_type, - output_buffer, batch_size, SortMergeJoinMetrics::new(partition, &self.metrics), )?)) @@ -325,10 +320,6 @@ struct SMJStream { pub streamed_schema: SchemaRef, /// Input schema of buffered pub buffered_schema: SchemaRef, - /// Number of columns of streamed - pub num_streamed_columns: usize, - /// Number of columns of buffered - pub num_buffered_columns: usize, /// Streamed data stream pub streamed: SendableRecordBatchStream, /// Buffered data stream @@ -356,7 +347,7 @@ struct SMJStream { /// Join key columns of buffered pub on_buffered: Vec, /// Staging output array builders - pub output_buffer: Vec>, + pub output_record_batches: Vec, /// Staging output size pub output_size: usize, /// Target output batch size @@ -447,11 +438,11 @@ impl Stream for SMJStream { self.state = SMJState::JoinOutput; } SMJState::JoinOutput => { - self.join_partial()?; + let output_indices = self.join_partial()?; + self.output_partial(&output_indices)?; + if self.output_size == self.batch_size { let record_batch = self.output_record_batch_and_reset()?; - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(record_batch.num_rows()); return Poll::Ready(Some(Ok(record_batch))); } if self.buffered_data.scanning_finished() { @@ -465,10 +456,8 @@ impl Stream for SMJStream { } } SMJState::Exhausted => { - if self.output_size > 0 { + if !self.output_record_batches.is_empty() { let record_batch = self.output_record_batch_and_reset()?; - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(record_batch.num_rows()); return Poll::Ready(Some(Ok(record_batch))); } return Poll::Ready(None); @@ -489,7 +478,6 @@ impl SMJStream { on_streamed: Vec, on_buffered: Vec, join_type: JoinType, - output_buffer: Vec>, batch_size: usize, join_metrics: SortMergeJoinMetrics, ) -> Result { @@ -500,8 +488,6 @@ impl SMJStream { schema: schema.clone(), streamed_schema: streamed.schema(), buffered_schema: buffered.schema(), - num_streamed_columns: streamed.schema().fields().len(), - num_buffered_columns: buffered.schema().fields().len(), streamed, buffered, streamed_batch: RecordBatch::new_empty(schema), @@ -515,7 +501,7 @@ impl SMJStream { current_ordering: Ordering::Equal, on_streamed, on_buffered, - output_buffer, + output_record_batches: vec![], output_size: 0, batch_size, join_type, @@ -683,21 +669,8 @@ impl SMJStream { /// Produce join and fill output buffer until reaching target batch size /// or the join is finished - fn join_partial(&mut self) -> ArrowResult<()> { - // decide streamed/buffered output columns by join type - let output_parts = - self.output_buffer - .split_at_mut(if self.join_type != JoinType::Right { - self.num_streamed_columns - } else { - self.num_buffered_columns - }); - let (streamed_output, buffered_output) = if self.join_type != JoinType::Right { - (output_parts.0, output_parts.1) - } else { - (output_parts.1, output_parts.0) - }; - + fn join_partial(&mut self) -> ArrowResult> { + let mut output_indices = vec![]; match self.current_ordering { Ordering::Less => { let output_streamed_join = match self.join_type { @@ -710,12 +683,10 @@ impl SMJStream { // streamed joins null if output_streamed_join { - append_row_to_output( - &self.streamed_batch, - self.streamed_idx, - streamed_output, - )?; - append_nulls_row_to_output(&self.buffered_schema, buffered_output)?; + output_indices.push(OutputIndex { + streamed_idx: Some(self.streamed_idx), + buffered_idx: None, + }); self.output_size += 1; } self.buffered_data.scanning_finish(); @@ -731,7 +702,18 @@ impl SMJStream { }; // streamed joins buffered - if !output_equal_join { + if output_equal_join { + if JoinType::Semi == self.join_type { + if !self.streamed_joined { + output_indices.push(OutputIndex { + streamed_idx: Some(self.streamed_idx), + buffered_idx: None, + }); + self.output_size += 1; + } + self.buffered_data.scanning_finish(); + } + } else { self.buffered_data.scanning_finish(); } } @@ -746,7 +728,18 @@ impl SMJStream { }; // null joins buffered - if !output_buffered_join { + if output_buffered_join { + if JoinType::Anti == self.join_type { + if !self.streamed_joined { + output_indices.push(OutputIndex { + streamed_idx: Some(self.streamed_idx), + buffered_idx: None, + }); + self.output_size += 1; + } + self.buffered_data.scanning_finish(); + } + } else { self.buffered_data.scanning_finish(); } } @@ -757,34 +750,251 @@ impl SMJStream { && self.output_size < self.batch_size { if self.current_ordering == Ordering::Equal { - append_row_to_output( - &self.streamed_batch, - self.streamed_idx, - streamed_output, - )?; + output_indices.push(OutputIndex { + streamed_idx: Some(self.streamed_idx), + buffered_idx: Some(( + self.buffered_data.scanning_batch_idx, + self.buffered_data.scanning_idx(), + )), + }); } else { - append_nulls_row_to_output(&self.streamed_schema, streamed_output)?; + output_indices.push(OutputIndex { + streamed_idx: None, + buffered_idx: Some(( + self.buffered_data.scanning_batch_idx, + self.buffered_data.scanning_idx(), + )), + }); } - - append_row_to_output( - &self.buffered_data.scanning_batch().batch, - self.buffered_data.scanning_idx(), - buffered_output, - )?; self.output_size += 1; self.buffered_data.scanning_advance(); } - Ok(()) + Ok(output_indices) } fn output_record_batch_and_reset(&mut self) -> ArrowResult { + assert!(!self.output_record_batches.is_empty()); + let record_batch = - make_batch(self.schema.clone(), self.output_buffer.drain(..).collect())?; + combine_batches(&self.output_record_batches, self.schema.clone())?.unwrap(); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(record_batch.num_rows()); self.output_size = 0; - self.output_buffer - .extend(new_array_builders(self.schema.clone(), self.batch_size)?); + self.output_record_batches.clear(); Ok(record_batch) } + + fn output_partial(&mut self, output_indices: &[OutputIndex]) -> ArrowResult<()> { + match self.join_type { + JoinType::Inner => { + self.output_partial_streamed_joining_buffered(output_indices)?; + } + JoinType::Left | JoinType::Right => { + self.output_partial_streamed_joining_buffered(output_indices)?; + self.output_partial_streamed_joining_null(output_indices)?; + } + JoinType::Full => { + self.output_partial_streamed_joining_buffered(output_indices)?; + self.output_partial_streamed_joining_null(output_indices)?; + self.output_partial_null_joining_buffered(output_indices)?; + } + JoinType::Semi | JoinType::Anti => { + self.output_partial_streamed_joining_null(output_indices)?; + } + } + Ok(()) + } + + fn output_partial_streamed_joining_buffered( + &mut self, + output_indices: &[OutputIndex], + ) -> ArrowResult<()> { + let mut output = |buffered_batch_idx: usize, indices: &[OutputIndex]| { + if indices.is_empty() { + return ArrowResult::Ok(()); + } + + // take streamed columns + let streamed_indices = UInt64Array::from_iter_values( + indices + .iter() + .map(|index| index.streamed_idx.unwrap() as u64), + ); + let mut streamed_columns = self + .streamed_batch + .columns() + .iter() + .map(|column| take(column, &streamed_indices, None)) + .collect::>>()?; + + // take buffered columns + let buffered_indices = UInt64Array::from_iter_values( + indices + .iter() + .map(|index| index.buffered_idx.unwrap().1 as u64), + ); + let mut buffered_columns = self.buffered_data.batches[buffered_batch_idx] + .batch + .columns() + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::>>()?; + + // combine columns and produce record batch + let columns = match self.join_type { + JoinType::Inner | JoinType::Left | JoinType::Full => { + streamed_columns.extend(buffered_columns); + streamed_columns + } + JoinType::Right => { + buffered_columns.extend(streamed_columns); + buffered_columns + } + JoinType::Semi | JoinType::Anti => { + unreachable!() + } + }; + let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?; + self.output_record_batches.push(record_batch); + Ok(()) + }; + + let mut buffered_batch_idx = 0; + let mut indices = vec![]; + for &index in output_indices + .iter() + .filter(|index| index.streamed_idx.is_some()) + .filter(|index| index.buffered_idx.is_some()) + { + let buffered_idx = index.buffered_idx.unwrap(); + if buffered_idx.0 == buffered_batch_idx { + indices.push(index); + } else { + output(buffered_batch_idx, &indices)?; + buffered_batch_idx = buffered_idx.0; + indices.clear(); + } + } + output(buffered_batch_idx, &indices)?; + Ok(()) + } + + fn output_partial_streamed_joining_null( + &mut self, + output_indices: &[OutputIndex], + ) -> ArrowResult<()> { + // streamed joining null + let streamed_indices = UInt64Array::from_iter_values( + output_indices + .iter() + .filter(|index| index.streamed_idx.is_some()) + .filter(|index| index.buffered_idx.is_none()) + .map(|index| index.streamed_idx.unwrap() as u64), + ); + let mut streamed_columns = self + .streamed_batch + .columns() + .iter() + .map(|column| take(column, &streamed_indices, None)) + .collect::>>()?; + + let mut buffered_columns = self + .buffered_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), streamed_indices.len())) + .collect::>(); + + let columns = match self.join_type { + JoinType::Inner => { + unreachable!() + } + JoinType::Left | JoinType::Full => { + streamed_columns.extend(buffered_columns); + streamed_columns + } + JoinType::Right => { + buffered_columns.extend(streamed_columns); + buffered_columns + } + JoinType::Anti | JoinType::Semi => streamed_columns, + }; + + if !streamed_indices.is_empty() { + let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?; + self.output_record_batches.push(record_batch); + } + Ok(()) + } + + fn output_partial_null_joining_buffered( + &mut self, + output_indices: &[OutputIndex], + ) -> ArrowResult<()> { + let mut output = |buffered_batch_idx: usize, indices: &[OutputIndex]| { + if indices.is_empty() { + return ArrowResult::Ok(()); + } + + // take buffered columns + let buffered_indices = UInt64Array::from_iter_values( + indices + .iter() + .map(|index| index.buffered_idx.unwrap().1 as u64), + ); + let buffered_columns = self.buffered_data.batches[buffered_batch_idx] + .batch + .columns() + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::>>()?; + + // create null streamed columns + let mut streamed_columns = self + .streamed_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), buffered_indices.len())) + .collect::>(); + + // combine columns and produce record batch + let columns = match self.join_type { + JoinType::Full => { + streamed_columns.extend(buffered_columns); + streamed_columns + } + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Semi + | JoinType::Anti => { + unreachable!() + } + }; + let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?; + self.output_record_batches.push(record_batch); + Ok(()) + }; + + let mut buffered_batch_idx = 0; + let mut indices = vec![]; + for &index in output_indices + .iter() + .filter(|index| index.streamed_idx.is_none()) + .filter(|index| index.buffered_idx.is_some()) + { + let buffered_idx = index.buffered_idx.unwrap(); + if buffered_idx.0 == buffered_batch_idx { + indices.push(index); + } else { + output(buffered_batch_idx, &indices)?; + buffered_batch_idx = buffered_idx.0; + indices.clear(); + } + } + output(buffered_batch_idx, &indices)?; + Ok(()) + } } /// Buffered data contains all buffered batches with one unique join key @@ -849,6 +1059,14 @@ impl BufferedData { } } +#[derive(Clone, Copy, Debug)] +struct OutputIndex { + /// joined streamed row index + streamed_idx: Option, + /// joined buffered batch index and row index + buffered_idx: Option<(usize, usize)>, +} + /// Get join array refs of given batch and join columns fn join_arrays(batch: &RecordBatch, on_column: &[Column]) -> Vec { on_column @@ -919,7 +1137,7 @@ fn compare_join_arrays( DataType::UInt16 => compare_value!(UInt16Array), DataType::UInt32 => compare_value!(UInt32Array), DataType::UInt64 => compare_value!(UInt64Array), - DataType::Timestamp(_, None) => compare_value!(Int64Array), + DataType::Timestamp(_, _) => compare_value!(Int64Array), DataType::Utf8 => compare_value!(StringArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), _ => { @@ -947,10 +1165,12 @@ fn is_join_arrays_equal( for (left_array, right_array) in left_arrays.iter().zip(right_arrays) { macro_rules! compare_value { ($T:ty) => {{ - let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); - let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); match (left_array.is_null(left), right_array.is_null(right)) { (false, false) => { + let left_array = + left_array.as_any().downcast_ref::<$T>().unwrap(); + let right_array = + right_array.as_any().downcast_ref::<$T>().unwrap(); if left_array.value(left) != right_array.value(right) { is_equal = false; } @@ -989,138 +1209,13 @@ fn is_join_arrays_equal( Ok(true) } -/// Create new array builders of given schema and batch size -fn new_array_builders( - schema: SchemaRef, - batch_size: usize, -) -> ArrowResult>> { - let arrays: Vec> = schema - .fields() - .iter() - .map(|field| { - let dt = field.data_type(); - make_builder(dt, batch_size) - }) - .collect(); - Ok(arrays) -} - -/// Append one row to part of output buffer (the array builders) -fn append_row_to_output( - batch: &RecordBatch, - idx: usize, - arrays: &mut [Box], -) -> ArrowResult<()> { - if !arrays.is_empty() { - return batch - .columns() - .iter() - .zip(batch.schema().fields()) - .enumerate() - .try_for_each(|(i, (column, f))| { - array_append_value(f.data_type(), &mut arrays[i], &*column, idx) - }); - } - Ok(()) -} - -/// Append one row which all values are null to part of output buffer (the -/// array builders), used in outer join -fn append_nulls_row_to_output( - schema: &Schema, - arrays: &mut [Box], -) -> ArrowResult<()> { - if !arrays.is_empty() { - return schema - .fields() - .iter() - .enumerate() - .try_for_each(|(i, f)| array_append_null(f.data_type(), &mut arrays[i])); - } - Ok(()) -} - -/// Finish output buffer and produce one record batch -fn make_batch( - schema: SchemaRef, - mut arrays: Vec>, -) -> ArrowResult { - let columns = arrays.iter_mut().map(|array| array.finish()).collect(); - RecordBatch::try_new(schema, columns) -} - -/// Append null value to a array builder -fn array_append_null( - data_type: &DataType, - to: &mut Box, -) -> ArrowResult<()> { - macro_rules! append_null { - ($TO:ty) => {{ - to.as_any_mut().downcast_mut::<$TO>().unwrap().append_null() - }}; - } - match data_type { - DataType::Boolean => append_null!(BooleanBuilder), - DataType::Int8 => append_null!(Int8Builder), - DataType::Int16 => append_null!(Int16Builder), - DataType::Int32 => append_null!(Int32Builder), - DataType::Int64 => append_null!(Int64Builder), - DataType::UInt8 => append_null!(UInt8Builder), - DataType::UInt16 => append_null!(UInt16Builder), - DataType::UInt32 => append_null!(UInt32Builder), - DataType::UInt64 => append_null!(UInt64Builder), - DataType::Float32 => append_null!(Float32Builder), - DataType::Float64 => append_null!(Float64Builder), - DataType::Utf8 => append_null!(GenericStringBuilder), - _ => todo!(), - } -} - -/// Append value to a array builder -fn array_append_value( - data_type: &DataType, - to: &mut Box, - from: &dyn Array, - idx: usize, -) -> ArrowResult<()> { - macro_rules! append_value { - ($TO:ty, $FROM:ty) => {{ - let to = to.as_any_mut().downcast_mut::<$TO>().unwrap(); - let from = from.as_any().downcast_ref::<$FROM>().unwrap(); - if from.is_valid(idx) { - to.append_value(from.value(idx)) - } else { - to.append_null() - } - }}; - } - - match data_type { - DataType::Boolean => append_value!(BooleanBuilder, BooleanArray), - DataType::Int8 => append_value!(Int8Builder, Int8Array), - DataType::Int16 => append_value!(Int16Builder, Int16Array), - DataType::Int32 => append_value!(Int32Builder, Int32Array), - DataType::Int64 => append_value!(Int64Builder, Int64Array), - DataType::UInt8 => append_value!(UInt8Builder, UInt8Array), - DataType::UInt16 => append_value!(UInt16Builder, UInt16Array), - DataType::UInt32 => append_value!(UInt32Builder, UInt32Array), - DataType::UInt64 => append_value!(UInt64Builder, UInt64Array), - DataType::Float32 => append_value!(Float32Builder, Float32Array), - DataType::Float64 => append_value!(Float64Builder, Float64Array), - DataType::Utf8 => { - append_value!(GenericStringBuilder, GenericStringArray) - } - _ => todo!(), - } -} - #[cfg(test)] mod tests { + use std::sync::Arc; + use arrow::array::Int32Array; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; - use std::sync::Arc; - use arrow::record_batch::RecordBatch; use crate::assert_batches_sorted_eq; @@ -1445,103 +1540,6 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); Ok(()) } - /// Test where the left has 1 part, the right has 2 parts => 2 parts - #[tokio::test] - async fn join_inner_one_two_parts_right() -> Result<()> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 5]), // this has a repetition - ("c1", &vec![7, 8, 9]), - ); - - let batch1 = build_table_i32( - ("a2", &vec![10, 20]), - ("b1", &vec![4, 6]), - ("c2", &vec![70, 80]), - ); - let batch2 = - build_table_i32(("a2", &vec![30]), ("b1", &vec![5]), ("c2", &vec![90])); - let schema = batch1.schema(); - let right = Arc::new( - MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), - ); - - let on = vec![( - Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, - )]; - - let join = join(left, right, on, JoinType::Inner)?; - let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - - // first part - let stream = join.execute(0, task_ctx.clone()).await?; - let batches = common::collect(stream).await?; - assert_eq!(batches.len(), 1); - - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - - // second part - let stream = join.execute(1, task_ctx.clone()).await?; - let batches = common::collect(stream).await?; - assert_eq!(batches.len(), 1); - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+----+", - "| 2 | 5 | 8 | 30 | 5 | 90 |", - "| 3 | 5 | 9 | 30 | 5 | 90 |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - - let metrics = join.metrics().unwrap(); - assert!( - 0 < metrics - .sum(|m| m.value().name() == "join_time") - .map(|v| v.as_usize()) - .unwrap() - ); - assert_eq!( - 2, - metrics - .sum(|m| m.value().name() == "output_batches") - .map(|v| v.as_usize()) - .unwrap() - ); // 1+1 - assert_eq!( - 3, - metrics - .sum(|m| m.value().name() == "output_rows") - .map(|v| v.as_usize()) - .unwrap() - ); // 2+1 - assert_eq!( - 4, - metrics - .sum(|m| m.value().name() == "input_batches") - .map(|v| v.as_usize()) - .unwrap() - ); // (1+1) + (1+1) - assert_eq!( - 9, - metrics - .sum(|m| m.value().name() == "input_rows") - .map(|v| v.as_usize()) - .unwrap() - ); // (3+2) + (3+1) - Ok(()) - } #[tokio::test] async fn join_left_one() -> Result<()> { From f5f24db46c7410642cfa004f813acffa729eab7c Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Wed, 20 Apr 2022 21:18:43 +0800 Subject: [PATCH 07/10] Support timestamp/decimal types in join columns --- .../core/src/physical_plan/sort_merge_join.rs | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/physical_plan/sort_merge_join.rs b/datafusion/core/src/physical_plan/sort_merge_join.rs index ffbf23b491d7b..93477dc1172c1 100644 --- a/datafusion/core/src/physical_plan/sort_merge_join.rs +++ b/datafusion/core/src/physical_plan/sort_merge_join.rs @@ -29,7 +29,7 @@ use std::task::{Context, Poll}; use arrow::array::*; use arrow::compute::{take, SortOptions}; -use arrow::datatypes::{DataType, SchemaRef}; +use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use async_trait::async_trait; @@ -208,7 +208,7 @@ impl ExecutionPlan for SortMergeJoinExec { } } -/// Metrics for SortMergeJoinExec (Not yet implemented) +/// Metrics for SortMergeJoinExec #[allow(dead_code)] struct SortMergeJoinMetrics { /// Total time for joining probe-side batches to the build-side batches @@ -1137,9 +1137,15 @@ fn compare_join_arrays( DataType::UInt16 => compare_value!(UInt16Array), DataType::UInt32 => compare_value!(UInt32Array), DataType::UInt64 => compare_value!(UInt64Array), - DataType::Timestamp(_, _) => compare_value!(Int64Array), DataType::Utf8 => compare_value!(StringArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Decimal(..) => compare_value!(DecimalArray), + DataType::Timestamp(time_unit, None) => match time_unit { + TimeUnit::Second => compare_value!(TimestampSecondArray), + TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), + TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), + TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), + }, _ => { return Err(ArrowError::NotYetImplemented( "Unsupported data type in sort merge join comparator".to_owned(), @@ -1193,9 +1199,15 @@ fn is_join_arrays_equal( DataType::UInt16 => compare_value!(UInt16Array), DataType::UInt32 => compare_value!(UInt32Array), DataType::UInt64 => compare_value!(UInt64Array), - DataType::Timestamp(_, None) => compare_value!(Int64Array), DataType::Utf8 => compare_value!(StringArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Decimal(..) => compare_value!(DecimalArray), + DataType::Timestamp(time_unit, None) => match time_unit { + TimeUnit::Second => compare_value!(TimestampSecondArray), + TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), + TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), + TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), + }, _ => { return Err(ArrowError::NotYetImplemented( "Unsupported data type in sort merge join comparator".to_owned(), From a0fe903d8288b9365b84c01726731ed6949745c6 Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Thu, 21 Apr 2022 02:13:38 +0800 Subject: [PATCH 08/10] Add fuzz test and fix edge cases --- .../core/src/physical_plan/sort_merge_join.rs | 237 +++++++++--------- datafusion/core/tests/join_fuzz.rs | 223 ++++++++++++++++ 2 files changed, 344 insertions(+), 116 deletions(-) create mode 100644 datafusion/core/tests/join_fuzz.rs diff --git a/datafusion/core/src/physical_plan/sort_merge_join.rs b/datafusion/core/src/physical_plan/sort_merge_join.rs index 93477dc1172c1..ba102b069a664 100644 --- a/datafusion/core/src/physical_plan/sort_merge_join.rs +++ b/datafusion/core/src/physical_plan/sort_merge_join.rs @@ -286,6 +286,7 @@ enum BufferedState { } /// A buffered batch that contains contiguous rows with same join key +#[derive(Debug)] struct BufferedBatch { /// The buffered record batch pub batch: RecordBatch, @@ -375,7 +376,6 @@ impl Stream for SMJStream { loop { match &self.state { SMJState::Init => { - self.buffered_data.scanning_reset(); let streamed_exhausted = self.streamed_state == StreamedState::Exhausted; let buffered_exhausted = @@ -404,12 +404,8 @@ impl Stream for SMJStream { if ![StreamedState::Exhausted, StreamedState::Ready] .contains(&self.streamed_state) { - match self.poll_streamed_row(cx) { - Poll::Ready(Some(Ok(()))) => {} - Poll::Ready(Some(Err(e))) => { - return Poll::Ready(Some(Err(e))) - } - Poll::Ready(None) => {} + match self.poll_streamed_row(cx)? { + Poll::Ready(_) => {} Poll::Pending => return Poll::Pending, } } @@ -417,12 +413,8 @@ impl Stream for SMJStream { if ![BufferedState::Exhausted, BufferedState::Ready] .contains(&self.buffered_state) { - match self.poll_buffered_batches(cx) { - Poll::Ready(Some(Ok(()))) => {} - Poll::Ready(Some(Err(e))) => { - return Poll::Ready(Some(Err(e))) - } - Poll::Ready(None) => {} + match self.poll_buffered_batches(cx)? { + Poll::Ready(_) => {} Poll::Pending => return Poll::Pending, } } @@ -439,21 +431,19 @@ impl Stream for SMJStream { } SMJState::JoinOutput => { let output_indices = self.join_partial()?; - self.output_partial(&output_indices)?; + if !output_indices.is_empty() { + self.output_partial(&output_indices)?; + } - if self.output_size == self.batch_size { + if self.output_size < self.batch_size { + if self.buffered_data.scanning_finished() { + self.buffered_data.scanning_reset(); + self.state = SMJState::Init; + } + } else { let record_batch = self.output_record_batch_and_reset()?; return Poll::Ready(Some(Ok(record_batch))); } - if self.buffered_data.scanning_finished() { - if self.current_ordering.is_le() { - self.streamed_joined = true; - } - if self.current_ordering.is_ge() { - self.buffered_joined = true; - } - self.state = SMJState::Init; - } } SMJState::Exhausted => { if !self.output_record_batches.is_empty() { @@ -607,9 +597,9 @@ impl SMJStream { < self.buffered_data.tail_batch().batch.num_rows() { if is_join_arrays_equal( - self.buffered_data.head_batch().batch.columns(), + &self.buffered_data.head_batch().join_arrays, self.buffered_data.head_batch().range.start, - self.buffered_data.tail_batch().batch.columns(), + &self.buffered_data.tail_batch().join_arrays, self.buffered_data.tail_batch().range.end, )? { self.buffered_data.tail_batch_mut().range.end += 1; @@ -628,12 +618,16 @@ impl SMJStream { } Poll::Ready(Some(batch)) => { self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - self.buffered_data.batches.push_back(BufferedBatch::new( - batch, - 0..0, - &self.on_buffered, - )); + if batch.num_rows() > 0 { + self.join_metrics.input_rows.add(batch.num_rows()); + self.buffered_data.batches.push_back( + BufferedBatch::new( + batch, + 0..0, + &self.on_buffered, + ), + ); + } } } } @@ -670,104 +664,79 @@ impl SMJStream { /// Produce join and fill output buffer until reaching target batch size /// or the join is finished fn join_partial(&mut self) -> ArrowResult> { - let mut output_indices = vec![]; + let mut join_streamed = false; + let mut join_buffered = false; + + // determine whether we need to join streamed/buffered rows match self.current_ordering { Ordering::Less => { - let output_streamed_join = match self.join_type { - JoinType::Inner | JoinType::Semi => false, - JoinType::Left - | JoinType::Right - | JoinType::Full - | JoinType::Anti => !self.streamed_joined, - }; - - // streamed joins null - if output_streamed_join { - output_indices.push(OutputIndex { - streamed_idx: Some(self.streamed_idx), - buffered_idx: None, - }); - self.output_size += 1; + if matches!( + self.join_type, + JoinType::Left | JoinType::Right | JoinType::Full | JoinType::Anti + ) { + join_streamed = !self.streamed_joined; } - self.buffered_data.scanning_finish(); } Ordering::Equal => { - let output_equal_join = match self.join_type { - JoinType::Inner - | JoinType::Left - | JoinType::Right - | JoinType::Full - | JoinType::Semi => true, - JoinType::Anti => false, - }; - - // streamed joins buffered - if output_equal_join { - if JoinType::Semi == self.join_type { - if !self.streamed_joined { - output_indices.push(OutputIndex { - streamed_idx: Some(self.streamed_idx), - buffered_idx: None, - }); - self.output_size += 1; - } - self.buffered_data.scanning_finish(); - } - } else { - self.buffered_data.scanning_finish(); + if matches!(self.join_type, JoinType::Semi) { + join_streamed = !self.streamed_joined; } + if matches!( + self.join_type, + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full + ) { + join_streamed = true; + join_buffered = true; + }; } Ordering::Greater => { - let output_buffered_join = match self.join_type { - JoinType::Inner - | JoinType::Left - | JoinType::Right - | JoinType::Anti - | JoinType::Semi => false, - JoinType::Full => !self.buffered_joined, + if matches!(self.join_type, JoinType::Full) { + join_buffered = !self.buffered_joined; }; - - // null joins buffered - if output_buffered_join { - if JoinType::Anti == self.join_type { - if !self.streamed_joined { - output_indices.push(OutputIndex { - streamed_idx: Some(self.streamed_idx), - buffered_idx: None, - }); - self.output_size += 1; - } - self.buffered_data.scanning_finish(); - } - } else { - self.buffered_data.scanning_finish(); - } } } + if !join_streamed && !join_buffered { + // no joined data + self.buffered_data.scanning_finish(); + return Ok(vec![]); + } - // scan buffered stream and write to output buffer - while !self.buffered_data.scanning_finished() - && self.output_size < self.batch_size - { - if self.current_ordering == Ordering::Equal { - output_indices.push(OutputIndex { - streamed_idx: Some(self.streamed_idx), - buffered_idx: Some(( - self.buffered_data.scanning_batch_idx, - self.buffered_data.scanning_idx(), - )), - }); + let mut output_indices = vec![]; + + if join_buffered { + // joining streamed/nulls and buffered + let streamed_idx = if join_streamed { + Some(self.streamed_idx) } else { + None + }; + while !self.buffered_data.scanning_finished() + && self.output_size < self.batch_size + { output_indices.push(OutputIndex { - streamed_idx: None, + streamed_idx, buffered_idx: Some(( self.buffered_data.scanning_batch_idx, self.buffered_data.scanning_idx(), )), }); + self.output_size += 1; + self.buffered_data.scanning_advance(); + + if self.buffered_data.scanning_finished() { + self.streamed_joined = join_streamed; + self.buffered_joined = true; + } } + } else { + // joining streamed and nulls + output_indices.push(OutputIndex { + streamed_idx: Some(self.streamed_idx), + buffered_idx: None, + }); self.output_size += 1; - self.buffered_data.scanning_advance(); + self.buffered_data.scanning_finish(); + self.streamed_joined = true; } Ok(output_indices) } @@ -867,13 +836,12 @@ impl SMJStream { .filter(|index| index.buffered_idx.is_some()) { let buffered_idx = index.buffered_idx.unwrap(); - if buffered_idx.0 == buffered_batch_idx { - indices.push(index); - } else { + if index.buffered_idx.unwrap().0 != buffered_batch_idx { output(buffered_batch_idx, &indices)?; buffered_batch_idx = buffered_idx.0; indices.clear(); } + indices.push(index); } output(buffered_batch_idx, &indices)?; Ok(()) @@ -984,13 +952,12 @@ impl SMJStream { .filter(|index| index.buffered_idx.is_some()) { let buffered_idx = index.buffered_idx.unwrap(); - if buffered_idx.0 == buffered_batch_idx { - indices.push(index); - } else { + if buffered_idx.0 != buffered_batch_idx { output(buffered_batch_idx, &indices)?; buffered_batch_idx = buffered_idx.0; indices.clear(); } + indices.push(index); } output(buffered_batch_idx, &indices)?; Ok(()) @@ -998,7 +965,7 @@ impl SMJStream { } /// Buffered data contains all buffered batches with one unique join key -#[derive(Default)] +#[derive(Debug, Default)] struct BufferedData { /// Buffered batches with the same key pub batches: VecDeque, @@ -1424,6 +1391,44 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_inner_two_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 1, 2]), + ("b2", &vec![1, 1, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 1, 3]), + ("b2", &vec![1, 1, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + + let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 1 | 1 | 7 | 1 | 1 | 80 |", + "| 1 | 1 | 8 | 1 | 1 | 70 |", + "| 1 | 1 | 8 | 1 | 1 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn join_inner_with_nulls() -> Result<()> { let left = build_table_i32_nullable( diff --git a/datafusion/core/tests/join_fuzz.rs b/datafusion/core/tests/join_fuzz.rs new file mode 100644 index 0000000000000..503dd94223bfb --- /dev/null +++ b/datafusion/core/tests/join_fuzz.rs @@ -0,0 +1,223 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array}; +use arrow::compute::SortOptions; +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::pretty_format_batches; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +use datafusion::logical_plan::JoinType; +use datafusion::physical_plan::collect; +use datafusion::physical_plan::expressions::Column; +use datafusion::physical_plan::hash_join::{HashJoinExec, PartitionMode}; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::sort_merge_join::SortMergeJoinExec; + +use datafusion::prelude::{SessionConfig, SessionContext}; +use fuzz_utils::{add_empty_batches}; + +#[tokio::test] +async fn test_inner_join_1k() { + run_join_test( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Inner, + ) + .await +} + +#[tokio::test] +async fn test_left_join_1k() { + run_join_test( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Left, + ) + .await +} + +#[tokio::test] +async fn test_right_join_1k() { + run_join_test( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Right, + ) + .await +} + +#[tokio::test] +async fn test_full_join_1k() { + run_join_test( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Full, + ) + .await +} + +#[tokio::test] +async fn test_semi_join_1k() { + run_join_test( + make_staggered_batches(10000), + make_staggered_batches(10000), + JoinType::Semi, + ) + .await +} + +#[tokio::test] +async fn test_anti_join_1k() { + run_join_test( + make_staggered_batches(10000), + make_staggered_batches(10000), + JoinType::Anti, + ) + .await +} + +/// Perform sort-merge join and hash join on same input +/// and verify two outputs are equal +async fn run_join_test( + input1: Vec, + input2: Vec, + join_type: JoinType, +) { + let batch_sizes = [1, 2, 7, 49, 50, 51, 100]; + for batch_size in batch_sizes { + let session_config = SessionConfig::new().with_batch_size(batch_size); + let ctx = SessionContext::with_config(session_config); + let task_ctx = ctx.task_ctx(); + + let schema1 = input1[0].schema(); + let schema2 = input2[0].schema(); + let on_columns = vec![ + ( + Column::new_with_schema("a", &schema1).unwrap(), + Column::new_with_schema("a", &schema2).unwrap(), + ), + ( + Column::new_with_schema("b", &schema1).unwrap(), + Column::new_with_schema("b", &schema2).unwrap(), + ), + ]; + + // sort-merge join + let left = Arc::new( + MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), + ); + let right = Arc::new( + MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), + ); + let smj = Arc::new( + SortMergeJoinExec::try_new( + left, + right, + on_columns.clone(), + join_type, + vec![SortOptions::default(), SortOptions::default()], + false, + ) + .unwrap(), + ); + let smj_collected = collect(smj, task_ctx.clone()).await.unwrap(); + + // hash join + let left = Arc::new( + MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), + ); + let right = Arc::new( + MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), + ); + let hj = Arc::new( + HashJoinExec::try_new( + left, + right, + on_columns.clone(), + &join_type, + PartitionMode::Partitioned, + &false, + ) + .unwrap(), + ); + let hj_collected = collect(hj, task_ctx.clone()).await.unwrap(); + + // compare + let smj_formatted = pretty_format_batches(&smj_collected).unwrap().to_string(); + let hj_formatted = pretty_format_batches(&hj_collected).unwrap().to_string(); + + let mut smj_formatted_sorted: Vec<&str> = smj_formatted.trim().lines().collect(); + smj_formatted_sorted.sort_unstable(); + + let mut hj_formatted_sorted: Vec<&str> = hj_formatted.trim().lines().collect(); + hj_formatted_sorted.sort_unstable(); + + for (i, (smj_line, hj_line)) in smj_formatted_sorted + .iter() + .zip(&hj_formatted_sorted) + .enumerate() + { + assert_eq!((i, smj_line), (i, hj_line)); + } + } +} + +/// Return randomly sized record batches with: +/// two sorted int32 columns 'a', 'b' ranged from 0..99 as join columns +/// two random int32 columns 'x', 'y' as other columns +fn make_staggered_batches(len: usize) -> Vec { + let mut rng = rand::thread_rng(); + let mut input12: Vec<(i32, i32)> = vec![(0, 0); len]; + let mut input3: Vec = vec![0; len]; + let mut input4: Vec = vec![0; len]; + input12 + .iter_mut() + .for_each(|v| *v = (rng.gen_range(0..100), rng.gen_range(0..100))); + rng.fill(&mut input3[..]); + rng.fill(&mut input4[..]); + input12.sort_unstable(); + let input1 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.0)); + let input2 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.1)); + let input3 = Int32Array::from_iter_values(input3.into_iter()); + let input4 = Int32Array::from_iter_values(input4.into_iter()); + + // split into several record batches + let mut remainder = RecordBatch::try_from_iter(vec![ + ("a", Arc::new(input1) as ArrayRef), + ("b", Arc::new(input2) as ArrayRef), + ("x", Arc::new(input3) as ArrayRef), + ("y", Arc::new(input4) as ArrayRef), + ]) + .unwrap(); + + let mut batches = vec![]; + + // use a random number generator to pick a random sized output + let mut rng = StdRng::seed_from_u64(42); + while remainder.num_rows() > 0 { + let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + + batches.push(remainder.slice(0, batch_size)); + remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); + } + + add_empty_batches(batches, &mut rng) +} From 4a39c9c35dabc90b6ce4a1ab8a97d98cf6d08b16 Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Thu, 21 Apr 2022 21:46:59 +0800 Subject: [PATCH 09/10] Support float32/64 data types in comparison --- datafusion/core/src/physical_plan/sort_merge_join.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datafusion/core/src/physical_plan/sort_merge_join.rs b/datafusion/core/src/physical_plan/sort_merge_join.rs index ba102b069a664..9c769bd961654 100644 --- a/datafusion/core/src/physical_plan/sort_merge_join.rs +++ b/datafusion/core/src/physical_plan/sort_merge_join.rs @@ -1104,6 +1104,8 @@ fn compare_join_arrays( DataType::UInt16 => compare_value!(UInt16Array), DataType::UInt32 => compare_value!(UInt32Array), DataType::UInt64 => compare_value!(UInt64Array), + DataType::Float32 => compare_value!(Float32Array), + DataType::Float64 => compare_value!(Float64Array), DataType::Utf8 => compare_value!(StringArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), DataType::Decimal(..) => compare_value!(DecimalArray), @@ -1166,6 +1168,8 @@ fn is_join_arrays_equal( DataType::UInt16 => compare_value!(UInt16Array), DataType::UInt32 => compare_value!(UInt32Array), DataType::UInt64 => compare_value!(UInt64Array), + DataType::Float32 => compare_value!(Float32Array), + DataType::Float64 => compare_value!(Float64Array), DataType::Utf8 => compare_value!(StringArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), DataType::Decimal(..) => compare_value!(DecimalArray), From 2aa915163d68ee66ddd86443982826867a0dfde0 Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Fri, 22 Apr 2022 10:24:35 +0800 Subject: [PATCH 10/10] Fix lint issues --- datafusion/core/tests/join_fuzz.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/join_fuzz.rs b/datafusion/core/tests/join_fuzz.rs index 503dd94223bfb..2f47356236b69 100644 --- a/datafusion/core/tests/join_fuzz.rs +++ b/datafusion/core/tests/join_fuzz.rs @@ -32,7 +32,7 @@ use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sort_merge_join::SortMergeJoinExec; use datafusion::prelude::{SessionConfig, SessionContext}; -use fuzz_utils::{add_empty_batches}; +use fuzz_utils::add_empty_batches; #[tokio::test] async fn test_inner_join_1k() {