Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,7 @@ mod tests {
assert_eq!("\
Projection: #t1.c1 AS AAA, #t1.c2, #t1.c3, #t2.c1, #t2.c2, #t2.c3\
\n Limit: skip=0, fetch=1\
\n Sort: #t1.c1 ASC NULLS FIRST, #t1.c2 ASC NULLS FIRST, #t1.c3 ASC NULLS FIRST, #t2.c1 ASC NULLS FIRST, #t2.c2 ASC NULLS FIRST, #t2.c3 ASC NULLS FIRST\
\n Sort: #t1.c1 ASC NULLS FIRST, #t1.c2 ASC NULLS FIRST, #t1.c3 ASC NULLS FIRST, #t2.c1 ASC NULLS FIRST, #t2.c2 ASC NULLS FIRST, #t2.c3 ASC NULLS FIRST, fetch=1\
\n Inner Join: #t1.c1 = #t2.c1\
\n TableScan: t1 projection=[c1, c2, c3]\
\n TableScan: t2 projection=[c1, c2, c3]",
Expand Down
23 changes: 12 additions & 11 deletions datafusion/core/src/physical_optimizer/parallel_sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ use crate::{
error::Result,
physical_optimizer::PhysicalOptimizerRule,
physical_plan::{
limit::GlobalLimitExec,
sorts::{sort::SortExec, sort_preserving_merge::SortPreservingMergeExec},
with_new_children_if_necessary,
},
Expand Down Expand Up @@ -55,31 +54,33 @@ impl PhysicalOptimizerRule for ParallelSort {
.map(|child| self.optimize(child.clone(), config))
.collect::<Result<Vec<_>>>()?;
let plan = with_new_children_if_necessary(plan, children)?;
let children = plan.children();
let plan_any = plan.as_any();
// GlobalLimitExec (SortExec preserve_partitioning=False)
// -> GlobalLimitExec (SortExec preserve_partitioning=True)
let parallel_sort = plan_any.downcast_ref::<GlobalLimitExec>().is_some()
&& children.len() == 1
&& children[0].as_any().downcast_ref::<SortExec>().is_some()
&& !children[0]
.as_any()
// SortExec preserve_partitioning=False, fetch=Some(n))
// -> SortPreservingMergeExec (SortExec preserve_partitioning=True, fetch=Some(n))
let parallel_sort = plan_any.downcast_ref::<SortExec>().is_some()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we now have the pushdown - we can use fetch, and support more than just a limit directly after sort.

&& plan_any
.downcast_ref::<SortExec>()
.unwrap()
.fetch()
.is_some()
&& !plan_any
.downcast_ref::<SortExec>()
.unwrap()
.preserve_partitioning();

Ok(if parallel_sort {
let sort = children[0].as_any().downcast_ref::<SortExec>().unwrap();
let sort = plan_any.downcast_ref::<SortExec>().unwrap();
let new_sort = SortExec::new_with_partitioning(
sort.expr().to_vec(),
sort.input().clone(),
true,
sort.fetch(),
);
let merge = SortPreservingMergeExec::new(
sort.expr().to_vec(),
Arc::new(new_sort),
);
with_new_children_if_necessary(plan, vec![Arc::new(merge)])?
Arc::new(merge)
} else {
plan.clone()
})
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/repartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ mod tests {
expr: col("c1", &schema()).unwrap(),
options: SortOptions::default(),
}];
Arc::new(SortExec::try_new(sort_exprs, input).unwrap())
Arc::new(SortExec::try_new(sort_exprs, input, None).unwrap())
}

fn projection_exec(input: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,9 +590,9 @@ impl DefaultPhysicalPlanner {
})
.collect::<Result<Vec<_>>>()?;
Arc::new(if can_repartition {
SortExec::new_with_partitioning(sort_keys, input_exec, true)
SortExec::new_with_partitioning(sort_keys, input_exec, true, None)
} else {
SortExec::try_new(sort_keys, input_exec)?
SortExec::try_new(sort_keys, input_exec, None)?
})
};

Expand Down Expand Up @@ -815,7 +815,7 @@ impl DefaultPhysicalPlanner {
physical_partitioning,
)?) )
}
LogicalPlan::Sort(Sort { expr, input, .. }) => {
LogicalPlan::Sort(Sort { expr, input, fetch, .. }) => {
let physical_input = self.create_initial_plan(input, session_state).await?;
let input_schema = physical_input.as_ref().schema();
let input_dfschema = input.as_ref().schema();
Expand All @@ -841,7 +841,7 @@ impl DefaultPhysicalPlanner {
)),
})
.collect::<Result<Vec<_>>>()?;
Ok(Arc::new(SortExec::try_new(sort_expr, physical_input)?))
Ok(Arc::new(SortExec::try_new(sort_expr, physical_input, *fetch)?))
}
LogicalPlan::Join(Join {
left,
Expand Down
30 changes: 26 additions & 4 deletions datafusion/core/src/physical_plan/sorts/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ struct ExternalSorter {
runtime: Arc<RuntimeEnv>,
metrics_set: CompositeMetricsSet,
metrics: BaselineMetrics,
fetch: Option<usize>,
}

impl ExternalSorter {
Expand All @@ -92,6 +93,7 @@ impl ExternalSorter {
metrics_set: CompositeMetricsSet,
session_config: Arc<SessionConfig>,
runtime: Arc<RuntimeEnv>,
fetch: Option<usize>,
) -> Self {
let metrics = metrics_set.new_intermediate_baseline(partition_id);
Self {
Expand All @@ -104,6 +106,7 @@ impl ExternalSorter {
runtime,
metrics_set,
metrics,
fetch,
}
}

Expand All @@ -120,7 +123,7 @@ impl ExternalSorter {
// NB timer records time taken on drop, so there are no
// calls to `timer.done()` below.
let _timer = tracking_metrics.elapsed_compute().timer();
let partial = sort_batch(input, self.schema.clone(), &self.expr)?;
let partial = sort_batch(input, self.schema.clone(), &self.expr, self.fetch)?;
in_mem_batches.push(partial);
}
Ok(())
Expand Down Expand Up @@ -657,15 +660,18 @@ pub struct SortExec {
metrics_set: CompositeMetricsSet,
/// Preserve partitions of input plan
preserve_partitioning: bool,
/// Fetch highest/lowest n results
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see -- this seems like it it now has the information plumbed to the SortExec to implement "TopK" within the physical operator's implementation. 👍

Very cool

fetch: Option<usize>,
}

impl SortExec {
/// Create a new sort execution plan
pub fn try_new(
expr: Vec<PhysicalSortExpr>,
input: Arc<dyn ExecutionPlan>,
fetch: Option<usize>,
) -> Result<Self> {
Ok(Self::new_with_partitioning(expr, input, false))
Ok(Self::new_with_partitioning(expr, input, false, fetch))
}

/// Whether this `SortExec` preserves partitioning of the children
Expand All @@ -679,12 +685,14 @@ impl SortExec {
expr: Vec<PhysicalSortExpr>,
input: Arc<dyn ExecutionPlan>,
preserve_partitioning: bool,
fetch: Option<usize>,
) -> Self {
Self {
expr,
input,
metrics_set: CompositeMetricsSet::new(),
preserve_partitioning,
fetch,
}
}

Expand All @@ -697,6 +705,11 @@ impl SortExec {
pub fn expr(&self) -> &[PhysicalSortExpr] {
&self.expr
}

/// If `Some(fetch)`, limits output to only the first "fetch" items
pub fn fetch(&self) -> Option<usize> {
self.fetch
}
}

impl ExecutionPlan for SortExec {
Expand Down Expand Up @@ -750,6 +763,7 @@ impl ExecutionPlan for SortExec {
self.expr.clone(),
children[0].clone(),
self.preserve_partitioning,
self.fetch,
)))
}

Expand Down Expand Up @@ -778,6 +792,7 @@ impl ExecutionPlan for SortExec {
self.expr.clone(),
self.metrics_set.clone(),
context,
self.fetch(),
)
.map_err(|e| ArrowError::ExternalError(Box::new(e))),
)
Expand Down Expand Up @@ -816,14 +831,14 @@ fn sort_batch(
batch: RecordBatch,
schema: SchemaRef,
expr: &[PhysicalSortExpr],
fetch: Option<usize>,
) -> ArrowResult<BatchWithSortArray> {
// TODO: pushup the limit expression to sort
let sort_columns = expr
.iter()
.map(|e| e.evaluate_to_sort_column(&batch))
.collect::<Result<Vec<SortColumn>>>()?;

let indices = lexsort_to_indices(&sort_columns, None)?;
let indices = lexsort_to_indices(&sort_columns, fetch)?;
Copy link
Contributor Author

@Dandandan Dandandan Sep 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key optimization: this returns only n indices after the change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this will effectively get us much of the benefit of a special TopK operator as we don't have to copy the entire input -- we only copy the fetch limit, if specified

Although I suppose SortExec still buffers all of its input where a TopK could buffer them

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, I wonder if you could also apply the limit here:

https://github.com/apache/arrow-datafusion/blob/3a9e0d0/datafusion/core/src/physical_plan/sorts/sort.rs#L123-L124

as part of sorting each batch -- rather than keeping the entire input batch, we only need to keep at most fetch rows from each batch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lexsort_to_indices already returns only fetch indices per batch, this is used to take that nr. of indices per batch, throwing away the rest of the rows.

The remaining optimization I think is tweaking SortPreservingMergeStream to only maintain fetch records in the heap instead of all fetch top records for each batch in the partition as mentioned here #3516 (comment). After this I think we have a full TopK implementation that only needs to keep n number of rows in memory (per partition).

I would like to do this in a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A separate PR is a great idea 👍

lexsort_to_indices already returns only fetch indices per batch, this is used to take that nr. of indices per batch, throwing away the rest of the rows.

Right, the point I was trying to make is that there are 2 calls to lexsort_to_indices in sort.rs. I think this PR only pushed fetch to one of them. The second is https://github.com/apache/arrow-datafusion/blob/3a9e0d0/datafusion/core/src/physical_plan/sorts/sort.rs#L826 and I think it is correct to push fetch there too

I was thinking if we applied fetch to the second call, we could get close to the same effect without changing SortPreservingMergeStream.

  • After this PR, sort buffers num_input_batches * input_batch_size rows.
  • Adding fetch to the other call to lexsort_to_indices would would buffer num_input_batches * limit rows
  • Extending SortPreservingMergeStream would allow us to buffer only limit rows.

So clearly extending SortPreservingMergeStream is optimal in terms of rows buffered, but it likely requires a bit more effort.

Copy link
Contributor Author

@Dandandan Dandandan Sep 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I didn't look to much at the rest of the implementation, I think you're right that providing fetch to the other lexsort_to_indices would be beneficial as well. I will create a issue for this and issue a PR later.

Copy link
Contributor Author

@Dandandan Dandandan Sep 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current change already buffers num_input_batches * limit by the way, as it is applied before adding them to the buffer. As far as I can see adding the second to lexsort_to_indices will reduce mainly the output of the individual sorts to fetch rows - which is of course beneficial too as that reduces time to sort and limit the input again to take and input to SortPreservingMergeExec

Copy link
Contributor

@jychen7 jychen7 Apr 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right that providing fetch to the other lexsort_to_indices would be beneficial as well. I will create a issue for this and issue a PR later.

for other readers, this is addressed by issue #3544 and fixed by PR #3545


// reorder all rows based on sorted indices
let sorted_batch = RecordBatch::try_new(
Expand Down Expand Up @@ -870,6 +885,7 @@ async fn do_sort(
expr: Vec<PhysicalSortExpr>,
metrics_set: CompositeMetricsSet,
context: Arc<TaskContext>,
fetch: Option<usize>,
) -> Result<SendableRecordBatchStream> {
debug!(
"Start do_sort for partition {} of context session_id {} and task_id {:?}",
Expand All @@ -887,6 +903,7 @@ async fn do_sort(
metrics_set,
Arc::new(context.session_config()),
context.runtime_env(),
fetch,
);
context.runtime_env().register_requester(sorter.id());
while let Some(batch) = input.next().await {
Expand Down Expand Up @@ -949,6 +966,7 @@ mod tests {
},
],
Arc::new(CoalescePartitionsExec::new(csv)),
None,
)?);

let result = collect(sort_exec, task_ctx).await?;
Expand Down Expand Up @@ -1011,6 +1029,7 @@ mod tests {
},
],
Arc::new(CoalescePartitionsExec::new(csv)),
None,
)?);

let task_ctx = session_ctx.task_ctx();
Expand Down Expand Up @@ -1083,6 +1102,7 @@ mod tests {
options: SortOptions::default(),
}],
input,
None,
)?);

