diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index a1b43cf531f16..c9bf527091c0a 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -1296,7 +1296,7 @@ mod tests { assert_eq!("\ Projection: #t1.c1 AS AAA, #t1.c2, #t1.c3, #t2.c1, #t2.c2, #t2.c3\ \n Limit: skip=0, fetch=1\ - \n Sort: #t1.c1 ASC NULLS FIRST, #t1.c2 ASC NULLS FIRST, #t1.c3 ASC NULLS FIRST, #t2.c1 ASC NULLS FIRST, #t2.c2 ASC NULLS FIRST, #t2.c3 ASC NULLS FIRST\ + \n Sort: #t1.c1 ASC NULLS FIRST, #t1.c2 ASC NULLS FIRST, #t1.c3 ASC NULLS FIRST, #t2.c1 ASC NULLS FIRST, #t2.c2 ASC NULLS FIRST, #t2.c3 ASC NULLS FIRST, fetch=1\ \n Inner Join: #t1.c1 = #t2.c1\ \n TableScan: t1 projection=[c1, c2, c3]\ \n TableScan: t2 projection=[c1, c2, c3]", diff --git a/datafusion/core/src/physical_optimizer/parallel_sort.rs b/datafusion/core/src/physical_optimizer/parallel_sort.rs index 3361d8155f7fe..e3ca60cb5cf72 100644 --- a/datafusion/core/src/physical_optimizer/parallel_sort.rs +++ b/datafusion/core/src/physical_optimizer/parallel_sort.rs @@ -20,7 +20,6 @@ use crate::{ error::Result, physical_optimizer::PhysicalOptimizerRule, physical_plan::{ - limit::GlobalLimitExec, sorts::{sort::SortExec, sort_preserving_merge::SortPreservingMergeExec}, with_new_children_if_necessary, }, @@ -55,31 +54,33 @@ impl PhysicalOptimizerRule for ParallelSort { .map(|child| self.optimize(child.clone(), config)) .collect::>>()?; let plan = with_new_children_if_necessary(plan, children)?; - let children = plan.children(); let plan_any = plan.as_any(); - // GlobalLimitExec (SortExec preserve_partitioning=False) - // -> GlobalLimitExec (SortExec preserve_partitioning=True) - let parallel_sort = plan_any.downcast_ref::().is_some() - && children.len() == 1 - && children[0].as_any().downcast_ref::().is_some() - && !children[0] - .as_any() + // SortExec preserve_partitioning=False, fetch=Some(n)) + // -> SortPreservingMergeExec (SortExec preserve_partitioning=True, fetch=Some(n)) + let parallel_sort = plan_any.downcast_ref::().is_some() + && plan_any + .downcast_ref::() + .unwrap() + .fetch() + .is_some() + && !plan_any .downcast_ref::() .unwrap() .preserve_partitioning(); Ok(if parallel_sort { - let sort = children[0].as_any().downcast_ref::().unwrap(); + let sort = plan_any.downcast_ref::().unwrap(); let new_sort = SortExec::new_with_partitioning( sort.expr().to_vec(), sort.input().clone(), true, + sort.fetch(), ); let merge = SortPreservingMergeExec::new( sort.expr().to_vec(), Arc::new(new_sort), ); - with_new_children_if_necessary(plan, vec![Arc::new(merge)])? + Arc::new(merge) } else { plan.clone() }) diff --git a/datafusion/core/src/physical_optimizer/repartition.rs b/datafusion/core/src/physical_optimizer/repartition.rs index 20aa59f0cd679..1d2b259086839 100644 --- a/datafusion/core/src/physical_optimizer/repartition.rs +++ b/datafusion/core/src/physical_optimizer/repartition.rs @@ -295,7 +295,7 @@ mod tests { expr: col("c1", &schema()).unwrap(), options: SortOptions::default(), }]; - Arc::new(SortExec::try_new(sort_exprs, input).unwrap()) + Arc::new(SortExec::try_new(sort_exprs, input, None).unwrap()) } fn projection_exec(input: Arc) -> Arc { diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 20a819622d5dd..005a7943265aa 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -590,9 +590,9 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; Arc::new(if can_repartition { - SortExec::new_with_partitioning(sort_keys, input_exec, true) + SortExec::new_with_partitioning(sort_keys, input_exec, true, None) } else { - SortExec::try_new(sort_keys, input_exec)? + SortExec::try_new(sort_keys, input_exec, None)? }) }; @@ -815,7 +815,7 @@ impl DefaultPhysicalPlanner { physical_partitioning, )?) ) } - LogicalPlan::Sort(Sort { expr, input, .. }) => { + LogicalPlan::Sort(Sort { expr, input, fetch, .. }) => { let physical_input = self.create_initial_plan(input, session_state).await?; let input_schema = physical_input.as_ref().schema(); let input_dfschema = input.as_ref().schema(); @@ -841,7 +841,7 @@ impl DefaultPhysicalPlanner { )), }) .collect::>>()?; - Ok(Arc::new(SortExec::try_new(sort_expr, physical_input)?)) + Ok(Arc::new(SortExec::try_new(sort_expr, physical_input, *fetch)?)) } LogicalPlan::Join(Join { left, diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index cc05017853244..8e8457be2549a 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -82,6 +82,7 @@ struct ExternalSorter { runtime: Arc, metrics_set: CompositeMetricsSet, metrics: BaselineMetrics, + fetch: Option, } impl ExternalSorter { @@ -92,6 +93,7 @@ impl ExternalSorter { metrics_set: CompositeMetricsSet, session_config: Arc, runtime: Arc, + fetch: Option, ) -> Self { let metrics = metrics_set.new_intermediate_baseline(partition_id); Self { @@ -104,6 +106,7 @@ impl ExternalSorter { runtime, metrics_set, metrics, + fetch, } } @@ -120,7 +123,7 @@ impl ExternalSorter { // NB timer records time taken on drop, so there are no // calls to `timer.done()` below. let _timer = tracking_metrics.elapsed_compute().timer(); - let partial = sort_batch(input, self.schema.clone(), &self.expr)?; + let partial = sort_batch(input, self.schema.clone(), &self.expr, self.fetch)?; in_mem_batches.push(partial); } Ok(()) @@ -657,6 +660,8 @@ pub struct SortExec { metrics_set: CompositeMetricsSet, /// Preserve partitions of input plan preserve_partitioning: bool, + /// Fetch highest/lowest n results + fetch: Option, } impl SortExec { @@ -664,8 +669,9 @@ impl SortExec { pub fn try_new( expr: Vec, input: Arc, + fetch: Option, ) -> Result { - Ok(Self::new_with_partitioning(expr, input, false)) + Ok(Self::new_with_partitioning(expr, input, false, fetch)) } /// Whether this `SortExec` preserves partitioning of the children @@ -679,12 +685,14 @@ impl SortExec { expr: Vec, input: Arc, preserve_partitioning: bool, + fetch: Option, ) -> Self { Self { expr, input, metrics_set: CompositeMetricsSet::new(), preserve_partitioning, + fetch, } } @@ -697,6 +705,11 @@ impl SortExec { pub fn expr(&self) -> &[PhysicalSortExpr] { &self.expr } + + /// If `Some(fetch)`, limits output to only the first "fetch" items + pub fn fetch(&self) -> Option { + self.fetch + } } impl ExecutionPlan for SortExec { @@ -750,6 +763,7 @@ impl ExecutionPlan for SortExec { self.expr.clone(), children[0].clone(), self.preserve_partitioning, + self.fetch, ))) } @@ -778,6 +792,7 @@ impl ExecutionPlan for SortExec { self.expr.clone(), self.metrics_set.clone(), context, + self.fetch(), ) .map_err(|e| ArrowError::ExternalError(Box::new(e))), ) @@ -816,14 +831,14 @@ fn sort_batch( batch: RecordBatch, schema: SchemaRef, expr: &[PhysicalSortExpr], + fetch: Option, ) -> ArrowResult { - // TODO: pushup the limit expression to sort let sort_columns = expr .iter() .map(|e| e.evaluate_to_sort_column(&batch)) .collect::>>()?; - let indices = lexsort_to_indices(&sort_columns, None)?; + let indices = lexsort_to_indices(&sort_columns, fetch)?; // reorder all rows based on sorted indices let sorted_batch = RecordBatch::try_new( @@ -870,6 +885,7 @@ async fn do_sort( expr: Vec, metrics_set: CompositeMetricsSet, context: Arc, + fetch: Option, ) -> Result { debug!( "Start do_sort for partition {} of context session_id {} and task_id {:?}", @@ -887,6 +903,7 @@ async fn do_sort( metrics_set, Arc::new(context.session_config()), context.runtime_env(), + fetch, ); context.runtime_env().register_requester(sorter.id()); while let Some(batch) = input.next().await { @@ -949,6 +966,7 @@ mod tests { }, ], Arc::new(CoalescePartitionsExec::new(csv)), + None, )?); let result = collect(sort_exec, task_ctx).await?; @@ -1011,6 +1029,7 @@ mod tests { }, ], Arc::new(CoalescePartitionsExec::new(csv)), + None, )?); let task_ctx = session_ctx.task_ctx(); @@ -1083,6 +1102,7 @@ mod tests { options: SortOptions::default(), }], input, + None, )?); let result: Vec = collect(sort_exec, task_ctx).await?; @@ -1159,6 +1179,7 @@ mod tests { }, ], Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None)?), + None, )?); assert_eq!(DataType::Float32, *sort_exec.schema().field(0).data_type()); @@ -1226,6 +1247,7 @@ mod tests { options: SortOptions::default(), }], blocking_exec, + None, )?); let fut = collect(sort_exec, task_ctx); diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs index dc788be3380e9..5db3c50e6c141 100644 --- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs @@ -874,8 +874,12 @@ mod tests { sort: Vec, context: Arc, ) -> RecordBatch { - let sort_exec = - Arc::new(SortExec::new_with_partitioning(sort.clone(), input, true)); + let sort_exec = Arc::new(SortExec::new_with_partitioning( + sort.clone(), + input, + true, + None, + )); sorted_merge(sort_exec, sort, context).await } @@ -885,7 +889,7 @@ mod tests { context: Arc, ) -> RecordBatch { let merge = Arc::new(CoalescePartitionsExec::new(src)); - let sort_exec = Arc::new(SortExec::try_new(sort, merge).unwrap()); + let sort_exec = Arc::new(SortExec::try_new(sort, merge, None).unwrap()); let mut result = collect(sort_exec, context).await.unwrap(); assert_eq!(result.len(), 1); result.remove(0) diff --git a/datafusion/core/tests/order_spill_fuzz.rs b/datafusion/core/tests/order_spill_fuzz.rs index a2d47629b9270..faaa2ae0bb3d8 100644 --- a/datafusion/core/tests/order_spill_fuzz.rs +++ b/datafusion/core/tests/order_spill_fuzz.rs @@ -75,7 +75,7 @@ async fn run_sort(pool_size: usize, size_spill: Vec<(usize, bool)>) { }]; let exec = MemoryExec::try_new(&input, schema, None).unwrap(); - let sort = Arc::new(SortExec::try_new(sort, Arc::new(exec)).unwrap()); + let sort = Arc::new(SortExec::try_new(sort, Arc::new(exec), None).unwrap()); let runtime_config = RuntimeConfig::new().with_memory_manager( MemoryManagerConfig::try_new_limit(pool_size, 1.0).unwrap(), diff --git a/datafusion/core/tests/user_defined_plan.rs b/datafusion/core/tests/user_defined_plan.rs index 13ddb1eb8da14..c577e48e78000 100644 --- a/datafusion/core/tests/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined_plan.rs @@ -299,6 +299,7 @@ impl OptimizerRule for TopKOptimizerRule { if let LogicalPlan::Sort(Sort { ref expr, ref input, + .. }) = **input { if expr.len() == 1 { diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 0125291fda0d9..cb024fd361fbc 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -393,6 +393,7 @@ impl LogicalPlanBuilder { return Ok(Self::from(LogicalPlan::Sort(Sort { expr: normalize_cols(exprs, &self.plan)?, input: Arc::new(self.plan.clone()), + fetch: None, }))); } @@ -400,6 +401,7 @@ impl LogicalPlanBuilder { let sort_plan = LogicalPlan::Sort(Sort { expr: normalize_cols(exprs, &plan)?, input: Arc::new(plan.clone()), + fetch: None, }); // remove pushed down sort columns let new_expr = schema diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index f6b2b1b74f6ab..049e6158ca8fb 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -806,7 +806,7 @@ impl LogicalPlan { "Aggregate: groupBy=[{:?}], aggr=[{:?}]", group_expr, aggr_expr ), - LogicalPlan::Sort(Sort { expr, .. }) => { + LogicalPlan::Sort(Sort { expr, fetch, .. }) => { write!(f, "Sort: ")?; for (i, expr_item) in expr.iter().enumerate() { if i > 0 { @@ -814,6 +814,10 @@ impl LogicalPlan { } write!(f, "{:?}", expr_item)?; } + if let Some(a) = fetch { + write!(f, ", fetch={}", a)?; + } + Ok(()) } LogicalPlan::Join(Join { @@ -1373,6 +1377,8 @@ pub struct Sort { pub expr: Vec, /// The incoming logical plan pub input: Arc, + /// Optional fetch limit + pub fetch: Option, } /// Join two logical plans on one or more join columns diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 65c4be24849b0..4ace809b811dd 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -420,9 +420,10 @@ pub fn from_plan( expr[group_expr.len()..].to_vec(), schema.clone(), )?)), - LogicalPlan::Sort(Sort { .. }) => Ok(LogicalPlan::Sort(Sort { + LogicalPlan::Sort(Sort { fetch, .. }) => Ok(LogicalPlan::Sort(Sort { expr: expr.to_vec(), input: Arc::new(inputs[0].clone()), + fetch: *fetch, })), LogicalPlan::Join(Join { join_type, diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 978b79d375d3c..b17dd9492efc8 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -196,7 +196,7 @@ fn optimize( schema.clone(), )?)) } - LogicalPlan::Sort(Sort { expr, input }) => { + LogicalPlan::Sort(Sort { expr, input, fetch }) => { let arrays = to_arrays(expr, input, &mut expr_set)?; let (mut new_expr, new_input) = rewrite_expr( @@ -210,6 +210,7 @@ fn optimize( Ok(LogicalPlan::Sort(Sort { expr: pop_expr(&mut new_expr)?, input: Arc::new(new_input), + fetch: *fetch, })) } LogicalPlan::Join { .. } diff --git a/datafusion/optimizer/src/limit_push_down.rs b/datafusion/optimizer/src/limit_push_down.rs index 65bdd64ee0c57..cacbbfc5b2b17 100644 --- a/datafusion/optimizer/src/limit_push_down.rs +++ b/datafusion/optimizer/src/limit_push_down.rs @@ -20,7 +20,9 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ - logical_plan::{Join, JoinType, Limit, LogicalPlan, Projection, TableScan, Union}, + logical_plan::{ + Join, JoinType, Limit, LogicalPlan, Projection, Sort, TableScan, Union, + }, utils::from_plan, }; use std::sync::Arc; @@ -247,6 +249,25 @@ fn limit_push_down( ), } } + ( + LogicalPlan::Sort(Sort { expr, input, fetch }), + Ancestor::FromLimit { + skip: ancestor_skip, + fetch: Some(ancestor_fetch), + .. + }, + ) => { + // Update Sort `fetch`, but simply recurse through children (sort should receive all input for sorting) + let input = push_down_children_limit(_optimizer, _optimizer_config, input)?; + let sort_fetch = ancestor_skip + ancestor_fetch; + let plan = LogicalPlan::Sort(Sort { + expr: expr.clone(), + input: Arc::new(input), + fetch: Some(fetch.map(|f| f.min(sort_fetch)).unwrap_or(sort_fetch)), + }); + Ok(plan) + } + // For other nodes we can't push down the limit // But try to recurse and find other limit nodes to push down _ => push_down_children_limit(_optimizer, _optimizer_config, plan), @@ -340,6 +361,8 @@ impl OptimizerRule for LimitPushDown { #[cfg(test)] mod test { + use std::vec; + use super::*; use crate::test::*; use datafusion_expr::{ @@ -438,6 +461,44 @@ mod test { Ok(()) } + #[test] + fn limit_push_down_sort() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .sort(vec![col("a")])? + .limit(0, Some(10))? + .build()?; + + // Should push down limit to sort + let expected = "Limit: skip=0, fetch=10\ + \n Sort: #test.a, fetch=10\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + #[test] + fn limit_push_down_sort_skip() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .sort(vec![col("a")])? + .limit(5, Some(10))? + .build()?; + + // Should push down limit to sort + let expected = "Limit: skip=5, fetch=10\ + \n Sort: #test.a, fetch=15\ + \n TableScan: test"; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + #[test] fn multi_stage_limit_recurses_to_deeper_limit() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index baabc04cfff7a..893d107181948 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -125,6 +125,8 @@ message SelectionNode { message SortNode { LogicalPlanNode input = 1; repeated datafusion.LogicalExprNode expr = 2; + // Maximum number of highest/lowest rows to fetch; negative means no limit + int64 fetch = 3; } message RepartitionNode { diff --git a/datafusion/proto/src/logical_plan.rs b/datafusion/proto/src/logical_plan.rs index 45399554b46d9..8f9f8a9c46613 100644 --- a/datafusion/proto/src/logical_plan.rs +++ b/datafusion/proto/src/logical_plan.rs @@ -959,7 +959,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Sort(Sort { input, expr }) => { + LogicalPlan::Sort(Sort { input, expr, fetch }) => { let input: protobuf::LogicalPlanNode = protobuf::LogicalPlanNode::try_from_logical_plan( input.as_ref(), @@ -974,6 +974,7 @@ impl AsLogicalPlan for LogicalPlanNode { protobuf::SortNode { input: Some(Box::new(input)), expr: selection_expr, + fetch: fetch.map(|f| f as i64).unwrap_or(-1i64), }, ))), })