diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index 8edd20a8b7a85..b6e662a6e90a3 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -32,7 +32,7 @@ use crate::physical_plan::metrics::{ }; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStream; use crate::physical_plan::sorts::SortedStream; -use crate::physical_plan::stream::RecordBatchReceiverStream; +use crate::physical_plan::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; use crate::physical_plan::{ DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, @@ -42,12 +42,12 @@ use arrow::array::{make_array, Array, ArrayRef, MutableArrayData, UInt32Array}; pub use arrow::compute::SortOptions; use arrow::compute::{concat, lexsort_to_indices, take, SortColumn, TakeOptions}; use arrow::datatypes::SchemaRef; -use arrow::error::Result as ArrowResult; +use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::ipc::reader::FileReader; use arrow::record_batch::RecordBatch; use async_trait::async_trait; use futures::lock::Mutex; -use futures::{Stream, StreamExt}; +use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt}; use log::{debug, error}; use std::any::Any; use std::cmp::min; @@ -779,17 +779,20 @@ impl ExecutionPlan for SortExec { debug!("End SortExec's input.execute for partition: {}", partition); - let result = do_sort( - input, - partition, - self.expr.clone(), - self.metrics_set.clone(), - context, - ) - .await; - - debug!("End SortExec::execute for partition {}", partition); - result + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::once( + do_sort( + input, + partition, + self.expr.clone(), + self.metrics_set.clone(), + context, + ) + .map_err(|e| ArrowError::ExternalError(Box::new(e))), + ) + .try_flatten(), + ))) } fn metrics(&self) -> Option { diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs index 67b7090406901..99209121ffa7b 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/core/src/physical_plan/stream.rs @@ -21,6 +21,7 @@ use arrow::{ datatypes::SchemaRef, error::Result as ArrowResult, record_batch::RecordBatch, }; use futures::{Stream, StreamExt}; +use pin_project_lite::pin_project; use tokio::task::JoinHandle; use tokio_stream::wrappers::ReceiverStream; @@ -73,3 +74,52 @@ impl RecordBatchStream for RecordBatchReceiverStream { self.schema.clone() } } + +pin_project! { + /// Combines a [`Stream`] with a [`SchemaRef`] implementing + /// [`RecordBatchStream`] for the combination + pub(crate) struct RecordBatchStreamAdapter { + schema: SchemaRef, + + #[pin] + stream: S, + } +} + +impl RecordBatchStreamAdapter { + /// Creates a new [`RecordBatchStreamAdapter`] from the provided schema and stream + pub(crate) fn new(schema: SchemaRef, stream: S) -> Self { + Self { schema, stream } + } +} + +impl std::fmt::Debug for RecordBatchStreamAdapter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RecordBatchStreamAdapter") + .field("schema", &self.schema) + .finish() + } +} + +impl Stream for RecordBatchStreamAdapter +where + S: Stream>, +{ + type Item = ArrowResult; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.project().stream.poll_next(cx) + } +} + +impl RecordBatchStream for RecordBatchStreamAdapter +where + S: Stream>, +{ + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +}