From d87d24df879fa1335a10e904aa3196ff8d9b588d Mon Sep 17 00:00:00 2001 From: Jonathan Date: Tue, 9 Sep 2025 00:02:04 -0400 Subject: [PATCH 01/20] POC: `ClassicJoin` for PWMJ --- datafusion/core/src/physical_planner.rs | 138 +- .../src/joins/hash_join/stream.rs | 1 + datafusion/physical-plan/src/joins/mod.rs | 2 + .../piecewise_merge_join/classic_join.rs | 1494 +++++++++++++++++ .../src/joins/piecewise_merge_join/exec.rs | 730 ++++++++ .../src/joins/piecewise_merge_join/mod.rs | 22 + .../src/joins/piecewise_merge_join/utils.rs | 61 + .../src/joins/sort_merge_join/stream.rs | 91 +- datafusion/physical-plan/src/joins/utils.rs | 120 +- datafusion/sqllogictest/test_files/joins.slt | 71 +- datafusion/sqllogictest/test_files/pwmj.slt | 141 ++ 11 files changed, 2761 insertions(+), 110 deletions(-) create mode 100644 datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs create mode 100644 datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs create mode 100644 datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs create mode 100644 datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs create mode 100644 datafusion/sqllogictest/test_files/pwmj.slt diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 413b809c4e6ab..490fcb820141f 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -77,10 +77,11 @@ use datafusion_expr::expr::{ }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; +use datafusion_expr::utils::split_conjunction; use datafusion_expr::{ - Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType, - Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame, - WindowFrameBound, WriteOp, + Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, + FetchType, Filter, JoinType, Operator, RecursiveQuery, SkipType, StringifiedPlan, + WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::{Column, Literal}; @@ -90,6 +91,7 @@ use datafusion_physical_expr::{ use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::execution_plan::InvariantLevel; +use datafusion_physical_plan::joins::PiecewiseMergeJoinExec; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::recursive_query::RecursiveQueryExec; use datafusion_physical_plan::unnest::ListUnnest; @@ -1089,8 +1091,42 @@ impl DefaultPhysicalPlanner { }) .collect::>()?; + // TODO: `num_range_filters` can be used later on for ASOF joins (`num_range_filters > 1`) + let mut num_range_filters = 0; + let mut range_filters: Vec = Vec::new(); + let mut total_filters = 0; + let join_filter = match filter { Some(expr) => { + let split_expr = split_conjunction(expr); + for expr in split_expr.iter() { + match *expr { + Expr::BinaryExpr(BinaryExpr { + left: _, + right: _, + op, + }) => { + if matches!( + op, + Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + ) { + range_filters.push((**expr).clone()); + num_range_filters += 1; + } + total_filters += 1; + } + // TODO: Want to deal with `Expr::Between` for IEJoins, it counts as two range predicates + // which is why it is not dealt with in PWMJ + // Expr::Between(_) => {}, + _ => { + total_filters += 1; + } + } + } + // Extract columns from filter expression and saved in a HashSet let cols = expr.column_refs(); @@ -1146,6 +1182,7 @@ impl DefaultPhysicalPlanner { )?; let filter_schema = Schema::new_with_metadata(filter_fields, metadata); + let filter_expr = create_physical_expr( expr, &filter_df_schema, @@ -1168,10 +1205,105 @@ impl DefaultPhysicalPlanner { let prefer_hash_join = session_state.config_options().optimizer.prefer_hash_join; + let cfg = session_state.config(); + + let can_run_single = + cfg.target_partitions() == 1 || !cfg.repartition_joins(); + + // TODO: Allow PWMJ to deal with residual equijoin conditions let join: Arc = if join_on.is_empty() { if join_filter.is_none() && matches!(join_type, JoinType::Inner) { // cross join if there is no join conditions and no join filter set Arc::new(CrossJoinExec::new(physical_left, physical_right)) + } else if num_range_filters == 1 + && total_filters == 1 + && can_run_single + && !matches!( + join_type, + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark + ) + { + let Expr::BinaryExpr(be) = &range_filters[0] else { + return plan_err!( + "Unsupported expression for PWMJ: Expected `Expr::BinaryExpr`" + ); + }; + + let mut op = be.op; + if !matches!( + op, + Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq + ) { + return plan_err!( + "Unsupported operator for PWMJ: {:?}. Expected one of <, <=, >, >=", + op + ); + } + + fn reverse_ineq(op: Operator) -> Operator { + match op { + Operator::Lt => Operator::Gt, + Operator::LtEq => Operator::GtEq, + Operator::Gt => Operator::Lt, + Operator::GtEq => Operator::LtEq, + _ => op, + } + } + + let side_of = |e: &Expr| -> Result<&'static str> { + let cols = e.column_refs(); + let in_left = cols + .iter() + .all(|c| left_df_schema.index_of_column(c).is_ok()); + let in_right = cols + .iter() + .all(|c| right_df_schema.index_of_column(c).is_ok()); + match (in_left, in_right) { + (true, false) => Ok("left"), + (false, true) => Ok("right"), + _ => unreachable!(), + } + }; + + let mut lhs_logical = &be.left; + let mut rhs_logical = &be.right; + + let left_side = side_of(lhs_logical)?; + let right_side = side_of(rhs_logical)?; + if left_side == "right" && right_side == "left" { + std::mem::swap(&mut lhs_logical, &mut rhs_logical); + op = reverse_ineq(op); + } else if !(left_side == "left" && right_side == "right") { + return plan_err!( + "Unsupported operator for PWMJ: {:?}. Expected one of <, <=, >, >=", + op + ); + } + + let on_left = create_physical_expr( + lhs_logical, + &left_df_schema, + session_state.execution_props(), + )?; + let on_right = create_physical_expr( + rhs_logical, + &right_df_schema, + session_state.execution_props(), + )?; + + + Arc::new(PiecewiseMergeJoinExec::try_new( + physical_left, + physical_right, + (on_left, on_right), + op, + *join_type, + )?) } else { // there is no equal join condition, use the nested loop join Arc::new(NestedLoopJoinExec::try_new( diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index d368a9cf8ee29..0adb9b7a69cbe 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -589,6 +589,7 @@ impl HashJoinStream { let (left_side, right_side) = get_final_indices_from_shared_bitmap( build_side.left_data.visited_indices_bitmap(), self.join_type, + true, ); let empty_right_batch = RecordBatch::new_empty(self.right.schema()); // use the left and right indices to produce the batch result diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 1d36db996434e..b0c28cf994f71 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -24,11 +24,13 @@ pub use hash_join::HashJoinExec; pub use nested_loop_join::NestedLoopJoinExec; use parking_lot::Mutex; // Note: SortMergeJoin is not used in plans yet +pub use piecewise_merge_join::PiecewiseMergeJoinExec; pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; mod hash_join; mod nested_loop_join; +mod piecewise_merge_join; mod sort_merge_join; mod stream_join_utils; mod symmetric_hash_join; diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs new file mode 100644 index 0000000000000..cde31a7f5df1b --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -0,0 +1,1494 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Stream Implementation for PiecewiseMergeJoin's Classic Join (Left, Right, Full, Inner) + +use arrow::array::{ + new_null_array, Array, PrimitiveArray, PrimitiveBuilder, RecordBatchOptions, +}; +use arrow::compute::take; +use arrow::datatypes::{UInt32Type, UInt64Type}; +use arrow::{ + array::{ + ArrayRef, RecordBatch, UInt32Array, UInt32Builder, UInt64Array, UInt64Builder, + }, + compute::{sort_to_indices, take_record_batch}, +}; +use arrow_schema::{ArrowError, Schema, SchemaRef, SortOptions}; +use datafusion_common::NullEquality; +use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::PhysicalExprRef; +use futures::{Stream, StreamExt}; +use log::debug; +use std::{cmp::Ordering, task::ready}; +use std::{sync::Arc, task::Poll}; + +use crate::handle_state; +use crate::joins::piecewise_merge_join::exec::{ + BufferedSide, BufferedSideReadyState, +}; +use crate::joins::piecewise_merge_join::utils::need_produce_result_in_final; +use crate::joins::utils::{compare_join_arrays, get_final_indices_from_shared_bitmap}; +use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult}; +pub(super) enum PiecewiseMergeJoinStreamState { + WaitBufferedSide, + FetchStreamBatch, + ProcessStreamBatch(StreamedBatch), + ExhaustedStreamSide, + Completed, +} + +impl PiecewiseMergeJoinStreamState { + // Grab mutable reference to the current stream batch + fn try_as_process_stream_batch_mut(&mut self) -> Result<&mut StreamedBatch> { + match self { + PiecewiseMergeJoinStreamState::ProcessStreamBatch(state) => Ok(state), + _ => internal_err!("Expected streamed batch in StreamBatch"), + } + } +} + +pub(super) struct StreamedBatch { + pub batch: RecordBatch, + values: Vec, +} + +impl StreamedBatch { + #[allow(dead_code)] + fn new(batch: RecordBatch, values: Vec) -> Self { + Self { batch, values } + } + + fn values(&self) -> &Vec { + &self.values + } +} + +pub(super) struct ClassicPWMJStream { + // Output schema of the `PiecewiseMergeJoin` + pub schema: Arc, + + // Physical expression that is evaluated on the streamed side + // We do not need on_buffered as this is already evaluated when + // creating the buffered side which happens before initializing + // `PiecewiseMergeJoinStream` + pub on_streamed: PhysicalExprRef, + // Type of join + pub join_type: JoinType, + // Comparison operator + pub operator: Operator, + // Streamed batch + pub streamed: SendableRecordBatchStream, + // Streamed schema + streamed_schema: SchemaRef, + // Buffered side data + buffered_side: BufferedSide, + // Tracks the state of the `PiecewiseMergeJoin` + state: PiecewiseMergeJoinStreamState, + // Sort option for buffered and streamed side (specifies whether + // the sort is ascending or descending) + sort_option: SortOptions, + // Metrics for build + probe joins + join_metrics: BuildProbeJoinMetrics, + // Tracking incremental state for emitting record batches + batch_process_state: BatchProcessState, + // Creates batch size + batch_size: usize, +} + +impl RecordBatchStream for ClassicPWMJStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +// `PiecewiseMergeJoinStreamState` is separated into `WaitBufferedSide`, `FetchStreamBatch`, +// `ProcessStreamBatch`, `ExhaustedStreamSide` and `Completed`. +// +// Classic Joins +// 1. `WaitBufferedSide` - Load in the buffered side data into memory. +// 2. `FetchStreamBatch` - Fetch + sort incoming stream batches. We switch the state to +// `ExhaustedStreamBatch` once stream batches are exhausted. +// 3. `ProcessStreamBatch` - Compare stream batch row values against the buffered side data. +// 4. `ExhaustedStreamBatch` - If the join type is Left or Inner we will return state as +// `Completed` however for Full and Right we will need to process the matched/unmatched rows. +impl ClassicPWMJStream { + // Creates a new `PiecewiseMergeJoinStream` instance + #[allow(clippy::too_many_arguments)] + pub fn try_new( + schema: Arc, + on_streamed: PhysicalExprRef, + join_type: JoinType, + operator: Operator, + streamed: SendableRecordBatchStream, + buffered_side: BufferedSide, + state: PiecewiseMergeJoinStreamState, + sort_option: SortOptions, + join_metrics: BuildProbeJoinMetrics, + batch_size: usize, + ) -> Self { + let streamed_schema = streamed.schema(); + Self { + schema, + on_streamed, + join_type, + operator, + streamed_schema, + streamed, + buffered_side, + state, + sort_option, + join_metrics, + batch_process_state: BatchProcessState::new(), + batch_size, + } + } + + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + loop { + return match self.state { + PiecewiseMergeJoinStreamState::WaitBufferedSide => { + handle_state!(ready!(self.collect_buffered_side(cx))) + } + PiecewiseMergeJoinStreamState::FetchStreamBatch => { + handle_state!(ready!(self.fetch_stream_batch(cx))) + } + PiecewiseMergeJoinStreamState::ProcessStreamBatch(_) => { + handle_state!(self.process_stream_batch()) + } + PiecewiseMergeJoinStreamState::ExhaustedStreamSide => { + handle_state!(self.process_unmatched_buffered_batch()) + } + PiecewiseMergeJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + // Collects buffered side data + fn collect_buffered_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + let build_timer = self.join_metrics.build_time.timer(); + let buffered_data = ready!(self + .buffered_side + .try_as_initial_mut()? + .buffered_fut + .get_shared(cx))?; + build_timer.done(); + + // We will start fetching stream batches for classic joins + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + + self.buffered_side = + BufferedSide::Ready(BufferedSideReadyState { buffered_data }); + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + // Fetches incoming stream batches + fn fetch_stream_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.streamed.poll_next_unpin(cx)) { + None => { + self.state = PiecewiseMergeJoinStreamState::ExhaustedStreamSide; + } + Some(Ok(batch)) => { + // Evaluate the streamed physical expression on the stream batch + let stream_values: ArrayRef = self + .on_streamed + .evaluate(&batch)? + .into_array(batch.num_rows())?; + + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + // Sort stream values and change the streamed record batch accordingly + let indices = sort_to_indices( + stream_values.as_ref(), + Some(self.sort_option), + None, + )?; + let stream_batch = take_record_batch(&batch, &indices)?; + let stream_values = take(stream_values.as_ref(), &indices, None)?; + + self.state = + PiecewiseMergeJoinStreamState::ProcessStreamBatch(StreamedBatch { + batch: stream_batch, + values: vec![stream_values], + }); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + // Only classic join will call. This function will process stream batches and evaluate against + // the buffered side data. + fn process_stream_batch( + &mut self, + ) -> Result>> { + let buffered_side = self.buffered_side.try_as_ready_mut()?; + let stream_batch = self.state.try_as_process_stream_batch_mut()?; + + let batch = resolve_classic_join( + buffered_side, + stream_batch, + Arc::clone(&self.schema), + self.operator, + self.sort_option, + self.join_type, + &mut self.batch_process_state, + self.batch_size, + )?; + + if self.batch_process_state.continue_process { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + Ok(StatefulStreamResult::Ready(Some(batch))) + } + + // Process remaining unmatched rows + fn process_unmatched_buffered_batch( + &mut self, + ) -> Result>> { + // Return early for `JoinType::Right` and `JoinType::Inner` + if matches!(self.join_type, JoinType::Right | JoinType::Inner) { + self.state = PiecewiseMergeJoinStreamState::Completed; + return Ok(StatefulStreamResult::Ready(None)); + } + + let timer = self.join_metrics.join_time.timer(); + + let buffered_data = + Arc::clone(&self.buffered_side.try_as_ready().unwrap().buffered_data); + + // Check if the same batch needs to be checked for values again + if let Some(start_idx) = self.batch_process_state.process_rest { + if let Some(buffered_indices) = &self.batch_process_state.buffered_indices { + let remaining = buffered_indices.len() - start_idx; + + // Branch into this and return value if there are more rows to deal with + if remaining > self.batch_size { + let buffered_batch = buffered_data.batch(); + let empty_stream_batch = + RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); + + let buffered_chunk_ref = buffered_indices + .slice(start_idx, self.batch_size); + let new_buffered_indices = buffered_chunk_ref + .as_any() + .downcast_ref::() + .expect("downcast to UInt64Array after slice"); + + let streamed_indices: UInt32Array = + (0..new_buffered_indices.len() as u32).collect(); + + let batch = build_matched_indices( + Arc::clone(&self.schema), + &empty_stream_batch, + buffered_batch, + streamed_indices, + new_buffered_indices.clone(), + )?; + + self.batch_process_state.set_process_rest(Some( + start_idx + self.batch_size, + )); + self.batch_process_state.continue_process = true; + + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + let buffered_batch = buffered_data.batch(); + let empty_stream_batch = + RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); + + let buffered_chunk_ref = buffered_indices.slice(start_idx, remaining); + let new_buffered_indices = buffered_chunk_ref + .as_any() + .downcast_ref::() + .expect("downcast to UInt64Array after slice"); + + let streamed_indices: UInt32Array = + (0..new_buffered_indices.len() as u32).collect(); + + let batch = build_matched_indices( + Arc::clone(&self.schema), + &empty_stream_batch, + buffered_batch, + streamed_indices, + new_buffered_indices.clone(), + )?; + + self.batch_process_state.reset(); + + timer.done(); + self.join_metrics.output_batches.add(1); + self.state = PiecewiseMergeJoinStreamState::Completed; + + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + return exec_err!("Batch process state should hold buffered indices"); + } + + // Pass in piecewise flag to allow Right Semi/Anti/Mark joins to also be processed + let (buffered_indices, streamed_indices) = get_final_indices_from_shared_bitmap( + &buffered_data.visited_indices_bitmap, + self.join_type, + true, + ); + + // If the output indices is larger than the limit for the incremental batching then + // proceed to outputting all matches up to that index, return batch, and the matching + // will start next on the updated index (`process_rest`) + if buffered_indices.len() > self.batch_size { + let buffered_batch = buffered_data.batch(); + let empty_stream_batch = + RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); + + let indices_chunk_ref = buffered_indices.slice( + self.batch_process_state.start_idx, + self.batch_size, + ); + + let indices_chunk = indices_chunk_ref + .as_any() + .downcast_ref::() + .expect("downcast to UInt64Array after slice"); + + let batch = build_matched_indices( + Arc::clone(&self.schema), + &empty_stream_batch, + buffered_batch, + streamed_indices, + indices_chunk.clone(), + )?; + + self.batch_process_state.buffered_indices = Some(buffered_indices); + self.batch_process_state + .set_process_rest(Some(self.batch_size)); + self.batch_process_state.continue_process = true; + + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + let buffered_batch = buffered_data.batch(); + let empty_stream_batch = + RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); + + let batch = build_matched_indices( + Arc::clone(&self.schema), + &empty_stream_batch, + buffered_batch, + streamed_indices, + buffered_indices, + )?; + + timer.done(); + self.join_metrics.output_batches.add(1); + self.state = PiecewiseMergeJoinStreamState::Completed; + + Ok(StatefulStreamResult::Ready(Some(batch))) + } +} + +// Holds all information for processing incremental output +struct BatchProcessState { + // Used to pick up from the last index on the stream side + start_idx: usize, + // Used to pick up from the last index on the buffered side + pivot: usize, + // Tracks the number of rows processed; default starts at 0 + num_rows: usize, + // Processes the rest of the batch + process_rest: Option, + // Used to skip fully processing the row + not_found: bool, + // Signals whether to call `ProcessStreamBatch` again + continue_process: bool, + // Holding the buffered indices when processing the remaining marked rows. + buffered_indices: Option>, +} + +impl BatchProcessState { + pub fn new() -> Self { + Self { + start_idx: 0, + num_rows: 0, + pivot: 0, + process_rest: None, + not_found: false, + continue_process: false, + buffered_indices: None, + } + } + + fn reset(&mut self) { + self.start_idx = 0; + self.num_rows = 0; + self.pivot = 0; + self.process_rest = None; + self.not_found = false; + self.continue_process = false; + self.buffered_indices = None; + } + + fn pivot(&self) -> usize { + self.pivot + } + + fn set_pivot(&mut self, pivot: usize) { + self.pivot = pivot; + } + + fn set_start_idx(&mut self, start_idx: usize) { + self.start_idx = start_idx; + } + + fn set_rows(&mut self, num_rows: usize) { + self.num_rows = num_rows; + } + + fn set_process_rest(&mut self, process_rest: Option) { + self.process_rest = process_rest; + } +} + +impl Stream for ClassicPWMJStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.poll_next_impl(cx) + } +} + +// For Left, Right, Full, and Inner joins, incoming stream batches will already be sorted. +fn resolve_classic_join( + buffered_side: &mut BufferedSideReadyState, + stream_batch: &StreamedBatch, + join_schema: Arc, + operator: Operator, + sort_options: SortOptions, + join_type: JoinType, + batch_process_state: &mut BatchProcessState, + batch_size: usize, +) -> Result { + let buffered_values = buffered_side.buffered_data.values(); + let buffered_len = buffered_values.len(); + let stream_values = stream_batch.values(); + + let mut buffered_indices = UInt64Builder::default(); + let mut stream_indices = UInt32Builder::default(); + debug!("wow!"); + // Our pivot variable allows us to start probing on the buffered side where we last matched + // in the previous stream row. + let mut pivot = batch_process_state.pivot(); + for row_idx in batch_process_state.start_idx..stream_values[0].len() { + let mut found = false; + + // Check once to see if it is a redo of a null value if not we do not try to process the batch + if !batch_process_state.not_found { + while pivot < buffered_values.len() + || batch_process_state.process_rest.is_some() + { + // If there is still data left in the batch to process, use the index and output + if let Some(start_idx) = batch_process_state.process_rest { + let count = buffered_values.len() - start_idx; + if count >= batch_size { + let stream_repeated = + vec![row_idx as u32; batch_size]; + batch_process_state.set_process_rest(Some( + start_idx + batch_size, + )); + batch_process_state.set_rows( + batch_process_state.num_rows + + batch_size, + ); + let buffered_range: Vec = (start_idx as u64 + ..((start_idx as u64) + + (batch_size as u64))) + .collect(); + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + + let batch = process_batch( + &mut buffered_indices, + &mut stream_indices, + stream_batch, + buffered_side, + join_type, + join_schema, + )?; + batch_process_state.continue_process = true; + batch_process_state.set_rows(0); + + return Ok(batch); + } + + batch_process_state.set_rows(batch_process_state.num_rows + count); + let stream_repeated = vec![row_idx as u32; count]; + let buffered_range: Vec = + (start_idx as u64..buffered_len as u64).collect(); + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + batch_process_state.process_rest = None; + + found = true; + + break; + } + + let compare = compare_join_arrays( + &[Arc::clone(&stream_values[0])], + row_idx, + &[Arc::clone(buffered_values)], + pivot, + &[sort_options], + NullEquality::NullEqualsNothing, + )?; + + // If we find a match we append all indices and move to the next stream row index + match operator { + Operator::Gt | Operator::Lt => { + if matches!(compare, Ordering::Less) { + let count = buffered_values.len() - pivot; + + // If the current output + new output is over our process value then we want to be + // able to change that + if batch_process_state.num_rows + count + >= batch_size + { + let process_batch_size = batch_size + - batch_process_state.num_rows; + let stream_repeated = + vec![row_idx as u32; process_batch_size]; + batch_process_state.set_rows( + batch_process_state.num_rows + process_batch_size, + ); + + debug!( + "pivot: {}, process_batch_size: {}", + pivot, process_batch_size + ); + let buffered_range: Vec = (pivot as u64 + ..(pivot + process_batch_size) as u64) + .collect(); + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + + let batch = process_batch( + &mut buffered_indices, + &mut stream_indices, + stream_batch, + buffered_side, + join_type, + join_schema, + )?; + + batch_process_state + .set_process_rest(Some(pivot + process_batch_size)); + batch_process_state.continue_process = true; + // Update the start index so it repeats the process + batch_process_state.set_start_idx(row_idx); + batch_process_state.set_pivot(pivot); + batch_process_state.set_rows(0); + + return Ok(batch); + } + + // Update the number of rows processed + batch_process_state + .set_rows(batch_process_state.num_rows + count); + debug!( + "pivot: {}, process_batch_size: {}", + pivot, buffered_len + ); + let stream_repeated = vec![row_idx as u32; count]; + let buffered_range: Vec = + (pivot as u64..buffered_len as u64).collect(); + + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + found = true; + + break; + } + } + Operator::GtEq | Operator::LtEq => { + if matches!(compare, Ordering::Equal | Ordering::Less) { + let count = buffered_values.len() - pivot; + + // If the current output + new output is over our process value then we want to be + // able to change that + if batch_process_state.num_rows + count + >= batch_size + { + // Update the start index so it repeats the process + batch_process_state.set_start_idx(row_idx); + batch_process_state.set_pivot(pivot); + + let process_batch_size = batch_size + - batch_process_state.num_rows; + let stream_repeated = + vec![row_idx as u32; process_batch_size]; + batch_process_state + .set_process_rest(Some(pivot + process_batch_size)); + batch_process_state.set_rows( + batch_process_state.num_rows + process_batch_size, + ); + let buffered_range: Vec = (pivot as u64 + ..(pivot + process_batch_size) as u64) + .collect(); + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + + let batch = process_batch( + &mut buffered_indices, + &mut stream_indices, + stream_batch, + buffered_side, + join_type, + join_schema, + )?; + + batch_process_state.continue_process = true; + batch_process_state.set_rows(0); + + return Ok(batch); + } + + // Update the number of rows processed + batch_process_state + .set_rows(batch_process_state.num_rows + count); + let stream_repeated = vec![row_idx as u32; count]; + let buffered_range: Vec = + (pivot as u64..buffered_len as u64).collect(); + + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + found = true; + + break; + } + } + _ => { + return exec_err!( + "PiecewiseMergeJoin should not contain operator, {}", + operator + ) + } + }; + + // Increment pivot after every row + pivot += 1; + } + } + + // If not found we append a null value for `JoinType::Right` and `JoinType::Full` + if (!found || batch_process_state.not_found) + && matches!(join_type, JoinType::Right | JoinType::Full) + { + let remaining = batch_size + .saturating_sub(batch_process_state.num_rows); + if remaining == 0 { + let batch = process_batch( + &mut buffered_indices, + &mut stream_indices, + stream_batch, + buffered_side, + join_type, + join_schema, + )?; + + // Update the start index so it repeats the process + batch_process_state.set_start_idx(row_idx); + batch_process_state.set_pivot(pivot); + batch_process_state.not_found = true; + batch_process_state.continue_process = true; + batch_process_state.set_rows(0); + + return Ok(batch); + } + + // Append right side value + null value for left + stream_indices.append_value(row_idx as u32); + buffered_indices.append_null(); + batch_process_state.set_rows(batch_process_state.num_rows + 1); + batch_process_state.not_found = false; + } + } + + let batch = process_batch( + &mut buffered_indices, + &mut stream_indices, + stream_batch, + buffered_side, + join_type, + join_schema, + )?; + + // Resets batch process state for processing `Left` + `Full` join + batch_process_state.reset(); + + Ok(batch) +} + +fn process_batch( + buffered_indices: &mut PrimitiveBuilder, + stream_indices: &mut PrimitiveBuilder, + stream_batch: &StreamedBatch, + buffered_side: &mut BufferedSideReadyState, + join_type: JoinType, + join_schema: Arc, +) -> Result { + let stream_indices_array = stream_indices.finish(); + let buffered_indices_array = buffered_indices.finish(); + + // We need to mark the buffered side matched indices for `JoinType::Full` and `JoinType::Left` + if need_produce_result_in_final(join_type) { + let mut bitmap = buffered_side.buffered_data.visited_indices_bitmap.lock(); + + buffered_indices_array.iter().flatten().for_each(|i| { + bitmap.set_bit(i as usize, true); + }); + } + + let batch = build_matched_indices( + join_schema, + &stream_batch.batch, + &buffered_side.buffered_data.batch, + stream_indices_array, + buffered_indices_array, + )?; + + Ok(batch) +} + +fn build_matched_indices( + schema: Arc, + streamed_batch: &RecordBatch, + buffered_batch: &RecordBatch, + streamed_indices: UInt32Array, + buffered_indices: UInt64Array, +) -> Result { + if schema.fields().is_empty() { + // Build an “empty” RecordBatch with just row‐count metadata + let options = RecordBatchOptions::new() + .with_match_field_names(true) + .with_row_count(Some(streamed_indices.len())); + return Ok(RecordBatch::try_new_with_options( + Arc::new((*schema).clone()), + vec![], + &options, + )?); + } + + // Gather stream columns after applying filter specified with stream indices + let streamed_columns = streamed_batch + .columns() + .iter() + .map(|column_array| { + if column_array.is_empty() + || streamed_indices.null_count() == streamed_indices.len() + { + assert_eq!(streamed_indices.null_count(), streamed_indices.len()); + Ok(new_null_array( + column_array.data_type(), + streamed_indices.len(), + )) + } else { + take(column_array, &streamed_indices, None) + } + }) + .collect::, ArrowError>>()?; + + let mut buffered_columns = buffered_batch + .columns() + .iter() + .map(|column_array| take(column_array, &buffered_indices, None)) + .collect::, ArrowError>>()?; + + buffered_columns.extend(streamed_columns); + + Ok(RecordBatch::try_new( + Arc::new((*schema).clone()), + buffered_columns, + )?) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + common, + joins::PiecewiseMergeJoinExec, + test::{build_table_i32, TestMemoryExec}, + ExecutionPlan, + }; + use arrow::array::{Date32Array, Date64Array}; + use arrow_schema::{DataType, Field}; + use datafusion_common::test_util::batches_to_string; + use datafusion_execution::TaskContext; + use datafusion_expr::JoinType; + use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; + use insta::assert_snapshot; + use std::sync::Arc; + + fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() + } + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn build_date_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date32, false), + Field::new(b.0, DataType::Date32, false), + Field::new(c.0, DataType::Date32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date32Array::from(a.1.clone())), + Arc::new(Date32Array::from(b.1.clone())), + Arc::new(Date32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn build_date64_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date64, false), + Field::new(b.0, DataType::Date64, false), + Field::new(c.0, DataType::Date64, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date64Array::from(a.1.clone())), + Arc::new(Date64Array::from(b.1.clone())), + Arc::new(Date64Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn join( + left: Arc, + right: Arc, + on: (Arc, Arc), + operator: Operator, + join_type: JoinType, + ) -> Result { + PiecewiseMergeJoinExec::try_new(left, right, on, operator, join_type) + } + + async fn join_collect( + left: Arc, + right: Arc, + on: (PhysicalExprRef, PhysicalExprRef), + operator: Operator, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + join_collect_with_options(left, right, on, operator, join_type).await + } + + async fn join_collect_with_options( + left: Arc, + right: Arc, + on: (PhysicalExprRef, PhysicalExprRef), + operator: Operator, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + let task_ctx = Arc::new(TaskContext::default()); + let join = join(left, right, on, operator, join_type)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + + #[tokio::test] + async fn join_inner_less_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 3 | 7 | + // | 2 | 2 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![3, 2, 1]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 4]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 3 | 7 | 30 | 4 | 90 | + | 2 | 2 | 8 | 30 | 4 | 90 | + | 3 | 1 | 9 | 30 | 4 | 90 | + | 2 | 2 | 8 | 20 | 3 | 80 | + | 3 | 1 | 9 | 20 | 3 | 80 | + | 3 | 1 | 9 | 10 | 2 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_less_than_unsorted() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 3 | 7 | + // | 2 | 2 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![3, 2, 1]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 2 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 2, 4]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 3 | 7 | 30 | 4 | 90 | + | 2 | 2 | 8 | 30 | 4 | 90 | + | 3 | 1 | 9 | 30 | 4 | 90 | + | 2 | 2 | 8 | 10 | 3 | 70 | + | 3 | 1 | 9 | 10 | 3 | 70 | + | 3 | 1 | 9 | 20 | 2 | 80 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_greater_than_equal_to() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 2 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![2, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 2 | 80 | + // | 30 | 1 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 2, 1]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 2 | 7 | 30 | 1 | 90 | + | 2 | 3 | 8 | 30 | 1 | 90 | + | 3 | 4 | 9 | 30 | 1 | 90 | + | 1 | 2 | 7 | 20 | 2 | 80 | + | 2 | 3 | 8 | 20 | 2 | 80 | + | 3 | 4 | 9 | 20 | 2 | 80 | + | 2 | 3 | 8 | 10 | 3 | 70 | + | 3 | 4 | 9 | 10 | 3 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_empty_left() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // (empty) + // +----+----+----+ + let left = build_table( + ("a1", &Vec::::new()), + ("b1", &Vec::::new()), + ("c1", &Vec::::new()), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 1 | 1 | 1 | + // | 2 | 2 | 2 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![1, 2]), + ("b1", &vec![1, 2]), + ("c2", &vec![1, 2]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + let (_, batches) = + join_collect(left, right, on, Operator::LtEq, JoinType::Inner).await?; + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_full_greater_than_equal_to() -> Result<()> { + // +----+----+-----+ + // | a1 | b1 | c1 | + // +----+----+-----+ + // | 1 | 1 | 100 | + // | 2 | 2 | 200 | + // +----+----+-----+ + let left = build_table( + ("a1", &vec![1, 2]), + ("b1", &vec![1, 2]), + ("c1", &vec![100, 200]), + ); + + // +----+----+-----+ + // | a2 | b1 | c2 | + // +----+----+-----+ + // | 10 | 3 | 300 | + // | 20 | 2 | 400 | + // +----+----+-----+ + let right = build_table( + ("a2", &vec![10, 20]), + ("b1", &vec![3, 2]), + ("c2", &vec![300, 400]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::Full).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+-----+----+----+-----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+-----+----+----+-----+ + | 2 | 2 | 200 | 20 | 2 | 400 | + | | | | 10 | 3 | 300 | + | 1 | 1 | 100 | | | | + +----+----+-----+----+----+-----+ + "#); + + Ok(()) + } + + #[tokio::test] + async fn join_left_greater_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 2 | 80 | + // | 30 | 1 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 2, 1]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Left).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 2 | 3 | 8 | 30 | 1 | 90 | + | 3 | 4 | 9 | 30 | 1 | 90 | + | 2 | 3 | 8 | 20 | 2 | 80 | + | 3 | 4 | 9 | 20 | 2 | 80 | + | 3 | 4 | 9 | 10 | 3 | 70 | + | 1 | 1 | 7 | | | | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_greater_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 5 | 70 | + // | 20 | 3 | 80 | + // | 30 | 2 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![5, 3, 2]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 2 | 3 | 8 | 30 | 2 | 90 | + | 3 | 4 | 9 | 30 | 2 | 90 | + | 3 | 4 | 9 | 20 | 3 | 80 | + | | | | 10 | 5 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_less_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 4 | 7 | + // | 2 | 3 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 3, 1]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 5 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 5]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 30 | 5 | 90 | + | 2 | 3 | 8 | 30 | 5 | 90 | + | 3 | 1 | 9 | 30 | 5 | 90 | + | 3 | 1 | 9 | 20 | 3 | 80 | + | 3 | 1 | 9 | 10 | 2 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_date32_inner_less_than() -> Result<()> { + // +----+-------+----+ + // | a1 | b1 | c1 | + // +----+-------+----+ + // | 1 | 19107 | 7 | + // | 2 | 19107 | 8 | + // | 3 | 19105 | 9 | + // +----+-------+----+ + let left = build_date_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![19107, 19107, 19105]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+-------+----+ + // | a2 | b1 | c2 | + // +----+-------+----+ + // | 10 | 19105 | 70 | + // | 20 | 19103 | 80 | + // | 30 | 19107 | 90 | + // +----+-------+----+ + let right = build_date_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![19105, 19103, 19107]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +------------+------------+------------+------------+------------+------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +------------+------------+------------+------------+------------+------------+ + | 1970-01-04 | 2022-04-23 | 1970-01-10 | 1970-01-31 | 2022-04-25 | 1970-04-01 | + +------------+------------+------------+------------+------------+------------+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_date64_inner_less_than() -> Result<()> { + // +----+---------------+----+ + // | a1 | b1 | c1 | + // +----+---------------+----+ + // | 1 | 1650903441000 | 7 | + // | 2 | 1650903441000 | 8 | + // | 3 | 1650703441000 | 9 | + // +----+---------------+----+ + let left = build_date64_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1650903441000, 1650903441000, 1650703441000]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+---------------+----+ + // | a2 | b1 | c2 | + // +----+---------------+----+ + // | 10 | 1650703441000 | 70 | + // | 20 | 1650503441000 | 80 | + // | 30 | 1650903441000 | 90 | + // +----+---------------+----+ + let right = build_date64_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![1650703441000, 1650503441000, 1650903441000]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.003 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_date64_right_less_than() -> Result<()> { + // +----+---------------+----+ + // | a1 | b1 | c1 | + // +----+---------------+----+ + // | 1 | 1650903441000 | 7 | + // | 2 | 1650703441000 | 8 | + // +----+---------------+----+ + let left = build_date64_table( + ("a1", &vec![1, 2]), + ("b1", &vec![1650903441000, 1650703441000]), + ("c1", &vec![7, 8]), + ); + + // +----+---------------+----+ + // | a2 | b1 | c2 | + // +----+---------------+----+ + // | 10 | 1650703441000 | 80 | + // | 20 | 1650903441000 | 90 | + // +----+---------------+----+ + let right = build_date64_table( + ("a2", &vec![10, 20]), + ("b1", &vec![1650703441000, 1650903441000]), + ("c2", &vec![80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.002 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.020 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + | | | | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.080 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ +"#); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs new file mode 100644 index 0000000000000..d4980032728ad --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -0,0 +1,730 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Array; +use arrow::{ + array::{ArrayRef, BooleanBufferBuilder, RecordBatch}, + compute::concat_batches, + util::bit_util, +}; +use arrow_schema::{SchemaRef, SortOptions}; +use datafusion_common::not_impl_err; +use datafusion_common::{internal_err, JoinSide, Result}; +use datafusion_execution::{ + memory_pool::{MemoryConsumer, MemoryReservation}, + SendableRecordBatchStream, +}; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::equivalence::join_equivalence_properties; +use datafusion_physical_expr::{ + LexOrdering, OrderingRequirements, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, +}; +use datafusion_physical_expr_common::physical_expr::fmt_sql; +use futures::TryStreamExt; +use parking_lot::Mutex; +use std::fmt::Formatter; +use std::sync::Arc; + +use crate::execution_plan::{boundedness_from_children, EmissionType}; + +use crate::joins::piecewise_merge_join::classic_join::{ + ClassicPWMJStream, PiecewiseMergeJoinStreamState, +}; +use crate::joins::piecewise_merge_join::utils::{ + build_visited_indices_map, is_existence_join, is_right_existence_join, +}; +use crate::joins::utils::symmetric_join_output_partitioning; +use crate::{ + joins::{ + utils::{build_join_schema, BuildProbeJoinMetrics, OnceAsync, OnceFut}, + SharedBitmapBuilder, + }, + metrics::ExecutionPlanMetricsSet, + spill::get_record_batch_memory_size, + ExecutionPlan, PlanProperties, +}; +use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; + + +/// `PiecewiseMergeJoinExec` is a join execution plan that only evaluates single range filter. +/// +/// The physical planner will choose to evalute this join when there is only one range predicate. This +/// is a binary expression which contains [`Operator::Lt`], [`Operator::LtEq`], [`Operator::Gt`], and +/// [`Operator::GtEq`].: +/// Examples: +/// - `col0` < `colb`, `col0` <= `colb`, `col0` > `colb`, `col0` >= `colb` +/// +/// Since the join only support range predicates, equijoins are not supported in `PiecewiseMergeJoinExec`, +/// however you can first evaluate another join and run `PiecewiseMergeJoinExec` if left with one range +/// predicate. +/// +/// # Execution Plan Inputs +/// For `PiecewiseMergeJoin` we label all right inputs as the `streamed' side and the left outputs as the +/// 'buffered' side. +/// +/// `PiecewiseMergeJoin` takes a sorted input for the side to be buffered and is able to sort streamed record +/// batches during processing. Sorted input must specifically be ascending/descending based on the operator. +/// +/// # Algorithms +/// Classic joins are processed differently compared to existence joins. +/// +/// ## Classic Joins (Inner, Full, Left, Right) +/// For classic joins we buffer the right side (buffered), and incrementally process the left side (streamed). +/// Every streamed batch is sorted so we can perform a sort merge algorithm. For the buffered side we want to +/// have it already sorted either ascending or descending based on the operator as this allows us to emit all +/// the rows from a given point to the end as matches. Sorting the streamed side allows us to start the pointer +/// from the previous row's match on the buffered side. +/// +/// For `Lt` (`<`) + `LtEq` (`<=`) operations both inputs are to be sorted in descending order and sorted in +/// ascending order for `Gt` (`>`) + `GtEq` (`>=`) than (`>`) operations. `SortExec` is used to enforce sorting +/// on the buffered side and streamed side is sorted in memory. +/// +/// The pseudocode for the algorithm looks like this: +/// +/// ```text +/// for stream_row in stream_batch: +/// for buffer_row in buffer_batch: +/// if compare(stream_row, probe_row): +/// output stream_row X buffer_batch[buffer_row:] +/// else: +/// continue +/// ``` +/// +/// The algorithm uses the streamed side to drive the loop. This is due to every row on the stream side iterating +/// the buffered side to find every first match. +/// +/// Here is an example: +/// +/// We perform a `JoinType::Left` with these two batches and the operator being `Operator::Lt`(<). For each +/// row on the streamed side we move a pointer on the buffered until it matches the condition. Once we reach +/// the row which matches (in this case with row 1 on streamed will have its first match on row 2 on +/// buffered; 100 < 200 is true), we can emit all rows after that match. We can emit the rows like this because +/// if the batch is sorted in ascending order, every subsequent row will also satisfy the condition as they will +/// all be larger values. +/// +/// ```text +/// SQL statement: +/// SELECT * +/// FROM (VALUES (100), (200), (500)) AS streamed(a) +/// LEFT JOIN (VALUES (100), (200), (200), (300), (400)) AS buffered(b) +/// ON streamed.a < buffered.b; +/// +/// Processing Row 1: +/// +/// Sorted Buffered Side Sorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ ─┐ 2 │ 200 │ +/// ├──────────────────┤ │ For row 1 on streamed side with ├──────────────────┤ +/// 3 │ 200 │ │ value 100, we emit rows 2 - 5. 3 │ 500 │ +/// ├──────────────────┤ │ as matches when the operator is └──────────────────┘ +/// 4 │ 300 │ │ `Operator::Lt` (<) Emitting all +/// ├──────────────────┤ │ rows after the first match (row +/// 5 │ 400 │ ─┘ 2 buffered side; 100 < 200) +/// └──────────────────┘ +/// +/// Processing Row 2: +/// By sorting the streamed side we know +/// +/// Sorted Buffered Side Sorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ <- Start here when probing for the 2 │ 200 │ +/// ├──────────────────┤ streamed side row 2. ├──────────────────┤ +/// 3 │ 200 │ 3 │ 500 │ +/// ├──────────────────┤ └──────────────────┘ +/// 4 │ 300 │ +/// ├──────────────────┤ +/// 5 │ 400 │ +/// └──────────────────┘ +/// +/// ``` +/// +/// ## Existence Joins (Semi, Anti, Mark) +/// Existence joins are made magnitudes of times faster with a `PiecewiseMergeJoin` as we only need to find +/// the min/max value of the streamed side to be able to emit all matches on the buffered side. By putting +/// the side we need to mark onto the sorted buffer side, we can emit all these matches at once. +/// +/// For less than operations (`<`) both inputs are to be sorted in descending order and vice versa for greater +/// than (`>`) operations. `SortExec` is used to enforce sorting on the buffered side and streamed side does not +/// need to be sorted due to only needing to find the min/max. +/// +/// For Left Semi, Anti, and Mark joins we swap the inputs so that the marked side is on the buffered side. +/// +/// The pseudocode for the algorithm looks like this: +/// +/// ```text +/// // Using the example of a less than `<` operation +/// let max = max_batch(streamed_batch) +/// +/// for buffer_row in buffer_batch: +/// if buffer_row < max: +/// output buffer_batch[buffer_row:] +/// ``` +/// +/// Only need to find the min/max value and iterate through the buffered side once. +/// +/// Here is an example: +/// We perform a `JoinType::LeftSemi` with these two batches and the operator being `Operator::Lt`(<). Because +/// the operator is `Operator::Lt` we can find the minimum value in the streamed side; in this case it is 200. +/// We can then advance a pointer from the start of the buffer side until we find the first value that satisfies +/// the predicate. All rows after that first matched value satisfy the condition 200 < x so we can mark all of +/// those rows as matched. +/// +/// ```text +/// SQL statement: +/// SELECT * +/// FROM (VALUES (500), (200), (300)) AS streamed(a) +/// LEFT SEMI JOIN (VALUES (100), (200), (200), (300), (400)) AS buffered(b) +/// ON streamed.a < buffered.b; +/// +/// Sorted Buffered Side Unsorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 500 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ 2 │ 200 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 3 │ 200 │ 3 │ 300 │ +/// ├──────────────────┤ └──────────────────┘ +/// 4 │ 300 │ ─┐ +/// ├──────────────────┤ | We emit matches for row 4 - 5 +/// 5 │ 400 │ ─┘ on the buffered side. +/// └──────────────────┘ +/// min value: 200 +/// ``` +/// +/// For both types of joins, the buffered side must be sorted ascending for `Operator::Lt` (<) or +/// `Operator::LtEq` (<=) and descending for `Operator::Gt` (>) or `Operator::GtEq` (>=). +/// +/// # Performance Explanation (cost) +/// Piecewise Merge Join is used over Nested Loop Join due to its superior performance. Here is the breakdown: +/// +/// ## Piecewise Merge Join (PWMJ) +/// # Classic Join: +/// Requires sorting the probe side and, for each probe row, scanning the buffered side until the first match +/// is found. +/// Complexity: `O(sort(S) + |S| * scan(R))`. +/// +/// # Mark Join: +/// Sorts the probe side, then computes the min/max range of the probe keys and scans the buffered side only +/// within that range. +/// Complexity: `O(|S| + scan(R[range]))`. +/// +/// ## Nested Loop Join +/// Compares every row from `S` with every row from `R`. +/// Complexity: `O(|S| * |R|)`. +/// +/// ## Nested Loop Join +/// Always going to be probe (O(N) * O(N)). +/// +/// # Further Reference Material +/// DuckDB blog on Range Joins: [Range Joins in DuckDB](https://duckdb.org/2022/05/27/iejoin.html) +#[derive(Debug)] +pub struct PiecewiseMergeJoinExec { + /// Left buffered execution plan + pub buffered: Arc, + /// Right streamed execution plan + pub streamed: Arc, + /// The two expressions being compared + pub on: (Arc, Arc), + /// Comparison operator in the range predicate + pub operator: Operator, + /// How the join is performed + pub join_type: JoinType, + /// The schema once the join is applied + schema: SchemaRef, + /// Buffered data + buffered_fut: OnceAsync, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// The left sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations + left_sort_exprs: LexOrdering, + /// The right sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations + /// Unsorted for mark joins + right_sort_exprs: LexOrdering, + /// Sort options of join columns used in sorting the stream and buffered execution plans + sort_options: SortOptions, + /// Cache holding plan properties like equivalences, output partitioning etc. + cache: PlanProperties, +} + +impl PiecewiseMergeJoinExec { + pub fn try_new( + buffered: Arc, + streamed: Arc, + on: (Arc, Arc), + operator: Operator, + join_type: JoinType, + ) -> Result { + // TODO: Implement existence joins for PiecewiseMergeJoin + if is_existence_join(join_type) { + return not_impl_err!( + "Existence Joins are currently not supported for PiecewiseMergeJoin" + ); + } + + // Take the operator and enforce a sort order on the streamed + buffered side based on + // the operator type. + let sort_options = match operator { + Operator::Lt | Operator::LtEq => { + // For left existence joins the inputs will be swapped so the sort + // options are switched + if is_right_existence_join(join_type) { + SortOptions::new(false, false) + } else { + SortOptions::new(true, false) + } + } + Operator::Gt | Operator::GtEq => { + if is_right_existence_join(join_type) { + SortOptions::new(true, false) + } else { + SortOptions::new(false, false) + } + } + _ => { + return internal_err!( + "Cannot contain non-range operator in PiecewiseMergeJoinExec" + ) + } + }; + + // Give the same `sort_option for comparison later` + let left_sort_exprs = + vec![PhysicalSortExpr::new(Arc::clone(&on.0), sort_options)]; + let right_sort_exprs = + vec![PhysicalSortExpr::new(Arc::clone(&on.1), sort_options)]; + + let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else { + return internal_err!( + "PiecewiseMergeJoinExec requires valid sort expressions for its left side" + ); + }; + let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else { + return internal_err!( + "PiecewiseMergeJoinExec requires valid sort expressions for its right side" + ); + }; + + let buffered_schema = buffered.schema(); + let streamed_schema = streamed.schema(); + + // Create output schema for the join + let schema = + Arc::new(build_join_schema(&buffered_schema, &streamed_schema, &join_type).0); + let cache = Self::compute_properties( + &buffered, + &streamed, + Arc::clone(&schema), + join_type, + &on, + )?; + + Ok(Self { + streamed, + buffered, + on, + operator, + join_type, + schema, + buffered_fut: Default::default(), + metrics: ExecutionPlanMetricsSet::new(), + left_sort_exprs, + right_sort_exprs, + sort_options, + cache, + }) + } + + /// Reference to buffered side execution plan + pub fn buffered(&self) -> &Arc { + &self.buffered + } + + /// Reference to streamed side execution plan + pub fn streamed(&self) -> &Arc { + &self.streamed + } + + /// Join type + pub fn join_type(&self) -> JoinType { + self.join_type + } + + /// Reference to sort options + pub fn sort_options(&self) -> &SortOptions { + &self.sort_options + } + + /// Get probe side (streamed side) for the PiecewiseMergeJoin + /// In current implementation, probe side is determined according to join type. + pub fn probe_side(join_type: &JoinType) -> JoinSide { + match join_type { + JoinType::Right + | JoinType::Inner + | JoinType::Full + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => JoinSide::Right, + JoinType::Left + | JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::LeftMark => JoinSide::Left, + } + } + + pub fn compute_properties( + buffered: &Arc, + streamed: &Arc, + schema: SchemaRef, + join_type: JoinType, + join_on: &(PhysicalExprRef, PhysicalExprRef), + ) -> Result { + let eq_properties = join_equivalence_properties( + buffered.equivalence_properties().clone(), + streamed.equivalence_properties().clone(), + &join_type, + schema, + &Self::maintains_input_order(join_type), + Some(Self::probe_side(&join_type)), + std::slice::from_ref(join_on), + )?; + + let output_partitioning = + symmetric_join_output_partitioning(buffered, streamed, &join_type)?; + + Ok(PlanProperties::new( + eq_properties, + output_partitioning, + EmissionType::Incremental, + boundedness_from_children([buffered, streamed]), + )) + } + + // TODO: Add input order + fn maintains_input_order(join_type: JoinType) -> Vec { + match join_type { + // The existence side is expected to come in sorted + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + vec![false, false] + } + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + vec![false, false] + } + // Left, Right, Full, Inner Join is not guaranteed to maintain + // input order as the streamed side will be sorted during + // execution for `PiecewiseMergeJoin` + _ => vec![false, false], + } + } + + // TODO: We implement this with the physical planner. + pub fn swap_inputs(&self) -> Result> { + todo!() + } +} + +impl ExecutionPlan for PiecewiseMergeJoinExec { + fn name(&self) -> &str { + "PiecewiseMergeJoinExec" + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.buffered, &self.streamed] + } + + fn required_input_ordering(&self) -> Vec> { + // Existence joins don't need to be sorted on one side. + if is_right_existence_join(self.join_type) { + // Right side needs to be sorted because this will be swapped to the + // buffered side + vec![ + None, + Some(OrderingRequirements::from(self.right_sort_exprs.clone())), + ] + } else { + // Sort the right side in memory, so we do not need to enforce any sorting + vec![ + Some(OrderingRequirements::from(self.left_sort_exprs.clone())), + None, + ] + } + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match &children[..] { + [left, right] => Ok(Arc::new(PiecewiseMergeJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + self.on.clone(), + self.operator, + self.join_type, + )?)), + _ => internal_err!( + "PiecewiseMergeJoin should have 2 children, found {}", + children.len() + ), + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let on_buffered = Arc::clone(&self.on.0); + let on_streamed = Arc::clone(&self.on.1); + + // If the join type is either RightSemi, RightAnti, or RightMark we will swap the inputs + // and sort ordering because we want the mark side to be the buffered side. + let (buffered, streamed, on_buffered, on_streamed, operator) = + if is_right_existence_join(self.join_type) { + ( + Arc::clone(&self.streamed), + Arc::clone(&self.buffered), + on_streamed, + on_buffered, + self.operator.swap().unwrap(), + ) + } else { + ( + Arc::clone(&self.buffered), + Arc::clone(&self.streamed), + on_buffered, + on_streamed, + self.operator, + ) + }; + + let metrics = BuildProbeJoinMetrics::new(0, &self.metrics); + let buffered_fut = self.buffered_fut.try_once(|| { + let reservation = MemoryConsumer::new("PiecewiseMergeJoinInput") + .register(context.memory_pool()); + let buffered_stream = buffered.execute(partition, Arc::clone(&context))?; + Ok(build_buffered_data( + buffered_stream, + Arc::clone(&on_buffered), + metrics.clone(), + reservation, + build_visited_indices_map(self.join_type), + )) + })?; + + let streamed = streamed.execute(partition, Arc::clone(&context))?; + + let batch_size = context.session_config().batch_size(); + + // TODO: Add existence joins + this is guarded at physical planner + if is_existence_join(self.join_type()) { + unreachable!() + } else { + Ok(Box::pin(ClassicPWMJStream::try_new( + Arc::clone(&self.schema), + on_streamed, + self.join_type, + operator, + streamed, + BufferedSide::Initial(BufferedSideInitialState { buffered_fut }), + PiecewiseMergeJoinStreamState::WaitBufferedSide, + self.sort_options, + metrics, + batch_size, + ))) + } + } +} + +impl DisplayAs for PiecewiseMergeJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + let on_str = format!( + "({} {} {})", + fmt_sql(self.on.0.as_ref()), + self.operator, + fmt_sql(self.on.1.as_ref()) + ); + + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "PiecewiseMergeJoin: operator={:?}, join_type={:?}, on={}", + self.operator, self.join_type, on_str + ) + } + + DisplayFormatType::TreeRender => { + writeln!(f, "operator={:?}", self.operator)?; + if self.join_type != JoinType::Inner { + writeln!(f, "join_type={:?}", self.join_type)?; + } + writeln!(f, "on={on_str}") + } + } + } +} + +async fn build_buffered_data( + buffered: SendableRecordBatchStream, + on_buffered: PhysicalExprRef, + metrics: BuildProbeJoinMetrics, + reservation: MemoryReservation, + build_map: bool, +) -> Result { + let schema = buffered.schema(); + + // Combine batches and record number of rows + let initial = (Vec::new(), 0, metrics, reservation); + let (batches, num_rows, metrics, mut reservation) = buffered + .try_fold(initial, |mut acc, batch| async { + let batch_size = get_record_batch_memory_size(&batch); + acc.3.try_grow(batch_size)?; + acc.2.build_mem_used.add(batch_size); + acc.2.build_input_batches.add(1); + acc.2.build_input_rows.add(batch.num_rows()); + // Update row count + acc.1 += batch.num_rows(); + // Push batch to output + acc.0.push(batch); + Ok(acc) + }) + .await?; + + let batches_iter = batches.iter().rev(); + let single_batch = concat_batches(&schema, batches_iter)?; + + // Evaluate physical expression on the buffered side. + let buffered_values = on_buffered + .evaluate(&single_batch)? + .into_array(single_batch.num_rows())?; + + // We add the single batch size + the memory of the join keys + // size of the size estimation + let size_estimation = get_record_batch_memory_size(&single_batch) + + buffered_values.get_array_memory_size(); + reservation.try_grow(size_estimation)?; + metrics.build_mem_used.add(size_estimation); + + // Created visited indices bitmap only if the join type requires it + let visited_indices_bitmap = if build_map { + let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); + reservation.try_grow(bitmap_size)?; + metrics.build_mem_used.add(bitmap_size); + + let mut bitmap_buffer = BooleanBufferBuilder::new(single_batch.num_rows()); + bitmap_buffer.append_n(num_rows, false); + bitmap_buffer + } else { + BooleanBufferBuilder::new(0) + }; + + let buffered_data = BufferedSideData::new( + single_batch, + buffered_values, + Mutex::new(visited_indices_bitmap), + reservation, + ); + + Ok(buffered_data) +} + +pub(super) struct BufferedSideData { + pub(super) batch: RecordBatch, + values: ArrayRef, + pub(super) visited_indices_bitmap: SharedBitmapBuilder, + _reservation: MemoryReservation, +} + +impl BufferedSideData { + pub(super) fn new( + batch: RecordBatch, + values: ArrayRef, + visited_indices_bitmap: SharedBitmapBuilder, + reservation: MemoryReservation, + ) -> Self { + Self { + batch, + values, + visited_indices_bitmap, + _reservation: reservation, + } + } + + pub(super) fn batch(&self) -> &RecordBatch { + &self.batch + } + + pub(super) fn values(&self) -> &ArrayRef { + &self.values + } +} + +pub(super) enum BufferedSide { + /// Indicates that build-side not collected yet + Initial(BufferedSideInitialState), + /// Indicates that build-side data has been collected + Ready(BufferedSideReadyState), +} + +impl BufferedSide { + // Takes a mutable state of the buffered row batches + pub(super) fn try_as_initial_mut(&mut self) -> Result<&mut BufferedSideInitialState> { + match self { + BufferedSide::Initial(state) => Ok(state), + _ => internal_err!("Expected build side in initial state"), + } + } + + pub(super) fn try_as_ready(&self) -> Result<&BufferedSideReadyState> { + match self { + BufferedSide::Ready(state) => Ok(state), + _ => { + internal_err!("Expected build side in ready state") + } + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + pub(super) fn try_as_ready_mut(&mut self) -> Result<&mut BufferedSideReadyState> { + match self { + BufferedSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } +} + +pub(super) struct BufferedSideInitialState { + pub(crate) buffered_fut: OnceFut, +} + +pub(super) struct BufferedSideReadyState { + /// Collected build-side data + pub(super) buffered_data: Arc, +} diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs new file mode 100644 index 0000000000000..f66de0ddab43c --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub use exec::PiecewiseMergeJoinExec; + +mod classic_join; +mod exec; +mod utils; diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs new file mode 100644 index 0000000000000..5bbb496322b5f --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::JoinType; + +// Returns boolean for whether the join is a right existence join +pub(super) fn is_right_existence_join(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark + ) +} + +// Returns boolean for whether the join is an existence join +pub(super) fn is_existence_join(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftMark + | JoinType::RightMark + ) +} + +// Returns boolean to check if the join type needs to record +// buffered side matches for classic joins +pub(super) fn need_produce_result_in_final(join_type: JoinType) -> bool { + matches!(join_type, JoinType::Full | JoinType::Left) +} + +// Returns boolean for whether or not we need to build the buffered side +// bitmap for marking matched rows on the buffered side. +pub(super) fn build_visited_indices_map(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::Full + | JoinType::Left + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftMark + | JoinType::RightMark + ) +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs index 334098150711c..7e645c632d8f9 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs @@ -34,7 +34,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics; -use crate::joins::utils::JoinFilter; +use crate::joins::utils::{compare_join_arrays, JoinFilter}; use crate::spill::spill_manager::SpillManager; use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream}; @@ -1849,95 +1849,6 @@ fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec Result { - let mut res = Ordering::Equal; - for ((left_array, right_array), sort_options) in - left_arrays.iter().zip(right_arrays).zip(sort_options) - { - macro_rules! compare_value { - ($T:ty) => {{ - let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); - let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); - match (left_array.is_null(left), right_array.is_null(right)) { - (false, false) => { - let left_value = &left_array.value(left); - let right_value = &right_array.value(right); - res = left_value.partial_cmp(right_value).unwrap(); - if sort_options.descending { - res = res.reverse(); - } - } - (true, false) => { - res = if sort_options.nulls_first { - Ordering::Less - } else { - Ordering::Greater - }; - } - (false, true) => { - res = if sort_options.nulls_first { - Ordering::Greater - } else { - Ordering::Less - }; - } - _ => { - res = match null_equality { - NullEquality::NullEqualsNothing => Ordering::Less, - NullEquality::NullEqualsNull => Ordering::Equal, - }; - } - } - }}; - } - - match left_array.data_type() { - DataType::Null => {} - DataType::Boolean => compare_value!(BooleanArray), - DataType::Int8 => compare_value!(Int8Array), - DataType::Int16 => compare_value!(Int16Array), - DataType::Int32 => compare_value!(Int32Array), - DataType::Int64 => compare_value!(Int64Array), - DataType::UInt8 => compare_value!(UInt8Array), - DataType::UInt16 => compare_value!(UInt16Array), - DataType::UInt32 => compare_value!(UInt32Array), - DataType::UInt64 => compare_value!(UInt64Array), - DataType::Float32 => compare_value!(Float32Array), - DataType::Float64 => compare_value!(Float64Array), - DataType::Utf8 => compare_value!(StringArray), - DataType::Utf8View => compare_value!(StringViewArray), - DataType::LargeUtf8 => compare_value!(LargeStringArray), - DataType::Decimal128(..) => compare_value!(Decimal128Array), - DataType::Timestamp(time_unit, None) => match time_unit { - TimeUnit::Second => compare_value!(TimestampSecondArray), - TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), - TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), - TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), - }, - DataType::Date32 => compare_value!(Date32Array), - DataType::Date64 => compare_value!(Date64Array), - dt => { - return not_impl_err!( - "Unsupported data type in sort merge join comparator: {}", - dt - ); - } - } - if !res.is_eq() { - break; - } - } - Ok(res) -} - /// A faster version of compare_join_arrays() that only output whether /// the given two rows are equal fn is_join_arrays_equal( diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index d392650f88dda..420825d831f89 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -17,7 +17,7 @@ //! Join related functionality used both on logical and physical plans -use std::cmp::min; +use std::cmp::{min, Ordering}; use std::collections::HashSet; use std::fmt::{self, Debug}; use std::future::Future; @@ -43,7 +43,12 @@ use arrow::array::{ BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, RecordBatchOptions, UInt32Array, UInt32Builder, UInt64Array, }; -use arrow::array::{ArrayRef, BooleanArray}; +use arrow::array::{ + ArrayRef, BooleanArray, Date32Array, Date64Array, Decimal128Array, Float32Array, + Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, + StringArray, StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt8Array, +}; use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::kernels::cmp::eq; use arrow::compute::{self, and, take, FilterBuilder}; @@ -51,12 +56,13 @@ use arrow::datatypes::{ ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type, }; use arrow_ord::cmp::not_distinct; -use arrow_schema::ArrowError; +use arrow_schema::{ArrowError, DataType, SortOptions, TimeUnit}; use datafusion_common::cast::as_boolean_array; use datafusion_common::hash_utils::create_hashes; use datafusion_common::stats::Precision; use datafusion_common::{ - plan_err, DataFusionError, JoinSide, JoinType, NullEquality, Result, SharedResult, + not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, NullEquality, Result, + SharedResult, }; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::Operator; @@ -284,7 +290,7 @@ pub fn build_join_schema( JoinType::LeftSemi | JoinType::LeftAnti => left_fields().unzip(), JoinType::LeftMark => { let right_field = once(( - Field::new("mark", arrow::datatypes::DataType::Boolean, false), + Field::new("mark", DataType::Boolean, false), ColumnIndex { index: 0, side: JoinSide::None, @@ -295,7 +301,7 @@ pub fn build_join_schema( JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(), JoinType::RightMark => { let left_field = once(( - Field::new("mark", arrow_schema::DataType::Boolean, false), + Field::new("mark", DataType::Boolean, false), ColumnIndex { index: 0, side: JoinSide::None, @@ -817,9 +823,10 @@ pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool { pub(crate) fn get_final_indices_from_shared_bitmap( shared_bitmap: &SharedBitmapBuilder, join_type: JoinType, + piecewise: bool, ) -> (UInt64Array, UInt32Array) { let bitmap = shared_bitmap.lock(); - get_final_indices_from_bit_map(&bitmap, join_type) + get_final_indices_from_bit_map(&bitmap, join_type, piecewise) } /// In the end of join execution, need to use bit map of the matched @@ -834,16 +841,22 @@ pub(crate) fn get_final_indices_from_shared_bitmap( pub(crate) fn get_final_indices_from_bit_map( left_bit_map: &BooleanBufferBuilder, join_type: JoinType, + // We add a flag for whether this is being passed from the `PiecewiseMergeJoin` + // because the bitmap can be for left + right `JoinType`s + piecewise: bool, ) -> (UInt64Array, UInt32Array) { let left_size = left_bit_map.len(); - if join_type == JoinType::LeftMark { + if join_type == JoinType::LeftMark || (join_type == JoinType::RightMark && piecewise) + { let left_indices = (0..left_size as u64).collect::(); let right_indices = (0..left_size) .map(|idx| left_bit_map.get_bit(idx).then_some(0)) .collect::(); return (left_indices, right_indices); } - let left_indices = if join_type == JoinType::LeftSemi { + let left_indices = if join_type == JoinType::LeftSemi + || (join_type == JoinType::RightSemi && piecewise) + { (0..left_size) .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64)) .collect::() @@ -1753,6 +1766,95 @@ fn eq_dyn_null( } } +/// Get comparison result of two rows of join arrays +pub fn compare_join_arrays( + left_arrays: &[ArrayRef], + left: usize, + right_arrays: &[ArrayRef], + right: usize, + sort_options: &[SortOptions], + null_equality: NullEquality, +) -> Result { + let mut res = Ordering::Equal; + for ((left_array, right_array), sort_options) in + left_arrays.iter().zip(right_arrays).zip(sort_options) + { + macro_rules! compare_value { + ($T:ty) => {{ + let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); + let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); + match (left_array.is_null(left), right_array.is_null(right)) { + (false, false) => { + let left_value = &left_array.value(left); + let right_value = &right_array.value(right); + res = left_value.partial_cmp(right_value).unwrap(); + if sort_options.descending { + res = res.reverse(); + } + } + (true, false) => { + res = if sort_options.nulls_first { + Ordering::Less + } else { + Ordering::Greater + }; + } + (false, true) => { + res = if sort_options.nulls_first { + Ordering::Greater + } else { + Ordering::Less + }; + } + _ => { + res = match null_equality { + NullEquality::NullEqualsNothing => Ordering::Less, + NullEquality::NullEqualsNull => Ordering::Equal, + }; + } + } + }}; + } + + match left_array.data_type() { + DataType::Null => {} + DataType::Boolean => compare_value!(BooleanArray), + DataType::Int8 => compare_value!(Int8Array), + DataType::Int16 => compare_value!(Int16Array), + DataType::Int32 => compare_value!(Int32Array), + DataType::Int64 => compare_value!(Int64Array), + DataType::UInt8 => compare_value!(UInt8Array), + DataType::UInt16 => compare_value!(UInt16Array), + DataType::UInt32 => compare_value!(UInt32Array), + DataType::UInt64 => compare_value!(UInt64Array), + DataType::Float32 => compare_value!(Float32Array), + DataType::Float64 => compare_value!(Float64Array), + DataType::Utf8 => compare_value!(StringArray), + DataType::Utf8View => compare_value!(StringViewArray), + DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Decimal128(..) => compare_value!(Decimal128Array), + DataType::Timestamp(time_unit, None) => match time_unit { + TimeUnit::Second => compare_value!(TimestampSecondArray), + TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), + TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), + TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), + }, + DataType::Date32 => compare_value!(Date32Array), + DataType::Date64 => compare_value!(Date64Array), + dt => { + return not_impl_err!( + "Unsupported data type in sort merge join comparator: {}", + dt + ); + } + } + if !res.is_eq() { + break; + } + } + Ok(res) +} + #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index ad21bdac6d2d7..daf4229cafd5d 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4148,10 +4148,11 @@ logical_plan 03)----TableScan: left_table projection=[a, b, c] 04)----TableScan: right_table projection=[x, y, z] physical_plan -01)SortExec: expr=[x@3 ASC NULLS LAST], preserve_partitioning=[false] -02)--NestedLoopJoinExec: join_type=Inner, filter=a@0 < x@1 -03)----DataSourceExec: partitions=1, partition_sizes=[0] -04)----DataSourceExec: partitions=1, partition_sizes=[0] +01)SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--PiecewiseMergeJoin: operator=Lt, join_type=Inner, on=(a < x) +03)----SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] +04)------DataSourceExec: partitions=1, partition_sizes=[0] +05)----DataSourceExec: partitions=1, partition_sizes=[0] query TT EXPLAIN SELECT * FROM left_table JOIN right_table ON left_table.a= t1.c2 LIMIT 20; +---- +01)CoalesceBatchesExec: target_batch_size=3, fetch=20 +01)Limit: skip=0, fetch=20 +02)--Full Join: t0.c1 = t1.c1 +02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)] +03)----DataSourceExec: partitions=1, partition_sizes=[2] +03)----TableScan: t0 projection=[c1, c2] +04)----DataSourceExec: partitions=1, partition_sizes=[2] +04)----TableScan: t1 projection=[c1, c2, c3] +logical_plan +physical_plan + query IIIIB rowsort -- Note: using LIMIT value higher than cardinality before LIMIT to avoid query non-determinism SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 20; @@ -4238,7 +4254,7 @@ physical_plan 03)----DataSourceExec: partitions=1, partition_sizes=[2] 04)----DataSourceExec: partitions=1, partition_sizes=[2] -## Test join.on.is_empty() && join.filter.is_some() +## Test join.on.is_empty() && join.filter.is_some() -> single filter now a PWMJ query TT EXPLAIN SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 2; ---- @@ -4249,9 +4265,10 @@ logical_plan 04)----TableScan: t1 projection=[c1, c2, c3] physical_plan 01)GlobalLimitExec: skip=0, fetch=2 -02)--NestedLoopJoinExec: join_type=Full, filter=c2@0 >= c2@1 -03)----DataSourceExec: partitions=1, partition_sizes=[2] -04)----DataSourceExec: partitions=1, partition_sizes=[2] +02)--PiecewiseMergeJoin: operator=GtEq, join_type=Full, on=(c2 >= c2) +03)----SortExec: expr=[c2@1 ASC NULLS LAST], preserve_partitioning=[false] +04)------DataSourceExec: partitions=1, partition_sizes=[2] +05)----DataSourceExec: partitions=1, partition_sizes=[2] ## Test !join.on.is_empty() && join.filter.is_some() query TT @@ -5161,6 +5178,44 @@ WHERE k1 < 0 ---- +# PiecewiseMergeJoin Test +statement ok +set datafusion.execution.batch_size = 8192; + +# TODO: partitioned PWMJ execution +statement ok +set datafusion.execution.target_partitions = 1; + +query II +SELECT join_t1.t1_id, join_t2.t2_id +FROM join_t1 +INNER JOIN join_t2 ON join_t1.t1_id > join_t2.t2_id +WHERE join_t1.t1_id > 10 AND join_t2.t2_int > 1 +ORDER BY 1 +---- +22 11 +33 11 +44 11 + +query TT +EXPLAIN +SELECT join_t1.t1_id, join_t2.t2_id +FROM join_t1 +INNER JOIN join_t2 ON join_t1.t1_id > join_t2.t2_id +WHERE join_t1.t1_id > 10 AND join_t2.t2_int > 1 +ORDER BY 1 +---- +physical_plan +01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--PiecewiseMergeJoin: operator=Gt, join_type=Inner, on=(t1_id > t2_id) +03)----SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------FilterExec: t1_id@0 > 10 +06)----------DataSourceExec: partitions=1, partition_sizes=[1] +07)----CoalesceBatchesExec: target_batch_size=8192 +08)------FilterExec: t2_int@1 > 1, projection=[t2_id@0] +09)--------DataSourceExec: partitions=1, partition_sizes=[1] + statement ok DROP TABLE t1; diff --git a/datafusion/sqllogictest/test_files/pwmj.slt b/datafusion/sqllogictest/test_files/pwmj.slt new file mode 100644 index 0000000000000..f597051e468e5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/pwmj.slt @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +statement ok +set datafusion.execution.target_partitions = 1; + +statement ok +CREATE TABLE join_t1 (t1_id INT); + +statement ok +CREATE TABLE join_t2 (t2_id INT, t2_name TEXT, t2_int INT); + +statement ok +INSERT INTO join_t1 VALUES (11), (22), (33), (44); + +statement ok +INSERT INTO join_t2 VALUES + (11, 'z', 3), + (22, 'y', 1), + (44, 'x', 3), + (55, 'w', 3); + +# --- sanity: ITI matches your sample ------------------------------------------------- +query ITI +SELECT t2_id, t2_name, t2_int +FROM join_t2 +ORDER BY t2_id; +---- +11 z 3 +22 y 1 +44 x 3 +55 w 3 + +# ===================================================================================== +# PWMJ candidates: exactly one join range predicate; any extra filters are per-side. +# ===================================================================================== + +# 1) GT with pushdowns (your exact example) +query II +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id > t2.t2_id -- single join predicate (range) +WHERE t1.t1_id > 10 -- left pushdown + AND t2.t2_int > 1 -- right pushdown +ORDER BY 1; +---- +22 11 +33 11 +44 11 + +# 2) GTE; right pushed to equality (t2_int = 3) +query II +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id >= t2.t2_id +WHERE t1.t1_id >= 22 + AND t2.t2_int = 3 +ORDER BY 1,2; +---- +22 11 +33 11 +44 11 +44 44 + +query II +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id < t2.t2_id +WHERE t2.t2_int >= 3 +ORDER BY 1,2; +---- +11 55 +11 44 +22 55 +22 44 +33 55 +33 44 +44 55 + +query TT +EXPLAIN +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id < t2.t2_id +WHERE t2.t2_int >= 3 +ORDER BY 1,2; +---- +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST, t2.t2_id ASC NULLS LAST +02)--Inner Join: Filter: t1.t1_id < t2.t2_id +03)----SubqueryAlias: t1 +04)------TableScan: join_t1 projection=[t1_id] +05)----SubqueryAlias: t2 +06)------Projection: join_t2.t2_id +07)--------Filter: join_t2.t2_int >= Int32(3) +08)----------TableScan: join_t2 projection=[t2_id, t2_int] +physical_plan +01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--PiecewiseMergeJoin: operator=Lt, join_type=Inner, on=(t1_id < t2_id) +03)----SortExec: expr=[t1_id@0 DESC NULLS LAST], preserve_partitioning=[false] +04)------DataSourceExec: partitions=1, partition_sizes=[1] +05)----CoalesceBatchesExec: target_batch_size=8192 +06)------FilterExec: t2_int@1 >= 3, projection=[t2_id@0] +07)--------DataSourceExec: partitions=1, partition_sizes=[1] + + +query II +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id <= t2.t2_id +WHERE t1.t1_id IN (11, 44) + AND t2.t2_name <> 'y' +ORDER BY 1,2; +---- +11 55 +11 44 +11 11 +44 55 +44 44 + + From f343f71f31ed9a552f0a9e328f55b45ba7737778 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Tue, 9 Sep 2025 00:30:21 -0400 Subject: [PATCH 02/20] fmt --- datafusion/core/src/physical_planner.rs | 3 +- .../piecewise_merge_join/classic_join.rs | 69 +++++++------------ .../src/joins/piecewise_merge_join/exec.rs | 1 - datafusion/sqllogictest/test_files/joins.slt | 11 +-- 4 files changed, 30 insertions(+), 54 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index f79197dc99675..71f7e37b36a76 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1296,12 +1296,11 @@ impl DefaultPhysicalPlanner { session_state.execution_props(), )?; - Arc::new(PiecewiseMergeJoinExec::try_new( physical_left, physical_right, (on_left, on_right), - op, + op, *join_type, )?) } else { diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index cde31a7f5df1b..5e6647ec4a204 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -35,14 +35,11 @@ use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::PhysicalExprRef; use futures::{Stream, StreamExt}; -use log::debug; use std::{cmp::Ordering, task::ready}; use std::{sync::Arc, task::Poll}; use crate::handle_state; -use crate::joins::piecewise_merge_join::exec::{ - BufferedSide, BufferedSideReadyState, -}; +use crate::joins::piecewise_merge_join::exec::{BufferedSide, BufferedSideReadyState}; use crate::joins::piecewise_merge_join::utils::need_produce_result_in_final; use crate::joins::utils::{compare_join_arrays, get_final_indices_from_shared_bitmap}; use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult}; @@ -108,7 +105,7 @@ pub(super) struct ClassicPWMJStream { join_metrics: BuildProbeJoinMetrics, // Tracking incremental state for emitting record batches batch_process_state: BatchProcessState, - // Creates batch size + // Creates batch size batch_size: usize, } @@ -298,8 +295,8 @@ impl ClassicPWMJStream { let empty_stream_batch = RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); - let buffered_chunk_ref = buffered_indices - .slice(start_idx, self.batch_size); + let buffered_chunk_ref = + buffered_indices.slice(start_idx, self.batch_size); let new_buffered_indices = buffered_chunk_ref .as_any() .downcast_ref::() @@ -316,9 +313,8 @@ impl ClassicPWMJStream { new_buffered_indices.clone(), )?; - self.batch_process_state.set_process_rest(Some( - start_idx + self.batch_size, - )); + self.batch_process_state + .set_process_rest(Some(start_idx + self.batch_size)); self.batch_process_state.continue_process = true; return Ok(StatefulStreamResult::Ready(Some(batch))); @@ -372,10 +368,8 @@ impl ClassicPWMJStream { let empty_stream_batch = RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); - let indices_chunk_ref = buffered_indices.slice( - self.batch_process_state.start_idx, - self.batch_size, - ); + let indices_chunk_ref = buffered_indices + .slice(self.batch_process_state.start_idx, self.batch_size); let indices_chunk = indices_chunk_ref .as_any() @@ -508,7 +502,7 @@ fn resolve_classic_join( let mut buffered_indices = UInt64Builder::default(); let mut stream_indices = UInt32Builder::default(); - debug!("wow!"); + // Our pivot variable allows us to start probing on the buffered side where we last matched // in the previous stream row. let mut pivot = batch_process_state.pivot(); @@ -524,18 +518,13 @@ fn resolve_classic_join( if let Some(start_idx) = batch_process_state.process_rest { let count = buffered_values.len() - start_idx; if count >= batch_size { - let stream_repeated = - vec![row_idx as u32; batch_size]; - batch_process_state.set_process_rest(Some( - start_idx + batch_size, - )); - batch_process_state.set_rows( - batch_process_state.num_rows - + batch_size, - ); + let stream_repeated = vec![row_idx as u32; batch_size]; + batch_process_state + .set_process_rest(Some(start_idx + batch_size)); + batch_process_state + .set_rows(batch_process_state.num_rows + batch_size); let buffered_range: Vec = (start_idx as u64 - ..((start_idx as u64) - + (batch_size as u64))) + ..((start_idx as u64) + (batch_size as u64))) .collect(); stream_indices.append_slice(&stream_repeated); buffered_indices.append_slice(&buffered_range); @@ -584,21 +573,15 @@ fn resolve_classic_join( // If the current output + new output is over our process value then we want to be // able to change that - if batch_process_state.num_rows + count - >= batch_size - { - let process_batch_size = batch_size - - batch_process_state.num_rows; + if batch_process_state.num_rows + count >= batch_size { + let process_batch_size = + batch_size - batch_process_state.num_rows; let stream_repeated = vec![row_idx as u32; process_batch_size]; batch_process_state.set_rows( batch_process_state.num_rows + process_batch_size, ); - debug!( - "pivot: {}, process_batch_size: {}", - pivot, process_batch_size - ); let buffered_range: Vec = (pivot as u64 ..(pivot + process_batch_size) as u64) .collect(); @@ -628,10 +611,7 @@ fn resolve_classic_join( // Update the number of rows processed batch_process_state .set_rows(batch_process_state.num_rows + count); - debug!( - "pivot: {}, process_batch_size: {}", - pivot, buffered_len - ); + let stream_repeated = vec![row_idx as u32; count]; let buffered_range: Vec = (pivot as u64..buffered_len as u64).collect(); @@ -649,15 +629,13 @@ fn resolve_classic_join( // If the current output + new output is over our process value then we want to be // able to change that - if batch_process_state.num_rows + count - >= batch_size - { + if batch_process_state.num_rows + count >= batch_size { // Update the start index so it repeats the process batch_process_state.set_start_idx(row_idx); batch_process_state.set_pivot(pivot); - let process_batch_size = batch_size - - batch_process_state.num_rows; + let process_batch_size = + batch_size - batch_process_state.num_rows; let stream_repeated = vec![row_idx as u32; process_batch_size]; batch_process_state @@ -717,8 +695,7 @@ fn resolve_classic_join( if (!found || batch_process_state.not_found) && matches!(join_type, JoinType::Right | JoinType::Full) { - let remaining = batch_size - .saturating_sub(batch_process_state.num_rows); + let remaining = batch_size.saturating_sub(batch_process_state.num_rows); if remaining == 0 { let batch = process_batch( &mut buffered_indices, diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index d4980032728ad..ee39aae1653e8 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -59,7 +59,6 @@ use crate::{ }; use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; - /// `PiecewiseMergeJoinExec` is a join execution plan that only evaluates single range filter. /// /// The physical planner will choose to evalute this join when there is only one range predicate. This diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index daf4229cafd5d..1be58f63d06ae 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4201,14 +4201,15 @@ query TT rowsort -- Note: using LIMIT value higher than cardinality before LIMIT to avoid query non-determinism EXPLAIN SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 20; ---- -01)CoalesceBatchesExec: target_batch_size=3, fetch=20 +01)GlobalLimitExec: skip=0, fetch=20 01)Limit: skip=0, fetch=20 -02)--Full Join: t0.c1 = t1.c1 -02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)] -03)----DataSourceExec: partitions=1, partition_sizes=[2] +02)--Full Join: Filter: t0.c2 >= t1.c2 +02)--PiecewiseMergeJoin: operator=GtEq, join_type=Full, on=(c2 >= c2) +03)----SortExec: expr=[c2@1 ASC NULLS LAST], preserve_partitioning=[false] 03)----TableScan: t0 projection=[c1, c2] -04)----DataSourceExec: partitions=1, partition_sizes=[2] +04)------DataSourceExec: partitions=1, partition_sizes=[2] 04)----TableScan: t1 projection=[c1, c2, c3] +05)----DataSourceExec: partitions=1, partition_sizes=[2] logical_plan physical_plan From 248ae4919220051aefbb0ac0b22dac794ab5eaf8 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Tue, 9 Sep 2025 02:23:03 -0400 Subject: [PATCH 03/20] clippy + fix test --- datafusion/core/src/physical_planner.rs | 4 +- .../piecewise_merge_join/classic_join.rs | 51 +++++++++++++++++++ .../src/joins/piecewise_merge_join/exec.rs | 2 +- datafusion/sqllogictest/test_files/joins.slt | 11 ++++ 4 files changed, 65 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 71f7e37b36a76..4ce845bbbdb28 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1287,12 +1287,12 @@ impl DefaultPhysicalPlanner { let on_left = create_physical_expr( lhs_logical, - &left_df_schema, + left_df_schema, session_state.execution_props(), )?; let on_right = create_physical_expr( rhs_logical, - &right_df_schema, + right_df_schema, session_state.execution_props(), )?; diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index 5e6647ec4a204..c1a2d9ffabca1 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -486,6 +486,7 @@ impl Stream for ClassicPWMJStream { } // For Left, Right, Full, and Inner joins, incoming stream batches will already be sorted. +#[allow(clippy::too_many_arguments)] fn resolve_classic_join( buffered_side: &mut BufferedSideReadyState, stream_batch: &StreamedBatch, @@ -1182,6 +1183,56 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_full_greater_than_equal_to_on_c2_limit20() -> Result<()> { + // t0 + // +----+----+ + // | c1 | c2 | + // +----+----+ + // | 1 | 1 | + // | 2 | 2 | + // | 3 | 3 | + // | 4 | 4 | + // +----+----+ + let right = build_table( + ("c1", &vec![1, 2, 3, 4]), + ("c2", &vec![1, 2, 3, 4]), + ("c3", &vec![1, 2, 3, 4]), + ); + + // t1 + // +------+----+------+ + // | c1_r | c2 | c3 | + // +------+----+------+ + // | 2 | 2 | false| + // | 2 | 2 | true | + // | 3 | 3 | false| + // | 3 | 3 | true | + // +------+----+------+ + let left = build_table( + ("c1_r", &vec![2, 2, 3, 3]), + ("c2", &vec![2, 2, 3, 3]), + ("c3", &vec![1, 2, 3, 4]), + ); + + // ON t0.c2 >= t1.c2 + let on = ( + Arc::new(Column::new_with_schema("c2", &left.schema())?) as _, + Arc::new(Column::new_with_schema("c2", &right.schema())?) as _, + ); + + // Run FULL join with >= + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::Full).await?; + + // LIMIT 20 in the SQL simply ensured "all rows" (we have 11), so we assert the full result set. + assert_snapshot!(batches_to_string(&batches), @r#" + + "#); + + Ok(()) + } + #[tokio::test] async fn join_left_greater_than() -> Result<()> { // +----+----+----+ diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index ee39aae1653e8..2e7362131991b 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -61,7 +61,7 @@ use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; /// `PiecewiseMergeJoinExec` is a join execution plan that only evaluates single range filter. /// -/// The physical planner will choose to evalute this join when there is only one range predicate. This +/// The physical planner will choose to evaluate this join when there is only one range predicate. This /// is a binary expression which contains [`Operator::Lt`], [`Operator::LtEq`], [`Operator::Gt`], and /// [`Operator::GtEq`].: /// Examples: diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 1be58f63d06ae..c9bbf3cf5734f 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4197,6 +4197,14 @@ SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 LIMIT 20; 3 3 3 3 true 4 4 NULL NULL NULL + +statement ok +set datafusion.execution.target_partitions = 1; + +statement ok +set datafusion.execution.batch_size = 4; + + query TT rowsort -- Note: using LIMIT value higher than cardinality before LIMIT to avoid query non-determinism EXPLAIN SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 20; @@ -4229,6 +4237,9 @@ SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 20; 4 4 3 3 false 4 4 3 3 true +statement ok +set datafusion.execution.batch_size = 3; + query IIIIB rowsort -- Note: using LIMIT value higher than cardinality before LIMIT to avoid query non-determinism SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 AND t0.c2 >= t1.c2 LIMIT 20; From 1020e651c7a9556630311944587017978f7b2cde Mon Sep 17 00:00:00 2001 From: Jonathan Date: Tue, 9 Sep 2025 02:52:34 -0400 Subject: [PATCH 04/20] fix tests --- .../piecewise_merge_join/classic_join.rs | 50 ------------------- datafusion/physical-plan/src/joins/utils.rs | 9 ++-- 2 files changed, 5 insertions(+), 54 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index c1a2d9ffabca1..c0fbadc34d8bd 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -1183,56 +1183,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn join_full_greater_than_equal_to_on_c2_limit20() -> Result<()> { - // t0 - // +----+----+ - // | c1 | c2 | - // +----+----+ - // | 1 | 1 | - // | 2 | 2 | - // | 3 | 3 | - // | 4 | 4 | - // +----+----+ - let right = build_table( - ("c1", &vec![1, 2, 3, 4]), - ("c2", &vec![1, 2, 3, 4]), - ("c3", &vec![1, 2, 3, 4]), - ); - - // t1 - // +------+----+------+ - // | c1_r | c2 | c3 | - // +------+----+------+ - // | 2 | 2 | false| - // | 2 | 2 | true | - // | 3 | 3 | false| - // | 3 | 3 | true | - // +------+----+------+ - let left = build_table( - ("c1_r", &vec![2, 2, 3, 3]), - ("c2", &vec![2, 2, 3, 3]), - ("c3", &vec![1, 2, 3, 4]), - ); - - // ON t0.c2 >= t1.c2 - let on = ( - Arc::new(Column::new_with_schema("c2", &left.schema())?) as _, - Arc::new(Column::new_with_schema("c2", &right.schema())?) as _, - ); - - // Run FULL join with >= - let (_, batches) = - join_collect(left, right, on, Operator::GtEq, JoinType::Full).await?; - - // LIMIT 20 in the SQL simply ensured "all rows" (we have 11), so we assert the full result set. - assert_snapshot!(batches_to_string(&batches), @r#" - - "#); - - Ok(()) - } - #[tokio::test] async fn join_left_greater_than() -> Result<()> { // +----+----+----+ diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 420825d831f89..005cb0a679f63 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -44,10 +44,7 @@ use arrow::array::{ UInt32Array, UInt32Builder, UInt64Array, }; use arrow::array::{ - ArrayRef, BooleanArray, Date32Array, Date64Array, Decimal128Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, - StringArray, StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt8Array, + ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, StringArray, StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt8Array }; use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::kernels::cmp::eq; @@ -1829,6 +1826,10 @@ pub fn compare_join_arrays( DataType::UInt64 => compare_value!(UInt64Array), DataType::Float32 => compare_value!(Float32Array), DataType::Float64 => compare_value!(Float64Array), + DataType::Binary => compare_value!(BinaryArray), + DataType::BinaryView => compare_value!(BinaryViewArray), + DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), + DataType::LargeBinary => compare_value!(LargeBinaryArray), DataType::Utf8 => compare_value!(StringArray), DataType::Utf8View => compare_value!(StringViewArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), From 29c0ff087f7d6a90c2e13dc47f9ccfa6cc829e68 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Tue, 9 Sep 2025 03:03:24 -0400 Subject: [PATCH 05/20] fmt --- datafusion/physical-plan/src/joins/utils.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 005cb0a679f63..b41a3e0514cf0 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -44,7 +44,11 @@ use arrow::array::{ UInt32Array, UInt32Builder, UInt64Array, }; use arrow::array::{ - ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, StringArray, StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt8Array + ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, + Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, StringArray, + StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt8Array, }; use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::kernels::cmp::eq; From cb94a2072717fc52b5cb4dd4e2e77cad62a11543 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Tue, 9 Sep 2025 13:58:55 -0400 Subject: [PATCH 06/20] clean up slt tests --- .../piecewise_merge_join/classic_join.rs | 1 - .../src/joins/piecewise_merge_join/exec.rs | 2 +- datafusion/sqllogictest/test_files/pwmj.slt | 116 ++++++++++++++---- 3 files changed, 96 insertions(+), 23 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index c0fbadc34d8bd..55c8245b45079 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -353,7 +353,6 @@ impl ClassicPWMJStream { return exec_err!("Batch process state should hold buffered indices"); } - // Pass in piecewise flag to allow Right Semi/Anti/Mark joins to also be processed let (buffered_indices, streamed_indices) = get_final_indices_from_shared_bitmap( &buffered_data.visited_indices_bitmap, self.join_type, diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index 2e7362131991b..4bcd1ffa6f801 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -433,7 +433,7 @@ impl PiecewiseMergeJoinExec { } } - // TODO: We implement this with the physical planner. + // TODO pub fn swap_inputs(&self) -> Result> { todo!() } diff --git a/datafusion/sqllogictest/test_files/pwmj.slt b/datafusion/sqllogictest/test_files/pwmj.slt index f597051e468e5..ee9622e6bb1b9 100644 --- a/datafusion/sqllogictest/test_files/pwmj.slt +++ b/datafusion/sqllogictest/test_files/pwmj.slt @@ -35,36 +35,50 @@ INSERT INTO join_t2 VALUES (44, 'x', 3), (55, 'w', 3); -# --- sanity: ITI matches your sample ------------------------------------------------- -query ITI -SELECT t2_id, t2_name, t2_int -FROM join_t2 -ORDER BY t2_id; ----- -11 z 3 -22 y 1 -44 x 3 -55 w 3 - -# ===================================================================================== -# PWMJ candidates: exactly one join range predicate; any extra filters are per-side. -# ===================================================================================== - -# 1) GT with pushdowns (your exact example) query II SELECT t1.t1_id, t2.t2_id FROM join_t1 t1 JOIN join_t2 t2 - ON t1.t1_id > t2.t2_id -- single join predicate (range) -WHERE t1.t1_id > 10 -- left pushdown - AND t2.t2_int > 1 -- right pushdown + ON t1.t1_id > t2.t2_id +WHERE t1.t1_id > 10 + AND t2.t2_int > 1 ORDER BY 1; ---- 22 11 33 11 44 11 -# 2) GTE; right pushed to equality (t2_int = 3) +query TT +EXPLAIN +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id > t2.t2_id +WHERE t1.t1_id > 10 + AND t2.t2_int > 1 +ORDER BY 1; +---- +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST +02)--Inner Join: Filter: t1.t1_id > t2.t2_id +03)----SubqueryAlias: t1 +04)------Filter: join_t1.t1_id > Int32(10) +05)--------TableScan: join_t1 projection=[t1_id] +06)----SubqueryAlias: t2 +07)------Projection: join_t2.t2_id +08)--------Filter: join_t2.t2_int > Int32(1) +09)----------TableScan: join_t2 projection=[t2_id, t2_int] +physical_plan +01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--PiecewiseMergeJoin: operator=Gt, join_type=Inner, on=(t1_id > t2_id) +03)----SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------FilterExec: t1_id@0 > 10 +06)----------DataSourceExec: partitions=1, partition_sizes=[1] +07)----CoalesceBatchesExec: target_batch_size=8192 +08)------FilterExec: t2_int@1 > 1, projection=[t2_id@0] +09)--------DataSourceExec: partitions=1, partition_sizes=[1] + query II SELECT t1.t1_id, t2.t2_id FROM join_t1 t1 @@ -79,6 +93,37 @@ ORDER BY 1,2; 44 11 44 44 +query TT +EXPLAIN +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id >= t2.t2_id +WHERE t1.t1_id >= 22 + AND t2.t2_int = 3 +ORDER BY 1,2; +---- +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST, t2.t2_id ASC NULLS LAST +02)--Inner Join: Filter: t1.t1_id >= t2.t2_id +03)----SubqueryAlias: t1 +04)------Filter: join_t1.t1_id >= Int32(22) +05)--------TableScan: join_t1 projection=[t1_id] +06)----SubqueryAlias: t2 +07)------Projection: join_t2.t2_id +08)--------Filter: join_t2.t2_int = Int32(3) +09)----------TableScan: join_t2 projection=[t2_id, t2_int] +physical_plan +01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--PiecewiseMergeJoin: operator=GtEq, join_type=Inner, on=(t1_id >= t2_id) +03)----SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------FilterExec: t1_id@0 >= 22 +06)----------DataSourceExec: partitions=1, partition_sizes=[1] +07)----CoalesceBatchesExec: target_batch_size=8192 +08)------FilterExec: t2_int@1 = 3, projection=[t2_id@0] +09)--------DataSourceExec: partitions=1, partition_sizes=[1] + query II SELECT t1.t1_id, t2.t2_id FROM join_t1 t1 @@ -138,4 +183,33 @@ ORDER BY 1,2; 44 55 44 44 - +query TT +EXPLAIN +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id <= t2.t2_id +WHERE t1.t1_id IN (11, 44) + AND t2.t2_name <> 'y' +ORDER BY 1,2; +---- +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST, t2.t2_id ASC NULLS LAST +02)--Inner Join: Filter: t1.t1_id <= t2.t2_id +03)----SubqueryAlias: t1 +04)------Filter: join_t1.t1_id = Int32(11) OR join_t1.t1_id = Int32(44) +05)--------TableScan: join_t1 projection=[t1_id] +06)----SubqueryAlias: t2 +07)------Projection: join_t2.t2_id +08)--------Filter: join_t2.t2_name != Utf8View("y") +09)----------TableScan: join_t2 projection=[t2_id, t2_name] +physical_plan +01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--PiecewiseMergeJoin: operator=LtEq, join_type=Inner, on=(t1_id <= t2_id) +03)----SortExec: expr=[t1_id@0 DESC NULLS LAST], preserve_partitioning=[false] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------FilterExec: t1_id@0 = 11 OR t1_id@0 = 44 +06)----------DataSourceExec: partitions=1, partition_sizes=[1] +07)----CoalesceBatchesExec: target_batch_size=8192 +08)------FilterExec: t2_name@1 != y, projection=[t2_id@0] +09)--------DataSourceExec: partitions=1, partition_sizes=[1] From 18ee4cb4bc71641b2bf8a16e97a432728bd7ffb6 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Sun, 14 Sep 2025 19:48:49 -0400 Subject: [PATCH 07/20] fixes --- datafusion/common/src/config.rs | 4 ++ datafusion/core/src/physical_planner.rs | 30 +++++++--- .../src/joins/piecewise_merge_join/exec.rs | 58 ++++++++++--------- .../src/joins/piecewise_merge_join/mod.rs | 2 + 4 files changed, 60 insertions(+), 34 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index abc8862b675ed..e0b52d3dfc3f6 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -831,6 +831,10 @@ config_namespace! { /// HashJoin can work more efficiently than SortMergeJoin but consumes more memory pub prefer_hash_join: bool, default = true + /// When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. + /// HashJoin can work more efficiently than SortMergeJoin but consumes more memory + pub allow_hash_join: bool, default = true + /// The maximum estimated size in bytes for one input side of a HashJoin /// will be collected into a single partition pub hash_join_single_partition_threshold: usize, default = 1024 * 1024 diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 4ce845bbbdb28..2d615c3e9317f 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1295,14 +1295,28 @@ impl DefaultPhysicalPlanner { right_df_schema, session_state.execution_props(), )?; - - Arc::new(PiecewiseMergeJoinExec::try_new( - physical_left, - physical_right, - (on_left, on_right), - op, - *join_type, - )?) + if matches!( + join_type, + JoinType::RightAnti + | JoinType::RightSemi + | JoinType::RightMark + ) { + Arc::new(PiecewiseMergeJoinExec::try_new( + physical_right, + physical_left, + (on_right, on_left), + op, + *join_type, + )?) + } else { + Arc::new(PiecewiseMergeJoinExec::try_new( + physical_left, + physical_right, + (on_left, on_right), + op, + *join_type, + )?) + } } else { // there is no equal join condition, use the nested loop join Arc::new(NestedLoopJoinExec::try_new( diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index 4bcd1ffa6f801..c33d12b3b5841 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -59,18 +59,15 @@ use crate::{ }; use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; -/// `PiecewiseMergeJoinExec` is a join execution plan that only evaluates single range filter. +/// `PiecewiseMergeJoinExec` is a join execution plan that only evaluates single range filter and show much +/// better performance for these workloads than `NestedLoopJoin` /// -/// The physical planner will choose to evaluate this join when there is only one range predicate. This +/// The physical planner will choose to evaluate this join when there is only one comparison filter. This /// is a binary expression which contains [`Operator::Lt`], [`Operator::LtEq`], [`Operator::Gt`], and /// [`Operator::GtEq`].: /// Examples: /// - `col0` < `colb`, `col0` <= `colb`, `col0` > `colb`, `col0` >= `colb` /// -/// Since the join only support range predicates, equijoins are not supported in `PiecewiseMergeJoinExec`, -/// however you can first evaluate another join and run `PiecewiseMergeJoinExec` if left with one range -/// predicate. -/// /// # Execution Plan Inputs /// For `PiecewiseMergeJoin` we label all right inputs as the `streamed' side and the left outputs as the /// 'buffered' side. @@ -82,15 +79,19 @@ use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; /// Classic joins are processed differently compared to existence joins. /// /// ## Classic Joins (Inner, Full, Left, Right) -/// For classic joins we buffer the right side (buffered), and incrementally process the left side (streamed). -/// Every streamed batch is sorted so we can perform a sort merge algorithm. For the buffered side we want to -/// have it already sorted either ascending or descending based on the operator as this allows us to emit all -/// the rows from a given point to the end as matches. Sorting the streamed side allows us to start the pointer -/// from the previous row's match on the buffered side. -/// -/// For `Lt` (`<`) + `LtEq` (`<=`) operations both inputs are to be sorted in descending order and sorted in -/// ascending order for `Gt` (`>`) + `GtEq` (`>=`) than (`>`) operations. `SortExec` is used to enforce sorting -/// on the buffered side and streamed side is sorted in memory. +/// For classic joins we buffer the right side (the "build" side) and stream the left side (the "probe" side). +/// Both sides are sorted so that we can iterate from index 0 to the end on each side. This ordering ensures +/// that when we find the first matching pair of rows, we can emit the current left row joined with all remaining +/// right rows from the match position onward, without rescanning earlier right rows. +/// +/// For `<` and `<=` operators, both inputs are sorted in **descending** order, while for `>` and `>=` operators +/// they are sorted in **ascending** order. This choice ensures that the pointer on the buffered side can advance +/// monotonically as we stream new batches from the left side. +/// +/// The streamed (left) side may arrive unsorted, so this operator sorts each incoming batch in memory before +/// processing. The buffered (right) side is required to be globally sorted; the plan declares this requirement +/// in `requires_input_order`, which allows the optimizer to automatically insert a `SortExec` on that side if needed. +/// By the time this operator runs, the right side is guaranteed to be in the proper order. /// /// The pseudocode for the algorithm looks like this: /// @@ -103,8 +104,9 @@ use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; /// continue /// ``` /// -/// The algorithm uses the streamed side to drive the loop. This is due to every row on the stream side iterating -/// the buffered side to find every first match. +/// The algorithm uses the streamed side (larger) to drive the loop. This is due to every row on the stream side iterating +/// the buffered side to find every first match. By doing this, each match can output more result so that output +/// handling can be better vectorized for performance. /// /// Here is an example: /// @@ -214,11 +216,15 @@ use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; /// # Performance Explanation (cost) /// Piecewise Merge Join is used over Nested Loop Join due to its superior performance. Here is the breakdown: /// +/// R: Buffered Side +/// S: Streamed Side +/// /// ## Piecewise Merge Join (PWMJ) +/// /// # Classic Join: /// Requires sorting the probe side and, for each probe row, scanning the buffered side until the first match /// is found. -/// Complexity: `O(sort(S) + |S| * scan(R))`. +/// Complexity: `O(sort(S) + num_of_batches(|S|) * scan(R))`. /// /// # Mark Join: /// Sorts the probe side, then computes the min/max range of the probe keys and scans the buffered side only @@ -230,7 +236,7 @@ use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; /// Complexity: `O(|S| * |R|)`. /// /// ## Nested Loop Join -/// Always going to be probe (O(N) * O(N)). +/// Always going to be probe (O(S) * O(R)). /// /// # Further Reference Material /// DuckDB blog on Range Joins: [Range Joins in DuckDB](https://duckdb.org/2022/05/27/iejoin.html) @@ -252,12 +258,17 @@ pub struct PiecewiseMergeJoinExec { buffered_fut: OnceAsync, /// Execution metrics metrics: ExecutionPlanMetricsSet, + + /// Sort expressions - See above for more details [`PiecewiseMergeJoinExec`] + /// /// The left sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations left_sort_exprs: LexOrdering, /// The right sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations /// Unsorted for mark joins + #[allow(unused)] right_sort_exprs: LexOrdering, - /// Sort options of join columns used in sorting the stream and buffered execution plans + + /// This determines the sort order of all join columns used in sorting the stream and buffered execution plans. sort_options: SortOptions, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, @@ -459,12 +470,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { fn required_input_ordering(&self) -> Vec> { // Existence joins don't need to be sorted on one side. if is_right_existence_join(self.join_type) { - // Right side needs to be sorted because this will be swapped to the - // buffered side - vec![ - None, - Some(OrderingRequirements::from(self.right_sort_exprs.clone())), - ] + unimplemented!() } else { // Sort the right side in memory, so we do not need to enforce any sorting vec![ diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs index f66de0ddab43c..c85a7cc16f657 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! PiecewiseMergeJoin is currently experimental + pub use exec::PiecewiseMergeJoinExec; mod classic_join; From e3d8606207eb86674d202b71ff184ec9e7f24fab Mon Sep 17 00:00:00 2001 From: Jonathan Date: Sun, 14 Sep 2025 22:58:36 -0400 Subject: [PATCH 08/20] remove swap --- datafusion/common/src/config.rs | 4 -- datafusion/core/src/physical_planner.rs | 30 +++-------- .../src/joins/piecewise_merge_join/exec.rs | 52 +++++++------------ 3 files changed, 27 insertions(+), 59 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index e0b52d3dfc3f6..abc8862b675ed 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -831,10 +831,6 @@ config_namespace! { /// HashJoin can work more efficiently than SortMergeJoin but consumes more memory pub prefer_hash_join: bool, default = true - /// When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. - /// HashJoin can work more efficiently than SortMergeJoin but consumes more memory - pub allow_hash_join: bool, default = true - /// The maximum estimated size in bytes for one input side of a HashJoin /// will be collected into a single partition pub hash_join_single_partition_threshold: usize, default = 1024 * 1024 diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 2d615c3e9317f..4ce845bbbdb28 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1295,28 +1295,14 @@ impl DefaultPhysicalPlanner { right_df_schema, session_state.execution_props(), )?; - if matches!( - join_type, - JoinType::RightAnti - | JoinType::RightSemi - | JoinType::RightMark - ) { - Arc::new(PiecewiseMergeJoinExec::try_new( - physical_right, - physical_left, - (on_right, on_left), - op, - *join_type, - )?) - } else { - Arc::new(PiecewiseMergeJoinExec::try_new( - physical_left, - physical_right, - (on_left, on_right), - op, - *join_type, - )?) - } + + Arc::new(PiecewiseMergeJoinExec::try_new( + physical_left, + physical_right, + (on_left, on_right), + op, + *join_type, + )?) } else { // there is no equal join condition, use the nested loop join Arc::new(NestedLoopJoinExec::try_new( diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index c33d12b3b5841..a51b35d5ab7b7 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -262,11 +262,11 @@ pub struct PiecewiseMergeJoinExec { /// Sort expressions - See above for more details [`PiecewiseMergeJoinExec`] /// /// The left sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations - left_sort_exprs: LexOrdering, + left_child_plan_required_order: LexOrdering, /// The right sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations /// Unsorted for mark joins #[allow(unused)] - right_sort_exprs: LexOrdering, + ight_batch_required_orders: LexOrdering, /// This determines the sort order of all join columns used in sorting the stream and buffered execution plans. sort_options: SortOptions, @@ -316,17 +316,21 @@ impl PiecewiseMergeJoinExec { }; // Give the same `sort_option for comparison later` - let left_sort_exprs = + let left_child_plan_required_order = vec![PhysicalSortExpr::new(Arc::clone(&on.0), sort_options)]; - let right_sort_exprs = + let ight_batch_required_orders = vec![PhysicalSortExpr::new(Arc::clone(&on.1), sort_options)]; - let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else { + let Some(left_child_plan_required_order) = + LexOrdering::new(left_child_plan_required_order) + else { return internal_err!( "PiecewiseMergeJoinExec requires valid sort expressions for its left side" ); }; - let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else { + let Some(ight_batch_required_orders) = + LexOrdering::new(ight_batch_required_orders) + else { return internal_err!( "PiecewiseMergeJoinExec requires valid sort expressions for its right side" ); @@ -355,8 +359,8 @@ impl PiecewiseMergeJoinExec { schema, buffered_fut: Default::default(), metrics: ExecutionPlanMetricsSet::new(), - left_sort_exprs, - right_sort_exprs, + left_child_plan_required_order, + ight_batch_required_orders, sort_options, cache, }) @@ -474,7 +478,9 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { } else { // Sort the right side in memory, so we do not need to enforce any sorting vec![ - Some(OrderingRequirements::from(self.left_sort_exprs.clone())), + Some(OrderingRequirements::from( + self.left_child_plan_required_order.clone(), + )), None, ] } @@ -507,32 +513,12 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { let on_buffered = Arc::clone(&self.on.0); let on_streamed = Arc::clone(&self.on.1); - // If the join type is either RightSemi, RightAnti, or RightMark we will swap the inputs - // and sort ordering because we want the mark side to be the buffered side. - let (buffered, streamed, on_buffered, on_streamed, operator) = - if is_right_existence_join(self.join_type) { - ( - Arc::clone(&self.streamed), - Arc::clone(&self.buffered), - on_streamed, - on_buffered, - self.operator.swap().unwrap(), - ) - } else { - ( - Arc::clone(&self.buffered), - Arc::clone(&self.streamed), - on_buffered, - on_streamed, - self.operator, - ) - }; - let metrics = BuildProbeJoinMetrics::new(0, &self.metrics); let buffered_fut = self.buffered_fut.try_once(|| { let reservation = MemoryConsumer::new("PiecewiseMergeJoinInput") .register(context.memory_pool()); - let buffered_stream = buffered.execute(partition, Arc::clone(&context))?; + let buffered_stream = + self.buffered.execute(partition, Arc::clone(&context))?; Ok(build_buffered_data( buffered_stream, Arc::clone(&on_buffered), @@ -542,7 +528,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { )) })?; - let streamed = streamed.execute(partition, Arc::clone(&context))?; + let streamed = self.streamed.execute(partition, Arc::clone(&context))?; let batch_size = context.session_config().batch_size(); @@ -554,7 +540,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { Arc::clone(&self.schema), on_streamed, self.join_type, - operator, + self.operator, streamed, BufferedSide::Initial(BufferedSideInitialState { buffered_fut }), PiecewiseMergeJoinStreamState::WaitBufferedSide, From 0834b983601484fb77b8fd198775a5fc09bd3021 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 17 Sep 2025 20:34:56 -0400 Subject: [PATCH 09/20] change varialbe names --- .../piecewise_merge_join/classic_join.rs | 109 ++++++++++-------- 1 file changed, 62 insertions(+), 47 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index 55c8245b45079..1864be82f3af4 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -285,9 +285,9 @@ impl ClassicPWMJStream { Arc::clone(&self.buffered_side.try_as_ready().unwrap().buffered_data); // Check if the same batch needs to be checked for values again - if let Some(start_idx) = self.batch_process_state.process_rest { + if let Some(stream_idx) = self.batch_process_state.process_rest { if let Some(buffered_indices) = &self.batch_process_state.buffered_indices { - let remaining = buffered_indices.len() - start_idx; + let remaining = buffered_indices.len() - stream_idx; // Branch into this and return value if there are more rows to deal with if remaining > self.batch_size { @@ -296,7 +296,7 @@ impl ClassicPWMJStream { RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); let buffered_chunk_ref = - buffered_indices.slice(start_idx, self.batch_size); + buffered_indices.slice(stream_idx, self.batch_size); let new_buffered_indices = buffered_chunk_ref .as_any() .downcast_ref::() @@ -314,7 +314,7 @@ impl ClassicPWMJStream { )?; self.batch_process_state - .set_process_rest(Some(start_idx + self.batch_size)); + .set_process_rest(Some(stream_idx + self.batch_size)); self.batch_process_state.continue_process = true; return Ok(StatefulStreamResult::Ready(Some(batch))); @@ -324,7 +324,7 @@ impl ClassicPWMJStream { let empty_stream_batch = RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); - let buffered_chunk_ref = buffered_indices.slice(start_idx, remaining); + let buffered_chunk_ref = buffered_indices.slice(stream_idx, remaining); let new_buffered_indices = buffered_chunk_ref .as_any() .downcast_ref::() @@ -368,7 +368,7 @@ impl ClassicPWMJStream { RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); let indices_chunk_ref = buffered_indices - .slice(self.batch_process_state.start_idx, self.batch_size); + .slice(self.batch_process_state.stream_idx, self.batch_size); let indices_chunk = indices_chunk_ref .as_any() @@ -412,11 +412,25 @@ impl ClassicPWMJStream { } // Holds all information for processing incremental output +// +// Responsibilities: +// - Keeps track of the current stream row index (`stream_idx`) so we can resume +// processing the same stream batch if we return early due to `batch_size`. +// - Remembers the last buffered row index we probed (`buffer_idx`) so we don’t +// restart from 0 for every stream row. +// - Stores `process_rest` to continue outputting matches for the same stream row +// when we previously hit the batch size limit mid-output. +// - Tracks how many rows we’ve produced so far (`num_rows`) to know when to flush. +// - Uses `not_found` to signal that the last stream row had no matches and we +// may need to emit NULLs for RIGHT/FULL OUTER joins. +// - Uses `continue_process` to tell the executor we are not done yet and must +// be called again with the same stream batch. +// - Optionally holds `buffered_indices` when resuming output of remaining matches. struct BatchProcessState { // Used to pick up from the last index on the stream side - start_idx: usize, + stream_idx: usize, // Used to pick up from the last index on the buffered side - pivot: usize, + buffer_idx: usize, // Tracks the number of rows processed; default starts at 0 num_rows: usize, // Processes the rest of the batch @@ -432,9 +446,9 @@ struct BatchProcessState { impl BatchProcessState { pub fn new() -> Self { Self { - start_idx: 0, + stream_idx: 0, num_rows: 0, - pivot: 0, + buffer_idx: 0, process_rest: None, not_found: false, continue_process: false, @@ -443,25 +457,25 @@ impl BatchProcessState { } fn reset(&mut self) { - self.start_idx = 0; + self.stream_idx = 0; self.num_rows = 0; - self.pivot = 0; + self.buffer_idx = 0; self.process_rest = None; self.not_found = false; self.continue_process = false; self.buffered_indices = None; } - fn pivot(&self) -> usize { - self.pivot + fn buffer_idx(&self) -> usize { + self.buffer_idx } - fn set_pivot(&mut self, pivot: usize) { - self.pivot = pivot; + fn set_buffer_idx(&mut self, buffer_idx: usize) { + self.buffer_idx = buffer_idx; } - fn set_start_idx(&mut self, start_idx: usize) { - self.start_idx = start_idx; + fn set_stream_idx(&mut self, stream_idx: usize) { + self.stream_idx = stream_idx; } fn set_rows(&mut self, num_rows: usize) { @@ -485,6 +499,7 @@ impl Stream for ClassicPWMJStream { } // For Left, Right, Full, and Inner joins, incoming stream batches will already be sorted. +// #[allow(clippy::too_many_arguments)] fn resolve_classic_join( buffered_side: &mut BufferedSideReadyState, @@ -503,28 +518,28 @@ fn resolve_classic_join( let mut buffered_indices = UInt64Builder::default(); let mut stream_indices = UInt32Builder::default(); - // Our pivot variable allows us to start probing on the buffered side where we last matched + // Our buffer_idx variable allows us to start probing on the buffered side where we last matched // in the previous stream row. - let mut pivot = batch_process_state.pivot(); - for row_idx in batch_process_state.start_idx..stream_values[0].len() { + let mut buffer_idx = batch_process_state.buffer_idx(); + for row_idx in batch_process_state.stream_idx..stream_values[0].len() { let mut found = false; // Check once to see if it is a redo of a null value if not we do not try to process the batch if !batch_process_state.not_found { - while pivot < buffered_values.len() + while buffer_idx < buffered_values.len() || batch_process_state.process_rest.is_some() { // If there is still data left in the batch to process, use the index and output - if let Some(start_idx) = batch_process_state.process_rest { - let count = buffered_values.len() - start_idx; + if let Some(stream_idx) = batch_process_state.process_rest { + let count = buffered_values.len() - stream_idx; if count >= batch_size { let stream_repeated = vec![row_idx as u32; batch_size]; batch_process_state - .set_process_rest(Some(start_idx + batch_size)); + .set_process_rest(Some(stream_idx + batch_size)); batch_process_state .set_rows(batch_process_state.num_rows + batch_size); - let buffered_range: Vec = (start_idx as u64 - ..((start_idx as u64) + (batch_size as u64))) + let buffered_range: Vec = (stream_idx as u64 + ..((stream_idx as u64) + (batch_size as u64))) .collect(); stream_indices.append_slice(&stream_repeated); buffered_indices.append_slice(&buffered_range); @@ -546,7 +561,7 @@ fn resolve_classic_join( batch_process_state.set_rows(batch_process_state.num_rows + count); let stream_repeated = vec![row_idx as u32; count]; let buffered_range: Vec = - (start_idx as u64..buffered_len as u64).collect(); + (stream_idx as u64..buffered_len as u64).collect(); stream_indices.append_slice(&stream_repeated); buffered_indices.append_slice(&buffered_range); batch_process_state.process_rest = None; @@ -560,7 +575,7 @@ fn resolve_classic_join( &[Arc::clone(&stream_values[0])], row_idx, &[Arc::clone(buffered_values)], - pivot, + buffer_idx, &[sort_options], NullEquality::NullEqualsNothing, )?; @@ -569,7 +584,7 @@ fn resolve_classic_join( match operator { Operator::Gt | Operator::Lt => { if matches!(compare, Ordering::Less) { - let count = buffered_values.len() - pivot; + let count = buffered_values.len() - buffer_idx; // If the current output + new output is over our process value then we want to be // able to change that @@ -582,8 +597,8 @@ fn resolve_classic_join( batch_process_state.num_rows + process_batch_size, ); - let buffered_range: Vec = (pivot as u64 - ..(pivot + process_batch_size) as u64) + let buffered_range: Vec = (buffer_idx as u64 + ..(buffer_idx + process_batch_size) as u64) .collect(); stream_indices.append_slice(&stream_repeated); buffered_indices.append_slice(&buffered_range); @@ -598,11 +613,11 @@ fn resolve_classic_join( )?; batch_process_state - .set_process_rest(Some(pivot + process_batch_size)); + .set_process_rest(Some(buffer_idx + process_batch_size)); batch_process_state.continue_process = true; // Update the start index so it repeats the process - batch_process_state.set_start_idx(row_idx); - batch_process_state.set_pivot(pivot); + batch_process_state.set_stream_idx(row_idx); + batch_process_state.set_buffer_idx(buffer_idx); batch_process_state.set_rows(0); return Ok(batch); @@ -614,7 +629,7 @@ fn resolve_classic_join( let stream_repeated = vec![row_idx as u32; count]; let buffered_range: Vec = - (pivot as u64..buffered_len as u64).collect(); + (buffer_idx as u64..buffered_len as u64).collect(); stream_indices.append_slice(&stream_repeated); buffered_indices.append_slice(&buffered_range); @@ -625,26 +640,26 @@ fn resolve_classic_join( } Operator::GtEq | Operator::LtEq => { if matches!(compare, Ordering::Equal | Ordering::Less) { - let count = buffered_values.len() - pivot; + let count = buffered_values.len() - buffer_idx; // If the current output + new output is over our process value then we want to be // able to change that if batch_process_state.num_rows + count >= batch_size { // Update the start index so it repeats the process - batch_process_state.set_start_idx(row_idx); - batch_process_state.set_pivot(pivot); + batch_process_state.set_stream_idx(row_idx); + batch_process_state.set_buffer_idx(buffer_idx); let process_batch_size = batch_size - batch_process_state.num_rows; let stream_repeated = vec![row_idx as u32; process_batch_size]; batch_process_state - .set_process_rest(Some(pivot + process_batch_size)); + .set_process_rest(Some(buffer_idx + process_batch_size)); batch_process_state.set_rows( batch_process_state.num_rows + process_batch_size, ); - let buffered_range: Vec = (pivot as u64 - ..(pivot + process_batch_size) as u64) + let buffered_range: Vec = (buffer_idx as u64 + ..(buffer_idx + process_batch_size) as u64) .collect(); stream_indices.append_slice(&stream_repeated); buffered_indices.append_slice(&buffered_range); @@ -669,7 +684,7 @@ fn resolve_classic_join( .set_rows(batch_process_state.num_rows + count); let stream_repeated = vec![row_idx as u32; count]; let buffered_range: Vec = - (pivot as u64..buffered_len as u64).collect(); + (buffer_idx as u64..buffered_len as u64).collect(); stream_indices.append_slice(&stream_repeated); buffered_indices.append_slice(&buffered_range); @@ -686,8 +701,8 @@ fn resolve_classic_join( } }; - // Increment pivot after every row - pivot += 1; + // Increment buffer_idx after every row + buffer_idx += 1; } } @@ -707,8 +722,8 @@ fn resolve_classic_join( )?; // Update the start index so it repeats the process - batch_process_state.set_start_idx(row_idx); - batch_process_state.set_pivot(pivot); + batch_process_state.set_stream_idx(row_idx); + batch_process_state.set_buffer_idx(buffer_idx); batch_process_state.not_found = true; batch_process_state.continue_process = true; batch_process_state.set_rows(0); From 120296282fa6f594120aa6429a581feba5992a3c Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 17 Sep 2025 20:47:46 -0400 Subject: [PATCH 10/20] add flag --- datafusion/common/src/config.rs | 4 ++++ datafusion/core/src/physical_planner.rs | 1 + 2 files changed, 5 insertions(+) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 6abb2f5c6d3ca..e51b2829ccdf6 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -832,6 +832,10 @@ config_namespace! { /// HashJoin can work more efficiently than SortMergeJoin but consumes more memory pub prefer_hash_join: bool, default = true + /// When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently + /// experimental. + pub allow_piecewise_merge_join: bool, default = false + /// The maximum estimated size in bytes for one input side of a HashJoin /// will be collected into a single partition pub hash_join_single_partition_threshold: usize, default = 1024 * 1024 diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 4ce845bbbdb28..9bb323554bbde 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1227,6 +1227,7 @@ impl DefaultPhysicalPlanner { | JoinType::LeftMark | JoinType::RightMark ) + && session_state.config_options().optimizer.allow_piecewise_merge_join { let Expr::BinaryExpr(be) = &range_filters[0] else { return plan_err!( From 502c08845da65575322ffe6b954ed09cec20414e Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 1 Oct 2025 17:25:26 -0400 Subject: [PATCH 11/20] remove duplicate function --- datafusion/core/src/physical_planner.rs | 5 +- .../piecewise_merge_join/classic_join.rs | 10 +- .../src/joins/sort_merge_join/stream.rs | 95 ------------------- 3 files changed, 10 insertions(+), 100 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 97094b743c8ad..1da06ea93d996 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1269,7 +1269,10 @@ impl DefaultPhysicalPlanner { | JoinType::LeftMark | JoinType::RightMark ) - && session_state.config_options().optimizer.allow_piecewise_merge_join + && session_state + .config_options() + .optimizer + .allow_piecewise_merge_join { let Expr::BinaryExpr(be) = &range_filters[0] else { return plan_err!( diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index 1864be82f3af4..0c7df29bb5862 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -612,8 +612,9 @@ fn resolve_classic_join( join_schema, )?; - batch_process_state - .set_process_rest(Some(buffer_idx + process_batch_size)); + batch_process_state.set_process_rest(Some( + buffer_idx + process_batch_size, + )); batch_process_state.continue_process = true; // Update the start index so it repeats the process batch_process_state.set_stream_idx(row_idx); @@ -653,8 +654,9 @@ fn resolve_classic_join( batch_size - batch_process_state.num_rows; let stream_repeated = vec![row_idx as u32; process_batch_size]; - batch_process_state - .set_process_rest(Some(buffer_idx + process_batch_size)); + batch_process_state.set_process_rest(Some( + buffer_idx + process_batch_size, + )); batch_process_state.set_rows( batch_process_state.num_rows + process_batch_size, ); diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs index 22a840ab9d865..5a2e3669ab5ec 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs @@ -1865,101 +1865,6 @@ fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec Result { - let mut res = Ordering::Equal; - for ((left_array, right_array), sort_options) in - left_arrays.iter().zip(right_arrays).zip(sort_options) - { - macro_rules! compare_value { - ($T:ty) => {{ - let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); - let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); - match (left_array.is_null(left), right_array.is_null(right)) { - (false, false) => { - let left_value = &left_array.value(left); - let right_value = &right_array.value(right); - res = left_value.partial_cmp(right_value).unwrap(); - if sort_options.descending { - res = res.reverse(); - } - } - (true, false) => { - res = if sort_options.nulls_first { - Ordering::Less - } else { - Ordering::Greater - }; - } - (false, true) => { - res = if sort_options.nulls_first { - Ordering::Greater - } else { - Ordering::Less - }; - } - _ => { - res = match null_equality { - NullEquality::NullEqualsNothing => Ordering::Less, - NullEquality::NullEqualsNull => Ordering::Equal, - }; - } - } - }}; - } - - match left_array.data_type() { - DataType::Null => {} - DataType::Boolean => compare_value!(BooleanArray), - DataType::Int8 => compare_value!(Int8Array), - DataType::Int16 => compare_value!(Int16Array), - DataType::Int32 => compare_value!(Int32Array), - DataType::Int64 => compare_value!(Int64Array), - DataType::UInt8 => compare_value!(UInt8Array), - DataType::UInt16 => compare_value!(UInt16Array), - DataType::UInt32 => compare_value!(UInt32Array), - DataType::UInt64 => compare_value!(UInt64Array), - DataType::Float32 => compare_value!(Float32Array), - DataType::Float64 => compare_value!(Float64Array), - DataType::Utf8 => compare_value!(StringArray), - DataType::Utf8View => compare_value!(StringViewArray), - DataType::LargeUtf8 => compare_value!(LargeStringArray), - DataType::Binary => compare_value!(BinaryArray), - DataType::BinaryView => compare_value!(BinaryViewArray), - DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), - DataType::LargeBinary => compare_value!(LargeBinaryArray), - DataType::Decimal32(..) => compare_value!(Decimal32Array), - DataType::Decimal64(..) => compare_value!(Decimal64Array), - DataType::Decimal128(..) => compare_value!(Decimal128Array), - DataType::Timestamp(time_unit, None) => match time_unit { - TimeUnit::Second => compare_value!(TimestampSecondArray), - TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), - TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), - TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), - }, - DataType::Date32 => compare_value!(Date32Array), - DataType::Date64 => compare_value!(Date64Array), - dt => { - return not_impl_err!( - "Unsupported data type in sort merge join comparator: {}", - dt - ); - } - } - if !res.is_eq() { - break; - } - } - Ok(res) -} - /// A faster version of compare_join_arrays() that only output whether /// the given two rows are equal fn is_join_arrays_equal( From 55a4a1dcec4258f10bd020cea901ff10d0084c3b Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 8 Oct 2025 22:23:11 -0400 Subject: [PATCH 12/20] changes --- datafusion/core/src/physical_planner.rs | 12 +- .../piecewise_merge_join/classic_join.rs | 1005 +++++++++-------- .../src/joins/piecewise_merge_join/exec.rs | 40 +- .../test_files/information_schema.slt | 2 + datafusion/sqllogictest/test_files/joins.slt | 44 +- datafusion/sqllogictest/test_files/pwmj.slt | 83 +- 6 files changed, 620 insertions(+), 566 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 1da06ea93d996..d958f721418c7 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -19,6 +19,7 @@ use std::borrow::Cow; use std::collections::HashMap; +use std::sync::atomic::AtomicUsize; use std::sync::Arc; use crate::datasource::file_format::file_type_to_format; @@ -1247,11 +1248,6 @@ impl DefaultPhysicalPlanner { let prefer_hash_join = session_state.config_options().optimizer.prefer_hash_join; - let cfg = session_state.config(); - - let can_run_single = - cfg.target_partitions() == 1 || !cfg.repartition_joins(); - // TODO: Allow PWMJ to deal with residual equijoin conditions let join: Arc = if join_on.is_empty() { if join_filter.is_none() && matches!(join_type, JoinType::Inner) { @@ -1259,7 +1255,6 @@ impl DefaultPhysicalPlanner { Arc::new(CrossJoinExec::new(physical_left, physical_right)) } else if num_range_filters == 1 && total_filters == 1 - && can_run_single && !matches!( join_type, JoinType::LeftSemi @@ -1342,12 +1337,17 @@ impl DefaultPhysicalPlanner { session_state.execution_props(), )?; + let num_partitions = Arc::new(AtomicUsize::new( + session_state.config().target_partitions(), + )); + Arc::new(PiecewiseMergeJoinExec::try_new( physical_left, physical_right, (on_left, on_right), op, *join_type, + num_partitions, )?) } else { // there is no equal join condition, use the nested loop join diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index 0c7df29bb5862..80eb4d5ecc61e 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -17,24 +17,21 @@ //! Stream Implementation for PiecewiseMergeJoin's Classic Join (Left, Right, Full, Inner) -use arrow::array::{ - new_null_array, Array, PrimitiveArray, PrimitiveBuilder, RecordBatchOptions, -}; -use arrow::compute::take; -use arrow::datatypes::{UInt32Type, UInt64Type}; +use arrow::array::{new_null_array, Array, PrimitiveBuilder}; +use arrow::compute::{take, BatchCoalescer}; +use arrow::datatypes::UInt32Type; use arrow::{ - array::{ - ArrayRef, RecordBatch, UInt32Array, UInt32Builder, UInt64Array, UInt64Builder, - }, + array::{ArrayRef, RecordBatch, UInt32Array}, compute::{sort_to_indices, take_record_batch}, }; -use arrow_schema::{ArrowError, Schema, SchemaRef, SortOptions}; +use arrow_schema::{Schema, SchemaRef, SortOptions}; use datafusion_common::NullEquality; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::PhysicalExprRef; use futures::{Stream, StreamExt}; +use std::sync::atomic::AtomicUsize; use std::{cmp::Ordering, task::ready}; use std::{sync::Arc, task::Poll}; @@ -43,6 +40,7 @@ use crate::joins::piecewise_merge_join::exec::{BufferedSide, BufferedSideReadySt use crate::joins::piecewise_merge_join::utils::need_produce_result_in_final; use crate::joins::utils::{compare_join_arrays, get_final_indices_from_shared_bitmap}; use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult}; + pub(super) enum PiecewiseMergeJoinStreamState { WaitBufferedSide, FetchStreamBatch, @@ -105,8 +103,8 @@ pub(super) struct ClassicPWMJStream { join_metrics: BuildProbeJoinMetrics, // Tracking incremental state for emitting record batches batch_process_state: BatchProcessState, - // Creates batch size - batch_size: usize, + // To synchronize when partition needs to finish + remaining_partitions: Arc, } impl RecordBatchStream for ClassicPWMJStream { @@ -121,10 +119,11 @@ impl RecordBatchStream for ClassicPWMJStream { // Classic Joins // 1. `WaitBufferedSide` - Load in the buffered side data into memory. // 2. `FetchStreamBatch` - Fetch + sort incoming stream batches. We switch the state to -// `ExhaustedStreamBatch` once stream batches are exhausted. +// `Completed` if there are are still remaining partitions to process. It is only switched to +// `ExhaustedStreamBatch` if all partitions have been processed. // 3. `ProcessStreamBatch` - Compare stream batch row values against the buffered side data. // 4. `ExhaustedStreamBatch` - If the join type is Left or Inner we will return state as -// `Completed` however for Full and Right we will need to process the matched/unmatched rows. +// `Completed` however for Full and Right we will need to process the unmatched buffered rows. impl ClassicPWMJStream { // Creates a new `PiecewiseMergeJoinStream` instance #[allow(clippy::too_many_arguments)] @@ -139,21 +138,21 @@ impl ClassicPWMJStream { sort_option: SortOptions, join_metrics: BuildProbeJoinMetrics, batch_size: usize, + remaining_partitions: Arc, ) -> Self { - let streamed_schema = streamed.schema(); Self { - schema, + schema: Arc::clone(&schema), on_streamed, join_type, operator, - streamed_schema, + streamed_schema: streamed.schema(), streamed, buffered_side, state, sort_option, join_metrics, - batch_process_state: BatchProcessState::new(), - batch_size, + batch_process_state: BatchProcessState::new(schema, batch_size), + remaining_partitions, } } @@ -209,7 +208,16 @@ impl ClassicPWMJStream { ) -> Poll>>> { match ready!(self.streamed.poll_next_unpin(cx)) { None => { - self.state = PiecewiseMergeJoinStreamState::ExhaustedStreamSide; + if self + .remaining_partitions + .fetch_sub(1, std::sync::atomic::Ordering::SeqCst) + == 1 + { + self.batch_process_state.reset(); + self.state = PiecewiseMergeJoinStreamState::ExhaustedStreamSide; + } else { + self.state = PiecewiseMergeJoinStreamState::Completed; + } } Some(Ok(batch)) => { // Evaluate the streamed physical expression on the stream batch @@ -230,6 +238,8 @@ impl ClassicPWMJStream { let stream_batch = take_record_batch(&batch, &indices)?; let stream_values = take(stream_values.as_ref(), &indices, None)?; + // Reset BatchProcessState before processing a new stream batch + self.batch_process_state.reset(); self.state = PiecewiseMergeJoinStreamState::ProcessStreamBatch(StreamedBatch { batch: stream_batch, @@ -250,6 +260,15 @@ impl ClassicPWMJStream { let buffered_side = self.buffered_side.try_as_ready_mut()?; let stream_batch = self.state.try_as_process_stream_batch_mut()?; + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + // Produce more work let batch = resolve_classic_join( buffered_side, stream_batch, @@ -258,14 +277,26 @@ impl ClassicPWMJStream { self.sort_option, self.join_type, &mut self.batch_process_state, - self.batch_size, )?; - if self.batch_process_state.continue_process { + if !self.batch_process_state.continue_process { + // We finished scanning this stream batch. + self.batch_process_state + .output_batches + .finish_buffered_batch()?; + if let Some(b) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + return Ok(StatefulStreamResult::Ready(Some(b))); + } + // Nothing pending; hand back whatever `resolve` returned (often empty) and move on. + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; return Ok(StatefulStreamResult::Ready(Some(batch))); } - self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; Ok(StatefulStreamResult::Ready(Some(batch))) } @@ -279,211 +310,116 @@ impl ClassicPWMJStream { return Ok(StatefulStreamResult::Ready(None)); } - let timer = self.join_metrics.join_time.timer(); - - let buffered_data = - Arc::clone(&self.buffered_side.try_as_ready().unwrap().buffered_data); - - // Check if the same batch needs to be checked for values again - if let Some(stream_idx) = self.batch_process_state.process_rest { - if let Some(buffered_indices) = &self.batch_process_state.buffered_indices { - let remaining = buffered_indices.len() - stream_idx; - - // Branch into this and return value if there are more rows to deal with - if remaining > self.batch_size { - let buffered_batch = buffered_data.batch(); - let empty_stream_batch = - RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); - - let buffered_chunk_ref = - buffered_indices.slice(stream_idx, self.batch_size); - let new_buffered_indices = buffered_chunk_ref - .as_any() - .downcast_ref::() - .expect("downcast to UInt64Array after slice"); - - let streamed_indices: UInt32Array = - (0..new_buffered_indices.len() as u32).collect(); - - let batch = build_matched_indices( - Arc::clone(&self.schema), - &empty_stream_batch, - buffered_batch, - streamed_indices, - new_buffered_indices.clone(), - )?; - - self.batch_process_state - .set_process_rest(Some(stream_idx + self.batch_size)); - self.batch_process_state.continue_process = true; - - return Ok(StatefulStreamResult::Ready(Some(batch))); - } - - let buffered_batch = buffered_data.batch(); - let empty_stream_batch = - RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); - - let buffered_chunk_ref = buffered_indices.slice(stream_idx, remaining); - let new_buffered_indices = buffered_chunk_ref - .as_any() - .downcast_ref::() - .expect("downcast to UInt64Array after slice"); - - let streamed_indices: UInt32Array = - (0..new_buffered_indices.len() as u32).collect(); - - let batch = build_matched_indices( - Arc::clone(&self.schema), - &empty_stream_batch, - buffered_batch, - streamed_indices, - new_buffered_indices.clone(), - )?; - - self.batch_process_state.reset(); + if !self.batch_process_state.continue_process { + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } - timer.done(); - self.join_metrics.output_batches.add(1); + self.batch_process_state + .output_batches + .finish_buffered_batch()?; + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { self.state = PiecewiseMergeJoinStreamState::Completed; - return Ok(StatefulStreamResult::Ready(Some(batch))); } - - return exec_err!("Batch process state should hold buffered indices"); } - let (buffered_indices, streamed_indices) = get_final_indices_from_shared_bitmap( + let buffered_data = + Arc::clone(&self.buffered_side.try_as_ready().unwrap().buffered_data); + + let (buffered_indices, _streamed_indices) = get_final_indices_from_shared_bitmap( &buffered_data.visited_indices_bitmap, self.join_type, true, ); - // If the output indices is larger than the limit for the incremental batching then - // proceed to outputting all matches up to that index, return batch, and the matching - // will start next on the updated index (`process_rest`) - if buffered_indices.len() > self.batch_size { - let buffered_batch = buffered_data.batch(); - let empty_stream_batch = - RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); - - let indices_chunk_ref = buffered_indices - .slice(self.batch_process_state.stream_idx, self.batch_size); - - let indices_chunk = indices_chunk_ref - .as_any() - .downcast_ref::() - .expect("downcast to UInt64Array after slice"); - - let batch = build_matched_indices( - Arc::clone(&self.schema), - &empty_stream_batch, - buffered_batch, - streamed_indices, - indices_chunk.clone(), - )?; - - self.batch_process_state.buffered_indices = Some(buffered_indices); - self.batch_process_state - .set_process_rest(Some(self.batch_size)); - self.batch_process_state.continue_process = true; + let new_buffered_batch = + take_record_batch(buffered_data.batch(), &buffered_indices)?; + let mut buffered_columns = new_buffered_batch.columns().to_vec(); + + let streamed_columns: Vec = self + .streamed_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), new_buffered_batch.num_rows())) + .collect(); + + buffered_columns.extend(streamed_columns); + + let batch = RecordBatch::try_new(Arc::clone(&self.schema), buffered_columns)?; + + self.batch_process_state.output_batches.push_batch(batch)?; + self.batch_process_state.continue_process = false; + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { return Ok(StatefulStreamResult::Ready(Some(batch))); } - let buffered_batch = buffered_data.batch(); - let empty_stream_batch = - RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); - - let batch = build_matched_indices( - Arc::clone(&self.schema), - &empty_stream_batch, - buffered_batch, - streamed_indices, - buffered_indices, - )?; + self.batch_process_state + .output_batches + .finish_buffered_batch()?; + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + self.state = PiecewiseMergeJoinStreamState::Completed; + return Ok(StatefulStreamResult::Ready(Some(batch))); + } - timer.done(); - self.join_metrics.output_batches.add(1); self.state = PiecewiseMergeJoinStreamState::Completed; - - Ok(StatefulStreamResult::Ready(Some(batch))) + self.batch_process_state.reset(); + Ok(StatefulStreamResult::Ready(None)) } } -// Holds all information for processing incremental output -// -// Responsibilities: -// - Keeps track of the current stream row index (`stream_idx`) so we can resume -// processing the same stream batch if we return early due to `batch_size`. -// - Remembers the last buffered row index we probed (`buffer_idx`) so we don’t -// restart from 0 for every stream row. -// - Stores `process_rest` to continue outputting matches for the same stream row -// when we previously hit the batch size limit mid-output. -// - Tracks how many rows we’ve produced so far (`num_rows`) to know when to flush. -// - Uses `not_found` to signal that the last stream row had no matches and we -// may need to emit NULLs for RIGHT/FULL OUTER joins. -// - Uses `continue_process` to tell the executor we are not done yet and must -// be called again with the same stream batch. -// - Optionally holds `buffered_indices` when resuming output of remaining matches. struct BatchProcessState { // Used to pick up from the last index on the stream side - stream_idx: usize, - // Used to pick up from the last index on the buffered side - buffer_idx: usize, - // Tracks the number of rows processed; default starts at 0 - num_rows: usize, - // Processes the rest of the batch - process_rest: Option, - // Used to skip fully processing the row - not_found: bool, - // Signals whether to call `ProcessStreamBatch` again + output_batches: Box, + // Used to store the unmatched stream indices for `JoinType::Right` and `JoinType::Full` + unmatched_indices: PrimitiveBuilder, + // Used to store the start index on the buffered side; used to resume processing on the correct + // row + start_buffer_idx: usize, + // Used to store the start index on the stream side; used to resume processing on the correct + // row + start_stream_idx: usize, + // Signals if we found a match for the current stream row + found: bool, + // Signals to continue processing the current stream batch continue_process: bool, - // Holding the buffered indices when processing the remaining marked rows. - buffered_indices: Option>, } impl BatchProcessState { - pub fn new() -> Self { + pub(crate) fn new(schema: Arc, batch_size: usize) -> Self { Self { - stream_idx: 0, - num_rows: 0, - buffer_idx: 0, - process_rest: None, - not_found: false, - continue_process: false, - buffered_indices: None, + output_batches: Box::new(BatchCoalescer::new(schema, batch_size)), + unmatched_indices: PrimitiveBuilder::new(), + start_buffer_idx: 0, + start_stream_idx: 0, + found: false, + continue_process: true, } } - fn reset(&mut self) { - self.stream_idx = 0; - self.num_rows = 0; - self.buffer_idx = 0; - self.process_rest = None; - self.not_found = false; - self.continue_process = false; - self.buffered_indices = None; - } - - fn buffer_idx(&self) -> usize { - self.buffer_idx - } - - fn set_buffer_idx(&mut self, buffer_idx: usize) { - self.buffer_idx = buffer_idx; - } - - fn set_stream_idx(&mut self, stream_idx: usize) { - self.stream_idx = stream_idx; - } - - fn set_rows(&mut self, num_rows: usize) { - self.num_rows = num_rows; - } - - fn set_process_rest(&mut self, process_rest: Option) { - self.process_rest = process_rest; + pub(crate) fn reset(&mut self) { + self.unmatched_indices = PrimitiveBuilder::new(); + self.start_buffer_idx = 0; + self.start_stream_idx = 0; + self.found = false; + self.continue_process = true; } } @@ -499,7 +435,6 @@ impl Stream for ClassicPWMJStream { } // For Left, Right, Full, and Inner joins, incoming stream batches will already be sorted. -// #[allow(clippy::too_many_arguments)] fn resolve_classic_join( buffered_side: &mut BufferedSideReadyState, @@ -509,332 +444,188 @@ fn resolve_classic_join( sort_options: SortOptions, join_type: JoinType, batch_process_state: &mut BatchProcessState, - batch_size: usize, ) -> Result { - let buffered_values = buffered_side.buffered_data.values(); - let buffered_len = buffered_values.len(); + let buffered_len = buffered_side.buffered_data.values().len(); let stream_values = stream_batch.values(); - let mut buffered_indices = UInt64Builder::default(); - let mut stream_indices = UInt32Builder::default(); + let mut buffer_idx = batch_process_state.start_buffer_idx; + let stream_idx = batch_process_state.start_stream_idx; // Our buffer_idx variable allows us to start probing on the buffered side where we last matched // in the previous stream row. - let mut buffer_idx = batch_process_state.buffer_idx(); - for row_idx in batch_process_state.stream_idx..stream_values[0].len() { - let mut found = false; - - // Check once to see if it is a redo of a null value if not we do not try to process the batch - if !batch_process_state.not_found { - while buffer_idx < buffered_values.len() - || batch_process_state.process_rest.is_some() - { - // If there is still data left in the batch to process, use the index and output - if let Some(stream_idx) = batch_process_state.process_rest { - let count = buffered_values.len() - stream_idx; - if count >= batch_size { - let stream_repeated = vec![row_idx as u32; batch_size]; - batch_process_state - .set_process_rest(Some(stream_idx + batch_size)); - batch_process_state - .set_rows(batch_process_state.num_rows + batch_size); - let buffered_range: Vec = (stream_idx as u64 - ..((stream_idx as u64) + (batch_size as u64))) - .collect(); - stream_indices.append_slice(&stream_repeated); - buffered_indices.append_slice(&buffered_range); - - let batch = process_batch( - &mut buffered_indices, - &mut stream_indices, - stream_batch, - buffered_side, - join_type, - join_schema, - )?; - batch_process_state.continue_process = true; - batch_process_state.set_rows(0); - - return Ok(batch); - } - - batch_process_state.set_rows(batch_process_state.num_rows + count); - let stream_repeated = vec![row_idx as u32; count]; - let buffered_range: Vec = - (stream_idx as u64..buffered_len as u64).collect(); - stream_indices.append_slice(&stream_repeated); - buffered_indices.append_slice(&buffered_range); - batch_process_state.process_rest = None; - - found = true; - - break; - } - - let compare = compare_join_arrays( + for row_idx in stream_idx..stream_batch.batch.num_rows() { + while buffer_idx < buffered_len { + let compare = { + let buffered_values = buffered_side.buffered_data.values(); + compare_join_arrays( &[Arc::clone(&stream_values[0])], row_idx, &[Arc::clone(buffered_values)], buffer_idx, &[sort_options], NullEquality::NullEqualsNothing, - )?; + )? + }; + + // If we find a match we append all indices and move to the next stream row index + match operator { + Operator::Gt | Operator::Lt => { + if matches!(compare, Ordering::Less) { + batch_process_state.found = true; + let count = buffered_len - buffer_idx; - // If we find a match we append all indices and move to the next stream row index - match operator { - Operator::Gt | Operator::Lt => { - if matches!(compare, Ordering::Less) { - let count = buffered_values.len() - buffer_idx; - - // If the current output + new output is over our process value then we want to be - // able to change that - if batch_process_state.num_rows + count >= batch_size { - let process_batch_size = - batch_size - batch_process_state.num_rows; - let stream_repeated = - vec![row_idx as u32; process_batch_size]; - batch_process_state.set_rows( - batch_process_state.num_rows + process_batch_size, - ); - - let buffered_range: Vec = (buffer_idx as u64 - ..(buffer_idx + process_batch_size) as u64) - .collect(); - stream_indices.append_slice(&stream_repeated); - buffered_indices.append_slice(&buffered_range); - - let batch = process_batch( - &mut buffered_indices, - &mut stream_indices, - stream_batch, - buffered_side, - join_type, - join_schema, - )?; - - batch_process_state.set_process_rest(Some( - buffer_idx + process_batch_size, - )); - batch_process_state.continue_process = true; - // Update the start index so it repeats the process - batch_process_state.set_stream_idx(row_idx); - batch_process_state.set_buffer_idx(buffer_idx); - batch_process_state.set_rows(0); - - return Ok(batch); - } - - // Update the number of rows processed - batch_process_state - .set_rows(batch_process_state.num_rows + count); - - let stream_repeated = vec![row_idx as u32; count]; - let buffered_range: Vec = - (buffer_idx as u64..buffered_len as u64).collect(); - - stream_indices.append_slice(&stream_repeated); - buffered_indices.append_slice(&buffered_range); - found = true; - - break; + let batch = build_matched_indices( + (buffer_idx, count), + (row_idx, count), + buffered_side, + stream_batch, + join_type, + Arc::clone(&join_schema), + )?; + + batch_process_state.output_batches.push_batch(batch)?; + + // Flush batch and update pointers if we have a completed batch + if let Some(batch) = + batch_process_state.output_batches.next_completed_batch() + { + batch_process_state.found = false; + batch_process_state.start_buffer_idx = buffer_idx; + batch_process_state.start_stream_idx = row_idx + 1; + return Ok(batch); } + + break; } - Operator::GtEq | Operator::LtEq => { - if matches!(compare, Ordering::Equal | Ordering::Less) { - let count = buffered_values.len() - buffer_idx; - - // If the current output + new output is over our process value then we want to be - // able to change that - if batch_process_state.num_rows + count >= batch_size { - // Update the start index so it repeats the process - batch_process_state.set_stream_idx(row_idx); - batch_process_state.set_buffer_idx(buffer_idx); - - let process_batch_size = - batch_size - batch_process_state.num_rows; - let stream_repeated = - vec![row_idx as u32; process_batch_size]; - batch_process_state.set_process_rest(Some( - buffer_idx + process_batch_size, - )); - batch_process_state.set_rows( - batch_process_state.num_rows + process_batch_size, - ); - let buffered_range: Vec = (buffer_idx as u64 - ..(buffer_idx + process_batch_size) as u64) - .collect(); - stream_indices.append_slice(&stream_repeated); - buffered_indices.append_slice(&buffered_range); - - let batch = process_batch( - &mut buffered_indices, - &mut stream_indices, - stream_batch, - buffered_side, - join_type, - join_schema, - )?; - - batch_process_state.continue_process = true; - batch_process_state.set_rows(0); - - return Ok(batch); - } - - // Update the number of rows processed - batch_process_state - .set_rows(batch_process_state.num_rows + count); - let stream_repeated = vec![row_idx as u32; count]; - let buffered_range: Vec = - (buffer_idx as u64..buffered_len as u64).collect(); - - stream_indices.append_slice(&stream_repeated); - buffered_indices.append_slice(&buffered_range); - found = true; - - break; + } + Operator::GtEq | Operator::LtEq => { + if matches!(compare, Ordering::Equal | Ordering::Less) { + batch_process_state.found = true; + let count = buffered_len - buffer_idx; + let batch = build_matched_indices( + (buffer_idx, count), + (row_idx, count), + buffered_side, + stream_batch, + join_type, + Arc::clone(&join_schema), + )?; + + // Flush batch and update pointers if we have a completed batch + batch_process_state.output_batches.push_batch(batch)?; + if let Some(batch) = + batch_process_state.output_batches.next_completed_batch() + { + batch_process_state.found = false; + batch_process_state.start_buffer_idx = buffer_idx; + batch_process_state.start_stream_idx = row_idx + 1; + return Ok(batch); } + + break; } - _ => { - return exec_err!( - "PiecewiseMergeJoin should not contain operator, {}", - operator - ) - } - }; + } + _ => { + return exec_err!( + "PiecewiseMergeJoin should not contain operator, {}", + operator + ) + } + }; - // Increment buffer_idx after every row - buffer_idx += 1; - } + // Increment buffer_idx after every row + buffer_idx += 1; } - // If not found we append a null value for `JoinType::Right` and `JoinType::Full` - if (!found || batch_process_state.not_found) - && matches!(join_type, JoinType::Right | JoinType::Full) + // If a match was not found for the current stream row index the stream indice is appended + // to the unmatched indices to be flushed later. + if matches!(join_type, JoinType::Right | JoinType::Full) + && !batch_process_state.found { - let remaining = batch_size.saturating_sub(batch_process_state.num_rows); - if remaining == 0 { - let batch = process_batch( - &mut buffered_indices, - &mut stream_indices, - stream_batch, - buffered_side, - join_type, - join_schema, - )?; - - // Update the start index so it repeats the process - batch_process_state.set_stream_idx(row_idx); - batch_process_state.set_buffer_idx(buffer_idx); - batch_process_state.not_found = true; - batch_process_state.continue_process = true; - batch_process_state.set_rows(0); - - return Ok(batch); - } - - // Append right side value + null value for left - stream_indices.append_value(row_idx as u32); - buffered_indices.append_null(); - batch_process_state.set_rows(batch_process_state.num_rows + 1); - batch_process_state.not_found = false; + batch_process_state + .unmatched_indices + .append_value(row_idx as u32); } + + batch_process_state.found = false; } - let batch = process_batch( - &mut buffered_indices, - &mut stream_indices, - stream_batch, - buffered_side, - join_type, - join_schema, - )?; + // Flushed all unmatched indices on the streamed side + if matches!(join_type, JoinType::Right | JoinType::Full) { + let batch = create_unmatched_batch( + &mut batch_process_state.unmatched_indices, + stream_batch, + Arc::clone(&join_schema), + )?; - // Resets batch process state for processing `Left` + `Full` join - batch_process_state.reset(); + batch_process_state.output_batches.push_batch(batch)?; + } - Ok(batch) + batch_process_state.continue_process = false; + Ok(RecordBatch::new_empty(Arc::clone(&join_schema))) } -fn process_batch( - buffered_indices: &mut PrimitiveBuilder, - stream_indices: &mut PrimitiveBuilder, - stream_batch: &StreamedBatch, +// Builds a record batch from indices ranges on the buffered and streamed side. +// +// The two ranges are: buffered_range: (start index, count) and streamed_range: (start index, count) due +// to batch.slice(start, count). +fn build_matched_indices( + buffered_range: (usize, usize), + streamed_range: (usize, usize), buffered_side: &mut BufferedSideReadyState, + stream_batch: &StreamedBatch, join_type: JoinType, join_schema: Arc, ) -> Result { - let stream_indices_array = stream_indices.finish(); - let buffered_indices_array = buffered_indices.finish(); - - // We need to mark the buffered side matched indices for `JoinType::Full` and `JoinType::Left` + // Mark the buffered indices as visited if need_produce_result_in_final(join_type) { let mut bitmap = buffered_side.buffered_data.visited_indices_bitmap.lock(); - - buffered_indices_array.iter().flatten().for_each(|i| { - bitmap.set_bit(i as usize, true); - }); + for i in buffered_range.0..buffered_range.0 + buffered_range.1 { + bitmap.set_bit(i, true); + } } - let batch = build_matched_indices( - join_schema, - &stream_batch.batch, - &buffered_side.buffered_data.batch, - stream_indices_array, - buffered_indices_array, - )?; + let new_buffered_batch = buffered_side + .buffered_data + .batch() + .slice(buffered_range.0, buffered_range.1); + let mut buffered_columns = new_buffered_batch.columns().to_vec(); - Ok(batch) -} + let indices = UInt32Array::from_value(streamed_range.0 as u32, streamed_range.1); + let new_stream_batch = take_record_batch(&stream_batch.batch, &indices)?; + let streamed_columns = new_stream_batch.columns().to_vec(); -fn build_matched_indices( - schema: Arc, - streamed_batch: &RecordBatch, - buffered_batch: &RecordBatch, - streamed_indices: UInt32Array, - buffered_indices: UInt64Array, -) -> Result { - if schema.fields().is_empty() { - // Build an “empty” RecordBatch with just row‐count metadata - let options = RecordBatchOptions::new() - .with_match_field_names(true) - .with_row_count(Some(streamed_indices.len())); - return Ok(RecordBatch::try_new_with_options( - Arc::new((*schema).clone()), - vec![], - &options, - )?); - } + buffered_columns.extend(streamed_columns); - // Gather stream columns after applying filter specified with stream indices - let streamed_columns = streamed_batch - .columns() - .iter() - .map(|column_array| { - if column_array.is_empty() - || streamed_indices.null_count() == streamed_indices.len() - { - assert_eq!(streamed_indices.null_count(), streamed_indices.len()); - Ok(new_null_array( - column_array.data_type(), - streamed_indices.len(), - )) - } else { - take(column_array, &streamed_indices, None) - } - }) - .collect::, ArrowError>>()?; + Ok(RecordBatch::try_new( + Arc::clone(&join_schema), + buffered_columns, + )?) +} - let mut buffered_columns = buffered_batch - .columns() +// Creates a record batch from the unmatched indices on the streamed side +fn create_unmatched_batch( + streamed_indices: &mut PrimitiveBuilder, + stream_batch: &StreamedBatch, + join_schema: Arc, +) -> Result { + let streamed_indices = streamed_indices.finish(); + let new_stream_batch = take_record_batch(&stream_batch.batch, &streamed_indices)?; + let streamed_columns = new_stream_batch.columns().to_vec(); + let buffered_cols_len = join_schema.fields().len() - streamed_columns.len(); + + let num_rows = new_stream_batch.num_rows(); + let mut buffered_columns: Vec = join_schema + .fields() .iter() - .map(|column_array| take(column_array, &buffered_indices, None)) - .collect::, ArrowError>>()?; + .take(buffered_cols_len) + .map(|field| new_null_array(field.data_type(), num_rows)) + .collect(); buffered_columns.extend(streamed_columns); Ok(RecordBatch::try_new( - Arc::new((*schema).clone()), + Arc::clone(&join_schema), buffered_columns, )?) } @@ -928,7 +719,14 @@ mod tests { operator: Operator, join_type: JoinType, ) -> Result { - PiecewiseMergeJoinExec::try_new(left, right, on, operator, join_type) + PiecewiseMergeJoinExec::try_new( + left, + right, + on, + operator, + join_type, + Arc::new(AtomicUsize::new(1)), + ) } async fn join_collect( @@ -1349,6 +1147,253 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_inner_less_than_equal_with_dups() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 4 | 7 | + // | 2 | 4 | 8 | + // | 3 | 2 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 4, 2]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 4 | 70 | + // | 20 | 3 | 80 | + // | 30 | 2 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 3, 2]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::LtEq, JoinType::Inner).await?; + + // Expected grouping follows right.b1 descending (4, 3, 2) + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 4 | 8 | 10 | 4 | 70 | + | 3 | 2 | 9 | 10 | 4 | 70 | + | 3 | 2 | 9 | 20 | 3 | 80 | + | 3 | 2 | 9 | 30 | 2 | 90 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_greater_than_unsorted_right() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 2 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 2, 4]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 1 | 80 | + // | 30 | 2 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 1, 2]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Inner).await?; + + // Grouped by right in ascending evaluation for > (1,2,3) + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 2 | 2 | 8 | 20 | 1 | 80 | + | 3 | 4 | 9 | 20 | 1 | 80 | + | 3 | 4 | 9 | 30 | 2 | 90 | + | 3 | 4 | 9 | 10 | 3 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_left_less_than_equal_with_left_nulls_on_no_match() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 5 | 7 | + // | 2 | 4 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![5, 4, 1]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // +----+----+----+ + let right = build_table(("a2", &vec![10]), ("b1", &vec![3]), ("c2", &vec![70])); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::LtEq, JoinType::Left).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 3 | 1 | 9 | 10 | 3 | 70 | + | 1 | 5 | 7 | | | | + | 2 | 4 | 8 | | | | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_greater_than_equal_with_right_nulls_on_no_match() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 2 | 8 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2]), + ("b1", &vec![1, 2]), + ("c1", &vec![7, 8]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 5 | 80 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20]), + ("b1", &vec![3, 5]), + ("c2", &vec![70, 80]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | | | | 10 | 3 | 70 | + | | | | 20 | 5 | 80 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_single_row_left_less_than() -> Result<()> { + let left = build_table(("a1", &vec![42]), ("b1", &vec![5]), ("c1", &vec![999])); + + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![1, 5, 7]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+-----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+-----+----+----+----+ + | 42 | 5 | 999 | 30 | 7 | 90 | + +----+----+-----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_empty_right() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 2, 3]), + ("c1", &vec![7, 8, 9]), + ); + + let right = build_table( + ("a2", &Vec::::new()), + ("b1", &Vec::::new()), + ("c2", &Vec::::new()), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + +----+----+----+----+----+----+ + "#); + Ok(()) + } + #[tokio::test] async fn join_date32_inner_less_than() -> Result<()> { // +----+-------+----+ diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index a51b35d5ab7b7..8fa58219ef9ce 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -31,12 +31,14 @@ use datafusion_execution::{ use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{ - LexOrdering, OrderingRequirements, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, + Distribution, LexOrdering, OrderingRequirements, PhysicalExpr, PhysicalExprRef, + PhysicalSortExpr, }; use datafusion_physical_expr_common::physical_expr::fmt_sql; use futures::TryStreamExt; use parking_lot::Mutex; use std::fmt::Formatter; +use std::sync::atomic::AtomicUsize; use std::sync::Arc; use crate::execution_plan::{boundedness_from_children, EmissionType}; @@ -47,7 +49,7 @@ use crate::joins::piecewise_merge_join::classic_join::{ use crate::joins::piecewise_merge_join::utils::{ build_visited_indices_map, is_existence_join, is_right_existence_join, }; -use crate::joins::utils::symmetric_join_output_partitioning; +use crate::joins::utils::asymmetric_join_output_partitioning; use crate::{ joins::{ utils::{build_join_schema, BuildProbeJoinMetrics, OnceAsync, OnceFut}, @@ -266,12 +268,14 @@ pub struct PiecewiseMergeJoinExec { /// The right sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations /// Unsorted for mark joins #[allow(unused)] - ight_batch_required_orders: LexOrdering, + right_batch_required_orders: LexOrdering, /// This determines the sort order of all join columns used in sorting the stream and buffered execution plans. sort_options: SortOptions, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, + /// Number of partitions to process + remaining_partitions: Arc, } impl PiecewiseMergeJoinExec { @@ -281,6 +285,7 @@ impl PiecewiseMergeJoinExec { on: (Arc, Arc), operator: Operator, join_type: JoinType, + num_partitions: Arc, ) -> Result { // TODO: Implement existence joins for PiecewiseMergeJoin if is_existence_join(join_type) { @@ -318,7 +323,7 @@ impl PiecewiseMergeJoinExec { // Give the same `sort_option for comparison later` let left_child_plan_required_order = vec![PhysicalSortExpr::new(Arc::clone(&on.0), sort_options)]; - let ight_batch_required_orders = + let right_batch_required_orders = vec![PhysicalSortExpr::new(Arc::clone(&on.1), sort_options)]; let Some(left_child_plan_required_order) = @@ -328,8 +333,8 @@ impl PiecewiseMergeJoinExec { "PiecewiseMergeJoinExec requires valid sort expressions for its left side" ); }; - let Some(ight_batch_required_orders) = - LexOrdering::new(ight_batch_required_orders) + let Some(right_batch_required_orders) = + LexOrdering::new(right_batch_required_orders) else { return internal_err!( "PiecewiseMergeJoinExec requires valid sort expressions for its right side" @@ -360,9 +365,10 @@ impl PiecewiseMergeJoinExec { buffered_fut: Default::default(), metrics: ExecutionPlanMetricsSet::new(), left_child_plan_required_order, - ight_batch_required_orders, + right_batch_required_orders, sort_options, cache, + remaining_partitions: num_partitions, }) } @@ -421,7 +427,7 @@ impl PiecewiseMergeJoinExec { )?; let output_partitioning = - symmetric_join_output_partitioning(buffered, streamed, &join_type)?; + asymmetric_join_output_partitioning(buffered, streamed, &join_type)?; Ok(PlanProperties::new( eq_properties, @@ -471,6 +477,13 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { vec![&self.buffered, &self.streamed] } + fn required_input_distribution(&self) -> Vec { + vec![ + Distribution::SinglePartition, + Distribution::UnspecifiedDistribution, + ] + } + fn required_input_ordering(&self) -> Vec> { // Existence joins don't need to be sorted on one side. if is_right_existence_join(self.join_type) { @@ -497,6 +510,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { self.on.clone(), self.operator, self.join_type, + Arc::clone(&self.remaining_partitions), )?)), _ => internal_err!( "PiecewiseMergeJoin should have 2 children, found {}", @@ -513,12 +527,12 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { let on_buffered = Arc::clone(&self.on.0); let on_streamed = Arc::clone(&self.on.1); - let metrics = BuildProbeJoinMetrics::new(0, &self.metrics); + let metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); let buffered_fut = self.buffered_fut.try_once(|| { let reservation = MemoryConsumer::new("PiecewiseMergeJoinInput") .register(context.memory_pool()); - let buffered_stream = - self.buffered.execute(partition, Arc::clone(&context))?; + + let buffered_stream = self.buffered.execute(0, Arc::clone(&context))?; Ok(build_buffered_data( buffered_stream, Arc::clone(&on_buffered), @@ -547,6 +561,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { self.sort_options, metrics, batch_size, + Arc::clone(&self.remaining_partitions), ))) } } @@ -607,8 +622,7 @@ async fn build_buffered_data( }) .await?; - let batches_iter = batches.iter().rev(); - let single_batch = concat_batches(&schema, batches_iter)?; + let single_batch = concat_batches(&schema, batches.iter())?; // Evaluate physical expression on the buffered side. let buffered_values = on_buffered diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 670992633bb85..99b454ee80949 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -285,6 +285,7 @@ datafusion.format.time_format %H:%M:%S%.f datafusion.format.timestamp_format %Y-%m-%dT%H:%M:%S%.f datafusion.format.timestamp_tz_format NULL datafusion.format.types_info false +datafusion.optimizer.allow_piecewise_merge_join false datafusion.optimizer.allow_symmetric_joins_without_pruning true datafusion.optimizer.default_filter_selectivity 20 datafusion.optimizer.enable_distinct_aggregation_soft_limit true @@ -401,6 +402,7 @@ datafusion.format.time_format %H:%M:%S%.f Time format for time arrays datafusion.format.timestamp_format %Y-%m-%dT%H:%M:%S%.f Timestamp format for timestamp arrays datafusion.format.timestamp_tz_format NULL Timestamp format for timestamp with timezone arrays. When `None`, ISO 8601 format is used. datafusion.format.types_info false Show types in visual representation batches +datafusion.optimizer.allow_piecewise_merge_join false When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. datafusion.optimizer.default_filter_selectivity 20 The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 6b92c71261f92..7f3f9ab5fcac4 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4148,11 +4148,10 @@ logical_plan 03)----TableScan: left_table projection=[a, b, c] 04)----TableScan: right_table projection=[x, y, z] physical_plan -01)SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] -02)--PiecewiseMergeJoin: operator=Lt, join_type=Inner, on=(a < x) -03)----SortExec: expr=[a@0 DESC NULLS LAST], preserve_partitioning=[false] -04)------DataSourceExec: partitions=1, partition_sizes=[0] -05)----DataSourceExec: partitions=1, partition_sizes=[0] +01)SortExec: expr=[x@3 ASC NULLS LAST], preserve_partitioning=[false] +02)--NestedLoopJoinExec: join_type=Inner, filter=a@0 < x@1 +03)----DataSourceExec: partitions=1, partition_sizes=[0] +04)----DataSourceExec: partitions=1, partition_sizes=[0] query TT EXPLAIN SELECT * FROM left_table JOIN right_table ON left_table.a= t1.c2 LIMIT 20; @@ -4212,12 +4203,11 @@ EXPLAIN SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 20; 01)GlobalLimitExec: skip=0, fetch=20 01)Limit: skip=0, fetch=20 02)--Full Join: Filter: t0.c2 >= t1.c2 -02)--PiecewiseMergeJoin: operator=GtEq, join_type=Full, on=(c2 >= c2) -03)----SortExec: expr=[c2@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--NestedLoopJoinExec: join_type=Full, filter=c2@0 >= c2@1 +03)----DataSourceExec: partitions=1, partition_sizes=[2] 03)----TableScan: t0 projection=[c1, c2] -04)------DataSourceExec: partitions=1, partition_sizes=[2] +04)----DataSourceExec: partitions=1, partition_sizes=[2] 04)----TableScan: t1 projection=[c1, c2, c3] -05)----DataSourceExec: partitions=1, partition_sizes=[2] logical_plan physical_plan @@ -4237,9 +4227,6 @@ SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 20; 4 4 3 3 false 4 4 3 3 true -statement ok -set datafusion.execution.batch_size = 3; - query IIIIB rowsort -- Note: using LIMIT value higher than cardinality before LIMIT to avoid query non-determinism SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 AND t0.c2 >= t1.c2 LIMIT 20; @@ -4277,10 +4264,9 @@ logical_plan 04)----TableScan: t1 projection=[c1, c2, c3] physical_plan 01)GlobalLimitExec: skip=0, fetch=2 -02)--PiecewiseMergeJoin: operator=GtEq, join_type=Full, on=(c2 >= c2) -03)----SortExec: expr=[c2@1 ASC NULLS LAST], preserve_partitioning=[false] -04)------DataSourceExec: partitions=1, partition_sizes=[2] -05)----DataSourceExec: partitions=1, partition_sizes=[2] +02)--NestedLoopJoinExec: join_type=Full, filter=c2@0 >= c2@1 +03)----DataSourceExec: partitions=1, partition_sizes=[2] +04)----DataSourceExec: partitions=1, partition_sizes=[2] ## Test !join.on.is_empty() && join.filter.is_some() query TT @@ -5222,11 +5208,7 @@ SELECT c # PiecewiseMergeJoin Test statement ok -set datafusion.execution.batch_size = 8192; - -# TODO: partitioned PWMJ execution -statement ok -set datafusion.execution.target_partitions = 1; +set datafusion.optimizer.allow_piecewise_merge_join = true; query II SELECT join_t1.t1_id, join_t2.t2_id @@ -5251,10 +5233,10 @@ physical_plan 01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] 02)--PiecewiseMergeJoin: operator=Gt, join_type=Inner, on=(t1_id > t2_id) 03)----SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] -04)------CoalesceBatchesExec: target_batch_size=8192 +04)------CoalesceBatchesExec: target_batch_size=3 05)--------FilterExec: t1_id@0 > 10 06)----------DataSourceExec: partitions=1, partition_sizes=[1] -07)----CoalesceBatchesExec: target_batch_size=8192 +07)----CoalesceBatchesExec: target_batch_size=3 08)------FilterExec: t2_int@1 > 1, projection=[t2_id@0] 09)--------DataSourceExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/pwmj.slt b/datafusion/sqllogictest/test_files/pwmj.slt index ee9622e6bb1b9..aff19d2ecb50b 100644 --- a/datafusion/sqllogictest/test_files/pwmj.slt +++ b/datafusion/sqllogictest/test_files/pwmj.slt @@ -16,8 +16,8 @@ # under the License. -statement ok -set datafusion.execution.target_partitions = 1; +statement ok +set datafusion.optimizer.allow_piecewise_merge_join = true; statement ok CREATE TABLE join_t1 (t1_id INT); @@ -69,15 +69,17 @@ logical_plan 08)--------Filter: join_t2.t2_int > Int32(1) 09)----------TableScan: join_t2 projection=[t2_id, t2_int] physical_plan -01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] -02)--PiecewiseMergeJoin: operator=Gt, join_type=Inner, on=(t1_id > t2_id) -03)----SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] -04)------CoalesceBatchesExec: target_batch_size=8192 -05)--------FilterExec: t1_id@0 > 10 -06)----------DataSourceExec: partitions=1, partition_sizes=[1] -07)----CoalesceBatchesExec: target_batch_size=8192 -08)------FilterExec: t2_int@1 > 1, projection=[t2_id@0] -09)--------DataSourceExec: partitions=1, partition_sizes=[1] +01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----PiecewiseMergeJoin: operator=Gt, join_type=Inner, on=(t1_id > t2_id) +04)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------FilterExec: t1_id@0 > 10 +07)------------DataSourceExec: partitions=1, partition_sizes=[1] +08)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)--------CoalesceBatchesExec: target_batch_size=8192 +10)----------FilterExec: t2_int@1 > 1, projection=[t2_id@0] +11)------------DataSourceExec: partitions=1, partition_sizes=[1] query II SELECT t1.t1_id, t2.t2_id @@ -114,15 +116,17 @@ logical_plan 08)--------Filter: join_t2.t2_int = Int32(3) 09)----------TableScan: join_t2 projection=[t2_id, t2_int] physical_plan -01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] -02)--PiecewiseMergeJoin: operator=GtEq, join_type=Inner, on=(t1_id >= t2_id) -03)----SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] -04)------CoalesceBatchesExec: target_batch_size=8192 -05)--------FilterExec: t1_id@0 >= 22 -06)----------DataSourceExec: partitions=1, partition_sizes=[1] -07)----CoalesceBatchesExec: target_batch_size=8192 -08)------FilterExec: t2_int@1 = 3, projection=[t2_id@0] -09)--------DataSourceExec: partitions=1, partition_sizes=[1] +01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST, t2_id@1 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----PiecewiseMergeJoin: operator=GtEq, join_type=Inner, on=(t1_id >= t2_id) +04)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------FilterExec: t1_id@0 >= 22 +07)------------DataSourceExec: partitions=1, partition_sizes=[1] +08)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)--------CoalesceBatchesExec: target_batch_size=8192 +10)----------FilterExec: t2_int@1 = 3, projection=[t2_id@0] +11)------------DataSourceExec: partitions=1, partition_sizes=[1] query II SELECT t1.t1_id, t2.t2_id @@ -159,13 +163,15 @@ logical_plan 07)--------Filter: join_t2.t2_int >= Int32(3) 08)----------TableScan: join_t2 projection=[t2_id, t2_int] physical_plan -01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] -02)--PiecewiseMergeJoin: operator=Lt, join_type=Inner, on=(t1_id < t2_id) -03)----SortExec: expr=[t1_id@0 DESC NULLS LAST], preserve_partitioning=[false] -04)------DataSourceExec: partitions=1, partition_sizes=[1] -05)----CoalesceBatchesExec: target_batch_size=8192 -06)------FilterExec: t2_int@1 >= 3, projection=[t2_id@0] -07)--------DataSourceExec: partitions=1, partition_sizes=[1] +01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST, t2_id@1 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----PiecewiseMergeJoin: operator=Lt, join_type=Inner, on=(t1_id < t2_id) +04)------SortExec: expr=[t1_id@0 DESC NULLS LAST], preserve_partitioning=[false] +05)--------DataSourceExec: partitions=1, partition_sizes=[1] +06)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)--------CoalesceBatchesExec: target_batch_size=8192 +08)----------FilterExec: t2_int@1 >= 3, projection=[t2_id@0] +09)------------DataSourceExec: partitions=1, partition_sizes=[1] query II @@ -204,12 +210,17 @@ logical_plan 08)--------Filter: join_t2.t2_name != Utf8View("y") 09)----------TableScan: join_t2 projection=[t2_id, t2_name] physical_plan -01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] -02)--PiecewiseMergeJoin: operator=LtEq, join_type=Inner, on=(t1_id <= t2_id) -03)----SortExec: expr=[t1_id@0 DESC NULLS LAST], preserve_partitioning=[false] -04)------CoalesceBatchesExec: target_batch_size=8192 -05)--------FilterExec: t1_id@0 = 11 OR t1_id@0 = 44 -06)----------DataSourceExec: partitions=1, partition_sizes=[1] -07)----CoalesceBatchesExec: target_batch_size=8192 -08)------FilterExec: t2_name@1 != y, projection=[t2_id@0] -09)--------DataSourceExec: partitions=1, partition_sizes=[1] +01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST, t2_id@1 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----PiecewiseMergeJoin: operator=LtEq, join_type=Inner, on=(t1_id <= t2_id) +04)------SortExec: expr=[t1_id@0 DESC NULLS LAST], preserve_partitioning=[false] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------FilterExec: t1_id@0 = 11 OR t1_id@0 = 44 +07)------------DataSourceExec: partitions=1, partition_sizes=[1] +08)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)--------CoalesceBatchesExec: target_batch_size=8192 +10)----------FilterExec: t2_name@1 != y, projection=[t2_id@0] +11)------------DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +set datafusion.optimizer.allow_piecewise_merge_join = false; From 10526fe9c00cee15182a0d6052bf4156f4f6f867 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 8 Oct 2025 22:44:25 -0400 Subject: [PATCH 13/20] update configs --- docs/source/user-guide/configs.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 4d0b897648f27..02f109cec4d4c 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -146,6 +146,7 @@ The following configuration settings are available: | datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | | datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | | datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | +| datafusion.optimizer.allow_piecewise_merge_join | false | When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. | | datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | | datafusion.optimizer.hash_join_single_partition_threshold_rows | 131072 | The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition | | datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | From 14d945c36cdcb0c9ea2ed33c5a405f26e4718ec6 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Mon, 13 Oct 2025 21:39:22 -0400 Subject: [PATCH 14/20] Update datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs Co-authored-by: Yongting You <2010youy01@gmail.com> --- .../physical-plan/src/joins/piecewise_merge_join/exec.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index 8fa58219ef9ce..2564a992885c7 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -437,7 +437,9 @@ impl PiecewiseMergeJoinExec { )) } - // TODO: Add input order + // TODO: Add input order. Now they're all `false` indicating it will not maintain the input order. + // However, for certain join types the order is maintained. This can be updated in the future after + // more testing. fn maintains_input_order(join_type: JoinType) -> Vec { match join_type { // The existence side is expected to come in sorted From 2bf7940dd50d67637470bd5fd42f31bf0921cfcb Mon Sep 17 00:00:00 2001 From: Jonathan Date: Tue, 14 Oct 2025 15:58:38 -0400 Subject: [PATCH 15/20] fix proposed changes --- datafusion/common/src/config.rs | 5 +- datafusion/core/src/physical_planner.rs | 9 +- .../piecewise_merge_join/classic_join.rs | 85 ++++++++++--------- .../src/joins/piecewise_merge_join/exec.rs | 35 +++++--- .../test_files/information_schema.slt | 4 +- datafusion/sqllogictest/test_files/joins.slt | 20 +---- datafusion/sqllogictest/test_files/pwmj.slt | 47 +++++++++- docs/source/user-guide/configs.md | 2 +- 8 files changed, 125 insertions(+), 82 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 375ffd0d32c0f..9468c08fb4200 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -841,8 +841,9 @@ config_namespace! { pub prefer_hash_join: bool, default = true /// When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently - /// experimental. - pub allow_piecewise_merge_join: bool, default = false + /// experimental. Physical planner will opt for PiecewiseMergeJoin when there is only + /// one range filter. + pub enable_piecewise_merge_join: bool, default = false /// The maximum estimated size in bytes for one input side of a HashJoin /// will be collected into a single partition diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index c673660ae1629..a79dec54b16e0 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -19,7 +19,6 @@ use std::borrow::Cow; use std::collections::HashMap; -use std::sync::atomic::AtomicUsize; use std::sync::Arc; use crate::datasource::file_format::file_type_to_format; @@ -1267,7 +1266,7 @@ impl DefaultPhysicalPlanner { && session_state .config_options() .optimizer - .allow_piecewise_merge_join + .enable_piecewise_merge_join { let Expr::BinaryExpr(be) = &range_filters[0] else { return plan_err!( @@ -1337,17 +1336,13 @@ impl DefaultPhysicalPlanner { session_state.execution_props(), )?; - let num_partitions = Arc::new(AtomicUsize::new( - session_state.config().target_partitions(), - )); - Arc::new(PiecewiseMergeJoinExec::try_new( physical_left, physical_right, (on_left, on_right), op, *join_type, - num_partitions, + session_state.config().target_partitions(), )?) } else { // there is no equal join condition, use the nested loop join diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index 80eb4d5ecc61e..23ccded9aa81b 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -26,12 +26,11 @@ use arrow::{ }; use arrow_schema::{Schema, SchemaRef, SortOptions}; use datafusion_common::NullEquality; -use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_common::{internal_err, Result}; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::PhysicalExprRef; use futures::{Stream, StreamExt}; -use std::sync::atomic::AtomicUsize; use std::{cmp::Ordering, task::ready}; use std::{sync::Arc, task::Poll}; @@ -44,14 +43,14 @@ use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult}; pub(super) enum PiecewiseMergeJoinStreamState { WaitBufferedSide, FetchStreamBatch, - ProcessStreamBatch(StreamedBatch), - ExhaustedStreamSide, + ProcessStreamBatch(SortedStreamBatch), + ProcessUnmatched, Completed, } impl PiecewiseMergeJoinStreamState { // Grab mutable reference to the current stream batch - fn try_as_process_stream_batch_mut(&mut self) -> Result<&mut StreamedBatch> { + fn try_as_process_stream_batch_mut(&mut self) -> Result<&mut SortedStreamBatch> { match self { PiecewiseMergeJoinStreamState::ProcessStreamBatch(state) => Ok(state), _ => internal_err!("Expected streamed batch in StreamBatch"), @@ -59,19 +58,28 @@ impl PiecewiseMergeJoinStreamState { } } -pub(super) struct StreamedBatch { +/// The stream side incoming batch with required sort order. +/// +/// Note the compare key in the join predicate might include expressions on the original +/// columns, so we store the evaluated compare key separately. +/// e.g. For join predicate `buffer.v1 < (stream.v1 + 1)`, the `compare_key_values` field stores +/// the evaluted `stream.v1 + 1` array. +pub(super) struct SortedStreamBatch { pub batch: RecordBatch, - values: Vec, + compare_key_values: Vec, } -impl StreamedBatch { +impl SortedStreamBatch { #[allow(dead_code)] - fn new(batch: RecordBatch, values: Vec) -> Self { - Self { batch, values } + fn new(batch: RecordBatch, compare_key_values: Vec) -> Self { + Self { + batch, + compare_key_values, + } } - fn values(&self) -> &Vec { - &self.values + fn compare_key_values(&self) -> &Vec { + &self.compare_key_values } } @@ -96,15 +104,13 @@ pub(super) struct ClassicPWMJStream { buffered_side: BufferedSide, // Tracks the state of the `PiecewiseMergeJoin` state: PiecewiseMergeJoinStreamState, - // Sort option for buffered and streamed side (specifies whether + // Sort option for streamed side (specifies whether // the sort is ascending or descending) sort_option: SortOptions, // Metrics for build + probe joins join_metrics: BuildProbeJoinMetrics, // Tracking incremental state for emitting record batches batch_process_state: BatchProcessState, - // To synchronize when partition needs to finish - remaining_partitions: Arc, } impl RecordBatchStream for ClassicPWMJStream { @@ -114,7 +120,7 @@ impl RecordBatchStream for ClassicPWMJStream { } // `PiecewiseMergeJoinStreamState` is separated into `WaitBufferedSide`, `FetchStreamBatch`, -// `ProcessStreamBatch`, `ExhaustedStreamSide` and `Completed`. +// `ProcessStreamBatch`, `ProcessUnmatched` and `Completed`. // // Classic Joins // 1. `WaitBufferedSide` - Load in the buffered side data into memory. @@ -138,7 +144,6 @@ impl ClassicPWMJStream { sort_option: SortOptions, join_metrics: BuildProbeJoinMetrics, batch_size: usize, - remaining_partitions: Arc, ) -> Self { Self { schema: Arc::clone(&schema), @@ -152,7 +157,6 @@ impl ClassicPWMJStream { sort_option, join_metrics, batch_process_state: BatchProcessState::new(schema, batch_size), - remaining_partitions, } } @@ -171,7 +175,7 @@ impl ClassicPWMJStream { PiecewiseMergeJoinStreamState::ProcessStreamBatch(_) => { handle_state!(self.process_stream_batch()) } - PiecewiseMergeJoinStreamState::ExhaustedStreamSide => { + PiecewiseMergeJoinStreamState::ProcessUnmatched => { handle_state!(self.process_unmatched_buffered_batch()) } PiecewiseMergeJoinStreamState::Completed => Poll::Ready(None), @@ -209,12 +213,15 @@ impl ClassicPWMJStream { match ready!(self.streamed.poll_next_unpin(cx)) { None => { if self + .buffered_side + .try_as_ready_mut()? + .buffered_data .remaining_partitions .fetch_sub(1, std::sync::atomic::Ordering::SeqCst) == 1 { self.batch_process_state.reset(); - self.state = PiecewiseMergeJoinStreamState::ExhaustedStreamSide; + self.state = PiecewiseMergeJoinStreamState::ProcessUnmatched; } else { self.state = PiecewiseMergeJoinStreamState::Completed; } @@ -240,11 +247,12 @@ impl ClassicPWMJStream { // Reset BatchProcessState before processing a new stream batch self.batch_process_state.reset(); - self.state = - PiecewiseMergeJoinStreamState::ProcessStreamBatch(StreamedBatch { + self.state = PiecewiseMergeJoinStreamState::ProcessStreamBatch( + SortedStreamBatch { batch: stream_batch, - values: vec![stream_values], - }); + compare_key_values: vec![stream_values], + }, + ); } Some(Err(err)) => return Poll::Ready(Err(err)), }; @@ -292,9 +300,13 @@ impl ClassicPWMJStream { self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; return Ok(StatefulStreamResult::Ready(Some(b))); } + // Nothing pending; hand back whatever `resolve` returned (often empty) and move on. + // if self.batch_process_state.output_batches.is_empty() { self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + return Ok(StatefulStreamResult::Ready(Some(batch))); + // } } Ok(StatefulStreamResult::Ready(Some(batch))) @@ -438,7 +450,7 @@ impl Stream for ClassicPWMJStream { #[allow(clippy::too_many_arguments)] fn resolve_classic_join( buffered_side: &mut BufferedSideReadyState, - stream_batch: &StreamedBatch, + stream_batch: &SortedStreamBatch, join_schema: Arc, operator: Operator, sort_options: SortOptions, @@ -446,7 +458,7 @@ fn resolve_classic_join( batch_process_state: &mut BatchProcessState, ) -> Result { let buffered_len = buffered_side.buffered_data.values().len(); - let stream_values = stream_batch.values(); + let stream_values = stream_batch.compare_key_values(); let mut buffer_idx = batch_process_state.start_buffer_idx; let stream_idx = batch_process_state.start_stream_idx; @@ -474,7 +486,7 @@ fn resolve_classic_join( batch_process_state.found = true; let count = buffered_len - buffer_idx; - let batch = build_matched_indices( + let batch = build_matched_indices_and_set_buffered_bitmap( (buffer_idx, count), (row_idx, count), buffered_side, @@ -502,7 +514,7 @@ fn resolve_classic_join( if matches!(compare, Ordering::Equal | Ordering::Less) { batch_process_state.found = true; let count = buffered_len - buffer_idx; - let batch = build_matched_indices( + let batch = build_matched_indices_and_set_buffered_bitmap( (buffer_idx, count), (row_idx, count), buffered_side, @@ -526,7 +538,7 @@ fn resolve_classic_join( } } _ => { - return exec_err!( + return internal_err!( "PiecewiseMergeJoin should not contain operator, {}", operator ) @@ -569,11 +581,11 @@ fn resolve_classic_join( // // The two ranges are: buffered_range: (start index, count) and streamed_range: (start index, count) due // to batch.slice(start, count). -fn build_matched_indices( +fn build_matched_indices_and_set_buffered_bitmap( buffered_range: (usize, usize), streamed_range: (usize, usize), buffered_side: &mut BufferedSideReadyState, - stream_batch: &StreamedBatch, + stream_batch: &SortedStreamBatch, join_type: JoinType, join_schema: Arc, ) -> Result { @@ -606,7 +618,7 @@ fn build_matched_indices( // Creates a record batch from the unmatched indices on the streamed side fn create_unmatched_batch( streamed_indices: &mut PrimitiveBuilder, - stream_batch: &StreamedBatch, + stream_batch: &SortedStreamBatch, join_schema: Arc, ) -> Result { let streamed_indices = streamed_indices.finish(); @@ -719,14 +731,7 @@ mod tests { operator: Operator, join_type: JoinType, ) -> Result { - PiecewiseMergeJoinExec::try_new( - left, - right, - on, - operator, - join_type, - Arc::new(AtomicUsize::new(1)), - ) + PiecewiseMergeJoinExec::try_new(left, right, on, operator, join_type, 1) } async fn join_collect( diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index 8fa58219ef9ce..9e91afd0dc877 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -81,19 +81,19 @@ use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; /// Classic joins are processed differently compared to existence joins. /// /// ## Classic Joins (Inner, Full, Left, Right) -/// For classic joins we buffer the right side (the "build" side) and stream the left side (the "probe" side). +/// For classic joins we buffer the build side and stream the probe side (the "probe" side). /// Both sides are sorted so that we can iterate from index 0 to the end on each side. This ordering ensures -/// that when we find the first matching pair of rows, we can emit the current left row joined with all remaining -/// right rows from the match position onward, without rescanning earlier right rows. +/// that when we find the first matching pair of rows, we can emit the current stream row joined with all remaining +/// probe rows from the match position onward, without rescanning earlier probe rows. /// /// For `<` and `<=` operators, both inputs are sorted in **descending** order, while for `>` and `>=` operators /// they are sorted in **ascending** order. This choice ensures that the pointer on the buffered side can advance -/// monotonically as we stream new batches from the left side. +/// monotonically as we stream new batches from the stream side. /// -/// The streamed (left) side may arrive unsorted, so this operator sorts each incoming batch in memory before -/// processing. The buffered (right) side is required to be globally sorted; the plan declares this requirement +/// The streamed side may arrive unsorted, so this operator sorts each incoming batch in memory before +/// processing. The buffered side is required to be globally sorted; the plan declares this requirement /// in `requires_input_order`, which allows the optimizer to automatically insert a `SortExec` on that side if needed. -/// By the time this operator runs, the right side is guaranteed to be in the proper order. +/// By the time this operator runs, the buffered side is guaranteed to be in the proper order. /// /// The pseudocode for the algorithm looks like this: /// @@ -215,6 +215,12 @@ use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; /// For both types of joins, the buffered side must be sorted ascending for `Operator::Lt` (<) or /// `Operator::LtEq` (<=) and descending for `Operator::Gt` (>) or `Operator::GtEq` (>=). /// +/// # Partitioning Logic +/// Piecewise Merge Join requires one buffered side partition + round robin partitioned stream side. A counter +/// is used in the buffered side to coordinate when all streamed partitions are finished execution. This allows +/// for processing the rest of the unmatched rows for Left and Full joins. The last partition that finishes +/// execution will be responsible for outputting the unmatched rows. +/// /// # Performance Explanation (cost) /// Piecewise Merge Join is used over Nested Loop Join due to its superior performance. Here is the breakdown: /// @@ -275,7 +281,7 @@ pub struct PiecewiseMergeJoinExec { /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, /// Number of partitions to process - remaining_partitions: Arc, + num_partitions: usize, } impl PiecewiseMergeJoinExec { @@ -285,7 +291,7 @@ impl PiecewiseMergeJoinExec { on: (Arc, Arc), operator: Operator, join_type: JoinType, - num_partitions: Arc, + num_partitions: usize, ) -> Result { // TODO: Implement existence joins for PiecewiseMergeJoin if is_existence_join(join_type) { @@ -368,7 +374,7 @@ impl PiecewiseMergeJoinExec { right_batch_required_orders, sort_options, cache, - remaining_partitions: num_partitions, + num_partitions, }) } @@ -510,7 +516,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { self.on.clone(), self.operator, self.join_type, - Arc::clone(&self.remaining_partitions), + self.num_partitions, )?)), _ => internal_err!( "PiecewiseMergeJoin should have 2 children, found {}", @@ -539,6 +545,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { metrics.clone(), reservation, build_visited_indices_map(self.join_type), + partition, )) })?; @@ -561,7 +568,6 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { self.sort_options, metrics, batch_size, - Arc::clone(&self.remaining_partitions), ))) } } @@ -602,6 +608,7 @@ async fn build_buffered_data( metrics: BuildProbeJoinMetrics, reservation: MemoryReservation, build_map: bool, + remaining_partitions: usize, ) -> Result { let schema = buffered.schema(); @@ -653,6 +660,7 @@ async fn build_buffered_data( single_batch, buffered_values, Mutex::new(visited_indices_bitmap), + remaining_partitions, reservation, ); @@ -663,6 +671,7 @@ pub(super) struct BufferedSideData { pub(super) batch: RecordBatch, values: ArrayRef, pub(super) visited_indices_bitmap: SharedBitmapBuilder, + pub(super) remaining_partitions: AtomicUsize, _reservation: MemoryReservation, } @@ -671,12 +680,14 @@ impl BufferedSideData { batch: RecordBatch, values: ArrayRef, visited_indices_bitmap: SharedBitmapBuilder, + remaining_partitions: usize, reservation: MemoryReservation, ) -> Self { Self { batch, values, visited_indices_bitmap, + remaining_partitions: AtomicUsize::new(remaining_partitions), _reservation: reservation, } } diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 99b454ee80949..fed8b851cdaed 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -285,7 +285,7 @@ datafusion.format.time_format %H:%M:%S%.f datafusion.format.timestamp_format %Y-%m-%dT%H:%M:%S%.f datafusion.format.timestamp_tz_format NULL datafusion.format.types_info false -datafusion.optimizer.allow_piecewise_merge_join false +datafusion.optimizer.enable_piecewise_merge_join false datafusion.optimizer.allow_symmetric_joins_without_pruning true datafusion.optimizer.default_filter_selectivity 20 datafusion.optimizer.enable_distinct_aggregation_soft_limit true @@ -402,7 +402,7 @@ datafusion.format.time_format %H:%M:%S%.f Time format for time arrays datafusion.format.timestamp_format %Y-%m-%dT%H:%M:%S%.f Timestamp format for timestamp arrays datafusion.format.timestamp_tz_format NULL Timestamp format for timestamp with timezone arrays. When `None`, ISO 8601 format is used. datafusion.format.types_info false Show types in visual representation batches -datafusion.optimizer.allow_piecewise_merge_join false When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. +datafusion.optimizer.enable_piecewise_merge_join false When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. datafusion.optimizer.default_filter_selectivity 20 The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 7f3f9ab5fcac4..7c63ad19d5770 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4196,21 +4196,6 @@ SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 LIMIT 20; 3 3 3 3 true 4 4 NULL NULL NULL -query TT rowsort --- Note: using LIMIT value higher than cardinality before LIMIT to avoid query non-determinism -EXPLAIN SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 20; ----- -01)GlobalLimitExec: skip=0, fetch=20 -01)Limit: skip=0, fetch=20 -02)--Full Join: Filter: t0.c2 >= t1.c2 -02)--NestedLoopJoinExec: join_type=Full, filter=c2@0 >= c2@1 -03)----DataSourceExec: partitions=1, partition_sizes=[2] -03)----TableScan: t0 projection=[c1, c2] -04)----DataSourceExec: partitions=1, partition_sizes=[2] -04)----TableScan: t1 projection=[c1, c2, c3] -logical_plan -physical_plan - query IIIIB rowsort -- Note: using LIMIT value higher than cardinality before LIMIT to avoid query non-determinism SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 20; @@ -5208,7 +5193,7 @@ SELECT c # PiecewiseMergeJoin Test statement ok -set datafusion.optimizer.allow_piecewise_merge_join = true; +set datafusion.optimizer.enable_piecewise_merge_join = true; query II SELECT join_t1.t1_id, join_t2.t2_id @@ -5248,3 +5233,6 @@ DROP TABLE t2; statement ok set datafusion.explain.physical_plan_only = false; + +statement ok +set datafusion.optimizer.enable_piecewise_merge_join = false; diff --git a/datafusion/sqllogictest/test_files/pwmj.slt b/datafusion/sqllogictest/test_files/pwmj.slt index aff19d2ecb50b..30d1205e2b45e 100644 --- a/datafusion/sqllogictest/test_files/pwmj.slt +++ b/datafusion/sqllogictest/test_files/pwmj.slt @@ -17,7 +17,7 @@ statement ok -set datafusion.optimizer.allow_piecewise_merge_join = true; +set datafusion.optimizer.enable_piecewise_merge_join = true; statement ok CREATE TABLE join_t1 (t1_id INT); @@ -48,6 +48,20 @@ ORDER BY 1; 33 11 44 11 +# Checking `SELECT *` +query IITI +SELECT * +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id > t2.t2_id +WHERE t1.t1_id > 10 + AND t2.t2_int > 1 +ORDER BY 1; +---- +22 11 +33 11 +44 11 + query TT EXPLAIN SELECT t1.t1_id, t2.t2_id @@ -174,6 +188,35 @@ physical_plan 09)------------DataSourceExec: partitions=1, partition_sizes=[1] +query II +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id < (t2.t2_id + 1) +WHERE t2.t2_int >= 3 +ORDER BY 1,2; +---- +11 11 +11 44 +11 55 +22 44 +22 55 +33 44 +33 55 +44 44 +44 55 + +query TT +EXPLAIN +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id < (t2.t2_id + 1) +WHERE t2.t2_int >= 3 +ORDER BY 1,2; +---- + + query II SELECT t1.t1_id, t2.t2_id FROM join_t1 t1 @@ -223,4 +266,4 @@ physical_plan 11)------------DataSourceExec: partitions=1, partition_sizes=[1] statement ok -set datafusion.optimizer.allow_piecewise_merge_join = false; +set datafusion.optimizer.enable_piecewise_merge_join = false; diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 680d45d39b860..6cf8e7bbe5a6c 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -146,7 +146,7 @@ The following configuration settings are available: | datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | | datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | | datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | -| datafusion.optimizer.allow_piecewise_merge_join | false | When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. | +| datafusion.optimizer.enable_piecewise_merge_join | false | When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. | | datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | | datafusion.optimizer.hash_join_single_partition_threshold_rows | 131072 | The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition | | datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | From 499ccd39a02592ac02b3d2e52aaa6f89244386eb Mon Sep 17 00:00:00 2001 From: Jonathan Date: Tue, 14 Oct 2025 19:21:17 -0400 Subject: [PATCH 16/20] fix --- .../src/joins/piecewise_merge_join/classic_join.rs | 8 ++++---- .../physical-plan/src/joins/piecewise_merge_join/exec.rs | 2 +- docs/source/user-guide/configs.md | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index 23ccded9aa81b..2c99d23206f35 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -302,11 +302,11 @@ impl ClassicPWMJStream { } // Nothing pending; hand back whatever `resolve` returned (often empty) and move on. - // if self.batch_process_state.output_batches.is_empty() { - self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + if self.batch_process_state.output_batches.is_empty() { + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; - return Ok(StatefulStreamResult::Ready(Some(batch))); - // } + return Ok(StatefulStreamResult::Ready(Some(batch))); + } } Ok(StatefulStreamResult::Ready(Some(batch))) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index b687f40172726..119e3ca906c89 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -547,7 +547,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { metrics.clone(), reservation, build_visited_indices_map(self.join_type), - partition, + self.num_partitions, )) })?; diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 6cf8e7bbe5a6c..e38cc3a2b4cb7 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -146,7 +146,7 @@ The following configuration settings are available: | datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | | datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | | datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | -| datafusion.optimizer.enable_piecewise_merge_join | false | When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. | +| datafusion.optimizer.enable_piecewise_merge_join | false | When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. | | datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | | datafusion.optimizer.hash_join_single_partition_threshold_rows | 131072 | The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition | | datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | From c9819479eb8864ccc7e42d6e3d79b9f877ae113f Mon Sep 17 00:00:00 2001 From: Jonathan Date: Tue, 14 Oct 2025 20:30:03 -0400 Subject: [PATCH 17/20] fix null handling --- .../piecewise_merge_join/classic_join.rs | 18 ++++- .../src/joins/piecewise_merge_join/exec.rs | 8 +-- .../test_files/information_schema.slt | 4 +- datafusion/sqllogictest/test_files/joins.slt | 2 +- datafusion/sqllogictest/test_files/pwmj.slt | 65 ++++++++++++++++--- 5 files changed, 79 insertions(+), 18 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index 2c99d23206f35..373eb62d97566 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -63,7 +63,7 @@ impl PiecewiseMergeJoinStreamState { /// Note the compare key in the join predicate might include expressions on the original /// columns, so we store the evaluated compare key separately. /// e.g. For join predicate `buffer.v1 < (stream.v1 + 1)`, the `compare_key_values` field stores -/// the evaluted `stream.v1 + 1` array. +/// the evaluated `stream.v1 + 1` array. pub(super) struct SortedStreamBatch { pub batch: RecordBatch, compare_key_values: Vec, @@ -412,6 +412,8 @@ struct BatchProcessState { found: bool, // Signals to continue processing the current stream batch continue_process: bool, + // Skip nulls + processed_null_count: bool, } impl BatchProcessState { @@ -423,6 +425,7 @@ impl BatchProcessState { start_stream_idx: 0, found: false, continue_process: true, + processed_null_count: false, } } @@ -432,6 +435,7 @@ impl BatchProcessState { self.start_stream_idx = 0; self.found = false; self.continue_process = true; + self.processed_null_count = false; } } @@ -459,9 +463,17 @@ fn resolve_classic_join( ) -> Result { let buffered_len = buffered_side.buffered_data.values().len(); let stream_values = stream_batch.compare_key_values(); - + let mut buffer_idx = batch_process_state.start_buffer_idx; - let stream_idx = batch_process_state.start_stream_idx; + let mut stream_idx = batch_process_state.start_stream_idx; + + if !batch_process_state.processed_null_count { + let buffered_null_idx = buffered_side.buffered_data.values().null_count(); + let stream_null_idx = stream_values[0].null_count(); + buffer_idx = buffered_null_idx; + stream_idx = stream_null_idx; + batch_process_state.processed_null_count = true; + } // Our buffer_idx variable allows us to start probing on the buffered side where we last matched // in the previous stream row. diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index 119e3ca906c89..987f3e9df45ac 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -307,16 +307,16 @@ impl PiecewiseMergeJoinExec { // For left existence joins the inputs will be swapped so the sort // options are switched if is_right_existence_join(join_type) { - SortOptions::new(false, false) + SortOptions::new(false, true) } else { - SortOptions::new(true, false) + SortOptions::new(true, true) } } Operator::Gt | Operator::GtEq => { if is_right_existence_join(join_type) { - SortOptions::new(true, false) + SortOptions::new(true, true) } else { - SortOptions::new(false, false) + SortOptions::new(false, true) } } _ => { diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index fed8b851cdaed..c4ce0eb8a433e 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -285,11 +285,11 @@ datafusion.format.time_format %H:%M:%S%.f datafusion.format.timestamp_format %Y-%m-%dT%H:%M:%S%.f datafusion.format.timestamp_tz_format NULL datafusion.format.types_info false -datafusion.optimizer.enable_piecewise_merge_join false datafusion.optimizer.allow_symmetric_joins_without_pruning true datafusion.optimizer.default_filter_selectivity 20 datafusion.optimizer.enable_distinct_aggregation_soft_limit true datafusion.optimizer.enable_dynamic_filter_pushdown true +datafusion.optimizer.enable_piecewise_merge_join false datafusion.optimizer.enable_round_robin_repartition true datafusion.optimizer.enable_topk_aggregation true datafusion.optimizer.enable_window_limits true @@ -402,11 +402,11 @@ datafusion.format.time_format %H:%M:%S%.f Time format for time arrays datafusion.format.timestamp_format %Y-%m-%dT%H:%M:%S%.f Timestamp format for timestamp arrays datafusion.format.timestamp_tz_format NULL Timestamp format for timestamp with timezone arrays. When `None`, ISO 8601 format is used. datafusion.format.types_info false Show types in visual representation batches -datafusion.optimizer.enable_piecewise_merge_join false When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. datafusion.optimizer.default_filter_selectivity 20 The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. datafusion.optimizer.enable_dynamic_filter_pushdown true When set to true attempts to push down dynamic filters generated by operators into the file scan phase. For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. This means that if we already have 10 timestamps in the year 2025 any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. +datafusion.optimizer.enable_piecewise_merge_join false When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. Physical planner will opt for PiecewiseMergeJoin when there is only one range filter. datafusion.optimizer.enable_round_robin_repartition true When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores datafusion.optimizer.enable_topk_aggregation true When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible datafusion.optimizer.enable_window_limits true When set to true, the optimizer will attempt to push limit operations past window functions, if possible diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 7c63ad19d5770..fd648f3ea0843 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -5217,7 +5217,7 @@ ORDER BY 1 physical_plan 01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] 02)--PiecewiseMergeJoin: operator=Gt, join_type=Inner, on=(t1_id > t2_id) -03)----SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +03)----SortExec: expr=[t1_id@0 ASC], preserve_partitioning=[false] 04)------CoalesceBatchesExec: target_batch_size=3 05)--------FilterExec: t1_id@0 > 10 06)----------DataSourceExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/pwmj.slt b/datafusion/sqllogictest/test_files/pwmj.slt index 30d1205e2b45e..fe018939a397d 100644 --- a/datafusion/sqllogictest/test_files/pwmj.slt +++ b/datafusion/sqllogictest/test_files/pwmj.slt @@ -58,9 +58,9 @@ WHERE t1.t1_id > 10 AND t2.t2_int > 1 ORDER BY 1; ---- -22 11 -33 11 -44 11 +22 11 z 3 +33 11 z 3 +44 11 z 3 query TT EXPLAIN @@ -86,7 +86,7 @@ physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] 02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----PiecewiseMergeJoin: operator=Gt, join_type=Inner, on=(t1_id > t2_id) -04)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +04)------SortExec: expr=[t1_id@0 ASC], preserve_partitioning=[false] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------FilterExec: t1_id@0 > 10 07)------------DataSourceExec: partitions=1, partition_sizes=[1] @@ -133,7 +133,7 @@ physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST, t2_id@1 ASC NULLS LAST] 02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----PiecewiseMergeJoin: operator=GtEq, join_type=Inner, on=(t1_id >= t2_id) -04)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +04)------SortExec: expr=[t1_id@0 ASC], preserve_partitioning=[false] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------FilterExec: t1_id@0 >= 22 07)------------DataSourceExec: partitions=1, partition_sizes=[1] @@ -180,7 +180,7 @@ physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST, t2_id@1 ASC NULLS LAST] 02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----PiecewiseMergeJoin: operator=Lt, join_type=Inner, on=(t1_id < t2_id) -04)------SortExec: expr=[t1_id@0 DESC NULLS LAST], preserve_partitioning=[false] +04)------SortExec: expr=[t1_id@0 DESC], preserve_partitioning=[false] 05)--------DataSourceExec: partitions=1, partition_sizes=[1] 06)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 07)--------CoalesceBatchesExec: target_batch_size=8192 @@ -215,7 +215,25 @@ JOIN join_t2 t2 WHERE t2.t2_int >= 3 ORDER BY 1,2; ---- - +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST, t2.t2_id ASC NULLS LAST +02)--Inner Join: Filter: CAST(t1.t1_id AS Int64) < CAST(t2.t2_id AS Int64) + Int64(1) +03)----SubqueryAlias: t1 +04)------TableScan: join_t1 projection=[t1_id] +05)----SubqueryAlias: t2 +06)------Projection: join_t2.t2_id +07)--------Filter: join_t2.t2_int >= Int32(3) +08)----------TableScan: join_t2 projection=[t2_id, t2_int] +physical_plan +01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST, t2_id@1 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST, t2_id@1 ASC NULLS LAST], preserve_partitioning=[true] +03)----PiecewiseMergeJoin: operator=Lt, join_type=Inner, on=(CAST(t1_id AS Int64) < CAST(t2_id AS Int64) + 1) +04)------SortExec: expr=[CAST(t1_id@0 AS Int64) DESC], preserve_partitioning=[false] +05)--------DataSourceExec: partitions=1, partition_sizes=[1] +06)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)--------CoalesceBatchesExec: target_batch_size=8192 +08)----------FilterExec: t2_int@1 >= 3, projection=[t2_id@0] +09)------------DataSourceExec: partitions=1, partition_sizes=[1] query II SELECT t1.t1_id, t2.t2_id @@ -256,7 +274,7 @@ physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST, t2_id@1 ASC NULLS LAST] 02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----PiecewiseMergeJoin: operator=LtEq, join_type=Inner, on=(t1_id <= t2_id) -04)------SortExec: expr=[t1_id@0 DESC NULLS LAST], preserve_partitioning=[false] +04)------SortExec: expr=[t1_id@0 DESC], preserve_partitioning=[false] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------FilterExec: t1_id@0 = 11 OR t1_id@0 = 44 07)------------DataSourceExec: partitions=1, partition_sizes=[1] @@ -265,5 +283,36 @@ physical_plan 10)----------FilterExec: t2_name@1 != y, projection=[t2_id@0] 11)------------DataSourceExec: partitions=1, partition_sizes=[1] +statement ok +CREATE TABLE null_join_t1 (id INT); + +statement ok +CREATE TABLE null_join_t2 (id INT); + +statement ok +INSERT INTO null_join_t1 VALUES (1), (2), (NULL); + +statement ok +INSERT INTO null_join_t2 VALUES (1), (NULL), (3); + +query II +SELECT t1.id AS left_id, t2.id AS right_id +FROM null_join_t1 t1 +JOIN null_join_t2 t2 + ON t1.id > t2.id +ORDER BY 1,2; +---- +2 1 + +query II +SELECT t1.id AS left_id, t2.id AS right_id +FROM null_join_t1 t1 +JOIN null_join_t2 t2 + ON t1.id < t2.id +ORDER BY 1,2; +---- +1 3 +2 3 + statement ok set datafusion.optimizer.enable_piecewise_merge_join = false; From 2946585e7e0073b191fc5086052c3d32d543e784 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Tue, 14 Oct 2025 20:30:42 -0400 Subject: [PATCH 18/20] fmt --- .../src/joins/piecewise_merge_join/classic_join.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index 373eb62d97566..646905e0d7875 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -463,7 +463,7 @@ fn resolve_classic_join( ) -> Result { let buffered_len = buffered_side.buffered_data.values().len(); let stream_values = stream_batch.compare_key_values(); - + let mut buffer_idx = batch_process_state.start_buffer_idx; let mut stream_idx = batch_process_state.start_stream_idx; From 86cf18a9aaecc3e0082c4b07f7dc7af4534b1eb5 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 15 Oct 2025 10:14:26 -0400 Subject: [PATCH 19/20] fix configs.md + fix Both side logic --- datafusion/core/src/physical_planner.rs | 44 +++++++++++++++------ datafusion/sqllogictest/test_files/pwmj.slt | 36 +++++++++++++++++ docs/source/user-guide/configs.md | 2 +- 3 files changed, 70 insertions(+), 12 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index a79dec54b16e0..15b558c183f10 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1295,19 +1295,28 @@ impl DefaultPhysicalPlanner { } } - let side_of = |e: &Expr| -> Result<&'static str> { + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + enum Side { + Left, + Right, + Both, + } + + let side_of = |e: &Expr| -> Result { let cols = e.column_refs(); - let in_left = cols + let any_left = cols .iter() - .all(|c| left_df_schema.index_of_column(c).is_ok()); - let in_right = cols + .any(|c| left_df_schema.index_of_column(c).is_ok()); + let any_right = cols .iter() - .all(|c| right_df_schema.index_of_column(c).is_ok()); - match (in_left, in_right) { - (true, false) => Ok("left"), - (false, true) => Ok("right"), + .any(|c| right_df_schema.index_of_column(c).is_ok()); + + Ok(match (any_left, any_right) { + (true, false) => Side::Left, + (false, true) => Side::Right, + (true, true) => Side::Both, _ => unreachable!(), - } + }) }; let mut lhs_logical = &be.left; @@ -1315,10 +1324,23 @@ impl DefaultPhysicalPlanner { let left_side = side_of(lhs_logical)?; let right_side = side_of(rhs_logical)?; - if left_side == "right" && right_side == "left" { + if matches!(left_side, Side::Both) + || matches!(right_side, Side::Both) + { + return Ok(Arc::new(NestedLoopJoinExec::try_new( + physical_left, + physical_right, + join_filter, + join_type, + None, + )?)); + } + + if left_side == Side::Right && right_side == Side::Left { std::mem::swap(&mut lhs_logical, &mut rhs_logical); op = reverse_ineq(op); - } else if !(left_side == "left" && right_side == "right") { + } else if !(left_side == Side::Left && right_side == Side::Right) + { return plan_err!( "Unsupported operator for PWMJ: {:?}. Expected one of <, <=, >, >=", op diff --git a/datafusion/sqllogictest/test_files/pwmj.slt b/datafusion/sqllogictest/test_files/pwmj.slt index fe018939a397d..0014b3c545f29 100644 --- a/datafusion/sqllogictest/test_files/pwmj.slt +++ b/datafusion/sqllogictest/test_files/pwmj.slt @@ -304,6 +304,42 @@ ORDER BY 1,2; ---- 2 1 +# Verify this will offload this query to Nested Loop Join +query II +SELECT t1.id AS left_id, t2.id AS right_id +FROM null_join_t1 t1 +JOIN null_join_t2 t2 + ON t1.id < (t1.id + t2.id) +ORDER BY 1,2; +---- +1 1 +1 3 +2 1 +2 3 + +query TT +EXPLAIN +SELECT t1.id AS left_id, t2.id AS right_id +FROM null_join_t1 t1 +JOIN null_join_t2 t2 + ON t1.id < (t1.id + t2.id) +ORDER BY 1,2; +---- +logical_plan +01)Sort: left_id ASC NULLS LAST, right_id ASC NULLS LAST +02)--Projection: t1.id AS left_id, t2.id AS right_id +03)----Inner Join: Filter: t1.id < t1.id + t2.id +04)------SubqueryAlias: t1 +05)--------TableScan: null_join_t1 projection=[id] +06)------SubqueryAlias: t2 +07)--------TableScan: null_join_t2 projection=[id] +physical_plan +01)SortExec: expr=[left_id@0 ASC NULLS LAST, right_id@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[id@0 as left_id, id@1 as right_id] +03)----NestedLoopJoinExec: join_type=Inner, filter=id@0 < id@0 + id@1 +04)------DataSourceExec: partitions=1, partition_sizes=[1] +05)------DataSourceExec: partitions=1, partition_sizes=[1] + query II SELECT t1.id AS left_id, t2.id AS right_id FROM null_join_t1 t1 diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index e38cc3a2b4cb7..b2c66a53a4da2 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -146,7 +146,7 @@ The following configuration settings are available: | datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | | datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | | datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | -| datafusion.optimizer.enable_piecewise_merge_join | false | When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. | +| datafusion.optimizer.enable_piecewise_merge_join | false | When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. Physical planner will opt for PiecewiseMergeJoin when there is only one range filter. | | datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | | datafusion.optimizer.hash_join_single_partition_threshold_rows | 131072 | The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition | | datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | From 84c3255a85c3481823ab86c8d07a4de16dc6adb1 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Thu, 16 Oct 2025 12:38:07 -0400 Subject: [PATCH 20/20] fix slt --- datafusion/sqllogictest/test_files/information_schema.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 7d85254780d3e..765d045917e9d 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -408,8 +408,8 @@ datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusio datafusion.optimizer.default_filter_selectivity 20 The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. datafusion.optimizer.enable_dynamic_filter_pushdown true When set to true attempts to push down dynamic filters generated by operators (topk & join) into the file scan phase. For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. This means that if we already have 10 timestamps in the year 2025 any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. The config will suppress `enable_join_dynamic_filter_pushdown` & `enable_topk_dynamic_filter_pushdown` So if you disable `enable_topk_dynamic_filter_pushdown`, then enable `enable_dynamic_filter_pushdown`, the `enable_topk_dynamic_filter_pushdown` will be overridden. -datafusion.optimizer.enable_piecewise_merge_join false When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. Physical planner will opt for PiecewiseMergeJoin when there is only one range filter. datafusion.optimizer.enable_join_dynamic_filter_pushdown true When set to true, the optimizer will attempt to push down Join dynamic filters into the file scan phase. +datafusion.optimizer.enable_piecewise_merge_join false When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. Physical planner will opt for PiecewiseMergeJoin when there is only one range filter. datafusion.optimizer.enable_round_robin_repartition true When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores datafusion.optimizer.enable_topk_aggregation true When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible datafusion.optimizer.enable_topk_dynamic_filter_pushdown true When set to true, the optimizer will attempt to push down TopK dynamic filters into the file scan phase.