let result: Vec<RecordBatch> = collect(sort_exec, task_ctx).await?;
Expand Down Expand Up @@ -1159,6 +1179,7 @@ mod tests {
},
],
Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None)?),
None,
)?);

assert_eq!(DataType::Float32, *sort_exec.schema().field(0).data_type());
Expand Down Expand Up @@ -1226,6 +1247,7 @@ mod tests {
options: SortOptions::default(),
}],
blocking_exec,
None,
)?);

let fut = collect(sort_exec, task_ctx);
Expand Down
10 changes: 7 additions & 3 deletions datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -874,8 +874,12 @@ mod tests {
sort: Vec<PhysicalSortExpr>,
context: Arc<TaskContext>,
) -> RecordBatch {
let sort_exec =
Arc::new(SortExec::new_with_partitioning(sort.clone(), input, true));
let sort_exec = Arc::new(SortExec::new_with_partitioning(
sort.clone(),
input,
true,
None,
));
sorted_merge(sort_exec, sort, context).await
}

Expand All @@ -885,7 +889,7 @@ mod tests {
context: Arc<TaskContext>,
) -> RecordBatch {
let merge = Arc::new(CoalescePartitionsExec::new(src));
let sort_exec = Arc::new(SortExec::try_new(sort, merge).unwrap());
let sort_exec = Arc::new(SortExec::try_new(sort, merge, None).unwrap());
let mut result = collect(sort_exec, context).await.unwrap();
assert_eq!(result.len(), 1);
result.remove(0)
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/order_spill_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async fn run_sort(pool_size: usize, size_spill: Vec<(usize, bool)>) {
}];

