diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index b0f86ec1f97..c83ca4d8de5 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -59,6 +59,7 @@ use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::projection_push_down::ProjectionPushDown; use crate::physical_optimizer::coalesce_batches::CoalesceBatches; use crate::physical_optimizer::merge_exec::AddMergeExec; +use crate::physical_optimizer::repartition::Repartition; use crate::physical_plan::csv::CsvReadOptions; use crate::physical_plan::planner::DefaultPhysicalPlanner; @@ -642,6 +643,7 @@ impl ExecutionConfig { ], physical_optimizers: vec![ Arc::new(CoalesceBatches::new()), + Arc::new(Repartition::new()), Arc::new(AddMergeExec::new()), ], query_planner: Arc::new(DefaultQueryPlanner {}), diff --git a/rust/datafusion/tests/user_defined_plan.rs b/rust/datafusion/tests/user_defined_plan.rs index aae5c597d82..f9f24430104 100644 --- a/rust/datafusion/tests/user_defined_plan.rs +++ b/rust/datafusion/tests/user_defined_plan.rs @@ -58,7 +58,7 @@ //! N elements, reducing the total amount of required buffer memory. //! -use futures::{FutureExt, Stream, StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt}; use arrow::{ array::{Int64Array, StringArray}, @@ -180,6 +180,7 @@ async fn topk_plan() -> Result<()> { fn make_topk_context() -> ExecutionContext { let config = ExecutionConfig::new() .with_query_planner(Arc::new(TopKQueryPlanner {})) + .with_concurrency(48) .add_optimizer_rule(Arc::new(TopKOptimizerRule {})); ExecutionContext::with_config(config) @@ -388,6 +389,7 @@ impl ExecutionPlan for TopKExec { input: self.input.execute(partition).await?, k: self.k, done: false, + state: BTreeMap::new(), })) } } @@ -400,6 +402,8 @@ struct TopKReader { k: usize, /// Have we produced the output yet? done: bool, + /// Output + state: BTreeMap, } /// Keeps track of the revenue from customer_id and stores if it @@ -432,7 +436,7 @@ fn accumulate_batch( input_batch: &RecordBatch, mut top_values: BTreeMap, k: &usize, -) -> Result> { +) -> BTreeMap { let num_rows = input_batch.num_rows(); // Assuming the input columns are // column[0]: customer_id / UTF8 @@ -457,7 +461,7 @@ fn accumulate_batch( k, ); } - Ok(top_values) + top_values } impl Stream for TopKReader { @@ -475,41 +479,29 @@ impl Stream for TopKReader { // take this as immutable let k = self.k; let schema = self.schema(); - let top_values = self - .input - .as_mut() - // Hard coded implementation for sales / customer_id example as BTree - .try_fold( - BTreeMap::::new(), - move |top_values, batch| async move { - accumulate_batch(&batch, top_values, &k) - .map_err(DataFusionError::into_arrow_external_error) - }, - ); - - let top_values = top_values.map(|top_values| match top_values { - Ok(top_values) => { - // make output by walking over the map backwards (so values are descending) + let poll = self.input.poll_next_unpin(cx); + + match poll { + Poll::Ready(Some(Ok(batch))) => { + self.state = accumulate_batch(&batch, self.state.clone(), &k); + Poll::Ready(Some(Ok(RecordBatch::new_empty(schema)))) + } + Poll::Ready(None) => { + self.done = true; let (revenue, customer): (Vec, Vec<&String>) = - top_values.iter().rev().unzip(); + self.state.iter().rev().unzip(); let customer: Vec<&str> = customer.iter().map(|&s| &**s).collect(); - Ok(RecordBatch::try_new( + Poll::Ready(Some(RecordBatch::try_new( schema, vec![ Arc::new(StringArray::from(customer)), Arc::new(Int64Array::from(revenue)), ], - )?) + ))) } - Err(e) => Err(e), - }); - let mut top_values = Box::pin(top_values.into_stream()); - - top_values.poll_next_unpin(cx).map(|batch| { - self.done = true; - batch - }) + other => other, + } } }