Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 101 additions & 61 deletions datafusion/core/src/physical_plan/cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,40 @@
//! Defines the cross join plan for loading the left side of the cross join
//! and producing batches in parallel for the right partitions

use futures::{lock::Mutex, StreamExt};
use futures::{ready, FutureExt, StreamExt};
use futures::{Stream, TryStreamExt};
use parking_lot::Mutex;
use std::{any::Any, sync::Arc, task::Poll};

use arrow::datatypes::{Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
use arrow::error::{ArrowError, Result as ArrowResult};
use arrow::record_batch::RecordBatch;

use futures::{Stream, TryStreamExt};

use super::expressions::PhysicalSortExpr;
use super::{
coalesce_partitions::CoalescePartitionsExec, join_utils::check_join_is_valid,
ColumnStatistics, Statistics,
};
use crate::{error::Result, scalar::ScalarValue};
use async_trait::async_trait;
use futures::future::{BoxFuture, Shared};
use std::time::Instant;

use super::{
coalesce_batches::concat_batches, memory::MemoryStream, DisplayFormatType,
ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream,
coalesce_batches::concat_batches, DisplayFormatType, ExecutionPlan, Partitioning,
RecordBatchStream, SendableRecordBatchStream,
};
use crate::execution::context::TaskContext;
use log::debug;

/// Data of the left side
type JoinLeftData = RecordBatch;

/// Type of future for collecting left data
///
/// [`Shared`] allows potentially multiple output streams to poll the same future to completion
type JoinLeftFut = Shared<BoxFuture<'static, Arc<Result<RecordBatch>>>>;

/// executes partitions in parallel and combines them into a set of
/// partitions by combining all values from the left with all values on the right
#[derive(Debug)]
Expand All @@ -57,7 +63,11 @@ pub struct CrossJoinExec {
/// The schema once the join is applied
schema: SchemaRef,
/// Build-side data
build_side: Arc<Mutex<Option<JoinLeftData>>>,
///
/// Ideally we would instantiate this in the constructor, avoiding the need for a
/// mutex and an option, but we need the [`TaskContext`] to evaluate the left
/// side data, which is only provided in [`ExecutionPlan::execute`]
left_fut: Mutex<Option<JoinLeftFut>>,
}

impl CrossJoinExec {
Expand Down Expand Up @@ -87,7 +97,7 @@ impl CrossJoinExec {
left,
right,
schema,
build_side: Arc::new(Mutex::new(None)),
left_fut: Mutex::new(None),
})
}

Expand All @@ -102,6 +112,37 @@ impl CrossJoinExec {
}
}

/// Asynchronously collect the result of the left child
async fn load_left_input(
left: Arc<dyn ExecutionPlan>,
context: Arc<TaskContext>,
) -> Result<JoinLeftData> {
let start = Instant::now();

// merge all left parts into a single stream
let merge = CoalescePartitionsExec::new(left.clone());
let stream = merge.execute(0, context).await?;

// Load all batches and count the rows
let (batches, num_rows) = stream
.try_fold((Vec::new(), 0usize), |mut acc, batch| async {
acc.1 += batch.num_rows();
acc.0.push(batch);
Ok(acc)
})
.await?;

let merged_batch = concat_batches(&left.schema(), &batches, num_rows)?;

debug!(
"Built build-side of cross join containing {} rows in {} ms",
num_rows,
start.elapsed().as_millis()
);

Ok(merged_batch)
}

