diff --git a/datafusion/core/src/physical_plan/cross_join.rs b/datafusion/core/src/physical_plan/cross_join.rs index 43555f07799dd..240c5dda4a523 100644 --- a/datafusion/core/src/physical_plan/cross_join.rs +++ b/datafusion/core/src/physical_plan/cross_join.rs @@ -18,15 +18,15 @@ //! Defines the cross join plan for loading the left side of the cross join //! and producing batches in parallel for the right partitions -use futures::{lock::Mutex, StreamExt}; +use futures::{ready, FutureExt, StreamExt}; +use futures::{Stream, TryStreamExt}; +use parking_lot::Mutex; use std::{any::Any, sync::Arc, task::Poll}; use arrow::datatypes::{Schema, SchemaRef}; -use arrow::error::Result as ArrowResult; +use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; -use futures::{Stream, TryStreamExt}; - use super::expressions::PhysicalSortExpr; use super::{ coalesce_partitions::CoalescePartitionsExec, join_utils::check_join_is_valid, @@ -34,11 +34,12 @@ use super::{ }; use crate::{error::Result, scalar::ScalarValue}; use async_trait::async_trait; +use futures::future::{BoxFuture, Shared}; use std::time::Instant; use super::{ - coalesce_batches::concat_batches, memory::MemoryStream, DisplayFormatType, - ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, + coalesce_batches::concat_batches, DisplayFormatType, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, }; use crate::execution::context::TaskContext; use log::debug; @@ -46,6 +47,11 @@ use log::debug; /// Data of the left side type JoinLeftData = RecordBatch; +/// Type of future for collecting left data +/// +/// [`Shared`] allows potentially multiple output streams to poll the same future to completion +type JoinLeftFut = Shared>>>; + /// executes partitions in parallel and combines them into a set of /// partitions by combining all values from the left with all values on the right #[derive(Debug)] @@ -57,7 +63,11 @@ pub struct CrossJoinExec { /// The schema once the join is applied schema: SchemaRef, /// Build-side data - build_side: Arc>>, + /// + /// Ideally we would instantiate this in the constructor, avoiding the need for a + /// mutex and an option, but we need the [`TaskContext`] to evaluate the left + /// side data, which is only provided in [`ExecutionPlan::execute`] + left_fut: Mutex>, } impl CrossJoinExec { @@ -87,7 +97,7 @@ impl CrossJoinExec { left, right, schema, - build_side: Arc::new(Mutex::new(None)), + left_fut: Mutex::new(None), }) } @@ -102,6 +112,37 @@ impl CrossJoinExec { } } +/// Asynchronously collect the result of the left child +async fn load_left_input( + left: Arc, + context: Arc, +) -> Result { + let start = Instant::now(); + + // merge all left parts into a single stream + let merge = CoalescePartitionsExec::new(left.clone()); + let stream = merge.execute(0, context).await?; + + // Load all batches and count the rows + let (batches, num_rows) = stream + .try_fold((Vec::new(), 0usize), |mut acc, batch| async { + acc.1 += batch.num_rows(); + acc.0.push(batch); + Ok(acc) + }) + .await?; + + let merged_batch = concat_batches(&left.schema(), &batches, num_rows)?; + + debug!( + "Built build-side of cross join containing {} rows in {} ms", + num_rows, + start.elapsed().as_millis() + ); + + Ok(merged_batch) +} + #[async_trait] impl ExecutionPlan for CrossJoinExec { fn as_any(&self) -> &dyn Any { @@ -143,55 +184,23 @@ impl ExecutionPlan for CrossJoinExec { partition: usize, context: Arc, ) -> Result { - // we only want to compute the build side once - let left_data = { - let mut build_side = self.build_side.lock().await; - - match build_side.as_ref() { - Some(stream) => stream.clone(), - None => { - let start = Instant::now(); - - // merge all left parts into a single stream - let merge = CoalescePartitionsExec::new(self.left.clone()); - let stream = merge.execute(0, context.clone()).await?; - - // Load all batches and count the rows - let (batches, num_rows) = stream - .try_fold((Vec::new(), 0usize), |mut acc, batch| async { - acc.1 += batch.num_rows(); - acc.0.push(batch); - Ok(acc) - }) - .await?; - let merged_batch = - concat_batches(&self.left.schema(), &batches, num_rows)?; - *build_side = Some(merged_batch.clone()); - - debug!( - "Built build-side of cross join containing {} rows in {} ms", - num_rows, - start.elapsed().as_millis() - ); - - merged_batch - } - } - }; - let stream = self.right.execute(partition, context.clone()).await?; - if left_data.num_rows() == 0 { - return Ok(Box::pin(MemoryStream::try_new( - vec![], - self.schema.clone(), - None, - )?)); - } + let left_fut = self + .left_fut + .lock() + .get_or_insert_with(|| { + load_left_input(self.left.clone(), context) + .map(Arc::new) + .boxed() + .shared() + }) + .clone(); Ok(Box::pin(CrossJoinStream { schema: self.schema.clone(), - left_data, + left_fut, + left_result: None, right: stream, right_batch: Arc::new(parking_lot::Mutex::new(None)), left_index: 0, @@ -293,8 +302,10 @@ fn stats_cartesian_product( struct CrossJoinStream { /// Input schema schema: Arc, + /// future for data from left side + left_fut: JoinLeftFut, /// data from the left side - left_data: JoinLeftData, + left_result: Option>>, /// right right: SendableRecordBatchStream, /// Current value on the left @@ -318,6 +329,7 @@ impl RecordBatchStream for CrossJoinStream { self.schema.clone() } } + fn build_batch( left_index: usize, batch: &RecordBatch, @@ -352,14 +364,46 @@ impl Stream for CrossJoinStream { mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - if self.left_index > 0 && self.left_index < self.left_data.num_rows() { + self.poll_next_impl(cx) + } +} + +impl CrossJoinStream { + /// Separate implementation function that unpins the [`CrossJoinStream`] so + /// that partial borrows work correctly + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll>> { + let left_result = match &self.left_result { + Some(data) => data, + None => { + let result = ready!(self.left_fut.poll_unpin(cx)); + self.left_result.insert(result) + } + }; + + let left_data = match left_result.as_ref() { + Ok(left_data) => left_data, + Err(e) => { + return Poll::Ready(Some(Err(ArrowError::ExternalError( + e.to_string().into(), + )))) + } + }; + + if left_data.num_rows() == 0 { + return Poll::Ready(None); + } + + if self.left_index > 0 && self.left_index < left_data.num_rows() { let start = Instant::now(); let right_batch = { let right_batch = self.right_batch.lock(); right_batch.clone().unwrap() }; let result = - build_batch(self.left_index, &right_batch, &self.left_data, &self.schema); + build_batch(self.left_index, &right_batch, left_data, &self.schema); self.num_input_rows += right_batch.num_rows(); if let Ok(ref batch) = result { self.join_time += start.elapsed().as_millis() as usize; @@ -375,12 +419,8 @@ impl Stream for CrossJoinStream { .map(|maybe_batch| match maybe_batch { Some(Ok(batch)) => { let start = Instant::now(); - let result = build_batch( - self.left_index, - &batch, - &self.left_data, - &self.schema, - ); + let result = + build_batch(self.left_index, &batch, left_data, &self.schema); self.num_input_batches += 1; self.num_input_rows += batch.num_rows(); if let Ok(ref batch) = result {