diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs index 7ef1948490741..a7b17c4161b0a 100644 --- a/datafusion/src/physical_plan/repartition.rs +++ b/datafusion/src/physical_plan/repartition.rs @@ -255,12 +255,15 @@ impl RepartitionExec { let mut counter = 0; let hashes_buf = &mut vec![]; - loop { + // While there are still outputs to send to, keep + // pulling inputs + while !txs.is_empty() { // fetch the next batch let now = Instant::now(); let result = stream.next().await; metrics.fetch_nanos.add_elapsed(now); + // Input is done if result.is_none() { break; } @@ -270,9 +273,13 @@ impl RepartitionExec { Partitioning::RoundRobinBatch(_) => { let now = Instant::now(); let output_partition = counter % num_output_partitions; - let tx = txs.get_mut(&output_partition).unwrap(); - tx.send(Some(result)) - .map_err(|e| DataFusionError::Execution(e.to_string()))?; + // if there is still a receiver, send to it + if let Some(tx) = txs.get_mut(&output_partition) { + if tx.send(Some(result)).is_err() { + // If the other end has hung up, it was an early shutdown (e.g. LIMIT) + txs.remove(&output_partition); + } + } metrics.send_nanos.add_elapsed(now); } Partitioning::Hash(exprs, _) => { @@ -315,9 +322,13 @@ impl RepartitionExec { RecordBatch::try_new(input_batch.schema(), columns); metrics.repart_nanos.add_elapsed(now); let now = Instant::now(); - let tx = txs.get_mut(&num_output_partition).unwrap(); - tx.send(Some(output_batch)) - .map_err(|e| DataFusionError::Execution(e.to_string()))?; + // if there is still a receiver, send to it + if let Some(tx) = txs.get_mut(&num_output_partition) { + if tx.send(Some(output_batch)).is_err() { + // If the other end has hung up, it was an early shutdown (e.g. LIMIT) + txs.remove(&num_output_partition); + } + } metrics.send_nanos.add_elapsed(now); } } @@ -425,7 +436,7 @@ mod tests { use crate::{ assert_batches_sorted_eq, physical_plan::memory::MemoryExec, - test::exec::{ErrorExec, MockExec}, + test::exec::{BarrierExec, ErrorExec, MockExec}, }; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; @@ -723,4 +734,105 @@ mod tests { assert_batches_sorted_eq!(&expected, &batches); } + + #[tokio::test] + async fn repartition_with_dropping_output_stream() { + #[derive(Debug)] + struct Case<'a> { + partitioning: Partitioning, + expected: Vec<&'a str>, + } + + let cases = vec![ + Case { + partitioning: Partitioning::RoundRobinBatch(2), + expected: vec![ + "+------------------+", + "| my_awesome_field |", + "+------------------+", + "| baz |", + "| frob |", + "| gaz |", + "| grob |", + "+------------------+", + ], + }, + Case { + partitioning: Partitioning::Hash( + vec![Arc::new(crate::physical_plan::expressions::Column::new( + "my_awesome_field", + ))], + 2, + ), + expected: vec![ + "+------------------+", + "| my_awesome_field |", + "+------------------+", + "| frob |", + "+------------------+", + ], + }, + ]; + + for case in cases { + println!("Running case {:?}", case.partitioning); + + // The barrier exec waits to be pinged + // requires the input to wait at least once) + let input = Arc::new(make_barrier_exec()); + + // partition into two output streams + let exec = + RepartitionExec::try_new(input.clone(), case.partitioning).unwrap(); + + let output_stream0 = exec.execute(0).await.unwrap(); + let output_stream1 = exec.execute(1).await.unwrap(); + + // now, purposely drop output stream 0 + // *before* any outputs are produced + std::mem::drop(output_stream0); + + // Now, start sending input + input.wait().await; + + // output stream 1 should *not* error and have one of the input batches + let batches = crate::physical_plan::common::collect(output_stream1) + .await + .unwrap(); + + assert_batches_sorted_eq!(&case.expected, &batches); + } + } + + /// Create a BarrierExec that returns two partitions of two batches each + fn make_barrier_exec() -> BarrierExec { + let batch1 = RecordBatch::try_from_iter(vec![( + "my_awesome_field", + Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef, + )]) + .unwrap(); + + let batch2 = RecordBatch::try_from_iter(vec![( + "my_awesome_field", + Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef, + )]) + .unwrap(); + + let batch3 = RecordBatch::try_from_iter(vec![( + "my_awesome_field", + Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef, + )]) + .unwrap(); + + let batch4 = RecordBatch::try_from_iter(vec![( + "my_awesome_field", + Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef, + )]) + .unwrap(); + + // The barrier exec waits to be pinged + // requires the input to wait at least once) + let schema = batch1.schema(); + BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema) + } } diff --git a/datafusion/src/test/exec.rs b/datafusion/src/test/exec.rs index bcd94dd6d6397..3971db3adf823 100644 --- a/datafusion/src/test/exec.rs +++ b/datafusion/src/test/exec.rs @@ -23,6 +23,7 @@ use std::{ sync::Arc, task::{Context, Poll}, }; +use tokio::sync::Barrier; use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, @@ -226,6 +227,95 @@ impl RecordBatchStream for DelayedStream { } } +/// A Mock ExecutionPlan that does not start producing input until a +/// barrier is called +/// +#[derive(Debug)] +pub struct BarrierExec { + /// partitions to send back + data: Vec>, + schema: SchemaRef, + + /// all streams wait on this barrier to produce + barrier: Arc, +} + +impl BarrierExec { + /// Create a new exec with some number of partitions. + pub fn new(data: Vec>, schema: SchemaRef) -> Self { + // wait for all streams and the input + let barrier = Arc::new(Barrier::new(data.len() + 1)); + Self { + data, + schema, + barrier, + } + } + + /// wait until all the input streams and this function is ready + pub async fn wait(&self) { + println!("BarrierExec::wait waiting on barrier"); + self.barrier.wait().await; + println!("BarrierExec::wait done waiting"); + } +} + +#[async_trait] +impl ExecutionPlan for BarrierExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.data.len()) + } + + fn children(&self) -> Vec> { + unimplemented!() + } + + fn with_new_children( + &self, + _children: Vec>, + ) -> Result> { + unimplemented!() + } + + /// Returns a stream which yields data + async fn execute(&self, partition: usize) -> Result { + assert!(partition < self.data.len()); + + let schema = self.schema(); + + let (tx, rx) = tokio::sync::mpsc::channel(2); + + // task simply sends data in order after barrier is reached + let data = self.data[partition].clone(); + let b = self.barrier.clone(); + tokio::task::spawn(async move { + println!("Partition {} waiting on barrier", partition); + b.wait().await; + for batch in data { + println!("Partition {} sending batch", partition); + if let Err(e) = tx.send(Ok(batch)).await { + println!("ERROR batch via barrier stream stream: {}", e); + } + } + }); + + // returned stream simply reads off the rx stream + let stream = DelayedStream { + schema, + inner: ReceiverStream::new(rx), + }; + Ok(Box::pin(stream)) + } +} + /// A mock execution plan that errors on a call to execute #[derive(Debug)] pub struct ErrorExec {