let exec = MemoryExec::try_new(&input, schema, None).unwrap();
let sort = Arc::new(SortExec::try_new(sort, Arc::new(exec)).unwrap());
let sort = Arc::new(SortExec::try_new(sort, Arc::new(exec), None).unwrap());

let runtime_config = RuntimeConfig::new().with_memory_manager(
MemoryManagerConfig::try_new_limit(pool_size, 1.0).unwrap(),
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/tests/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ impl OptimizerRule for TopKOptimizerRule {
if let LogicalPlan::Sort(Sort {
ref expr,
ref input,
..
}) = **input
{
if expr.len() == 1 {
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,15 @@ impl LogicalPlanBuilder {
return Ok(Self::from(LogicalPlan::Sort(Sort {
expr: normalize_cols(exprs, &self.plan)?,
input: Arc::new(self.plan.clone()),
fetch: None,
})));
}

let plan = self.add_missing_columns(self.plan.clone(), &missing_cols)?;
let sort_plan = LogicalPlan::Sort(Sort {
expr: normalize_cols(exprs, &plan)?,
input: Arc::new(plan.clone()),
fetch: None,
});
// remove pushed down sort columns
let new_expr = schema
Expand Down
8 changes: 7 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -806,14 +806,18 @@ impl LogicalPlan {
"Aggregate: groupBy=[{:?}], aggr=[{:?}]",
group_expr, aggr_expr
),
LogicalPlan::Sort(Sort { expr, .. }) => {
LogicalPlan::Sort(Sort { expr, fetch, .. }) => {
write!(f, "Sort: ")?;
for (i, expr_item) in expr.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{:?}", expr_item)?;
}
if let Some(a) = fetch {
write!(f, ", fetch={}", a)?;
}

Ok(())
}
LogicalPlan::Join(Join {
Expand Down Expand Up @@ -1373,6 +1377,8 @@ pub struct Sort {
pub expr: Vec<Expr>,
/// The incoming logical plan
pub input: Arc<LogicalPlan>,
/// Optional fetch limit
pub fetch: Option<usize>,
}

/// Join two logical plans on one or more join columns
Expand Down
3 changes: 2 additions & 1 deletion datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,10 @@ pub fn from_plan(
expr[group_expr.len()..].to_vec(),
schema.clone(),
)?)),
LogicalPlan::Sort(Sort { .. }) => Ok(LogicalPlan::Sort(Sort {
LogicalPlan::Sort(Sort { fetch, .. }) => Ok(LogicalPlan::Sort(Sort {
expr: expr.to_vec(),
input: Arc::new(inputs[0].clone()),
fetch: *fetch,
})),
LogicalPlan::Join(Join {
join_type,
Expand Down
3 changes: 2 additions & 1 deletion datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ fn optimize(
schema.clone(),
)?))
}
LogicalPlan::Sort(Sort { expr, input }) => {
LogicalPlan::Sort(Sort { expr, input, fetch }) => {
let arrays = to_arrays(expr, input, &mut expr_set)?;

let (mut new_expr, new_input) = rewrite_expr(
Expand All @@ -210,6 +210,7 @@ fn optimize(
Ok(LogicalPlan::Sort(Sort {
expr: pop_expr(&mut new_expr)?,
input: Arc::new(new_input),
fetch: *fetch,
}))
}
LogicalPlan::Join { .. }
Expand Down
Loading