From db06080572f944ee6da1a02f592b89ad0115659b Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Tue, 3 May 2022 15:29:40 +0100 Subject: [PATCH 1/2] Fix Ballista executing during plan --- .../src/execution_plans/distributed_query.rs | 195 +++++----- .../src/execution_plans/shuffle_reader.rs | 37 +- .../src/execution_plans/shuffle_writer.rs | 352 +++++++++--------- ballista/rust/core/src/utils.rs | 42 +-- datafusion/core/src/physical_plan/stream.rs | 8 +- 5 files changed, 315 insertions(+), 319 deletions(-) diff --git a/ballista/rust/core/src/execution_plans/distributed_query.rs b/ballista/rust/core/src/execution_plans/distributed_query.rs index b0d3bef1f062a..ed3c9fceb01bd 100644 --- a/ballista/rust/core/src/execution_plans/distributed_query.rs +++ b/ballista/rust/core/src/execution_plans/distributed_query.rs @@ -30,9 +30,8 @@ use crate::serde::protobuf::{ ExecuteQueryParams, GetJobStatusParams, GetJobStatusResult, KeyValuePair, PartitionLocation, }; -use crate::utils::WrappedStream; -use datafusion::arrow::datatypes::{Schema, SchemaRef}; +use datafusion::arrow::datatypes::SchemaRef; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_plan::LogicalPlan; use datafusion::physical_plan::expressions::PhysicalSortExpr; @@ -43,12 +42,14 @@ use datafusion::physical_plan::{ use crate::serde::protobuf::execute_query_params::OptionalSessionId; use crate::serde::{AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec}; use async_trait::async_trait; +use datafusion::arrow::error::{ArrowError, Result as ArrowResult}; +use datafusion::arrow::record_batch::RecordBatch; use datafusion::execution::context::TaskContext; -use futures::future; -use futures::StreamExt; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt}; use log::{error, info}; -/// This operator sends a logial plan to a Ballista scheduler for execution and +/// This operator sends a logical plan to a Ballista scheduler for execution and /// polls the scheduler until the query is complete and then fetches the resulting /// batches directly from the executors that hold the results from the final /// query stage. @@ -168,15 +169,6 @@ impl ExecutionPlan for DistributedQueryExec { ) -> Result { assert_eq!(0, partition); - info!("Connecting to Ballista scheduler at {}", self.scheduler_url); - // TODO reuse the scheduler to avoid connecting to the Ballista scheduler again and again - - let mut scheduler = SchedulerGrpcClient::connect(self.scheduler_url.clone()) - .await - .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; - - let schema: Schema = self.plan.schema().as_ref().clone().into(); - let mut buf: Vec = vec![]; let plan_message = T::try_from_logical_plan(&self.plan, self.extension_codec.as_ref()).map_err( @@ -191,88 +183,30 @@ impl ExecutionPlan for DistributedQueryExec { DataFusionError::Execution(format!("failed to encode logical plan: {:?}", e)) })?; - let query_result = scheduler - .execute_query(ExecuteQueryParams { - query: Some(Query::LogicalPlan(buf)), - settings: self - .config - .settings() - .iter() - .map(|(k, v)| KeyValuePair { - key: k.to_owned(), - value: v.to_owned(), - }) - .collect::>(), - optional_session_id: Some(OptionalSessionId::SessionId( - self.session_id.clone(), - )), - }) - .await - .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))? - .into_inner(); - - let response_session_id = query_result.session_id; - assert_eq!( - self.session_id.clone(), - response_session_id, - "Session id inconsistent between Client and Server side in DistributedQueryExec." - ); + let query = ExecuteQueryParams { + query: Some(Query::LogicalPlan(buf)), + settings: self + .config + .settings() + .iter() + .map(|(k, v)| KeyValuePair { + key: k.to_owned(), + value: v.to_owned(), + }) + .collect::>(), + optional_session_id: Some(OptionalSessionId::SessionId( + self.session_id.clone(), + )), + }; - let job_id = query_result.job_id; - let mut prev_status: Option = None; + let stream = futures::stream::once( + execute_query(self.scheduler_url.clone(), self.session_id.clone(), query) + .map_err(|e| ArrowError::ExternalError(Box::new(e))), + ) + .try_flatten(); - loop { - let GetJobStatusResult { status } = scheduler - .get_job_status(GetJobStatusParams { - job_id: job_id.clone(), - }) - .await - .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))? - .into_inner(); - let status = status.and_then(|s| s.status).ok_or_else(|| { - DataFusionError::Internal("Received empty status message".to_owned()) - })?; - let wait_future = tokio::time::sleep(Duration::from_millis(100)); - let has_status_change = prev_status.map(|x| x != status).unwrap_or(true); - match status { - job_status::Status::Queued(_) => { - if has_status_change { - info!("Job {} still queued...", job_id); - } - wait_future.await; - prev_status = Some(status); - } - job_status::Status::Running(_) => { - if has_status_change { - info!("Job {} is running...", job_id); - } - wait_future.await; - prev_status = Some(status); - } - job_status::Status::Failed(err) => { - let msg = format!("Job {} failed: {}", job_id, err.error); - error!("{}", msg); - break Err(DataFusionError::Execution(msg)); - } - job_status::Status::Completed(completed) => { - let result = future::join_all( - completed - .partition_location - .into_iter() - .map(fetch_partition), - ) - .await - .into_iter() - .collect::>>()?; - - let result = WrappedStream::new( - Box::pin(futures::stream::iter(result).flatten()), - Arc::new(schema), - ); - break Ok(Box::pin(result)); - } - }; - } + let schema = self.schema(); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) } fn fmt_as( @@ -299,6 +233,79 @@ impl ExecutionPlan for DistributedQueryExec { } } +async fn execute_query( + scheduler_url: String, + session_id: String, + query: ExecuteQueryParams, +) -> Result> + Send> { + info!("Connecting to Ballista scheduler at {}", scheduler_url); + // TODO reuse the scheduler to avoid connecting to the Ballista scheduler again and again + + let mut scheduler = SchedulerGrpcClient::connect(scheduler_url.clone()) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + + let query_result = scheduler + .execute_query(query) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))? + .into_inner(); + + assert_eq!( + session_id, query_result.session_id, + "Session id inconsistent between Client and Server side in DistributedQueryExec." + ); + + let job_id = query_result.job_id; + let mut prev_status: Option = None; + + loop { + let GetJobStatusResult { status } = scheduler + .get_job_status(GetJobStatusParams { + job_id: job_id.clone(), + }) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))? + .into_inner(); + let status = status.and_then(|s| s.status).ok_or_else(|| { + DataFusionError::Internal("Received empty status message".to_owned()) + })?; + let wait_future = tokio::time::sleep(Duration::from_millis(100)); + let has_status_change = prev_status.map(|x| x != status).unwrap_or(true); + match status { + job_status::Status::Queued(_) => { + if has_status_change { + info!("Job {} still queued...", job_id); + } + wait_future.await; + prev_status = Some(status); + } + job_status::Status::Running(_) => { + if has_status_change { + info!("Job {} is running...", job_id); + } + wait_future.await; + prev_status = Some(status); + } + job_status::Status::Failed(err) => { + let msg = format!("Job {} failed: {}", job_id, err.error); + error!("{}", msg); + break Err(DataFusionError::Execution(msg)); + } + job_status::Status::Completed(completed) => { + let streams = completed.partition_location.into_iter().map(|p| { + let f = fetch_partition(p) + .map_err(|e| ArrowError::ExternalError(Box::new(e))); + + futures::stream::once(f).try_flatten() + }); + + break Ok(futures::stream::iter(streams).flatten()); + } + }; + } +} + async fn fetch_partition( location: PartitionLocation, ) -> Result { diff --git a/ballista/rust/core/src/execution_plans/shuffle_reader.rs b/ballista/rust/core/src/execution_plans/shuffle_reader.rs index b0aa6af11b506..27252b980d117 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_reader.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::sync::Arc; -use std::{any::Any, pin::Pin}; use crate::client::BallistaClient; use crate::serde::scheduler::{PartitionLocation, PartitionStats}; -use crate::utils::WrappedStream; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; +use datafusion::error::{DataFusionError, Result}; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::metrics::{ ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, @@ -32,13 +32,11 @@ use datafusion::physical_plan::metrics::{ use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; -use datafusion::{ - error::{DataFusionError, Result}, - physical_plan::RecordBatchStream, -}; -use futures::{future, StreamExt}; +use futures::{StreamExt, TryStreamExt}; +use datafusion::arrow::error::ArrowError; use datafusion::execution::context::TaskContext; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use log::info; /// ShuffleReaderExec reads partitions that have already been materialized by a ShuffleWriterExec @@ -112,18 +110,23 @@ impl ExecutionPlan for ShuffleReaderExec { let fetch_time = MetricBuilder::new(&self.metrics).subset_time("fetch_time", partition); - let timer = fetch_time.timer(); - let partition_locations = &self.partition[partition]; - let result = future::join_all(partition_locations.iter().map(fetch_partition)) - .await - .into_iter() - .collect::>>()?; - timer.done(); + let locations = self.partition[partition].clone(); + let stream = locations.into_iter().map(move |p| { + let fetch_time = fetch_time.clone(); + futures::stream::once(async move { + let timer = fetch_time.timer(); + let r = fetch_partition(&p).await; + timer.done(); + + r.map_err(|e| ArrowError::ExternalError(Box::new(e))) + }) + .try_flatten() + }); - let result = WrappedStream::new( - Box::pin(futures::stream::iter(result).flatten()), + let result = RecordBatchStreamAdapter::new( Arc::new(self.schema.as_ref().clone()), + futures::stream::iter(stream).flatten(), ); Ok(Box::pin(result)) } @@ -201,7 +204,7 @@ fn stats_for_partitions( async fn fetch_partition( location: &PartitionLocation, -) -> Result>> { +) -> Result { let metadata = &location.executor_meta; let partition_id = &location.partition_id; let mut ballista_client = diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs index 7a87406afebc8..b68816fea7917 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs @@ -23,6 +23,7 @@ use datafusion::physical_plan::expressions::PhysicalSortExpr; use std::any::Any; +use std::future::Future; use std::iter::Iterator; use std::path::PathBuf; use std::sync::Arc; @@ -49,10 +50,12 @@ use datafusion::physical_plan::metrics::{ use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; -use futures::StreamExt; +use futures::{StreamExt, TryFutureExt, TryStreamExt}; +use datafusion::arrow::error::ArrowError; use datafusion::execution::context::TaskContext; use datafusion::physical_plan::repartition::BatchPartitioner; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use log::{debug, info}; /// ShuffleWriterExec represents a section of a query plan that has consistent partitioning and @@ -137,149 +140,155 @@ impl ShuffleWriterExec { self.shuffle_output_partitioning.as_ref() } - pub async fn execute_shuffle_write( + pub fn execute_shuffle_write( &self, input_partition: usize, context: Arc, - ) -> Result> { - let now = Instant::now(); - - let mut stream = self.plan.execute(input_partition, context).await?; - + ) -> impl Future>> { let mut path = PathBuf::from(&self.work_dir); path.push(&self.job_id); path.push(&format!("{}", self.stage_id)); let write_metrics = ShuffleWriteMetrics::new(input_partition, &self.metrics); - - match &self.shuffle_output_partitioning { - None => { - let timer = write_metrics.write_time.timer(); - path.push(&format!("{}", input_partition)); - std::fs::create_dir_all(&path)?; - path.push("data.arrow"); - let path = path.to_str().unwrap(); - info!("Writing results to {}", path); - - // stream results to disk - let stats = utils::write_stream_to_disk( - &mut stream, - path, - &write_metrics.write_time, - ) - .await - .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; - - write_metrics - .input_rows - .add(stats.num_rows.unwrap_or(0) as usize); - write_metrics - .output_rows - .add(stats.num_rows.unwrap_or(0) as usize); - timer.done(); - - info!( - "Executed partition {} in {} seconds. Statistics: {}", - input_partition, - now.elapsed().as_secs(), - stats - ); - - Ok(vec![ShuffleWritePartition { - partition_id: input_partition as u64, - path: path.to_owned(), - num_batches: stats.num_batches.unwrap_or(0), - num_rows: stats.num_rows.unwrap_or(0), - num_bytes: stats.num_bytes.unwrap_or(0), - }]) - } - - Some(Partitioning::Hash(exprs, n)) => { - let num_output_partitions = *n; - - // we won't necessary produce output for every possible partition, so we - // create writers on demand - let mut writers: Vec> = vec![]; - for _ in 0..num_output_partitions { - writers.push(None); + let output_partitioning = self.shuffle_output_partitioning.clone(); + let plan = self.plan.clone(); + + async move { + let now = Instant::now(); + let mut stream = plan.execute(input_partition, context).await?; + + match output_partitioning { + None => { + let timer = write_metrics.write_time.timer(); + path.push(&format!("{}", input_partition)); + std::fs::create_dir_all(&path)?; + path.push("data.arrow"); + let path = path.to_str().unwrap(); + info!("Writing results to {}", path); + + // stream results to disk + let stats = utils::write_stream_to_disk( + &mut stream, + path, + &write_metrics.write_time, + ) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + + write_metrics + .input_rows + .add(stats.num_rows.unwrap_or(0) as usize); + write_metrics + .output_rows + .add(stats.num_rows.unwrap_or(0) as usize); + timer.done(); + + info!( + "Executed partition {} in {} seconds. Statistics: {}", + input_partition, + now.elapsed().as_secs(), + stats + ); + + Ok(vec![ShuffleWritePartition { + partition_id: input_partition as u64, + path: path.to_owned(), + num_batches: stats.num_batches.unwrap_or(0), + num_rows: stats.num_rows.unwrap_or(0), + num_bytes: stats.num_bytes.unwrap_or(0), + }]) } - let mut partitioner = BatchPartitioner::try_new( - Partitioning::Hash(exprs.clone(), *n), - write_metrics.repart_time.clone(), - )?; - - while let Some(result) = stream.next().await { - let input_batch = result?; - - write_metrics.input_rows.add(input_batch.num_rows()); + Some(Partitioning::Hash(exprs, num_output_partitions)) => { + // we won't necessary produce output for every possible partition, so we + // create writers on demand + let mut writers: Vec> = vec![]; + for _ in 0..num_output_partitions { + writers.push(None); + } - partitioner.partition( - input_batch, - |output_partition, output_batch| { - // write non-empty batch out + let mut partitioner = BatchPartitioner::try_new( + Partitioning::Hash(exprs, num_output_partitions), + write_metrics.repart_time.clone(), + )?; - // TODO optimize so we don't write or fetch empty partitions - // if output_batch.num_rows() > 0 { - let timer = write_metrics.write_time.timer(); - match &mut writers[output_partition] { - Some(w) => { - w.write(&output_batch)?; + while let Some(result) = stream.next().await { + let input_batch = result?; + + write_metrics.input_rows.add(input_batch.num_rows()); + + partitioner.partition( + input_batch, + |output_partition, output_batch| { + // write non-empty batch out + + // TODO optimize so we don't write or fetch empty partitions + // if output_batch.num_rows() > 0 { + let timer = write_metrics.write_time.timer(); + match &mut writers[output_partition] { + Some(w) => { + w.write(&output_batch)?; + } + None => { + let mut path = path.clone(); + path.push(&format!("{}", output_partition)); + std::fs::create_dir_all(&path)?; + + path.push(format!( + "data-{}.arrow", + input_partition + )); + info!("Writing results to {:?}", path); + + let mut writer = IPCWriter::new( + &path, + stream.schema().as_ref(), + )?; + + writer.write(&output_batch)?; + writers[output_partition] = Some(writer); + } } - None => { - let mut path = path.clone(); - path.push(&format!("{}", output_partition)); - std::fs::create_dir_all(&path)?; - - path.push(format!("data-{}.arrow", input_partition)); - info!("Writing results to {:?}", path); - - let mut writer = - IPCWriter::new(&path, stream.schema().as_ref())?; + write_metrics.output_rows.add(output_batch.num_rows()); + timer.done(); + Ok(()) + }, + )?; + } - writer.write(&output_batch)?; - writers[output_partition] = Some(writer); - } + let mut part_locs = vec![]; + + for (i, w) in writers.iter_mut().enumerate() { + match w { + Some(w) => { + w.finish()?; + info!( + "Finished writing shuffle partition {} at {:?}. Batches: {}. Rows: {}. Bytes: {}.", + i, + w.path(), + w.num_batches, + w.num_rows, + w.num_bytes + ); + + part_locs.push(ShuffleWritePartition { + partition_id: i as u64, + path: w.path().to_string_lossy().to_string(), + num_batches: w.num_batches, + num_rows: w.num_rows, + num_bytes: w.num_bytes, + }); } - write_metrics.output_rows.add(output_batch.num_rows()); - timer.done(); - Ok(()) - }, - )?; - } - - let mut part_locs = vec![]; - - for (i, w) in writers.iter_mut().enumerate() { - match w { - Some(w) => { - w.finish()?; - info!( - "Finished writing shuffle partition {} at {:?}. Batches: {}. Rows: {}. Bytes: {}.", - i, - w.path(), - w.num_batches, - w.num_rows, - w.num_bytes - ); - - part_locs.push(ShuffleWritePartition { - partition_id: i as u64, - path: w.path().to_string_lossy().to_string(), - num_batches: w.num_batches, - num_rows: w.num_rows, - num_bytes: w.num_bytes, - }); + None => {} } - None => {} } + Ok(part_locs) } - Ok(part_locs) - } - _ => Err(DataFusionError::Execution( - "Invalid shuffle partitioning scheme".to_owned(), - )), + _ => Err(DataFusionError::Execution( + "Invalid shuffle partitioning scheme".to_owned(), + )), + } } } } @@ -291,7 +300,7 @@ impl ExecutionPlan for ShuffleWriterExec { } fn schema(&self) -> SchemaRef { - self.plan.schema() + result_schema() } fn output_partitioning(&self) -> Partitioning { @@ -332,50 +341,61 @@ impl ExecutionPlan for ShuffleWriterExec { partition: usize, context: Arc, ) -> Result { - let part_loc = self.execute_shuffle_write(partition, context).await?; - - // build metadata result batch - let num_writers = part_loc.len(); - let mut partition_builder = UInt32Builder::new(num_writers); - let mut path_builder = StringBuilder::new(num_writers); - let mut num_rows_builder = UInt64Builder::new(num_writers); - let mut num_batches_builder = UInt64Builder::new(num_writers); - let mut num_bytes_builder = UInt64Builder::new(num_writers); - - for loc in &part_loc { - path_builder.append_value(loc.path.clone())?; - partition_builder.append_value(loc.partition_id as u32)?; - num_rows_builder.append_value(loc.num_rows)?; - num_batches_builder.append_value(loc.num_batches)?; - num_bytes_builder.append_value(loc.num_bytes)?; - } + let schema = result_schema(); - // build arrays - let partition_num: ArrayRef = Arc::new(partition_builder.finish()); - let path: ArrayRef = Arc::new(path_builder.finish()); - let field_builders: Vec> = vec![ - Box::new(num_rows_builder), - Box::new(num_batches_builder), - Box::new(num_bytes_builder), - ]; - let mut stats_builder = StructBuilder::new( - PartitionStats::default().arrow_struct_fields(), - field_builders, - ); - for _ in 0..num_writers { - stats_builder.append(true)?; - } - let stats = Arc::new(stats_builder.finish()); + let schema_captured = schema.clone(); + let fut_stream = self + .execute_shuffle_write(partition, context) + .and_then(|part_loc| async move { + // build metadata result batch + let num_writers = part_loc.len(); + let mut partition_builder = UInt32Builder::new(num_writers); + let mut path_builder = StringBuilder::new(num_writers); + let mut num_rows_builder = UInt64Builder::new(num_writers); + let mut num_batches_builder = UInt64Builder::new(num_writers); + let mut num_bytes_builder = UInt64Builder::new(num_writers); + + for loc in &part_loc { + path_builder.append_value(loc.path.clone())?; + partition_builder.append_value(loc.partition_id as u32)?; + num_rows_builder.append_value(loc.num_rows)?; + num_batches_builder.append_value(loc.num_batches)?; + num_bytes_builder.append_value(loc.num_bytes)?; + } - // build result batch containing metadata - let schema = result_schema(); - let batch = - RecordBatch::try_new(schema.clone(), vec![partition_num, path, stats]) - .map_err(DataFusionError::ArrowError)?; + // build arrays + let partition_num: ArrayRef = Arc::new(partition_builder.finish()); + let path: ArrayRef = Arc::new(path_builder.finish()); + let field_builders: Vec> = vec![ + Box::new(num_rows_builder), + Box::new(num_batches_builder), + Box::new(num_bytes_builder), + ]; + let mut stats_builder = StructBuilder::new( + PartitionStats::default().arrow_struct_fields(), + field_builders, + ); + for _ in 0..num_writers { + stats_builder.append(true)?; + } + let stats = Arc::new(stats_builder.finish()); + + // build result batch containing metadata + let batch = RecordBatch::try_new( + schema_captured.clone(), + vec![partition_num, path, stats], + )?; + + debug!("RESULTS METADATA:\n{:?}", batch); - debug!("RESULTS METADATA:\n{:?}", batch); + MemoryStream::try_new(vec![batch], schema_captured, None) + }) + .map_err(|e| ArrowError::ExternalError(Box::new(e))); - Ok(Box::pin(MemoryStream::try_new(vec![batch], schema, None)?)) + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::once(fut_stream).try_flatten(), + ))) } fn metrics(&self) -> Option { diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs index 85a557e437ae3..1418aecb31a26 100644 --- a/ballista/rust/core/src/utils.rs +++ b/ballista/rust/core/src/utils.rs @@ -32,10 +32,7 @@ use crate::config::BallistaConfig; use crate::serde::{AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec}; use async_trait::async_trait; use datafusion::arrow::datatypes::Schema; -use datafusion::arrow::error::Result as ArrowResult; -use datafusion::arrow::{ - datatypes::SchemaRef, ipc::writer::FileWriter, record_batch::RecordBatch, -}; +use datafusion::arrow::{ipc::writer::FileWriter, record_batch::RecordBatch}; use datafusion::error::DataFusionError; use datafusion::execution::context::{ QueryPlanner, SessionConfig, SessionContext, SessionState, @@ -55,7 +52,7 @@ use datafusion::physical_plan::hash_join::HashJoinExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{metrics, ExecutionPlan, RecordBatchStream}; -use futures::{Stream, StreamExt}; +use futures::StreamExt; /// Stream data to disk in Arrow IPC format @@ -316,38 +313,3 @@ impl QueryPlanner for BallistaQueryPlanner { } } } - -pub struct WrappedStream { - stream: Pin> + Send>>, - schema: SchemaRef, -} - -impl WrappedStream { - pub fn new( - stream: Pin> + Send>>, - schema: SchemaRef, - ) -> Self { - Self { stream, schema } - } -} - -impl RecordBatchStream for WrappedStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -impl Stream for WrappedStream { - type Item = ArrowResult; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.stream.poll_next_unpin(cx) - } - - fn size_hint(&self) -> (usize, Option) { - self.stream.size_hint() - } -} diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs index 99209121ffa7b..06d670ff45ec9 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/core/src/physical_plan/stream.rs @@ -78,7 +78,7 @@ impl RecordBatchStream for RecordBatchReceiverStream { pin_project! { /// Combines a [`Stream`] with a [`SchemaRef`] implementing /// [`RecordBatchStream`] for the combination - pub(crate) struct RecordBatchStreamAdapter { + pub struct RecordBatchStreamAdapter { schema: SchemaRef, #[pin] @@ -88,7 +88,7 @@ pin_project! { impl RecordBatchStreamAdapter { /// Creates a new [`RecordBatchStreamAdapter`] from the provided schema and stream - pub(crate) fn new(schema: SchemaRef, stream: S) -> Self { + pub fn new(schema: SchemaRef, stream: S) -> Self { Self { schema, stream } } } @@ -113,6 +113,10 @@ where ) -> std::task::Poll> { self.project().stream.poll_next(cx) } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } } impl RecordBatchStream for RecordBatchStreamAdapter From 4b480f0fe32f3f11c85421287302fec154268ba2 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Tue, 3 May 2022 16:07:22 +0100 Subject: [PATCH 2/2] Revert ShuffleWriterExec schema change --- ballista/rust/core/src/execution_plans/shuffle_writer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs index b68816fea7917..f5c98b2001153 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs @@ -300,7 +300,7 @@ impl ExecutionPlan for ShuffleWriterExec { } fn schema(&self) -> SchemaRef { - result_schema() + self.plan.schema() } fn output_partitioning(&self) -> Partitioning {