#[async_trait]
impl ExecutionPlan for CrossJoinExec {
fn as_any(&self) -> &dyn Any {
Expand Down Expand Up @@ -143,55 +184,23 @@ impl ExecutionPlan for CrossJoinExec {
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
// we only want to compute the build side once
let left_data = {
let mut build_side = self.build_side.lock().await;

match build_side.as_ref() {
Some(stream) => stream.clone(),
None => {
let start = Instant::now();

// merge all left parts into a single stream
let merge = CoalescePartitionsExec::new(self.left.clone());
let stream = merge.execute(0, context.clone()).await?;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually a bug in the old implementation, if part of evaluation errored - it would try it again for the next partition


// Load all batches and count the rows
let (batches, num_rows) = stream
.try_fold((Vec::new(), 0usize), |mut acc, batch| async {
acc.1 += batch.num_rows();
acc.0.push(batch);
Ok(acc)
})
.await?;
let merged_batch =
concat_batches(&self.left.schema(), &batches, num_rows)?;
*build_side = Some(merged_batch.clone());

debug!(
"Built build-side of cross join containing {} rows in {} ms",
num_rows,
start.elapsed().as_millis()
);

merged_batch
}
}
};

let stream = self.right.execute(partition, context.clone()).await?;

if left_data.num_rows() == 0 {
return Ok(Box::pin(MemoryStream::try_new(
vec![],
self.schema.clone(),
None,
)?));
}
let left_fut = self
.left_fut
.lock()
.get_or_insert_with(|| {
load_left_input(self.left.clone(), context)
.map(Arc::new)
.boxed()
.shared()
})
.clone();

Ok(Box::pin(CrossJoinStream {
schema: self.schema.clone(),
left_data,
left_fut,
left_result: None,
right: stream,
right_batch: Arc::new(parking_lot::Mutex::new(None)),
left_index: 0,
Expand Down Expand Up @@ -293,8 +302,10 @@ fn stats_cartesian_product(
struct CrossJoinStream {
/// Input schema
schema: Arc<Schema>,
/// future for data from left side
left_fut: JoinLeftFut,
/// data from the left side
left_data: JoinLeftData,
left_result: Option<Arc<Result<RecordBatch>>>,
/// right
right: SendableRecordBatchStream,
/// Current value on the left
Expand All @@ -318,6 +329,7 @@ impl RecordBatchStream for CrossJoinStream {
self.schema.clone()
}
}

fn build_batch(
left_index: usize,
batch: &RecordBatch,
Expand Down Expand Up @@ -352,14 +364,46 @@ impl Stream for CrossJoinStream {
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if self.left_index > 0 && self.left_index < self.left_data.num_rows() {
self.poll_next_impl(cx)
}
}

impl CrossJoinStream {
/// Separate implementation function that unpins the [`CrossJoinStream`] so
/// that partial borrows work correctly
fn poll_next_impl(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<ArrowResult<RecordBatch>>> {
let left_result = match &self.left_result {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of arcane, hopefully the future changes to ExecutionPlan as part of #2199 will make writing these sorts of pipelines easier

Some(data) => data,
None => {
let result = ready!(self.left_fut.poll_unpin(cx));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Futures-rs does have a MaybeDone construct, but this seemed simpler to understand

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ready! effectvely calls return Poll::Pending if the left_fut does the same, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup

self.left_result.insert(result)
}
};

let left_data = match left_result.as_ref() {
Ok(left_data) => left_data,
Err(e) => {
return Poll::Ready(Some(Err(ArrowError::ExternalError(
e.to_string().into(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a hack as DatafusionError isn't clone-able

))))
}
};

if left_data.num_rows() == 0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure this is correct in the context of outer joins -- even if the left data has no rows, the stream may still produce output...

However, I see the original code did the same, so 🤷

       if left_data.num_rows() == 0 {
            return Ok(Box::pin(MemoryStream::try_new(
                vec![],
                self.schema.clone(),
                None,
            )?));
        }

(it probably only matters for joins that don't have an equality predicate)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I just blindly reproduced the existing behaviour - I presumed this special case was there for a reason

return Poll::Ready(None);
}

if self.left_index > 0 && self.left_index < left_data.num_rows() {
let start = Instant::now();
let right_batch = {
let right_batch = self.right_batch.lock();
right_batch.clone().unwrap()
};
let result =
build_batch(self.left_index, &right_batch, &self.left_data, &self.schema);
build_batch(self.left_index, &right_batch, left_data, &self.schema);
self.num_input_rows += right_batch.num_rows();
if let Ok(ref batch) = result {
self.join_time += start.elapsed().as_millis() as usize;
Expand All @@ -375,12 +419,8 @@ impl Stream for CrossJoinStream {
.map(|maybe_batch| match maybe_batch {
Some(Ok(batch)) => {
let start = Instant::now();
let result = build_batch(
self.left_index,
&batch,
&self.left_data,
&self.schema,
);
let result =
build_batch(self.left_index, &batch, left_data, &self.schema);
self.num_input_batches += 1;
self.num_input_rows += batch.num_rows();
if let Ok(ref batch) = result {
Expand Down