From 56db3139c2957c1f8c7f114ca222f1199002eb3b Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 13 Dec 2022 16:38:27 +0300 Subject: [PATCH 01/17] Sort Removal rule initial commit --- datafusion/common/src/lib.rs | 10 + datafusion/core/src/execution/context.rs | 6 + datafusion/core/src/physical_optimizer/mod.rs | 1 + .../remove_unnecessary_sorts.rs | 871 ++++++++++++++++++ .../core/src/physical_optimizer/utils.rs | 87 ++ datafusion/core/tests/sql/explain_analyze.rs | 5 - datafusion/core/tests/sql/window.rs | 529 +++++++++++ datafusion/expr/src/window_frame.rs | 27 + .../physical-expr/src/aggregate/count.rs | 10 +- datafusion/physical-expr/src/aggregate/mod.rs | 15 + datafusion/physical-expr/src/aggregate/sum.rs | 10 +- .../physical-expr/src/window/aggregate.rs | 14 + .../physical-expr/src/window/built_in.rs | 14 + .../window/built_in_window_function_expr.rs | 16 +- .../physical-expr/src/window/lead_lag.rs | 14 + .../physical-expr/src/window/nth_value.rs | 25 + .../physical-expr/src/window/window_expr.rs | 23 +- 17 files changed, 1667 insertions(+), 10 deletions(-) create mode 100644 datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 60d69324913ba..c6812ed24f8ae 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -30,6 +30,7 @@ pub mod stats; mod table_reference; pub mod test_util; +use arrow::compute::SortOptions; pub use column::Column; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema}; pub use error::{field_not_found, DataFusionError, Result, SchemaError}; @@ -63,3 +64,12 @@ macro_rules! downcast_value { })? }}; } + +/// Compute the "reverse" of given `SortOptions`. +// TODO: If/when arrow supports `!` for `SortOptions`, we can remove this. +pub fn reverse_sort_options(options: SortOptions) -> SortOptions { + SortOptions { + descending: !options.descending, + nulls_first: !options.nulls_first, + } +} diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 759cc79a8bf81..e0ebb8e018287 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -99,6 +99,7 @@ use url::Url; use crate::catalog::listing_schema::ListingSchemaProvider; use crate::datasource::object_store::ObjectStoreUrl; +use crate::physical_optimizer::remove_unnecessary_sorts::RemoveUnnecessarySorts; use uuid::Uuid; use super::options::{ @@ -1596,6 +1597,11 @@ impl SessionState { // To make sure the SinglePartition is satisfied, run the BasicEnforcement again, originally it was the AddCoalescePartitionsExec here. physical_optimizers.push(Arc::new(BasicEnforcement::new())); + // `BasicEnforcement` stage conservatively inserts `SortExec`s before `WindowAggExec`s without + // a deep analysis of window frames. Such analysis may sometimes reveal that a `SortExec` is + // actually unnecessary. The rule below performs this analysis and removes such `SortExec`s. + physical_optimizers.push(Arc::new(RemoveUnnecessarySorts::new())); + SessionState { session_id, optimizer: Optimizer::new(&optimizer_config), diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index 36b00a0e01bcd..a69aa16c343bd 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -24,6 +24,7 @@ pub mod enforcement; pub mod join_selection; pub mod optimizer; pub mod pruning; +pub mod remove_unnecessary_sorts; pub mod repartition; mod utils; diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs new file mode 100644 index 0000000000000..9c2e071dd8b84 --- /dev/null +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -0,0 +1,871 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Remove Unnecessary Sorts optimizer rule is used to for removing unnecessary SortExec's inserted to +//! physical plan. Produces a valid physical plan (in terms of Sorting requirement). Its input can be either +//! valid, or invalid physical plans (in terms of Sorting requirement) +use crate::error::Result; +use crate::physical_optimizer::utils::{ + add_sort_above_child, ordering_satisfy, ordering_satisfy_inner, +}; +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::rewrite::TreeNodeRewritable; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::windows::WindowAggExec; +use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use crate::prelude::SessionConfig; +use arrow::datatypes::SchemaRef; +use datafusion_common::{reverse_sort_options, DataFusionError}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use std::iter::zip; +use std::sync::Arc; + +/// As an example Assume we get +/// "SortExec: [nullable_col@0 ASC]", +/// " SortExec: [non_nullable_col@1 ASC]", somehow in the physical plan +/// The first Sort is unnecessary since, its result would be overwritten by another SortExec. We +/// remove first Sort from the physical plan +#[derive(Default)] +pub struct RemoveUnnecessarySorts {} + +impl RemoveUnnecessarySorts { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for RemoveUnnecessarySorts { + fn optimize( + &self, + plan: Arc, + _config: &SessionConfig, + ) -> Result> { + // Run a bottom-up process to adjust input key ordering recursively + let plan_requirements = PlanWithCorrespondingSort::new(plan); + let adjusted = plan_requirements.transform_up(&remove_unnecessary_sorts)?; + Ok(adjusted.plan) + } + + fn name(&self) -> &str { + "RemoveUnnecessarySorts" + } + + fn schema_check(&self) -> bool { + true + } +} + +fn remove_unnecessary_sorts( + requirements: PlanWithCorrespondingSort, +) -> Result> { + let mut new_children = requirements.plan.children().clone(); + let mut new_sort_onwards = requirements.sort_onwards.clone(); + for (idx, (child, sort_onward)) in new_children + .iter_mut() + .zip(new_sort_onwards.iter_mut()) + .enumerate() + { + let required_ordering = requirements.plan.required_input_ordering()[idx]; + let physical_ordering = child.output_ordering(); + match (required_ordering, physical_ordering) { + (Some(required_ordering), Some(physical_ordering)) => { + let is_ordering_satisfied = + ordering_satisfy_inner(physical_ordering, required_ordering, || { + child.equivalence_properties() + }); + if is_ordering_satisfied { + // can do analysis for sort removal + if !sort_onward.is_empty() { + let (_, sort_any) = sort_onward[0].clone(); + let sort_exec = sort_any + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Plan( + "First layer should start from SortExec".to_string(), + ) + })?; + let sort_output_ordering = sort_exec.output_ordering(); + let sort_input_ordering = sort_exec.input().output_ordering(); + // Do naive analysis, where a SortExec is already sorted according to desired Sorting + if ordering_satisfy( + sort_input_ordering, + sort_output_ordering, + || sort_exec.input().equivalence_properties(), + ) { + update_child_to_remove_unnecessary_sort(child, sort_onward)?; + } else if let Some(window_agg_exec) = + requirements.plan.as_any().downcast_ref::() + { + // For window expressions we can remove some Sorts when expression can be calculated in reverse order also. + if let Some(res) = + analyze_window_sort_removal(window_agg_exec, sort_onward)? + { + return Ok(Some(res)); + } + } + } + } else { + // During sort Removal we have invalidated ordering invariant fix it + // This is effectively moving sort above in the physical plan + update_child_to_remove_unnecessary_sort(child, sort_onward)?; + let sort_expr = required_ordering.to_vec(); + *child = add_sort_above_child(child, sort_expr)?; + // Since we have added Sort, we add it to the sort_onwards also. + sort_onward.push((idx, child.clone())) + } + } + (Some(required), None) => { + // Requirement is not satisfied We should add Sort to the plan. + let sort_expr = required.to_vec(); + *child = add_sort_above_child(child, sort_expr)?; + *sort_onward = vec![(idx, child.clone())]; + } + (None, Some(_)) => { + // Sort doesn't propagate to the layers above in the physical plan + if !requirements.plan.maintains_input_order() { + // Unnecessary Sort is added to the plan, we can remove unnecessary sort + update_child_to_remove_unnecessary_sort(child, sort_onward)?; + } + } + (None, None) => {} + } + } + if !requirements.plan.children().is_empty() { + let new_plan = requirements.plan.with_new_children(new_children)?; + for (idx, new_sort_onward) in new_sort_onwards + .iter_mut() + .enumerate() + .take(new_plan.children().len()) + { + let is_require_ordering = new_plan.required_input_ordering()[idx].is_none(); + //TODO: when `maintains_input_order` returns `Vec` use corresponding index + if new_plan.maintains_input_order() + && is_require_ordering + && !new_sort_onward.is_empty() + { + new_sort_onward.push((idx, new_plan.clone())); + } else if new_plan.as_any().is::() { + new_sort_onward.clear(); + new_sort_onward.push((idx, new_plan.clone())); + } else { + new_sort_onward.clear(); + } + } + Ok(Some(PlanWithCorrespondingSort { + plan: new_plan, + sort_onwards: new_sort_onwards, + })) + } else { + Ok(Some(requirements)) + } +} + +#[derive(Debug, Clone)] +struct PlanWithCorrespondingSort { + plan: Arc, + // For each child keeps a vector of `ExecutionPlan`s starting from SortExec till current plan + // first index of tuple(usize) is child index of plan (we need during updating plan above) + sort_onwards: Vec)>>, +} + +impl PlanWithCorrespondingSort { + pub fn new(plan: Arc) -> Self { + let children_len = plan.children().len(); + PlanWithCorrespondingSort { + plan, + sort_onwards: vec![vec![]; children_len], + } + } + + pub fn children(&self) -> Vec { + let plan_children = self.plan.children(); + plan_children + .into_iter() + .map(|child| { + let length = child.children().len(); + PlanWithCorrespondingSort { + plan: child, + sort_onwards: vec![vec![]; length], + } + }) + .collect() + } +} + +impl TreeNodeRewritable for PlanWithCorrespondingSort { + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.children(); + if !children.is_empty() { + let new_children: Result> = + children.into_iter().map(transform).collect(); + let children_requirements = new_children?; + let children_plans = children_requirements + .iter() + .map(|elem| elem.plan.clone()) + .collect::>(); + let sort_onwards = children_requirements + .iter() + .map(|elem| { + if !elem.sort_onwards.is_empty() { + // TODO: redirect the true sort onwards to above (the one we keep ordering) + // this is possible when maintains_input_order returns vec + elem.sort_onwards[0].clone() + } else { + vec![] + } + }) + .collect::>(); + let plan = with_new_children_if_necessary(self.plan, children_plans)?; + Ok(PlanWithCorrespondingSort { plan, sort_onwards }) + } else { + Ok(self) + } + } +} + +/// Analyzes `WindowAggExec` to determine whether Sort can be removed +fn analyze_window_sort_removal( + window_agg_exec: &WindowAggExec, + sort_onward: &mut Vec<(usize, Arc)>, +) -> Result> { + // If empty immediately return we cannot do analysis + if sort_onward.is_empty() { + return Ok(None); + } + let (_, sort_any) = sort_onward[0].clone(); + let sort_exec = sort_any + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Plan("First layer should start from SortExec".to_string()) + })?; + let required_ordering = sort_exec.output_ordering().ok_or_else(|| { + DataFusionError::Plan("SortExec should have output ordering".to_string()) + })?; + let physical_ordering = sort_exec.input().output_ordering(); + let physical_ordering = if let Some(physical_ordering) = physical_ordering { + physical_ordering + } else { + // If there is no physical ordering, there is no way to remove Sorting, immediately return + return Ok(None); + }; + let window_expr = window_agg_exec.window_expr(); + let partition_keys = window_expr[0].partition_by().to_vec(); + let (can_skip_sorting, should_reverse) = can_skip_sort( + &partition_keys, + required_ordering, + &sort_exec.input().schema(), + physical_ordering, + )?; + let all_window_fns_reversible = + window_expr.iter().all(|e| e.is_window_fn_reversible()); + let is_reversal_blocking = should_reverse && !all_window_fns_reversible; + + if can_skip_sorting && !is_reversal_blocking { + let window_expr = if should_reverse { + window_expr + .iter() + .map(|e| e.get_reversed_expr()) + .collect::>>()? + } else { + window_expr.to_vec() + }; + let new_child = remove_corresponding_sort_from_sub_plan(sort_onward)?; + let new_schema = new_child.schema(); + let new_plan = Arc::new(WindowAggExec::try_new( + window_expr, + new_child, + new_schema, + window_agg_exec.partition_keys.clone(), + Some(physical_ordering.to_vec()), + )?); + Ok(Some(PlanWithCorrespondingSort::new(new_plan))) + } else { + Ok(None) + } +} +/// Updates child such that unnecessary sorting below it is removed +fn update_child_to_remove_unnecessary_sort( + child: &mut Arc, + sort_onwards: &mut Vec<(usize, Arc)>, +) -> Result<()> { + if !sort_onwards.is_empty() { + *child = remove_corresponding_sort_from_sub_plan(sort_onwards)?; + } + Ok(()) +} +/// Removes the sort from the plan in the `sort_onwards` +fn remove_corresponding_sort_from_sub_plan( + sort_onwards: &mut Vec<(usize, Arc)>, +) -> Result> { + let (sort_child_idx, sort_any) = sort_onwards[0].clone(); + let sort_exec = sort_any + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Plan("First layer should start from SortExec".to_string()) + })?; + let mut prev_layer = sort_exec.input().clone(); + let mut prev_layer_child_idx = sort_child_idx; + // We start from 1 hence since first one is sort and we are removing it + // from the plan + for (cur_layer_child_idx, cur_layer) in sort_onwards.iter().skip(1) { + let mut new_children = cur_layer.children(); + new_children[prev_layer_child_idx] = prev_layer; + prev_layer = cur_layer.clone().with_new_children(new_children)?; + prev_layer_child_idx = *cur_layer_child_idx; + } + // We have removed the corresponding sort hence empty the sort_onwards + sort_onwards.clear(); + Ok(prev_layer) +} + +#[derive(Debug)] +/// This structure stores extra column information required to remove unnecessary sorts. +pub struct ColumnInfo { + is_aligned: bool, + reverse: bool, + is_partition: bool, +} + +/// Compares physical ordering and required ordering of all `PhysicalSortExpr`s and returns a tuple. +/// The first element indicates whether these `PhysicalSortExpr`s can be removed from the physical plan. +/// The second element is a flag indicating whether we should reverse the sort direction in order to +/// remove physical sort expressions from the plan. +pub fn can_skip_sort( + partition_keys: &[Arc], + required: &[PhysicalSortExpr], + input_schema: &SchemaRef, + physical_ordering: &[PhysicalSortExpr], +) -> Result<(bool, bool)> { + if required.len() > physical_ordering.len() { + return Ok((false, false)); + } + let mut col_infos = vec![]; + for (sort_expr, physical_expr) in zip(required, physical_ordering) { + let column = sort_expr.expr.clone(); + let is_partition = partition_keys.iter().any(|e| e.eq(&column)); + let (is_aligned, reverse) = + check_alignment(input_schema, physical_expr, sort_expr); + col_infos.push(ColumnInfo { + is_aligned, + reverse, + is_partition, + }); + } + let partition_by_sections = col_infos + .iter() + .filter(|elem| elem.is_partition) + .collect::>(); + let (can_skip_partition_bys, should_reverse_partition_bys) = + if partition_by_sections.is_empty() { + (true, false) + } else { + let first_reverse = partition_by_sections[0].reverse; + let can_skip_partition_bys = partition_by_sections + .iter() + .all(|c| c.is_aligned && c.reverse == first_reverse); + (can_skip_partition_bys, first_reverse) + }; + let order_by_sections = col_infos + .iter() + .filter(|elem| !elem.is_partition) + .collect::>(); + let (can_skip_order_bys, should_reverse_order_bys) = if order_by_sections.is_empty() { + (true, false) + } else { + let first_reverse = order_by_sections[0].reverse; + let can_skip_order_bys = order_by_sections + .iter() + .all(|c| c.is_aligned && c.reverse == first_reverse); + (can_skip_order_bys, first_reverse) + }; + // TODO: We cannot skip partition by keys when sort direction is reversed, + // by propogating partition by sort direction to `WindowAggExec` we can skip + // these columns also. Add support for that (Use direction during partition range calculation). + let can_skip = + can_skip_order_bys && can_skip_partition_bys && !should_reverse_partition_bys; + Ok((can_skip, should_reverse_order_bys)) +} + +/// Compares `physical_ordering` and `required` ordering, returns a tuple +/// indicating (1) whether this column requires sorting, and (2) whether we +/// should reverse the window expression in order to avoid sorting. +fn check_alignment( + input_schema: &SchemaRef, + physical_ordering: &PhysicalSortExpr, + required: &PhysicalSortExpr, +) -> (bool, bool) { + if required.expr.eq(&physical_ordering.expr) { + let nullable = required.expr.nullable(input_schema).unwrap(); + let physical_opts = physical_ordering.options; + let required_opts = required.options; + let is_reversed = if nullable { + physical_opts == reverse_sort_options(required_opts) + } else { + // If the column is not nullable, NULLS FIRST/LAST is not important. + physical_opts.descending != required_opts.descending + }; + let can_skip = !nullable || is_reversed || (physical_opts == required_opts); + (can_skip, is_reversed) + } else { + (false, false) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::physical_plan::displayable; + use crate::physical_plan::filter::FilterExec; + use crate::physical_plan::memory::MemoryExec; + use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; + use crate::physical_plan::windows::create_window_expr; + use crate::prelude::SessionContext; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::Result; + use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; + use datafusion_physical_expr::expressions::{col, NotExpr}; + use datafusion_physical_expr::PhysicalSortExpr; + use std::sync::Arc; + + fn create_test_schema() -> Result { + let nullable_column = Field::new("nullable_col", DataType::Int32, true); + let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false); + let schema = Arc::new(Schema::new(vec![nullable_column, non_nullable_column])); + + Ok(schema) + } + + #[tokio::test] + async fn test_is_column_aligned_nullable() -> Result<()> { + let schema = create_test_schema()?; + let params = vec![ + ((true, true), (false, false), (true, true)), + ((true, true), (false, true), (false, false)), + ((true, true), (true, false), (false, false)), + ((true, false), (false, true), (true, true)), + ((true, false), (false, false), (false, false)), + ((true, false), (true, true), (false, false)), + ]; + for ( + (physical_desc, physical_nulls_first), + (req_desc, req_nulls_first), + (is_aligned_expected, reverse_expected), + ) in params + { + let physical_ordering = PhysicalSortExpr { + expr: col("nullable_col", &schema)?, + options: SortOptions { + descending: physical_desc, + nulls_first: physical_nulls_first, + }, + }; + let required_ordering = PhysicalSortExpr { + expr: col("nullable_col", &schema)?, + options: SortOptions { + descending: req_desc, + nulls_first: req_nulls_first, + }, + }; + let (is_aligned, reverse) = + check_alignment(&schema, &physical_ordering, &required_ordering); + assert_eq!(is_aligned, is_aligned_expected); + assert_eq!(reverse, reverse_expected); + } + + Ok(()) + } + + #[tokio::test] + async fn test_is_column_aligned_non_nullable() -> Result<()> { + let schema = create_test_schema()?; + + let params = vec![ + ((true, true), (false, false), (true, true)), + ((true, true), (false, true), (true, true)), + ((true, true), (true, false), (true, false)), + ((true, false), (false, true), (true, true)), + ((true, false), (false, false), (true, true)), + ((true, false), (true, true), (true, false)), + ]; + for ( + (physical_desc, physical_nulls_first), + (req_desc, req_nulls_first), + (is_aligned_expected, reverse_expected), + ) in params + { + let physical_ordering = PhysicalSortExpr { + expr: col("non_nullable_col", &schema)?, + options: SortOptions { + descending: physical_desc, + nulls_first: physical_nulls_first, + }, + }; + let required_ordering = PhysicalSortExpr { + expr: col("non_nullable_col", &schema)?, + options: SortOptions { + descending: req_desc, + nulls_first: req_nulls_first, + }, + }; + let (is_aligned, reverse) = + check_alignment(&schema, &physical_ordering, &required_ordering); + assert_eq!(is_aligned, is_aligned_expected); + assert_eq!(reverse, reverse_expected); + } + + Ok(()) + } + + #[tokio::test] + async fn test_remove_unnecessary_sort() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("non_nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let sort_exec = Arc::new(SortExec::try_new(sort_exprs, source, None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let physical_plan = Arc::new(SortExec::try_new(sort_exprs, sort_exec, None)?) + as Arc; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "SortExec: [nullable_col@0 ASC]", + " SortExec: [non_nullable_col@1 ASC]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { vec!["SortExec: [nullable_col@0 ASC]"] }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } + + #[tokio::test] + async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("non_nullable_col", source.schema().as_ref()).unwrap(), + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let sort_exec = Arc::new(SortExec::try_new(sort_exprs.clone(), source, None)?) + as Arc; + let window_agg_exec = Arc::new(WindowAggExec::try_new( + vec![create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Count), + "count".to_owned(), + &[col("non_nullable_col", &schema)?], + &[], + &sort_exprs, + Arc::new(WindowFrame::new(true)), + schema.as_ref(), + )?], + sort_exec.clone(), + sort_exec.schema(), + vec![], + Some(sort_exprs), + )?) as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("non_nullable_col", window_agg_exec.schema().as_ref()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + let sort_exec = Arc::new(SortExec::try_new( + sort_exprs.clone(), + window_agg_exec, + None, + )?) as Arc; + // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before + let filter_exec = Arc::new(FilterExec::try_new( + Arc::new(NotExpr::new( + col("non_nullable_col", schema.as_ref()).unwrap(), + )), + sort_exec, + )?) as Arc; + // let filter_exec = sort_exec; + let window_agg_exec = Arc::new(WindowAggExec::try_new( + vec![create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Count), + "count".to_owned(), + &[col("non_nullable_col", &schema)?], + &[], + &sort_exprs, + Arc::new(WindowFrame::new(true)), + schema.as_ref(), + )?], + filter_exec.clone(), + filter_exec.schema(), + vec![], + Some(sort_exprs), + )?) as Arc; + let physical_plan = window_agg_exec; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", + " FilterExec: NOT non_nullable_col@1", + " SortExec: [non_nullable_col@2 ASC NULLS LAST]", + " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", + " SortExec: [non_nullable_col@1 DESC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { + vec![ + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL) }]", + " FilterExec: NOT non_nullable_col@1", + " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", + " SortExec: [non_nullable_col@1 DESC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } + + #[tokio::test] + async fn test_add_required_sort() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let physical_plan = Arc::new(SortPreservingMergeExec::new(sort_exprs, source)) + as Arc; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { vec!["SortPreservingMergeExec: [nullable_col@0 ASC]"] }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } + + #[tokio::test] + async fn test_remove_unnecessary_sort1() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let sort_exec = Arc::new(SortExec::try_new(sort_exprs.clone(), source, None)?) + as Arc; + let sort_preserving_merge_exec = + Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec)) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let sort_exec = Arc::new(SortExec::try_new( + sort_exprs.clone(), + sort_preserving_merge_exec, + None, + )?) as Arc; + let sort_preserving_merge_exec = + Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec)) + as Arc; + let physical_plan = sort_preserving_merge_exec; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + " SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } + + #[tokio::test] + async fn test_change_wrong_sorting() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![ + PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: col("non_nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }, + ]; + let sort_exec = Arc::new(SortExec::try_new( + vec![sort_exprs[0].clone()], + source, + None, + )?) as Arc; + let sort_preserving_merge_exec = + Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec)) + as Arc; + let physical_plan = sort_preserving_merge_exec; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index 4aceb776d7d5b..7389da085d600 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -21,7 +21,12 @@ use super::optimizer::PhysicalOptimizerRule; use crate::execution::context::SessionConfig; use crate::error::Result; +use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use datafusion_physical_expr::{ + normalize_sort_expr_with_equivalence_properties, EquivalenceProperties, + PhysicalSortExpr, +}; use std::sync::Arc; /// Convenience rule for writing optimizers: recursively invoke @@ -45,3 +50,85 @@ pub fn optimize_children( with_new_children_if_necessary(plan, children) } } + +/// Check the required ordering requirements are satisfied by the provided PhysicalSortExprs. +pub fn ordering_satisfy EquivalenceProperties>( + provided: Option<&[PhysicalSortExpr]>, + required: Option<&[PhysicalSortExpr]>, + equal_properties: F, +) -> bool { + match (provided, required) { + (_, None) => true, + (None, Some(_)) => false, + (Some(provided), Some(required)) => { + ordering_satisfy_inner(provided, required, equal_properties) + } + } +} + +pub fn ordering_satisfy_inner EquivalenceProperties>( + provided: &[PhysicalSortExpr], + required: &[PhysicalSortExpr], + equal_properties: F, +) -> bool { + if required.len() > provided.len() { + false + } else { + let fast_match = required + .iter() + .zip(provided.iter()) + .all(|(order1, order2)| order1.eq(order2)); + + if !fast_match { + let eq_properties = equal_properties(); + let eq_classes = eq_properties.classes(); + if !eq_classes.is_empty() { + let normalized_required_exprs = required + .iter() + .map(|e| { + normalize_sort_expr_with_equivalence_properties( + e.clone(), + eq_classes, + ) + }) + .collect::>(); + let normalized_provided_exprs = provided + .iter() + .map(|e| { + normalize_sort_expr_with_equivalence_properties( + e.clone(), + eq_classes, + ) + }) + .collect::>(); + normalized_required_exprs + .iter() + .zip(normalized_provided_exprs.iter()) + .all(|(order1, order2)| order1.eq(order2)) + } else { + fast_match + } + } else { + fast_match + } + } +} + +/// Util function to add SortExec above child +pub fn add_sort_above_child( + child: &Arc, + sort_expr: Vec, +) -> Result> { + let new_child = if child.output_partitioning().partition_count() > 1 { + Arc::new(SortExec::new_with_partitioning( + sort_expr, + child.clone(), + true, + None, + )) as Arc + } else { + Arc::new(SortExec::try_new(sort_expr, child.clone(), None)?) + as Arc + }; + Ok(new_child) +} diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 89112adae74a6..98ccb79c62634 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -62,11 +62,6 @@ async fn explain_analyze_baseline_metrics() { "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", "metrics=[output_rows=5, elapsed_compute=" ); - assert_metrics!( - &formatted, - "SortExec: [c1@0 ASC NULLS LAST]", - "metrics=[output_rows=5, elapsed_compute=" - ); assert_metrics!( &formatted, "FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index b550e7f5dd60b..3a44493aaab6d 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1646,3 +1646,532 @@ async fn test_window_agg_sort() -> Result<()> { ); Ok(()) } + +#[tokio::test] +async fn test_window_agg_sort_reversed_plan() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1, + SUM(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c9@2 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@0 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+-------------+-------------+", + "| c9 | sum1 | sum2 |", + "+------------+-------------+-------------+", + "| 4268716378 | 8498370520 | 24997484146 |", + "| 4229654142 | 12714811027 | 29012926487 |", + "| 4216440507 | 16858984380 | 28743001064 |", + "| 4144173353 | 20935849039 | 28472563256 |", + "| 4076864659 | 24997484146 | 28118515915 |", + "+------------+-------------+-------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_reversed_plan_builtin() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + FIRST_VALUE(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as fv1, + FIRST_VALUE(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as fv2, + LAG(c9, 2, 10101) OVER(ORDER BY c9 ASC) as lag1, + LAG(c9, 2, 10101) OVER(ORDER BY c9 DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lag2, + LEAD(c9, 2, 10101) OVER(ORDER BY c9 ASC) as lead1, + LEAD(c9, 2, 10101) OVER(ORDER BY c9 DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lead2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c9@6 as c9, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as fv1, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lead2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt32(NULL)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt32(NULL)) }]", + " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }]", + " SortExec: [c9@0 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+------------+------------+------------+------------+------------+------------+", + "| c9 | fv1 | fv2 | lag1 | lag2 | lead1 | lead2 |", + "+------------+------------+------------+------------+------------+------------+------------+", + "| 4268716378 | 4229654142 | 4268716378 | 4216440507 | 10101 | 10101 | 4216440507 |", + "| 4229654142 | 4216440507 | 4268716378 | 4144173353 | 10101 | 10101 | 4144173353 |", + "| 4216440507 | 4144173353 | 4229654142 | 4076864659 | 4268716378 | 4268716378 | 4076864659 |", + "| 4144173353 | 4076864659 | 4216440507 | 4061635107 | 4229654142 | 4229654142 | 4061635107 |", + "| 4076864659 | 4061635107 | 4144173353 | 4015442341 | 4216440507 | 4216440507 | 4015442341 |", + "+------------+------------+------------+------------+------------+------------+------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_non_reversed_plan() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + ROW_NUMBER() OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as rn1, + ROW_NUMBER() OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as rn2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // We cannot reverse each window function (ROW_NUMBER is not reversible) + let expected = { + vec![ + "ProjectionExec: expr=[c9@2 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as rn1, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@1 ASC NULLS LAST]", + " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@0 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------+-----+-----+", + "| c9 | rn1 | rn2 |", + "+-----------+-----+-----+", + "| 28774375 | 1 | 100 |", + "| 63044568 | 2 | 99 |", + "| 141047417 | 3 | 98 |", + "| 141680161 | 4 | 97 |", + "| 145294611 | 5 | 96 |", + "+-----------+-----+-----+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + SUM(c9) OVER(ORDER BY c9 ASC, c1 ASC, c2 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1, + SUM(c9) OVER(ORDER BY c9 DESC, c1 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2, + ROW_NUMBER() OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as rn2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // We cannot reverse each window function (ROW_NUMBER is not reversible) + let expected = { + vec![ + "ProjectionExec: expr=[c9@5 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as rn2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@4 DESC]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@2 ASC NULLS LAST,c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+-------------+-------------+-----+", + "| c9 | sum1 | sum2 | rn2 |", + "+------------+-------------+-------------+-----+", + "| 4268716378 | 8498370520 | 24997484146 | 1 |", + "| 4229654142 | 12714811027 | 29012926487 | 2 |", + "| 4216440507 | 16858984380 | 28743001064 | 3 |", + "| 4144173353 | 20935849039 | 28472563256 | 4 |", + "| 4076864659 | 24997484146 | 28118515915 | 5 |", + "+------------+-------------+-------------+-----+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_complex_plan() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_null_cases_csv(&ctx).await?; + let sql = "SELECT + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as a, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as b, + SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as c, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as d, + SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as e, + SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as f, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as g, + SUM(c1) OVER (ORDER BY c3) as h, + SUM(c1) OVER (ORDER BY c3 DESC) as i, + SUM(c1) OVER (ORDER BY c3 NULLS first) as j, + SUM(c1) OVER (ORDER BY c3 DESC NULLS first) as k, + SUM(c1) OVER (ORDER BY c3 DESC NULLS last) as l, + SUM(c1) OVER (ORDER BY c3, c2) as m, + SUM(c1) OVER (ORDER BY c3, c1 DESC) as n, + SUM(c1) OVER (ORDER BY c3 DESC, c1) as o, + SUM(c1) OVER (ORDER BY c3, c1 NULLs first) as p, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as a1, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as b1, + SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as c1, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as d1, + SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as e1, + SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as f1, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as g1, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as h1, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as j1, + SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as k1, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as l1, + SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as m1, + SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as n1, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as o1, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as h11, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as j11, + SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as k11, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as l11, + SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as m11, + SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as n11, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as o11 + FROM null_cases + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Unnecessary SortExecs are removed + let expected = { + vec![ + "ProjectionExec: expr=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as a, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as b, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@4 as c, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@8 as d, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@12 as e, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@4 as f, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@8 as g, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as h, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as i, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as j, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as k, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@13 as l, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as m, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@17 as n, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@18 as o, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as p, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@2 as a1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@2 as b1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@6 as c1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@10 as d1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@14 as e1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@6 as f1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@10 as g1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as h1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as j1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as k1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as l1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@13 as m1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as n1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as o1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@3 as h11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@3 as j11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@7 as k11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@11 as l11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@15 as m11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@7 as n11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@11 as o11]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@5 ASC NULLS LAST,c2@4 ASC NULLS LAST]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@3 DESC,c1@1 ASC NULLS LAST]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@2 ASC NULLS LAST,c1@0 ASC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_orderby_reversed_partitionby_plan() -> Result<()> { + let config = SessionConfig::new().with_repartition_windows(false); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + SUM(c9) OVER(ORDER BY c1, c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1, + SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c1@0 ASC,c9@1 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+-------------+-------------+", + "| c9 | sum1 | sum2 |", + "+------------+-------------+-------------+", + "| 4015442341 | 21907044499 | 21907044499 |", + "| 3998790955 | 24576419362 | 24576419362 |", + "| 3959216334 | 23063303501 | 23063303501 |", + "| 3717551163 | 21560567246 | 21560567246 |", + "| 3276123488 | 19815386638 | 19815386638 |", + "+------------+-------------+-------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_partitionby_reversed_plan() -> Result<()> { + let config = SessionConfig::new().with_repartition_windows(false); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1, + SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c1@0 ASC,c9@1 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+-------------+-------------+", + "| c9 | sum1 | sum2 |", + "+------------+-------------+-------------+", + "| 4015442341 | 8014233296 | 21907044499 |", + "| 3998790955 | 11973449630 | 24576419362 |", + "| 3959216334 | 15691000793 | 23063303501 |", + "| 3717551163 | 18967124281 | 21560567246 |", + "| 3276123488 | 21907044499 | 19815386638 |", + "+------------+-------------+-------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_orderby_reversed_binary_expr() -> Result<()> { + let config = SessionConfig::new().with_repartition_windows(false); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c3, + SUM(c9) OVER(ORDER BY c3+c4 DESC, c9 DESC, c2 ASC) as sum1, + SUM(c9) OVER(ORDER BY c3+c4 ASC, c9 ASC ) as sum2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c3@3 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@0 as sum2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: \"SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int16(NULL)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: \"SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow }]", + " SortExec: [CAST(c3@1 AS Int16) + c4@2 DESC,c9@3 DESC,c2@0 ASC NULLS LAST]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----+-------------+--------------+", + "| c3 | sum1 | sum2 |", + "+-----+-------------+--------------+", + "| -86 | 2861911482 | 222089770060 |", + "| 13 | 5075947208 | 219227858578 |", + "| 125 | 8701233618 | 217013822852 |", + "| 123 | 11293564174 | 213388536442 |", + "| 97 | 14767488750 | 210796205886 |", + "+-----+-------------+--------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_remove_unnecessary_sort_in_sub_query() -> Result<()> { + let config = SessionConfig::new(); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT count(*) as global_count FROM + (SELECT count(*), c1 + FROM aggregate_test_100 + WHERE c13 != 'C2GT5KVyOPZpgKVl110TyZO0NcJ434' + GROUP BY c1 + ORDER BY c1 ) AS a "; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Unnecessary Sort in the sub query is removed + let expected = { + vec![ + "ProjectionExec: expr=[COUNT(UInt8(1))@0 as global_count]", + " AggregateExec: mode=Final, gby=[], aggr=[COUNT(UInt8(1))]", + " CoalescePartitionsExec", + " AggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))]", + " RepartitionExec: partitioning=RoundRobinBatch(8)", + " CoalescePartitionsExec", + " AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(UInt8(1))]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 8)", + " AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[COUNT(UInt8(1))]", + " CoalesceBatchesExec: target_batch_size=4096", + " FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", + " RepartitionExec: partitioning=RoundRobinBatch(8)", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+--------------+", + "| global_count |", + "+--------------+", + "| 5 |", + "+--------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 35790885e02f0..b8274ed1d1930 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -113,6 +113,33 @@ impl WindowFrame { } } } + + /// Get reversed window frame + pub fn reverse(&self) -> Self { + let start_bound = match &self.end_bound { + WindowFrameBound::Preceding(elem) => { + WindowFrameBound::Following(elem.clone()) + } + WindowFrameBound::Following(elem) => { + WindowFrameBound::Preceding(elem.clone()) + } + WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow, + }; + let end_bound = match &self.start_bound { + WindowFrameBound::Preceding(elem) => { + WindowFrameBound::Following(elem.clone()) + } + WindowFrameBound::Following(elem) => { + WindowFrameBound::Preceding(elem.clone()) + } + WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow, + }; + WindowFrame { + units: self.units, + start_bound, + end_bound, + } + } } /// There are five ways to describe starting and ending frame boundaries: diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 4721bf8f2301b..7fceaeedeab3c 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -36,7 +36,7 @@ use crate::expressions::format_state_name; /// COUNT aggregate expression /// Returns the amount of non-null values of the given expression. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Count { name: String, data_type: DataType, @@ -104,6 +104,14 @@ impl AggregateExpr for Count { ) -> Result> { Ok(Box::new(CountRowAccumulator::new(start_index))) } + + fn is_window_fn_reversible(&self) -> bool { + true + } + + fn reverse_expr(&self) -> Result> { + Ok(Arc::new(self.clone())) + } } #[derive(Debug)] diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index f6374687403ec..b0776886e69d5 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -101,4 +101,19 @@ pub trait AggregateExpr: Send + Sync + Debug { self ))) } + + /// Get whether window function is reversible + /// make `true` if `reverse_expr` is implemented + fn is_window_fn_reversible(&self) -> bool { + false + } + + /// Construct Reverse Expression + // Typically expression itself for aggregate functions + fn reverse_expr(&self) -> Result> { + Err(DataFusionError::NotImplemented(format!( + "reverse_expr hasn't been implemented for {:?} yet", + self + ))) + } } diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index b330455a1855b..11371f31d4c68 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -42,7 +42,7 @@ use arrow::compute::cast; use datafusion_row::accessor::RowAccessor; /// SUM aggregate expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Sum { name: String, data_type: DataType, @@ -132,6 +132,14 @@ impl AggregateExpr for Sum { self.data_type.clone(), ))) } + + fn is_window_fn_reversible(&self) -> bool { + true + } + + fn reverse_expr(&self) -> Result> { + Ok(Arc::new(self.clone())) + } } #[derive(Debug)] diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 52a43050b1cc8..1f0c286efc859 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -30,6 +30,7 @@ use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::WindowFrame; +use crate::window::window_expr::reverse_order_bys; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use crate::{window::WindowExpr, AggregateExpr}; @@ -155,4 +156,17 @@ impl WindowExpr for AggregateWindowExpr { fn get_window_frame(&self) -> &Arc { &self.window_frame } + + fn is_window_fn_reversible(&self) -> bool { + self.aggregate.as_ref().is_window_fn_reversible() + } + + fn get_reversed_expr(&self) -> Result> { + Ok(Arc::new(AggregateWindowExpr::new( + self.aggregate.reverse_expr()?, + &self.partition_by.clone(), + &reverse_order_bys(&self.order_by), + Arc::new(self.window_frame.reverse()), + ))) + } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 95bf01608b82f..8e16fc7e7a627 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -20,6 +20,7 @@ use super::window_frame_state::WindowFrameContext; use super::BuiltInWindowFunctionExpr; use super::WindowExpr; +use crate::window::window_expr::reverse_order_bys; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use arrow::compute::{concat, SortOptions}; use arrow::record_batch::RecordBatch; @@ -137,4 +138,17 @@ impl WindowExpr for BuiltInWindowExpr { fn get_window_frame(&self) -> &Arc { &self.window_frame } + + fn is_window_fn_reversible(&self) -> bool { + self.expr.as_ref().is_window_fn_reversible() + } + + fn get_reversed_expr(&self) -> Result> { + Ok(Arc::new(BuiltInWindowExpr::new( + self.expr.reverse_expr()?, + &self.partition_by.clone(), + &reverse_order_bys(&self.order_by), + Arc::new(self.window_frame.reverse()), + ))) + } } diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index 7f7a27435c392..71b100b54e5e4 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -20,7 +20,7 @@ use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; -use datafusion_common::Result; +use datafusion_common::{DataFusionError, Result}; use std::any::Any; use std::sync::Arc; @@ -58,4 +58,18 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { /// Create built-in window evaluator with a batch fn create_evaluator(&self) -> Result>; + + /// Get whether window function is reversible + /// make true if `reverse_expr` is implemented + fn is_window_fn_reversible(&self) -> bool { + false + } + + /// Construct Reverse Expression + fn reverse_expr(&self) -> Result> { + Err(DataFusionError::NotImplemented(format!( + "reverse_expr hasn't been implemented for {:?} yet", + self + ))) + } } diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index c7fc73b9f1c1f..97076963f60d4 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -107,6 +107,20 @@ impl BuiltInWindowFunctionExpr for WindowShift { default_value: self.default_value.clone(), })) } + + fn is_window_fn_reversible(&self) -> bool { + true + } + + fn reverse_expr(&self) -> Result> { + Ok(Arc::new(Self { + name: self.name.clone(), + data_type: self.data_type.clone(), + shift_offset: -self.shift_offset, + expr: self.expr.clone(), + default_value: self.default_value.clone(), + })) + } } pub(crate) struct WindowShiftEvaluator { diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 63a2354c9e4c5..2942a797b8526 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -123,6 +123,31 @@ impl BuiltInWindowFunctionExpr for NthValue { fn create_evaluator(&self) -> Result> { Ok(Box::new(NthValueEvaluator { kind: self.kind })) } + + fn is_window_fn_reversible(&self) -> bool { + match self.kind { + NthValueKind::First | NthValueKind::Last => true, + NthValueKind::Nth(_) => false, + } + } + + fn reverse_expr(&self) -> Result> { + let reversed_kind = match self.kind { + NthValueKind::First => NthValueKind::Last, + NthValueKind::Last => NthValueKind::First, + NthValueKind::Nth(_) => { + return Err(DataFusionError::Execution( + "Cannot take reverse of NthValue".to_string(), + )) + } + }; + Ok(Arc::new(Self { + name: self.name.clone(), + expr: self.expr.clone(), + data_type: self.data_type.clone(), + kind: reversed_kind, + })) + } } /// Value evaluator for nth_value functions diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 209e0544f2fa6..b1af6cf49a0ab 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -20,7 +20,7 @@ use arrow::compute::kernels::partition::lexicographical_partition_ranges; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{reverse_sort_options, DataFusionError, Result}; use datafusion_expr::WindowFrame; use std::any::Any; use std::fmt::Debug; @@ -129,6 +129,25 @@ pub trait WindowExpr: Send + Sync + Debug { Ok((values, order_bys)) } - // Get window frame of this WindowExpr, None if absent + // Get window frame of this WindowExpr fn get_window_frame(&self) -> &Arc; + + /// Get whether window function can be reversed + fn is_window_fn_reversible(&self) -> bool; + + /// get reversed expression + fn get_reversed_expr(&self) -> Result>; +} + +/// Reverses the ORDER BY expression, which is useful during equivalent window +/// expression construction. For instance, 'ORDER BY a ASC, NULLS LAST' turns into +/// 'ORDER BY a DESC, NULLS FIRST'. +pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec { + order_bys + .iter() + .map(|e| PhysicalSortExpr { + expr: e.expr.clone(), + options: reverse_sort_options(e.options), + }) + .collect() } From 343fafbd962183bb7c8d2126b3cdcd3f545e556d Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 13 Dec 2022 16:42:53 +0300 Subject: [PATCH 02/17] move ordering satisfy to the util --- .../src/physical_optimizer/enforcement.rs | 79 ++----------------- 1 file changed, 8 insertions(+), 71 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforcement.rs b/datafusion/core/src/physical_optimizer/enforcement.rs index 3110061c4f2cb..9c7314a1b33ca 100644 --- a/datafusion/core/src/physical_optimizer/enforcement.rs +++ b/datafusion/core/src/physical_optimizer/enforcement.rs @@ -20,6 +20,7 @@ //! use crate::config::OPT_TOP_DOWN_JOIN_KEY_REORDERING; use crate::error::Result; +use crate::physical_optimizer::utils::{add_sort_above_child, ordering_satisfy}; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -29,7 +30,6 @@ use crate::physical_plan::joins::{ use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::rewrite::TreeNodeRewritable; -use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort::SortOptions; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::Partitioning; @@ -41,9 +41,8 @@ use datafusion_physical_expr::equivalence::EquivalenceProperties; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::expressions::NoOp; use datafusion_physical_expr::{ - expr_list_eq_strict_order, normalize_expr_with_equivalence_properties, - normalize_sort_expr_with_equivalence_properties, AggregateExpr, PhysicalExpr, - PhysicalSortExpr, + expr_list_eq_strict_order, normalize_expr_with_equivalence_properties, AggregateExpr, + PhysicalExpr, }; use std::collections::HashMap; use std::sync::Arc; @@ -885,14 +884,7 @@ fn ensure_distribution_and_ordering( Ok(child) } else { let sort_expr = required.unwrap().to_vec(); - if child.output_partitioning().partition_count() > 1 { - Ok(Arc::new(SortExec::new_with_partitioning( - sort_expr, child, true, None, - )) as Arc) - } else { - Ok(Arc::new(SortExec::try_new(sort_expr, child, None)?) - as Arc) - } + add_sort_above_child(&child, sort_expr) } }) .collect(); @@ -900,61 +892,6 @@ fn ensure_distribution_and_ordering( with_new_children_if_necessary(plan, new_children?) } -/// Check the required ordering requirements are satisfied by the provided PhysicalSortExprs. -fn ordering_satisfy EquivalenceProperties>( - provided: Option<&[PhysicalSortExpr]>, - required: Option<&[PhysicalSortExpr]>, - equal_properties: F, -) -> bool { - match (provided, required) { - (_, None) => true, - (None, Some(_)) => false, - (Some(provided), Some(required)) => { - if required.len() > provided.len() { - false - } else { - let fast_match = required - .iter() - .zip(provided.iter()) - .all(|(order1, order2)| order1.eq(order2)); - - if !fast_match { - let eq_properties = equal_properties(); - let eq_classes = eq_properties.classes(); - if !eq_classes.is_empty() { - let normalized_required_exprs = required - .iter() - .map(|e| { - normalize_sort_expr_with_equivalence_properties( - e.clone(), - eq_classes, - ) - }) - .collect::>(); - let normalized_provided_exprs = provided - .iter() - .map(|e| { - normalize_sort_expr_with_equivalence_properties( - e.clone(), - eq_classes, - ) - }) - .collect::>(); - normalized_required_exprs - .iter() - .zip(normalized_provided_exprs.iter()) - .all(|(order1, order2)| order1.eq(order2)) - } else { - fast_match - } - } else { - fast_match - } - } - } - } -} - #[derive(Debug, Clone)] struct JoinKeyPairs { left_keys: Vec>, @@ -1034,10 +971,10 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_expr::logical_plan::JoinType; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::binary; - use datafusion_physical_expr::expressions::lit; - use datafusion_physical_expr::expressions::Column; - use datafusion_physical_expr::{expressions, PhysicalExpr}; + use datafusion_physical_expr::{ + expressions, expressions::binary, expressions::lit, expressions::Column, + PhysicalExpr, PhysicalSortExpr, + }; use std::ops::Deref; use super::*; From dfb66836b9b328c296ae55b572c5770c09ba4abb Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 13 Dec 2022 17:06:40 +0300 Subject: [PATCH 03/17] update test and change repartition maintain_input_order impl --- datafusion/core/src/physical_plan/repartition.rs | 16 +++++++++++++++- datafusion/core/tests/sql/window.rs | 4 +++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs index 9492fb7497a62..12424c8587971 100644 --- a/datafusion/core/src/physical_plan/repartition.rs +++ b/datafusion/core/src/physical_plan/repartition.rs @@ -289,7 +289,21 @@ impl ExecutionPlan for RepartitionExec { } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None + if self.maintains_input_order() { + self.input().output_ordering() + } else { + None + } + } + + fn maintains_input_order(&self) -> bool { + let n_input = match self.input().output_partitioning() { + Partitioning::RoundRobinBatch(n) => n, + Partitioning::Hash(_, n) => n, + Partitioning::UnknownPartitioning(n) => n, + }; + // We preserve ordering when input partitioning is 1 + n_input <= 1 } fn equivalence_properties(&self) -> EquivalenceProperties { diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 3a44493aaab6d..085628051e17e 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -2119,7 +2119,9 @@ async fn test_window_agg_sort_orderby_reversed_binary_expr() -> Result<()> { #[tokio::test] async fn test_remove_unnecessary_sort_in_sub_query() -> Result<()> { - let config = SessionConfig::new(); + let config = SessionConfig::new() + .with_target_partitions(8) + .with_repartition_windows(true); let ctx = SessionContext::with_config(config); register_aggregate_csv(&ctx).await?; let sql = "SELECT count(*) as global_count FROM From 0a42315520940946cb2a80483439d4dcd46a7a10 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 13 Dec 2022 17:15:48 +0300 Subject: [PATCH 04/17] simplifications --- .../remove_unnecessary_sorts.rs | 86 ++++++++----------- 1 file changed, 37 insertions(+), 49 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index 9c2e071dd8b84..322677a30428d 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -88,39 +88,7 @@ fn remove_unnecessary_sorts( ordering_satisfy_inner(physical_ordering, required_ordering, || { child.equivalence_properties() }); - if is_ordering_satisfied { - // can do analysis for sort removal - if !sort_onward.is_empty() { - let (_, sort_any) = sort_onward[0].clone(); - let sort_exec = sort_any - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Plan( - "First layer should start from SortExec".to_string(), - ) - })?; - let sort_output_ordering = sort_exec.output_ordering(); - let sort_input_ordering = sort_exec.input().output_ordering(); - // Do naive analysis, where a SortExec is already sorted according to desired Sorting - if ordering_satisfy( - sort_input_ordering, - sort_output_ordering, - || sort_exec.input().equivalence_properties(), - ) { - update_child_to_remove_unnecessary_sort(child, sort_onward)?; - } else if let Some(window_agg_exec) = - requirements.plan.as_any().downcast_ref::() - { - // For window expressions we can remove some Sorts when expression can be calculated in reverse order also. - if let Some(res) = - analyze_window_sort_removal(window_agg_exec, sort_onward)? - { - return Ok(Some(res)); - } - } - } - } else { + if !is_ordering_satisfied { // During sort Removal we have invalidated ordering invariant fix it // This is effectively moving sort above in the physical plan update_child_to_remove_unnecessary_sort(child, sort_onward)?; @@ -128,6 +96,29 @@ fn remove_unnecessary_sorts( *child = add_sort_above_child(child, sort_expr)?; // Since we have added Sort, we add it to the sort_onwards also. sort_onward.push((idx, child.clone())) + } else if is_ordering_satisfied && !sort_onward.is_empty() { + // can do analysis for sort removal + let (_, sort_any) = sort_onward[0].clone(); + let sort_exec = convert_to_sort_exec(&sort_any)?; + let sort_output_ordering = sort_exec.output_ordering(); + let sort_input_ordering = sort_exec.input().output_ordering(); + // Do naive analysis, where a SortExec is already sorted according to desired Sorting + if ordering_satisfy(sort_input_ordering, sort_output_ordering, || { + sort_exec.input().equivalence_properties() + }) { + update_child_to_remove_unnecessary_sort(child, sort_onward)?; + } else if let Some(window_agg_exec) = + requirements.plan.as_any().downcast_ref::() + { + // For window expressions we can remove some Sorts when expression can be calculated in reverse order also. + if let Some(res) = analyze_window_sort_removal( + window_agg_exec, + sort_exec, + sort_onward, + )? { + return Ok(Some(res)); + } + } } } (Some(required), None) => { @@ -245,19 +236,9 @@ impl TreeNodeRewritable for PlanWithCorrespondingSort { /// Analyzes `WindowAggExec` to determine whether Sort can be removed fn analyze_window_sort_removal( window_agg_exec: &WindowAggExec, + sort_exec: &SortExec, sort_onward: &mut Vec<(usize, Arc)>, ) -> Result> { - // If empty immediately return we cannot do analysis - if sort_onward.is_empty() { - return Ok(None); - } - let (_, sort_any) = sort_onward[0].clone(); - let sort_exec = sort_any - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Plan("First layer should start from SortExec".to_string()) - })?; let required_ordering = sort_exec.output_ordering().ok_or_else(|| { DataFusionError::Plan("SortExec should have output ordering".to_string()) })?; @@ -313,17 +294,24 @@ fn update_child_to_remove_unnecessary_sort( } Ok(()) } -/// Removes the sort from the plan in the `sort_onwards` -fn remove_corresponding_sort_from_sub_plan( - sort_onwards: &mut Vec<(usize, Arc)>, -) -> Result> { - let (sort_child_idx, sort_any) = sort_onwards[0].clone(); + +/// Convert dyn ExecutionPlan to SortExec (Assumes it is SortExec) +fn convert_to_sort_exec(sort_any: &Arc) -> Result<&SortExec> { let sort_exec = sort_any .as_any() .downcast_ref::() .ok_or_else(|| { DataFusionError::Plan("First layer should start from SortExec".to_string()) })?; + Ok(sort_exec) +} + +/// Removes the sort from the plan in the `sort_onwards` +fn remove_corresponding_sort_from_sub_plan( + sort_onwards: &mut Vec<(usize, Arc)>, +) -> Result> { + let (sort_child_idx, sort_any) = sort_onwards[0].clone(); + let sort_exec = convert_to_sort_exec(&sort_any)?; let mut prev_layer = sort_exec.input().clone(); let mut prev_layer_child_idx = sort_child_idx; // We start from 1 hence since first one is sort and we are removing it From c2a1593cd4fb4c8f831a7f40f6e6dd0c53acb695 Mon Sep 17 00:00:00 2001 From: Mustafa akur <106137913+mustafasrepo@users.noreply.github.com> Date: Thu, 15 Dec 2022 15:10:19 +0300 Subject: [PATCH 05/17] partition by refactor (#28) * partition by refactor * minor changes * Unnecessary tuple to Range conversion is removed * move transpose under common --- datafusion/common/src/lib.rs | 12 +++ .../remove_unnecessary_sorts.rs | 25 +++--- .../physical_plan/windows/window_agg_exec.rs | 79 +++++++++++++++- datafusion/core/tests/sql/window.rs | 54 +++++++++++ .../physical-expr/src/window/aggregate.rs | 89 +++++++++---------- .../physical-expr/src/window/built_in.rs | 55 +++++------- .../physical-expr/src/window/cume_dist.rs | 25 +++--- .../physical-expr/src/window/lead_lag.rs | 17 ++-- .../src/window/partition_evaluator.rs | 55 +----------- datafusion/physical-expr/src/window/rank.rs | 28 +++--- .../physical-expr/src/window/row_number.rs | 18 ++-- .../physical-expr/src/window/window_expr.rs | 20 +---- .../src/window/window_frame_state.rs | 15 ++-- 13 files changed, 262 insertions(+), 230 deletions(-) diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index c6812ed24f8ae..63683f5af0242 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -73,3 +73,15 @@ pub fn reverse_sort_options(options: SortOptions) -> SortOptions { nulls_first: !options.nulls_first, } } + +/// Transposes 2d vector +pub fn transpose(original: Vec>) -> Vec> { + assert!(!original.is_empty()); + let mut transposed = (0..original[0].len()).map(|_| vec![]).collect::>(); + for original_row in original { + for (item, transposed_row) in original_row.into_iter().zip(&mut transposed) { + transposed_row.push(item); + } + } + transposed +} diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index 322677a30428d..ccaf93e0e60d3 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -364,16 +364,15 @@ pub fn can_skip_sort( .iter() .filter(|elem| elem.is_partition) .collect::>(); - let (can_skip_partition_bys, should_reverse_partition_bys) = - if partition_by_sections.is_empty() { - (true, false) - } else { - let first_reverse = partition_by_sections[0].reverse; - let can_skip_partition_bys = partition_by_sections - .iter() - .all(|c| c.is_aligned && c.reverse == first_reverse); - (can_skip_partition_bys, first_reverse) - }; + let can_skip_partition_bys = if partition_by_sections.is_empty() { + true + } else { + let first_reverse = partition_by_sections[0].reverse; + let can_skip_partition_bys = partition_by_sections + .iter() + .all(|c| c.is_aligned && c.reverse == first_reverse); + can_skip_partition_bys + }; let order_by_sections = col_infos .iter() .filter(|elem| !elem.is_partition) @@ -387,11 +386,7 @@ pub fn can_skip_sort( .all(|c| c.is_aligned && c.reverse == first_reverse); (can_skip_order_bys, first_reverse) }; - // TODO: We cannot skip partition by keys when sort direction is reversed, - // by propogating partition by sort direction to `WindowAggExec` we can skip - // these columns also. Add support for that (Use direction during partition range calculation). - let can_skip = - can_skip_order_bys && can_skip_partition_bys && !should_reverse_partition_bys; + let can_skip = can_skip_order_bys && can_skip_partition_bys; Ok((can_skip, should_reverse_order_bys)) } diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs index 914e3e71dbad7..837f32ac69b53 100644 --- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs @@ -28,19 +28,23 @@ use crate::physical_plan::{ ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, }; -use arrow::compute::concat_batches; +use arrow::compute::{ + concat, concat_batches, lexicographical_partition_ranges, SortColumn, +}; use arrow::{ array::ArrayRef, datatypes::{Schema, SchemaRef}, error::{ArrowError, Result as ArrowResult}, record_batch::RecordBatch, }; +use datafusion_common::{transpose, DataFusionError}; use datafusion_physical_expr::rewrite::TreeNodeRewritable; use datafusion_physical_expr::EquivalentClass; use futures::stream::Stream; use futures::{ready, StreamExt}; use log::debug; use std::any::Any; +use std::ops::Range; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -131,6 +135,25 @@ impl WindowAggExec { pub fn input_schema(&self) -> SchemaRef { self.input_schema.clone() } + + /// Get Partition Columns + pub fn partition_by_sort_keys(&self) -> Result> { + // All window exprs have same partition by hance we just use first one + let partition_by = self.window_expr()[0].partition_by(); + let mut partition_columns = vec![]; + for elem in partition_by { + if let Some(sort_keys) = &self.sort_keys { + for a in sort_keys { + if a.expr.eq(elem) { + partition_columns.push(a.clone()); + break; + } + } + } + } + assert_eq!(partition_by.len(), partition_columns.len()); + Ok(partition_columns) + } } impl ExecutionPlan for WindowAggExec { @@ -253,6 +276,7 @@ impl ExecutionPlan for WindowAggExec { self.window_expr.clone(), input, BaselineMetrics::new(&self.metrics, partition), + self.partition_by_sort_keys()?, )); Ok(stream) } @@ -337,6 +361,7 @@ pub struct WindowAggStream { batches: Vec, finished: bool, window_expr: Vec>, + partition_by_sort_keys: Vec, baseline_metrics: BaselineMetrics, } @@ -347,6 +372,7 @@ impl WindowAggStream { window_expr: Vec>, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, + partition_by_sort_keys: Vec, ) -> Self { Self { schema, @@ -355,6 +381,7 @@ impl WindowAggStream { finished: false, window_expr, baseline_metrics, + partition_by_sort_keys, } } @@ -369,8 +396,27 @@ impl WindowAggStream { let batch = concat_batches(&self.input.schema(), &self.batches)?; // calculate window cols - let mut columns = compute_window_aggregates(&self.window_expr, &batch) - .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + let partition_columns = self.partition_columns(&batch)?; + let partition_points = + self.evaluate_partition_points(batch.num_rows(), &partition_columns)?; + + let mut partition_results = vec![]; + for partition_point in partition_points { + let length = partition_point.end - partition_point.start; + partition_results.push( + compute_window_aggregates( + &self.window_expr, + &batch.slice(partition_point.start, length), + ) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?, + ) + } + let mut columns = transpose(partition_results) + .iter() + .map(|elems| concat(&elems.iter().map(|x| x.as_ref()).collect::>())) + .collect::>() + .into_iter() + .collect::>>()?; // combine with the original cols // note the setup of window aggregates is that they newly calculated window @@ -378,6 +424,33 @@ impl WindowAggStream { columns.extend_from_slice(batch.columns()); RecordBatch::try_new(self.schema.clone(), columns) } + + /// Get Partition Columns + pub fn partition_columns(&self, batch: &RecordBatch) -> Result> { + self.partition_by_sort_keys + .iter() + .map(|elem| elem.evaluate_to_sort_column(batch)) + .collect::>>() + } + + /// evaluate the partition points given the sort columns; if the sort columns are + /// empty then the result will be a single element vec of the whole column rows. + fn evaluate_partition_points( + &self, + num_rows: usize, + partition_columns: &[SortColumn], + ) -> Result>> { + if partition_columns.is_empty() { + Ok(vec![Range { + start: 0, + end: num_rows, + }]) + } else { + Ok(lexicographical_partition_ranges(partition_columns) + .map_err(DataFusionError::ArrowError)? + .collect::>()) + } + } } impl Stream for WindowAggStream { diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 085628051e17e..a5bd6a3b97c8e 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -2177,3 +2177,57 @@ async fn test_remove_unnecessary_sort_in_sub_query() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_window_agg_sort_orderby_reversed_partitionby_reversed_plan() -> Result<()> { + let config = SessionConfig::new().with_repartition_windows(false); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c3, + SUM(c9) OVER(ORDER BY c3 DESC, c9 DESC, c2 ASC) as sum1, + SUM(c9) OVER(PARTITION BY c3 ORDER BY c9 DESC ) as sum2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c3@3 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@0 as sum2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@1 DESC,c9@2 DESC,c2@0 ASC NULLS LAST]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----+-------------+------------+", + "| c3 | sum1 | sum2 |", + "+-----+-------------+------------+", + "| 125 | 3625286410 | 3625286410 |", + "| 123 | 7192027599 | 3566741189 |", + "| 123 | 9784358155 | 6159071745 |", + "| 122 | 13845993262 | 4061635107 |", + "| 120 | 16676974334 | 2830981072 |", + "+-----+-------------+------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 1f0c286efc859..1268559fe5688 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::iter::IntoIterator; +use std::ops::Range; use std::sync::Arc; use arrow::array::Array; @@ -90,58 +91,50 @@ impl WindowExpr for AggregateWindowExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let partition_columns = self.partition_columns(batch)?; - let partition_points = - self.evaluate_partition_points(batch.num_rows(), &partition_columns)?; let sort_options: Vec = self.order_by.iter().map(|o| o.options).collect(); let mut row_wise_results: Vec = vec![]; - for partition_range in &partition_points { - let mut accumulator = self.aggregate.create_accumulator()?; - let length = partition_range.end - partition_range.start; - let (values, order_bys) = - self.get_values_orderbys(&batch.slice(partition_range.start, length))?; - - let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); - let mut last_range: (usize, usize) = (0, 0); - - // We iterate on each row to perform a running calculation. - // First, cur_range is calculated, then it is compared with last_range. - for i in 0..length { - let cur_range = window_frame_ctx.calculate_range( - &order_bys, - &sort_options, - length, - i, - )?; - let value = if cur_range.0 == cur_range.1 { - // We produce None if the window is empty. - ScalarValue::try_from(self.aggregate.field()?.data_type())? - } else { - // Accumulate any new rows that have entered the window: - let update_bound = cur_range.1 - last_range.1; - if update_bound > 0 { - let update: Vec = values - .iter() - .map(|v| v.slice(last_range.1, update_bound)) - .collect(); - accumulator.update_batch(&update)? - } - // Remove rows that have now left the window: - let retract_bound = cur_range.0 - last_range.0; - if retract_bound > 0 { - let retract: Vec = values - .iter() - .map(|v| v.slice(last_range.0, retract_bound)) - .collect(); - accumulator.retract_batch(&retract)? - } - accumulator.evaluate()? - }; - row_wise_results.push(value); - last_range = cur_range; - } + + let mut accumulator = self.aggregate.create_accumulator()?; + let length = batch.num_rows(); + let (values, order_bys) = self.get_values_orderbys(batch)?; + + let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + let mut last_range = Range { start: 0, end: 0 }; + + // We iterate on each row to perform a running calculation. + // First, cur_range is calculated, then it is compared with last_range. + for i in 0..length { + let cur_range = + window_frame_ctx.calculate_range(&order_bys, &sort_options, length, i)?; + let value = if cur_range.end == cur_range.start { + // We produce None if the window is empty. + ScalarValue::try_from(self.aggregate.field()?.data_type())? + } else { + // Accumulate any new rows that have entered the window: + let update_bound = cur_range.end - last_range.end; + if update_bound > 0 { + let update: Vec = values + .iter() + .map(|v| v.slice(last_range.end, update_bound)) + .collect(); + accumulator.update_batch(&update)? + } + // Remove rows that have now left the window: + let retract_bound = cur_range.start - last_range.start; + if retract_bound > 0 { + let retract: Vec = values + .iter() + .map(|v| v.slice(last_range.start, retract_bound)) + .collect(); + accumulator.retract_batch(&retract)? + } + accumulator.evaluate()? + }; + row_wise_results.push(value); + last_range = cur_range; } + ScalarValue::iter_to_array(row_wise_results.into_iter()) } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 8e16fc7e7a627..0b1e5ee8f19cf 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -22,14 +22,13 @@ use super::BuiltInWindowFunctionExpr; use super::WindowExpr; use crate::window::window_expr::reverse_order_bys; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; -use arrow::compute::{concat, SortOptions}; +use arrow::compute::SortOptions; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::DataFusionError; use datafusion_common::Result; +use datafusion_common::ScalarValue; use datafusion_expr::WindowFrame; use std::any::Any; -use std::ops::Range; use std::sync::Arc; /// A window expr that takes the form of a built in window function @@ -92,47 +91,35 @@ impl WindowExpr for BuiltInWindowExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let evaluator = self.expr.create_evaluator()?; let num_rows = batch.num_rows(); - let partition_columns = self.partition_columns(batch)?; - let partition_points = - self.evaluate_partition_points(num_rows, &partition_columns)?; - - let results = if evaluator.uses_window_frame() { + if evaluator.uses_window_frame() { let sort_options: Vec = self.order_by.iter().map(|o| o.options).collect(); let mut row_wise_results = vec![]; - for partition_range in &partition_points { - let length = partition_range.end - partition_range.start; - let (values, order_bys) = self - .get_values_orderbys(&batch.slice(partition_range.start, length))?; - let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); - // We iterate on each row to calculate window frame range and and window function result - for idx in 0..length { - let range = window_frame_ctx.calculate_range( - &order_bys, - &sort_options, - num_rows, - idx, - )?; - let range = Range { - start: range.0, - end: range.1, - }; - let value = evaluator.evaluate_inside_range(&values, range)?; - row_wise_results.push(value.to_array()); - } + + let length = batch.num_rows(); + let (values, order_bys) = self.get_values_orderbys(batch)?; + let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + // We iterate on each row to calculate window frame range and and window function result + for idx in 0..length { + let range = window_frame_ctx.calculate_range( + &order_bys, + &sort_options, + num_rows, + idx, + )?; + let value = evaluator.evaluate_inside_range(&values, range)?; + row_wise_results.push(value); } - row_wise_results + ScalarValue::iter_to_array(row_wise_results.into_iter()) } else if evaluator.include_rank() { let columns = self.sort_columns(batch)?; let sort_partition_points = self.evaluate_partition_points(num_rows, &columns)?; - evaluator.evaluate_with_rank(partition_points, sort_partition_points)? + evaluator.evaluate_with_rank(num_rows, &sort_partition_points) } else { let (values, _) = self.get_values_orderbys(batch)?; - evaluator.evaluate(&values, partition_points)? - }; - let results = results.iter().map(|i| i.as_ref()).collect::>(); - concat(&results).map_err(DataFusionError::ArrowError) + evaluator.evaluate(&values, num_rows) + } } fn get_window_frame(&self) -> &Arc { diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs index 4202058a3c5a1..45fe51178afcd 100644 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ b/datafusion/physical-expr/src/window/cume_dist.rs @@ -73,19 +73,19 @@ impl PartitionEvaluator for CumeDistEvaluator { true } - fn evaluate_partition_with_rank( + fn evaluate_with_rank( &self, - partition: Range, + num_rows: usize, ranks_in_partition: &[Range], ) -> Result { - let scaler = (partition.end - partition.start) as f64; + let scalar = num_rows as f64; let result = Float64Array::from_iter_values( ranks_in_partition .iter() .scan(0_u64, |acc, range| { let len = range.end - range.start; *acc += len as u64; - let value: f64 = (*acc as f64) / scaler; + let value: f64 = (*acc as f64) / scalar; let result = iter::repeat(value).take(len); Some(result) }) @@ -102,15 +102,14 @@ mod tests { fn test_i32_result( expr: &CumeDist, - partition: Range, + num_rows: usize, ranks: Vec>, expected: Vec, ) -> Result<()> { let result = expr .create_evaluator()? - .evaluate_with_rank(vec![partition], ranks)?; - assert_eq!(1, result.len()); - let result = as_float64_array(&result[0])?; + .evaluate_with_rank(num_rows, &ranks)?; + let result = as_float64_array(&result)?; let result = result.values(); assert_eq!(expected, result); Ok(()) @@ -121,19 +120,19 @@ mod tests { let r = cume_dist("arr".into()); let expected = vec![0.0; 0]; - test_i32_result(&r, 0..0, vec![], expected)?; + test_i32_result(&r, 0, vec![], expected)?; let expected = vec![1.0; 1]; - test_i32_result(&r, 0..1, vec![0..1], expected)?; + test_i32_result(&r, 1, vec![0..1], expected)?; let expected = vec![1.0; 2]; - test_i32_result(&r, 0..2, vec![0..2], expected)?; + test_i32_result(&r, 2, vec![0..2], expected)?; let expected = vec![0.5, 0.5, 1.0, 1.0]; - test_i32_result(&r, 0..4, vec![0..2, 2..4], expected)?; + test_i32_result(&r, 4, vec![0..2, 2..4], expected)?; let expected = vec![0.25, 0.5, 0.75, 1.0]; - test_i32_result(&r, 0..4, vec![0..1, 1..2, 2..3, 3..4], expected)?; + test_i32_result(&r, 4, vec![0..1, 1..2, 2..3, 3..4], expected)?; Ok(()) } diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 97076963f60d4..f4c176262ae46 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -28,7 +28,6 @@ use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use std::any::Any; use std::ops::Neg; -use std::ops::Range; use std::sync::Arc; /// window shift expression @@ -178,15 +177,10 @@ fn shift_with_default_value( } impl PartitionEvaluator for WindowShiftEvaluator { - fn evaluate_partition( - &self, - values: &[ArrayRef], - partition: Range, - ) -> Result { + fn evaluate(&self, values: &[ArrayRef], _num_rows: usize) -> Result { // LEAD, LAG window functions take single column, values will have size 1 let value = &values[0]; - let value = value.slice(partition.start, partition.end - partition.start); - shift_with_default_value(&value, self.shift_offset, self.default_value.as_ref()) + shift_with_default_value(value, self.shift_offset, self.default_value.as_ref()) } } @@ -205,9 +199,10 @@ mod tests { let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; let values = expr.evaluate_args(&batch)?; - let result = expr.create_evaluator()?.evaluate(&values, vec![0..8])?; - assert_eq!(1, result.len()); - let result = as_int32_array(&result[0])?; + let result = expr + .create_evaluator()? + .evaluate(&values, batch.num_rows())?; + let result = as_int32_array(&result)?; assert_eq!(expected, *result); Ok(()) } diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index 1608758d61b38..86500441df5bc 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -22,23 +22,6 @@ use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; use std::ops::Range; -/// Given a partition range, and the full list of sort partition points, given that the sort -/// partition points are sorted using [partition columns..., order columns...], the split -/// boundaries would align (what's sorted on [partition columns...] would definitely be sorted -/// on finer columns), so this will use binary search to find ranges that are within the -/// partition range and return the valid slice. -pub(crate) fn find_ranges_in_range<'a>( - partition_range: &Range, - sort_partition_points: &'a [Range], -) -> &'a [Range] { - let start_idx = sort_partition_points - .partition_point(|sort_range| sort_range.start < partition_range.start); - let end_idx = start_idx - + sort_partition_points[start_idx..] - .partition_point(|sort_range| sort_range.end <= partition_range.end); - &sort_partition_points[start_idx..end_idx] -} - /// Partition evaluator pub trait PartitionEvaluator { /// Whether the evaluator should be evaluated with rank @@ -50,49 +33,17 @@ pub trait PartitionEvaluator { false } - /// evaluate the partition evaluator against the partitions - fn evaluate( - &self, - values: &[ArrayRef], - partition_points: Vec>, - ) -> Result> { - partition_points - .into_iter() - .map(|partition| self.evaluate_partition(values, partition)) - .collect() - } - - /// evaluate the partition evaluator against the partitions with rank information - fn evaluate_with_rank( - &self, - partition_points: Vec>, - sort_partition_points: Vec>, - ) -> Result> { - partition_points - .into_iter() - .map(|partition| { - let ranks_in_partition = - find_ranges_in_range(&partition, &sort_partition_points); - self.evaluate_partition_with_rank(partition, ranks_in_partition) - }) - .collect() - } - /// evaluate the partition evaluator against the partition - fn evaluate_partition( - &self, - _values: &[ArrayRef], - _partition: Range, - ) -> Result { + fn evaluate(&self, _values: &[ArrayRef], _num_rows: usize) -> Result { Err(DataFusionError::NotImplemented( "evaluate_partition is not implemented by default".into(), )) } /// evaluate the partition evaluator against the partition but with rank - fn evaluate_partition_with_rank( + fn evaluate_with_rank( &self, - _partition: Range, + _num_rows: usize, _ranks_in_partition: &[Range], ) -> Result { Err(DataFusionError::NotImplemented( diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 8ed0319a10b0f..87e01528de5a8 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -114,9 +114,9 @@ impl PartitionEvaluator for RankEvaluator { true } - fn evaluate_partition_with_rank( + fn evaluate_with_rank( &self, - partition: Range, + num_rows: usize, ranks_in_partition: &[Range], ) -> Result { // see https://www.postgresql.org/docs/current/functions-window.html @@ -132,7 +132,7 @@ impl PartitionEvaluator for RankEvaluator { )), RankType::Percent => { // Returns the relative rank of the current row, that is (rank - 1) / (total partition rows - 1). The value thus ranges from 0 to 1 inclusive. - let denominator = (partition.end - partition.start) as f64; + let denominator = num_rows as f64; Arc::new(Float64Array::from_iter_values( ranks_in_partition .iter() @@ -177,15 +177,14 @@ mod tests { fn test_f64_result( expr: &Rank, - range: Range, + num_rows: usize, ranks: Vec>, expected: Vec, ) -> Result<()> { let result = expr .create_evaluator()? - .evaluate_with_rank(vec![range], ranks)?; - assert_eq!(1, result.len()); - let result = as_float64_array(&result[0])?; + .evaluate_with_rank(num_rows, &ranks)?; + let result = as_float64_array(&result)?; let result = result.values(); assert_eq!(expected, result); Ok(()) @@ -196,11 +195,8 @@ mod tests { ranks: Vec>, expected: Vec, ) -> Result<()> { - let result = expr - .create_evaluator()? - .evaluate_with_rank(vec![0..8], ranks)?; - assert_eq!(1, result.len()); - let result = as_uint64_array(&result[0])?; + let result = expr.create_evaluator()?.evaluate_with_rank(8, &ranks)?; + let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(expected, result); Ok(()) @@ -228,19 +224,19 @@ mod tests { // empty case let expected = vec![0.0; 0]; - test_f64_result(&r, 0..0, vec![0..0; 0], expected)?; + test_f64_result(&r, 0, vec![0..0; 0], expected)?; // singleton case let expected = vec![0.0]; - test_f64_result(&r, 0..1, vec![0..1], expected)?; + test_f64_result(&r, 1, vec![0..1], expected)?; // uniform case let expected = vec![0.0; 7]; - test_f64_result(&r, 0..7, vec![0..7], expected)?; + test_f64_result(&r, 7, vec![0..7], expected)?; // non-trivial case let expected = vec![0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5]; - test_f64_result(&r, 0..7, vec![0..3, 3..7], expected)?; + test_f64_result(&r, 7, vec![0..3, 3..7], expected)?; Ok(()) } diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index f70d9ea379dd7..b27ac29d27640 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -24,7 +24,6 @@ use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::{DataType, Field}; use datafusion_common::Result; use std::any::Any; -use std::ops::Range; use std::sync::Arc; /// row_number expression @@ -69,12 +68,7 @@ impl BuiltInWindowFunctionExpr for RowNumber { pub(crate) struct NumRowsEvaluator {} impl PartitionEvaluator for NumRowsEvaluator { - fn evaluate_partition( - &self, - _values: &[ArrayRef], - partition: Range, - ) -> Result { - let num_rows = partition.end - partition.start; + fn evaluate(&self, _values: &[ArrayRef], num_rows: usize) -> Result { Ok(Arc::new(UInt64Array::from_iter_values( 1..(num_rows as u64) + 1, ))) @@ -99,9 +93,8 @@ mod tests { let values = row_number.evaluate_args(&batch)?; let result = row_number .create_evaluator()? - .evaluate(&values, vec![0..8])?; - assert_eq!(1, result.len()); - let result = as_uint64_array(&result[0])?; + .evaluate(&values, batch.num_rows())?; + let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result); Ok(()) @@ -118,9 +111,8 @@ mod tests { let values = row_number.evaluate_args(&batch)?; let result = row_number .create_evaluator()? - .evaluate(&values, vec![0..8])?; - assert_eq!(1, result.len()); - let result = as_uint64_array(&result[0])?; + .evaluate(&values, batch.num_rows())?; + let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result); Ok(()) diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index b1af6cf49a0ab..bc35dd49b50d4 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -17,7 +17,7 @@ use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::compute::kernels::partition::lexicographical_partition_ranges; -use arrow::compute::kernels::sort::{SortColumn, SortOptions}; +use arrow::compute::kernels::sort::SortColumn; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::{reverse_sort_options, DataFusionError, Result}; @@ -86,20 +86,6 @@ pub trait WindowExpr: Send + Sync + Debug { /// expressions that's from the window function's order by clause, empty if absent fn order_by(&self) -> &[PhysicalSortExpr]; - /// get partition columns that can be used for partitioning, empty if absent - fn partition_columns(&self, batch: &RecordBatch) -> Result> { - self.partition_by() - .iter() - .map(|expr| { - PhysicalSortExpr { - expr: expr.clone(), - options: SortOptions::default(), - } - .evaluate_to_sort_column(batch) - }) - .collect() - } - /// get order by columns, empty if absent fn order_by_columns(&self, batch: &RecordBatch) -> Result> { self.order_by() @@ -110,10 +96,8 @@ pub trait WindowExpr: Send + Sync + Debug { /// get sort columns that can be used for peer evaluation, empty if absent fn sort_columns(&self, batch: &RecordBatch) -> Result> { - let mut sort_columns = self.partition_columns(batch)?; let order_by_columns = self.order_by_columns(batch)?; - sort_columns.extend(order_by_columns); - Ok(sort_columns) + Ok(order_by_columns) } /// Get values columns(argument of Window Function) diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 307ea91440df9..b49bd3a22a78f 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -26,6 +26,7 @@ use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use std::cmp::min; use std::collections::VecDeque; use std::fmt::Debug; +use std::ops::Range; use std::sync::Arc; /// This object stores the window frame state for use in incremental calculations. @@ -68,7 +69,7 @@ impl<'a> WindowFrameContext<'a> { sort_options: &[SortOptions], length: usize, idx: usize, - ) -> Result<(usize, usize)> { + ) -> Result> { match *self { WindowFrameContext::Rows(window_frame) => { Self::calculate_range_rows(window_frame, length, idx) @@ -99,7 +100,7 @@ impl<'a> WindowFrameContext<'a> { window_frame: &Arc, length: usize, idx: usize, - ) -> Result<(usize, usize)> { + ) -> Result> { let start = match window_frame.start_bound { // UNBOUNDED PRECEDING WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0, @@ -152,7 +153,7 @@ impl<'a> WindowFrameContext<'a> { return Err(DataFusionError::Internal("Rows should be Uint".to_string())) } }; - Ok((start, end)) + Ok(Range { start, end }) } } @@ -171,7 +172,7 @@ impl WindowFrameStateRange { sort_options: &[SortOptions], length: usize, idx: usize, - ) -> Result<(usize, usize)> { + ) -> Result> { let start = match window_frame.start_bound { WindowFrameBound::Preceding(ref n) => { if n.is_null() { @@ -240,7 +241,7 @@ impl WindowFrameStateRange { } } }; - Ok((start, end)) + Ok(Range { start, end }) } /// This function does the heavy lifting when finding range boundaries. It is meant to be @@ -333,7 +334,7 @@ impl WindowFrameStateGroups { range_columns: &[ArrayRef], length: usize, idx: usize, - ) -> Result<(usize, usize)> { + ) -> Result> { if range_columns.is_empty() { return Err(DataFusionError::Execution( "GROUPS mode requires an ORDER BY clause".to_string(), @@ -399,7 +400,7 @@ impl WindowFrameStateGroups { )) } }; - Ok((start, end)) + Ok(Range { start, end }) } /// This function does the heavy lifting when finding group boundaries. It is meant to be From bf7bd11c9c822c77beb1010576c34d66b8ee6035 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 15 Dec 2022 16:06:03 +0300 Subject: [PATCH 06/17] Add naive sort removal rule --- .../remove_unnecessary_sorts.rs | 40 +++++++++++++++++-- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index ccaf93e0e60d3..4c38f1f1e3501 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -73,6 +73,11 @@ impl PhysicalOptimizerRule for RemoveUnnecessarySorts { fn remove_unnecessary_sorts( requirements: PlanWithCorrespondingSort, ) -> Result> { + // Do analysis of naive SortRemoval at the beginning + // Remove Sorts that are already satisfied + if let Some(res) = analyze_immediate_sort_removal(&requirements)? { + return Ok(Some(res)); + } let mut new_children = requirements.plan.children().clone(); let mut new_sort_onwards = requirements.sort_onwards.clone(); for (idx, (child, sort_onward)) in new_children @@ -90,7 +95,6 @@ fn remove_unnecessary_sorts( }); if !is_ordering_satisfied { // During sort Removal we have invalidated ordering invariant fix it - // This is effectively moving sort above in the physical plan update_child_to_remove_unnecessary_sort(child, sort_onward)?; let sort_expr = required_ordering.to_vec(); *child = add_sort_above_child(child, sort_expr)?; @@ -144,10 +148,10 @@ fn remove_unnecessary_sorts( .enumerate() .take(new_plan.children().len()) { - let is_require_ordering = new_plan.required_input_ordering()[idx].is_none(); + let requires_ordering = new_plan.required_input_ordering()[idx].is_some(); //TODO: when `maintains_input_order` returns `Vec` use corresponding index if new_plan.maintains_input_order() - && is_require_ordering + && !requires_ordering && !new_sort_onward.is_empty() { new_sort_onward.push((idx, new_plan.clone())); @@ -155,6 +159,8 @@ fn remove_unnecessary_sorts( new_sort_onward.clear(); new_sort_onward.push((idx, new_plan.clone())); } else { + // These executors use SortExec, hence doesn't propagate + // sort above in the physical plan new_sort_onward.clear(); } } @@ -233,6 +239,33 @@ impl TreeNodeRewritable for PlanWithCorrespondingSort { } } +/// Analyzes `SortExec` to determine whether this Sort can be removed +fn analyze_immediate_sort_removal( + requirements: &PlanWithCorrespondingSort, +) -> Result> { + if let Some(sort_exec) = requirements.plan.as_any().downcast_ref::() { + if ordering_satisfy( + sort_exec.input().output_ordering(), + sort_exec.output_ordering(), + || sort_exec.input().equivalence_properties(), + ) { + // This sort is unnecessary we should remove it + let new_plan = sort_exec.input(); + // Since we know that Sort have exactly one child we can use first index safely + assert_eq!(requirements.sort_onwards.len(), 1); + let mut new_sort_onward = requirements.sort_onwards[0].to_vec(); + if !new_sort_onward.is_empty() { + new_sort_onward.pop(); + } + return Ok(Some(PlanWithCorrespondingSort { + plan: new_plan.clone(), + sort_onwards: vec![new_sort_onward], + })); + } + } + Ok(None) +} + /// Analyzes `WindowAggExec` to determine whether Sort can be removed fn analyze_window_sort_removal( window_agg_exec: &WindowAggExec, @@ -284,6 +317,7 @@ fn analyze_window_sort_removal( Ok(None) } } + /// Updates child such that unnecessary sorting below it is removed fn update_child_to_remove_unnecessary_sort( child: &mut Arc, From 4cb7258fb49c6b911ae652b436fcdf1e8797ea38 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 15 Dec 2022 17:07:56 +0300 Subject: [PATCH 07/17] Add todo for finer Sort removal handling --- .../core/src/physical_optimizer/remove_unnecessary_sorts.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index 4c38f1f1e3501..e14a611db385a 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -106,6 +106,10 @@ fn remove_unnecessary_sorts( let sort_exec = convert_to_sort_exec(&sort_any)?; let sort_output_ordering = sort_exec.output_ordering(); let sort_input_ordering = sort_exec.input().output_ordering(); + // TODO: Once we can ensure required ordering propagates to above without changes + // (or with changes trackable) compare `sort_input_ordering` and and `required_ordering` + // this changes will enable us to remove (a,b) -> Sort -> (a,b,c) -> Required(a,b) Sort + // from the plan. With current implementation we cannot remove Sort from above configuration. // Do naive analysis, where a SortExec is already sorted according to desired Sorting if ordering_satisfy(sort_input_ordering, sort_output_ordering, || { sort_exec.input().equivalence_properties() From aa4f7393d1840bb0586c9a08f91b343d3f03ed14 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Mon, 19 Dec 2022 15:14:32 +0300 Subject: [PATCH 08/17] Refactors to improve readability and reduce nesting --- datafusion/common/src/lib.rs | 20 +++--- datafusion/core/src/execution/context.rs | 7 +- .../remove_unnecessary_sorts.rs | 11 +-- .../core/src/physical_optimizer/utils.rs | 67 ++++++++----------- .../core/src/physical_plan/repartition.rs | 2 +- .../physical_plan/windows/window_agg_exec.rs | 62 +++++++++-------- 6 files changed, 80 insertions(+), 89 deletions(-) diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 63683f5af0242..9911b1499a7b8 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -65,7 +65,7 @@ macro_rules! downcast_value { }}; } -/// Compute the "reverse" of given `SortOptions`. +/// Computes the "reverse" of given `SortOptions`. // TODO: If/when arrow supports `!` for `SortOptions`, we can remove this. pub fn reverse_sort_options(options: SortOptions) -> SortOptions { SortOptions { @@ -74,14 +74,18 @@ pub fn reverse_sort_options(options: SortOptions) -> SortOptions { } } -/// Transposes 2d vector +/// Transposes the given vector of vectors. pub fn transpose(original: Vec>) -> Vec> { - assert!(!original.is_empty()); - let mut transposed = (0..original[0].len()).map(|_| vec![]).collect::>(); - for original_row in original { - for (item, transposed_row) in original_row.into_iter().zip(&mut transposed) { - transposed_row.push(item); + match original.as_slice() { + [] => vec![], + [first, ..] => { + let mut result = (0..first.len()).map(|_| vec![]).collect::>(); + for row in original { + for (item, transposed_row) in row.into_iter().zip(&mut result) { + transposed_row.push(item); + } + } + result } } - transposed } diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 00de69eab7540..c503caca2edc9 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -1592,9 +1592,10 @@ impl SessionState { // To make sure the SinglePartition is satisfied, run the BasicEnforcement again, originally it was the AddCoalescePartitionsExec here. physical_optimizers.push(Arc::new(BasicEnforcement::new())); - // `BasicEnforcement` stage conservatively inserts `SortExec`s before `WindowAggExec`s without - // a deep analysis of window frames. Such analysis may sometimes reveal that a `SortExec` is - // actually unnecessary. The rule below performs this analysis and removes such `SortExec`s. + // `BasicEnforcement` stage conservatively inserts `SortExec`s to satisfy ordering requirements. + // However, a deeper analysis may sometimes reveal that such a `SortExec` is actually unnecessary. + // These cases typically arise when we have reversible `WindowAggExec`s or deep subqueries. The + // rule below performs this analysis and removes unnecessary `SortExec`s. physical_optimizers.push(Arc::new(RemoveUnnecessarySorts::new())); SessionState { diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index e14a611db385a..2f4e326c49f02 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -20,7 +20,7 @@ //! valid, or invalid physical plans (in terms of Sorting requirement) use crate::error::Result; use crate::physical_optimizer::utils::{ - add_sort_above_child, ordering_satisfy, ordering_satisfy_inner, + add_sort_above_child, ordering_satisfy, ordering_satisfy_concrete, }; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::rewrite::TreeNodeRewritable; @@ -89,10 +89,11 @@ fn remove_unnecessary_sorts( let physical_ordering = child.output_ordering(); match (required_ordering, physical_ordering) { (Some(required_ordering), Some(physical_ordering)) => { - let is_ordering_satisfied = - ordering_satisfy_inner(physical_ordering, required_ordering, || { - child.equivalence_properties() - }); + let is_ordering_satisfied = ordering_satisfy_concrete( + physical_ordering, + required_ordering, + || child.equivalence_properties(), + ); if !is_ordering_satisfied { // During sort Removal we have invalidated ordering invariant fix it update_child_to_remove_unnecessary_sort(child, sort_onward)?; diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index 7389da085d600..94394ca527ee6 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -51,7 +51,7 @@ pub fn optimize_children( } } -/// Check the required ordering requirements are satisfied by the provided PhysicalSortExprs. +/// Checks whether given ordering requirements are satisfied by provided [PhysicalSortExpr]s. pub fn ordering_satisfy EquivalenceProperties>( provided: Option<&[PhysicalSortExpr]>, required: Option<&[PhysicalSortExpr]>, @@ -61,56 +61,43 @@ pub fn ordering_satisfy EquivalenceProperties>( (_, None) => true, (None, Some(_)) => false, (Some(provided), Some(required)) => { - ordering_satisfy_inner(provided, required, equal_properties) + ordering_satisfy_concrete(provided, required, equal_properties) } } } -pub fn ordering_satisfy_inner EquivalenceProperties>( +pub fn ordering_satisfy_concrete EquivalenceProperties>( provided: &[PhysicalSortExpr], required: &[PhysicalSortExpr], equal_properties: F, ) -> bool { if required.len() > provided.len() { false - } else { - let fast_match = required + } else if required + .iter() + .zip(provided.iter()) + .all(|(order1, order2)| order1.eq(order2)) + { + true + } else if let eq_classes @ [_, ..] = equal_properties().classes() { + let normalized_required_exprs = required .iter() - .zip(provided.iter()) - .all(|(order1, order2)| order1.eq(order2)); - - if !fast_match { - let eq_properties = equal_properties(); - let eq_classes = eq_properties.classes(); - if !eq_classes.is_empty() { - let normalized_required_exprs = required - .iter() - .map(|e| { - normalize_sort_expr_with_equivalence_properties( - e.clone(), - eq_classes, - ) - }) - .collect::>(); - let normalized_provided_exprs = provided - .iter() - .map(|e| { - normalize_sort_expr_with_equivalence_properties( - e.clone(), - eq_classes, - ) - }) - .collect::>(); - normalized_required_exprs - .iter() - .zip(normalized_provided_exprs.iter()) - .all(|(order1, order2)| order1.eq(order2)) - } else { - fast_match - } - } else { - fast_match - } + .map(|e| { + normalize_sort_expr_with_equivalence_properties(e.clone(), eq_classes) + }) + .collect::>(); + let normalized_provided_exprs = provided + .iter() + .map(|e| { + normalize_sort_expr_with_equivalence_properties(e.clone(), eq_classes) + }) + .collect::>(); + normalized_required_exprs + .iter() + .zip(normalized_provided_exprs.iter()) + .all(|(order1, order2)| order1.eq(order2)) + } else { + false } } diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs index 12424c8587971..f7005d113e306 100644 --- a/datafusion/core/src/physical_plan/repartition.rs +++ b/datafusion/core/src/physical_plan/repartition.rs @@ -297,12 +297,12 @@ impl ExecutionPlan for RepartitionExec { } fn maintains_input_order(&self) -> bool { + // We preserve ordering when input partitioning is 1 let n_input = match self.input().output_partitioning() { Partitioning::RoundRobinBatch(n) => n, Partitioning::Hash(_, n) => n, Partitioning::UnknownPartitioning(n) => n, }; - // We preserve ordering when input partitioning is 1 n_input <= 1 } diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs index 837f32ac69b53..c709fe4942571 100644 --- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs @@ -136,23 +136,25 @@ impl WindowAggExec { self.input_schema.clone() } - /// Get Partition Columns + /// Get partition keys pub fn partition_by_sort_keys(&self) -> Result> { - // All window exprs have same partition by hance we just use first one + let mut result = vec![]; + // All window exprs have the same partition by, so we just use the first one: let partition_by = self.window_expr()[0].partition_by(); - let mut partition_columns = vec![]; - for elem in partition_by { - if let Some(sort_keys) = &self.sort_keys { - for a in sort_keys { - if a.expr.eq(elem) { - partition_columns.push(a.clone()); - break; - } - } + let sort_keys = self + .sort_keys + .as_ref() + .map_or_else(|| &[] as &[PhysicalSortExpr], |v| v.as_slice()); + for item in partition_by { + if let Some(a) = sort_keys.iter().find(|&e| e.expr.eq(item)) { + result.push(a.clone()); + } else { + return Err(DataFusionError::Execution( + "Partition key not found in sort keys".to_string(), + )); } } - assert_eq!(partition_by.len(), partition_columns.len()); - Ok(partition_columns) + Ok(result) } } @@ -395,12 +397,16 @@ impl WindowAggStream { let batch = concat_batches(&self.input.schema(), &self.batches)?; - // calculate window cols - let partition_columns = self.partition_columns(&batch)?; + let partition_by_sort_keys = self + .partition_by_sort_keys + .iter() + .map(|elem| elem.evaluate_to_sort_column(&batch)) + .collect::>>()?; let partition_points = - self.evaluate_partition_points(batch.num_rows(), &partition_columns)?; + self.evaluate_partition_points(batch.num_rows(), &partition_by_sort_keys)?; let mut partition_results = vec![]; + // Calculate window cols for partition_point in partition_points { let length = partition_point.end - partition_point.start; partition_results.push( @@ -425,31 +431,23 @@ impl WindowAggStream { RecordBatch::try_new(self.schema.clone(), columns) } - /// Get Partition Columns - pub fn partition_columns(&self, batch: &RecordBatch) -> Result> { - self.partition_by_sort_keys - .iter() - .map(|elem| elem.evaluate_to_sort_column(batch)) - .collect::>>() - } - - /// evaluate the partition points given the sort columns; if the sort columns are - /// empty then the result will be a single element vec of the whole column rows. + /// Evaluates the partition points given the sort columns. If the sort columns are + /// empty, then the result will be a single element vector spanning the entire batch. fn evaluate_partition_points( &self, num_rows: usize, partition_columns: &[SortColumn], ) -> Result>> { - if partition_columns.is_empty() { - Ok(vec![Range { + Ok(if partition_columns.is_empty() { + vec![Range { start: 0, end: num_rows, - }]) + }] } else { - Ok(lexicographical_partition_ranges(partition_columns) + lexicographical_partition_ranges(partition_columns) .map_err(DataFusionError::ArrowError)? - .collect::>()) - } + .collect::>() + }) } } From 6309b013efa2974d7581eccd346f694c419d5e12 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 19 Dec 2022 15:57:15 +0300 Subject: [PATCH 09/17] reverse expr returns Option (no need for support check) --- .../remove_unnecessary_sorts.rs | 37 +++++++++---------- .../physical-expr/src/aggregate/count.rs | 8 +--- datafusion/physical-expr/src/aggregate/mod.rs | 15 ++------ datafusion/physical-expr/src/aggregate/sum.rs | 8 +--- .../physical-expr/src/window/aggregate.rs | 22 +++++------ .../physical-expr/src/window/built_in.rs | 22 +++++------ .../window/built_in_window_function_expr.rs | 15 ++------ .../physical-expr/src/window/lead_lag.rs | 8 +--- .../physical-expr/src/window/nth_value.rs | 17 ++------- .../physical-expr/src/window/window_expr.rs | 5 +-- 10 files changed, 55 insertions(+), 102 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index 2f4e326c49f02..461873b9da70b 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -295,32 +295,29 @@ fn analyze_window_sort_removal( &sort_exec.input().schema(), physical_ordering, )?; - let all_window_fns_reversible = - window_expr.iter().all(|e| e.is_window_fn_reversible()); - let is_reversal_blocking = should_reverse && !all_window_fns_reversible; - - if can_skip_sorting && !is_reversal_blocking { - let window_expr = if should_reverse { + if can_skip_sorting { + let new_window_expr = if should_reverse { window_expr .iter() .map(|e| e.get_reversed_expr()) - .collect::>>()? + .collect::>>() } else { - window_expr.to_vec() + Some(window_expr.to_vec()) }; - let new_child = remove_corresponding_sort_from_sub_plan(sort_onward)?; - let new_schema = new_child.schema(); - let new_plan = Arc::new(WindowAggExec::try_new( - window_expr, - new_child, - new_schema, - window_agg_exec.partition_keys.clone(), - Some(physical_ordering.to_vec()), - )?); - Ok(Some(PlanWithCorrespondingSort::new(new_plan))) - } else { - Ok(None) + if let Some(window_expr) = new_window_expr { + let new_child = remove_corresponding_sort_from_sub_plan(sort_onward)?; + let new_schema = new_child.schema(); + let new_plan = Arc::new(WindowAggExec::try_new( + window_expr, + new_child, + new_schema, + window_agg_exec.partition_keys.clone(), + Some(physical_ordering.to_vec()), + )?); + return Ok(Some(PlanWithCorrespondingSort::new(new_plan))); + } } + Ok(None) } /// Updates child such that unnecessary sorting below it is removed diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index e5483495eb731..ea2fe4d15fb76 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -105,12 +105,8 @@ impl AggregateExpr for Count { Ok(Box::new(CountRowAccumulator::new(start_index))) } - fn is_window_fn_reversible(&self) -> bool { - true - } - - fn reverse_expr(&self) -> Result> { - Ok(Arc::new(self.clone())) + fn reverse_expr(&self) -> Option> { + Some(Arc::new(self.clone())) } } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index b0776886e69d5..325513293f5fd 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -102,18 +102,9 @@ pub trait AggregateExpr: Send + Sync + Debug { ))) } - /// Get whether window function is reversible - /// make `true` if `reverse_expr` is implemented - fn is_window_fn_reversible(&self) -> bool { - false - } - /// Construct Reverse Expression - // Typically expression itself for aggregate functions - fn reverse_expr(&self) -> Result> { - Err(DataFusionError::NotImplemented(format!( - "reverse_expr hasn't been implemented for {:?} yet", - self - ))) + // Typically expression itself for aggregate functions By default returns None + fn reverse_expr(&self) -> Option> { + None } } diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index c3bc4dbfaecdf..ebee61010018b 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -133,12 +133,8 @@ impl AggregateExpr for Sum { ))) } - fn is_window_fn_reversible(&self) -> bool { - true - } - - fn reverse_expr(&self) -> Result> { - Ok(Arc::new(self.clone())) + fn reverse_expr(&self) -> Option> { + Some(Arc::new(self.clone())) } } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 1268559fe5688..4d4509df028eb 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -150,16 +150,16 @@ impl WindowExpr for AggregateWindowExpr { &self.window_frame } - fn is_window_fn_reversible(&self) -> bool { - self.aggregate.as_ref().is_window_fn_reversible() - } - - fn get_reversed_expr(&self) -> Result> { - Ok(Arc::new(AggregateWindowExpr::new( - self.aggregate.reverse_expr()?, - &self.partition_by.clone(), - &reverse_order_bys(&self.order_by), - Arc::new(self.window_frame.reverse()), - ))) + fn get_reversed_expr(&self) -> Option> { + if let Some(reverse_expr) = self.aggregate.reverse_expr() { + Some(Arc::new(AggregateWindowExpr::new( + reverse_expr, + &self.partition_by.clone(), + &reverse_order_bys(&self.order_by), + Arc::new(self.window_frame.reverse()), + ))) + } else { + None + } } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 0b1e5ee8f19cf..8d6dc6008abbf 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -126,16 +126,16 @@ impl WindowExpr for BuiltInWindowExpr { &self.window_frame } - fn is_window_fn_reversible(&self) -> bool { - self.expr.as_ref().is_window_fn_reversible() - } - - fn get_reversed_expr(&self) -> Result> { - Ok(Arc::new(BuiltInWindowExpr::new( - self.expr.reverse_expr()?, - &self.partition_by.clone(), - &reverse_order_bys(&self.order_by), - Arc::new(self.window_frame.reverse()), - ))) + fn get_reversed_expr(&self) -> Option> { + if let Some(reverse_expr) = self.expr.reverse_expr() { + Some(Arc::new(BuiltInWindowExpr::new( + reverse_expr, + &self.partition_by.clone(), + &reverse_order_bys(&self.order_by), + Arc::new(self.window_frame.reverse()), + ))) + } else { + None + } } } diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index 71b100b54e5e4..e7553e566fa2b 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -20,7 +20,7 @@ use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::Result; use std::any::Any; use std::sync::Arc; @@ -59,17 +59,8 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { /// Create built-in window evaluator with a batch fn create_evaluator(&self) -> Result>; - /// Get whether window function is reversible - /// make true if `reverse_expr` is implemented - fn is_window_fn_reversible(&self) -> bool { - false - } - /// Construct Reverse Expression - fn reverse_expr(&self) -> Result> { - Err(DataFusionError::NotImplemented(format!( - "reverse_expr hasn't been implemented for {:?} yet", - self - ))) + fn reverse_expr(&self) -> Option> { + None } } diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index f4c176262ae46..e18815c4c3a62 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -107,12 +107,8 @@ impl BuiltInWindowFunctionExpr for WindowShift { })) } - fn is_window_fn_reversible(&self) -> bool { - true - } - - fn reverse_expr(&self) -> Result> { - Ok(Arc::new(Self { + fn reverse_expr(&self) -> Option> { + Some(Arc::new(Self { name: self.name.clone(), data_type: self.data_type.clone(), shift_offset: -self.shift_offset, diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 2942a797b8526..e998b47018a52 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -124,24 +124,13 @@ impl BuiltInWindowFunctionExpr for NthValue { Ok(Box::new(NthValueEvaluator { kind: self.kind })) } - fn is_window_fn_reversible(&self) -> bool { - match self.kind { - NthValueKind::First | NthValueKind::Last => true, - NthValueKind::Nth(_) => false, - } - } - - fn reverse_expr(&self) -> Result> { + fn reverse_expr(&self) -> Option> { let reversed_kind = match self.kind { NthValueKind::First => NthValueKind::Last, NthValueKind::Last => NthValueKind::First, - NthValueKind::Nth(_) => { - return Err(DataFusionError::Execution( - "Cannot take reverse of NthValue".to_string(), - )) - } + NthValueKind::Nth(_) => return None, }; - Ok(Arc::new(Self { + Some(Arc::new(Self { name: self.name.clone(), expr: self.expr.clone(), data_type: self.data_type.clone(), diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index bc35dd49b50d4..78c3935c567d0 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -116,11 +116,8 @@ pub trait WindowExpr: Send + Sync + Debug { // Get window frame of this WindowExpr fn get_window_frame(&self) -> &Arc; - /// Get whether window function can be reversed - fn is_window_fn_reversible(&self) -> bool; - /// get reversed expression - fn get_reversed_expr(&self) -> Result>; + fn get_reversed_expr(&self) -> Option>; } /// Reverses the ORDER BY expression, which is useful during equivalent window From 91629b8e5ea6830fc4bab1d0b8b59d5025199fea Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 20 Dec 2022 11:44:37 +0300 Subject: [PATCH 10/17] fix tests --- datafusion/core/src/execution/context.rs | 2 +- .../src/physical_optimizer/enforcement.rs | 2 +- datafusion/core/tests/sql/window.rs | 55 +++++++++---------- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index c03a9ad996cf3..8c7a8cbe82c49 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -99,8 +99,8 @@ use url::Url; use crate::catalog::listing_schema::ListingSchemaProvider; use crate::datasource::object_store::ObjectStoreUrl; -use crate::physical_optimizer::remove_unnecessary_sorts::RemoveUnnecessarySorts; use crate::execution::memory_pool::MemoryPool; +use crate::physical_optimizer::remove_unnecessary_sorts::RemoveUnnecessarySorts; use uuid::Uuid; use super::options::{ diff --git a/datafusion/core/src/physical_optimizer/enforcement.rs b/datafusion/core/src/physical_optimizer/enforcement.rs index 4a76573004b0f..3e74860b5be15 100644 --- a/datafusion/core/src/physical_optimizer/enforcement.rs +++ b/datafusion/core/src/physical_optimizer/enforcement.rs @@ -30,7 +30,7 @@ use crate::physical_plan::joins::{ use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::rewrite::TreeNodeRewritable; -use crate::physical_plan::sorts::sort::SortOptions; +use crate::physical_plan::sorts::sort::{SortExec, SortOptions}; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::Partitioning; diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 3c922c0f2f657..d4c117e339a86 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1948,13 +1948,13 @@ async fn test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> { // We cannot reverse each window function (ROW_NUMBER is not reversible) let expected = { vec![ - "ProjectionExec: expr=[c9@5 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as rn2]", + "ProjectionExec: expr=[c9@5 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum2, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2]", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c9@4 DESC]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@4 ASC NULLS LAST,c1@2 ASC NULLS LAST,c2@3 ASC NULLS LAST]", + " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c9@2 ASC NULLS LAST,c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", + " SortExec: [c9@2 DESC,c1@0 DESC]", ] }; @@ -1969,15 +1969,15 @@ async fn test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> { let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------------+-------------+-------------+-----+", - "| c9 | sum1 | sum2 | rn2 |", - "+------------+-------------+-------------+-----+", - "| 4268716378 | 8498370520 | 24997484146 | 1 |", - "| 4229654142 | 12714811027 | 29012926487 | 2 |", - "| 4216440507 | 16858984380 | 28743001064 | 3 |", - "| 4144173353 | 20935849039 | 28472563256 | 4 |", - "| 4076864659 | 24997484146 | 28118515915 | 5 |", - "+------------+-------------+-------------+-----+", + "+-----------+------------+-----------+-----+", + "| c9 | sum1 | sum2 | rn2 |", + "+-----------+------------+-----------+-----+", + "| 28774375 | 745354217 | 91818943 | 100 |", + "| 63044568 | 988558066 | 232866360 | 99 |", + "| 141047417 | 1285934966 | 374546521 | 98 |", + "| 141680161 | 1654839259 | 519841132 | 97 |", + "| 145294611 | 1980231675 | 745354217 | 96 |", + "+-----------+------------+-----------+-----+", ]; assert_batches_eq!(expected, &actual); @@ -2038,19 +2038,19 @@ async fn test_window_agg_complex_plan() -> Result<()> { // Unnecessary SortExecs are removed let expected = { vec![ - "ProjectionExec: expr=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as a, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as b, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@4 as c, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@8 as d, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@12 as e, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@4 as f, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@8 as g, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as h, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as i, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as j, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as k, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@13 as l, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as m, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@17 as n, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@18 as o, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as p, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@2 as a1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@2 as b1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@6 as c1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@10 as d1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@14 as e1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@6 as f1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@10 as g1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as h1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as j1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as k1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as l1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@13 as m1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as n1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as o1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@3 as h11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@3 as j11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@7 as k11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@11 as l11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@15 as m11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@7 as n11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@11 as o11]", + "ProjectionExec: expr=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as a, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as b, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@15 as c, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@7 as d, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@11 as e, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@15 as f, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@7 as g, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as h, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as i, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as j, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as k, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@12 as l, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as m, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as n, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as o, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as p, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@2 as a1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@2 as b1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@17 as c1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@9 as d1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@13 as e1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@17 as f1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@9 as g1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as h1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as j1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as k1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as l1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@12 as m1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as n1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as o1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@3 as h11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@3 as j11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@18 as k11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@10 as l11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@14 as m11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@18 as n11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@10 as o11]", " GlobalLimitExec: skip=0, fetch=5", " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", - " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", - " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", - " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", - " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", - " SortExec: [c3@5 ASC NULLS LAST,c2@4 ASC NULLS LAST]", - " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", - " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", - " SortExec: [c3@3 DESC,c1@1 ASC NULLS LAST]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@17 ASC NULLS LAST,c2@16 ASC NULLS LAST]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@16 ASC NULLS LAST,c1@14 ASC]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", - " SortExec: [c3@2 ASC NULLS LAST,c1@0 ASC]", + " SortExec: [c3@2 DESC,c1@0 ASC NULLS LAST]", ] }; @@ -2089,9 +2089,8 @@ async fn test_window_agg_sort_orderby_reversed_partitionby_plan() -> Result<()> vec![ "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c1@0 ASC,c9@1 DESC]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }, SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]", ] }; @@ -2146,7 +2145,7 @@ async fn test_window_agg_sort_partitionby_reversed_plan() -> Result<()> { " GlobalLimitExec: skip=0, fetch=5", " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c1@0 ASC,c9@1 DESC]", + " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]", ] }; From ae451a4db27b5a9906e3366a2a5349d11dc13591 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 20 Dec 2022 15:22:05 +0300 Subject: [PATCH 11/17] partition by and order by no longer ends up at the same window group --- datafusion/core/src/physical_plan/planner.rs | 2 +- datafusion/core/tests/sql/window.rs | 22 ++++--- datafusion/expr/src/logical_plan/builder.rs | 2 +- datafusion/expr/src/utils.rs | 60 +++++++++++++------- 4 files changed, 56 insertions(+), 30 deletions(-) diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 548a1dd36bd3d..ef7fa830a567a 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -582,7 +582,7 @@ impl DefaultPhysicalPlanner { let physical_input_schema = input_exec.schema(); let sort_keys = sort_keys .iter() - .map(|e| match e { + .map(|(e, _)| match e { Expr::Sort(expr::Sort { expr, asc, diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index d4c117e339a86..7d39a3c98db92 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1737,15 +1737,18 @@ async fn test_window_partition_by_order_by() -> Result<()> { let logical_plan = state.optimize(&plan)?; let physical_plan = state.create_physical_plan(&logical_plan).await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - // Only 1 SortExec was added let expected = { vec![ - "ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as COUNT(UInt8(1))]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", - " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", + "ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as COUNT(UInt8(1))]", + " WindowAggExec: wdw=[COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", + " SortExec: [c1@1 ASC NULLS LAST,c2@2 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 2)", - " RepartitionExec: partitioning=RoundRobinBatch(2)", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 1 }], 2)", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", + " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }, Column { name: \"c2\", index: 1 }], 2)", + " RepartitionExec: partitioning=RoundRobinBatch(2)", ] }; @@ -2087,10 +2090,11 @@ async fn test_window_agg_sort_orderby_reversed_partitionby_plan() -> Result<()> // Only 1 SortExec was added let expected = { vec![ - "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", + "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum2]", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }, SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]", ] }; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index bf2a1d0018675..eeb3215c4b6fd 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -257,7 +257,7 @@ impl LogicalPlanBuilder { // The sort_by() implementation here is a stable sort. // Note that by this rule if there's an empty over, it'll be at the top level groups.sort_by(|(key_a, _), (key_b, _)| { - for (first, second) in key_a.iter().zip(key_b.iter()) { + for ((first, _), (second, _)) in key_a.iter().zip(key_b.iter()) { let key_ordering = compare_sort_expr(first, second, plan.schema()); match key_ordering { Ordering::Less => { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 2577c3a1970ca..1ab4595d1c886 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -203,7 +203,7 @@ pub fn expand_qualified_wildcard( expand_wildcard(&qualifier_schema, plan) } -type WindowSortKey = Vec; +type WindowSortKey = Vec<(Expr, bool)>; /// Generate a sort key for a given window expr's partition_by and order_bu expr pub fn generate_sort_key( @@ -223,6 +223,7 @@ pub fn generate_sort_key( .collect::>>()?; let mut final_sort_keys = vec![]; + let mut is_partition_flag = vec![]; partition_by.iter().for_each(|e| { // By default, create sort key with ASC is true and NULLS LAST to be consistent with // PostgreSQL's rule: https://www.postgresql.org/docs/current/queries-order.html @@ -231,18 +232,26 @@ pub fn generate_sort_key( let order_by_key = &order_by[pos]; if !final_sort_keys.contains(order_by_key) { final_sort_keys.push(order_by_key.clone()); + is_partition_flag.push(true); } } else if !final_sort_keys.contains(&e) { final_sort_keys.push(e); + is_partition_flag.push(true); } }); order_by.iter().for_each(|e| { if !final_sort_keys.contains(e) { final_sort_keys.push(e.clone()); + is_partition_flag.push(false); } }); - Ok(final_sort_keys) + let res = final_sort_keys + .into_iter() + .zip(is_partition_flag) + .map(|(lhs, rhs)| (lhs, rhs)) + .collect::>(); + Ok(res) } /// Compare the sort expr as PostgreSQL's common_prefix_cmp(): @@ -1027,9 +1036,13 @@ mod tests { let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs)?; - let key1 = vec![age_asc.clone(), name_desc.clone()]; + let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)]; let key2 = vec![]; - let key3 = vec![name_desc, age_asc, created_at_desc]; + let key3 = vec![ + (name_desc, false), + (age_asc, false), + (created_at_desc, false), + ]; let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![ (key1, vec![&max1, &min3]), @@ -1096,21 +1109,30 @@ mod tests { ]; let expected = vec![ - Expr::Sort(Sort { - expr: Box::new(col("age")), - asc: asc_, - nulls_first: nulls_first_, - }), - Expr::Sort(Sort { - expr: Box::new(col("name")), - asc: asc_, - nulls_first: nulls_first_, - }), - Expr::Sort(Sort { - expr: Box::new(col("created_at")), - asc: true, - nulls_first: false, - }), + ( + Expr::Sort(Sort { + expr: Box::new(col("age")), + asc: asc_, + nulls_first: nulls_first_, + }), + true, + ), + ( + Expr::Sort(Sort { + expr: Box::new(col("name")), + asc: asc_, + nulls_first: nulls_first_, + }), + true, + ), + ( + Expr::Sort(Sort { + expr: Box::new(col("created_at")), + asc: true, + nulls_first: false, + }), + true, + ), ]; let result = generate_sort_key(partition_by, order_by)?; assert_eq!(expected, result); From 0e73945ae601e33849aaddb9defdce20a4a75b9d Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Wed, 21 Dec 2022 14:14:42 +0300 Subject: [PATCH 12/17] Refactor to simplify code --- .../remove_unnecessary_sorts.rs | 309 +++++++++--------- 1 file changed, 153 insertions(+), 156 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index 461873b9da70b..6a41ea4f285c9 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -15,9 +15,16 @@ // specific language governing permissions and limitations // under the License. -//! Remove Unnecessary Sorts optimizer rule is used to for removing unnecessary SortExec's inserted to -//! physical plan. Produces a valid physical plan (in terms of Sorting requirement). Its input can be either -//! valid, or invalid physical plans (in terms of Sorting requirement) +//! RemoveUnnecessarySorts optimizer rule inspects SortExec's in the given +//! physical plan and removes the ones it can prove unnecessary. The rule can +//! work on valid *and* invalid physical plans with respect to sorting +//! requirements, but always produces a valid physical plan in this sense. +//! +//! A non-realistic but easy to follow example: Assume that we somehow get the fragment +//! "SortExec: [nullable_col@0 ASC]", +//! " SortExec: [non_nullable_col@1 ASC]", +//! in the physical plan. The first sort is unnecessary since its result is overwritten +//! by another SortExec. Therefore, this rule removes it from the physical plan. use crate::error::Result; use crate::physical_optimizer::utils::{ add_sort_above_child, ordering_satisfy, ordering_satisfy_concrete, @@ -31,14 +38,12 @@ use crate::prelude::SessionConfig; use arrow::datatypes::SchemaRef; use datafusion_common::{reverse_sort_options, DataFusionError}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use itertools::izip; use std::iter::zip; use std::sync::Arc; -/// As an example Assume we get -/// "SortExec: [nullable_col@0 ASC]", -/// " SortExec: [non_nullable_col@1 ASC]", somehow in the physical plan -/// The first Sort is unnecessary since, its result would be overwritten by another SortExec. We -/// remove first Sort from the physical plan +/// This rule inspects SortExec's in the given physical plan and removes the +/// ones it can prove unnecessary. #[derive(Default)] pub struct RemoveUnnecessarySorts {} @@ -49,13 +54,77 @@ impl RemoveUnnecessarySorts { } } +/// This is a "data class" we use within the [RemoveUnnecessarySorts] rule +/// that tracks the closest `SortExec` descendant for every child of a plan. +#[derive(Debug, Clone)] +struct PlanWithCorrespondingSort { + plan: Arc, + // For every child, keep a vector of `ExecutionPlan`s starting from the + // closest `SortExec` till the current plan. The first index of the tuple is + // the child index of the plan -- we need this information as we make updates. + sort_onwards: Vec)>>, +} + +impl PlanWithCorrespondingSort { + pub fn new(plan: Arc) -> Self { + let length = plan.children().len(); + PlanWithCorrespondingSort { + plan, + sort_onwards: vec![vec![]; length], + } + } + + pub fn children(&self) -> Vec { + self.plan + .children() + .into_iter() + .map(|child| PlanWithCorrespondingSort::new(child)) + .collect() + } +} + +impl TreeNodeRewritable for PlanWithCorrespondingSort { + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.children(); + if children.is_empty() { + Ok(self) + } else { + let children_requirements = children + .into_iter() + .map(transform) + .collect::>>()?; + let children_plans = children_requirements + .iter() + .map(|elem| elem.plan.clone()) + .collect::>(); + let sort_onwards = children_requirements + .iter() + .map(|item| { + if item.sort_onwards.is_empty() { + vec![] + } else { + // TODO: When `maintains_input_order` returns Vec, + // pass the order-enforcing sort upwards. + item.sort_onwards[0].clone() + } + }) + .collect::>(); + let plan = with_new_children_if_necessary(self.plan, children_plans)?; + Ok(PlanWithCorrespondingSort { plan, sort_onwards }) + } + } +} + impl PhysicalOptimizerRule for RemoveUnnecessarySorts { fn optimize( &self, plan: Arc, _config: &SessionConfig, ) -> Result> { - // Run a bottom-up process to adjust input key ordering recursively + // Execute a post-order traversal to adjust input key ordering: let plan_requirements = PlanWithCorrespondingSort::new(plan); let adjusted = plan_requirements.transform_up(&remove_unnecessary_sorts)?; Ok(adjusted.plan) @@ -73,19 +142,20 @@ impl PhysicalOptimizerRule for RemoveUnnecessarySorts { fn remove_unnecessary_sorts( requirements: PlanWithCorrespondingSort, ) -> Result> { - // Do analysis of naive SortRemoval at the beginning - // Remove Sorts that are already satisfied - if let Some(res) = analyze_immediate_sort_removal(&requirements)? { - return Ok(Some(res)); + // Perform naive analysis at the beginning -- remove already-satisfied sorts: + if let Some(result) = analyze_immediate_sort_removal(&requirements)? { + return Ok(Some(result)); } - let mut new_children = requirements.plan.children().clone(); - let mut new_sort_onwards = requirements.sort_onwards.clone(); - for (idx, (child, sort_onward)) in new_children - .iter_mut() - .zip(new_sort_onwards.iter_mut()) - .enumerate() + let plan = &requirements.plan; + let mut new_children = plan.children().clone(); + let mut new_onwards = requirements.sort_onwards.clone(); + for (idx, (child, sort_onwards, required_ordering)) in izip!( + new_children.iter_mut(), + new_onwards.iter_mut(), + plan.required_input_ordering() + ) + .enumerate() { - let required_ordering = requirements.plan.required_input_ordering()[idx]; let physical_ordering = child.output_ordering(); match (required_ordering, physical_ordering) { (Some(required_ordering), Some(physical_ordering)) => { @@ -95,202 +165,133 @@ fn remove_unnecessary_sorts( || child.equivalence_properties(), ); if !is_ordering_satisfied { - // During sort Removal we have invalidated ordering invariant fix it - update_child_to_remove_unnecessary_sort(child, sort_onward)?; + // Make sure we preserve the ordering requirements: + update_child_to_remove_unnecessary_sort(child, sort_onwards)?; let sort_expr = required_ordering.to_vec(); *child = add_sort_above_child(child, sort_expr)?; - // Since we have added Sort, we add it to the sort_onwards also. - sort_onward.push((idx, child.clone())) - } else if is_ordering_satisfied && !sort_onward.is_empty() { - // can do analysis for sort removal - let (_, sort_any) = sort_onward[0].clone(); + sort_onwards.push((idx, child.clone())) + } else if let [first, ..] = sort_onwards.as_slice() { + // The ordering requirement is met, we can analyze if there is an unnecessary sort: + let sort_any = first.1.clone(); let sort_exec = convert_to_sort_exec(&sort_any)?; let sort_output_ordering = sort_exec.output_ordering(); let sort_input_ordering = sort_exec.input().output_ordering(); - // TODO: Once we can ensure required ordering propagates to above without changes - // (or with changes trackable) compare `sort_input_ordering` and and `required_ordering` - // this changes will enable us to remove (a,b) -> Sort -> (a,b,c) -> Required(a,b) Sort - // from the plan. With current implementation we cannot remove Sort from above configuration. - // Do naive analysis, where a SortExec is already sorted according to desired Sorting + // Simple analysis: Does the input of the sort in question already satisfy the ordering requirements? if ordering_satisfy(sort_input_ordering, sort_output_ordering, || { sort_exec.input().equivalence_properties() }) { - update_child_to_remove_unnecessary_sort(child, sort_onward)?; + update_child_to_remove_unnecessary_sort(child, sort_onwards)?; } else if let Some(window_agg_exec) = requirements.plan.as_any().downcast_ref::() { - // For window expressions we can remove some Sorts when expression can be calculated in reverse order also. + // For window expressions, we can remove some sorts when we can + // calculate the result in reverse: if let Some(res) = analyze_window_sort_removal( window_agg_exec, sort_exec, - sort_onward, + sort_onwards, )? { return Ok(Some(res)); } } + // TODO: Once we can ensure that required ordering information propagates with + // necessary lineage information, compare `sort_input_ordering` and `required_ordering`. + // This will enable us to handle cases such as (a,b) -> Sort -> (a,b,c) -> Required(a,b). + // Currently, we can not remove such sorts. } } (Some(required), None) => { - // Requirement is not satisfied We should add Sort to the plan. + // Ordering requirement is not met, we should add a SortExec to the plan. let sort_expr = required.to_vec(); *child = add_sort_above_child(child, sort_expr)?; - *sort_onward = vec![(idx, child.clone())]; + *sort_onwards = vec![(idx, child.clone())]; } (None, Some(_)) => { - // Sort doesn't propagate to the layers above in the physical plan + // We have a SortExec whose effect may be neutralized by a order-imposing + // operator. In this case, remove this sort: if !requirements.plan.maintains_input_order() { - // Unnecessary Sort is added to the plan, we can remove unnecessary sort - update_child_to_remove_unnecessary_sort(child, sort_onward)?; + update_child_to_remove_unnecessary_sort(child, sort_onwards)?; } } (None, None) => {} } } - if !requirements.plan.children().is_empty() { + if plan.children().is_empty() { + Ok(Some(requirements)) + } else { let new_plan = requirements.plan.with_new_children(new_children)?; - for (idx, new_sort_onward) in new_sort_onwards + for (idx, (trace, required_ordering)) in new_onwards .iter_mut() + .zip(new_plan.required_input_ordering()) .enumerate() .take(new_plan.children().len()) { - let requires_ordering = new_plan.required_input_ordering()[idx].is_some(); - //TODO: when `maintains_input_order` returns `Vec` use corresponding index + // TODO: When `maintains_input_order` returns a `Vec`, use corresponding index. if new_plan.maintains_input_order() - && !requires_ordering - && !new_sort_onward.is_empty() + && required_ordering.is_none() + && !trace.is_empty() { - new_sort_onward.push((idx, new_plan.clone())); - } else if new_plan.as_any().is::() { - new_sort_onward.clear(); - new_sort_onward.push((idx, new_plan.clone())); + trace.push((idx, new_plan.clone())); } else { - // These executors use SortExec, hence doesn't propagate - // sort above in the physical plan - new_sort_onward.clear(); + trace.clear(); + if new_plan.as_any().is::() { + trace.push((idx, new_plan.clone())); + } } } Ok(Some(PlanWithCorrespondingSort { plan: new_plan, - sort_onwards: new_sort_onwards, + sort_onwards: new_onwards, })) - } else { - Ok(Some(requirements)) - } -} - -#[derive(Debug, Clone)] -struct PlanWithCorrespondingSort { - plan: Arc, - // For each child keeps a vector of `ExecutionPlan`s starting from SortExec till current plan - // first index of tuple(usize) is child index of plan (we need during updating plan above) - sort_onwards: Vec)>>, -} - -impl PlanWithCorrespondingSort { - pub fn new(plan: Arc) -> Self { - let children_len = plan.children().len(); - PlanWithCorrespondingSort { - plan, - sort_onwards: vec![vec![]; children_len], - } - } - - pub fn children(&self) -> Vec { - let plan_children = self.plan.children(); - plan_children - .into_iter() - .map(|child| { - let length = child.children().len(); - PlanWithCorrespondingSort { - plan: child, - sort_onwards: vec![vec![]; length], - } - }) - .collect() - } -} - -impl TreeNodeRewritable for PlanWithCorrespondingSort { - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { - let children = self.children(); - if !children.is_empty() { - let new_children: Result> = - children.into_iter().map(transform).collect(); - let children_requirements = new_children?; - let children_plans = children_requirements - .iter() - .map(|elem| elem.plan.clone()) - .collect::>(); - let sort_onwards = children_requirements - .iter() - .map(|elem| { - if !elem.sort_onwards.is_empty() { - // TODO: redirect the true sort onwards to above (the one we keep ordering) - // this is possible when maintains_input_order returns vec - elem.sort_onwards[0].clone() - } else { - vec![] - } - }) - .collect::>(); - let plan = with_new_children_if_necessary(self.plan, children_plans)?; - Ok(PlanWithCorrespondingSort { plan, sort_onwards }) - } else { - Ok(self) - } } } -/// Analyzes `SortExec` to determine whether this Sort can be removed +/// Analyzes a given `SortExec` to determine whether its input already has +/// a finer ordering than this `SortExec` enforces. fn analyze_immediate_sort_removal( requirements: &PlanWithCorrespondingSort, ) -> Result> { if let Some(sort_exec) = requirements.plan.as_any().downcast_ref::() { + // If this sort is unnecessary, we should remove it: if ordering_satisfy( sort_exec.input().output_ordering(), sort_exec.output_ordering(), || sort_exec.input().equivalence_properties(), ) { - // This sort is unnecessary we should remove it - let new_plan = sort_exec.input(); - // Since we know that Sort have exactly one child we can use first index safely - assert_eq!(requirements.sort_onwards.len(), 1); - let mut new_sort_onward = requirements.sort_onwards[0].to_vec(); - if !new_sort_onward.is_empty() { - new_sort_onward.pop(); + // Since we know that a `SortExec` has exactly one child, + // we can use the zero index safely: + let mut new_onwards = requirements.sort_onwards[0].to_vec(); + if !new_onwards.is_empty() { + new_onwards.pop(); } return Ok(Some(PlanWithCorrespondingSort { - plan: new_plan.clone(), - sort_onwards: vec![new_sort_onward], + plan: sort_exec.input().clone(), + sort_onwards: vec![new_onwards], })); } } Ok(None) } -/// Analyzes `WindowAggExec` to determine whether Sort can be removed +/// Analyzes a `WindowAggExec` to determine whether it may allow removing a sort. fn analyze_window_sort_removal( window_agg_exec: &WindowAggExec, sort_exec: &SortExec, sort_onward: &mut Vec<(usize, Arc)>, ) -> Result> { let required_ordering = sort_exec.output_ordering().ok_or_else(|| { - DataFusionError::Plan("SortExec should have output ordering".to_string()) + DataFusionError::Plan("A SortExec should have output ordering".to_string()) })?; let physical_ordering = sort_exec.input().output_ordering(); let physical_ordering = if let Some(physical_ordering) = physical_ordering { physical_ordering } else { - // If there is no physical ordering, there is no way to remove Sorting, immediately return + // If there is no physical ordering, there is no way to remove a sort -- immediately return: return Ok(None); }; let window_expr = window_agg_exec.window_expr(); - let partition_keys = window_expr[0].partition_by().to_vec(); let (can_skip_sorting, should_reverse) = can_skip_sort( - &partition_keys, + window_expr[0].partition_by(), required_ordering, &sort_exec.input().schema(), physical_ordering, @@ -320,7 +321,7 @@ fn analyze_window_sort_removal( Ok(None) } -/// Updates child such that unnecessary sorting below it is removed +/// Updates child to remove the unnecessary sorting below it. fn update_child_to_remove_unnecessary_sort( child: &mut Arc, sort_onwards: &mut Vec<(usize, Arc)>, @@ -331,34 +332,30 @@ fn update_child_to_remove_unnecessary_sort( Ok(()) } -/// Convert dyn ExecutionPlan to SortExec (Assumes it is SortExec) +/// Converts an [ExecutionPlan] trait object to a [SortExec] when possible. fn convert_to_sort_exec(sort_any: &Arc) -> Result<&SortExec> { - let sort_exec = sort_any - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Plan("First layer should start from SortExec".to_string()) - })?; - Ok(sort_exec) + sort_any.as_any().downcast_ref::().ok_or_else(|| { + DataFusionError::Plan("Given ExecutionPlan is not a SortExec".to_string()) + }) } -/// Removes the sort from the plan in the `sort_onwards` +/// Removes the sort from the plan in `sort_onwards`. fn remove_corresponding_sort_from_sub_plan( sort_onwards: &mut Vec<(usize, Arc)>, ) -> Result> { let (sort_child_idx, sort_any) = sort_onwards[0].clone(); let sort_exec = convert_to_sort_exec(&sort_any)?; let mut prev_layer = sort_exec.input().clone(); - let mut prev_layer_child_idx = sort_child_idx; - // We start from 1 hence since first one is sort and we are removing it - // from the plan - for (cur_layer_child_idx, cur_layer) in sort_onwards.iter().skip(1) { - let mut new_children = cur_layer.children(); - new_children[prev_layer_child_idx] = prev_layer; - prev_layer = cur_layer.clone().with_new_children(new_children)?; - prev_layer_child_idx = *cur_layer_child_idx; + let mut prev_child_idx = sort_child_idx; + // In the loop below, se start from 1 as the first one is a SortExec + // and we are removing it from the plan. + for (child_idx, layer) in sort_onwards.iter().skip(1) { + let mut children = layer.children(); + children[prev_child_idx] = prev_layer; + prev_layer = layer.clone().with_new_children(children)?; + prev_child_idx = *child_idx; } - // We have removed the corresponding sort hence empty the sort_onwards + // We have removed the sort, hence empty the sort_onwards: sort_onwards.clear(); Ok(prev_layer) } From 4f145dd642896c010c37f1a116baf194463f45b2 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Wed, 21 Dec 2022 14:36:45 +0300 Subject: [PATCH 13/17] Better comments, change method names --- .../core/src/physical_optimizer/remove_unnecessary_sorts.rs | 2 +- datafusion/physical-expr/src/aggregate/mod.rs | 6 ++++-- datafusion/physical-expr/src/window/aggregate.rs | 2 +- datafusion/physical-expr/src/window/built_in.rs | 2 +- datafusion/physical-expr/src/window/window_expr.rs | 6 +++--- 5 files changed, 10 insertions(+), 8 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index 6a41ea4f285c9..593c31ded9a8b 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -300,7 +300,7 @@ fn analyze_window_sort_removal( let new_window_expr = if should_reverse { window_expr .iter() - .map(|e| e.get_reversed_expr()) + .map(|e| e.get_reverse_expr()) .collect::>>() } else { Some(window_expr.to_vec()) diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 325513293f5fd..41c541918e716 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -102,9 +102,11 @@ pub trait AggregateExpr: Send + Sync + Debug { ))) } - /// Construct Reverse Expression - // Typically expression itself for aggregate functions By default returns None + /// Construct an expression that calculates the aggregate in reverse. fn reverse_expr(&self) -> Option> { + // Typically the "reverse" expression is itself (e.g. SUM, COUNT). + // For aggregates that do not support calculation in reverse, + // returns None (which is the default value). None } } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 4d4509df028eb..8b2579e991f60 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -150,7 +150,7 @@ impl WindowExpr for AggregateWindowExpr { &self.window_frame } - fn get_reversed_expr(&self) -> Option> { + fn get_reverse_expr(&self) -> Option> { if let Some(reverse_expr) = self.aggregate.reverse_expr() { Some(Arc::new(AggregateWindowExpr::new( reverse_expr, diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 8d6dc6008abbf..a8d768d0357e4 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -126,7 +126,7 @@ impl WindowExpr for BuiltInWindowExpr { &self.window_frame } - fn get_reversed_expr(&self) -> Option> { + fn get_reverse_expr(&self) -> Option> { if let Some(reverse_expr) = self.expr.reverse_expr() { Some(Arc::new(BuiltInWindowExpr::new( reverse_expr, diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 78c3935c567d0..a718fa4cd3b36 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -113,11 +113,11 @@ pub trait WindowExpr: Send + Sync + Debug { Ok((values, order_bys)) } - // Get window frame of this WindowExpr + /// Get the window frame of this [WindowExpr]. fn get_window_frame(&self) -> &Arc; - /// get reversed expression - fn get_reversed_expr(&self) -> Option>; + /// Get the reverse expression of this [WindowExpr]. + fn get_reverse_expr(&self) -> Option>; } /// Reverses the ORDER BY expression, which is useful during equivalent window From 6b076211c4ed964d5afe938c4814d9490d5d8696 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 22 Dec 2022 11:01:25 +0300 Subject: [PATCH 14/17] Resolve errors introduced by syncing --- datafusion/core/tests/sql/window.rs | 62 +++++++------------- datafusion/physical-expr/src/window/ntile.rs | 9 +-- 2 files changed, 23 insertions(+), 48 deletions(-) diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index c57b40626fc79..41278e1208b78 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1748,7 +1748,7 @@ async fn test_window_partition_by_order_by() -> Result<()> { let msg = format!("Creating logical plan for '{}'", sql); let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await.unwrap(); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); let expected = { vec![ @@ -1788,10 +1788,8 @@ async fn test_window_agg_sort_reversed_plan() -> Result<()> { LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Only 1 SortExec was added let expected = { @@ -1846,10 +1844,8 @@ async fn test_window_agg_sort_reversed_plan_builtin() -> Result<()> { LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Only 1 SortExec was added let expected = { @@ -1900,10 +1896,8 @@ async fn test_window_agg_sort_non_reversed_plan() -> Result<()> { LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // We cannot reverse each window function (ROW_NUMBER is not reversible) let expected = { @@ -1956,10 +1950,8 @@ async fn test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> { LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // We cannot reverse each window function (ROW_NUMBER is not reversible) let expected = { @@ -2046,10 +2038,8 @@ async fn test_window_agg_complex_plan() -> Result<()> { LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Unnecessary SortExecs are removed let expected = { @@ -2095,10 +2085,8 @@ async fn test_window_agg_sort_orderby_reversed_partitionby_plan() -> Result<()> LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Only 1 SortExec was added let expected = { @@ -2150,10 +2138,8 @@ async fn test_window_agg_sort_partitionby_reversed_plan() -> Result<()> { LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Only 1 SortExec was added let expected = { @@ -2204,10 +2190,8 @@ async fn test_window_agg_sort_orderby_reversed_binary_expr() -> Result<()> { LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Only 1 SortExec was added let expected = { @@ -2261,10 +2245,8 @@ async fn test_remove_unnecessary_sort_in_sub_query() -> Result<()> { ORDER BY c1 ) AS a "; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Unnecessary Sort in the sub query is removed let expected = { @@ -2319,10 +2301,8 @@ async fn test_window_agg_sort_orderby_reversed_partitionby_reversed_plan() -> Re LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Only 1 SortExec was added let expected = { diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs index ed00c3c869550..f5844eccc63a8 100644 --- a/datafusion/physical-expr/src/window/ntile.rs +++ b/datafusion/physical-expr/src/window/ntile.rs @@ -26,7 +26,6 @@ use arrow::datatypes::Field; use arrow_schema::DataType; use datafusion_common::Result; use std::any::Any; -use std::ops::Range; use std::sync::Arc; #[derive(Debug)] @@ -70,12 +69,8 @@ pub(crate) struct NtileEvaluator { } impl PartitionEvaluator for NtileEvaluator { - fn evaluate_partition( - &self, - _values: &[ArrayRef], - partition: Range, - ) -> Result { - let num_rows = (partition.end - partition.start) as u64; + fn evaluate(&self, _values: &[ArrayRef], num_rows: usize) -> Result { + let num_rows = num_rows as u64; let mut vec: Vec = Vec::new(); for i in 0..num_rows { let res = i * self.n / num_rows; From ba388cb1ea632bd9d451ba263f8685dcd6f9d54d Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 23 Dec 2022 10:58:20 +0300 Subject: [PATCH 15/17] address reviews --- datafusion/common/src/lib.rs | 16 ------------ .../core/src/physical_optimizer/utils.rs | 1 + datafusion/core/src/physical_plan/common.rs | 25 +++++++++++++++++++ .../core/src/physical_plan/repartition.rs | 7 +----- .../physical_plan/windows/window_agg_exec.rs | 14 ++++++----- datafusion/expr/src/window_frame.rs | 4 ++- datafusion/physical-expr/src/aggregate/mod.rs | 6 ++--- .../physical-expr/src/window/aggregate.rs | 14 +++++------ .../physical-expr/src/window/built_in.rs | 10 +++----- .../window/built_in_window_function_expr.rs | 3 ++- .../src/window/sliding_aggregate.rs | 14 +++++------ 11 files changed, 59 insertions(+), 55 deletions(-) diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 9911b1499a7b8..392fa3f25a673 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -73,19 +73,3 @@ pub fn reverse_sort_options(options: SortOptions) -> SortOptions { nulls_first: !options.nulls_first, } } - -/// Transposes the given vector of vectors. -pub fn transpose(original: Vec>) -> Vec> { - match original.as_slice() { - [] => vec![], - [first, ..] => { - let mut result = (0..first.len()).map(|_| vec![]).collect::>(); - for row in original { - for (item, transposed_row) in row.into_iter().zip(&mut result) { - transposed_row.push(item); - } - } - result - } - } -} diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index 94394ca527ee6..8f1fe2d08213f 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -102,6 +102,7 @@ pub fn ordering_satisfy_concrete EquivalenceProperties>( } /// Util function to add SortExec above child +/// preserving the original partitioning pub fn add_sort_above_child( child: &Arc, sort_expr: Vec, diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index b29dc0cb8c119..1c36014f20123 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -266,6 +266,22 @@ impl Drop for AbortOnDropMany { } } +/// Transposes the given vector of vectors. +pub fn transpose(original: Vec>) -> Vec> { + match original.as_slice() { + [] => vec![], + [first, ..] => { + let mut result = (0..first.len()).map(|_| vec![]).collect::>(); + for row in original { + for (item, transposed_row) in row.into_iter().zip(&mut result) { + transposed_row.push(item); + } + } + result + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -332,6 +348,15 @@ mod tests { assert_eq!(actual, expected); Ok(()) } + + #[test] + fn test_transpose() -> Result<()> { + let in_data = vec![vec![1, 2, 3], vec![4, 5, 6]]; + let transposed = transpose(in_data); + let expected = vec![vec![1, 4], vec![2, 5], vec![3, 6]]; + assert_eq!(expected, transposed); + Ok(()) + } } /// Write in Arrow IPC format. diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs index f7005d113e306..3dc0c6d337cc3 100644 --- a/datafusion/core/src/physical_plan/repartition.rs +++ b/datafusion/core/src/physical_plan/repartition.rs @@ -298,12 +298,7 @@ impl ExecutionPlan for RepartitionExec { fn maintains_input_order(&self) -> bool { // We preserve ordering when input partitioning is 1 - let n_input = match self.input().output_partitioning() { - Partitioning::RoundRobinBatch(n) => n, - Partitioning::Hash(_, n) => n, - Partitioning::UnknownPartitioning(n) => n, - }; - n_input <= 1 + self.input().output_partitioning().partition_count() <= 1 } fn equivalence_properties(&self) -> EquivalenceProperties { diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs index c709fe4942571..d1ea0af69ad1a 100644 --- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs @@ -19,6 +19,7 @@ use crate::error::Result; use crate::execution::context::TaskContext; +use crate::physical_plan::common::transpose; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, @@ -37,7 +38,7 @@ use arrow::{ error::{ArrowError, Result as ArrowResult}, record_batch::RecordBatch, }; -use datafusion_common::{transpose, DataFusionError}; +use datafusion_common::DataFusionError; use datafusion_physical_expr::rewrite::TreeNodeRewritable; use datafusion_physical_expr::EquivalentClass; use futures::stream::Stream; @@ -136,15 +137,16 @@ impl WindowAggExec { self.input_schema.clone() } - /// Get partition keys + /// Return the output sort order of partition keys: For example + /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a + // We are sure that partition by columns are always at the beginning of sort_keys + // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely + // to calculate partition separation points pub fn partition_by_sort_keys(&self) -> Result> { let mut result = vec![]; // All window exprs have the same partition by, so we just use the first one: let partition_by = self.window_expr()[0].partition_by(); - let sort_keys = self - .sort_keys - .as_ref() - .map_or_else(|| &[] as &[PhysicalSortExpr], |v| v.as_slice()); + let sort_keys = self.sort_keys.as_deref().unwrap_or(&[]); for item in partition_by { if let Some(a) = sort_keys.iter().find(|&e| e.expr.eq(item)) { result.push(a.clone()); diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index aa25ecb05a0dd..100ea8e1ded10 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -114,7 +114,9 @@ impl WindowFrame { } } - /// Get reversed window frame + /// Get reversed window frame. For example + /// `3 ROWS PRECEDING AND 2 ROWS FOLLOWING` --> + /// `2 ROWS PRECEDING AND 3 ROWS FOLLOWING` pub fn reverse(&self) -> Self { let start_bound = match &self.end_bound { WindowFrameBound::Preceding(elem) => { diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 85352189bf53b..947336596292c 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -104,10 +104,10 @@ pub trait AggregateExpr: Send + Sync + Debug { } /// Construct an expression that calculates the aggregate in reverse. + /// Typically the "reverse" expression is itself (e.g. SUM, COUNT). + /// For aggregates that do not support calculation in reverse, + /// returns None (which is the default value). fn reverse_expr(&self) -> Option> { - // Typically the "reverse" expression is itself (e.g. SUM, COUNT). - // For aggregates that do not support calculation in reverse, - // returns None (which is the default value). None } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 357a9a12e9b2e..5c46f38f220ff 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -143,25 +143,23 @@ impl WindowExpr for AggregateWindowExpr { } fn get_reverse_expr(&self) -> Option> { - if let Some(reverse_expr) = self.aggregate.reverse_expr() { + self.aggregate.reverse_expr().map(|reverse_expr| { let reverse_window_frame = self.window_frame.reverse(); if reverse_window_frame.start_bound.is_unbounded() { - Some(Arc::new(AggregateWindowExpr::new( + Arc::new(AggregateWindowExpr::new( reverse_expr, &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), - ))) + )) as _ } else { - Some(Arc::new(SlidingAggregateWindowExpr::new( + Arc::new(SlidingAggregateWindowExpr::new( reverse_expr, &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), - ))) + )) as _ } - } else { - None - } + }) } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index a8d768d0357e4..9804432b2056d 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -127,15 +127,13 @@ impl WindowExpr for BuiltInWindowExpr { } fn get_reverse_expr(&self) -> Option> { - if let Some(reverse_expr) = self.expr.reverse_expr() { - Some(Arc::new(BuiltInWindowExpr::new( + self.expr.reverse_expr().map(|reverse_expr| { + Arc::new(BuiltInWindowExpr::new( reverse_expr, &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), - ))) - } else { - None - } + )) as _ + }) } } diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index e7553e566fa2b..c358403fefdac 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -59,7 +59,8 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { /// Create built-in window evaluator with a batch fn create_evaluator(&self) -> Result>; - /// Construct Reverse Expression + /// Construct Reverse Expression that produces the same result + /// on a reversed window. For example `lead(10)` --> `lag(10)` fn reverse_expr(&self) -> Option> { None } diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 6e0fa8b1476d5..2a0fa86b7fe33 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -151,25 +151,23 @@ impl WindowExpr for SlidingAggregateWindowExpr { } fn get_reverse_expr(&self) -> Option> { - if let Some(reverse_expr) = self.aggregate.reverse_expr() { + self.aggregate.reverse_expr().map(|reverse_expr| { let reverse_window_frame = self.window_frame.reverse(); if reverse_window_frame.start_bound.is_unbounded() { - Some(Arc::new(AggregateWindowExpr::new( + Arc::new(AggregateWindowExpr::new( reverse_expr, &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), - ))) + )) as _ } else { - Some(Arc::new(SlidingAggregateWindowExpr::new( + Arc::new(SlidingAggregateWindowExpr::new( reverse_expr, &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), - ))) + )) as _ } - } else { - None - } + }) } } From 572a1a4271651ec0f5223389c9d1ab694e476c1f Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 23 Dec 2022 13:23:39 +0300 Subject: [PATCH 16/17] address reviews --- datafusion/expr/src/utils.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 987ea4ade8204..ca06dfdb4aafe 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -204,6 +204,8 @@ pub fn expand_qualified_wildcard( expand_wildcard(&qualifier_schema, plan) } +/// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)") +/// if bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column type WindowSortKey = Vec<(Expr, bool)>; /// Generate a sort key for a given window expr's partition_by and order_bu expr From ef92d46b2a98358afe61927848c1d853e81761a1 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Mon, 26 Dec 2022 11:02:35 -0500 Subject: [PATCH 17/17] Rename to less confusing OptimizeSorts --- datafusion/core/src/execution/context.rs | 4 +-- datafusion/core/src/physical_optimizer/mod.rs | 2 +- ...unnecessary_sorts.rs => optimize_sorts.rs} | 34 +++++++++---------- 3 files changed, 20 insertions(+), 20 deletions(-) rename datafusion/core/src/physical_optimizer/{remove_unnecessary_sorts.rs => optimize_sorts.rs} (97%) diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 6d82a84ba7c57..978bde2a2ed8b 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -100,7 +100,7 @@ use url::Url; use crate::catalog::listing_schema::ListingSchemaProvider; use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::memory_pool::MemoryPool; -use crate::physical_optimizer::remove_unnecessary_sorts::RemoveUnnecessarySorts; +use crate::physical_optimizer::optimize_sorts::OptimizeSorts; use uuid::Uuid; use super::options::{ @@ -1585,7 +1585,7 @@ impl SessionState { // However, a deeper analysis may sometimes reveal that such a `SortExec` is actually unnecessary. // These cases typically arise when we have reversible `WindowAggExec`s or deep subqueries. The // rule below performs this analysis and removes unnecessary `SortExec`s. - physical_optimizers.push(Arc::new(RemoveUnnecessarySorts::new())); + physical_optimizers.push(Arc::new(OptimizeSorts::new())); let mut this = SessionState { session_id, diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index a69aa16c343bd..0fd0600fbe678 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -22,9 +22,9 @@ pub mod aggregate_statistics; pub mod coalesce_batches; pub mod enforcement; pub mod join_selection; +pub mod optimize_sorts; pub mod optimizer; pub mod pruning; -pub mod remove_unnecessary_sorts; pub mod repartition; mod utils; diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/optimize_sorts.rs similarity index 97% rename from datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs rename to datafusion/core/src/physical_optimizer/optimize_sorts.rs index 593c31ded9a8b..cb421b7b82fdd 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/optimize_sorts.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -//! RemoveUnnecessarySorts optimizer rule inspects SortExec's in the given -//! physical plan and removes the ones it can prove unnecessary. The rule can -//! work on valid *and* invalid physical plans with respect to sorting -//! requirements, but always produces a valid physical plan in this sense. +//! OptimizeSorts optimizer rule inspects [SortExec]s in the given physical +//! plan and removes the ones it can prove unnecessary. The rule can work on +//! valid *and* invalid physical plans with respect to sorting requirements, +//! but always produces a valid physical plan in this sense. //! //! A non-realistic but easy to follow example: Assume that we somehow get the fragment //! "SortExec: [nullable_col@0 ASC]", @@ -45,17 +45,17 @@ use std::sync::Arc; /// This rule inspects SortExec's in the given physical plan and removes the /// ones it can prove unnecessary. #[derive(Default)] -pub struct RemoveUnnecessarySorts {} +pub struct OptimizeSorts {} -impl RemoveUnnecessarySorts { +impl OptimizeSorts { #[allow(missing_docs)] pub fn new() -> Self { Self {} } } -/// This is a "data class" we use within the [RemoveUnnecessarySorts] rule -/// that tracks the closest `SortExec` descendant for every child of a plan. +/// This is a "data class" we use within the [OptimizeSorts] rule that +/// tracks the closest `SortExec` descendant for every child of a plan. #[derive(Debug, Clone)] struct PlanWithCorrespondingSort { plan: Arc, @@ -118,7 +118,7 @@ impl TreeNodeRewritable for PlanWithCorrespondingSort { } } -impl PhysicalOptimizerRule for RemoveUnnecessarySorts { +impl PhysicalOptimizerRule for OptimizeSorts { fn optimize( &self, plan: Arc, @@ -126,12 +126,12 @@ impl PhysicalOptimizerRule for RemoveUnnecessarySorts { ) -> Result> { // Execute a post-order traversal to adjust input key ordering: let plan_requirements = PlanWithCorrespondingSort::new(plan); - let adjusted = plan_requirements.transform_up(&remove_unnecessary_sorts)?; + let adjusted = plan_requirements.transform_up(&optimize_sorts)?; Ok(adjusted.plan) } fn name(&self) -> &str { - "RemoveUnnecessarySorts" + "OptimizeSorts" } fn schema_check(&self) -> bool { @@ -139,7 +139,7 @@ impl PhysicalOptimizerRule for RemoveUnnecessarySorts { } } -fn remove_unnecessary_sorts( +fn optimize_sorts( requirements: PlanWithCorrespondingSort, ) -> Result> { // Perform naive analysis at the beginning -- remove already-satisfied sorts: @@ -589,7 +589,7 @@ mod tests { expected, actual ); let optimized_physical_plan = - RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + OptimizeSorts::new().optimize(physical_plan, &conf)?; let formatted = displayable(optimized_physical_plan.as_ref()) .indent() .to_string(); @@ -690,7 +690,7 @@ mod tests { expected, actual ); let optimized_physical_plan = - RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + OptimizeSorts::new().optimize(physical_plan, &conf)?; let formatted = displayable(optimized_physical_plan.as_ref()) .indent() .to_string(); @@ -736,7 +736,7 @@ mod tests { expected, actual ); let optimized_physical_plan = - RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + OptimizeSorts::new().optimize(physical_plan, &conf)?; let formatted = displayable(optimized_physical_plan.as_ref()) .indent() .to_string(); @@ -803,7 +803,7 @@ mod tests { expected, actual ); let optimized_physical_plan = - RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + OptimizeSorts::new().optimize(physical_plan, &conf)?; let formatted = displayable(optimized_physical_plan.as_ref()) .indent() .to_string(); @@ -865,7 +865,7 @@ mod tests { expected, actual ); let optimized_physical_plan = - RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + OptimizeSorts::new().optimize(physical_plan, &conf)?; let formatted = displayable(optimized_physical_plan.as_ref()) .indent() .to_string();