From 56db3139c2957c1f8c7f114ca222f1199002eb3b Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 13 Dec 2022 16:38:27 +0300 Subject: [PATCH 01/50] Sort Removal rule initial commit --- datafusion/common/src/lib.rs | 10 + datafusion/core/src/execution/context.rs | 6 + datafusion/core/src/physical_optimizer/mod.rs | 1 + .../remove_unnecessary_sorts.rs | 871 ++++++++++++++++++ .../core/src/physical_optimizer/utils.rs | 87 ++ datafusion/core/tests/sql/explain_analyze.rs | 5 - datafusion/core/tests/sql/window.rs | 529 +++++++++++ datafusion/expr/src/window_frame.rs | 27 + .../physical-expr/src/aggregate/count.rs | 10 +- datafusion/physical-expr/src/aggregate/mod.rs | 15 + datafusion/physical-expr/src/aggregate/sum.rs | 10 +- .../physical-expr/src/window/aggregate.rs | 14 + .../physical-expr/src/window/built_in.rs | 14 + .../window/built_in_window_function_expr.rs | 16 +- .../physical-expr/src/window/lead_lag.rs | 14 + .../physical-expr/src/window/nth_value.rs | 25 + .../physical-expr/src/window/window_expr.rs | 23 +- 17 files changed, 1667 insertions(+), 10 deletions(-) create mode 100644 datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 60d69324913ba..c6812ed24f8ae 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -30,6 +30,7 @@ pub mod stats; mod table_reference; pub mod test_util; +use arrow::compute::SortOptions; pub use column::Column; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema}; pub use error::{field_not_found, DataFusionError, Result, SchemaError}; @@ -63,3 +64,12 @@ macro_rules! downcast_value { })? }}; } + +/// Compute the "reverse" of given `SortOptions`. +// TODO: If/when arrow supports `!` for `SortOptions`, we can remove this. +pub fn reverse_sort_options(options: SortOptions) -> SortOptions { + SortOptions { + descending: !options.descending, + nulls_first: !options.nulls_first, + } +} diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 759cc79a8bf81..e0ebb8e018287 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -99,6 +99,7 @@ use url::Url; use crate::catalog::listing_schema::ListingSchemaProvider; use crate::datasource::object_store::ObjectStoreUrl; +use crate::physical_optimizer::remove_unnecessary_sorts::RemoveUnnecessarySorts; use uuid::Uuid; use super::options::{ @@ -1596,6 +1597,11 @@ impl SessionState { // To make sure the SinglePartition is satisfied, run the BasicEnforcement again, originally it was the AddCoalescePartitionsExec here. physical_optimizers.push(Arc::new(BasicEnforcement::new())); + // `BasicEnforcement` stage conservatively inserts `SortExec`s before `WindowAggExec`s without + // a deep analysis of window frames. Such analysis may sometimes reveal that a `SortExec` is + // actually unnecessary. The rule below performs this analysis and removes such `SortExec`s. + physical_optimizers.push(Arc::new(RemoveUnnecessarySorts::new())); + SessionState { session_id, optimizer: Optimizer::new(&optimizer_config), diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index 36b00a0e01bcd..a69aa16c343bd 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -24,6 +24,7 @@ pub mod enforcement; pub mod join_selection; pub mod optimizer; pub mod pruning; +pub mod remove_unnecessary_sorts; pub mod repartition; mod utils; diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs new file mode 100644 index 0000000000000..9c2e071dd8b84 --- /dev/null +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -0,0 +1,871 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Remove Unnecessary Sorts optimizer rule is used to for removing unnecessary SortExec's inserted to +//! physical plan. Produces a valid physical plan (in terms of Sorting requirement). Its input can be either +//! valid, or invalid physical plans (in terms of Sorting requirement) +use crate::error::Result; +use crate::physical_optimizer::utils::{ + add_sort_above_child, ordering_satisfy, ordering_satisfy_inner, +}; +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::rewrite::TreeNodeRewritable; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::windows::WindowAggExec; +use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use crate::prelude::SessionConfig; +use arrow::datatypes::SchemaRef; +use datafusion_common::{reverse_sort_options, DataFusionError}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use std::iter::zip; +use std::sync::Arc; + +/// As an example Assume we get +/// "SortExec: [nullable_col@0 ASC]", +/// " SortExec: [non_nullable_col@1 ASC]", somehow in the physical plan +/// The first Sort is unnecessary since, its result would be overwritten by another SortExec. We +/// remove first Sort from the physical plan +#[derive(Default)] +pub struct RemoveUnnecessarySorts {} + +impl RemoveUnnecessarySorts { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for RemoveUnnecessarySorts { + fn optimize( + &self, + plan: Arc, + _config: &SessionConfig, + ) -> Result> { + // Run a bottom-up process to adjust input key ordering recursively + let plan_requirements = PlanWithCorrespondingSort::new(plan); + let adjusted = plan_requirements.transform_up(&remove_unnecessary_sorts)?; + Ok(adjusted.plan) + } + + fn name(&self) -> &str { + "RemoveUnnecessarySorts" + } + + fn schema_check(&self) -> bool { + true + } +} + +fn remove_unnecessary_sorts( + requirements: PlanWithCorrespondingSort, +) -> Result> { + let mut new_children = requirements.plan.children().clone(); + let mut new_sort_onwards = requirements.sort_onwards.clone(); + for (idx, (child, sort_onward)) in new_children + .iter_mut() + .zip(new_sort_onwards.iter_mut()) + .enumerate() + { + let required_ordering = requirements.plan.required_input_ordering()[idx]; + let physical_ordering = child.output_ordering(); + match (required_ordering, physical_ordering) { + (Some(required_ordering), Some(physical_ordering)) => { + let is_ordering_satisfied = + ordering_satisfy_inner(physical_ordering, required_ordering, || { + child.equivalence_properties() + }); + if is_ordering_satisfied { + // can do analysis for sort removal + if !sort_onward.is_empty() { + let (_, sort_any) = sort_onward[0].clone(); + let sort_exec = sort_any + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Plan( + "First layer should start from SortExec".to_string(), + ) + })?; + let sort_output_ordering = sort_exec.output_ordering(); + let sort_input_ordering = sort_exec.input().output_ordering(); + // Do naive analysis, where a SortExec is already sorted according to desired Sorting + if ordering_satisfy( + sort_input_ordering, + sort_output_ordering, + || sort_exec.input().equivalence_properties(), + ) { + update_child_to_remove_unnecessary_sort(child, sort_onward)?; + } else if let Some(window_agg_exec) = + requirements.plan.as_any().downcast_ref::() + { + // For window expressions we can remove some Sorts when expression can be calculated in reverse order also. + if let Some(res) = + analyze_window_sort_removal(window_agg_exec, sort_onward)? + { + return Ok(Some(res)); + } + } + } + } else { + // During sort Removal we have invalidated ordering invariant fix it + // This is effectively moving sort above in the physical plan + update_child_to_remove_unnecessary_sort(child, sort_onward)?; + let sort_expr = required_ordering.to_vec(); + *child = add_sort_above_child(child, sort_expr)?; + // Since we have added Sort, we add it to the sort_onwards also. + sort_onward.push((idx, child.clone())) + } + } + (Some(required), None) => { + // Requirement is not satisfied We should add Sort to the plan. + let sort_expr = required.to_vec(); + *child = add_sort_above_child(child, sort_expr)?; + *sort_onward = vec![(idx, child.clone())]; + } + (None, Some(_)) => { + // Sort doesn't propagate to the layers above in the physical plan + if !requirements.plan.maintains_input_order() { + // Unnecessary Sort is added to the plan, we can remove unnecessary sort + update_child_to_remove_unnecessary_sort(child, sort_onward)?; + } + } + (None, None) => {} + } + } + if !requirements.plan.children().is_empty() { + let new_plan = requirements.plan.with_new_children(new_children)?; + for (idx, new_sort_onward) in new_sort_onwards + .iter_mut() + .enumerate() + .take(new_plan.children().len()) + { + let is_require_ordering = new_plan.required_input_ordering()[idx].is_none(); + //TODO: when `maintains_input_order` returns `Vec` use corresponding index + if new_plan.maintains_input_order() + && is_require_ordering + && !new_sort_onward.is_empty() + { + new_sort_onward.push((idx, new_plan.clone())); + } else if new_plan.as_any().is::() { + new_sort_onward.clear(); + new_sort_onward.push((idx, new_plan.clone())); + } else { + new_sort_onward.clear(); + } + } + Ok(Some(PlanWithCorrespondingSort { + plan: new_plan, + sort_onwards: new_sort_onwards, + })) + } else { + Ok(Some(requirements)) + } +} + +#[derive(Debug, Clone)] +struct PlanWithCorrespondingSort { + plan: Arc, + // For each child keeps a vector of `ExecutionPlan`s starting from SortExec till current plan + // first index of tuple(usize) is child index of plan (we need during updating plan above) + sort_onwards: Vec)>>, +} + +impl PlanWithCorrespondingSort { + pub fn new(plan: Arc) -> Self { + let children_len = plan.children().len(); + PlanWithCorrespondingSort { + plan, + sort_onwards: vec![vec![]; children_len], + } + } + + pub fn children(&self) -> Vec { + let plan_children = self.plan.children(); + plan_children + .into_iter() + .map(|child| { + let length = child.children().len(); + PlanWithCorrespondingSort { + plan: child, + sort_onwards: vec![vec![]; length], + } + }) + .collect() + } +} + +impl TreeNodeRewritable for PlanWithCorrespondingSort { + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.children(); + if !children.is_empty() { + let new_children: Result> = + children.into_iter().map(transform).collect(); + let children_requirements = new_children?; + let children_plans = children_requirements + .iter() + .map(|elem| elem.plan.clone()) + .collect::>(); + let sort_onwards = children_requirements + .iter() + .map(|elem| { + if !elem.sort_onwards.is_empty() { + // TODO: redirect the true sort onwards to above (the one we keep ordering) + // this is possible when maintains_input_order returns vec + elem.sort_onwards[0].clone() + } else { + vec![] + } + }) + .collect::>(); + let plan = with_new_children_if_necessary(self.plan, children_plans)?; + Ok(PlanWithCorrespondingSort { plan, sort_onwards }) + } else { + Ok(self) + } + } +} + +/// Analyzes `WindowAggExec` to determine whether Sort can be removed +fn analyze_window_sort_removal( + window_agg_exec: &WindowAggExec, + sort_onward: &mut Vec<(usize, Arc)>, +) -> Result> { + // If empty immediately return we cannot do analysis + if sort_onward.is_empty() { + return Ok(None); + } + let (_, sort_any) = sort_onward[0].clone(); + let sort_exec = sort_any + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Plan("First layer should start from SortExec".to_string()) + })?; + let required_ordering = sort_exec.output_ordering().ok_or_else(|| { + DataFusionError::Plan("SortExec should have output ordering".to_string()) + })?; + let physical_ordering = sort_exec.input().output_ordering(); + let physical_ordering = if let Some(physical_ordering) = physical_ordering { + physical_ordering + } else { + // If there is no physical ordering, there is no way to remove Sorting, immediately return + return Ok(None); + }; + let window_expr = window_agg_exec.window_expr(); + let partition_keys = window_expr[0].partition_by().to_vec(); + let (can_skip_sorting, should_reverse) = can_skip_sort( + &partition_keys, + required_ordering, + &sort_exec.input().schema(), + physical_ordering, + )?; + let all_window_fns_reversible = + window_expr.iter().all(|e| e.is_window_fn_reversible()); + let is_reversal_blocking = should_reverse && !all_window_fns_reversible; + + if can_skip_sorting && !is_reversal_blocking { + let window_expr = if should_reverse { + window_expr + .iter() + .map(|e| e.get_reversed_expr()) + .collect::>>()? + } else { + window_expr.to_vec() + }; + let new_child = remove_corresponding_sort_from_sub_plan(sort_onward)?; + let new_schema = new_child.schema(); + let new_plan = Arc::new(WindowAggExec::try_new( + window_expr, + new_child, + new_schema, + window_agg_exec.partition_keys.clone(), + Some(physical_ordering.to_vec()), + )?); + Ok(Some(PlanWithCorrespondingSort::new(new_plan))) + } else { + Ok(None) + } +} +/// Updates child such that unnecessary sorting below it is removed +fn update_child_to_remove_unnecessary_sort( + child: &mut Arc, + sort_onwards: &mut Vec<(usize, Arc)>, +) -> Result<()> { + if !sort_onwards.is_empty() { + *child = remove_corresponding_sort_from_sub_plan(sort_onwards)?; + } + Ok(()) +} +/// Removes the sort from the plan in the `sort_onwards` +fn remove_corresponding_sort_from_sub_plan( + sort_onwards: &mut Vec<(usize, Arc)>, +) -> Result> { + let (sort_child_idx, sort_any) = sort_onwards[0].clone(); + let sort_exec = sort_any + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Plan("First layer should start from SortExec".to_string()) + })?; + let mut prev_layer = sort_exec.input().clone(); + let mut prev_layer_child_idx = sort_child_idx; + // We start from 1 hence since first one is sort and we are removing it + // from the plan + for (cur_layer_child_idx, cur_layer) in sort_onwards.iter().skip(1) { + let mut new_children = cur_layer.children(); + new_children[prev_layer_child_idx] = prev_layer; + prev_layer = cur_layer.clone().with_new_children(new_children)?; + prev_layer_child_idx = *cur_layer_child_idx; + } + // We have removed the corresponding sort hence empty the sort_onwards + sort_onwards.clear(); + Ok(prev_layer) +} + +#[derive(Debug)] +/// This structure stores extra column information required to remove unnecessary sorts. +pub struct ColumnInfo { + is_aligned: bool, + reverse: bool, + is_partition: bool, +} + +/// Compares physical ordering and required ordering of all `PhysicalSortExpr`s and returns a tuple. +/// The first element indicates whether these `PhysicalSortExpr`s can be removed from the physical plan. +/// The second element is a flag indicating whether we should reverse the sort direction in order to +/// remove physical sort expressions from the plan. +pub fn can_skip_sort( + partition_keys: &[Arc], + required: &[PhysicalSortExpr], + input_schema: &SchemaRef, + physical_ordering: &[PhysicalSortExpr], +) -> Result<(bool, bool)> { + if required.len() > physical_ordering.len() { + return Ok((false, false)); + } + let mut col_infos = vec![]; + for (sort_expr, physical_expr) in zip(required, physical_ordering) { + let column = sort_expr.expr.clone(); + let is_partition = partition_keys.iter().any(|e| e.eq(&column)); + let (is_aligned, reverse) = + check_alignment(input_schema, physical_expr, sort_expr); + col_infos.push(ColumnInfo { + is_aligned, + reverse, + is_partition, + }); + } + let partition_by_sections = col_infos + .iter() + .filter(|elem| elem.is_partition) + .collect::>(); + let (can_skip_partition_bys, should_reverse_partition_bys) = + if partition_by_sections.is_empty() { + (true, false) + } else { + let first_reverse = partition_by_sections[0].reverse; + let can_skip_partition_bys = partition_by_sections + .iter() + .all(|c| c.is_aligned && c.reverse == first_reverse); + (can_skip_partition_bys, first_reverse) + }; + let order_by_sections = col_infos + .iter() + .filter(|elem| !elem.is_partition) + .collect::>(); + let (can_skip_order_bys, should_reverse_order_bys) = if order_by_sections.is_empty() { + (true, false) + } else { + let first_reverse = order_by_sections[0].reverse; + let can_skip_order_bys = order_by_sections + .iter() + .all(|c| c.is_aligned && c.reverse == first_reverse); + (can_skip_order_bys, first_reverse) + }; + // TODO: We cannot skip partition by keys when sort direction is reversed, + // by propogating partition by sort direction to `WindowAggExec` we can skip + // these columns also. Add support for that (Use direction during partition range calculation). + let can_skip = + can_skip_order_bys && can_skip_partition_bys && !should_reverse_partition_bys; + Ok((can_skip, should_reverse_order_bys)) +} + +/// Compares `physical_ordering` and `required` ordering, returns a tuple +/// indicating (1) whether this column requires sorting, and (2) whether we +/// should reverse the window expression in order to avoid sorting. +fn check_alignment( + input_schema: &SchemaRef, + physical_ordering: &PhysicalSortExpr, + required: &PhysicalSortExpr, +) -> (bool, bool) { + if required.expr.eq(&physical_ordering.expr) { + let nullable = required.expr.nullable(input_schema).unwrap(); + let physical_opts = physical_ordering.options; + let required_opts = required.options; + let is_reversed = if nullable { + physical_opts == reverse_sort_options(required_opts) + } else { + // If the column is not nullable, NULLS FIRST/LAST is not important. + physical_opts.descending != required_opts.descending + }; + let can_skip = !nullable || is_reversed || (physical_opts == required_opts); + (can_skip, is_reversed) + } else { + (false, false) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::physical_plan::displayable; + use crate::physical_plan::filter::FilterExec; + use crate::physical_plan::memory::MemoryExec; + use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; + use crate::physical_plan::windows::create_window_expr; + use crate::prelude::SessionContext; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::Result; + use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; + use datafusion_physical_expr::expressions::{col, NotExpr}; + use datafusion_physical_expr::PhysicalSortExpr; + use std::sync::Arc; + + fn create_test_schema() -> Result { + let nullable_column = Field::new("nullable_col", DataType::Int32, true); + let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false); + let schema = Arc::new(Schema::new(vec![nullable_column, non_nullable_column])); + + Ok(schema) + } + + #[tokio::test] + async fn test_is_column_aligned_nullable() -> Result<()> { + let schema = create_test_schema()?; + let params = vec![ + ((true, true), (false, false), (true, true)), + ((true, true), (false, true), (false, false)), + ((true, true), (true, false), (false, false)), + ((true, false), (false, true), (true, true)), + ((true, false), (false, false), (false, false)), + ((true, false), (true, true), (false, false)), + ]; + for ( + (physical_desc, physical_nulls_first), + (req_desc, req_nulls_first), + (is_aligned_expected, reverse_expected), + ) in params + { + let physical_ordering = PhysicalSortExpr { + expr: col("nullable_col", &schema)?, + options: SortOptions { + descending: physical_desc, + nulls_first: physical_nulls_first, + }, + }; + let required_ordering = PhysicalSortExpr { + expr: col("nullable_col", &schema)?, + options: SortOptions { + descending: req_desc, + nulls_first: req_nulls_first, + }, + }; + let (is_aligned, reverse) = + check_alignment(&schema, &physical_ordering, &required_ordering); + assert_eq!(is_aligned, is_aligned_expected); + assert_eq!(reverse, reverse_expected); + } + + Ok(()) + } + + #[tokio::test] + async fn test_is_column_aligned_non_nullable() -> Result<()> { + let schema = create_test_schema()?; + + let params = vec![ + ((true, true), (false, false), (true, true)), + ((true, true), (false, true), (true, true)), + ((true, true), (true, false), (true, false)), + ((true, false), (false, true), (true, true)), + ((true, false), (false, false), (true, true)), + ((true, false), (true, true), (true, false)), + ]; + for ( + (physical_desc, physical_nulls_first), + (req_desc, req_nulls_first), + (is_aligned_expected, reverse_expected), + ) in params + { + let physical_ordering = PhysicalSortExpr { + expr: col("non_nullable_col", &schema)?, + options: SortOptions { + descending: physical_desc, + nulls_first: physical_nulls_first, + }, + }; + let required_ordering = PhysicalSortExpr { + expr: col("non_nullable_col", &schema)?, + options: SortOptions { + descending: req_desc, + nulls_first: req_nulls_first, + }, + }; + let (is_aligned, reverse) = + check_alignment(&schema, &physical_ordering, &required_ordering); + assert_eq!(is_aligned, is_aligned_expected); + assert_eq!(reverse, reverse_expected); + } + + Ok(()) + } + + #[tokio::test] + async fn test_remove_unnecessary_sort() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("non_nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let sort_exec = Arc::new(SortExec::try_new(sort_exprs, source, None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let physical_plan = Arc::new(SortExec::try_new(sort_exprs, sort_exec, None)?) + as Arc; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "SortExec: [nullable_col@0 ASC]", + " SortExec: [non_nullable_col@1 ASC]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { vec!["SortExec: [nullable_col@0 ASC]"] }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } + + #[tokio::test] + async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("non_nullable_col", source.schema().as_ref()).unwrap(), + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let sort_exec = Arc::new(SortExec::try_new(sort_exprs.clone(), source, None)?) + as Arc; + let window_agg_exec = Arc::new(WindowAggExec::try_new( + vec![create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Count), + "count".to_owned(), + &[col("non_nullable_col", &schema)?], + &[], + &sort_exprs, + Arc::new(WindowFrame::new(true)), + schema.as_ref(), + )?], + sort_exec.clone(), + sort_exec.schema(), + vec![], + Some(sort_exprs), + )?) as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("non_nullable_col", window_agg_exec.schema().as_ref()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + let sort_exec = Arc::new(SortExec::try_new( + sort_exprs.clone(), + window_agg_exec, + None, + )?) as Arc; + // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before + let filter_exec = Arc::new(FilterExec::try_new( + Arc::new(NotExpr::new( + col("non_nullable_col", schema.as_ref()).unwrap(), + )), + sort_exec, + )?) as Arc; + // let filter_exec = sort_exec; + let window_agg_exec = Arc::new(WindowAggExec::try_new( + vec![create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Count), + "count".to_owned(), + &[col("non_nullable_col", &schema)?], + &[], + &sort_exprs, + Arc::new(WindowFrame::new(true)), + schema.as_ref(), + )?], + filter_exec.clone(), + filter_exec.schema(), + vec![], + Some(sort_exprs), + )?) as Arc; + let physical_plan = window_agg_exec; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", + " FilterExec: NOT non_nullable_col@1", + " SortExec: [non_nullable_col@2 ASC NULLS LAST]", + " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", + " SortExec: [non_nullable_col@1 DESC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { + vec![ + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL) }]", + " FilterExec: NOT non_nullable_col@1", + " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", + " SortExec: [non_nullable_col@1 DESC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } + + #[tokio::test] + async fn test_add_required_sort() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let physical_plan = Arc::new(SortPreservingMergeExec::new(sort_exprs, source)) + as Arc; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { vec!["SortPreservingMergeExec: [nullable_col@0 ASC]"] }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } + + #[tokio::test] + async fn test_remove_unnecessary_sort1() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let sort_exec = Arc::new(SortExec::try_new(sort_exprs.clone(), source, None)?) + as Arc; + let sort_preserving_merge_exec = + Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec)) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let sort_exec = Arc::new(SortExec::try_new( + sort_exprs.clone(), + sort_preserving_merge_exec, + None, + )?) as Arc; + let sort_preserving_merge_exec = + Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec)) + as Arc; + let physical_plan = sort_preserving_merge_exec; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + " SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } + + #[tokio::test] + async fn test_change_wrong_sorting() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![ + PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: col("non_nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }, + ]; + let sort_exec = Arc::new(SortExec::try_new( + vec![sort_exprs[0].clone()], + source, + None, + )?) as Arc; + let sort_preserving_merge_exec = + Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec)) + as Arc; + let physical_plan = sort_preserving_merge_exec; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index 4aceb776d7d5b..7389da085d600 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -21,7 +21,12 @@ use super::optimizer::PhysicalOptimizerRule; use crate::execution::context::SessionConfig; use crate::error::Result; +use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use datafusion_physical_expr::{ + normalize_sort_expr_with_equivalence_properties, EquivalenceProperties, + PhysicalSortExpr, +}; use std::sync::Arc; /// Convenience rule for writing optimizers: recursively invoke @@ -45,3 +50,85 @@ pub fn optimize_children( with_new_children_if_necessary(plan, children) } } + +/// Check the required ordering requirements are satisfied by the provided PhysicalSortExprs. +pub fn ordering_satisfy EquivalenceProperties>( + provided: Option<&[PhysicalSortExpr]>, + required: Option<&[PhysicalSortExpr]>, + equal_properties: F, +) -> bool { + match (provided, required) { + (_, None) => true, + (None, Some(_)) => false, + (Some(provided), Some(required)) => { + ordering_satisfy_inner(provided, required, equal_properties) + } + } +} + +pub fn ordering_satisfy_inner EquivalenceProperties>( + provided: &[PhysicalSortExpr], + required: &[PhysicalSortExpr], + equal_properties: F, +) -> bool { + if required.len() > provided.len() { + false + } else { + let fast_match = required + .iter() + .zip(provided.iter()) + .all(|(order1, order2)| order1.eq(order2)); + + if !fast_match { + let eq_properties = equal_properties(); + let eq_classes = eq_properties.classes(); + if !eq_classes.is_empty() { + let normalized_required_exprs = required + .iter() + .map(|e| { + normalize_sort_expr_with_equivalence_properties( + e.clone(), + eq_classes, + ) + }) + .collect::>(); + let normalized_provided_exprs = provided + .iter() + .map(|e| { + normalize_sort_expr_with_equivalence_properties( + e.clone(), + eq_classes, + ) + }) + .collect::>(); + normalized_required_exprs + .iter() + .zip(normalized_provided_exprs.iter()) + .all(|(order1, order2)| order1.eq(order2)) + } else { + fast_match + } + } else { + fast_match + } + } +} + +/// Util function to add SortExec above child +pub fn add_sort_above_child( + child: &Arc, + sort_expr: Vec, +) -> Result> { + let new_child = if child.output_partitioning().partition_count() > 1 { + Arc::new(SortExec::new_with_partitioning( + sort_expr, + child.clone(), + true, + None, + )) as Arc + } else { + Arc::new(SortExec::try_new(sort_expr, child.clone(), None)?) + as Arc + }; + Ok(new_child) +} diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 89112adae74a6..98ccb79c62634 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -62,11 +62,6 @@ async fn explain_analyze_baseline_metrics() { "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", "metrics=[output_rows=5, elapsed_compute=" ); - assert_metrics!( - &formatted, - "SortExec: [c1@0 ASC NULLS LAST]", - "metrics=[output_rows=5, elapsed_compute=" - ); assert_metrics!( &formatted, "FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index b550e7f5dd60b..3a44493aaab6d 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1646,3 +1646,532 @@ async fn test_window_agg_sort() -> Result<()> { ); Ok(()) } + +#[tokio::test] +async fn test_window_agg_sort_reversed_plan() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1, + SUM(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c9@2 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@0 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+-------------+-------------+", + "| c9 | sum1 | sum2 |", + "+------------+-------------+-------------+", + "| 4268716378 | 8498370520 | 24997484146 |", + "| 4229654142 | 12714811027 | 29012926487 |", + "| 4216440507 | 16858984380 | 28743001064 |", + "| 4144173353 | 20935849039 | 28472563256 |", + "| 4076864659 | 24997484146 | 28118515915 |", + "+------------+-------------+-------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_reversed_plan_builtin() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + FIRST_VALUE(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as fv1, + FIRST_VALUE(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as fv2, + LAG(c9, 2, 10101) OVER(ORDER BY c9 ASC) as lag1, + LAG(c9, 2, 10101) OVER(ORDER BY c9 DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lag2, + LEAD(c9, 2, 10101) OVER(ORDER BY c9 ASC) as lead1, + LEAD(c9, 2, 10101) OVER(ORDER BY c9 DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lead2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c9@6 as c9, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as fv1, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lead2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt32(NULL)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt32(NULL)) }]", + " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }]", + " SortExec: [c9@0 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+------------+------------+------------+------------+------------+------------+", + "| c9 | fv1 | fv2 | lag1 | lag2 | lead1 | lead2 |", + "+------------+------------+------------+------------+------------+------------+------------+", + "| 4268716378 | 4229654142 | 4268716378 | 4216440507 | 10101 | 10101 | 4216440507 |", + "| 4229654142 | 4216440507 | 4268716378 | 4144173353 | 10101 | 10101 | 4144173353 |", + "| 4216440507 | 4144173353 | 4229654142 | 4076864659 | 4268716378 | 4268716378 | 4076864659 |", + "| 4144173353 | 4076864659 | 4216440507 | 4061635107 | 4229654142 | 4229654142 | 4061635107 |", + "| 4076864659 | 4061635107 | 4144173353 | 4015442341 | 4216440507 | 4216440507 | 4015442341 |", + "+------------+------------+------------+------------+------------+------------+------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_non_reversed_plan() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + ROW_NUMBER() OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as rn1, + ROW_NUMBER() OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as rn2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // We cannot reverse each window function (ROW_NUMBER is not reversible) + let expected = { + vec![ + "ProjectionExec: expr=[c9@2 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as rn1, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@1 ASC NULLS LAST]", + " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@0 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------+-----+-----+", + "| c9 | rn1 | rn2 |", + "+-----------+-----+-----+", + "| 28774375 | 1 | 100 |", + "| 63044568 | 2 | 99 |", + "| 141047417 | 3 | 98 |", + "| 141680161 | 4 | 97 |", + "| 145294611 | 5 | 96 |", + "+-----------+-----+-----+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + SUM(c9) OVER(ORDER BY c9 ASC, c1 ASC, c2 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1, + SUM(c9) OVER(ORDER BY c9 DESC, c1 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2, + ROW_NUMBER() OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as rn2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // We cannot reverse each window function (ROW_NUMBER is not reversible) + let expected = { + vec![ + "ProjectionExec: expr=[c9@5 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as rn2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@4 DESC]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@2 ASC NULLS LAST,c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+-------------+-------------+-----+", + "| c9 | sum1 | sum2 | rn2 |", + "+------------+-------------+-------------+-----+", + "| 4268716378 | 8498370520 | 24997484146 | 1 |", + "| 4229654142 | 12714811027 | 29012926487 | 2 |", + "| 4216440507 | 16858984380 | 28743001064 | 3 |", + "| 4144173353 | 20935849039 | 28472563256 | 4 |", + "| 4076864659 | 24997484146 | 28118515915 | 5 |", + "+------------+-------------+-------------+-----+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_complex_plan() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_null_cases_csv(&ctx).await?; + let sql = "SELECT + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as a, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as b, + SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as c, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as d, + SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as e, + SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as f, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as g, + SUM(c1) OVER (ORDER BY c3) as h, + SUM(c1) OVER (ORDER BY c3 DESC) as i, + SUM(c1) OVER (ORDER BY c3 NULLS first) as j, + SUM(c1) OVER (ORDER BY c3 DESC NULLS first) as k, + SUM(c1) OVER (ORDER BY c3 DESC NULLS last) as l, + SUM(c1) OVER (ORDER BY c3, c2) as m, + SUM(c1) OVER (ORDER BY c3, c1 DESC) as n, + SUM(c1) OVER (ORDER BY c3 DESC, c1) as o, + SUM(c1) OVER (ORDER BY c3, c1 NULLs first) as p, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as a1, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as b1, + SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as c1, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as d1, + SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as e1, + SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as f1, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as g1, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as h1, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as j1, + SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as k1, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as l1, + SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as m1, + SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as n1, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as o1, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as h11, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as j11, + SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as k11, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as l11, + SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as m11, + SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as n11, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as o11 + FROM null_cases + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Unnecessary SortExecs are removed + let expected = { + vec![ + "ProjectionExec: expr=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as a, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as b, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@4 as c, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@8 as d, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@12 as e, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@4 as f, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@8 as g, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as h, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as i, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as j, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as k, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@13 as l, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as m, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@17 as n, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@18 as o, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as p, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@2 as a1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@2 as b1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@6 as c1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@10 as d1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@14 as e1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@6 as f1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@10 as g1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as h1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as j1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as k1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as l1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@13 as m1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as n1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as o1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@3 as h11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@3 as j11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@7 as k11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@11 as l11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@15 as m11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@7 as n11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@11 as o11]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@5 ASC NULLS LAST,c2@4 ASC NULLS LAST]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@3 DESC,c1@1 ASC NULLS LAST]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@2 ASC NULLS LAST,c1@0 ASC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_orderby_reversed_partitionby_plan() -> Result<()> { + let config = SessionConfig::new().with_repartition_windows(false); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + SUM(c9) OVER(ORDER BY c1, c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1, + SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c1@0 ASC,c9@1 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+-------------+-------------+", + "| c9 | sum1 | sum2 |", + "+------------+-------------+-------------+", + "| 4015442341 | 21907044499 | 21907044499 |", + "| 3998790955 | 24576419362 | 24576419362 |", + "| 3959216334 | 23063303501 | 23063303501 |", + "| 3717551163 | 21560567246 | 21560567246 |", + "| 3276123488 | 19815386638 | 19815386638 |", + "+------------+-------------+-------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_partitionby_reversed_plan() -> Result<()> { + let config = SessionConfig::new().with_repartition_windows(false); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1, + SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c1@0 ASC,c9@1 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+-------------+-------------+", + "| c9 | sum1 | sum2 |", + "+------------+-------------+-------------+", + "| 4015442341 | 8014233296 | 21907044499 |", + "| 3998790955 | 11973449630 | 24576419362 |", + "| 3959216334 | 15691000793 | 23063303501 |", + "| 3717551163 | 18967124281 | 21560567246 |", + "| 3276123488 | 21907044499 | 19815386638 |", + "+------------+-------------+-------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_orderby_reversed_binary_expr() -> Result<()> { + let config = SessionConfig::new().with_repartition_windows(false); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c3, + SUM(c9) OVER(ORDER BY c3+c4 DESC, c9 DESC, c2 ASC) as sum1, + SUM(c9) OVER(ORDER BY c3+c4 ASC, c9 ASC ) as sum2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c3@3 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@0 as sum2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: \"SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int16(NULL)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: \"SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow }]", + " SortExec: [CAST(c3@1 AS Int16) + c4@2 DESC,c9@3 DESC,c2@0 ASC NULLS LAST]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----+-------------+--------------+", + "| c3 | sum1 | sum2 |", + "+-----+-------------+--------------+", + "| -86 | 2861911482 | 222089770060 |", + "| 13 | 5075947208 | 219227858578 |", + "| 125 | 8701233618 | 217013822852 |", + "| 123 | 11293564174 | 213388536442 |", + "| 97 | 14767488750 | 210796205886 |", + "+-----+-------------+--------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_remove_unnecessary_sort_in_sub_query() -> Result<()> { + let config = SessionConfig::new(); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT count(*) as global_count FROM + (SELECT count(*), c1 + FROM aggregate_test_100 + WHERE c13 != 'C2GT5KVyOPZpgKVl110TyZO0NcJ434' + GROUP BY c1 + ORDER BY c1 ) AS a "; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Unnecessary Sort in the sub query is removed + let expected = { + vec![ + "ProjectionExec: expr=[COUNT(UInt8(1))@0 as global_count]", + " AggregateExec: mode=Final, gby=[], aggr=[COUNT(UInt8(1))]", + " CoalescePartitionsExec", + " AggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))]", + " RepartitionExec: partitioning=RoundRobinBatch(8)", + " CoalescePartitionsExec", + " AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(UInt8(1))]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 8)", + " AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[COUNT(UInt8(1))]", + " CoalesceBatchesExec: target_batch_size=4096", + " FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", + " RepartitionExec: partitioning=RoundRobinBatch(8)", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+--------------+", + "| global_count |", + "+--------------+", + "| 5 |", + "+--------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 35790885e02f0..b8274ed1d1930 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -113,6 +113,33 @@ impl WindowFrame { } } } + + /// Get reversed window frame + pub fn reverse(&self) -> Self { + let start_bound = match &self.end_bound { + WindowFrameBound::Preceding(elem) => { + WindowFrameBound::Following(elem.clone()) + } + WindowFrameBound::Following(elem) => { + WindowFrameBound::Preceding(elem.clone()) + } + WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow, + }; + let end_bound = match &self.start_bound { + WindowFrameBound::Preceding(elem) => { + WindowFrameBound::Following(elem.clone()) + } + WindowFrameBound::Following(elem) => { + WindowFrameBound::Preceding(elem.clone()) + } + WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow, + }; + WindowFrame { + units: self.units, + start_bound, + end_bound, + } + } } /// There are five ways to describe starting and ending frame boundaries: diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 4721bf8f2301b..7fceaeedeab3c 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -36,7 +36,7 @@ use crate::expressions::format_state_name; /// COUNT aggregate expression /// Returns the amount of non-null values of the given expression. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Count { name: String, data_type: DataType, @@ -104,6 +104,14 @@ impl AggregateExpr for Count { ) -> Result> { Ok(Box::new(CountRowAccumulator::new(start_index))) } + + fn is_window_fn_reversible(&self) -> bool { + true + } + + fn reverse_expr(&self) -> Result> { + Ok(Arc::new(self.clone())) + } } #[derive(Debug)] diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index f6374687403ec..b0776886e69d5 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -101,4 +101,19 @@ pub trait AggregateExpr: Send + Sync + Debug { self ))) } + + /// Get whether window function is reversible + /// make `true` if `reverse_expr` is implemented + fn is_window_fn_reversible(&self) -> bool { + false + } + + /// Construct Reverse Expression + // Typically expression itself for aggregate functions + fn reverse_expr(&self) -> Result> { + Err(DataFusionError::NotImplemented(format!( + "reverse_expr hasn't been implemented for {:?} yet", + self + ))) + } } diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index b330455a1855b..11371f31d4c68 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -42,7 +42,7 @@ use arrow::compute::cast; use datafusion_row::accessor::RowAccessor; /// SUM aggregate expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Sum { name: String, data_type: DataType, @@ -132,6 +132,14 @@ impl AggregateExpr for Sum { self.data_type.clone(), ))) } + + fn is_window_fn_reversible(&self) -> bool { + true + } + + fn reverse_expr(&self) -> Result> { + Ok(Arc::new(self.clone())) + } } #[derive(Debug)] diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 52a43050b1cc8..1f0c286efc859 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -30,6 +30,7 @@ use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::WindowFrame; +use crate::window::window_expr::reverse_order_bys; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use crate::{window::WindowExpr, AggregateExpr}; @@ -155,4 +156,17 @@ impl WindowExpr for AggregateWindowExpr { fn get_window_frame(&self) -> &Arc { &self.window_frame } + + fn is_window_fn_reversible(&self) -> bool { + self.aggregate.as_ref().is_window_fn_reversible() + } + + fn get_reversed_expr(&self) -> Result> { + Ok(Arc::new(AggregateWindowExpr::new( + self.aggregate.reverse_expr()?, + &self.partition_by.clone(), + &reverse_order_bys(&self.order_by), + Arc::new(self.window_frame.reverse()), + ))) + } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 95bf01608b82f..8e16fc7e7a627 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -20,6 +20,7 @@ use super::window_frame_state::WindowFrameContext; use super::BuiltInWindowFunctionExpr; use super::WindowExpr; +use crate::window::window_expr::reverse_order_bys; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use arrow::compute::{concat, SortOptions}; use arrow::record_batch::RecordBatch; @@ -137,4 +138,17 @@ impl WindowExpr for BuiltInWindowExpr { fn get_window_frame(&self) -> &Arc { &self.window_frame } + + fn is_window_fn_reversible(&self) -> bool { + self.expr.as_ref().is_window_fn_reversible() + } + + fn get_reversed_expr(&self) -> Result> { + Ok(Arc::new(BuiltInWindowExpr::new( + self.expr.reverse_expr()?, + &self.partition_by.clone(), + &reverse_order_bys(&self.order_by), + Arc::new(self.window_frame.reverse()), + ))) + } } diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index 7f7a27435c392..71b100b54e5e4 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -20,7 +20,7 @@ use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; -use datafusion_common::Result; +use datafusion_common::{DataFusionError, Result}; use std::any::Any; use std::sync::Arc; @@ -58,4 +58,18 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { /// Create built-in window evaluator with a batch fn create_evaluator(&self) -> Result>; + + /// Get whether window function is reversible + /// make true if `reverse_expr` is implemented + fn is_window_fn_reversible(&self) -> bool { + false + } + + /// Construct Reverse Expression + fn reverse_expr(&self) -> Result> { + Err(DataFusionError::NotImplemented(format!( + "reverse_expr hasn't been implemented for {:?} yet", + self + ))) + } } diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index c7fc73b9f1c1f..97076963f60d4 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -107,6 +107,20 @@ impl BuiltInWindowFunctionExpr for WindowShift { default_value: self.default_value.clone(), })) } + + fn is_window_fn_reversible(&self) -> bool { + true + } + + fn reverse_expr(&self) -> Result> { + Ok(Arc::new(Self { + name: self.name.clone(), + data_type: self.data_type.clone(), + shift_offset: -self.shift_offset, + expr: self.expr.clone(), + default_value: self.default_value.clone(), + })) + } } pub(crate) struct WindowShiftEvaluator { diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 63a2354c9e4c5..2942a797b8526 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -123,6 +123,31 @@ impl BuiltInWindowFunctionExpr for NthValue { fn create_evaluator(&self) -> Result> { Ok(Box::new(NthValueEvaluator { kind: self.kind })) } + + fn is_window_fn_reversible(&self) -> bool { + match self.kind { + NthValueKind::First | NthValueKind::Last => true, + NthValueKind::Nth(_) => false, + } + } + + fn reverse_expr(&self) -> Result> { + let reversed_kind = match self.kind { + NthValueKind::First => NthValueKind::Last, + NthValueKind::Last => NthValueKind::First, + NthValueKind::Nth(_) => { + return Err(DataFusionError::Execution( + "Cannot take reverse of NthValue".to_string(), + )) + } + }; + Ok(Arc::new(Self { + name: self.name.clone(), + expr: self.expr.clone(), + data_type: self.data_type.clone(), + kind: reversed_kind, + })) + } } /// Value evaluator for nth_value functions diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 209e0544f2fa6..b1af6cf49a0ab 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -20,7 +20,7 @@ use arrow::compute::kernels::partition::lexicographical_partition_ranges; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{reverse_sort_options, DataFusionError, Result}; use datafusion_expr::WindowFrame; use std::any::Any; use std::fmt::Debug; @@ -129,6 +129,25 @@ pub trait WindowExpr: Send + Sync + Debug { Ok((values, order_bys)) } - // Get window frame of this WindowExpr, None if absent + // Get window frame of this WindowExpr fn get_window_frame(&self) -> &Arc; + + /// Get whether window function can be reversed + fn is_window_fn_reversible(&self) -> bool; + + /// get reversed expression + fn get_reversed_expr(&self) -> Result>; +} + +/// Reverses the ORDER BY expression, which is useful during equivalent window +/// expression construction. For instance, 'ORDER BY a ASC, NULLS LAST' turns into +/// 'ORDER BY a DESC, NULLS FIRST'. +pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec { + order_bys + .iter() + .map(|e| PhysicalSortExpr { + expr: e.expr.clone(), + options: reverse_sort_options(e.options), + }) + .collect() } From 343fafbd962183bb7c8d2126b3cdcd3f545e556d Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 13 Dec 2022 16:42:53 +0300 Subject: [PATCH 02/50] move ordering satisfy to the util --- .../src/physical_optimizer/enforcement.rs | 79 ++----------------- 1 file changed, 8 insertions(+), 71 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforcement.rs b/datafusion/core/src/physical_optimizer/enforcement.rs index 3110061c4f2cb..9c7314a1b33ca 100644 --- a/datafusion/core/src/physical_optimizer/enforcement.rs +++ b/datafusion/core/src/physical_optimizer/enforcement.rs @@ -20,6 +20,7 @@ //! use crate::config::OPT_TOP_DOWN_JOIN_KEY_REORDERING; use crate::error::Result; +use crate::physical_optimizer::utils::{add_sort_above_child, ordering_satisfy}; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -29,7 +30,6 @@ use crate::physical_plan::joins::{ use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::rewrite::TreeNodeRewritable; -use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort::SortOptions; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::Partitioning; @@ -41,9 +41,8 @@ use datafusion_physical_expr::equivalence::EquivalenceProperties; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::expressions::NoOp; use datafusion_physical_expr::{ - expr_list_eq_strict_order, normalize_expr_with_equivalence_properties, - normalize_sort_expr_with_equivalence_properties, AggregateExpr, PhysicalExpr, - PhysicalSortExpr, + expr_list_eq_strict_order, normalize_expr_with_equivalence_properties, AggregateExpr, + PhysicalExpr, }; use std::collections::HashMap; use std::sync::Arc; @@ -885,14 +884,7 @@ fn ensure_distribution_and_ordering( Ok(child) } else { let sort_expr = required.unwrap().to_vec(); - if child.output_partitioning().partition_count() > 1 { - Ok(Arc::new(SortExec::new_with_partitioning( - sort_expr, child, true, None, - )) as Arc) - } else { - Ok(Arc::new(SortExec::try_new(sort_expr, child, None)?) - as Arc) - } + add_sort_above_child(&child, sort_expr) } }) .collect(); @@ -900,61 +892,6 @@ fn ensure_distribution_and_ordering( with_new_children_if_necessary(plan, new_children?) } -/// Check the required ordering requirements are satisfied by the provided PhysicalSortExprs. -fn ordering_satisfy EquivalenceProperties>( - provided: Option<&[PhysicalSortExpr]>, - required: Option<&[PhysicalSortExpr]>, - equal_properties: F, -) -> bool { - match (provided, required) { - (_, None) => true, - (None, Some(_)) => false, - (Some(provided), Some(required)) => { - if required.len() > provided.len() { - false - } else { - let fast_match = required - .iter() - .zip(provided.iter()) - .all(|(order1, order2)| order1.eq(order2)); - - if !fast_match { - let eq_properties = equal_properties(); - let eq_classes = eq_properties.classes(); - if !eq_classes.is_empty() { - let normalized_required_exprs = required - .iter() - .map(|e| { - normalize_sort_expr_with_equivalence_properties( - e.clone(), - eq_classes, - ) - }) - .collect::>(); - let normalized_provided_exprs = provided - .iter() - .map(|e| { - normalize_sort_expr_with_equivalence_properties( - e.clone(), - eq_classes, - ) - }) - .collect::>(); - normalized_required_exprs - .iter() - .zip(normalized_provided_exprs.iter()) - .all(|(order1, order2)| order1.eq(order2)) - } else { - fast_match - } - } else { - fast_match - } - } - } - } -} - #[derive(Debug, Clone)] struct JoinKeyPairs { left_keys: Vec>, @@ -1034,10 +971,10 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_expr::logical_plan::JoinType; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::binary; - use datafusion_physical_expr::expressions::lit; - use datafusion_physical_expr::expressions::Column; - use datafusion_physical_expr::{expressions, PhysicalExpr}; + use datafusion_physical_expr::{ + expressions, expressions::binary, expressions::lit, expressions::Column, + PhysicalExpr, PhysicalSortExpr, + }; use std::ops::Deref; use super::*; From dfb66836b9b328c296ae55b572c5770c09ba4abb Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 13 Dec 2022 17:06:40 +0300 Subject: [PATCH 03/50] update test and change repartition maintain_input_order impl --- datafusion/core/src/physical_plan/repartition.rs | 16 +++++++++++++++- datafusion/core/tests/sql/window.rs | 4 +++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs index 9492fb7497a62..12424c8587971 100644 --- a/datafusion/core/src/physical_plan/repartition.rs +++ b/datafusion/core/src/physical_plan/repartition.rs @@ -289,7 +289,21 @@ impl ExecutionPlan for RepartitionExec { } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None + if self.maintains_input_order() { + self.input().output_ordering() + } else { + None + } + } + + fn maintains_input_order(&self) -> bool { + let n_input = match self.input().output_partitioning() { + Partitioning::RoundRobinBatch(n) => n, + Partitioning::Hash(_, n) => n, + Partitioning::UnknownPartitioning(n) => n, + }; + // We preserve ordering when input partitioning is 1 + n_input <= 1 } fn equivalence_properties(&self) -> EquivalenceProperties { diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 3a44493aaab6d..085628051e17e 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -2119,7 +2119,9 @@ async fn test_window_agg_sort_orderby_reversed_binary_expr() -> Result<()> { #[tokio::test] async fn test_remove_unnecessary_sort_in_sub_query() -> Result<()> { - let config = SessionConfig::new(); + let config = SessionConfig::new() + .with_target_partitions(8) + .with_repartition_windows(true); let ctx = SessionContext::with_config(config); register_aggregate_csv(&ctx).await?; let sql = "SELECT count(*) as global_count FROM From 0a42315520940946cb2a80483439d4dcd46a7a10 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 13 Dec 2022 17:15:48 +0300 Subject: [PATCH 04/50] simplifications --- .../remove_unnecessary_sorts.rs | 86 ++++++++----------- 1 file changed, 37 insertions(+), 49 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index 9c2e071dd8b84..322677a30428d 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -88,39 +88,7 @@ fn remove_unnecessary_sorts( ordering_satisfy_inner(physical_ordering, required_ordering, || { child.equivalence_properties() }); - if is_ordering_satisfied { - // can do analysis for sort removal - if !sort_onward.is_empty() { - let (_, sort_any) = sort_onward[0].clone(); - let sort_exec = sort_any - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Plan( - "First layer should start from SortExec".to_string(), - ) - })?; - let sort_output_ordering = sort_exec.output_ordering(); - let sort_input_ordering = sort_exec.input().output_ordering(); - // Do naive analysis, where a SortExec is already sorted according to desired Sorting - if ordering_satisfy( - sort_input_ordering, - sort_output_ordering, - || sort_exec.input().equivalence_properties(), - ) { - update_child_to_remove_unnecessary_sort(child, sort_onward)?; - } else if let Some(window_agg_exec) = - requirements.plan.as_any().downcast_ref::() - { - // For window expressions we can remove some Sorts when expression can be calculated in reverse order also. - if let Some(res) = - analyze_window_sort_removal(window_agg_exec, sort_onward)? - { - return Ok(Some(res)); - } - } - } - } else { + if !is_ordering_satisfied { // During sort Removal we have invalidated ordering invariant fix it // This is effectively moving sort above in the physical plan update_child_to_remove_unnecessary_sort(child, sort_onward)?; @@ -128,6 +96,29 @@ fn remove_unnecessary_sorts( *child = add_sort_above_child(child, sort_expr)?; // Since we have added Sort, we add it to the sort_onwards also. sort_onward.push((idx, child.clone())) + } else if is_ordering_satisfied && !sort_onward.is_empty() { + // can do analysis for sort removal + let (_, sort_any) = sort_onward[0].clone(); + let sort_exec = convert_to_sort_exec(&sort_any)?; + let sort_output_ordering = sort_exec.output_ordering(); + let sort_input_ordering = sort_exec.input().output_ordering(); + // Do naive analysis, where a SortExec is already sorted according to desired Sorting + if ordering_satisfy(sort_input_ordering, sort_output_ordering, || { + sort_exec.input().equivalence_properties() + }) { + update_child_to_remove_unnecessary_sort(child, sort_onward)?; + } else if let Some(window_agg_exec) = + requirements.plan.as_any().downcast_ref::() + { + // For window expressions we can remove some Sorts when expression can be calculated in reverse order also. + if let Some(res) = analyze_window_sort_removal( + window_agg_exec, + sort_exec, + sort_onward, + )? { + return Ok(Some(res)); + } + } } } (Some(required), None) => { @@ -245,19 +236,9 @@ impl TreeNodeRewritable for PlanWithCorrespondingSort { /// Analyzes `WindowAggExec` to determine whether Sort can be removed fn analyze_window_sort_removal( window_agg_exec: &WindowAggExec, + sort_exec: &SortExec, sort_onward: &mut Vec<(usize, Arc)>, ) -> Result> { - // If empty immediately return we cannot do analysis - if sort_onward.is_empty() { - return Ok(None); - } - let (_, sort_any) = sort_onward[0].clone(); - let sort_exec = sort_any - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Plan("First layer should start from SortExec".to_string()) - })?; let required_ordering = sort_exec.output_ordering().ok_or_else(|| { DataFusionError::Plan("SortExec should have output ordering".to_string()) })?; @@ -313,17 +294,24 @@ fn update_child_to_remove_unnecessary_sort( } Ok(()) } -/// Removes the sort from the plan in the `sort_onwards` -fn remove_corresponding_sort_from_sub_plan( - sort_onwards: &mut Vec<(usize, Arc)>, -) -> Result> { - let (sort_child_idx, sort_any) = sort_onwards[0].clone(); + +/// Convert dyn ExecutionPlan to SortExec (Assumes it is SortExec) +fn convert_to_sort_exec(sort_any: &Arc) -> Result<&SortExec> { let sort_exec = sort_any .as_any() .downcast_ref::() .ok_or_else(|| { DataFusionError::Plan("First layer should start from SortExec".to_string()) })?; + Ok(sort_exec) +} + +/// Removes the sort from the plan in the `sort_onwards` +fn remove_corresponding_sort_from_sub_plan( + sort_onwards: &mut Vec<(usize, Arc)>, +) -> Result> { + let (sort_child_idx, sort_any) = sort_onwards[0].clone(); + let sort_exec = convert_to_sort_exec(&sort_any)?; let mut prev_layer = sort_exec.input().clone(); let mut prev_layer_child_idx = sort_child_idx; // We start from 1 hence since first one is sort and we are removing it From c2a1593cd4fb4c8f831a7f40f6e6dd0c53acb695 Mon Sep 17 00:00:00 2001 From: Mustafa akur <106137913+mustafasrepo@users.noreply.github.com> Date: Thu, 15 Dec 2022 15:10:19 +0300 Subject: [PATCH 05/50] partition by refactor (#28) * partition by refactor * minor changes * Unnecessary tuple to Range conversion is removed * move transpose under common --- datafusion/common/src/lib.rs | 12 +++ .../remove_unnecessary_sorts.rs | 25 +++--- .../physical_plan/windows/window_agg_exec.rs | 79 +++++++++++++++- datafusion/core/tests/sql/window.rs | 54 +++++++++++ .../physical-expr/src/window/aggregate.rs | 89 +++++++++---------- .../physical-expr/src/window/built_in.rs | 55 +++++------- .../physical-expr/src/window/cume_dist.rs | 25 +++--- .../physical-expr/src/window/lead_lag.rs | 17 ++-- .../src/window/partition_evaluator.rs | 55 +----------- datafusion/physical-expr/src/window/rank.rs | 28 +++--- .../physical-expr/src/window/row_number.rs | 18 ++-- .../physical-expr/src/window/window_expr.rs | 20 +---- .../src/window/window_frame_state.rs | 15 ++-- 13 files changed, 262 insertions(+), 230 deletions(-) diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index c6812ed24f8ae..63683f5af0242 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -73,3 +73,15 @@ pub fn reverse_sort_options(options: SortOptions) -> SortOptions { nulls_first: !options.nulls_first, } } + +/// Transposes 2d vector +pub fn transpose(original: Vec>) -> Vec> { + assert!(!original.is_empty()); + let mut transposed = (0..original[0].len()).map(|_| vec![]).collect::>(); + for original_row in original { + for (item, transposed_row) in original_row.into_iter().zip(&mut transposed) { + transposed_row.push(item); + } + } + transposed +} diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index 322677a30428d..ccaf93e0e60d3 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -364,16 +364,15 @@ pub fn can_skip_sort( .iter() .filter(|elem| elem.is_partition) .collect::>(); - let (can_skip_partition_bys, should_reverse_partition_bys) = - if partition_by_sections.is_empty() { - (true, false) - } else { - let first_reverse = partition_by_sections[0].reverse; - let can_skip_partition_bys = partition_by_sections - .iter() - .all(|c| c.is_aligned && c.reverse == first_reverse); - (can_skip_partition_bys, first_reverse) - }; + let can_skip_partition_bys = if partition_by_sections.is_empty() { + true + } else { + let first_reverse = partition_by_sections[0].reverse; + let can_skip_partition_bys = partition_by_sections + .iter() + .all(|c| c.is_aligned && c.reverse == first_reverse); + can_skip_partition_bys + }; let order_by_sections = col_infos .iter() .filter(|elem| !elem.is_partition) @@ -387,11 +386,7 @@ pub fn can_skip_sort( .all(|c| c.is_aligned && c.reverse == first_reverse); (can_skip_order_bys, first_reverse) }; - // TODO: We cannot skip partition by keys when sort direction is reversed, - // by propogating partition by sort direction to `WindowAggExec` we can skip - // these columns also. Add support for that (Use direction during partition range calculation). - let can_skip = - can_skip_order_bys && can_skip_partition_bys && !should_reverse_partition_bys; + let can_skip = can_skip_order_bys && can_skip_partition_bys; Ok((can_skip, should_reverse_order_bys)) } diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs index 914e3e71dbad7..837f32ac69b53 100644 --- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs @@ -28,19 +28,23 @@ use crate::physical_plan::{ ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, }; -use arrow::compute::concat_batches; +use arrow::compute::{ + concat, concat_batches, lexicographical_partition_ranges, SortColumn, +}; use arrow::{ array::ArrayRef, datatypes::{Schema, SchemaRef}, error::{ArrowError, Result as ArrowResult}, record_batch::RecordBatch, }; +use datafusion_common::{transpose, DataFusionError}; use datafusion_physical_expr::rewrite::TreeNodeRewritable; use datafusion_physical_expr::EquivalentClass; use futures::stream::Stream; use futures::{ready, StreamExt}; use log::debug; use std::any::Any; +use std::ops::Range; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -131,6 +135,25 @@ impl WindowAggExec { pub fn input_schema(&self) -> SchemaRef { self.input_schema.clone() } + + /// Get Partition Columns + pub fn partition_by_sort_keys(&self) -> Result> { + // All window exprs have same partition by hance we just use first one + let partition_by = self.window_expr()[0].partition_by(); + let mut partition_columns = vec![]; + for elem in partition_by { + if let Some(sort_keys) = &self.sort_keys { + for a in sort_keys { + if a.expr.eq(elem) { + partition_columns.push(a.clone()); + break; + } + } + } + } + assert_eq!(partition_by.len(), partition_columns.len()); + Ok(partition_columns) + } } impl ExecutionPlan for WindowAggExec { @@ -253,6 +276,7 @@ impl ExecutionPlan for WindowAggExec { self.window_expr.clone(), input, BaselineMetrics::new(&self.metrics, partition), + self.partition_by_sort_keys()?, )); Ok(stream) } @@ -337,6 +361,7 @@ pub struct WindowAggStream { batches: Vec, finished: bool, window_expr: Vec>, + partition_by_sort_keys: Vec, baseline_metrics: BaselineMetrics, } @@ -347,6 +372,7 @@ impl WindowAggStream { window_expr: Vec>, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, + partition_by_sort_keys: Vec, ) -> Self { Self { schema, @@ -355,6 +381,7 @@ impl WindowAggStream { finished: false, window_expr, baseline_metrics, + partition_by_sort_keys, } } @@ -369,8 +396,27 @@ impl WindowAggStream { let batch = concat_batches(&self.input.schema(), &self.batches)?; // calculate window cols - let mut columns = compute_window_aggregates(&self.window_expr, &batch) - .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + let partition_columns = self.partition_columns(&batch)?; + let partition_points = + self.evaluate_partition_points(batch.num_rows(), &partition_columns)?; + + let mut partition_results = vec![]; + for partition_point in partition_points { + let length = partition_point.end - partition_point.start; + partition_results.push( + compute_window_aggregates( + &self.window_expr, + &batch.slice(partition_point.start, length), + ) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?, + ) + } + let mut columns = transpose(partition_results) + .iter() + .map(|elems| concat(&elems.iter().map(|x| x.as_ref()).collect::>())) + .collect::>() + .into_iter() + .collect::>>()?; // combine with the original cols // note the setup of window aggregates is that they newly calculated window @@ -378,6 +424,33 @@ impl WindowAggStream { columns.extend_from_slice(batch.columns()); RecordBatch::try_new(self.schema.clone(), columns) } + + /// Get Partition Columns + pub fn partition_columns(&self, batch: &RecordBatch) -> Result> { + self.partition_by_sort_keys + .iter() + .map(|elem| elem.evaluate_to_sort_column(batch)) + .collect::>>() + } + + /// evaluate the partition points given the sort columns; if the sort columns are + /// empty then the result will be a single element vec of the whole column rows. + fn evaluate_partition_points( + &self, + num_rows: usize, + partition_columns: &[SortColumn], + ) -> Result>> { + if partition_columns.is_empty() { + Ok(vec![Range { + start: 0, + end: num_rows, + }]) + } else { + Ok(lexicographical_partition_ranges(partition_columns) + .map_err(DataFusionError::ArrowError)? + .collect::>()) + } + } } impl Stream for WindowAggStream { diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 085628051e17e..a5bd6a3b97c8e 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -2177,3 +2177,57 @@ async fn test_remove_unnecessary_sort_in_sub_query() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_window_agg_sort_orderby_reversed_partitionby_reversed_plan() -> Result<()> { + let config = SessionConfig::new().with_repartition_windows(false); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c3, + SUM(c9) OVER(ORDER BY c3 DESC, c9 DESC, c2 ASC) as sum1, + SUM(c9) OVER(PARTITION BY c3 ORDER BY c9 DESC ) as sum2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c3@3 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@0 as sum2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@1 DESC,c9@2 DESC,c2@0 ASC NULLS LAST]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----+-------------+------------+", + "| c3 | sum1 | sum2 |", + "+-----+-------------+------------+", + "| 125 | 3625286410 | 3625286410 |", + "| 123 | 7192027599 | 3566741189 |", + "| 123 | 9784358155 | 6159071745 |", + "| 122 | 13845993262 | 4061635107 |", + "| 120 | 16676974334 | 2830981072 |", + "+-----+-------------+------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 1f0c286efc859..1268559fe5688 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::iter::IntoIterator; +use std::ops::Range; use std::sync::Arc; use arrow::array::Array; @@ -90,58 +91,50 @@ impl WindowExpr for AggregateWindowExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let partition_columns = self.partition_columns(batch)?; - let partition_points = - self.evaluate_partition_points(batch.num_rows(), &partition_columns)?; let sort_options: Vec = self.order_by.iter().map(|o| o.options).collect(); let mut row_wise_results: Vec = vec![]; - for partition_range in &partition_points { - let mut accumulator = self.aggregate.create_accumulator()?; - let length = partition_range.end - partition_range.start; - let (values, order_bys) = - self.get_values_orderbys(&batch.slice(partition_range.start, length))?; - - let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); - let mut last_range: (usize, usize) = (0, 0); - - // We iterate on each row to perform a running calculation. - // First, cur_range is calculated, then it is compared with last_range. - for i in 0..length { - let cur_range = window_frame_ctx.calculate_range( - &order_bys, - &sort_options, - length, - i, - )?; - let value = if cur_range.0 == cur_range.1 { - // We produce None if the window is empty. - ScalarValue::try_from(self.aggregate.field()?.data_type())? - } else { - // Accumulate any new rows that have entered the window: - let update_bound = cur_range.1 - last_range.1; - if update_bound > 0 { - let update: Vec = values - .iter() - .map(|v| v.slice(last_range.1, update_bound)) - .collect(); - accumulator.update_batch(&update)? - } - // Remove rows that have now left the window: - let retract_bound = cur_range.0 - last_range.0; - if retract_bound > 0 { - let retract: Vec = values - .iter() - .map(|v| v.slice(last_range.0, retract_bound)) - .collect(); - accumulator.retract_batch(&retract)? - } - accumulator.evaluate()? - }; - row_wise_results.push(value); - last_range = cur_range; - } + + let mut accumulator = self.aggregate.create_accumulator()?; + let length = batch.num_rows(); + let (values, order_bys) = self.get_values_orderbys(batch)?; + + let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + let mut last_range = Range { start: 0, end: 0 }; + + // We iterate on each row to perform a running calculation. + // First, cur_range is calculated, then it is compared with last_range. + for i in 0..length { + let cur_range = + window_frame_ctx.calculate_range(&order_bys, &sort_options, length, i)?; + let value = if cur_range.end == cur_range.start { + // We produce None if the window is empty. + ScalarValue::try_from(self.aggregate.field()?.data_type())? + } else { + // Accumulate any new rows that have entered the window: + let update_bound = cur_range.end - last_range.end; + if update_bound > 0 { + let update: Vec = values + .iter() + .map(|v| v.slice(last_range.end, update_bound)) + .collect(); + accumulator.update_batch(&update)? + } + // Remove rows that have now left the window: + let retract_bound = cur_range.start - last_range.start; + if retract_bound > 0 { + let retract: Vec = values + .iter() + .map(|v| v.slice(last_range.start, retract_bound)) + .collect(); + accumulator.retract_batch(&retract)? + } + accumulator.evaluate()? + }; + row_wise_results.push(value); + last_range = cur_range; } + ScalarValue::iter_to_array(row_wise_results.into_iter()) } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 8e16fc7e7a627..0b1e5ee8f19cf 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -22,14 +22,13 @@ use super::BuiltInWindowFunctionExpr; use super::WindowExpr; use crate::window::window_expr::reverse_order_bys; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; -use arrow::compute::{concat, SortOptions}; +use arrow::compute::SortOptions; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::DataFusionError; use datafusion_common::Result; +use datafusion_common::ScalarValue; use datafusion_expr::WindowFrame; use std::any::Any; -use std::ops::Range; use std::sync::Arc; /// A window expr that takes the form of a built in window function @@ -92,47 +91,35 @@ impl WindowExpr for BuiltInWindowExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let evaluator = self.expr.create_evaluator()?; let num_rows = batch.num_rows(); - let partition_columns = self.partition_columns(batch)?; - let partition_points = - self.evaluate_partition_points(num_rows, &partition_columns)?; - - let results = if evaluator.uses_window_frame() { + if evaluator.uses_window_frame() { let sort_options: Vec = self.order_by.iter().map(|o| o.options).collect(); let mut row_wise_results = vec![]; - for partition_range in &partition_points { - let length = partition_range.end - partition_range.start; - let (values, order_bys) = self - .get_values_orderbys(&batch.slice(partition_range.start, length))?; - let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); - // We iterate on each row to calculate window frame range and and window function result - for idx in 0..length { - let range = window_frame_ctx.calculate_range( - &order_bys, - &sort_options, - num_rows, - idx, - )?; - let range = Range { - start: range.0, - end: range.1, - }; - let value = evaluator.evaluate_inside_range(&values, range)?; - row_wise_results.push(value.to_array()); - } + + let length = batch.num_rows(); + let (values, order_bys) = self.get_values_orderbys(batch)?; + let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + // We iterate on each row to calculate window frame range and and window function result + for idx in 0..length { + let range = window_frame_ctx.calculate_range( + &order_bys, + &sort_options, + num_rows, + idx, + )?; + let value = evaluator.evaluate_inside_range(&values, range)?; + row_wise_results.push(value); } - row_wise_results + ScalarValue::iter_to_array(row_wise_results.into_iter()) } else if evaluator.include_rank() { let columns = self.sort_columns(batch)?; let sort_partition_points = self.evaluate_partition_points(num_rows, &columns)?; - evaluator.evaluate_with_rank(partition_points, sort_partition_points)? + evaluator.evaluate_with_rank(num_rows, &sort_partition_points) } else { let (values, _) = self.get_values_orderbys(batch)?; - evaluator.evaluate(&values, partition_points)? - }; - let results = results.iter().map(|i| i.as_ref()).collect::>(); - concat(&results).map_err(DataFusionError::ArrowError) + evaluator.evaluate(&values, num_rows) + } } fn get_window_frame(&self) -> &Arc { diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs index 4202058a3c5a1..45fe51178afcd 100644 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ b/datafusion/physical-expr/src/window/cume_dist.rs @@ -73,19 +73,19 @@ impl PartitionEvaluator for CumeDistEvaluator { true } - fn evaluate_partition_with_rank( + fn evaluate_with_rank( &self, - partition: Range, + num_rows: usize, ranks_in_partition: &[Range], ) -> Result { - let scaler = (partition.end - partition.start) as f64; + let scalar = num_rows as f64; let result = Float64Array::from_iter_values( ranks_in_partition .iter() .scan(0_u64, |acc, range| { let len = range.end - range.start; *acc += len as u64; - let value: f64 = (*acc as f64) / scaler; + let value: f64 = (*acc as f64) / scalar; let result = iter::repeat(value).take(len); Some(result) }) @@ -102,15 +102,14 @@ mod tests { fn test_i32_result( expr: &CumeDist, - partition: Range, + num_rows: usize, ranks: Vec>, expected: Vec, ) -> Result<()> { let result = expr .create_evaluator()? - .evaluate_with_rank(vec![partition], ranks)?; - assert_eq!(1, result.len()); - let result = as_float64_array(&result[0])?; + .evaluate_with_rank(num_rows, &ranks)?; + let result = as_float64_array(&result)?; let result = result.values(); assert_eq!(expected, result); Ok(()) @@ -121,19 +120,19 @@ mod tests { let r = cume_dist("arr".into()); let expected = vec![0.0; 0]; - test_i32_result(&r, 0..0, vec![], expected)?; + test_i32_result(&r, 0, vec![], expected)?; let expected = vec![1.0; 1]; - test_i32_result(&r, 0..1, vec![0..1], expected)?; + test_i32_result(&r, 1, vec![0..1], expected)?; let expected = vec![1.0; 2]; - test_i32_result(&r, 0..2, vec![0..2], expected)?; + test_i32_result(&r, 2, vec![0..2], expected)?; let expected = vec![0.5, 0.5, 1.0, 1.0]; - test_i32_result(&r, 0..4, vec![0..2, 2..4], expected)?; + test_i32_result(&r, 4, vec![0..2, 2..4], expected)?; let expected = vec![0.25, 0.5, 0.75, 1.0]; - test_i32_result(&r, 0..4, vec![0..1, 1..2, 2..3, 3..4], expected)?; + test_i32_result(&r, 4, vec![0..1, 1..2, 2..3, 3..4], expected)?; Ok(()) } diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 97076963f60d4..f4c176262ae46 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -28,7 +28,6 @@ use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use std::any::Any; use std::ops::Neg; -use std::ops::Range; use std::sync::Arc; /// window shift expression @@ -178,15 +177,10 @@ fn shift_with_default_value( } impl PartitionEvaluator for WindowShiftEvaluator { - fn evaluate_partition( - &self, - values: &[ArrayRef], - partition: Range, - ) -> Result { + fn evaluate(&self, values: &[ArrayRef], _num_rows: usize) -> Result { // LEAD, LAG window functions take single column, values will have size 1 let value = &values[0]; - let value = value.slice(partition.start, partition.end - partition.start); - shift_with_default_value(&value, self.shift_offset, self.default_value.as_ref()) + shift_with_default_value(value, self.shift_offset, self.default_value.as_ref()) } } @@ -205,9 +199,10 @@ mod tests { let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; let values = expr.evaluate_args(&batch)?; - let result = expr.create_evaluator()?.evaluate(&values, vec![0..8])?; - assert_eq!(1, result.len()); - let result = as_int32_array(&result[0])?; + let result = expr + .create_evaluator()? + .evaluate(&values, batch.num_rows())?; + let result = as_int32_array(&result)?; assert_eq!(expected, *result); Ok(()) } diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index 1608758d61b38..86500441df5bc 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -22,23 +22,6 @@ use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; use std::ops::Range; -/// Given a partition range, and the full list of sort partition points, given that the sort -/// partition points are sorted using [partition columns..., order columns...], the split -/// boundaries would align (what's sorted on [partition columns...] would definitely be sorted -/// on finer columns), so this will use binary search to find ranges that are within the -/// partition range and return the valid slice. -pub(crate) fn find_ranges_in_range<'a>( - partition_range: &Range, - sort_partition_points: &'a [Range], -) -> &'a [Range] { - let start_idx = sort_partition_points - .partition_point(|sort_range| sort_range.start < partition_range.start); - let end_idx = start_idx - + sort_partition_points[start_idx..] - .partition_point(|sort_range| sort_range.end <= partition_range.end); - &sort_partition_points[start_idx..end_idx] -} - /// Partition evaluator pub trait PartitionEvaluator { /// Whether the evaluator should be evaluated with rank @@ -50,49 +33,17 @@ pub trait PartitionEvaluator { false } - /// evaluate the partition evaluator against the partitions - fn evaluate( - &self, - values: &[ArrayRef], - partition_points: Vec>, - ) -> Result> { - partition_points - .into_iter() - .map(|partition| self.evaluate_partition(values, partition)) - .collect() - } - - /// evaluate the partition evaluator against the partitions with rank information - fn evaluate_with_rank( - &self, - partition_points: Vec>, - sort_partition_points: Vec>, - ) -> Result> { - partition_points - .into_iter() - .map(|partition| { - let ranks_in_partition = - find_ranges_in_range(&partition, &sort_partition_points); - self.evaluate_partition_with_rank(partition, ranks_in_partition) - }) - .collect() - } - /// evaluate the partition evaluator against the partition - fn evaluate_partition( - &self, - _values: &[ArrayRef], - _partition: Range, - ) -> Result { + fn evaluate(&self, _values: &[ArrayRef], _num_rows: usize) -> Result { Err(DataFusionError::NotImplemented( "evaluate_partition is not implemented by default".into(), )) } /// evaluate the partition evaluator against the partition but with rank - fn evaluate_partition_with_rank( + fn evaluate_with_rank( &self, - _partition: Range, + _num_rows: usize, _ranks_in_partition: &[Range], ) -> Result { Err(DataFusionError::NotImplemented( diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 8ed0319a10b0f..87e01528de5a8 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -114,9 +114,9 @@ impl PartitionEvaluator for RankEvaluator { true } - fn evaluate_partition_with_rank( + fn evaluate_with_rank( &self, - partition: Range, + num_rows: usize, ranks_in_partition: &[Range], ) -> Result { // see https://www.postgresql.org/docs/current/functions-window.html @@ -132,7 +132,7 @@ impl PartitionEvaluator for RankEvaluator { )), RankType::Percent => { // Returns the relative rank of the current row, that is (rank - 1) / (total partition rows - 1). The value thus ranges from 0 to 1 inclusive. - let denominator = (partition.end - partition.start) as f64; + let denominator = num_rows as f64; Arc::new(Float64Array::from_iter_values( ranks_in_partition .iter() @@ -177,15 +177,14 @@ mod tests { fn test_f64_result( expr: &Rank, - range: Range, + num_rows: usize, ranks: Vec>, expected: Vec, ) -> Result<()> { let result = expr .create_evaluator()? - .evaluate_with_rank(vec![range], ranks)?; - assert_eq!(1, result.len()); - let result = as_float64_array(&result[0])?; + .evaluate_with_rank(num_rows, &ranks)?; + let result = as_float64_array(&result)?; let result = result.values(); assert_eq!(expected, result); Ok(()) @@ -196,11 +195,8 @@ mod tests { ranks: Vec>, expected: Vec, ) -> Result<()> { - let result = expr - .create_evaluator()? - .evaluate_with_rank(vec![0..8], ranks)?; - assert_eq!(1, result.len()); - let result = as_uint64_array(&result[0])?; + let result = expr.create_evaluator()?.evaluate_with_rank(8, &ranks)?; + let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(expected, result); Ok(()) @@ -228,19 +224,19 @@ mod tests { // empty case let expected = vec![0.0; 0]; - test_f64_result(&r, 0..0, vec![0..0; 0], expected)?; + test_f64_result(&r, 0, vec![0..0; 0], expected)?; // singleton case let expected = vec![0.0]; - test_f64_result(&r, 0..1, vec![0..1], expected)?; + test_f64_result(&r, 1, vec![0..1], expected)?; // uniform case let expected = vec![0.0; 7]; - test_f64_result(&r, 0..7, vec![0..7], expected)?; + test_f64_result(&r, 7, vec![0..7], expected)?; // non-trivial case let expected = vec![0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5]; - test_f64_result(&r, 0..7, vec![0..3, 3..7], expected)?; + test_f64_result(&r, 7, vec![0..3, 3..7], expected)?; Ok(()) } diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index f70d9ea379dd7..b27ac29d27640 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -24,7 +24,6 @@ use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::{DataType, Field}; use datafusion_common::Result; use std::any::Any; -use std::ops::Range; use std::sync::Arc; /// row_number expression @@ -69,12 +68,7 @@ impl BuiltInWindowFunctionExpr for RowNumber { pub(crate) struct NumRowsEvaluator {} impl PartitionEvaluator for NumRowsEvaluator { - fn evaluate_partition( - &self, - _values: &[ArrayRef], - partition: Range, - ) -> Result { - let num_rows = partition.end - partition.start; + fn evaluate(&self, _values: &[ArrayRef], num_rows: usize) -> Result { Ok(Arc::new(UInt64Array::from_iter_values( 1..(num_rows as u64) + 1, ))) @@ -99,9 +93,8 @@ mod tests { let values = row_number.evaluate_args(&batch)?; let result = row_number .create_evaluator()? - .evaluate(&values, vec![0..8])?; - assert_eq!(1, result.len()); - let result = as_uint64_array(&result[0])?; + .evaluate(&values, batch.num_rows())?; + let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result); Ok(()) @@ -118,9 +111,8 @@ mod tests { let values = row_number.evaluate_args(&batch)?; let result = row_number .create_evaluator()? - .evaluate(&values, vec![0..8])?; - assert_eq!(1, result.len()); - let result = as_uint64_array(&result[0])?; + .evaluate(&values, batch.num_rows())?; + let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result); Ok(()) diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index b1af6cf49a0ab..bc35dd49b50d4 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -17,7 +17,7 @@ use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::compute::kernels::partition::lexicographical_partition_ranges; -use arrow::compute::kernels::sort::{SortColumn, SortOptions}; +use arrow::compute::kernels::sort::SortColumn; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::{reverse_sort_options, DataFusionError, Result}; @@ -86,20 +86,6 @@ pub trait WindowExpr: Send + Sync + Debug { /// expressions that's from the window function's order by clause, empty if absent fn order_by(&self) -> &[PhysicalSortExpr]; - /// get partition columns that can be used for partitioning, empty if absent - fn partition_columns(&self, batch: &RecordBatch) -> Result> { - self.partition_by() - .iter() - .map(|expr| { - PhysicalSortExpr { - expr: expr.clone(), - options: SortOptions::default(), - } - .evaluate_to_sort_column(batch) - }) - .collect() - } - /// get order by columns, empty if absent fn order_by_columns(&self, batch: &RecordBatch) -> Result> { self.order_by() @@ -110,10 +96,8 @@ pub trait WindowExpr: Send + Sync + Debug { /// get sort columns that can be used for peer evaluation, empty if absent fn sort_columns(&self, batch: &RecordBatch) -> Result> { - let mut sort_columns = self.partition_columns(batch)?; let order_by_columns = self.order_by_columns(batch)?; - sort_columns.extend(order_by_columns); - Ok(sort_columns) + Ok(order_by_columns) } /// Get values columns(argument of Window Function) diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 307ea91440df9..b49bd3a22a78f 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -26,6 +26,7 @@ use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use std::cmp::min; use std::collections::VecDeque; use std::fmt::Debug; +use std::ops::Range; use std::sync::Arc; /// This object stores the window frame state for use in incremental calculations. @@ -68,7 +69,7 @@ impl<'a> WindowFrameContext<'a> { sort_options: &[SortOptions], length: usize, idx: usize, - ) -> Result<(usize, usize)> { + ) -> Result> { match *self { WindowFrameContext::Rows(window_frame) => { Self::calculate_range_rows(window_frame, length, idx) @@ -99,7 +100,7 @@ impl<'a> WindowFrameContext<'a> { window_frame: &Arc, length: usize, idx: usize, - ) -> Result<(usize, usize)> { + ) -> Result> { let start = match window_frame.start_bound { // UNBOUNDED PRECEDING WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0, @@ -152,7 +153,7 @@ impl<'a> WindowFrameContext<'a> { return Err(DataFusionError::Internal("Rows should be Uint".to_string())) } }; - Ok((start, end)) + Ok(Range { start, end }) } } @@ -171,7 +172,7 @@ impl WindowFrameStateRange { sort_options: &[SortOptions], length: usize, idx: usize, - ) -> Result<(usize, usize)> { + ) -> Result> { let start = match window_frame.start_bound { WindowFrameBound::Preceding(ref n) => { if n.is_null() { @@ -240,7 +241,7 @@ impl WindowFrameStateRange { } } }; - Ok((start, end)) + Ok(Range { start, end }) } /// This function does the heavy lifting when finding range boundaries. It is meant to be @@ -333,7 +334,7 @@ impl WindowFrameStateGroups { range_columns: &[ArrayRef], length: usize, idx: usize, - ) -> Result<(usize, usize)> { + ) -> Result> { if range_columns.is_empty() { return Err(DataFusionError::Execution( "GROUPS mode requires an ORDER BY clause".to_string(), @@ -399,7 +400,7 @@ impl WindowFrameStateGroups { )) } }; - Ok((start, end)) + Ok(Range { start, end }) } /// This function does the heavy lifting when finding group boundaries. It is meant to be From bf7bd11c9c822c77beb1010576c34d66b8ee6035 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 15 Dec 2022 16:06:03 +0300 Subject: [PATCH 06/50] Add naive sort removal rule --- .../remove_unnecessary_sorts.rs | 40 +++++++++++++++++-- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index ccaf93e0e60d3..4c38f1f1e3501 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -73,6 +73,11 @@ impl PhysicalOptimizerRule for RemoveUnnecessarySorts { fn remove_unnecessary_sorts( requirements: PlanWithCorrespondingSort, ) -> Result> { + // Do analysis of naive SortRemoval at the beginning + // Remove Sorts that are already satisfied + if let Some(res) = analyze_immediate_sort_removal(&requirements)? { + return Ok(Some(res)); + } let mut new_children = requirements.plan.children().clone(); let mut new_sort_onwards = requirements.sort_onwards.clone(); for (idx, (child, sort_onward)) in new_children @@ -90,7 +95,6 @@ fn remove_unnecessary_sorts( }); if !is_ordering_satisfied { // During sort Removal we have invalidated ordering invariant fix it - // This is effectively moving sort above in the physical plan update_child_to_remove_unnecessary_sort(child, sort_onward)?; let sort_expr = required_ordering.to_vec(); *child = add_sort_above_child(child, sort_expr)?; @@ -144,10 +148,10 @@ fn remove_unnecessary_sorts( .enumerate() .take(new_plan.children().len()) { - let is_require_ordering = new_plan.required_input_ordering()[idx].is_none(); + let requires_ordering = new_plan.required_input_ordering()[idx].is_some(); //TODO: when `maintains_input_order` returns `Vec` use corresponding index if new_plan.maintains_input_order() - && is_require_ordering + && !requires_ordering && !new_sort_onward.is_empty() { new_sort_onward.push((idx, new_plan.clone())); @@ -155,6 +159,8 @@ fn remove_unnecessary_sorts( new_sort_onward.clear(); new_sort_onward.push((idx, new_plan.clone())); } else { + // These executors use SortExec, hence doesn't propagate + // sort above in the physical plan new_sort_onward.clear(); } } @@ -233,6 +239,33 @@ impl TreeNodeRewritable for PlanWithCorrespondingSort { } } +/// Analyzes `SortExec` to determine whether this Sort can be removed +fn analyze_immediate_sort_removal( + requirements: &PlanWithCorrespondingSort, +) -> Result> { + if let Some(sort_exec) = requirements.plan.as_any().downcast_ref::() { + if ordering_satisfy( + sort_exec.input().output_ordering(), + sort_exec.output_ordering(), + || sort_exec.input().equivalence_properties(), + ) { + // This sort is unnecessary we should remove it + let new_plan = sort_exec.input(); + // Since we know that Sort have exactly one child we can use first index safely + assert_eq!(requirements.sort_onwards.len(), 1); + let mut new_sort_onward = requirements.sort_onwards[0].to_vec(); + if !new_sort_onward.is_empty() { + new_sort_onward.pop(); + } + return Ok(Some(PlanWithCorrespondingSort { + plan: new_plan.clone(), + sort_onwards: vec![new_sort_onward], + })); + } + } + Ok(None) +} + /// Analyzes `WindowAggExec` to determine whether Sort can be removed fn analyze_window_sort_removal( window_agg_exec: &WindowAggExec, @@ -284,6 +317,7 @@ fn analyze_window_sort_removal( Ok(None) } } + /// Updates child such that unnecessary sorting below it is removed fn update_child_to_remove_unnecessary_sort( child: &mut Arc, From 4cb7258fb49c6b911ae652b436fcdf1e8797ea38 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 15 Dec 2022 17:07:56 +0300 Subject: [PATCH 07/50] Add todo for finer Sort removal handling --- .../core/src/physical_optimizer/remove_unnecessary_sorts.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index 4c38f1f1e3501..e14a611db385a 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -106,6 +106,10 @@ fn remove_unnecessary_sorts( let sort_exec = convert_to_sort_exec(&sort_any)?; let sort_output_ordering = sort_exec.output_ordering(); let sort_input_ordering = sort_exec.input().output_ordering(); + // TODO: Once we can ensure required ordering propagates to above without changes + // (or with changes trackable) compare `sort_input_ordering` and and `required_ordering` + // this changes will enable us to remove (a,b) -> Sort -> (a,b,c) -> Required(a,b) Sort + // from the plan. With current implementation we cannot remove Sort from above configuration. // Do naive analysis, where a SortExec is already sorted according to desired Sorting if ordering_satisfy(sort_input_ordering, sort_output_ordering, || { sort_exec.input().equivalence_properties() From aa4f7393d1840bb0586c9a08f91b343d3f03ed14 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Mon, 19 Dec 2022 15:14:32 +0300 Subject: [PATCH 08/50] Refactors to improve readability and reduce nesting --- datafusion/common/src/lib.rs | 20 +++--- datafusion/core/src/execution/context.rs | 7 +- .../remove_unnecessary_sorts.rs | 11 +-- .../core/src/physical_optimizer/utils.rs | 67 ++++++++----------- .../core/src/physical_plan/repartition.rs | 2 +- .../physical_plan/windows/window_agg_exec.rs | 62 +++++++++-------- 6 files changed, 80 insertions(+), 89 deletions(-) diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 63683f5af0242..9911b1499a7b8 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -65,7 +65,7 @@ macro_rules! downcast_value { }}; } -/// Compute the "reverse" of given `SortOptions`. +/// Computes the "reverse" of given `SortOptions`. // TODO: If/when arrow supports `!` for `SortOptions`, we can remove this. pub fn reverse_sort_options(options: SortOptions) -> SortOptions { SortOptions { @@ -74,14 +74,18 @@ pub fn reverse_sort_options(options: SortOptions) -> SortOptions { } } -/// Transposes 2d vector +/// Transposes the given vector of vectors. pub fn transpose(original: Vec>) -> Vec> { - assert!(!original.is_empty()); - let mut transposed = (0..original[0].len()).map(|_| vec![]).collect::>(); - for original_row in original { - for (item, transposed_row) in original_row.into_iter().zip(&mut transposed) { - transposed_row.push(item); + match original.as_slice() { + [] => vec![], + [first, ..] => { + let mut result = (0..first.len()).map(|_| vec![]).collect::>(); + for row in original { + for (item, transposed_row) in row.into_iter().zip(&mut result) { + transposed_row.push(item); + } + } + result } } - transposed } diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 00de69eab7540..c503caca2edc9 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -1592,9 +1592,10 @@ impl SessionState { // To make sure the SinglePartition is satisfied, run the BasicEnforcement again, originally it was the AddCoalescePartitionsExec here. physical_optimizers.push(Arc::new(BasicEnforcement::new())); - // `BasicEnforcement` stage conservatively inserts `SortExec`s before `WindowAggExec`s without - // a deep analysis of window frames. Such analysis may sometimes reveal that a `SortExec` is - // actually unnecessary. The rule below performs this analysis and removes such `SortExec`s. + // `BasicEnforcement` stage conservatively inserts `SortExec`s to satisfy ordering requirements. + // However, a deeper analysis may sometimes reveal that such a `SortExec` is actually unnecessary. + // These cases typically arise when we have reversible `WindowAggExec`s or deep subqueries. The + // rule below performs this analysis and removes unnecessary `SortExec`s. physical_optimizers.push(Arc::new(RemoveUnnecessarySorts::new())); SessionState { diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index e14a611db385a..2f4e326c49f02 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -20,7 +20,7 @@ //! valid, or invalid physical plans (in terms of Sorting requirement) use crate::error::Result; use crate::physical_optimizer::utils::{ - add_sort_above_child, ordering_satisfy, ordering_satisfy_inner, + add_sort_above_child, ordering_satisfy, ordering_satisfy_concrete, }; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::rewrite::TreeNodeRewritable; @@ -89,10 +89,11 @@ fn remove_unnecessary_sorts( let physical_ordering = child.output_ordering(); match (required_ordering, physical_ordering) { (Some(required_ordering), Some(physical_ordering)) => { - let is_ordering_satisfied = - ordering_satisfy_inner(physical_ordering, required_ordering, || { - child.equivalence_properties() - }); + let is_ordering_satisfied = ordering_satisfy_concrete( + physical_ordering, + required_ordering, + || child.equivalence_properties(), + ); if !is_ordering_satisfied { // During sort Removal we have invalidated ordering invariant fix it update_child_to_remove_unnecessary_sort(child, sort_onward)?; diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index 7389da085d600..94394ca527ee6 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -51,7 +51,7 @@ pub fn optimize_children( } } -/// Check the required ordering requirements are satisfied by the provided PhysicalSortExprs. +/// Checks whether given ordering requirements are satisfied by provided [PhysicalSortExpr]s. pub fn ordering_satisfy EquivalenceProperties>( provided: Option<&[PhysicalSortExpr]>, required: Option<&[PhysicalSortExpr]>, @@ -61,56 +61,43 @@ pub fn ordering_satisfy EquivalenceProperties>( (_, None) => true, (None, Some(_)) => false, (Some(provided), Some(required)) => { - ordering_satisfy_inner(provided, required, equal_properties) + ordering_satisfy_concrete(provided, required, equal_properties) } } } -pub fn ordering_satisfy_inner EquivalenceProperties>( +pub fn ordering_satisfy_concrete EquivalenceProperties>( provided: &[PhysicalSortExpr], required: &[PhysicalSortExpr], equal_properties: F, ) -> bool { if required.len() > provided.len() { false - } else { - let fast_match = required + } else if required + .iter() + .zip(provided.iter()) + .all(|(order1, order2)| order1.eq(order2)) + { + true + } else if let eq_classes @ [_, ..] = equal_properties().classes() { + let normalized_required_exprs = required .iter() - .zip(provided.iter()) - .all(|(order1, order2)| order1.eq(order2)); - - if !fast_match { - let eq_properties = equal_properties(); - let eq_classes = eq_properties.classes(); - if !eq_classes.is_empty() { - let normalized_required_exprs = required - .iter() - .map(|e| { - normalize_sort_expr_with_equivalence_properties( - e.clone(), - eq_classes, - ) - }) - .collect::>(); - let normalized_provided_exprs = provided - .iter() - .map(|e| { - normalize_sort_expr_with_equivalence_properties( - e.clone(), - eq_classes, - ) - }) - .collect::>(); - normalized_required_exprs - .iter() - .zip(normalized_provided_exprs.iter()) - .all(|(order1, order2)| order1.eq(order2)) - } else { - fast_match - } - } else { - fast_match - } + .map(|e| { + normalize_sort_expr_with_equivalence_properties(e.clone(), eq_classes) + }) + .collect::>(); + let normalized_provided_exprs = provided + .iter() + .map(|e| { + normalize_sort_expr_with_equivalence_properties(e.clone(), eq_classes) + }) + .collect::>(); + normalized_required_exprs + .iter() + .zip(normalized_provided_exprs.iter()) + .all(|(order1, order2)| order1.eq(order2)) + } else { + false } } diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs index 12424c8587971..f7005d113e306 100644 --- a/datafusion/core/src/physical_plan/repartition.rs +++ b/datafusion/core/src/physical_plan/repartition.rs @@ -297,12 +297,12 @@ impl ExecutionPlan for RepartitionExec { } fn maintains_input_order(&self) -> bool { + // We preserve ordering when input partitioning is 1 let n_input = match self.input().output_partitioning() { Partitioning::RoundRobinBatch(n) => n, Partitioning::Hash(_, n) => n, Partitioning::UnknownPartitioning(n) => n, }; - // We preserve ordering when input partitioning is 1 n_input <= 1 } diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs index 837f32ac69b53..c709fe4942571 100644 --- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs @@ -136,23 +136,25 @@ impl WindowAggExec { self.input_schema.clone() } - /// Get Partition Columns + /// Get partition keys pub fn partition_by_sort_keys(&self) -> Result> { - // All window exprs have same partition by hance we just use first one + let mut result = vec![]; + // All window exprs have the same partition by, so we just use the first one: let partition_by = self.window_expr()[0].partition_by(); - let mut partition_columns = vec![]; - for elem in partition_by { - if let Some(sort_keys) = &self.sort_keys { - for a in sort_keys { - if a.expr.eq(elem) { - partition_columns.push(a.clone()); - break; - } - } + let sort_keys = self + .sort_keys + .as_ref() + .map_or_else(|| &[] as &[PhysicalSortExpr], |v| v.as_slice()); + for item in partition_by { + if let Some(a) = sort_keys.iter().find(|&e| e.expr.eq(item)) { + result.push(a.clone()); + } else { + return Err(DataFusionError::Execution( + "Partition key not found in sort keys".to_string(), + )); } } - assert_eq!(partition_by.len(), partition_columns.len()); - Ok(partition_columns) + Ok(result) } } @@ -395,12 +397,16 @@ impl WindowAggStream { let batch = concat_batches(&self.input.schema(), &self.batches)?; - // calculate window cols - let partition_columns = self.partition_columns(&batch)?; + let partition_by_sort_keys = self + .partition_by_sort_keys + .iter() + .map(|elem| elem.evaluate_to_sort_column(&batch)) + .collect::>>()?; let partition_points = - self.evaluate_partition_points(batch.num_rows(), &partition_columns)?; + self.evaluate_partition_points(batch.num_rows(), &partition_by_sort_keys)?; let mut partition_results = vec![]; + // Calculate window cols for partition_point in partition_points { let length = partition_point.end - partition_point.start; partition_results.push( @@ -425,31 +431,23 @@ impl WindowAggStream { RecordBatch::try_new(self.schema.clone(), columns) } - /// Get Partition Columns - pub fn partition_columns(&self, batch: &RecordBatch) -> Result> { - self.partition_by_sort_keys - .iter() - .map(|elem| elem.evaluate_to_sort_column(batch)) - .collect::>>() - } - - /// evaluate the partition points given the sort columns; if the sort columns are - /// empty then the result will be a single element vec of the whole column rows. + /// Evaluates the partition points given the sort columns. If the sort columns are + /// empty, then the result will be a single element vector spanning the entire batch. fn evaluate_partition_points( &self, num_rows: usize, partition_columns: &[SortColumn], ) -> Result>> { - if partition_columns.is_empty() { - Ok(vec![Range { + Ok(if partition_columns.is_empty() { + vec![Range { start: 0, end: num_rows, - }]) + }] } else { - Ok(lexicographical_partition_ranges(partition_columns) + lexicographical_partition_ranges(partition_columns) .map_err(DataFusionError::ArrowError)? - .collect::>()) - } + .collect::>() + }) } } From 6309b013efa2974d7581eccd346f694c419d5e12 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 19 Dec 2022 15:57:15 +0300 Subject: [PATCH 09/50] reverse expr returns Option (no need for support check) --- .../remove_unnecessary_sorts.rs | 37 +++++++++---------- .../physical-expr/src/aggregate/count.rs | 8 +--- datafusion/physical-expr/src/aggregate/mod.rs | 15 ++------ datafusion/physical-expr/src/aggregate/sum.rs | 8 +--- .../physical-expr/src/window/aggregate.rs | 22 +++++------ .../physical-expr/src/window/built_in.rs | 22 +++++------ .../window/built_in_window_function_expr.rs | 15 ++------ .../physical-expr/src/window/lead_lag.rs | 8 +--- .../physical-expr/src/window/nth_value.rs | 17 ++------- .../physical-expr/src/window/window_expr.rs | 5 +-- 10 files changed, 55 insertions(+), 102 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index 2f4e326c49f02..461873b9da70b 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -295,32 +295,29 @@ fn analyze_window_sort_removal( &sort_exec.input().schema(), physical_ordering, )?; - let all_window_fns_reversible = - window_expr.iter().all(|e| e.is_window_fn_reversible()); - let is_reversal_blocking = should_reverse && !all_window_fns_reversible; - - if can_skip_sorting && !is_reversal_blocking { - let window_expr = if should_reverse { + if can_skip_sorting { + let new_window_expr = if should_reverse { window_expr .iter() .map(|e| e.get_reversed_expr()) - .collect::>>()? + .collect::>>() } else { - window_expr.to_vec() + Some(window_expr.to_vec()) }; - let new_child = remove_corresponding_sort_from_sub_plan(sort_onward)?; - let new_schema = new_child.schema(); - let new_plan = Arc::new(WindowAggExec::try_new( - window_expr, - new_child, - new_schema, - window_agg_exec.partition_keys.clone(), - Some(physical_ordering.to_vec()), - )?); - Ok(Some(PlanWithCorrespondingSort::new(new_plan))) - } else { - Ok(None) + if let Some(window_expr) = new_window_expr { + let new_child = remove_corresponding_sort_from_sub_plan(sort_onward)?; + let new_schema = new_child.schema(); + let new_plan = Arc::new(WindowAggExec::try_new( + window_expr, + new_child, + new_schema, + window_agg_exec.partition_keys.clone(), + Some(physical_ordering.to_vec()), + )?); + return Ok(Some(PlanWithCorrespondingSort::new(new_plan))); + } } + Ok(None) } /// Updates child such that unnecessary sorting below it is removed diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index e5483495eb731..ea2fe4d15fb76 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -105,12 +105,8 @@ impl AggregateExpr for Count { Ok(Box::new(CountRowAccumulator::new(start_index))) } - fn is_window_fn_reversible(&self) -> bool { - true - } - - fn reverse_expr(&self) -> Result> { - Ok(Arc::new(self.clone())) + fn reverse_expr(&self) -> Option> { + Some(Arc::new(self.clone())) } } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index b0776886e69d5..325513293f5fd 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -102,18 +102,9 @@ pub trait AggregateExpr: Send + Sync + Debug { ))) } - /// Get whether window function is reversible - /// make `true` if `reverse_expr` is implemented - fn is_window_fn_reversible(&self) -> bool { - false - } - /// Construct Reverse Expression - // Typically expression itself for aggregate functions - fn reverse_expr(&self) -> Result> { - Err(DataFusionError::NotImplemented(format!( - "reverse_expr hasn't been implemented for {:?} yet", - self - ))) + // Typically expression itself for aggregate functions By default returns None + fn reverse_expr(&self) -> Option> { + None } } diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index c3bc4dbfaecdf..ebee61010018b 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -133,12 +133,8 @@ impl AggregateExpr for Sum { ))) } - fn is_window_fn_reversible(&self) -> bool { - true - } - - fn reverse_expr(&self) -> Result> { - Ok(Arc::new(self.clone())) + fn reverse_expr(&self) -> Option> { + Some(Arc::new(self.clone())) } } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 1268559fe5688..4d4509df028eb 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -150,16 +150,16 @@ impl WindowExpr for AggregateWindowExpr { &self.window_frame } - fn is_window_fn_reversible(&self) -> bool { - self.aggregate.as_ref().is_window_fn_reversible() - } - - fn get_reversed_expr(&self) -> Result> { - Ok(Arc::new(AggregateWindowExpr::new( - self.aggregate.reverse_expr()?, - &self.partition_by.clone(), - &reverse_order_bys(&self.order_by), - Arc::new(self.window_frame.reverse()), - ))) + fn get_reversed_expr(&self) -> Option> { + if let Some(reverse_expr) = self.aggregate.reverse_expr() { + Some(Arc::new(AggregateWindowExpr::new( + reverse_expr, + &self.partition_by.clone(), + &reverse_order_bys(&self.order_by), + Arc::new(self.window_frame.reverse()), + ))) + } else { + None + } } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 0b1e5ee8f19cf..8d6dc6008abbf 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -126,16 +126,16 @@ impl WindowExpr for BuiltInWindowExpr { &self.window_frame } - fn is_window_fn_reversible(&self) -> bool { - self.expr.as_ref().is_window_fn_reversible() - } - - fn get_reversed_expr(&self) -> Result> { - Ok(Arc::new(BuiltInWindowExpr::new( - self.expr.reverse_expr()?, - &self.partition_by.clone(), - &reverse_order_bys(&self.order_by), - Arc::new(self.window_frame.reverse()), - ))) + fn get_reversed_expr(&self) -> Option> { + if let Some(reverse_expr) = self.expr.reverse_expr() { + Some(Arc::new(BuiltInWindowExpr::new( + reverse_expr, + &self.partition_by.clone(), + &reverse_order_bys(&self.order_by), + Arc::new(self.window_frame.reverse()), + ))) + } else { + None + } } } diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index 71b100b54e5e4..e7553e566fa2b 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -20,7 +20,7 @@ use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::Result; use std::any::Any; use std::sync::Arc; @@ -59,17 +59,8 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { /// Create built-in window evaluator with a batch fn create_evaluator(&self) -> Result>; - /// Get whether window function is reversible - /// make true if `reverse_expr` is implemented - fn is_window_fn_reversible(&self) -> bool { - false - } - /// Construct Reverse Expression - fn reverse_expr(&self) -> Result> { - Err(DataFusionError::NotImplemented(format!( - "reverse_expr hasn't been implemented for {:?} yet", - self - ))) + fn reverse_expr(&self) -> Option> { + None } } diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index f4c176262ae46..e18815c4c3a62 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -107,12 +107,8 @@ impl BuiltInWindowFunctionExpr for WindowShift { })) } - fn is_window_fn_reversible(&self) -> bool { - true - } - - fn reverse_expr(&self) -> Result> { - Ok(Arc::new(Self { + fn reverse_expr(&self) -> Option> { + Some(Arc::new(Self { name: self.name.clone(), data_type: self.data_type.clone(), shift_offset: -self.shift_offset, diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 2942a797b8526..e998b47018a52 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -124,24 +124,13 @@ impl BuiltInWindowFunctionExpr for NthValue { Ok(Box::new(NthValueEvaluator { kind: self.kind })) } - fn is_window_fn_reversible(&self) -> bool { - match self.kind { - NthValueKind::First | NthValueKind::Last => true, - NthValueKind::Nth(_) => false, - } - } - - fn reverse_expr(&self) -> Result> { + fn reverse_expr(&self) -> Option> { let reversed_kind = match self.kind { NthValueKind::First => NthValueKind::Last, NthValueKind::Last => NthValueKind::First, - NthValueKind::Nth(_) => { - return Err(DataFusionError::Execution( - "Cannot take reverse of NthValue".to_string(), - )) - } + NthValueKind::Nth(_) => return None, }; - Ok(Arc::new(Self { + Some(Arc::new(Self { name: self.name.clone(), expr: self.expr.clone(), data_type: self.data_type.clone(), diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index bc35dd49b50d4..78c3935c567d0 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -116,11 +116,8 @@ pub trait WindowExpr: Send + Sync + Debug { // Get window frame of this WindowExpr fn get_window_frame(&self) -> &Arc; - /// Get whether window function can be reversed - fn is_window_fn_reversible(&self) -> bool; - /// get reversed expression - fn get_reversed_expr(&self) -> Result>; + fn get_reversed_expr(&self) -> Option>; } /// Reverses the ORDER BY expression, which is useful during equivalent window From 91629b8e5ea6830fc4bab1d0b8b59d5025199fea Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 20 Dec 2022 11:44:37 +0300 Subject: [PATCH 10/50] fix tests --- datafusion/core/src/execution/context.rs | 2 +- .../src/physical_optimizer/enforcement.rs | 2 +- datafusion/core/tests/sql/window.rs | 55 +++++++++---------- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index c03a9ad996cf3..8c7a8cbe82c49 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -99,8 +99,8 @@ use url::Url; use crate::catalog::listing_schema::ListingSchemaProvider; use crate::datasource::object_store::ObjectStoreUrl; -use crate::physical_optimizer::remove_unnecessary_sorts::RemoveUnnecessarySorts; use crate::execution::memory_pool::MemoryPool; +use crate::physical_optimizer::remove_unnecessary_sorts::RemoveUnnecessarySorts; use uuid::Uuid; use super::options::{ diff --git a/datafusion/core/src/physical_optimizer/enforcement.rs b/datafusion/core/src/physical_optimizer/enforcement.rs index 4a76573004b0f..3e74860b5be15 100644 --- a/datafusion/core/src/physical_optimizer/enforcement.rs +++ b/datafusion/core/src/physical_optimizer/enforcement.rs @@ -30,7 +30,7 @@ use crate::physical_plan::joins::{ use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::rewrite::TreeNodeRewritable; -use crate::physical_plan::sorts::sort::SortOptions; +use crate::physical_plan::sorts::sort::{SortExec, SortOptions}; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::Partitioning; diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 3c922c0f2f657..d4c117e339a86 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1948,13 +1948,13 @@ async fn test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> { // We cannot reverse each window function (ROW_NUMBER is not reversible) let expected = { vec![ - "ProjectionExec: expr=[c9@5 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as rn2]", + "ProjectionExec: expr=[c9@5 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum2, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2]", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c9@4 DESC]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@4 ASC NULLS LAST,c1@2 ASC NULLS LAST,c2@3 ASC NULLS LAST]", + " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c9@2 ASC NULLS LAST,c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", + " SortExec: [c9@2 DESC,c1@0 DESC]", ] }; @@ -1969,15 +1969,15 @@ async fn test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> { let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------------+-------------+-------------+-----+", - "| c9 | sum1 | sum2 | rn2 |", - "+------------+-------------+-------------+-----+", - "| 4268716378 | 8498370520 | 24997484146 | 1 |", - "| 4229654142 | 12714811027 | 29012926487 | 2 |", - "| 4216440507 | 16858984380 | 28743001064 | 3 |", - "| 4144173353 | 20935849039 | 28472563256 | 4 |", - "| 4076864659 | 24997484146 | 28118515915 | 5 |", - "+------------+-------------+-------------+-----+", + "+-----------+------------+-----------+-----+", + "| c9 | sum1 | sum2 | rn2 |", + "+-----------+------------+-----------+-----+", + "| 28774375 | 745354217 | 91818943 | 100 |", + "| 63044568 | 988558066 | 232866360 | 99 |", + "| 141047417 | 1285934966 | 374546521 | 98 |", + "| 141680161 | 1654839259 | 519841132 | 97 |", + "| 145294611 | 1980231675 | 745354217 | 96 |", + "+-----------+------------+-----------+-----+", ]; assert_batches_eq!(expected, &actual); @@ -2038,19 +2038,19 @@ async fn test_window_agg_complex_plan() -> Result<()> { // Unnecessary SortExecs are removed let expected = { vec![ - "ProjectionExec: expr=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as a, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as b, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@4 as c, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@8 as d, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@12 as e, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@4 as f, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@8 as g, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as h, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as i, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as j, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as k, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@13 as l, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as m, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@17 as n, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@18 as o, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as p, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@2 as a1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@2 as b1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@6 as c1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@10 as d1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@14 as e1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@6 as f1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@10 as g1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as h1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as j1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as k1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as l1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@13 as m1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as n1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as o1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@3 as h11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@3 as j11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@7 as k11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@11 as l11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@15 as m11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@7 as n11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@11 as o11]", + "ProjectionExec: expr=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as a, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as b, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@15 as c, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@7 as d, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@11 as e, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@15 as f, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@7 as g, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as h, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as i, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as j, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as k, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@12 as l, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as m, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as n, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as o, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as p, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@2 as a1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@2 as b1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@17 as c1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@9 as d1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@13 as e1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@17 as f1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@9 as g1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as h1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as j1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as k1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as l1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@12 as m1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as n1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as o1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@3 as h11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@3 as j11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@18 as k11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@10 as l11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@14 as m11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@18 as n11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@10 as o11]", " GlobalLimitExec: skip=0, fetch=5", " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", - " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", - " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", - " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", - " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", - " SortExec: [c3@5 ASC NULLS LAST,c2@4 ASC NULLS LAST]", - " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", - " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", - " SortExec: [c3@3 DESC,c1@1 ASC NULLS LAST]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@17 ASC NULLS LAST,c2@16 ASC NULLS LAST]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@16 ASC NULLS LAST,c1@14 ASC]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", - " SortExec: [c3@2 ASC NULLS LAST,c1@0 ASC]", + " SortExec: [c3@2 DESC,c1@0 ASC NULLS LAST]", ] }; @@ -2089,9 +2089,8 @@ async fn test_window_agg_sort_orderby_reversed_partitionby_plan() -> Result<()> vec![ "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c1@0 ASC,c9@1 DESC]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }, SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]", ] }; @@ -2146,7 +2145,7 @@ async fn test_window_agg_sort_partitionby_reversed_plan() -> Result<()> { " GlobalLimitExec: skip=0, fetch=5", " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c1@0 ASC,c9@1 DESC]", + " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]", ] }; From ae451a4db27b5a9906e3366a2a5349d11dc13591 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 20 Dec 2022 15:22:05 +0300 Subject: [PATCH 11/50] partition by and order by no longer ends up at the same window group --- datafusion/core/src/physical_plan/planner.rs | 2 +- datafusion/core/tests/sql/window.rs | 22 ++++--- datafusion/expr/src/logical_plan/builder.rs | 2 +- datafusion/expr/src/utils.rs | 60 +++++++++++++------- 4 files changed, 56 insertions(+), 30 deletions(-) diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 548a1dd36bd3d..ef7fa830a567a 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -582,7 +582,7 @@ impl DefaultPhysicalPlanner { let physical_input_schema = input_exec.schema(); let sort_keys = sort_keys .iter() - .map(|e| match e { + .map(|(e, _)| match e { Expr::Sort(expr::Sort { expr, asc, diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index d4c117e339a86..7d39a3c98db92 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1737,15 +1737,18 @@ async fn test_window_partition_by_order_by() -> Result<()> { let logical_plan = state.optimize(&plan)?; let physical_plan = state.create_physical_plan(&logical_plan).await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - // Only 1 SortExec was added let expected = { vec![ - "ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as COUNT(UInt8(1))]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", - " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", + "ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as COUNT(UInt8(1))]", + " WindowAggExec: wdw=[COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", + " SortExec: [c1@1 ASC NULLS LAST,c2@2 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 2)", - " RepartitionExec: partitioning=RoundRobinBatch(2)", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 1 }], 2)", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", + " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }, Column { name: \"c2\", index: 1 }], 2)", + " RepartitionExec: partitioning=RoundRobinBatch(2)", ] }; @@ -2087,10 +2090,11 @@ async fn test_window_agg_sort_orderby_reversed_partitionby_plan() -> Result<()> // Only 1 SortExec was added let expected = { vec![ - "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", + "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum2]", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }, SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]", ] }; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index bf2a1d0018675..eeb3215c4b6fd 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -257,7 +257,7 @@ impl LogicalPlanBuilder { // The sort_by() implementation here is a stable sort. // Note that by this rule if there's an empty over, it'll be at the top level groups.sort_by(|(key_a, _), (key_b, _)| { - for (first, second) in key_a.iter().zip(key_b.iter()) { + for ((first, _), (second, _)) in key_a.iter().zip(key_b.iter()) { let key_ordering = compare_sort_expr(first, second, plan.schema()); match key_ordering { Ordering::Less => { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 2577c3a1970ca..1ab4595d1c886 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -203,7 +203,7 @@ pub fn expand_qualified_wildcard( expand_wildcard(&qualifier_schema, plan) } -type WindowSortKey = Vec; +type WindowSortKey = Vec<(Expr, bool)>; /// Generate a sort key for a given window expr's partition_by and order_bu expr pub fn generate_sort_key( @@ -223,6 +223,7 @@ pub fn generate_sort_key( .collect::>>()?; let mut final_sort_keys = vec![]; + let mut is_partition_flag = vec![]; partition_by.iter().for_each(|e| { // By default, create sort key with ASC is true and NULLS LAST to be consistent with // PostgreSQL's rule: https://www.postgresql.org/docs/current/queries-order.html @@ -231,18 +232,26 @@ pub fn generate_sort_key( let order_by_key = &order_by[pos]; if !final_sort_keys.contains(order_by_key) { final_sort_keys.push(order_by_key.clone()); + is_partition_flag.push(true); } } else if !final_sort_keys.contains(&e) { final_sort_keys.push(e); + is_partition_flag.push(true); } }); order_by.iter().for_each(|e| { if !final_sort_keys.contains(e) { final_sort_keys.push(e.clone()); + is_partition_flag.push(false); } }); - Ok(final_sort_keys) + let res = final_sort_keys + .into_iter() + .zip(is_partition_flag) + .map(|(lhs, rhs)| (lhs, rhs)) + .collect::>(); + Ok(res) } /// Compare the sort expr as PostgreSQL's common_prefix_cmp(): @@ -1027,9 +1036,13 @@ mod tests { let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs)?; - let key1 = vec![age_asc.clone(), name_desc.clone()]; + let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)]; let key2 = vec![]; - let key3 = vec![name_desc, age_asc, created_at_desc]; + let key3 = vec![ + (name_desc, false), + (age_asc, false), + (created_at_desc, false), + ]; let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![ (key1, vec![&max1, &min3]), @@ -1096,21 +1109,30 @@ mod tests { ]; let expected = vec![ - Expr::Sort(Sort { - expr: Box::new(col("age")), - asc: asc_, - nulls_first: nulls_first_, - }), - Expr::Sort(Sort { - expr: Box::new(col("name")), - asc: asc_, - nulls_first: nulls_first_, - }), - Expr::Sort(Sort { - expr: Box::new(col("created_at")), - asc: true, - nulls_first: false, - }), + ( + Expr::Sort(Sort { + expr: Box::new(col("age")), + asc: asc_, + nulls_first: nulls_first_, + }), + true, + ), + ( + Expr::Sort(Sort { + expr: Box::new(col("name")), + asc: asc_, + nulls_first: nulls_first_, + }), + true, + ), + ( + Expr::Sort(Sort { + expr: Box::new(col("created_at")), + asc: true, + nulls_first: false, + }), + true, + ), ]; let result = generate_sort_key(partition_by, order_by)?; assert_eq!(expected, result); From 94c784b265d3e3c70624b06dffba242a03a2eb6c Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 15 Dec 2022 17:07:56 +0300 Subject: [PATCH 12/50] Bounded window exec --- datafusion/core/Cargo.toml | 1 + datafusion/core/benches/data_utils/mod.rs | 2 + datafusion/core/src/execution/context.rs | 4 + datafusion/core/src/physical_optimizer/mod.rs | 1 + .../remove_unnecessary_sorts.rs | 83 +- .../replace_window_with_bounded_impl.rs | 95 +++ datafusion/core/src/physical_plan/common.rs | 28 + .../windows/bounded_window_agg_exec.rs | 712 ++++++++++++++++++ .../core/src/physical_plan/windows/mod.rs | 2 + datafusion/core/tests/sql/window.rs | 26 +- datafusion/core/tests/window_fuzz.rs | 309 ++++++++ datafusion/expr/src/accumulator.rs | 8 +- datafusion/expr/src/window_frame.rs | 10 + datafusion/physical-expr/Cargo.toml | 1 + .../physical-expr/src/aggregate/count.rs | 4 + datafusion/physical-expr/src/aggregate/mod.rs | 6 + datafusion/physical-expr/src/aggregate/sum.rs | 4 + .../physical-expr/src/window/aggregate.rs | 159 +++- .../physical-expr/src/window/built_in.rs | 116 ++- .../window/built_in_window_function_expr.rs | 6 + .../physical-expr/src/window/cume_dist.rs | 5 + .../physical-expr/src/window/lead_lag.rs | 88 ++- datafusion/physical-expr/src/window/mod.rs | 6 + .../physical-expr/src/window/nth_value.rs | 52 +- .../src/window/partition_evaluator.rs | 50 +- datafusion/physical-expr/src/window/rank.rs | 77 +- .../physical-expr/src/window/row_number.rs | 45 +- .../physical-expr/src/window/window_expr.rs | 157 +++- test-utils/src/lib.rs | 5 +- 29 files changed, 2020 insertions(+), 42 deletions(-) create mode 100644 datafusion/core/src/physical_optimizer/replace_window_with_bounded_impl.rs create mode 100644 datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs create mode 100644 datafusion/core/tests/window_fuzz.rs diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 3a7313b5a07c2..3c20a912ca6f8 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -75,6 +75,7 @@ flate2 = { version = "1.0.24", optional = true } futures = "0.3" glob = "0.3.0" hashbrown = { version = "0.13", features = ["raw"] } +indexmap = "1.9.2" itertools = "0.10" lazy_static = { version = "^1.4.0" } log = "^0.4" diff --git a/datafusion/core/benches/data_utils/mod.rs b/datafusion/core/benches/data_utils/mod.rs index 575e1831c8380..082d0a258e60d 100644 --- a/datafusion/core/benches/data_utils/mod.rs +++ b/datafusion/core/benches/data_utils/mod.rs @@ -17,6 +17,7 @@ //! This module provides the in-memory table for more realistic benchmarking. +use arrow::array::Int32Array; use arrow::{ array::Float32Array, array::Float64Array, @@ -31,6 +32,7 @@ use datafusion::from_slice::FromSlice; use rand::rngs::StdRng; use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; +use std::collections::BTreeMap; use std::sync::Arc; /// create an in-memory table given the partition len, array len, and batch size, diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index e0ebb8e018287..6b8d917eaa6f7 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -100,6 +100,7 @@ use url::Url; use crate::catalog::listing_schema::ListingSchemaProvider; use crate::datasource::object_store::ObjectStoreUrl; use crate::physical_optimizer::remove_unnecessary_sorts::RemoveUnnecessarySorts; +use crate::physical_optimizer::replace_window_with_bounded_impl::ReplaceWindowWithBoundedImpl; use uuid::Uuid; use super::options::{ @@ -1602,6 +1603,9 @@ impl SessionState { // actually unnecessary. The rule below performs this analysis and removes such `SortExec`s. physical_optimizers.push(Arc::new(RemoveUnnecessarySorts::new())); + // Replace WindowAggExec with BoundedWindowAggExec if conditions are met + physical_optimizers.push(Arc::new(ReplaceWindowWithBoundedImpl::new())); + SessionState { session_id, optimizer: Optimizer::new(&optimizer_config), diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index a69aa16c343bd..fe3bbf9bb8320 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -26,6 +26,7 @@ pub mod optimizer; pub mod pruning; pub mod remove_unnecessary_sorts; pub mod repartition; +pub mod replace_window_with_bounded_impl; mod utils; pub use optimizer::PhysicalOptimizerRule; diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index 4c38f1f1e3501..58f3433013046 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -104,12 +104,13 @@ fn remove_unnecessary_sorts( // can do analysis for sort removal let (_, sort_any) = sort_onward[0].clone(); let sort_exec = convert_to_sort_exec(&sort_any)?; - let sort_output_ordering = sort_exec.output_ordering(); let sort_input_ordering = sort_exec.input().output_ordering(); // Do naive analysis, where a SortExec is already sorted according to desired Sorting - if ordering_satisfy(sort_input_ordering, sort_output_ordering, || { - sort_exec.input().equivalence_properties() - }) { + if ordering_satisfy( + sort_input_ordering, + Some(required_ordering), + || sort_exec.input().equivalence_properties(), + ) { update_child_to_remove_unnecessary_sort(child, sort_onward)?; } else if let Some(window_agg_exec) = requirements.plan.as_any().downcast_ref::() @@ -825,6 +826,80 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_remove_unnecessary_sort2() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let sort_exec = Arc::new(SortExec::try_new(sort_exprs.clone(), source, None)?) + as Arc; + let sort_preserving_merge_exec = + Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec)) + as Arc; + let sort_exprs = vec![ + PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: col("non_nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }, + ]; + let sort_exec = Arc::new(SortExec::try_new( + sort_exprs.clone(), + sort_preserving_merge_exec, + None, + )?) as Arc; + let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new( + vec![sort_exprs[0].clone()], + sort_exec, + )) as Arc; + let physical_plan = sort_preserving_merge_exec; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } + #[tokio::test] async fn test_change_wrong_sorting() -> Result<()> { let session_ctx = SessionContext::new(); diff --git a/datafusion/core/src/physical_optimizer/replace_window_with_bounded_impl.rs b/datafusion/core/src/physical_optimizer/replace_window_with_bounded_impl.rs new file mode 100644 index 0000000000000..886455904691b --- /dev/null +++ b/datafusion/core/src/physical_optimizer/replace_window_with_bounded_impl.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! CoalesceBatches optimizer that groups batches together rows +//! in bigger batches to avoid overhead with small batches + +use crate::physical_plan::windows::BoundedWindowAggExec; +use crate::physical_plan::windows::WindowAggExec; +use crate::{ + error::Result, + physical_optimizer::PhysicalOptimizerRule, + physical_plan::{ + coalesce_batches::CoalesceBatchesExec, filter::FilterExec, joins::HashJoinExec, + repartition::RepartitionExec, rewrite::TreeNodeRewritable, + }, +}; +use datafusion_expr::WindowFrameUnits; +use datafusion_physical_expr::window::WindowExpr; +use std::sync::Arc; + +/// Optimizer rule that introduces CoalesceBatchesExec to avoid overhead with small batches that +/// are produced by highly selective filters +#[derive(Default)] +pub struct ReplaceWindowWithBoundedImpl {} + +impl ReplaceWindowWithBoundedImpl { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} +impl PhysicalOptimizerRule for ReplaceWindowWithBoundedImpl { + fn optimize( + &self, + plan: Arc, + _config: &crate::execution::context::SessionConfig, + ) -> Result> { + plan.transform_up(&|plan| { + if let Some(window_agg_exec) = plan.as_any().downcast_ref::() { + // println!("do analysis for bounded impl"); + let is_contains_groups = window_agg_exec + .window_expr() + .iter() + .any(|elem| is_window_frame_groups(elem)); + let can_run_bounded = window_agg_exec + .window_expr() + .iter() + .all(|elem| elem.can_run_bounded()); + // println!("is_contains_groups: {:?}", is_contains_groups); + // println!("can_run_bounded: {:?}", can_run_bounded); + if !is_contains_groups && can_run_bounded { + // println!("changing with bounded"); + return Ok(Some(Arc::new(BoundedWindowAggExec::try_new( + window_agg_exec.window_expr().to_vec(), + window_agg_exec.input().clone(), + window_agg_exec.input().schema(), + window_agg_exec.partition_keys.clone(), + window_agg_exec.sort_keys.clone(), + )?))); + } + } + Ok(None) + }) + } + + fn name(&self) -> &str { + "coalesce_batches" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// Checks window expression whether it is GROUPS mode +fn is_window_frame_groups(window_expr: &Arc) -> bool { + match window_expr.get_window_frame().units { + WindowFrameUnits::Groups => true, + _ => false, + } +} diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index b4db3a32b5220..7cc331f076c7b 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -22,6 +22,7 @@ use crate::error::{DataFusionError, Result}; use crate::execution::context::TaskContext; use crate::physical_plan::metrics::MemTrackingMetrics; use crate::physical_plan::{displayable, ColumnStatistics, ExecutionPlan, Statistics}; +use arrow::compute::concat; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::ArrowError; use arrow::error::Result as ArrowResult; @@ -95,6 +96,33 @@ pub async fn collect(stream: SendableRecordBatchStream) -> Result ArrowResult> { + if batches.is_empty() { + Ok(None) + } else { + let columns = schema + .fields() + .iter() + .enumerate() + .map(|(i, _)| { + concat( + &batches + .iter() + .map(|batch| batch.column(i).as_ref()) + .collect::>(), + ) + }) + .collect::>>()?; + Ok(Some(RecordBatch::try_new(schema.clone(), columns)?)) + } +} + /// Recursively builds a list of files in a directory with a given extension pub fn build_checked_file_list(dir: &str, ext: &str) -> Result> { let mut filenames: Vec = Vec::new(); diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs new file mode 100644 index 0000000000000..ac4b6b42eb604 --- /dev/null +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -0,0 +1,712 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Stream and channel implementations for window function expressions. + +use crate::error::Result; +use crate::execution::context::TaskContext; +use crate::physical_plan::expressions::PhysicalSortExpr; +use crate::physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, +}; +use crate::physical_plan::{ + ColumnStatistics, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, +}; +use arrow::array::Array; +use arrow::compute::{concat, lexicographical_partition_ranges, SortColumn}; +use arrow::{ + array::ArrayRef, + datatypes::{Schema, SchemaRef}, + error::Result as ArrowResult, + record_batch::RecordBatch, +}; +use datafusion_common::{DataFusionError, ScalarValue}; +use futures::stream::Stream; +use futures::{ready, StreamExt}; +use std::any::Any; +use std::cmp::min; +use std::collections::HashMap; +use std::ops::Range; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::physical_plan::common::combine_batches_with_ref; +use datafusion_physical_expr::window::{ + PartitionBatchState, PartitionBatches, PartitionKey, PartitionWindowAggStates, + WindowAggState, WindowState, +}; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; +use indexmap::IndexMap; +use log::debug; + +/// Window execution plan +#[derive(Debug)] +pub struct BoundedWindowAggExec { + /// Input plan + input: Arc, + /// Window function expression + window_expr: Vec>, + /// Schema after the window is run + schema: SchemaRef, + /// Schema before the window + input_schema: SchemaRef, + /// Partition Keys + pub partition_keys: Vec>, + /// Sort Keys + pub sort_keys: Option>, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +impl BoundedWindowAggExec { + /// Create a new execution plan for window aggregates + pub fn try_new( + window_expr: Vec>, + input: Arc, + input_schema: SchemaRef, + partition_keys: Vec>, + sort_keys: Option>, + ) -> Result { + let schema = create_schema(&input_schema, &window_expr)?; + let schema = Arc::new(schema); + Ok(Self { + input, + window_expr, + schema, + input_schema, + partition_keys, + sort_keys, + metrics: ExecutionPlanMetricsSet::new(), + }) + } + + /// Window expressions + pub fn window_expr(&self) -> &[Arc] { + &self.window_expr + } + + /// Input plan + pub fn input(&self) -> &Arc { + &self.input + } + + /// Get the input schema before any window functions are applied + pub fn input_schema(&self) -> SchemaRef { + self.input_schema.clone() + } + + /// Get Partition Columns + pub fn partition_by_sort_keys(&self) -> Result> { + // All window exprs have same partition by hance we just use first one + let partition_by = self.window_expr()[0].partition_by(); + let mut partition_columns = vec![]; + for elem in partition_by { + if let Some(sort_keys) = &self.sort_keys { + for a in sort_keys { + if a.expr.eq(elem) { + partition_columns.push(a.clone()); + break; + } + } + } + } + assert_eq!(partition_by.len(), partition_columns.len()); + Ok(partition_columns) + } +} + +impl ExecutionPlan for BoundedWindowAggExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + // because we can have repartitioning using the partition keys + // this would be either 1 or more than 1 depending on the presense of + // repartitioning + self.input.output_partitioning() + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + // This executor maintains input order, and has required input_ordering filled + // hence output_ordering would be `required_input_ordering` + self.required_input_ordering()[0] + } + + fn required_input_ordering(&self) -> Vec> { + let sort_keys = self.sort_keys.as_deref(); + vec![sort_keys] + } + + fn required_input_distribution(&self) -> Vec { + if self.partition_keys.is_empty() { + debug!("No partition defined for WindowAggExec!!!"); + vec![Distribution::SinglePartition] + } else { + //TODO support PartitionCollections if there is no common partition columns in the window_expr + vec![Distribution::HashPartitioned(self.partition_keys.clone())] + } + } + + fn equivalence_properties(&self) -> EquivalenceProperties { + self.input.equivalence_properties() + } + + fn maintains_input_order(&self) -> bool { + true + } + + fn relies_on_input_order(&self) -> bool { + true + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(BoundedWindowAggExec::try_new( + self.window_expr.clone(), + children[0].clone(), + self.input_schema.clone(), + self.partition_keys.clone(), + self.sort_keys.clone(), + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let input = self.input.execute(partition, context)?; + let stream = Box::pin(SortedPartitionByBoundedWindowStream::new( + self.schema.clone(), + self.window_expr.clone(), + input, + BaselineMetrics::new(&self.metrics, partition), + self.partition_by_sort_keys()?, + )); + Ok(stream) + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + write!(f, "BoundedWindowAggExec: ")?; + let g: Vec = self + .window_expr + .iter() + .map(|e| { + format!( + "{}: {:?}, frame: {:?}", + e.name().to_owned(), + e.field(), + e.get_window_frame() + ) + }) + .collect(); + write!(f, "wdw=[{}]", g.join(", "))?; + } + } + Ok(()) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Statistics { + let input_stat = self.input.statistics(); + let win_cols = self.window_expr.len(); + let input_cols = self.input_schema.fields().len(); + // TODO stats: some windowing function will maintain invariants such as min, max... + let mut column_statistics = vec![ColumnStatistics::default(); win_cols]; + if let Some(input_col_stats) = input_stat.column_statistics { + column_statistics.extend(input_col_stats); + } else { + column_statistics.extend(vec![ColumnStatistics::default(); input_cols]); + } + Statistics { + is_exact: input_stat.is_exact, + num_rows: input_stat.num_rows, + column_statistics: Some(column_statistics), + // TODO stats: knowing the type of the new columns we can guess the output size + total_byte_size: None, + } + } +} + +/// Trait for updating state, calculate results for window functions +/// According to partition by column assumptions Sorted/Unsorted we may have different +/// implementations for these fields +pub trait PartitionByHandler { + /// Method to construct output columns from window_expression results + fn calculate_out_columns(&self) -> Result>>; + /// Given how many rows we emitted as results + /// prune no longer needed sections from the state + fn prune_state(&mut self, n_out: usize) -> Result<()>; + /// method to update record batches for each partition + /// when new record batches are received + fn update_partition_batch(&mut self, record_batch: RecordBatch) -> Result<()>; +} + +fn create_schema( + input_schema: &Schema, + window_expr: &[Arc], +) -> Result { + let mut fields = Vec::with_capacity(input_schema.fields().len() + window_expr.len()); + for expr in window_expr { + fields.push(expr.field()?); + } + fields.extend_from_slice(input_schema.fields()); + Ok(Schema::new(fields)) +} + +/// stream for window aggregation plan +/// assuming partition by column is sorted (or without PARTITION BY expression) +pub struct SortedPartitionByBoundedWindowStream { + schema: SchemaRef, + input: SendableRecordBatchStream, + /// The record_batch executor receives as input (columns needed during aggregate results calculation) + input_buffer_record_batch: RecordBatch, + /// we separate `input_buffer_record_batch` according to different partitions (determined by PARTITION BY columns) + /// and store the result record_batches per partition base in the `partition_batches`. + /// This variable is used during result calculation for each window_expression + /// This enables us to use same batch for different window_expressions (without copying) + // We may have keep record_batches for each window expression in the `PartitionWindowAggStates` + // However, this would use more memory (on the order of window_expression number) + partition_batches: PartitionBatches, + /// Each executor can run multiple window expressions given + /// their PARTITION BY and ORDER BY sections are same + /// We keep state of the each window expression inside `window_agg_states` + window_agg_states: Vec, + finished: bool, + window_expr: Vec>, + baseline_metrics: BaselineMetrics, + partition_by_sort_keys: Vec, +} + +impl PartitionByHandler for SortedPartitionByBoundedWindowStream { + /// This method constructs output columns using the result of each window expression + fn calculate_out_columns(&self) -> Result>> { + let n_out = self.calculate_n_out_row(); + if n_out > 0 { + let mut out_columns = vec![]; + for partition_window_agg_states in self.window_agg_states.iter() { + out_columns.push(get_aggregate_result_out_column( + partition_window_agg_states, + n_out, + )?); + } + + let batch_to_show = self + .input_buffer_record_batch + .columns() + .iter() + .map(|elem| elem.slice(0, n_out)) + .collect::>(); + out_columns.extend_from_slice(&batch_to_show); + + Ok(Some(out_columns)) + } else { + Ok(None) + } + } + + /// prunes sections in the state that are no longer needed + fn prune_state(&mut self, n_out: usize) -> Result<()> { + self.prune_partition_batches()?; + self.prune_input_batch(n_out)?; + self.prune_out_columns(n_out)?; + + Ok(()) + } + + fn update_partition_batch(&mut self, record_batch: RecordBatch) -> Result<()> { + // all window expressions have same other than window frame boundaries hence we can use any one of the window expressions + let window_expr = self.window_expr.first().ok_or_else(|| { + DataFusionError::Execution( + "window expr cannot be empty to support streaming".to_string(), + ) + })?; + + let partition_columns = self.partition_columns(&record_batch)?; + let num_rows = record_batch.num_rows(); + if num_rows > 0 { + let partition_points = + self.evaluate_partition_points(num_rows, &partition_columns)?; + for partition_range in partition_points { + let partition_row = partition_columns + .iter() + .map(|arr| { + ScalarValue::try_from_array(&arr.values, partition_range.start) + }) + .collect::>()?; + let partition_batch = record_batch.slice( + partition_range.start, + partition_range.end - partition_range.start, + ); + if let Some(partition_batch_state) = + self.partition_batches.get_mut(&partition_row) + { + let combined_partition_batch = combine_batches_with_ref( + &[&partition_batch_state.record_batch, &partition_batch], + self.input.schema(), + )? + .ok_or_else(|| { + DataFusionError::Execution( + "Should contain at least one entry".to_string(), + ) + })?; + partition_batch_state.record_batch = combined_partition_batch; + } else { + let partition_batch_state = PartitionBatchState { + record_batch: partition_batch, + is_end: false, + }; + self.partition_batches + .insert(partition_row.clone(), partition_batch_state); + }; + } + } + let n_partitions = self.partition_batches.len(); + for (idx, (_, partition_batch_state)) in + self.partition_batches.iter_mut().enumerate() + { + if idx < n_partitions - 1 { + partition_batch_state.is_end = true; + } + } + if self.input_buffer_record_batch.num_rows() == 0 { + self.input_buffer_record_batch = record_batch; + } else { + self.input_buffer_record_batch = combine_batches_with_ref( + &[&self.input_buffer_record_batch, &record_batch], + self.input.schema(), + )? + .ok_or_else(|| { + DataFusionError::Execution( + "Should contain at least one entry".to_string(), + ) + })?; + } + + Ok(()) + } +} + +impl Stream for SortedPartitionByBoundedWindowStream { + type Item = ArrowResult; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let poll = self.poll_next_inner(cx); + self.baseline_metrics.record_poll(poll) + } +} + +impl SortedPartitionByBoundedWindowStream { + /// Create a new WindowAggStream + pub fn new( + schema: SchemaRef, + window_expr: Vec>, + input: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, + partition_by_sort_keys: Vec, + ) -> Self { + let mut state = vec![]; + for _i in 0..window_expr.len() { + state.push(IndexMap::new()); + } + let empty_batch = RecordBatch::new_empty(schema.clone()); + Self { + schema, + input, + input_buffer_record_batch: empty_batch, + partition_batches: IndexMap::new(), + window_agg_states: state, + finished: false, + window_expr, + baseline_metrics, + partition_by_sort_keys, + } + } + + fn compute_aggregates(&mut self) -> ArrowResult { + // calculate window cols + for (idx, cur_window_expr) in self.window_expr.iter().enumerate() { + cur_window_expr.evaluate_bounded( + &self.partition_batches, + &mut self.window_agg_states[idx], + )?; + } + + let columns_to_show = self.calculate_out_columns()?; + if let Some(columns_to_show) = columns_to_show { + let n_generated = columns_to_show[0].len(); + self.prune_state(n_generated)?; + RecordBatch::try_new(self.schema.clone(), columns_to_show) + } else { + Ok(RecordBatch::new_empty(self.schema.clone())) + } + } + + #[inline] + fn poll_next_inner( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.finished { + return Poll::Ready(None); + } + + let result = match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + self.update_partition_batch(batch)?; + self.compute_aggregates() + } + Some(Err(e)) => Err(e), + None => { + self.finished = true; + for (_, partition_batch_state) in self.partition_batches.iter_mut() { + partition_batch_state.is_end = true; + } + self.compute_aggregates() + } + }; + Poll::Ready(Some(result)) + } + + /// Method to calculate how many rows SortedPartitionByBoundedWindowStream can produce as output + fn calculate_n_out_row(&self) -> usize { + // different window aggregators may produce results with different rates + // we produce the overall batch result with the same speed as slowest one + self.window_agg_states + .iter() + .map(|window_agg_state| { + // below variable stores how many elements are generated (can displayed) for current + // window expression + let mut cur_window_expr_out_result_len = 0; + // We iterate over window_agg_state + // since it is IndexMap, iteration is over insertion order + // Hence we preserve sorting when partition columns are sorted + for (_, WindowState { state, .. }) in window_agg_state.iter() { + cur_window_expr_out_result_len += state.out_col.len(); + // If we do not generate all results for current partition + // we do not generate result for next partition + // otherwise we will lose input ordering + if state.n_row_result_missing > 0 { + break; + } + } + cur_window_expr_out_result_len + }) + .into_iter() + .min() + .unwrap_or(0) + } + + /// prunes the sections of the record batch (for each partition) + /// we no longer need to calculate window function result + fn prune_partition_batches(&mut self) -> Result<()> { + // Remove partitions which we know that ended (is_end flag is true) + // Retain method keep the remaining elements in the insertion order + // Hence after removal we still preserve ordering in between partitions + self.partition_batches + .retain(|_, partition_batch_state| !partition_batch_state.is_end); + + // `self.partition_batches` data are used by all window expressions + // hence when removing from `self.partition_batches` we need to remove from the earliest range boundary + // among all window expressions. `n_prune_each_partition` fill the earliest range boundary information + // for each partition. By this way we can delete no longer needed sections from the `self.partition_batches`. + // For instance if window frame one uses [10, 20] and window frame 2 uses [5, 15] + // We prune only first 5 elements from corresponding record batch in `self.partition_batches` + // Calculate how many element to prune for each partition_batch + let mut n_prune_each_partition: HashMap = HashMap::new(); + for window_agg_state in self.window_agg_states.iter_mut() { + window_agg_state.retain(|_, WindowState { state, .. }| !state.is_end); + for (partition_row, WindowState { state: value, .. }) in window_agg_state { + if let Some(state) = n_prune_each_partition.get_mut(partition_row) { + if value.current_range_of_sliding_window.start < *state { + *state = value.current_range_of_sliding_window.start; + } + } else { + n_prune_each_partition.insert( + partition_row.clone(), + value.current_range_of_sliding_window.start, + ); + } + } + } + + let err = || DataFusionError::Execution("Expects to have partition".to_string()); + // Retracts no longer needed parts during window calculations from partition batch + for (partition_row, n_prune) in n_prune_each_partition.iter() { + let partition_batch_state = self + .partition_batches + .get_mut(partition_row) + .ok_or_else(err)?; + let new_record_batch = partition_batch_state.record_batch.slice( + *n_prune, + partition_batch_state.record_batch.num_rows() - n_prune, + ); + partition_batch_state.record_batch = new_record_batch; + + // Update State indices, since we have pruned some rows from the beginning + for window_agg_state in self.window_agg_states.iter_mut() { + let window_state = + window_agg_state.get_mut(partition_row).ok_or_else(err)?; + let mut state = &mut window_state.state; + state.current_range_of_sliding_window = Range { + start: state.current_range_of_sliding_window.start - n_prune, + end: state.current_range_of_sliding_window.end - n_prune, + }; + state.last_calculated_index -= n_prune; + state.offset_pruned_rows += n_prune; + } + } + Ok(()) + } + + /// Prunes the section of the input batch whose aggregate results + /// are calculated and emitted as result + fn prune_input_batch(&mut self, n_out: usize) -> Result<()> { + let len_batch = self.input_buffer_record_batch.num_rows(); + let n_to_keep = len_batch - n_out; + let batch_to_keep = self + .input_buffer_record_batch + .columns() + .iter() + .map(|elem| elem.slice(n_out, n_to_keep)) + .collect::>(); + self.input_buffer_record_batch = + RecordBatch::try_new(self.input_buffer_record_batch.schema(), batch_to_keep)?; + Ok(()) + } + + /// Prunes emitted parts from WindowAggState `out_col` field + fn prune_out_columns(&mut self, n_out: usize) -> Result<()> { + // We store generated columns for each window expression in `out_col` field of the `WindowAggState` + // Given how many rows are emitted to output we remove these sections from state + for partition_window_agg_states in self.window_agg_states.iter_mut() { + let mut running_length = 0; + // Remove total of `n_out` entries from `out_col` field of `WindowAggState`. Iterates in the + // insertion order. Hence we preserve per partition ordering. Without emitting all results for a partition + // we do not generate result for another partition + for ( + _, + WindowState { + state: WindowAggState { out_col, .. }, + .. + }, + ) in partition_window_agg_states + { + if running_length < n_out { + let n_to_del = min(out_col.len(), n_out - running_length); + let n_to_keep = out_col.len() - n_to_del; + *out_col = out_col.slice(n_to_del, n_to_keep); + running_length += n_to_del; + } + } + } + Ok(()) + } + + /// Get Partition Columns + pub fn partition_columns(&self, batch: &RecordBatch) -> Result> { + self.partition_by_sort_keys + .iter() + .map(|elem| elem.evaluate_to_sort_column(batch)) + .collect::>>() + } + + /// evaluate the partition points given the sort columns; if the sort columns are + /// empty then the result will be a single element vec of the whole column rows. + fn evaluate_partition_points( + &self, + num_rows: usize, + partition_columns: &[SortColumn], + ) -> Result>> { + if partition_columns.is_empty() { + Ok(vec![Range { + start: 0, + end: num_rows, + }]) + } else { + Ok(lexicographical_partition_ranges(partition_columns) + .map_err(DataFusionError::ArrowError)? + .collect::>()) + } + } +} + +impl RecordBatchStream for SortedPartitionByBoundedWindowStream { + /// Get the schema + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +/// Calculates the section we can show results for expression +fn get_aggregate_result_out_column( + partition_window_agg_states: &PartitionWindowAggStates, + len_to_show: usize, +) -> Result { + let mut ret = None; + let mut running_length = 0; + // We assume that iteration order is according to insertion order + for ( + _, + WindowState { + state: WindowAggState { out_col, .. }, + .. + }, + ) in partition_window_agg_states + { + if running_length < len_to_show { + let n_to_use = min(len_to_show - running_length, out_col.len()); + running_length += n_to_use; + let slice_to_use = out_col.slice(0, n_to_use); + ret = match ret { + Some(ret) => Some(concat(&[&ret, &slice_to_use])?), + None => Some(slice_to_use), + } + } else { + break; + } + } + assert_eq!(running_length, len_to_show); + ret.ok_or_else(|| DataFusionError::Execution("Should contain something".to_string())) +} diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index 0f837e581141d..72494b3f20706 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -37,8 +37,10 @@ use datafusion_physical_expr::window::BuiltInWindowFunctionExpr; use std::convert::TryInto; use std::sync::Arc; +mod bounded_window_agg_exec; mod window_agg_exec; +pub use bounded_window_agg_exec::BoundedWindowAggExec; pub use datafusion_physical_expr::window::{ AggregateWindowExpr, BuiltInWindowExpr, WindowExpr, }; diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index a5bd6a3b97c8e..a78ba414b8be4 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1669,8 +1669,8 @@ async fn test_window_agg_sort_reversed_plan() -> Result<()> { vec![ "ProjectionExec: expr=[c9@2 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c9@0 DESC]", ] }; @@ -1727,8 +1727,8 @@ async fn test_window_agg_sort_reversed_plan_builtin() -> Result<()> { vec![ "ProjectionExec: expr=[c9@6 as c9, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as fv1, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lead2]", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt32(NULL)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt32(NULL)) }]", - " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt32(NULL)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt32(NULL)) }]", + " BoundedWindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }]", " SortExec: [c9@0 DESC]", ] }; @@ -1781,9 +1781,9 @@ async fn test_window_agg_sort_non_reversed_plan() -> Result<()> { vec![ "ProjectionExec: expr=[c9@2 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as rn1, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2]", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c9@1 ASC NULLS LAST]", - " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c9@0 DESC]", ] }; @@ -1837,10 +1837,10 @@ async fn test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> { vec![ "ProjectionExec: expr=[c9@5 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as rn2]", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c9@4 DESC]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c9@2 ASC NULLS LAST,c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", ] }; @@ -1976,8 +1976,8 @@ async fn test_window_agg_sort_orderby_reversed_partitionby_plan() -> Result<()> vec![ "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c1@0 ASC,c9@1 DESC]", ] }; @@ -2031,8 +2031,8 @@ async fn test_window_agg_sort_partitionby_reversed_plan() -> Result<()> { vec![ "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c1@0 ASC,c9@1 DESC]", ] }; diff --git a/datafusion/core/tests/window_fuzz.rs b/datafusion/core/tests/window_fuzz.rs new file mode 100644 index 0000000000000..10e914ed9900b --- /dev/null +++ b/datafusion/core/tests/window_fuzz.rs @@ -0,0 +1,309 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array}; +use arrow::compute::{concat_batches, SortOptions}; +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::pretty_format_batches; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use tokio::runtime::Builder; + +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::windows::{ + create_window_expr, BoundedWindowAggExec, WindowAggExec, +}; +use datafusion::physical_plan::{collect, common}; +use datafusion_expr::{ + AggregateFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, +}; + +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_common::ScalarValue; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::PhysicalSortExpr; +use test_utils::add_empty_batches; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn single_order_by_test() { + let rt = Builder::new_multi_thread() + .worker_threads(8) + .build() + .unwrap(); + let n = 100; + let handles_low_cardinality = (1..n).map(|i| { + rt.spawn(run_window_test( + make_staggered_batches::(1000, i), + i, + vec!["a"], + vec![], + )) + }); + let handles_high_cardinality = (1..n).map(|i| { + rt.spawn(run_window_test( + make_staggered_batches::(1000, i), + i, + vec!["a"], + vec![], + )) + }); + let handles = handles_low_cardinality + .into_iter() + .chain(handles_high_cardinality.into_iter()) + .collect::>>(); + rt.block_on(async { + for handle in handles { + handle.await.unwrap(); + } + }); + } + + #[test] + fn order_by_with_partition_test() { + let rt = Builder::new_multi_thread() + .worker_threads(8) + .build() + .unwrap(); + let n = 100; + // since we have sorted pairs (a,b) to not violate per partition soring + // partition should be field a, order by should be field b + let handles_low_cardinality = (1..n).map(|i| { + rt.spawn(run_window_test( + make_staggered_batches::(1000, i), + i, + vec!["b"], + vec!["a"], + )) + }); + let handles_high_cardinality = (1..n).map(|i| { + rt.spawn(run_window_test( + make_staggered_batches::(1000, i), + i, + vec!["b"], + vec!["a"], + )) + }); + let handles = handles_low_cardinality + .into_iter() + .chain(handles_high_cardinality.into_iter()) + .collect::>>(); + rt.block_on(async { + for handle in handles { + handle.await.unwrap(); + } + }); + } +} + +/// Perform batch and running window same input +/// and verify two outputs are equal +async fn run_window_test( + input1: Vec, + random_seed: u64, + orderby_columns: Vec<&str>, + partition_by_columns: Vec<&str>, +) { + let mut rng = StdRng::seed_from_u64(random_seed); + let session_config = SessionConfig::new().with_batch_size(50); + let ctx = SessionContext::with_config(session_config); + let schema = input1[0].schema(); + let preceding = rng.gen_range(0..50); + let following = rng.gen_range(0..50); + let rand_num = rng.gen_range(0..3); + let units = if rand_num < 1 { + WindowFrameUnits::Range + } else if rand_num < 2 { + WindowFrameUnits::Rows + } else { + // For now we do not support GROUPS in BoundedWindowAggExec implementation + // TODO: once GROUPS handling is available, use WindowFrameUnits::GROUPS in randomized tests also. + WindowFrameUnits::Range + }; + let window_frame = match units { + // In range queries window frame boundaries should match column type + WindowFrameUnits::Range => WindowFrame { + units, + start_bound: WindowFrameBound::Preceding(ScalarValue::Int32(Some(preceding))), + end_bound: WindowFrameBound::Following(ScalarValue::Int32(Some(following))), + }, + // In window queries, window frame boundary should be Uint64 + WindowFrameUnits::Rows => WindowFrame { + units, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some( + preceding as u64, + ))), + end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some( + following as u64, + ))), + }, + // Once GROUPS support is added construct window frame for this case also + _ => todo!(), + }; + let mut orderby_exprs = vec![]; + for column in orderby_columns { + orderby_exprs.push(PhysicalSortExpr { + expr: col(column, &schema).unwrap(), + options: SortOptions::default(), + }) + } + let mut partitionby_exprs = vec![]; + for column in partition_by_columns { + partitionby_exprs.push(col(column, &schema).unwrap()); + } + let mut sort_keys = vec![]; + for partition_by_expr in &partitionby_exprs { + sort_keys.push(PhysicalSortExpr { + expr: partition_by_expr.clone(), + options: SortOptions::default(), + }) + } + for order_by_expr in &orderby_exprs { + sort_keys.push(order_by_expr.clone()) + } + + let concat_input_record = concat_batches(&schema, &input1).unwrap(); + let exec1 = Arc::new( + MemoryExec::try_new(&[vec![concat_input_record]], schema.clone(), None).unwrap(), + ); + let usual_window_exec = Arc::new( + WindowAggExec::try_new( + vec![create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Sum), + "sum".to_owned(), + &[col("x", &schema).unwrap()], + &partitionby_exprs, + &orderby_exprs, + Arc::new(window_frame.clone()), + schema.as_ref(), + ) + .unwrap()], + exec1, + schema.clone(), + vec![], + Some(sort_keys.clone()), + ) + .unwrap(), + ); + let exec2 = + Arc::new(MemoryExec::try_new(&[input1.clone()], schema.clone(), None).unwrap()); + let running_window_exec = Arc::new( + BoundedWindowAggExec::try_new( + vec![create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Sum), + "sum".to_owned(), + &[col("x", &schema).unwrap()], + &partitionby_exprs, + &orderby_exprs, + Arc::new(window_frame.clone()), + schema.as_ref(), + ) + .unwrap()], + exec2, + schema.clone(), + vec![], + Some(sort_keys), + ) + .unwrap(), + ); + + let task_ctx = ctx.task_ctx(); + let collected_usual = collect(usual_window_exec, task_ctx.clone()).await.unwrap(); + + let collected_running = collect(running_window_exec, task_ctx.clone()) + .await + .unwrap(); + // compare + let usual_formatted = pretty_format_batches(&collected_usual).unwrap().to_string(); + let running_formatted = pretty_format_batches(&collected_running) + .unwrap() + .to_string(); + + let mut usual_formatted_sorted: Vec<&str> = usual_formatted.trim().lines().collect(); + usual_formatted_sorted.sort_unstable(); + + let mut running_formatted_sorted: Vec<&str> = + running_formatted.trim().lines().collect(); + running_formatted_sorted.sort_unstable(); + for (i, (usual_line, running_line)) in usual_formatted_sorted + .iter() + .zip(&running_formatted_sorted) + .enumerate() + { + assert_eq!((i, usual_line), (i, running_line)); + } +} + +/// Return randomly sized record batches with: +/// two sorted int32 columns 'a', 'b' ranged from 0..len / DISTINCT as columns +/// two random int32 columns 'x', 'y' as other columns +fn make_staggered_batches( + len: usize, + random_seed: u64, +) -> Vec { + // use a random number generator to pick a random sized output + let mut rng = StdRng::seed_from_u64(random_seed); + let mut input12: Vec<(i32, i32)> = vec![(0, 0); len]; + let mut input3: Vec = vec![0; len]; + let mut input4: Vec = vec![0; len]; + input12.iter_mut().for_each(|v| { + *v = ( + rng.gen_range(0..(len / DISTINCT)) as i32, + rng.gen_range(0..(len / DISTINCT)) as i32, + ) + }); + rng.fill(&mut input3[..]); + rng.fill(&mut input4[..]); + input12.sort(); + let input1 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.0)); + let input2 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.1)); + let input3 = Int32Array::from_iter_values(input3.into_iter()); + let input4 = Int32Array::from_iter_values(input4.into_iter()); + + // split into several record batches + let mut remainder = RecordBatch::try_from_iter(vec![ + ("a", Arc::new(input1) as ArrayRef), + ("b", Arc::new(input2) as ArrayRef), + ("x", Arc::new(input3) as ArrayRef), + ("y", Arc::new(input4) as ArrayRef), + ]) + .unwrap(); + + let mut batches = vec![]; + if STREAM { + while remainder.num_rows() > 0 { + let batch_size = rng.gen_range(0..50); + if remainder.num_rows() < batch_size { + break; + } + batches.push(remainder.slice(0, batch_size)); + remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); + } + } else { + while remainder.num_rows() > 0 { + let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + batches.push(remainder.slice(0, batch_size)); + remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); + } + } + add_empty_batches(batches, &mut rng) +} diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr/src/accumulator.rs index 5b8269ee28209..6e0ed0c6c718e 100644 --- a/datafusion/expr/src/accumulator.rs +++ b/datafusion/expr/src/accumulator.rs @@ -79,12 +79,18 @@ pub trait Accumulator: Send + Sync + Debug { /// Allocated means that for internal containers such as `Vec`, the `capacity` should be used /// not the `len` fn size(&self) -> usize; + + fn clone_dyn(&self) -> Result> { + Err(DataFusionError::NotImplemented( + "clone_dyn is not implemented by default for this accumulator, to use it in for cloning implement this method".into(), + )) + } } /// Representation of internal accumulator state. Accumulators can potentially have a mix of /// scalar and array values. It may be desirable to add custom aggregator states here as well /// in the future (perhaps `Custom(Box)`?). -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum AggregateState { /// Simple scalar value. Note that `ScalarValue::List` can be used to pass multiple /// values around diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index b8274ed1d1930..aa25ecb05a0dd 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -174,6 +174,16 @@ pub enum WindowFrameBound { Following(ScalarValue), } +impl WindowFrameBound { + pub fn is_unbounded(&self) -> bool { + match self { + WindowFrameBound::Preceding(elem) => elem.is_null(), + WindowFrameBound::CurrentRow => false, + WindowFrameBound::Following(elem) => elem.is_null(), + } + } +} + impl TryFrom for WindowFrameBound { type Error = DataFusionError; diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 6c202d55e4b8c..1bbb6d75373ac 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -51,6 +51,7 @@ datafusion-expr = { path = "../expr", version = "15.0.0" } datafusion-row = { path = "../row", version = "15.0.0" } half = { version = "2.1", default-features = false } hashbrown = { version = "0.13", features = ["raw"] } +indexmap = "1.9.2" itertools = { version = "0.10", features = ["use_std"] } lazy_static = { version = "^1.4.0" } md-5 = { version = "^0.10.0", optional = true } diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 7fceaeedeab3c..4871b53e7fa69 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -98,6 +98,10 @@ impl AggregateExpr for Count { true } + fn bounded_exec_supported(&self) -> bool { + true + } + fn create_row_accumulator( &self, start_index: usize, diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index b0776886e69d5..b43639ab59462 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -87,6 +87,12 @@ pub trait AggregateExpr: Send + Sync + Debug { false } + /// Specifies whether this aggregate function can run suing bounded memory + /// To be true accumulator should have `retract_batch` implemented + fn bounded_exec_supported(&self) -> bool { + false + } + /// RowAccumulator to access/update row-based aggregation state in-place. /// Currently, row accumulator only supports states of fixed-sized type. /// diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 11371f31d4c68..7257f4f5a7133 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -123,6 +123,10 @@ impl AggregateExpr for Sum { ) } + fn bounded_exec_supported(&self) -> bool { + true + } + fn create_row_accumulator( &self, start_index: usize, diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 1268559fe5688..f068e1c5eb3aa 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -23,15 +23,18 @@ use std::ops::Range; use std::sync::Arc; use arrow::array::Array; -use arrow::compute::SortOptions; +use arrow::compute::{concat, SortOptions}; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::WindowFrame; +use datafusion_expr::{Accumulator, WindowFrame}; -use crate::window::window_expr::reverse_order_bys; +use crate::window::window_expr::{reverse_order_bys, WindowFn, WindowFunctionState}; +use crate::window::{ + PartitionBatches, PartitionWindowAggStates, WindowAggState, WindowState, +}; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use crate::{window::WindowExpr, AggregateExpr}; @@ -138,6 +141,61 @@ impl WindowExpr for AggregateWindowExpr { ScalarValue::iter_to_array(row_wise_results.into_iter()) } + fn evaluate_bounded( + &self, + partition_batches: &PartitionBatches, + window_agg_state: &mut PartitionWindowAggStates, + ) -> Result<()> { + for (partition_row, partition_batch_state) in partition_batches.iter() { + if !window_agg_state.contains_key(partition_row) { + let accumulator = self.aggregate.create_accumulator()?; + let field = self.aggregate.field()?; + let out_type = field.data_type(); + // let out_type = &accumulator.out_type()?; + window_agg_state.insert( + partition_row.clone(), + WindowState { + state: WindowAggState::new( + out_type, + WindowFunctionState::AggregateState(vec![]), + )?, + window_fn: WindowFn::Aggregate(accumulator), + }, + ); + }; + let window_state = window_agg_state.get_mut(partition_row).unwrap(); + let accumulator = match &mut window_state.window_fn { + WindowFn::Aggregate(accumulator) => accumulator, + _ => unreachable!(), + }; + let mut state = &mut window_state.state; + state.is_end = partition_batch_state.is_end; + + let num_rows = partition_batch_state.record_batch.num_rows(); + + let mut idx = state.last_calculated_index; + let mut last_range = state.current_range_of_sliding_window.clone(); + let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + let out_col = self.get_result_column( + accumulator, + &partition_batch_state.record_batch, + &mut window_frame_ctx, + &mut last_range, + &mut idx, + state.is_end, + )?; + state.last_calculated_index = idx; + state.current_range_of_sliding_window = last_range.clone(); + + state.out_col = concat(&[&state.out_col, &out_col])?; + state.n_row_result_missing = num_rows - state.last_calculated_index; + + state.window_function_state = + WindowFunctionState::AggregateState(accumulator.state()?); + } + Ok(()) + } + fn partition_by(&self) -> &[Arc] { &self.partition_by } @@ -162,4 +220,99 @@ impl WindowExpr for AggregateWindowExpr { Arc::new(self.window_frame.reverse()), ))) } + + fn can_run_bounded(&self) -> bool { + self.aggregate.bounded_exec_supported() + && !self.window_frame.start_bound.is_unbounded() + && !self.window_frame.end_bound.is_unbounded() + } +} + +impl AggregateWindowExpr { + /// For given range calculate accumulator result inside range on value_slice and + /// update accumulator state + fn get_aggregate_result_inside_range( + &self, + last_range: &Range, + cur_range: &Range, + value_slice: &[ArrayRef], + accumulator: &mut Box, + ) -> Result { + let value = if cur_range.start == cur_range.end { + // We produce None if the window is empty. + ScalarValue::try_from(self.aggregate.field()?.data_type())? + } else { + // Accumulate any new rows that have entered the window: + let update_bound = cur_range.end - last_range.end; + if update_bound > 0 { + let update: Vec = value_slice + .iter() + .map(|v| v.slice(last_range.end, update_bound)) + .collect(); + accumulator.update_batch(&update)? + } + // Remove rows that have now left the window: + let retract_bound = cur_range.start - last_range.start; + if retract_bound > 0 { + let retract: Vec = value_slice + .iter() + .map(|v| v.slice(last_range.start, retract_bound)) + .collect(); + accumulator.retract_batch(&retract)? + } + accumulator.evaluate()? + }; + Ok(value) + } + + fn get_result_column( + &self, + accumulator: &mut Box, + // values: &[ArrayRef], + // order_bys: &[ArrayRef], + record_batch: &RecordBatch, + window_frame_ctx: &mut WindowFrameContext, + last_range: &mut Range, + idx: &mut usize, + is_end: bool, + ) -> Result { + let (values, order_bys) = self.get_values_orderbys(record_batch)?; + // We iterate on each row to perform a running calculation. + // First, current_range_of_sliding_window is calculated, then it is compared with last_range. + let length = values[0].len(); + let sort_options: Vec = + self.order_by.iter().map(|o| o.options).collect(); + let mut row_wise_results: Vec = vec![]; + let field = self.aggregate.field()?; + let out_type = field.data_type(); + while *idx < length { + let cur_range = window_frame_ctx.calculate_range( + &order_bys, + &sort_options, + length, + *idx, + )?; + // exit if range end index is length, need kind of flag to stop + if cur_range.end == length && !is_end { + break; + } + let value = self.get_aggregate_result_inside_range( + last_range, + &cur_range, + &values, + accumulator, + )?; + row_wise_results.push(value); + last_range.start = cur_range.start; + last_range.end = cur_range.end; + *idx += 1; + } + let out_col = if !row_wise_results.is_empty() { + ScalarValue::iter_to_array(row_wise_results.into_iter())? + } else { + let a = ScalarValue::try_from(out_type)?; + a.to_array_of_size(0) + }; + Ok(out_col) + } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 0b1e5ee8f19cf..75c9ed4fd0877 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -20,9 +20,14 @@ use super::window_frame_state::WindowFrameContext; use super::BuiltInWindowFunctionExpr; use super::WindowExpr; -use crate::window::window_expr::reverse_order_bys; +use crate::window::window_expr::{ + reverse_order_bys, BuiltinWindowState, WindowFn, WindowFunctionState, +}; +use crate::window::{ + PartitionBatches, PartitionWindowAggStates, WindowAggState, WindowState, +}; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; -use arrow::compute::SortOptions; +use arrow::compute::{concat, SortOptions}; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::Result; @@ -91,7 +96,7 @@ impl WindowExpr for BuiltInWindowExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let evaluator = self.expr.create_evaluator()?; let num_rows = batch.num_rows(); - if evaluator.uses_window_frame() { + if self.expr.uses_window_frame() { let sort_options: Vec = self.order_by.iter().map(|o| o.options).collect(); let mut row_wise_results = vec![]; @@ -122,6 +127,101 @@ impl WindowExpr for BuiltInWindowExpr { } } + /// evaluate the window function values against the batch + fn evaluate_bounded( + &self, + partition_batches: &PartitionBatches, + window_agg_state: &mut PartitionWindowAggStates, + ) -> Result<()> { + let sort_options: Vec = + self.order_by.iter().map(|o| o.options).collect(); + for (partition_row, partition_batch_state) in partition_batches.iter() { + if !window_agg_state.contains_key(partition_row) { + let evaluator = self.expr.create_evaluator()?; + let field = self.expr.field()?; + let out_type = field.data_type(); + window_agg_state.insert( + partition_row.clone(), + WindowState { + state: WindowAggState::new( + out_type, + WindowFunctionState::BuiltinWindowState( + BuiltinWindowState::Default, + ), + )?, + window_fn: WindowFn::Builtin(evaluator), + }, + ); + }; + let window_state = window_agg_state.get_mut(partition_row).unwrap(); + let evaluator = match &mut window_state.window_fn { + WindowFn::Builtin(evaluator) => evaluator, + _ => unreachable!(), + }; + let mut state = &mut window_state.state; + state.is_end = partition_batch_state.is_end; + let num_rows = partition_batch_state.record_batch.num_rows(); + + let columns = self.sort_columns(&partition_batch_state.record_batch)?; + let sort_partition_points = + self.evaluate_partition_points(num_rows, &columns)?; + let (values, order_bys) = + self.get_values_orderbys(&partition_batch_state.record_batch)?; + + // We iterate on each row to perform a running calculation. + // First, current_range_of_sliding_window is calculated, then it is compared with last_range. + let mut row_wise_results: Vec = vec![]; + let mut last_range = state.current_range_of_sliding_window.clone(); + let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + for idx in state.last_calculated_index..num_rows { + state.current_range_of_sliding_window = if !self.expr.uses_window_frame() + { + evaluator.get_range(state, num_rows)? + } else { + window_frame_ctx.calculate_range( + &order_bys, + &sort_options, + num_rows, + idx, + )? + }; + evaluator.update_state(state, &order_bys, &sort_partition_points)?; + // exit if range end index is length, need kind of flag to stop + if state.current_range_of_sliding_window.end == num_rows + && !partition_batch_state.is_end + { + state.current_range_of_sliding_window = last_range.clone(); + break; + } + if state.current_range_of_sliding_window.start + == state.current_range_of_sliding_window.end + { + // We produce None if the window is empty. + row_wise_results + .push(ScalarValue::try_from(self.expr.field()?.data_type())?) + } else { + let res = evaluator.evaluate_bounded(&values)?; + row_wise_results.push(res); + } + last_range = state.current_range_of_sliding_window.clone(); + state.last_calculated_index = idx + 1; + } + state.current_range_of_sliding_window = last_range; + let out_col = if !row_wise_results.is_empty() { + ScalarValue::iter_to_array(row_wise_results.into_iter())? + } else { + let a = ScalarValue::try_from(self.expr.field()?.data_type())?; + a.to_array_of_size(0) + }; + + state.out_col = concat(&[&state.out_col, &out_col])?; + state.n_row_result_missing = num_rows - state.last_calculated_index; + state.window_function_state = + WindowFunctionState::BuiltinWindowState(evaluator.state()?); + } + Ok(()) + } + fn get_window_frame(&self) -> &Arc { &self.window_frame } @@ -138,4 +238,14 @@ impl WindowExpr for BuiltInWindowExpr { Arc::new(self.window_frame.reverse()), ))) } + + fn can_run_bounded(&self) -> bool { + if self.expr.uses_window_frame() { + self.expr.bounded_exec_supported() + && !self.window_frame.start_bound.is_unbounded() + && !self.window_frame.end_bound.is_unbounded() + } else { + self.expr.bounded_exec_supported() + } + } } diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index 71b100b54e5e4..e8aacd5de2c45 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -72,4 +72,10 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { self ))) } + + fn bounded_exec_supported(&self) -> bool; + + fn uses_window_frame(&self) -> bool { + false + } } diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs index 45fe51178afcd..be48f0e9ca088 100644 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ b/datafusion/physical-expr/src/window/cume_dist.rs @@ -64,8 +64,13 @@ impl BuiltInWindowFunctionExpr for CumeDist { fn create_evaluator(&self) -> Result> { Ok(Box::new(CumeDistEvaluator {})) } + + fn bounded_exec_supported(&self) -> bool { + false + } } +#[derive(Debug)] pub(crate) struct CumeDistEvaluator; impl PartitionEvaluator for CumeDistEvaluator { diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index f4c176262ae46..7c9b3b810de7e 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -19,7 +19,8 @@ //! at runtime during query execution use crate::window::partition_evaluator::PartitionEvaluator; -use crate::window::BuiltInWindowFunctionExpr; +use crate::window::window_expr::{BuiltinWindowState, LeadLagState}; +use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::compute::cast; @@ -27,7 +28,8 @@ use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use std::any::Any; -use std::ops::Neg; +use std::cmp::min; +use std::ops::{Neg, Range}; use std::sync::Arc; /// window shift expression @@ -102,6 +104,7 @@ impl BuiltInWindowFunctionExpr for WindowShift { fn create_evaluator(&self) -> Result> { Ok(Box::new(WindowShiftEvaluator { + state: LeadLagState { idx: 0 }, shift_offset: self.shift_offset, default_value: self.default_value.clone(), })) @@ -111,6 +114,10 @@ impl BuiltInWindowFunctionExpr for WindowShift { true } + fn bounded_exec_supported(&self) -> bool { + true + } + fn reverse_expr(&self) -> Result> { Ok(Arc::new(Self { name: self.name.clone(), @@ -122,7 +129,9 @@ impl BuiltInWindowFunctionExpr for WindowShift { } } +#[derive(Debug)] pub(crate) struct WindowShiftEvaluator { + state: LeadLagState, shift_offset: i64, default_value: Option, } @@ -177,6 +186,63 @@ fn shift_with_default_value( } impl PartitionEvaluator for WindowShiftEvaluator { + fn state(&self) -> Result { + // If we do not use state we just return Default + Ok(BuiltinWindowState::LeadLag(self.state.clone())) + } + + fn set_state(&mut self, state: &BuiltinWindowState) -> Result<()> { + match &state { + BuiltinWindowState::LeadLag(lead_lag_state) => { + self.state = lead_lag_state.clone() + } + _ => self.state = LeadLagState::default(), + } + Ok(()) + } + + fn update_state( + &mut self, + state: &WindowAggState, + _range_columns: &[ArrayRef], + _sort_partition_points: &[Range], + ) -> Result<()> { + self.state.idx = state.last_calculated_index; + Ok(()) + } + + fn get_range(&self, state: &WindowAggState, n_rows: usize) -> Result> { + if self.shift_offset > 0 { + let start = if state.last_calculated_index > self.shift_offset as usize { + state.last_calculated_index - self.shift_offset as usize + } else { + 0 + }; + Ok(Range { + start, + end: state.last_calculated_index + 1, + }) + } else { + let end = state.last_calculated_index + (-self.shift_offset) as usize; + // let n_rows = self.values[0].len(); + let end = min(end, n_rows); + Ok(Range { + start: state.last_calculated_index, + end, + }) + } + } + + fn evaluate_bounded(&mut self, values: &[ArrayRef]) -> Result { + let dtype = values[0].data_type(); + let idx = self.state.idx as i64 - self.shift_offset; + if idx < 0 || idx as usize >= values[0].len() { + get_default_value(&self.default_value, dtype) + } else { + ScalarValue::try_from_array(&values[0], idx as usize) + } + } + fn evaluate(&self, values: &[ArrayRef], _num_rows: usize) -> Result { // LEAD, LAG window functions take single column, values will have size 1 let value = &values[0]; @@ -184,6 +250,24 @@ impl PartitionEvaluator for WindowShiftEvaluator { } } +fn get_default_value( + default_value: &Option, + dtype: &DataType, +) -> Result { + if let Some(val) = default_value { + match val { + ScalarValue::Int64(Some(val)) => { + ScalarValue::try_from_string(val.to_string(), dtype) + } + _ => Err(DataFusionError::Internal( + "Expects default value to have Int64 type".to_string(), + )), + } + } else { + Ok(ScalarValue::try_from(dtype)?) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index 40ed658ee38a2..722c65508c2cb 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -30,4 +30,10 @@ mod window_frame_state; pub use aggregate::AggregateWindowExpr; pub use built_in::BuiltInWindowExpr; pub use built_in_window_function_expr::BuiltInWindowFunctionExpr; +pub use window_expr::PartitionBatchState; +pub use window_expr::PartitionBatches; +pub use window_expr::PartitionKey; +pub use window_expr::PartitionWindowAggStates; +pub use window_expr::WindowAggState; pub use window_expr::WindowExpr; +pub use window_expr::WindowState; diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 2942a797b8526..8d9dfbbb483cd 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -19,7 +19,8 @@ //! that can evaluated at runtime during query execution use crate::window::partition_evaluator::PartitionEvaluator; -use crate::window::BuiltInWindowFunctionExpr; +use crate::window::window_expr::{BuiltinWindowState, NthValueState}; +use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; use crate::PhysicalExpr; use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; @@ -121,7 +122,10 @@ impl BuiltInWindowFunctionExpr for NthValue { } fn create_evaluator(&self) -> Result> { - Ok(Box::new(NthValueEvaluator { kind: self.kind })) + Ok(Box::new(NthValueEvaluator { + state: NthValueState::default(), + kind: self.kind, + })) } fn is_window_fn_reversible(&self) -> bool { @@ -131,6 +135,14 @@ impl BuiltInWindowFunctionExpr for NthValue { } } + fn bounded_exec_supported(&self) -> bool { + true + } + + fn uses_window_frame(&self) -> bool { + true + } + fn reverse_expr(&self) -> Result> { let reversed_kind = match self.kind { NthValueKind::First => NthValueKind::Last, @@ -151,13 +163,45 @@ impl BuiltInWindowFunctionExpr for NthValue { } /// Value evaluator for nth_value functions +#[derive(Debug)] pub(crate) struct NthValueEvaluator { + state: NthValueState, kind: NthValueKind, } impl PartitionEvaluator for NthValueEvaluator { - fn uses_window_frame(&self) -> bool { - true + fn state(&self) -> Result { + // If we do not use state we just return Default + Ok(BuiltinWindowState::NthValue(self.state.clone())) + } + + fn set_state(&mut self, state: &BuiltinWindowState) -> Result<()> { + match &state { + BuiltinWindowState::NthValue(nth_value_state) => { + self.state = nth_value_state.clone(); + } + _ => self.state = NthValueState::default(), + } + Ok(()) + } + + fn update_state( + &mut self, + state: &WindowAggState, + _range_columns: &[ArrayRef], + _sort_partition_points: &[Range], + ) -> Result<()> { + // If we do not use state, update_state does nothing + self.state.range = state.current_range_of_sliding_window.clone(); + Ok(()) + } + + // fn uses_window_frame(&self) -> bool { + // true + // } + + fn evaluate_bounded(&mut self, values: &[ArrayRef]) -> Result { + self.evaluate_inside_range(values, self.state.range.clone()) } fn evaluate_inside_range( diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index 86500441df5bc..c7a261761d638 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -17,20 +17,51 @@ //! partition evaluation module +use crate::window::window_expr::BuiltinWindowState; +use crate::window::WindowAggState; use arrow::array::ArrayRef; use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; +use std::fmt::Debug; use std::ops::Range; /// Partition evaluator -pub trait PartitionEvaluator { +pub trait PartitionEvaluator: Debug + Send + Sync { /// Whether the evaluator should be evaluated with rank fn include_rank(&self) -> bool { false } - fn uses_window_frame(&self) -> bool { - false + // fn uses_window_frame(&self) -> bool { + // false + // } + + /// Returns state of the Built-in Window Function + fn state(&self) -> Result { + // If we do not use state we just return Default + Ok(BuiltinWindowState::Default) + } + + /// Initializes state of the Built-in Window Function (useful for bounded memory implementation) + fn set_state(&mut self, _state: &BuiltinWindowState) -> Result<()> { + // If we do not use state, set_state does nothing + Ok(()) + } + + fn update_state( + &mut self, + _state: &WindowAggState, + _range_columns: &[ArrayRef], + _sort_partition_points: &[Range], + ) -> Result<()> { + // If we do not use state, update_state does nothing + Ok(()) + } + + fn get_range(&self, _state: &WindowAggState, _n_rows: usize) -> Result> { + Err(DataFusionError::NotImplemented( + "get_range is not implemented for this window function".to_string(), + )) } /// evaluate the partition evaluator against the partition @@ -40,6 +71,13 @@ pub trait PartitionEvaluator { )) } + /// evaluate window function result inside given range + fn evaluate_bounded(&mut self, _values: &[ArrayRef]) -> Result { + Err(DataFusionError::NotImplemented( + "evaluate_bounded is not implemented by default".into(), + )) + } + /// evaluate the partition evaluator against the partition but with rank fn evaluate_with_rank( &self, @@ -61,4 +99,10 @@ pub trait PartitionEvaluator { "evaluate_inside_range is not implemented by default".into(), )) } + + fn clone_dyn(&self) -> Result> { + Err(DataFusionError::NotImplemented( + "clone_dyn is not implemented by default for this evaluator, to use it in for cloning implement this method".into(), + )) + } } diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 87e01528de5a8..27b4495940d41 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -19,12 +19,13 @@ //! at runtime during query execution use crate::window::partition_evaluator::PartitionEvaluator; -use crate::window::BuiltInWindowFunctionExpr; +use crate::window::window_expr::{BuiltinWindowState, RankState}; +use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::array::{Float64Array, UInt64Array}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::Result; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use std::any::Any; use std::iter; use std::ops::Range; @@ -98,18 +99,90 @@ impl BuiltInWindowFunctionExpr for Rank { &self.name } + fn bounded_exec_supported(&self) -> bool { + matches!(self.rank_type, RankType::Basic | RankType::Dense) + } + fn create_evaluator(&self) -> Result> { Ok(Box::new(RankEvaluator { + state: RankState::default(), rank_type: self.rank_type, })) } } +#[derive(Debug)] pub(crate) struct RankEvaluator { + state: RankState, rank_type: RankType, } impl PartitionEvaluator for RankEvaluator { + fn get_range(&self, state: &WindowAggState, _n_rows: usize) -> Result> { + Ok(Range { + start: state.last_calculated_index, + end: state.last_calculated_index + 1, + }) + } + + fn state(&self) -> Result { + Ok(BuiltinWindowState::Rank(self.state.clone())) + } + + fn set_state(&mut self, state: &BuiltinWindowState) -> Result<()> { + match &state { + BuiltinWindowState::Rank(rank_state) => { + self.state = rank_state.clone(); + } + _ => self.state = RankState::default(), + } + Ok(()) + } + + fn update_state( + &mut self, + state: &WindowAggState, + range_columns: &[ArrayRef], + sort_partition_points: &[Range], + ) -> Result<()> { + // find range inside `sort_partition_points` containing `state.last_calculated_index` + let chunk_idx = sort_partition_points + .iter() + .position(|elem| { + elem.start <= state.last_calculated_index + && state.last_calculated_index < elem.end + }) + .ok_or_else(|| DataFusionError::Execution("Expects sort_partition_points to contain state.last_calculated_index".to_string()))?; + let cur_chunk = sort_partition_points[chunk_idx].clone(); + let mut last_rank_data = vec![]; + for column in range_columns { + last_rank_data.push(ScalarValue::try_from_array(column, cur_chunk.end - 1)?) + } + if self.state.last_rank_data.is_empty() { + self.state.last_rank_data = last_rank_data; + self.state.last_rank_boundary = state.offset_pruned_rows + cur_chunk.start; + self.state.n_rank = sort_partition_points.len(); + } else if self.state.last_rank_data != last_rank_data { + self.state.last_rank_data = last_rank_data; + self.state.last_rank_boundary = state.offset_pruned_rows + cur_chunk.start; + self.state.n_rank += 1 + } + Ok(()) + } + + /// evaluate window function result inside given range + fn evaluate_bounded(&mut self, _values: &[ArrayRef]) -> Result { + match self.rank_type { + RankType::Basic => Ok(ScalarValue::UInt64(Some( + self.state.last_rank_boundary as u64 + 1, + ))), + RankType::Dense => Ok(ScalarValue::UInt64(Some(self.state.n_rank as u64))), + RankType::Percent => Err(DataFusionError::Execution( + "Cannot Run Percent_RANK in streaming case".to_string(), + )), + } + } + fn include_rank(&self) -> bool { true } diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index b27ac29d27640..efbcb688cabf9 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -18,12 +18,14 @@ //! Defines physical expression for `row_number` that can evaluated at runtime during query execution use crate::window::partition_evaluator::PartitionEvaluator; -use crate::window::BuiltInWindowFunctionExpr; +use crate::window::window_expr::{BuiltinWindowState, NumRowsState}; +use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; use crate::PhysicalExpr; use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::Result; +use datafusion_common::{Result, ScalarValue}; use std::any::Any; +use std::ops::Range; use std::sync::Arc; /// row_number expression @@ -62,12 +64,47 @@ impl BuiltInWindowFunctionExpr for RowNumber { fn create_evaluator(&self) -> Result> { Ok(Box::::default()) } + + fn bounded_exec_supported(&self) -> bool { + true + } } -#[derive(Default)] -pub(crate) struct NumRowsEvaluator {} +#[derive(Default, Debug)] +pub(crate) struct NumRowsEvaluator { + state: NumRowsState, +} impl PartitionEvaluator for NumRowsEvaluator { + fn state(&self) -> Result { + // If we do not use state we just return Default + Ok(BuiltinWindowState::NumRows(self.state.clone())) + } + + fn set_state(&mut self, state: &BuiltinWindowState) -> Result<()> { + match &state { + BuiltinWindowState::NumRows(num_rows_state) => { + self.state = num_rows_state.clone(); + } + _ => self.state = NumRowsState::default(), + } + Ok(()) + } + + fn get_range(&self, state: &WindowAggState, _n_rows: usize) -> Result> { + Ok(Range { + start: state.last_calculated_index, + end: state.last_calculated_index + 1, + }) + } + + /// evaluate window function result inside given range + fn evaluate_bounded(&mut self, _values: &[ArrayRef]) -> Result { + let n_row = self.state.n_rows as u64 + 1; + self.state.n_rows += 1; + Ok(ScalarValue::UInt64(Some(n_row))) + } + fn evaluate(&self, _values: &[ArrayRef], num_rows: usize) -> Result { Ok(Arc::new(UInt64Array::from_iter_values( 1..(num_rows as u64) + 1, diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index bc35dd49b50d4..7700ed23fe5ec 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -15,14 +15,18 @@ // specific language governing permissions and limitations // under the License. +use crate::window::partition_evaluator::PartitionEvaluator; use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::compute::kernels::partition::lexicographical_partition_ranges; use arrow::compute::kernels::sort::SortColumn; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::{reverse_sort_options, DataFusionError, Result}; -use datafusion_expr::WindowFrame; +use arrow_schema::DataType; +use datafusion_common::{reverse_sort_options, DataFusionError, Result, ScalarValue}; +use datafusion_expr::{Accumulator, AggregateState, WindowFrame}; +use indexmap::IndexMap; use std::any::Any; +use std::fmt; use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; @@ -61,6 +65,17 @@ pub trait WindowExpr: Send + Sync + Debug { /// evaluate the window function values against the batch fn evaluate(&self, batch: &RecordBatch) -> Result; + /// evaluate the window function values against the batch + fn evaluate_bounded( + &self, + _partition_batches: &PartitionBatches, + _window_agg_state: &mut PartitionWindowAggStates, + ) -> Result<()> { + Err(DataFusionError::Internal( + "evaluate_bounded is not implemented".to_string(), + )) + } + /// evaluate the partition points given the sort columns; if the sort columns are /// empty then the result will be a single element vec of the whole column rows. fn evaluate_partition_points( @@ -121,6 +136,9 @@ pub trait WindowExpr: Send + Sync + Debug { /// get reversed expression fn get_reversed_expr(&self) -> Result>; + + /// get whether can run with bounded executor + fn can_run_bounded(&self) -> bool; } /// Reverses the ORDER BY expression, which is useful during equivalent window @@ -135,3 +153,138 @@ pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec), + Aggregate(Box), +} + +impl fmt::Debug for WindowFn { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFn::Builtin(builtin, ..) => { + write!(f, "partition evaluator: {:?}", builtin) + } + WindowFn::Aggregate(aggregate, ..) => { + write!(f, "accumulator: {:?}", aggregate) + } + } + } +} + +impl Clone for WindowFn { + fn clone(&self) -> Self { + match self { + WindowFn::Builtin(builtin) => WindowFn::Builtin(builtin.clone_dyn().unwrap()), + WindowFn::Aggregate(aggregate) => { + WindowFn::Aggregate(aggregate.clone_dyn().unwrap()) + } + } + } +} + +/// State for RANK(percent_rank, rank, dense_rank) +/// builtin window function +#[derive(Debug, Clone, Default)] +pub struct RankState { + /// The last values for rank as these values change, we increase n_rank + pub last_rank_data: Vec, + /// The index where last_rank_boundary is started + pub last_rank_boundary: usize, + /// Rank number kept from the start + pub n_rank: usize, +} + +/// State for 'ROW_NUMBER' builtin window function +#[derive(Debug, Clone, Default)] +pub struct NumRowsState { + pub n_rows: usize, +} + +#[derive(Debug, Clone, Default)] +pub struct NthValueState { + pub range: Range, +} + +#[derive(Debug, Clone, Default)] +pub struct LeadLagState { + pub idx: usize, +} + +#[derive(Debug, Clone, Default)] +pub enum BuiltinWindowState { + Rank(RankState), + NumRows(NumRowsState), + NthValue(NthValueState), + LeadLag(LeadLagState), + #[default] + Default, +} +#[derive(Debug, Clone)] +pub enum WindowFunctionState { + /// Different Aggregate functions may have different state definitions + /// In [Accumulator] trait, [fn state(&self) -> Result>] implementation + /// dictates that. + AggregateState(Vec), + /// BuiltinWindowState + BuiltinWindowState(BuiltinWindowState), +} + +#[derive(Debug, Clone)] +pub struct WindowAggState { + /// The range that we calculate the window function + pub current_range_of_sliding_window: Range, + /// The index of the last row that its result is calculated inside the partition record batch buffer. + pub last_calculated_index: usize, + /// The offset of the deleted row number + pub offset_pruned_rows: usize, + /// + pub window_function_state: WindowFunctionState, + // Keeps the results + pub out_col: ArrayRef, + pub n_row_result_missing: usize, + /// flag indicating whether we have received all data for this partition + pub is_end: bool, +} + +/// State for each unique partition determined according to PARTITION BY column(s) +#[derive(Debug, Clone)] +pub struct PartitionBatchState { + /// The record_batch belonging to current partition + pub record_batch: RecordBatch, + /// flag indicating whether we have received all data for this partition + pub is_end: bool, +} + +/// key for IndexMap for each unique partition +/// For instance, if window frame is OVER(PARTITION BY a,b) +/// PartitionKey would consist of unique [a,b] pairs +pub type PartitionKey = Vec; + +#[derive(Debug, Clone)] +pub struct WindowState { + pub state: WindowAggState, + pub window_fn: WindowFn, +} +pub type PartitionWindowAggStates = IndexMap; + +/// The IndexMap(Ordered HashMap) where record_batch is seperated for each partition +pub type PartitionBatches = IndexMap; + +impl WindowAggState { + pub fn new( + out_type: &DataType, + window_function_state: WindowFunctionState, + ) -> Result { + let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0); + Ok(Self { + current_range_of_sliding_window: Range { start: 0, end: 0 }, + last_calculated_index: 0, + offset_pruned_rows: 0, + window_function_state, + out_col: empty_out_col, + n_row_result_missing: 0, + is_end: false, + }) + } +} diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 4002a49cf585b..dfd878275181c 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -50,7 +50,10 @@ pub fn partitions_to_sorted_vec(partitions: &[Vec]) -> Vec, rng: &mut StdRng) -> Vec { +pub fn add_empty_batches( + batches: Vec, + rng: &mut StdRng, +) -> Vec { let schema = batches[0].schema(); batches From 0068566b3965c2c182c1c1a0e23e92fd01a005e4 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 20 Dec 2022 16:25:46 +0300 Subject: [PATCH 13/50] solve merge problems --- datafusion/core/benches/data_utils/mod.rs | 2 -- .../replace_window_with_bounded_impl.rs | 22 ++++-------- .../windows/bounded_window_agg_exec.rs | 7 ---- datafusion/core/tests/sql/window.rs | 16 ++++----- datafusion/core/tests/window_fuzz.rs | 2 +- datafusion/expr/src/accumulator.rs | 35 ------------------- .../physical-expr/src/window/lead_lag.rs | 2 +- .../physical-expr/src/window/nth_value.rs | 9 +++-- .../src/window/partition_evaluator.rs | 2 +- .../physical-expr/src/window/window_expr.rs | 6 ++-- 10 files changed, 25 insertions(+), 78 deletions(-) diff --git a/datafusion/core/benches/data_utils/mod.rs b/datafusion/core/benches/data_utils/mod.rs index 082d0a258e60d..575e1831c8380 100644 --- a/datafusion/core/benches/data_utils/mod.rs +++ b/datafusion/core/benches/data_utils/mod.rs @@ -17,7 +17,6 @@ //! This module provides the in-memory table for more realistic benchmarking. -use arrow::array::Int32Array; use arrow::{ array::Float32Array, array::Float64Array, @@ -32,7 +31,6 @@ use datafusion::from_slice::FromSlice; use rand::rngs::StdRng; use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; -use std::collections::BTreeMap; use std::sync::Arc; /// create an in-memory table given the partition len, array len, and batch size, diff --git a/datafusion/core/src/physical_optimizer/replace_window_with_bounded_impl.rs b/datafusion/core/src/physical_optimizer/replace_window_with_bounded_impl.rs index 886455904691b..0ff03130fddd1 100644 --- a/datafusion/core/src/physical_optimizer/replace_window_with_bounded_impl.rs +++ b/datafusion/core/src/physical_optimizer/replace_window_with_bounded_impl.rs @@ -21,12 +21,8 @@ use crate::physical_plan::windows::BoundedWindowAggExec; use crate::physical_plan::windows::WindowAggExec; use crate::{ - error::Result, - physical_optimizer::PhysicalOptimizerRule, - physical_plan::{ - coalesce_batches::CoalesceBatchesExec, filter::FilterExec, joins::HashJoinExec, - repartition::RepartitionExec, rewrite::TreeNodeRewritable, - }, + error::Result, physical_optimizer::PhysicalOptimizerRule, + physical_plan::rewrite::TreeNodeRewritable, }; use datafusion_expr::WindowFrameUnits; use datafusion_physical_expr::window::WindowExpr; @@ -51,19 +47,15 @@ impl PhysicalOptimizerRule for ReplaceWindowWithBoundedImpl { ) -> Result> { plan.transform_up(&|plan| { if let Some(window_agg_exec) = plan.as_any().downcast_ref::() { - // println!("do analysis for bounded impl"); let is_contains_groups = window_agg_exec .window_expr() .iter() - .any(|elem| is_window_frame_groups(elem)); + .any(is_window_frame_groups); let can_run_bounded = window_agg_exec .window_expr() .iter() .all(|elem| elem.can_run_bounded()); - // println!("is_contains_groups: {:?}", is_contains_groups); - // println!("can_run_bounded: {:?}", can_run_bounded); if !is_contains_groups && can_run_bounded { - // println!("changing with bounded"); return Ok(Some(Arc::new(BoundedWindowAggExec::try_new( window_agg_exec.window_expr().to_vec(), window_agg_exec.input().clone(), @@ -88,8 +80,8 @@ impl PhysicalOptimizerRule for ReplaceWindowWithBoundedImpl { /// Checks window expression whether it is GROUPS mode fn is_window_frame_groups(window_expr: &Arc) -> bool { - match window_expr.get_window_frame().units { - WindowFrameUnits::Groups => true, - _ => false, - } + matches!( + window_expr.get_window_frame().units, + WindowFrameUnits::Groups + ) } diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index ac4b6b42eb604..177071a2922fc 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -353,13 +353,6 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { } fn update_partition_batch(&mut self, record_batch: RecordBatch) -> Result<()> { - // all window expressions have same other than window frame boundaries hence we can use any one of the window expressions - let window_expr = self.window_expr.first().ok_or_else(|| { - DataFusionError::Execution( - "window expr cannot be empty to support streaming".to_string(), - ) - })?; - let partition_columns = self.partition_columns(&record_batch)?; let num_rows = record_batch.num_rows(); if num_rows > 0 { diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 93739291ab207..bb325a5533bca 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1740,11 +1740,11 @@ async fn test_window_partition_by_order_by() -> Result<()> { let expected = { vec![ "ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as COUNT(UInt8(1))]", - " WindowAggExec: wdw=[COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", " SortExec: [c1@1 ASC NULLS LAST,c2@2 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 1 }], 2)", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }, Column { name: \"c2\", index: 1 }], 2)", @@ -1953,11 +1953,11 @@ async fn test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> { vec![ "ProjectionExec: expr=[c9@5 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum2, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2]", " GlobalLimitExec: skip=0, fetch=5", - " BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c9@4 DESC]", - " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@4 ASC NULLS LAST,c1@2 ASC NULLS LAST,c2@3 ASC NULLS LAST]", + " BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c9@2 ASC NULLS LAST,c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", + " SortExec: [c9@2 DESC,c1@0 DESC]", ] }; @@ -2094,7 +2094,7 @@ async fn test_window_agg_sort_orderby_reversed_partitionby_plan() -> Result<()> " GlobalLimitExec: skip=0, fetch=5", " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c1@0 ASC,c9@1 DESC]", + " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]", ] }; @@ -2149,7 +2149,7 @@ async fn test_window_agg_sort_partitionby_reversed_plan() -> Result<()> { " GlobalLimitExec: skip=0, fetch=5", " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c1@0 ASC,c9@1 DESC]", + " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]", ] }; diff --git a/datafusion/core/tests/window_fuzz.rs b/datafusion/core/tests/window_fuzz.rs index 10e914ed9900b..106f95b62b8b1 100644 --- a/datafusion/core/tests/window_fuzz.rs +++ b/datafusion/core/tests/window_fuzz.rs @@ -25,11 +25,11 @@ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use tokio::runtime::Builder; +use datafusion::physical_plan::collect; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::windows::{ create_window_expr, BoundedWindowAggExec, WindowAggExec, }; -use datafusion::physical_plan::{collect, common}; use datafusion_expr::{ AggregateFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, }; diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr/src/accumulator.rs index bf4f1e1822430..131bd64c0343b 100644 --- a/datafusion/expr/src/accumulator.rs +++ b/datafusion/expr/src/accumulator.rs @@ -92,38 +92,3 @@ pub trait Accumulator: Send + Sync + Debug { )) } } - -/// Representation of internal accumulator state. Accumulators can potentially have a mix of -/// scalar and array values. It may be desirable to add custom aggregator states here as well -/// in the future (perhaps `Custom(Box)`?). -#[derive(Debug, Clone)] -pub enum AggregateState { - /// Simple scalar value. Note that `ScalarValue::List` can be used to pass multiple - /// values around - Scalar(ScalarValue), - /// Arrays can be used instead of `ScalarValue::List` and could potentially have better - /// performance with large data sets, although this has not been verified. It also allows - /// for use of arrow kernels with less overhead. - Array(ArrayRef), -} - -impl AggregateState { - /// Access the aggregate state as a scalar value. An error will occur if the - /// state is not a scalar value. - pub fn as_scalar(&self) -> Result<&ScalarValue> { - match &self { - Self::Scalar(v) => Ok(v), - _ => Err(DataFusionError::Internal( - "AggregateState is not a scalar aggregate".to_string(), - )), - } - } - - /// Access the aggregate state as an array value. - pub fn to_array(&self) -> ArrayRef { - match &self { - Self::Scalar(v) => v.to_array(), - Self::Array(array) => array.clone(), - } - } -} diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 9f6bf38d9b0c5..21b2136d287f8 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -113,7 +113,7 @@ impl BuiltInWindowFunctionExpr for WindowShift { fn bounded_exec_supported(&self) -> bool { true } - + fn reverse_expr(&self) -> Option> { Some(Arc::new(Self { name: self.name.clone(), diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 983080465d367..8fad8c6f383bc 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -128,7 +128,6 @@ impl BuiltInWindowFunctionExpr for NthValue { })) } - fn bounded_exec_supported(&self) -> bool { true } @@ -146,6 +145,10 @@ impl BuiltInWindowFunctionExpr for NthValue { kind: reversed_kind, })) } + + fn uses_window_frame(&self) -> bool { + true + } } /// Value evaluator for nth_value functions @@ -182,10 +185,6 @@ impl PartitionEvaluator for NthValueEvaluator { Ok(()) } - // fn uses_window_frame(&self) -> bool { - // true - // } - fn evaluate_bounded(&mut self, values: &[ArrayRef]) -> Result { self.evaluate_inside_range(values, self.state.range.clone()) } diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index c7a261761d638..371243a3c66fd 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -67,7 +67,7 @@ pub trait PartitionEvaluator: Debug + Send + Sync { /// evaluate the partition evaluator against the partition fn evaluate(&self, _values: &[ArrayRef], _num_rows: usize) -> Result { Err(DataFusionError::NotImplemented( - "evaluate_partition is not implemented by default".into(), + "evaluate is not implemented by default".into(), )) } diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 45df4d1a6f4a5..fecc6c50a9413 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -23,7 +23,7 @@ use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; use arrow_schema::DataType; use datafusion_common::{reverse_sort_options, DataFusionError, Result, ScalarValue}; -use datafusion_expr::{Accumulator, AggregateState, WindowFrame}; +use datafusion_expr::{Accumulator, WindowFrame}; use indexmap::IndexMap; use std::any::Any; use std::fmt; @@ -220,9 +220,9 @@ pub enum BuiltinWindowState { #[derive(Debug, Clone)] pub enum WindowFunctionState { /// Different Aggregate functions may have different state definitions - /// In [Accumulator] trait, [fn state(&self) -> Result>] implementation + /// In [Accumulator] trait, [fn state(&self) -> Result>] implementation /// dictates that. - AggregateState(Vec), + AggregateState(Vec), /// BuiltinWindowState BuiltinWindowState(BuiltinWindowState), } From 0e73945ae601e33849aaddb9defdce20a4a75b9d Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Wed, 21 Dec 2022 14:14:42 +0300 Subject: [PATCH 14/50] Refactor to simplify code --- .../remove_unnecessary_sorts.rs | 309 +++++++++--------- 1 file changed, 153 insertions(+), 156 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index 461873b9da70b..6a41ea4f285c9 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -15,9 +15,16 @@ // specific language governing permissions and limitations // under the License. -//! Remove Unnecessary Sorts optimizer rule is used to for removing unnecessary SortExec's inserted to -//! physical plan. Produces a valid physical plan (in terms of Sorting requirement). Its input can be either -//! valid, or invalid physical plans (in terms of Sorting requirement) +//! RemoveUnnecessarySorts optimizer rule inspects SortExec's in the given +//! physical plan and removes the ones it can prove unnecessary. The rule can +//! work on valid *and* invalid physical plans with respect to sorting +//! requirements, but always produces a valid physical plan in this sense. +//! +//! A non-realistic but easy to follow example: Assume that we somehow get the fragment +//! "SortExec: [nullable_col@0 ASC]", +//! " SortExec: [non_nullable_col@1 ASC]", +//! in the physical plan. The first sort is unnecessary since its result is overwritten +//! by another SortExec. Therefore, this rule removes it from the physical plan. use crate::error::Result; use crate::physical_optimizer::utils::{ add_sort_above_child, ordering_satisfy, ordering_satisfy_concrete, @@ -31,14 +38,12 @@ use crate::prelude::SessionConfig; use arrow::datatypes::SchemaRef; use datafusion_common::{reverse_sort_options, DataFusionError}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use itertools::izip; use std::iter::zip; use std::sync::Arc; -/// As an example Assume we get -/// "SortExec: [nullable_col@0 ASC]", -/// " SortExec: [non_nullable_col@1 ASC]", somehow in the physical plan -/// The first Sort is unnecessary since, its result would be overwritten by another SortExec. We -/// remove first Sort from the physical plan +/// This rule inspects SortExec's in the given physical plan and removes the +/// ones it can prove unnecessary. #[derive(Default)] pub struct RemoveUnnecessarySorts {} @@ -49,13 +54,77 @@ impl RemoveUnnecessarySorts { } } +/// This is a "data class" we use within the [RemoveUnnecessarySorts] rule +/// that tracks the closest `SortExec` descendant for every child of a plan. +#[derive(Debug, Clone)] +struct PlanWithCorrespondingSort { + plan: Arc, + // For every child, keep a vector of `ExecutionPlan`s starting from the + // closest `SortExec` till the current plan. The first index of the tuple is + // the child index of the plan -- we need this information as we make updates. + sort_onwards: Vec)>>, +} + +impl PlanWithCorrespondingSort { + pub fn new(plan: Arc) -> Self { + let length = plan.children().len(); + PlanWithCorrespondingSort { + plan, + sort_onwards: vec![vec![]; length], + } + } + + pub fn children(&self) -> Vec { + self.plan + .children() + .into_iter() + .map(|child| PlanWithCorrespondingSort::new(child)) + .collect() + } +} + +impl TreeNodeRewritable for PlanWithCorrespondingSort { + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.children(); + if children.is_empty() { + Ok(self) + } else { + let children_requirements = children + .into_iter() + .map(transform) + .collect::>>()?; + let children_plans = children_requirements + .iter() + .map(|elem| elem.plan.clone()) + .collect::>(); + let sort_onwards = children_requirements + .iter() + .map(|item| { + if item.sort_onwards.is_empty() { + vec![] + } else { + // TODO: When `maintains_input_order` returns Vec, + // pass the order-enforcing sort upwards. + item.sort_onwards[0].clone() + } + }) + .collect::>(); + let plan = with_new_children_if_necessary(self.plan, children_plans)?; + Ok(PlanWithCorrespondingSort { plan, sort_onwards }) + } + } +} + impl PhysicalOptimizerRule for RemoveUnnecessarySorts { fn optimize( &self, plan: Arc, _config: &SessionConfig, ) -> Result> { - // Run a bottom-up process to adjust input key ordering recursively + // Execute a post-order traversal to adjust input key ordering: let plan_requirements = PlanWithCorrespondingSort::new(plan); let adjusted = plan_requirements.transform_up(&remove_unnecessary_sorts)?; Ok(adjusted.plan) @@ -73,19 +142,20 @@ impl PhysicalOptimizerRule for RemoveUnnecessarySorts { fn remove_unnecessary_sorts( requirements: PlanWithCorrespondingSort, ) -> Result> { - // Do analysis of naive SortRemoval at the beginning - // Remove Sorts that are already satisfied - if let Some(res) = analyze_immediate_sort_removal(&requirements)? { - return Ok(Some(res)); + // Perform naive analysis at the beginning -- remove already-satisfied sorts: + if let Some(result) = analyze_immediate_sort_removal(&requirements)? { + return Ok(Some(result)); } - let mut new_children = requirements.plan.children().clone(); - let mut new_sort_onwards = requirements.sort_onwards.clone(); - for (idx, (child, sort_onward)) in new_children - .iter_mut() - .zip(new_sort_onwards.iter_mut()) - .enumerate() + let plan = &requirements.plan; + let mut new_children = plan.children().clone(); + let mut new_onwards = requirements.sort_onwards.clone(); + for (idx, (child, sort_onwards, required_ordering)) in izip!( + new_children.iter_mut(), + new_onwards.iter_mut(), + plan.required_input_ordering() + ) + .enumerate() { - let required_ordering = requirements.plan.required_input_ordering()[idx]; let physical_ordering = child.output_ordering(); match (required_ordering, physical_ordering) { (Some(required_ordering), Some(physical_ordering)) => { @@ -95,202 +165,133 @@ fn remove_unnecessary_sorts( || child.equivalence_properties(), ); if !is_ordering_satisfied { - // During sort Removal we have invalidated ordering invariant fix it - update_child_to_remove_unnecessary_sort(child, sort_onward)?; + // Make sure we preserve the ordering requirements: + update_child_to_remove_unnecessary_sort(child, sort_onwards)?; let sort_expr = required_ordering.to_vec(); *child = add_sort_above_child(child, sort_expr)?; - // Since we have added Sort, we add it to the sort_onwards also. - sort_onward.push((idx, child.clone())) - } else if is_ordering_satisfied && !sort_onward.is_empty() { - // can do analysis for sort removal - let (_, sort_any) = sort_onward[0].clone(); + sort_onwards.push((idx, child.clone())) + } else if let [first, ..] = sort_onwards.as_slice() { + // The ordering requirement is met, we can analyze if there is an unnecessary sort: + let sort_any = first.1.clone(); let sort_exec = convert_to_sort_exec(&sort_any)?; let sort_output_ordering = sort_exec.output_ordering(); let sort_input_ordering = sort_exec.input().output_ordering(); - // TODO: Once we can ensure required ordering propagates to above without changes - // (or with changes trackable) compare `sort_input_ordering` and and `required_ordering` - // this changes will enable us to remove (a,b) -> Sort -> (a,b,c) -> Required(a,b) Sort - // from the plan. With current implementation we cannot remove Sort from above configuration. - // Do naive analysis, where a SortExec is already sorted according to desired Sorting + // Simple analysis: Does the input of the sort in question already satisfy the ordering requirements? if ordering_satisfy(sort_input_ordering, sort_output_ordering, || { sort_exec.input().equivalence_properties() }) { - update_child_to_remove_unnecessary_sort(child, sort_onward)?; + update_child_to_remove_unnecessary_sort(child, sort_onwards)?; } else if let Some(window_agg_exec) = requirements.plan.as_any().downcast_ref::() { - // For window expressions we can remove some Sorts when expression can be calculated in reverse order also. + // For window expressions, we can remove some sorts when we can + // calculate the result in reverse: if let Some(res) = analyze_window_sort_removal( window_agg_exec, sort_exec, - sort_onward, + sort_onwards, )? { return Ok(Some(res)); } } + // TODO: Once we can ensure that required ordering information propagates with + // necessary lineage information, compare `sort_input_ordering` and `required_ordering`. + // This will enable us to handle cases such as (a,b) -> Sort -> (a,b,c) -> Required(a,b). + // Currently, we can not remove such sorts. } } (Some(required), None) => { - // Requirement is not satisfied We should add Sort to the plan. + // Ordering requirement is not met, we should add a SortExec to the plan. let sort_expr = required.to_vec(); *child = add_sort_above_child(child, sort_expr)?; - *sort_onward = vec![(idx, child.clone())]; + *sort_onwards = vec![(idx, child.clone())]; } (None, Some(_)) => { - // Sort doesn't propagate to the layers above in the physical plan + // We have a SortExec whose effect may be neutralized by a order-imposing + // operator. In this case, remove this sort: if !requirements.plan.maintains_input_order() { - // Unnecessary Sort is added to the plan, we can remove unnecessary sort - update_child_to_remove_unnecessary_sort(child, sort_onward)?; + update_child_to_remove_unnecessary_sort(child, sort_onwards)?; } } (None, None) => {} } } - if !requirements.plan.children().is_empty() { + if plan.children().is_empty() { + Ok(Some(requirements)) + } else { let new_plan = requirements.plan.with_new_children(new_children)?; - for (idx, new_sort_onward) in new_sort_onwards + for (idx, (trace, required_ordering)) in new_onwards .iter_mut() + .zip(new_plan.required_input_ordering()) .enumerate() .take(new_plan.children().len()) { - let requires_ordering = new_plan.required_input_ordering()[idx].is_some(); - //TODO: when `maintains_input_order` returns `Vec` use corresponding index + // TODO: When `maintains_input_order` returns a `Vec`, use corresponding index. if new_plan.maintains_input_order() - && !requires_ordering - && !new_sort_onward.is_empty() + && required_ordering.is_none() + && !trace.is_empty() { - new_sort_onward.push((idx, new_plan.clone())); - } else if new_plan.as_any().is::() { - new_sort_onward.clear(); - new_sort_onward.push((idx, new_plan.clone())); + trace.push((idx, new_plan.clone())); } else { - // These executors use SortExec, hence doesn't propagate - // sort above in the physical plan - new_sort_onward.clear(); + trace.clear(); + if new_plan.as_any().is::() { + trace.push((idx, new_plan.clone())); + } } } Ok(Some(PlanWithCorrespondingSort { plan: new_plan, - sort_onwards: new_sort_onwards, + sort_onwards: new_onwards, })) - } else { - Ok(Some(requirements)) - } -} - -#[derive(Debug, Clone)] -struct PlanWithCorrespondingSort { - plan: Arc, - // For each child keeps a vector of `ExecutionPlan`s starting from SortExec till current plan - // first index of tuple(usize) is child index of plan (we need during updating plan above) - sort_onwards: Vec)>>, -} - -impl PlanWithCorrespondingSort { - pub fn new(plan: Arc) -> Self { - let children_len = plan.children().len(); - PlanWithCorrespondingSort { - plan, - sort_onwards: vec![vec![]; children_len], - } - } - - pub fn children(&self) -> Vec { - let plan_children = self.plan.children(); - plan_children - .into_iter() - .map(|child| { - let length = child.children().len(); - PlanWithCorrespondingSort { - plan: child, - sort_onwards: vec![vec![]; length], - } - }) - .collect() - } -} - -impl TreeNodeRewritable for PlanWithCorrespondingSort { - fn map_children(self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { - let children = self.children(); - if !children.is_empty() { - let new_children: Result> = - children.into_iter().map(transform).collect(); - let children_requirements = new_children?; - let children_plans = children_requirements - .iter() - .map(|elem| elem.plan.clone()) - .collect::>(); - let sort_onwards = children_requirements - .iter() - .map(|elem| { - if !elem.sort_onwards.is_empty() { - // TODO: redirect the true sort onwards to above (the one we keep ordering) - // this is possible when maintains_input_order returns vec - elem.sort_onwards[0].clone() - } else { - vec![] - } - }) - .collect::>(); - let plan = with_new_children_if_necessary(self.plan, children_plans)?; - Ok(PlanWithCorrespondingSort { plan, sort_onwards }) - } else { - Ok(self) - } } } -/// Analyzes `SortExec` to determine whether this Sort can be removed +/// Analyzes a given `SortExec` to determine whether its input already has +/// a finer ordering than this `SortExec` enforces. fn analyze_immediate_sort_removal( requirements: &PlanWithCorrespondingSort, ) -> Result> { if let Some(sort_exec) = requirements.plan.as_any().downcast_ref::() { + // If this sort is unnecessary, we should remove it: if ordering_satisfy( sort_exec.input().output_ordering(), sort_exec.output_ordering(), || sort_exec.input().equivalence_properties(), ) { - // This sort is unnecessary we should remove it - let new_plan = sort_exec.input(); - // Since we know that Sort have exactly one child we can use first index safely - assert_eq!(requirements.sort_onwards.len(), 1); - let mut new_sort_onward = requirements.sort_onwards[0].to_vec(); - if !new_sort_onward.is_empty() { - new_sort_onward.pop(); + // Since we know that a `SortExec` has exactly one child, + // we can use the zero index safely: + let mut new_onwards = requirements.sort_onwards[0].to_vec(); + if !new_onwards.is_empty() { + new_onwards.pop(); } return Ok(Some(PlanWithCorrespondingSort { - plan: new_plan.clone(), - sort_onwards: vec![new_sort_onward], + plan: sort_exec.input().clone(), + sort_onwards: vec![new_onwards], })); } } Ok(None) } -/// Analyzes `WindowAggExec` to determine whether Sort can be removed +/// Analyzes a `WindowAggExec` to determine whether it may allow removing a sort. fn analyze_window_sort_removal( window_agg_exec: &WindowAggExec, sort_exec: &SortExec, sort_onward: &mut Vec<(usize, Arc)>, ) -> Result> { let required_ordering = sort_exec.output_ordering().ok_or_else(|| { - DataFusionError::Plan("SortExec should have output ordering".to_string()) + DataFusionError::Plan("A SortExec should have output ordering".to_string()) })?; let physical_ordering = sort_exec.input().output_ordering(); let physical_ordering = if let Some(physical_ordering) = physical_ordering { physical_ordering } else { - // If there is no physical ordering, there is no way to remove Sorting, immediately return + // If there is no physical ordering, there is no way to remove a sort -- immediately return: return Ok(None); }; let window_expr = window_agg_exec.window_expr(); - let partition_keys = window_expr[0].partition_by().to_vec(); let (can_skip_sorting, should_reverse) = can_skip_sort( - &partition_keys, + window_expr[0].partition_by(), required_ordering, &sort_exec.input().schema(), physical_ordering, @@ -320,7 +321,7 @@ fn analyze_window_sort_removal( Ok(None) } -/// Updates child such that unnecessary sorting below it is removed +/// Updates child to remove the unnecessary sorting below it. fn update_child_to_remove_unnecessary_sort( child: &mut Arc, sort_onwards: &mut Vec<(usize, Arc)>, @@ -331,34 +332,30 @@ fn update_child_to_remove_unnecessary_sort( Ok(()) } -/// Convert dyn ExecutionPlan to SortExec (Assumes it is SortExec) +/// Converts an [ExecutionPlan] trait object to a [SortExec] when possible. fn convert_to_sort_exec(sort_any: &Arc) -> Result<&SortExec> { - let sort_exec = sort_any - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Plan("First layer should start from SortExec".to_string()) - })?; - Ok(sort_exec) + sort_any.as_any().downcast_ref::().ok_or_else(|| { + DataFusionError::Plan("Given ExecutionPlan is not a SortExec".to_string()) + }) } -/// Removes the sort from the plan in the `sort_onwards` +/// Removes the sort from the plan in `sort_onwards`. fn remove_corresponding_sort_from_sub_plan( sort_onwards: &mut Vec<(usize, Arc)>, ) -> Result> { let (sort_child_idx, sort_any) = sort_onwards[0].clone(); let sort_exec = convert_to_sort_exec(&sort_any)?; let mut prev_layer = sort_exec.input().clone(); - let mut prev_layer_child_idx = sort_child_idx; - // We start from 1 hence since first one is sort and we are removing it - // from the plan - for (cur_layer_child_idx, cur_layer) in sort_onwards.iter().skip(1) { - let mut new_children = cur_layer.children(); - new_children[prev_layer_child_idx] = prev_layer; - prev_layer = cur_layer.clone().with_new_children(new_children)?; - prev_layer_child_idx = *cur_layer_child_idx; + let mut prev_child_idx = sort_child_idx; + // In the loop below, se start from 1 as the first one is a SortExec + // and we are removing it from the plan. + for (child_idx, layer) in sort_onwards.iter().skip(1) { + let mut children = layer.children(); + children[prev_child_idx] = prev_layer; + prev_layer = layer.clone().with_new_children(children)?; + prev_child_idx = *child_idx; } - // We have removed the corresponding sort hence empty the sort_onwards + // We have removed the sort, hence empty the sort_onwards: sort_onwards.clear(); Ok(prev_layer) } From 4f145dd642896c010c37f1a116baf194463f45b2 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Wed, 21 Dec 2022 14:36:45 +0300 Subject: [PATCH 15/50] Better comments, change method names --- .../core/src/physical_optimizer/remove_unnecessary_sorts.rs | 2 +- datafusion/physical-expr/src/aggregate/mod.rs | 6 ++++-- datafusion/physical-expr/src/window/aggregate.rs | 2 +- datafusion/physical-expr/src/window/built_in.rs | 2 +- datafusion/physical-expr/src/window/window_expr.rs | 6 +++--- 5 files changed, 10 insertions(+), 8 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index 6a41ea4f285c9..593c31ded9a8b 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -300,7 +300,7 @@ fn analyze_window_sort_removal( let new_window_expr = if should_reverse { window_expr .iter() - .map(|e| e.get_reversed_expr()) + .map(|e| e.get_reverse_expr()) .collect::>>() } else { Some(window_expr.to_vec()) diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 325513293f5fd..41c541918e716 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -102,9 +102,11 @@ pub trait AggregateExpr: Send + Sync + Debug { ))) } - /// Construct Reverse Expression - // Typically expression itself for aggregate functions By default returns None + /// Construct an expression that calculates the aggregate in reverse. fn reverse_expr(&self) -> Option> { + // Typically the "reverse" expression is itself (e.g. SUM, COUNT). + // For aggregates that do not support calculation in reverse, + // returns None (which is the default value). None } } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 4d4509df028eb..8b2579e991f60 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -150,7 +150,7 @@ impl WindowExpr for AggregateWindowExpr { &self.window_frame } - fn get_reversed_expr(&self) -> Option> { + fn get_reverse_expr(&self) -> Option> { if let Some(reverse_expr) = self.aggregate.reverse_expr() { Some(Arc::new(AggregateWindowExpr::new( reverse_expr, diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 8d6dc6008abbf..a8d768d0357e4 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -126,7 +126,7 @@ impl WindowExpr for BuiltInWindowExpr { &self.window_frame } - fn get_reversed_expr(&self) -> Option> { + fn get_reverse_expr(&self) -> Option> { if let Some(reverse_expr) = self.expr.reverse_expr() { Some(Arc::new(BuiltInWindowExpr::new( reverse_expr, diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 78c3935c567d0..a718fa4cd3b36 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -113,11 +113,11 @@ pub trait WindowExpr: Send + Sync + Debug { Ok((values, order_bys)) } - // Get window frame of this WindowExpr + /// Get the window frame of this [WindowExpr]. fn get_window_frame(&self) -> &Arc; - /// get reversed expression - fn get_reversed_expr(&self) -> Option>; + /// Get the reverse expression of this [WindowExpr]. + fn get_reverse_expr(&self) -> Option>; } /// Reverses the ORDER BY expression, which is useful during equivalent window From 838972c623d9e6137bc18693556eb48bbcd2cece Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Wed, 21 Dec 2022 17:40:08 +0300 Subject: [PATCH 16/50] resolve merge conflicts --- .../remove_unnecessary_sorts.rs | 75 +------------------ .../physical-expr/src/window/window_expr.rs | 1 - 2 files changed, 1 insertion(+), 75 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs index d8010c466edf9..593c31ded9a8b 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs @@ -174,6 +174,7 @@ fn remove_unnecessary_sorts( // The ordering requirement is met, we can analyze if there is an unnecessary sort: let sort_any = first.1.clone(); let sort_exec = convert_to_sort_exec(&sort_any)?; + let sort_output_ordering = sort_exec.output_ordering(); let sort_input_ordering = sort_exec.input().output_ordering(); // Simple analysis: Does the input of the sort in question already satisfy the ordering requirements? if ordering_satisfy(sort_input_ordering, sort_output_ordering, || { @@ -823,80 +824,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_remove_unnecessary_sort2() -> Result<()> { - let session_ctx = SessionContext::new(); - let conf = session_ctx.copied_config(); - let schema = create_test_schema()?; - let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) - as Arc; - let sort_exprs = vec![PhysicalSortExpr { - expr: col("nullable_col", schema.as_ref()).unwrap(), - options: SortOptions::default(), - }]; - let sort_exec = Arc::new(SortExec::try_new(sort_exprs.clone(), source, None)?) - as Arc; - let sort_preserving_merge_exec = - Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec)) - as Arc; - let sort_exprs = vec![ - PhysicalSortExpr { - expr: col("nullable_col", schema.as_ref()).unwrap(), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: col("non_nullable_col", schema.as_ref()).unwrap(), - options: SortOptions::default(), - }, - ]; - let sort_exec = Arc::new(SortExec::try_new( - sort_exprs.clone(), - sort_preserving_merge_exec, - None, - )?) as Arc; - let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new( - vec![sort_exprs[0].clone()], - sort_exec, - )) as Arc; - let physical_plan = sort_preserving_merge_exec; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let expected = { - vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", - " SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: [nullable_col@0 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", - ] - }; - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - let optimized_physical_plan = - RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; - let formatted = displayable(optimized_physical_plan.as_ref()) - .indent() - .to_string(); - let expected = { - vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: [nullable_col@0 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", - ] - }; - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual - ); - Ok(()) - } - #[tokio::test] async fn test_change_wrong_sorting() -> Result<()> { let session_ctx = SessionContext::new(); diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 51ec6b51f8550..567668800f8b0 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -136,7 +136,6 @@ pub trait WindowExpr: Send + Sync + Debug { /// Get the reverse expression of this [WindowExpr]. fn get_reverse_expr(&self) -> Option>; - } /// Reverses the ORDER BY expression, which is useful during equivalent window From 6b076211c4ed964d5afe938c4814d9490d5d8696 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 22 Dec 2022 11:01:25 +0300 Subject: [PATCH 17/50] Resolve errors introduced by syncing --- datafusion/core/tests/sql/window.rs | 62 +++++++------------- datafusion/physical-expr/src/window/ntile.rs | 9 +-- 2 files changed, 23 insertions(+), 48 deletions(-) diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index c57b40626fc79..41278e1208b78 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1748,7 +1748,7 @@ async fn test_window_partition_by_order_by() -> Result<()> { let msg = format!("Creating logical plan for '{}'", sql); let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await.unwrap(); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); let expected = { vec![ @@ -1788,10 +1788,8 @@ async fn test_window_agg_sort_reversed_plan() -> Result<()> { LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Only 1 SortExec was added let expected = { @@ -1846,10 +1844,8 @@ async fn test_window_agg_sort_reversed_plan_builtin() -> Result<()> { LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Only 1 SortExec was added let expected = { @@ -1900,10 +1896,8 @@ async fn test_window_agg_sort_non_reversed_plan() -> Result<()> { LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // We cannot reverse each window function (ROW_NUMBER is not reversible) let expected = { @@ -1956,10 +1950,8 @@ async fn test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> { LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // We cannot reverse each window function (ROW_NUMBER is not reversible) let expected = { @@ -2046,10 +2038,8 @@ async fn test_window_agg_complex_plan() -> Result<()> { LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Unnecessary SortExecs are removed let expected = { @@ -2095,10 +2085,8 @@ async fn test_window_agg_sort_orderby_reversed_partitionby_plan() -> Result<()> LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Only 1 SortExec was added let expected = { @@ -2150,10 +2138,8 @@ async fn test_window_agg_sort_partitionby_reversed_plan() -> Result<()> { LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Only 1 SortExec was added let expected = { @@ -2204,10 +2190,8 @@ async fn test_window_agg_sort_orderby_reversed_binary_expr() -> Result<()> { LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Only 1 SortExec was added let expected = { @@ -2261,10 +2245,8 @@ async fn test_remove_unnecessary_sort_in_sub_query() -> Result<()> { ORDER BY c1 ) AS a "; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Unnecessary Sort in the sub query is removed let expected = { @@ -2319,10 +2301,8 @@ async fn test_window_agg_sort_orderby_reversed_partitionby_reversed_plan() -> Re LIMIT 5"; let msg = format!("Creating logical plan for '{}'", sql); - let plan = ctx.create_logical_plan(sql).expect(&msg); - let state = ctx.state(); - let logical_plan = state.optimize(&plan)?; - let physical_plan = state.create_physical_plan(&logical_plan).await?; + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); // Only 1 SortExec was added let expected = { diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs index ed00c3c869550..f5844eccc63a8 100644 --- a/datafusion/physical-expr/src/window/ntile.rs +++ b/datafusion/physical-expr/src/window/ntile.rs @@ -26,7 +26,6 @@ use arrow::datatypes::Field; use arrow_schema::DataType; use datafusion_common::Result; use std::any::Any; -use std::ops::Range; use std::sync::Arc; #[derive(Debug)] @@ -70,12 +69,8 @@ pub(crate) struct NtileEvaluator { } impl PartitionEvaluator for NtileEvaluator { - fn evaluate_partition( - &self, - _values: &[ArrayRef], - partition: Range, - ) -> Result { - let num_rows = (partition.end - partition.start) as u64; + fn evaluate(&self, _values: &[ArrayRef], num_rows: usize) -> Result { + let num_rows = num_rows as u64; let mut vec: Vec = Vec::new(); for i in 0..num_rows { let res = i * self.n / num_rows; From a2d2229151c898790703e9018a0dbd2247260027 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 22 Dec 2022 11:13:27 +0300 Subject: [PATCH 18/50] remove set_state, make ntile debuggable --- datafusion/physical-expr/src/aggregate/mod.rs | 2 +- .../src/window/built_in_window_function_expr.rs | 4 +++- datafusion/physical-expr/src/window/cume_dist.rs | 4 ---- datafusion/physical-expr/src/window/lead_lag.rs | 11 ----------- datafusion/physical-expr/src/window/nth_value.rs | 10 ---------- datafusion/physical-expr/src/window/ntile.rs | 1 + .../physical-expr/src/window/partition_evaluator.rs | 10 ---------- datafusion/physical-expr/src/window/rank.rs | 10 ---------- datafusion/physical-expr/src/window/row_number.rs | 10 ---------- 9 files changed, 5 insertions(+), 57 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 719f7e6ff120b..5cb251a4d415d 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -87,7 +87,7 @@ pub trait AggregateExpr: Send + Sync + Debug { false } - /// Specifies whether this aggregate function can run suing bounded memory + /// Specifies whether this aggregate function can run using bounded memory /// To be true accumulator should have `retract_batch` implemented fn bounded_exec_supported(&self) -> bool { false diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index 7cb2b8e8ed980..5ab7708974877 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -64,7 +64,9 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { None } - fn bounded_exec_supported(&self) -> bool; + fn bounded_exec_supported(&self) -> bool { + false + } fn uses_window_frame(&self) -> bool { false diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs index be48f0e9ca088..3abb91e06f6a2 100644 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ b/datafusion/physical-expr/src/window/cume_dist.rs @@ -64,10 +64,6 @@ impl BuiltInWindowFunctionExpr for CumeDist { fn create_evaluator(&self) -> Result> { Ok(Box::new(CumeDistEvaluator {})) } - - fn bounded_exec_supported(&self) -> bool { - false - } } #[derive(Debug)] diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 21b2136d287f8..3939292ccba42 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -187,16 +187,6 @@ impl PartitionEvaluator for WindowShiftEvaluator { Ok(BuiltinWindowState::LeadLag(self.state.clone())) } - fn set_state(&mut self, state: &BuiltinWindowState) -> Result<()> { - match &state { - BuiltinWindowState::LeadLag(lead_lag_state) => { - self.state = lead_lag_state.clone() - } - _ => self.state = LeadLagState::default(), - } - Ok(()) - } - fn update_state( &mut self, state: &WindowAggState, @@ -220,7 +210,6 @@ impl PartitionEvaluator for WindowShiftEvaluator { }) } else { let end = state.last_calculated_index + (-self.shift_offset) as usize; - // let n_rows = self.values[0].len(); let end = min(end, n_rows); Ok(Range { start: state.last_calculated_index, diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 8fad8c6f383bc..9836feec51ca7 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -164,16 +164,6 @@ impl PartitionEvaluator for NthValueEvaluator { Ok(BuiltinWindowState::NthValue(self.state.clone())) } - fn set_state(&mut self, state: &BuiltinWindowState) -> Result<()> { - match &state { - BuiltinWindowState::NthValue(nth_value_state) => { - self.state = nth_value_state.clone(); - } - _ => self.state = NthValueState::default(), - } - Ok(()) - } - fn update_state( &mut self, state: &WindowAggState, diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs index f5844eccc63a8..b8365dba19d01 100644 --- a/datafusion/physical-expr/src/window/ntile.rs +++ b/datafusion/physical-expr/src/window/ntile.rs @@ -64,6 +64,7 @@ impl BuiltInWindowFunctionExpr for Ntile { } } +#[derive(Debug)] pub(crate) struct NtileEvaluator { n: u64, } diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index 371243a3c66fd..54681bb0efc0f 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -32,22 +32,12 @@ pub trait PartitionEvaluator: Debug + Send + Sync { false } - // fn uses_window_frame(&self) -> bool { - // false - // } - /// Returns state of the Built-in Window Function fn state(&self) -> Result { // If we do not use state we just return Default Ok(BuiltinWindowState::Default) } - /// Initializes state of the Built-in Window Function (useful for bounded memory implementation) - fn set_state(&mut self, _state: &BuiltinWindowState) -> Result<()> { - // If we do not use state, set_state does nothing - Ok(()) - } - fn update_state( &mut self, _state: &WindowAggState, diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 27b4495940d41..ee16a23c68673 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -129,16 +129,6 @@ impl PartitionEvaluator for RankEvaluator { Ok(BuiltinWindowState::Rank(self.state.clone())) } - fn set_state(&mut self, state: &BuiltinWindowState) -> Result<()> { - match &state { - BuiltinWindowState::Rank(rank_state) => { - self.state = rank_state.clone(); - } - _ => self.state = RankState::default(), - } - Ok(()) - } - fn update_state( &mut self, state: &WindowAggState, diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index efbcb688cabf9..26e81e90a360d 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -81,16 +81,6 @@ impl PartitionEvaluator for NumRowsEvaluator { Ok(BuiltinWindowState::NumRows(self.state.clone())) } - fn set_state(&mut self, state: &BuiltinWindowState) -> Result<()> { - match &state { - BuiltinWindowState::NumRows(num_rows_state) => { - self.state = num_rows_state.clone(); - } - _ => self.state = NumRowsState::default(), - } - Ok(()) - } - fn get_range(&self, state: &WindowAggState, _n_rows: usize) -> Result> { Ok(Range { start: state.last_calculated_index, From 63d77a6f60fbe85bdeb0ed8206ef5b225b517b46 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 22 Dec 2022 13:24:28 +0300 Subject: [PATCH 19/50] remove locked flag --- .github/workflows/rust.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c15615d59b082..9b96c126eb76a 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -64,7 +64,7 @@ jobs: - name: Check Cargo.lock for datafusion-cli run: | # If this test fails, try running `cargo update` in the `datafusion-cli` directory - cargo check --manifest-path datafusion-cli/Cargo.toml --locked + cargo check --manifest-path datafusion-cli/Cargo.toml # test the crate linux-test: From ba388cb1ea632bd9d451ba263f8685dcd6f9d54d Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 23 Dec 2022 10:58:20 +0300 Subject: [PATCH 20/50] address reviews --- datafusion/common/src/lib.rs | 16 ------------ .../core/src/physical_optimizer/utils.rs | 1 + datafusion/core/src/physical_plan/common.rs | 25 +++++++++++++++++++ .../core/src/physical_plan/repartition.rs | 7 +----- .../physical_plan/windows/window_agg_exec.rs | 14 ++++++----- datafusion/expr/src/window_frame.rs | 4 ++- datafusion/physical-expr/src/aggregate/mod.rs | 6 ++--- .../physical-expr/src/window/aggregate.rs | 14 +++++------ .../physical-expr/src/window/built_in.rs | 10 +++----- .../window/built_in_window_function_expr.rs | 3 ++- .../src/window/sliding_aggregate.rs | 14 +++++------ 11 files changed, 59 insertions(+), 55 deletions(-) diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 9911b1499a7b8..392fa3f25a673 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -73,19 +73,3 @@ pub fn reverse_sort_options(options: SortOptions) -> SortOptions { nulls_first: !options.nulls_first, } } - -/// Transposes the given vector of vectors. -pub fn transpose(original: Vec>) -> Vec> { - match original.as_slice() { - [] => vec![], - [first, ..] => { - let mut result = (0..first.len()).map(|_| vec![]).collect::>(); - for row in original { - for (item, transposed_row) in row.into_iter().zip(&mut result) { - transposed_row.push(item); - } - } - result - } - } -} diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index 94394ca527ee6..8f1fe2d08213f 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -102,6 +102,7 @@ pub fn ordering_satisfy_concrete EquivalenceProperties>( } /// Util function to add SortExec above child +/// preserving the original partitioning pub fn add_sort_above_child( child: &Arc, sort_expr: Vec, diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index b29dc0cb8c119..1c36014f20123 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -266,6 +266,22 @@ impl Drop for AbortOnDropMany { } } +/// Transposes the given vector of vectors. +pub fn transpose(original: Vec>) -> Vec> { + match original.as_slice() { + [] => vec![], + [first, ..] => { + let mut result = (0..first.len()).map(|_| vec![]).collect::>(); + for row in original { + for (item, transposed_row) in row.into_iter().zip(&mut result) { + transposed_row.push(item); + } + } + result + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -332,6 +348,15 @@ mod tests { assert_eq!(actual, expected); Ok(()) } + + #[test] + fn test_transpose() -> Result<()> { + let in_data = vec![vec![1, 2, 3], vec![4, 5, 6]]; + let transposed = transpose(in_data); + let expected = vec![vec![1, 4], vec![2, 5], vec![3, 6]]; + assert_eq!(expected, transposed); + Ok(()) + } } /// Write in Arrow IPC format. diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs index f7005d113e306..3dc0c6d337cc3 100644 --- a/datafusion/core/src/physical_plan/repartition.rs +++ b/datafusion/core/src/physical_plan/repartition.rs @@ -298,12 +298,7 @@ impl ExecutionPlan for RepartitionExec { fn maintains_input_order(&self) -> bool { // We preserve ordering when input partitioning is 1 - let n_input = match self.input().output_partitioning() { - Partitioning::RoundRobinBatch(n) => n, - Partitioning::Hash(_, n) => n, - Partitioning::UnknownPartitioning(n) => n, - }; - n_input <= 1 + self.input().output_partitioning().partition_count() <= 1 } fn equivalence_properties(&self) -> EquivalenceProperties { diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs index c709fe4942571..d1ea0af69ad1a 100644 --- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs @@ -19,6 +19,7 @@ use crate::error::Result; use crate::execution::context::TaskContext; +use crate::physical_plan::common::transpose; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, @@ -37,7 +38,7 @@ use arrow::{ error::{ArrowError, Result as ArrowResult}, record_batch::RecordBatch, }; -use datafusion_common::{transpose, DataFusionError}; +use datafusion_common::DataFusionError; use datafusion_physical_expr::rewrite::TreeNodeRewritable; use datafusion_physical_expr::EquivalentClass; use futures::stream::Stream; @@ -136,15 +137,16 @@ impl WindowAggExec { self.input_schema.clone() } - /// Get partition keys + /// Return the output sort order of partition keys: For example + /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a + // We are sure that partition by columns are always at the beginning of sort_keys + // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely + // to calculate partition separation points pub fn partition_by_sort_keys(&self) -> Result> { let mut result = vec![]; // All window exprs have the same partition by, so we just use the first one: let partition_by = self.window_expr()[0].partition_by(); - let sort_keys = self - .sort_keys - .as_ref() - .map_or_else(|| &[] as &[PhysicalSortExpr], |v| v.as_slice()); + let sort_keys = self.sort_keys.as_deref().unwrap_or(&[]); for item in partition_by { if let Some(a) = sort_keys.iter().find(|&e| e.expr.eq(item)) { result.push(a.clone()); diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index aa25ecb05a0dd..100ea8e1ded10 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -114,7 +114,9 @@ impl WindowFrame { } } - /// Get reversed window frame + /// Get reversed window frame. For example + /// `3 ROWS PRECEDING AND 2 ROWS FOLLOWING` --> + /// `2 ROWS PRECEDING AND 3 ROWS FOLLOWING` pub fn reverse(&self) -> Self { let start_bound = match &self.end_bound { WindowFrameBound::Preceding(elem) => { diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 85352189bf53b..947336596292c 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -104,10 +104,10 @@ pub trait AggregateExpr: Send + Sync + Debug { } /// Construct an expression that calculates the aggregate in reverse. + /// Typically the "reverse" expression is itself (e.g. SUM, COUNT). + /// For aggregates that do not support calculation in reverse, + /// returns None (which is the default value). fn reverse_expr(&self) -> Option> { - // Typically the "reverse" expression is itself (e.g. SUM, COUNT). - // For aggregates that do not support calculation in reverse, - // returns None (which is the default value). None } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 357a9a12e9b2e..5c46f38f220ff 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -143,25 +143,23 @@ impl WindowExpr for AggregateWindowExpr { } fn get_reverse_expr(&self) -> Option> { - if let Some(reverse_expr) = self.aggregate.reverse_expr() { + self.aggregate.reverse_expr().map(|reverse_expr| { let reverse_window_frame = self.window_frame.reverse(); if reverse_window_frame.start_bound.is_unbounded() { - Some(Arc::new(AggregateWindowExpr::new( + Arc::new(AggregateWindowExpr::new( reverse_expr, &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), - ))) + )) as _ } else { - Some(Arc::new(SlidingAggregateWindowExpr::new( + Arc::new(SlidingAggregateWindowExpr::new( reverse_expr, &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), - ))) + )) as _ } - } else { - None - } + }) } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index a8d768d0357e4..9804432b2056d 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -127,15 +127,13 @@ impl WindowExpr for BuiltInWindowExpr { } fn get_reverse_expr(&self) -> Option> { - if let Some(reverse_expr) = self.expr.reverse_expr() { - Some(Arc::new(BuiltInWindowExpr::new( + self.expr.reverse_expr().map(|reverse_expr| { + Arc::new(BuiltInWindowExpr::new( reverse_expr, &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), - ))) - } else { - None - } + )) as _ + }) } } diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index e7553e566fa2b..c358403fefdac 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -59,7 +59,8 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { /// Create built-in window evaluator with a batch fn create_evaluator(&self) -> Result>; - /// Construct Reverse Expression + /// Construct Reverse Expression that produces the same result + /// on a reversed window. For example `lead(10)` --> `lag(10)` fn reverse_expr(&self) -> Option> { None } diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 6e0fa8b1476d5..2a0fa86b7fe33 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -151,25 +151,23 @@ impl WindowExpr for SlidingAggregateWindowExpr { } fn get_reverse_expr(&self) -> Option> { - if let Some(reverse_expr) = self.aggregate.reverse_expr() { + self.aggregate.reverse_expr().map(|reverse_expr| { let reverse_window_frame = self.window_frame.reverse(); if reverse_window_frame.start_bound.is_unbounded() { - Some(Arc::new(AggregateWindowExpr::new( + Arc::new(AggregateWindowExpr::new( reverse_expr, &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), - ))) + )) as _ } else { - Some(Arc::new(SlidingAggregateWindowExpr::new( + Arc::new(SlidingAggregateWindowExpr::new( reverse_expr, &self.partition_by.clone(), &reverse_order_bys(&self.order_by), Arc::new(self.window_frame.reverse()), - ))) + )) as _ } - } else { - None - } + }) } } From 572a1a4271651ec0f5223389c9d1ab694e476c1f Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 23 Dec 2022 13:23:39 +0300 Subject: [PATCH 21/50] address reviews --- datafusion/expr/src/utils.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 987ea4ade8204..ca06dfdb4aafe 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -204,6 +204,8 @@ pub fn expand_qualified_wildcard( expand_wildcard(&qualifier_schema, plan) } +/// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)") +/// if bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column type WindowSortKey = Vec<(Expr, bool)>; /// Generate a sort key for a given window expr's partition_by and order_bu expr From af60aa96dcdacf754fad7839d5836b5a0ada512f Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 23 Dec 2022 17:22:18 +0300 Subject: [PATCH 22/50] Resolve merge conflict --- .../physical-expr/src/window/aggregate.rs | 153 +---------------- datafusion/physical-expr/src/window/mod.rs | 2 +- .../src/window/sliding_aggregate.rs | 161 +++++++++++++++++- 3 files changed, 161 insertions(+), 155 deletions(-) diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index ccbba9f057140..ead986eaaa839 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -23,18 +23,15 @@ use std::ops::Range; use std::sync::Arc; use arrow::array::Array; -use arrow::compute::{concat, SortOptions}; +use arrow::compute::SortOptions; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::{Accumulator, WindowFrame}; +use datafusion_expr::WindowFrame; -use crate::window::window_expr::{reverse_order_bys, WindowFn, WindowFunctionState}; -use crate::window::{ - PartitionBatches, PartitionWindowAggStates, WindowAggState, WindowState, -}; +use crate::window::window_expr::reverse_order_bys; use crate::window::SlidingAggregateWindowExpr; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use crate::{window::WindowExpr, AggregateExpr}; @@ -133,61 +130,6 @@ impl WindowExpr for AggregateWindowExpr { ScalarValue::iter_to_array(row_wise_results.into_iter()) } - fn evaluate_bounded( - &self, - partition_batches: &PartitionBatches, - window_agg_state: &mut PartitionWindowAggStates, - ) -> Result<()> { - for (partition_row, partition_batch_state) in partition_batches.iter() { - if !window_agg_state.contains_key(partition_row) { - let accumulator = self.aggregate.create_accumulator()?; - let field = self.aggregate.field()?; - let out_type = field.data_type(); - // let out_type = &accumulator.out_type()?; - window_agg_state.insert( - partition_row.clone(), - WindowState { - state: WindowAggState::new( - out_type, - WindowFunctionState::AggregateState(vec![]), - )?, - window_fn: WindowFn::Aggregate(accumulator), - }, - ); - }; - let window_state = window_agg_state.get_mut(partition_row).unwrap(); - let accumulator = match &mut window_state.window_fn { - WindowFn::Aggregate(accumulator) => accumulator, - _ => unreachable!(), - }; - let mut state = &mut window_state.state; - state.is_end = partition_batch_state.is_end; - - let num_rows = partition_batch_state.record_batch.num_rows(); - - let mut idx = state.last_calculated_index; - let mut last_range = state.current_range_of_sliding_window.clone(); - let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); - let out_col = self.get_result_column( - accumulator, - &partition_batch_state.record_batch, - &mut window_frame_ctx, - &mut last_range, - &mut idx, - state.is_end, - )?; - state.last_calculated_index = idx; - state.current_range_of_sliding_window = last_range.clone(); - - state.out_col = concat(&[&state.out_col, &out_col])?; - state.n_row_result_missing = num_rows - state.last_calculated_index; - - state.window_function_state = - WindowFunctionState::AggregateState(accumulator.state()?); - } - Ok(()) - } - fn partition_by(&self) -> &[Arc] { &self.partition_by } @@ -227,92 +169,3 @@ impl WindowExpr for AggregateWindowExpr { && !self.window_frame.end_bound.is_unbounded() } } - -impl AggregateWindowExpr { - /// For given range calculate accumulator result inside range on value_slice and - /// update accumulator state - fn get_aggregate_result_inside_range( - &self, - last_range: &Range, - cur_range: &Range, - value_slice: &[ArrayRef], - accumulator: &mut Box, - ) -> Result { - let value = if cur_range.start == cur_range.end { - // We produce None if the window is empty. - ScalarValue::try_from(self.aggregate.field()?.data_type())? - } else { - // Accumulate any new rows that have entered the window: - let update_bound = cur_range.end - last_range.end; - if update_bound > 0 { - let update: Vec = value_slice - .iter() - .map(|v| v.slice(last_range.end, update_bound)) - .collect(); - accumulator.update_batch(&update)? - } - // Remove rows that have now left the window: - let retract_bound = cur_range.start - last_range.start; - if retract_bound > 0 { - let retract: Vec = value_slice - .iter() - .map(|v| v.slice(last_range.start, retract_bound)) - .collect(); - accumulator.retract_batch(&retract)? - } - accumulator.evaluate()? - }; - Ok(value) - } - - fn get_result_column( - &self, - accumulator: &mut Box, - // values: &[ArrayRef], - // order_bys: &[ArrayRef], - record_batch: &RecordBatch, - window_frame_ctx: &mut WindowFrameContext, - last_range: &mut Range, - idx: &mut usize, - is_end: bool, - ) -> Result { - let (values, order_bys) = self.get_values_orderbys(record_batch)?; - // We iterate on each row to perform a running calculation. - // First, current_range_of_sliding_window is calculated, then it is compared with last_range. - let length = values[0].len(); - let sort_options: Vec = - self.order_by.iter().map(|o| o.options).collect(); - let mut row_wise_results: Vec = vec![]; - let field = self.aggregate.field()?; - let out_type = field.data_type(); - while *idx < length { - let cur_range = window_frame_ctx.calculate_range( - &order_bys, - &sort_options, - length, - *idx, - )?; - // exit if range end index is length, need kind of flag to stop - if cur_range.end == length && !is_end { - break; - } - let value = self.get_aggregate_result_inside_range( - last_range, - &cur_range, - &values, - accumulator, - )?; - row_wise_results.push(value); - last_range.start = cur_range.start; - last_range.end = cur_range.end; - *idx += 1; - } - let out_col = if !row_wise_results.is_empty() { - ScalarValue::iter_to_array(row_wise_results.into_iter())? - } else { - let a = ScalarValue::try_from(out_type)?; - a.to_array_of_size(0) - }; - Ok(out_col) - } -} diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index 0b773eb53075c..35036a6dbeb4c 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -32,11 +32,11 @@ mod window_frame_state; pub use aggregate::AggregateWindowExpr; pub use built_in::BuiltInWindowExpr; pub use built_in_window_function_expr::BuiltInWindowFunctionExpr; +pub use sliding_aggregate::SlidingAggregateWindowExpr; pub use window_expr::PartitionBatchState; pub use window_expr::PartitionBatches; pub use window_expr::PartitionKey; pub use window_expr::PartitionWindowAggStates; pub use window_expr::WindowAggState; -pub use sliding_aggregate::SlidingAggregateWindowExpr; pub use window_expr::WindowExpr; pub use window_expr::WindowState; diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 2a0fa86b7fe33..4ef47a236825e 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -23,16 +23,19 @@ use std::ops::Range; use std::sync::Arc; use arrow::array::Array; -use arrow::compute::SortOptions; +use arrow::compute::{concat, SortOptions}; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::WindowFrame; +use datafusion_expr::{Accumulator, WindowFrame}; -use crate::window::window_expr::reverse_order_bys; -use crate::window::AggregateWindowExpr; +use crate::window::window_expr::{reverse_order_bys, WindowFn, WindowFunctionState}; +use crate::window::{ + AggregateWindowExpr, PartitionBatches, PartitionWindowAggStates, WindowAggState, + WindowState, +}; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use crate::{window::WindowExpr, AggregateExpr}; @@ -138,6 +141,61 @@ impl WindowExpr for SlidingAggregateWindowExpr { ScalarValue::iter_to_array(row_wise_results.into_iter()) } + fn evaluate_bounded( + &self, + partition_batches: &PartitionBatches, + window_agg_state: &mut PartitionWindowAggStates, + ) -> Result<()> { + for (partition_row, partition_batch_state) in partition_batches.iter() { + if !window_agg_state.contains_key(partition_row) { + let accumulator = self.aggregate.create_accumulator()?; + let field = self.aggregate.field()?; + let out_type = field.data_type(); + // let out_type = &accumulator.out_type()?; + window_agg_state.insert( + partition_row.clone(), + WindowState { + state: WindowAggState::new( + out_type, + WindowFunctionState::AggregateState(vec![]), + )?, + window_fn: WindowFn::Aggregate(accumulator), + }, + ); + }; + let window_state = window_agg_state.get_mut(partition_row).unwrap(); + let accumulator = match &mut window_state.window_fn { + WindowFn::Aggregate(accumulator) => accumulator, + _ => unreachable!(), + }; + let mut state = &mut window_state.state; + state.is_end = partition_batch_state.is_end; + + let num_rows = partition_batch_state.record_batch.num_rows(); + + let mut idx = state.last_calculated_index; + let mut last_range = state.current_range_of_sliding_window.clone(); + let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + let out_col = self.get_result_column( + accumulator, + &partition_batch_state.record_batch, + &mut window_frame_ctx, + &mut last_range, + &mut idx, + state.is_end, + )?; + state.last_calculated_index = idx; + state.current_range_of_sliding_window = last_range.clone(); + + state.out_col = concat(&[&state.out_col, &out_col])?; + state.n_row_result_missing = num_rows - state.last_calculated_index; + + state.window_function_state = + WindowFunctionState::AggregateState(accumulator.state()?); + } + Ok(()) + } + fn partition_by(&self) -> &[Arc] { &self.partition_by } @@ -170,4 +228,99 @@ impl WindowExpr for SlidingAggregateWindowExpr { } }) } + + fn can_run_bounded(&self) -> bool { + self.aggregate.bounded_exec_supported() + && !self.window_frame.start_bound.is_unbounded() + && !self.window_frame.end_bound.is_unbounded() + } +} + +impl SlidingAggregateWindowExpr { + /// For given range calculate accumulator result inside range on value_slice and + /// update accumulator state + fn get_aggregate_result_inside_range( + &self, + last_range: &Range, + cur_range: &Range, + value_slice: &[ArrayRef], + accumulator: &mut Box, + ) -> Result { + let value = if cur_range.start == cur_range.end { + // We produce None if the window is empty. + ScalarValue::try_from(self.aggregate.field()?.data_type())? + } else { + // Accumulate any new rows that have entered the window: + let update_bound = cur_range.end - last_range.end; + if update_bound > 0 { + let update: Vec = value_slice + .iter() + .map(|v| v.slice(last_range.end, update_bound)) + .collect(); + accumulator.update_batch(&update)? + } + // Remove rows that have now left the window: + let retract_bound = cur_range.start - last_range.start; + if retract_bound > 0 { + let retract: Vec = value_slice + .iter() + .map(|v| v.slice(last_range.start, retract_bound)) + .collect(); + accumulator.retract_batch(&retract)? + } + accumulator.evaluate()? + }; + Ok(value) + } + + fn get_result_column( + &self, + accumulator: &mut Box, + // values: &[ArrayRef], + // order_bys: &[ArrayRef], + record_batch: &RecordBatch, + window_frame_ctx: &mut WindowFrameContext, + last_range: &mut Range, + idx: &mut usize, + is_end: bool, + ) -> Result { + let (values, order_bys) = self.get_values_orderbys(record_batch)?; + // We iterate on each row to perform a running calculation. + // First, current_range_of_sliding_window is calculated, then it is compared with last_range. + let length = values[0].len(); + let sort_options: Vec = + self.order_by.iter().map(|o| o.options).collect(); + let mut row_wise_results: Vec = vec![]; + let field = self.aggregate.field()?; + let out_type = field.data_type(); + while *idx < length { + let cur_range = window_frame_ctx.calculate_range( + &order_bys, + &sort_options, + length, + *idx, + )?; + // exit if range end index is length, need kind of flag to stop + if cur_range.end == length && !is_end { + break; + } + let value = self.get_aggregate_result_inside_range( + last_range, + &cur_range, + &values, + accumulator, + )?; + row_wise_results.push(value); + last_range.start = cur_range.start; + last_range.end = cur_range.end; + *idx += 1; + } + let out_col = if !row_wise_results.is_empty() { + ScalarValue::iter_to_array(row_wise_results.into_iter())? + } else { + let a = ScalarValue::try_from(out_type)?; + a.to_array_of_size(0) + }; + Ok(out_col) + } } From ca711e4642ae21e64220a87562820102a68b81df Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 23 Dec 2022 17:34:56 +0300 Subject: [PATCH 23/50] address reviews --- .../replace_window_with_bounded_impl.rs | 30 ++++++++----------- .../windows/bounded_window_agg_exec.rs | 2 +- .../physical-expr/src/aggregate/count.rs | 2 +- datafusion/physical-expr/src/aggregate/mod.rs | 2 +- datafusion/physical-expr/src/aggregate/sum.rs | 2 +- .../physical-expr/src/window/aggregate.rs | 2 +- .../physical-expr/src/window/built_in.rs | 4 +-- .../window/built_in_window_function_expr.rs | 2 +- .../physical-expr/src/window/lead_lag.rs | 2 +- .../physical-expr/src/window/nth_value.rs | 2 +- datafusion/physical-expr/src/window/rank.rs | 2 +- .../physical-expr/src/window/row_number.rs | 2 +- .../src/window/sliding_aggregate.rs | 3 +- 13 files changed, 25 insertions(+), 32 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/replace_window_with_bounded_impl.rs b/datafusion/core/src/physical_optimizer/replace_window_with_bounded_impl.rs index 0ff03130fddd1..e7143b1e7bd03 100644 --- a/datafusion/core/src/physical_optimizer/replace_window_with_bounded_impl.rs +++ b/datafusion/core/src/physical_optimizer/replace_window_with_bounded_impl.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -//! CoalesceBatches optimizer that groups batches together rows -//! in bigger batches to avoid overhead with small batches +//! ReplaceWindowWithBoundedImpl optimizer that replaces `WindowAggExec` +//! with `BoundedWindowAggExec` if window_expr can run using `BoundedWindowAggExec` use crate::physical_plan::windows::BoundedWindowAggExec; use crate::physical_plan::windows::WindowAggExec; @@ -25,11 +25,10 @@ use crate::{ physical_plan::rewrite::TreeNodeRewritable, }; use datafusion_expr::WindowFrameUnits; -use datafusion_physical_expr::window::WindowExpr; use std::sync::Arc; -/// Optimizer rule that introduces CoalesceBatchesExec to avoid overhead with small batches that -/// are produced by highly selective filters +/// Optimizer rule that introduces replaces `WindowAggExec` with `BoundedWindowAggExec` +/// to run executor with bounded memory. #[derive(Default)] pub struct ReplaceWindowWithBoundedImpl {} @@ -47,10 +46,13 @@ impl PhysicalOptimizerRule for ReplaceWindowWithBoundedImpl { ) -> Result> { plan.transform_up(&|plan| { if let Some(window_agg_exec) = plan.as_any().downcast_ref::() { - let is_contains_groups = window_agg_exec - .window_expr() - .iter() - .any(is_window_frame_groups); + let is_contains_groups = + window_agg_exec.window_expr().iter().any(|window_expr| { + matches!( + window_expr.get_window_frame().units, + WindowFrameUnits::Groups + ) + }); let can_run_bounded = window_agg_exec .window_expr() .iter() @@ -70,18 +72,10 @@ impl PhysicalOptimizerRule for ReplaceWindowWithBoundedImpl { } fn name(&self) -> &str { - "coalesce_batches" + "ReplaceWindowWithBoundedImpl" } fn schema_check(&self) -> bool { true } } - -/// Checks window expression whether it is GROUPS mode -fn is_window_frame_groups(window_expr: &Arc) -> bool { - matches!( - window_expr.get_window_frame().units, - WindowFrameUnits::Groups - ) -} diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index 177071a2922fc..0877aae243bab 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -388,7 +388,7 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { is_end: false, }; self.partition_batches - .insert(partition_row.clone(), partition_batch_state); + .insert(partition_row, partition_batch_state); }; } } diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 778d9f885835f..8ccf87ac2b1da 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -98,7 +98,7 @@ impl AggregateExpr for Count { true } - fn bounded_exec_supported(&self) -> bool { + fn supports_bounded_execution(&self) -> bool { true } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 2cd9f48422deb..372823e754f06 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -90,7 +90,7 @@ pub trait AggregateExpr: Send + Sync + Debug { /// Specifies whether this aggregate function can run using bounded memory /// To be true accumulator should have `retract_batch` implemented - fn bounded_exec_supported(&self) -> bool { + fn supports_bounded_execution(&self) -> bool { false } diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 7b3961e36e3eb..649ce67996a44 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -113,7 +113,7 @@ impl AggregateExpr for Sum { is_row_accumulator_support_dtype(&self.data_type) } - fn bounded_exec_supported(&self) -> bool { + fn supports_bounded_execution(&self) -> bool { true } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index ead986eaaa839..8e598a0ad3c76 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -164,7 +164,7 @@ impl WindowExpr for AggregateWindowExpr { } fn can_run_bounded(&self) -> bool { - self.aggregate.bounded_exec_supported() + self.aggregate.supports_bounded_execution() && !self.window_frame.start_bound.is_unbounded() && !self.window_frame.end_bound.is_unbounded() } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 543d07616e9b1..62fac81438ca1 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -239,11 +239,11 @@ impl WindowExpr for BuiltInWindowExpr { fn can_run_bounded(&self) -> bool { if self.expr.uses_window_frame() { - self.expr.bounded_exec_supported() + self.expr.supports_bounded_execution() && !self.window_frame.start_bound.is_unbounded() && !self.window_frame.end_bound.is_unbounded() } else { - self.expr.bounded_exec_supported() + self.expr.supports_bounded_execution() } } } diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index e1e5a71553aa7..6f41ec599a83c 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -65,7 +65,7 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { None } - fn bounded_exec_supported(&self) -> bool { + fn supports_bounded_execution(&self) -> bool { false } diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 3939292ccba42..a6b6e6272e615 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -110,7 +110,7 @@ impl BuiltInWindowFunctionExpr for WindowShift { })) } - fn bounded_exec_supported(&self) -> bool { + fn supports_bounded_execution(&self) -> bool { true } diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 9836feec51ca7..2431200379aa6 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -128,7 +128,7 @@ impl BuiltInWindowFunctionExpr for NthValue { })) } - fn bounded_exec_supported(&self) -> bool { + fn supports_bounded_execution(&self) -> bool { true } diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index ee16a23c68673..92d0f6a5179e8 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -99,7 +99,7 @@ impl BuiltInWindowFunctionExpr for Rank { &self.name } - fn bounded_exec_supported(&self) -> bool { + fn supports_bounded_execution(&self) -> bool { matches!(self.rank_type, RankType::Basic | RankType::Dense) } diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index 26e81e90a360d..d099e8914d79e 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -65,7 +65,7 @@ impl BuiltInWindowFunctionExpr for RowNumber { Ok(Box::::default()) } - fn bounded_exec_supported(&self) -> bool { + fn supports_bounded_execution(&self) -> bool { true } } diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 4ef47a236825e..85ab1076ef5f4 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -151,7 +151,6 @@ impl WindowExpr for SlidingAggregateWindowExpr { let accumulator = self.aggregate.create_accumulator()?; let field = self.aggregate.field()?; let out_type = field.data_type(); - // let out_type = &accumulator.out_type()?; window_agg_state.insert( partition_row.clone(), WindowState { @@ -230,7 +229,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { } fn can_run_bounded(&self) -> bool { - self.aggregate.bounded_exec_supported() + self.aggregate.supports_bounded_execution() && !self.window_frame.start_bound.is_unbounded() && !self.window_frame.end_bound.is_unbounded() } From eb97a5cfd87df954db207c4d3bd5f95890741fc6 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 23 Dec 2022 18:13:49 +0300 Subject: [PATCH 24/50] address reviews --- datafusion/physical-expr/src/window/built_in.rs | 7 +++++-- datafusion/physical-expr/src/window/sliding_aggregate.rs | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 62fac81438ca1..4869814d46e1d 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -30,8 +30,8 @@ use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use arrow::compute::{concat, SortOptions}; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::Result; use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::WindowFrame; use std::any::Any; use std::sync::Arc; @@ -153,7 +153,10 @@ impl WindowExpr for BuiltInWindowExpr { }, ); }; - let window_state = window_agg_state.get_mut(partition_row).unwrap(); + let window_state = + window_agg_state.get_mut(partition_row).ok_or_else(|| { + DataFusionError::Execution("Cannot find state".to_string()) + })?; let evaluator = match &mut window_state.window_fn { WindowFn::Builtin(evaluator) => evaluator, _ => unreachable!(), diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 85ab1076ef5f4..7c5e956f39b0d 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -27,8 +27,8 @@ use arrow::compute::{concat, SortOptions}; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::Result; use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{Accumulator, WindowFrame}; use crate::window::window_expr::{reverse_order_bys, WindowFn, WindowFunctionState}; @@ -162,7 +162,10 @@ impl WindowExpr for SlidingAggregateWindowExpr { }, ); }; - let window_state = window_agg_state.get_mut(partition_row).unwrap(); + let window_state = + window_agg_state.get_mut(partition_row).ok_or_else(|| { + DataFusionError::Execution("Cannot find state".to_string()) + })?; let accumulator = match &mut window_state.window_fn { WindowFn::Aggregate(accumulator) => accumulator, _ => unreachable!(), From 36394c0b44a6d29f925eed4371d6f05409bc50f9 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 26 Dec 2022 09:56:47 +0300 Subject: [PATCH 25/50] address reviews --- .../windows/bounded_window_agg_exec.rs | 61 +++++++++++-------- .../physical-expr/src/window/built_in.rs | 40 +++++------- .../physical-expr/src/window/lead_lag.rs | 11 ++-- .../physical-expr/src/window/nth_value.rs | 2 +- datafusion/physical-expr/src/window/rank.rs | 2 +- .../src/window/sliding_aggregate.rs | 7 +-- .../physical-expr/src/window/window_expr.rs | 8 ++- 7 files changed, 66 insertions(+), 65 deletions(-) diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index 0877aae243bab..84ba1f9544001 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -111,23 +111,26 @@ impl BoundedWindowAggExec { self.input_schema.clone() } - /// Get Partition Columns + /// Return the output sort order of partition keys: For example + /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a + // We are sure that partition by columns are always at the beginning of sort_keys + // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely + // to calculate partition separation points pub fn partition_by_sort_keys(&self) -> Result> { - // All window exprs have same partition by hance we just use first one + let mut result = vec![]; + // All window exprs have the same partition by, so we just use the first one: let partition_by = self.window_expr()[0].partition_by(); - let mut partition_columns = vec![]; - for elem in partition_by { - if let Some(sort_keys) = &self.sort_keys { - for a in sort_keys { - if a.expr.eq(elem) { - partition_columns.push(a.clone()); - break; - } - } + let sort_keys = self.sort_keys.as_deref().unwrap_or(&[]); + for item in partition_by { + if let Some(a) = sort_keys.iter().find(|&e| e.expr.eq(item)) { + result.push(a.clone()); + } else { + return Err(DataFusionError::Execution( + "Partition key not found in sort keys".to_string(), + )); } } - assert_eq!(partition_by.len(), partition_columns.len()); - Ok(partition_columns) + Ok(result) } } @@ -343,12 +346,17 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { } } - /// prunes sections in the state that are no longer needed + /// Prunes sections in the state that are no longer needed + /// for calculating result. Determined by window frame boundaries + // For instance if `n_out` number of rows are calculated, we can remove first `n_out` rows from + // the `self.input_buffer_record_batch` fn prune_state(&mut self, n_out: usize) -> Result<()> { + // Prunes `self.partition_batches` self.prune_partition_batches()?; + // Prunes `self.input_buffer_record_batch` self.prune_input_batch(n_out)?; + // Prunes `self.window_agg_states` self.prune_out_columns(n_out)?; - Ok(()) } @@ -552,14 +560,12 @@ impl SortedPartitionByBoundedWindowStream { window_agg_state.retain(|_, WindowState { state, .. }| !state.is_end); for (partition_row, WindowState { state: value, .. }) in window_agg_state { if let Some(state) = n_prune_each_partition.get_mut(partition_row) { - if value.current_range_of_sliding_window.start < *state { - *state = value.current_range_of_sliding_window.start; + if value.window_frame_range.start < *state { + *state = value.window_frame_range.start; } } else { - n_prune_each_partition.insert( - partition_row.clone(), - value.current_range_of_sliding_window.start, - ); + n_prune_each_partition + .insert(partition_row.clone(), value.window_frame_range.start); } } } @@ -582,9 +588,9 @@ impl SortedPartitionByBoundedWindowStream { let window_state = window_agg_state.get_mut(partition_row).ok_or_else(err)?; let mut state = &mut window_state.state; - state.current_range_of_sliding_window = Range { - start: state.current_range_of_sliding_window.start - n_prune, - end: state.current_range_of_sliding_window.end - n_prune, + state.window_frame_range = Range { + start: state.window_frame_range.start - n_prune, + end: state.window_frame_range.end - n_prune, }; state.last_calculated_index -= n_prune; state.offset_pruned_rows += n_prune; @@ -700,6 +706,11 @@ fn get_aggregate_result_out_column( break; } } - assert_eq!(running_length, len_to_show); + if running_length != len_to_show { + return Err(DataFusionError::Execution(format!( + "Generated row number should be {}, it is {}", + len_to_show, running_length + ))); + } ret.ok_or_else(|| DataFusionError::Execution("Should contain something".to_string())) } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 4869814d46e1d..5633576e49265 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -136,10 +136,10 @@ impl WindowExpr for BuiltInWindowExpr { let sort_options: Vec = self.order_by.iter().map(|o| o.options).collect(); for (partition_row, partition_batch_state) in partition_batches.iter() { + let field = self.expr.field()?; + let out_type = field.data_type(); if !window_agg_state.contains_key(partition_row) { let evaluator = self.expr.create_evaluator()?; - let field = self.expr.field()?; - let out_type = field.data_type(); window_agg_state.insert( partition_row.clone(), WindowState { @@ -172,13 +172,11 @@ impl WindowExpr for BuiltInWindowExpr { self.get_values_orderbys(&partition_batch_state.record_batch)?; // We iterate on each row to perform a running calculation. - // First, current_range_of_sliding_window is calculated, then it is compared with last_range. let mut row_wise_results: Vec = vec![]; - let mut last_range = state.current_range_of_sliding_window.clone(); + let mut last_range = state.window_frame_range.clone(); let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); for idx in state.last_calculated_index..num_rows { - state.current_range_of_sliding_window = if !self.expr.uses_window_frame() - { + state.window_frame_range = if !self.expr.uses_window_frame() { evaluator.get_range(state, num_rows)? } else { window_frame_ctx.calculate_range( @@ -190,15 +188,13 @@ impl WindowExpr for BuiltInWindowExpr { }; evaluator.update_state(state, &order_bys, &sort_partition_points)?; // exit if range end index is length, need kind of flag to stop - if state.current_range_of_sliding_window.end == num_rows + if state.window_frame_range.end == num_rows && !partition_batch_state.is_end { - state.current_range_of_sliding_window = last_range.clone(); + state.window_frame_range = last_range.clone(); break; } - if state.current_range_of_sliding_window.start - == state.current_range_of_sliding_window.end - { + if state.window_frame_range.start == state.window_frame_range.end { // We produce None if the window is empty. row_wise_results .push(ScalarValue::try_from(self.expr.field()?.data_type())?) @@ -206,15 +202,14 @@ impl WindowExpr for BuiltInWindowExpr { let res = evaluator.evaluate_bounded(&values)?; row_wise_results.push(res); } - last_range = state.current_range_of_sliding_window.clone(); + last_range = state.window_frame_range.clone(); state.last_calculated_index = idx + 1; } - state.current_range_of_sliding_window = last_range; - let out_col = if !row_wise_results.is_empty() { - ScalarValue::iter_to_array(row_wise_results.into_iter())? + state.window_frame_range = last_range; + let out_col = if row_wise_results.is_empty() { + ScalarValue::try_from(out_type)?.to_array_of_size(0) } else { - let a = ScalarValue::try_from(self.expr.field()?.data_type())?; - a.to_array_of_size(0) + ScalarValue::iter_to_array(row_wise_results.into_iter())? }; state.out_col = concat(&[&state.out_col, &out_col])?; @@ -241,12 +236,9 @@ impl WindowExpr for BuiltInWindowExpr { } fn can_run_bounded(&self) -> bool { - if self.expr.uses_window_frame() { - self.expr.supports_bounded_execution() - && !self.window_frame.start_bound.is_unbounded() - && !self.window_frame.end_bound.is_unbounded() - } else { - self.expr.supports_bounded_execution() - } + self.expr.supports_bounded_execution() + && (!self.expr.uses_window_frame() + || !(self.window_frame.start_bound.is_unbounded() + || self.window_frame.end_bound.is_unbounded())) } } diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index a6b6e6272e615..5f9f690deeaca 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -240,13 +240,12 @@ fn get_default_value( dtype: &DataType, ) -> Result { if let Some(val) = default_value { - match val { - ScalarValue::Int64(Some(val)) => { - ScalarValue::try_from_string(val.to_string(), dtype) - } - _ => Err(DataFusionError::Internal( + if let ScalarValue::Int64(Some(val)) = val { + ScalarValue::try_from_string(val.to_string(), dtype) + } else { + Err(DataFusionError::Internal( "Expects default value to have Int64 type".to_string(), - )), + )) } } else { Ok(ScalarValue::try_from(dtype)?) diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 2431200379aa6..0bc24830c6b5f 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -171,7 +171,7 @@ impl PartitionEvaluator for NthValueEvaluator { _sort_partition_points: &[Range], ) -> Result<()> { // If we do not use state, update_state does nothing - self.state.range = state.current_range_of_sliding_window.clone(); + self.state.range = state.window_frame_range.clone(); Ok(()) } diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 92d0f6a5179e8..4541bde818c35 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -168,7 +168,7 @@ impl PartitionEvaluator for RankEvaluator { ))), RankType::Dense => Ok(ScalarValue::UInt64(Some(self.state.n_rank as u64))), RankType::Percent => Err(DataFusionError::Execution( - "Cannot Run Percent_RANK in streaming case".to_string(), + "Can not execute Percent_RANK in a streaming fashion".to_string(), )), } } diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 7c5e956f39b0d..ac32cf010f790 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -176,7 +176,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { let num_rows = partition_batch_state.record_batch.num_rows(); let mut idx = state.last_calculated_index; - let mut last_range = state.current_range_of_sliding_window.clone(); + let mut last_range = state.window_frame_range.clone(); let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); let out_col = self.get_result_column( accumulator, @@ -187,7 +187,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { state.is_end, )?; state.last_calculated_index = idx; - state.current_range_of_sliding_window = last_range.clone(); + state.window_frame_range = last_range.clone(); state.out_col = concat(&[&state.out_col, &out_col])?; state.n_row_result_missing = num_rows - state.last_calculated_index; @@ -278,8 +278,6 @@ impl SlidingAggregateWindowExpr { fn get_result_column( &self, accumulator: &mut Box, - // values: &[ArrayRef], - // order_bys: &[ArrayRef], record_batch: &RecordBatch, window_frame_ctx: &mut WindowFrameContext, last_range: &mut Range, @@ -288,7 +286,6 @@ impl SlidingAggregateWindowExpr { ) -> Result { let (values, order_bys) = self.get_values_orderbys(record_batch)?; // We iterate on each row to perform a running calculation. - // First, current_range_of_sliding_window is calculated, then it is compared with last_range. let length = values[0].len(); let sort_options: Vec = self.order_by.iter().map(|o| o.options).collect(); diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 567668800f8b0..0c78843f5ff1b 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -230,15 +230,17 @@ pub enum WindowFunctionState { #[derive(Debug, Clone)] pub struct WindowAggState { /// The range that we calculate the window function - pub current_range_of_sliding_window: Range, + pub window_frame_range: Range, /// The index of the last row that its result is calculated inside the partition record batch buffer. pub last_calculated_index: usize, /// The offset of the deleted row number pub offset_pruned_rows: usize, /// pub window_function_state: WindowFunctionState, - // Keeps the results + /// Stores the results calculated by window frame pub out_col: ArrayRef, + /// Keeps track of how many rows should be generated to be in sync with input record_batch. + // (For each row in the input record batch we need to generate a window result). pub n_row_result_missing: usize, /// flag indicating whether we have received all data for this partition pub is_end: bool, @@ -275,7 +277,7 @@ impl WindowAggState { ) -> Result { let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0); Ok(Self { - current_range_of_sliding_window: Range { start: 0, end: 0 }, + window_frame_range: Range { start: 0, end: 0 }, last_calculated_index: 0, offset_pruned_rows: 0, window_function_state, From 8b3d37f828e75c6f772e469e9404f2a7bc00db8c Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 26 Dec 2022 15:10:23 +0300 Subject: [PATCH 26/50] Add new tests --- datafusion/core/tests/sql/window.rs | 273 ++++++++++++++++++++ datafusion/physical-expr/src/window/rank.rs | 2 +- 2 files changed, 274 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 9d6536d3db9f9..0a8334efa55d3 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -16,6 +16,8 @@ // under the License. use super::*; +use ::parquet::arrow::arrow_writer::ArrowWriter; +use ::parquet::file::properties::WriterProperties; /// for window functions without order by the first, last, and nth function call does not make sense #[tokio::test] @@ -2340,3 +2342,274 @@ async fn test_window_agg_sort_orderby_reversed_partitionby_reversed_plan() -> Re Ok(()) } + +fn write_test_data_to_parquet(tmpdir: &TempDir, n_file: usize) -> Result<()> { + let ts_field = Field::new("ts", DataType::Int32, false); + let inc_field = Field::new("inc_col", DataType::Int32, false); + let desc_field = Field::new("desc_col", DataType::Int32, false); + let equal_field = Field::new("equal", DataType::Int32, false); + let nonmonothonic_inc_field = Field::new("nonmonothonic_inc", DataType::Int32, false); + + let unsorted_non_unique_field = + Field::new("unsorted_non_unique_field", DataType::Int32, false); + + let schema = Arc::new(Schema::new(vec![ + ts_field, + inc_field, + desc_field, + equal_field, + nonmonothonic_inc_field, + unsorted_non_unique_field, + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from_slice([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, + 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, + 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, + 93, 94, 95, 96, 97, 98, 99, 100, + ])), + Arc::new(Int32Array::from_slice([ + 1, 2, 6, 10, 14, 15, 19, 20, 21, 22, 24, 29, 34, 38, 41, 45, 48, 50, 51, + 56, 61, 63, 65, 67, 71, 72, 75, 76, 80, 84, 88, 93, 96, 97, 98, 101, 105, + 110, 115, 118, 123, 124, 129, 132, 137, 141, 144, 149, 151, 152, 153, + 155, 157, 161, 165, 169, 173, 178, 180, 184, 187, 188, 189, 191, 193, + 197, 200, 203, 208, 210, 215, 218, 222, 227, 228, 232, 234, 236, 240, + 242, 244, 245, 249, 252, 253, 254, 259, 260, 261, 265, 267, 271, 276, + 277, 279, 284, 287, 290, 293, 296, + ])), + Arc::new(Int32Array::from_slice([ + 100, 99, 95, 91, 86, 83, 81, 78, 75, 70, 65, 63, 59, 56, 53, 49, 48, 45, + 43, 42, 38, 35, 30, 25, 22, 21, 16, 11, 9, 5, 0, -5, -6, -9, -10, -15, + -20, -24, -27, -28, -31, -32, -34, -36, -39, -42, -46, -49, -54, -56, + -61, -64, -69, -70, -75, -78, -79, -84, -86, -87, -90, -93, -96, -97, + -99, -100, -104, -105, -109, -112, -117, -118, -119, -121, -122, -126, + -130, -132, -133, -134, -138, -139, -142, -143, -147, -149, -153, -155, + -160, -162, -163, -165, -168, -171, -172, -174, -178, -183, -187, -191, + ])), + Arc::new(Int32Array::from_slice([ + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100, + ])), + Arc::new(Int32Array::from_slice([ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, + 9, 9, 9, 9, + ])), + Arc::new(Int32Array::from_slice([ + 2, 1, 3, 1, 2, 5, 5, 3, 3, 3, 5, 1, 4, 2, 2, 5, 0, 3, 5, 2, 5, 3, 5, 5, + 3, 1, 0, 0, 5, 1, 2, 4, 5, 2, 4, 1, 0, 5, 2, 2, 4, 3, 5, 3, 4, 4, 1, 1, + 5, 3, 3, 5, 4, 1, 1, 2, 5, 4, 0, 2, 0, 4, 4, 3, 5, 4, 4, 4, 2, 5, 0, 3, + 0, 5, 1, 4, 5, 4, 5, 5, 2, 4, 5, 3, 2, 0, 1, 1, 1, 0, 2, 5, 1, 2, 5, 4, + 4, 2, 0, 4, + ])), + ], + )?; + let n_chunk = batch.num_rows() / n_file; + for i in 0..n_file { + let target_file = tmpdir.path().join(format!("{}.parquet", i)); + let file = File::create(target_file).unwrap(); + // Default writer properties + let props = WriterProperties::builder().build(); + let chunks_start = i * n_chunk; + let cur_batch = batch.slice(chunks_start, n_chunk); + // let chunks_end = chunks_start + n_chunk; + let mut writer = + ArrowWriter::try_new(file, cur_batch.schema(), Some(props)).unwrap(); + + writer.write(&cur_batch).expect("Writing batch"); + + // writer must be closed to write footer + writer.close().unwrap(); + } + Ok(()) +} + +async fn get_test_context(tmpdir: &TempDir) -> Result { + let session_config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::with_config(session_config); + + let parquet_read_options = ParquetReadOptions::default(); + // The sort order is specified (not actually correct in this case) + let file_sort_order = [col("ts")] + .into_iter() + .map(|e| { + let ascending = true; + let nulls_first = false; + e.sort(ascending, nulls_first) + }) + .collect::>(); + + let options_sort = parquet_read_options + .to_listing_options(&ctx.copied_config()) + .with_file_sort_order(Some(file_sort_order)); + + write_test_data_to_parquet(tmpdir, 1)?; + let provided_schema = None; + let sql_definition = None; + ctx.register_listing_table( + "annotated_data", + tmpdir.path().to_string_lossy(), + options_sort.clone(), + provided_schema, + sql_definition, + ) + .await + .unwrap(); + Ok(ctx) +} + +mod tests { + use super::*; + + #[tokio::test] + async fn test_source_sorted_aggregate() -> Result<()> { + let tmpdir = TempDir::new().unwrap(); + let ctx = get_test_context(&tmpdir).await?; + + let sql = "SELECT + SUM(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING) as sum1, + SUM(desc_col) OVER(ORDER BY ts RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING) as sum2, + SUM(inc_col) OVER(ORDER BY ts ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING) as sum3, + COUNT(*) OVER(ORDER BY ts RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING) as cnt1, + COUNT(*) OVER(ORDER BY ts ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING) as cnt2, + SUM(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING) as sumr1, + SUM(desc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING) as sumr2, + SUM(desc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sumr3, + COUNT(*) OVER(ORDER BY ts DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) as cntr1, + COUNT(*) OVER(ORDER BY ts DESC ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING) as cntr2, + SUM(desc_col) OVER(ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING) as sum4, + COUNT(*) OVER(ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING) as cnt3 + FROM annotated_data + ORDER BY ts DESC + LIMIT 5 + "; + + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, sum3@2 as sum3, cnt1@3 as cnt1, cnt2@4 as cnt2, sumr1@5 as sumr1, sumr2@6 as sumr2, sumr3@7 as sumr3, cntr1@8 as cntr1, cntr2@9 as cntr2, sum4@10 as sum4, cnt3@11 as cnt3]", + " GlobalLimitExec: skip=0, fetch=5", + " SortExec: [ts@12 DESC]", + " ProjectionExec: expr=[SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as sum1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@3 as sum2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@4 as sum3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@5 as cnt1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@6 as cnt2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@7 as sumr1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@8 as sumr2, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@9 as sumr3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@10 as cntr1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@11 as cntr2, SUM(annotated_data.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@0 as sum4, COUNT(UInt8(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@1 as cnt3, ts@12 as ts]", + " BoundedWindowAggExec: wdw=[SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(8)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(8)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(2)), end_bound: Following(Int32(6)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(8)) }]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------+-------+------+------+------+-------+-------+-------+-------+-------+-------+------+", + "| sum1 | sum2 | sum3 | cnt1 | cnt2 | sumr1 | sumr2 | sumr3 | cntr1 | cntr2 | sum4 | cnt3 |", + "+------+-------+------+------+------+-------+-------+-------+-------+-------+-------+------+", + "| 3085 | -1085 | 589 | 5 | 9 | 1450 | -1589 | -1085 | 3 | 2 | -1589 | 9 |", + "| 3346 | -1256 | 879 | 6 | 10 | 1729 | -1752 | -1256 | 4 | 3 | -1752 | 10 |", + "| 3310 | -1233 | 1166 | 7 | 10 | 1710 | -1723 | -1233 | 5 | 4 | -1723 | 10 |", + "| 3276 | -1211 | 1450 | 8 | 10 | 1693 | -1696 | -1211 | 6 | 5 | -1696 | 10 |", + "| 3240 | -1191 | 1729 | 9 | 10 | 1674 | -1668 | -1191 | 7 | 6 | -1668 | 10 |", + "+------+-------+------+------+------+-------+-------+-------+-------+-------+-------+------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) + } + + #[tokio::test] + async fn test_source_sorted_builtin() -> Result<()> { + let tmpdir = TempDir::new().unwrap(); + let ctx = get_test_context(&tmpdir).await?; + + let sql = "SELECT + FIRST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as fv1, + FIRST_VALUE(inc_col) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as fv2, + LAST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as lv1, + LAST_VALUE(inc_col) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lv2, + NTH_VALUE(inc_col, 5) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as nv1, + NTH_VALUE(inc_col, 5) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as nv2, + ROW_NUMBER() OVER(ORDER BY ts RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS rn1, + ROW_NUMBER() OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as rn2, + RANK() OVER(ORDER BY ts RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS rank1, + RANK() OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as rank2, + DENSE_RANK() OVER(ORDER BY ts RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS dense_rank1, + DENSE_RANK() OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as dense_rank2, + LAG(inc_col, 1, 1001) OVER(ORDER BY ts RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS lag1, + LAG(inc_col, 2, 1002) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lag2, + LEAD(inc_col, -1, 1001) OVER(ORDER BY ts RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS lead1, + LEAD(inc_col, 4, 1004) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lead2, + FIRST_VALUE(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as fvr1, + FIRST_VALUE(inc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as fvr2, + LAST_VALUE(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as lvr1, + LAST_VALUE(inc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lvr2, + LAG(inc_col, 1, 1001) OVER(ORDER BY ts DESC RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS lagr1, + LAG(inc_col, 2, 1002) OVER(ORDER BY ts DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lagr2, + LEAD(inc_col, -1, 1001) OVER(ORDER BY ts DESC RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS leadr1, + LEAD(inc_col, 4, 1004) OVER(ORDER BY ts DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as leadr2 + FROM annotated_data + ORDER BY ts DESC + LIMIT 5 + "; + + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "ProjectionExec: expr=[fv1@0 as fv1, fv2@1 as fv2, lv1@2 as lv1, lv2@3 as lv2, nv1@4 as nv1, nv2@5 as nv2, rn1@6 as rn1, rn2@7 as rn2, rank1@8 as rank1, rank2@9 as rank2, dense_rank1@10 as dense_rank1, dense_rank2@11 as dense_rank2, lag1@12 as lag1, lag2@13 as lag2, lead1@14 as lead1, lead2@15 as lead2, fvr1@16 as fvr1, fvr2@17 as fvr2, lvr1@18 as lvr1, lvr2@19 as lvr2, lagr1@20 as lagr1, lagr2@21 as lagr2, leadr1@22 as leadr1, leadr2@23 as leadr2]", + " GlobalLimitExec: skip=0, fetch=5", + " SortExec: [ts@24 DESC]", + " ProjectionExec: expr=[FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@0 as fv1, FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@1 as fv2, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as lv1, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as lv2, NTH_VALUE(annotated_data.inc_col,Int64(5)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as nv1, NTH_VALUE(annotated_data.inc_col,Int64(5)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as nv2, ROW_NUMBER() ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as rn1, ROW_NUMBER() ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as rn2, RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as rank1, RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as rank2, DENSE_RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as dense_rank1, DENSE_RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as dense_rank2, LAG(annotated_data.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@12 as lag1, LAG(annotated_data.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lag2, LEAD(annotated_data.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@14 as lead1, LEAD(annotated_data.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as lead2, FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as fvr1, FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as fvr2, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@18 as lvr1, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as lvr2, LAG(annotated_data.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as lagr1, LAG(annotated_data.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as lagr2, LEAD(annotated_data.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as leadr1, LEAD(annotated_data.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as leadr2, ts@24 as ts]", + " BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data.inc_col): Ok(Field { name: \"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, FIRST_VALUE(annotated_data.inc_col): Ok(Field { name: \"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LAST_VALUE(annotated_data.inc_col): Ok(Field { name: \"LAST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LAST_VALUE(annotated_data.inc_col): Ok(Field { name: \"LAST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, NTH_VALUE(annotated_data.inc_col,Int64(5)): Ok(Field { name: \"NTH_VALUE(annotated_data.inc_col,Int64(5))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, NTH_VALUE(annotated_data.inc_col,Int64(5)): Ok(Field { name: \"NTH_VALUE(annotated_data.inc_col,Int64(5))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, RANK(): Ok(Field { name: \"RANK()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, RANK(): Ok(Field { name: \"RANK()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, DENSE_RANK(): Ok(Field { name: \"DENSE_RANK()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, DENSE_RANK(): Ok(Field { name: \"DENSE_RANK()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LAG(annotated_data.inc_col,Int64(1),Int64(1001)): Ok(Field { name: \"LAG(annotated_data.inc_col,Int64(1),Int64(1001))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LAG(annotated_data.inc_col,Int64(2),Int64(1002)): Ok(Field { name: \"LAG(annotated_data.inc_col,Int64(2),Int64(1002))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(annotated_data.inc_col,Int64(-1),Int64(1001)): Ok(Field { name: \"LEAD(annotated_data.inc_col,Int64(-1),Int64(1001))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LEAD(annotated_data.inc_col,Int64(4),Int64(1004)): Ok(Field { name: \"LEAD(annotated_data.inc_col,Int64(4),Int64(1004))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data.inc_col): Ok(Field { name: \"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, FIRST_VALUE(annotated_data.inc_col): Ok(Field { name: \"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LAST_VALUE(annotated_data.inc_col): Ok(Field { name: \"LAST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LAST_VALUE(annotated_data.inc_col): Ok(Field { name: \"LAST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LAG(annotated_data.inc_col,Int64(1),Int64(1001)): Ok(Field { name: \"LAG(annotated_data.inc_col,Int64(1),Int64(1001))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LAG(annotated_data.inc_col,Int64(2),Int64(1002)): Ok(Field { name: \"LAG(annotated_data.inc_col,Int64(2),Int64(1002))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LEAD(annotated_data.inc_col,Int64(-1),Int64(1001)): Ok(Field { name: \"LEAD(annotated_data.inc_col,Int64(-1),Int64(1001))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LEAD(annotated_data.inc_col,Int64(4),Int64(1004)): Ok(Field { name: \"LEAD(annotated_data.inc_col,Int64(4),Int64(1004))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----+-----+-----+-----+-----+-----+-----+-----+-------+-------+-------------+-------------+------+------+-------+-------+------+------+------+------+-------+-------+--------+--------+", + "| fv1 | fv2 | lv1 | lv2 | nv1 | nv2 | rn1 | rn2 | rank1 | rank2 | dense_rank1 | dense_rank2 | lag1 | lag2 | lead1 | lead2 | fvr1 | fvr2 | lvr1 | lvr2 | lagr1 | lagr2 | leadr1 | leadr2 |", + "+-----+-----+-----+-----+-----+-----+-----+-----+-------+-------+-------------+-------------+------+------+-------+-------+------+------+------+------+-------+-------+--------+--------+", + "| 265 | 265 | 296 | 296 | 277 | 277 | 100 | 100 | 100 | 100 | 100 | 100 | 293 | 290 | 293 | 1004 | 296 | 296 | 293 | 293 | 1001 | 1002 | 1001 | 284 |", + "| 261 | 261 | 296 | 296 | 276 | 276 | 99 | 99 | 99 | 99 | 99 | 99 | 290 | 287 | 290 | 1004 | 296 | 296 | 290 | 290 | 296 | 1002 | 296 | 279 |", + "| 260 | 260 | 293 | 293 | 271 | 271 | 98 | 98 | 98 | 98 | 98 | 98 | 287 | 284 | 287 | 1004 | 296 | 296 | 287 | 287 | 293 | 296 | 293 | 277 |", + "| 259 | 259 | 290 | 290 | 267 | 267 | 97 | 97 | 97 | 97 | 97 | 97 | 284 | 279 | 284 | 1004 | 296 | 296 | 284 | 284 | 290 | 293 | 290 | 276 |", + "| 254 | 254 | 287 | 287 | 265 | 265 | 96 | 96 | 96 | 96 | 96 | 96 | 279 | 277 | 279 | 296 | 296 | 296 | 279 | 279 | 287 | 290 | 287 | 271 |", + "+-----+-----+-----+-----+-----+-----+-----+-----+-------+-------+-------------+-------------+------+------+-------+-------+------+------+------+------+-------+-------+--------+--------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 4541bde818c35..e96d48108f149 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -151,7 +151,7 @@ impl PartitionEvaluator for RankEvaluator { if self.state.last_rank_data.is_empty() { self.state.last_rank_data = last_rank_data; self.state.last_rank_boundary = state.offset_pruned_rows + cur_chunk.start; - self.state.n_rank = sort_partition_points.len(); + self.state.n_rank = chunk_idx + 1; } else if self.state.last_rank_data != last_rank_data { self.state.last_rank_data = last_rank_data; self.state.last_rank_boundary = state.offset_pruned_rows + cur_chunk.start; From 25af93c8930593db509950f40f7f98782952ea17 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 26 Dec 2022 15:30:28 +0300 Subject: [PATCH 27/50] Update tests --- datafusion/core/tests/sql/window.rs | 113 ++++++++++------------------ 1 file changed, 39 insertions(+), 74 deletions(-) diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 0a8334efa55d3..d4a6436a85c58 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -2347,73 +2347,38 @@ fn write_test_data_to_parquet(tmpdir: &TempDir, n_file: usize) -> Result<()> { let ts_field = Field::new("ts", DataType::Int32, false); let inc_field = Field::new("inc_col", DataType::Int32, false); let desc_field = Field::new("desc_col", DataType::Int32, false); - let equal_field = Field::new("equal", DataType::Int32, false); - let nonmonothonic_inc_field = Field::new("nonmonothonic_inc", DataType::Int32, false); - let unsorted_non_unique_field = - Field::new("unsorted_non_unique_field", DataType::Int32, false); - - let schema = Arc::new(Schema::new(vec![ - ts_field, - inc_field, - desc_field, - equal_field, - nonmonothonic_inc_field, - unsorted_non_unique_field, - ])); + let schema = Arc::new(Schema::new(vec![ts_field, inc_field, desc_field])); let batch = RecordBatch::try_new( schema, vec![ Arc::new(Int32Array::from_slice([ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, - 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, - 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, - 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, - 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, - 93, 94, 95, 96, 97, 98, 99, 100, - ])), - Arc::new(Int32Array::from_slice([ - 1, 2, 6, 10, 14, 15, 19, 20, 21, 22, 24, 29, 34, 38, 41, 45, 48, 50, 51, - 56, 61, 63, 65, 67, 71, 72, 75, 76, 80, 84, 88, 93, 96, 97, 98, 101, 105, - 110, 115, 118, 123, 124, 129, 132, 137, 141, 144, 149, 151, 152, 153, - 155, 157, 161, 165, 169, 173, 178, 180, 184, 187, 188, 189, 191, 193, - 197, 200, 203, 208, 210, 215, 218, 222, 227, 228, 232, 234, 236, 240, - 242, 244, 245, 249, 252, 253, 254, 259, 260, 261, 265, 267, 271, 276, - 277, 279, 284, 287, 290, 293, 296, - ])), - Arc::new(Int32Array::from_slice([ - 100, 99, 95, 91, 86, 83, 81, 78, 75, 70, 65, 63, 59, 56, 53, 49, 48, 45, - 43, 42, 38, 35, 30, 25, 22, 21, 16, 11, 9, 5, 0, -5, -6, -9, -10, -15, - -20, -24, -27, -28, -31, -32, -34, -36, -39, -42, -46, -49, -54, -56, - -61, -64, -69, -70, -75, -78, -79, -84, -86, -87, -90, -93, -96, -97, - -99, -100, -104, -105, -109, -112, -117, -118, -119, -121, -122, -126, - -130, -132, -133, -134, -138, -139, -142, -143, -147, -149, -153, -155, - -160, -162, -163, -165, -168, -171, -172, -174, -178, -183, -187, -191, - ])), - Arc::new(Int32Array::from_slice([ - 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, - 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, - 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, - 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, - 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, - 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, - 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, - 100, 100, + 1, 1, 5, 9, 10, 11, 16, 21, 22, 26, 26, 28, 31, 33, 38, 42, 47, 51, 53, + 53, 58, 63, 67, 68, 70, 72, 72, 76, 81, 85, 86, 88, 91, 96, 97, 98, 100, + 101, 102, 104, 104, 108, 112, 113, 113, 114, 114, 117, 122, 126, 131, + 131, 136, 136, 136, 139, 141, 146, 147, 147, 152, 154, 159, 161, 163, + 164, 167, 172, 173, 177, 180, 185, 186, 191, 195, 195, 199, 203, 207, + 210, 213, 218, 221, 224, 226, 230, 232, 235, 238, 238, 239, 244, 245, + 247, 250, 254, 258, 262, 264, 264, ])), Arc::new(Int32Array::from_slice([ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, - 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, - 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, - 9, 9, 9, 9, + 1, 5, 10, 15, 20, 21, 26, 29, 30, 33, 37, 40, 43, 44, 45, 49, 51, 53, 58, + 61, 65, 70, 75, 78, 83, 88, 90, 91, 95, 97, 100, 105, 109, 111, 115, 119, + 120, 124, 126, 129, 131, 135, 140, 143, 144, 147, 148, 149, 151, 155, + 156, 159, 160, 163, 165, 170, 172, 177, 181, 182, 186, 187, 192, 196, + 197, 199, 203, 207, 209, 213, 214, 216, 219, 221, 222, 225, 226, 231, + 236, 237, 242, 245, 247, 248, 253, 254, 259, 261, 266, 269, 272, 275, + 278, 283, 286, 289, 291, 296, 301, 305, ])), Arc::new(Int32Array::from_slice([ - 2, 1, 3, 1, 2, 5, 5, 3, 3, 3, 5, 1, 4, 2, 2, 5, 0, 3, 5, 2, 5, 3, 5, 5, - 3, 1, 0, 0, 5, 1, 2, 4, 5, 2, 4, 1, 0, 5, 2, 2, 4, 3, 5, 3, 4, 4, 1, 1, - 5, 3, 3, 5, 4, 1, 1, 2, 5, 4, 0, 2, 0, 4, 4, 3, 5, 4, 4, 4, 2, 5, 0, 3, - 0, 5, 1, 4, 5, 4, 5, 5, 2, 4, 5, 3, 2, 0, 1, 1, 1, 0, 2, 5, 1, 2, 5, 4, - 4, 2, 0, 4, + 100, 98, 93, 91, 86, 84, 81, 77, 75, 71, 70, 69, 64, 62, 59, 55, 50, 45, + 41, 40, 39, 36, 31, 28, 23, 22, 17, 13, 10, 6, 5, 2, 1, -1, -4, -5, -6, + -8, -12, -16, -17, -19, -24, -25, -29, -34, -37, -42, -47, -48, -49, -53, + -57, -58, -61, -65, -67, -68, -71, -73, -75, -76, -78, -83, -87, -91, + -95, -98, -101, -105, -106, -111, -114, -116, -120, -125, -128, -129, + -134, -139, -142, -143, -146, -150, -154, -158, -163, -168, -172, -176, + -181, -184, -189, -193, -196, -201, -203, -208, -210, -213, ])), ], )?; @@ -2493,7 +2458,7 @@ mod tests { SUM(desc_col) OVER(ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING) as sum4, COUNT(*) OVER(ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING) as cnt3 FROM annotated_data - ORDER BY ts DESC + ORDER BY inc_col DESC LIMIT 5 "; @@ -2505,8 +2470,8 @@ mod tests { vec![ "ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, sum3@2 as sum3, cnt1@3 as cnt1, cnt2@4 as cnt2, sumr1@5 as sumr1, sumr2@6 as sumr2, sumr3@7 as sumr3, cntr1@8 as cntr1, cntr2@9 as cntr2, sum4@10 as sum4, cnt3@11 as cnt3]", " GlobalLimitExec: skip=0, fetch=5", - " SortExec: [ts@12 DESC]", - " ProjectionExec: expr=[SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as sum1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@3 as sum2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@4 as sum3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@5 as cnt1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@6 as cnt2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@7 as sumr1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@8 as sumr2, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@9 as sumr3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@10 as cntr1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@11 as cntr2, SUM(annotated_data.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@0 as sum4, COUNT(UInt8(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@1 as cnt3, ts@12 as ts]", + " SortExec: [inc_col@12 DESC]", + " ProjectionExec: expr=[SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as sum1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@3 as sum2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@4 as sum3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@5 as cnt1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@6 as cnt2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@7 as sumr1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@8 as sumr2, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@9 as sumr3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@10 as cntr1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@11 as cntr2, SUM(annotated_data.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@0 as sum4, COUNT(UInt8(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@1 as cnt3, inc_col@13 as inc_col]", " BoundedWindowAggExec: wdw=[SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }]", " BoundedWindowAggExec: wdw=[SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(8)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }]", " BoundedWindowAggExec: wdw=[SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(8)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(2)), end_bound: Following(Int32(6)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(8)) }]", @@ -2524,15 +2489,15 @@ mod tests { let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------+-------+------+------+------+-------+-------+-------+-------+-------+-------+------+", - "| sum1 | sum2 | sum3 | cnt1 | cnt2 | sumr1 | sumr2 | sumr3 | cntr1 | cntr2 | sum4 | cnt3 |", - "+------+-------+------+------+------+-------+-------+-------+-------+-------+-------+------+", - "| 3085 | -1085 | 589 | 5 | 9 | 1450 | -1589 | -1085 | 3 | 2 | -1589 | 9 |", - "| 3346 | -1256 | 879 | 6 | 10 | 1729 | -1752 | -1256 | 4 | 3 | -1752 | 10 |", - "| 3310 | -1233 | 1166 | 7 | 10 | 1710 | -1723 | -1233 | 5 | 4 | -1723 | 10 |", - "| 3276 | -1211 | 1450 | 8 | 10 | 1693 | -1696 | -1211 | 6 | 5 | -1696 | 10 |", - "| 3240 | -1191 | 1729 | 9 | 10 | 1674 | -1668 | -1191 | 7 | 6 | -1668 | 10 |", - "+------+-------+------+------+------+-------+-------+-------+-------+-------+-------+------+", + "+------+------+------+------+------+-------+-------+-------+-------+-------+-------+------+", + "| sum1 | sum2 | sum3 | cnt1 | cnt2 | sumr1 | sumr2 | sumr3 | cntr1 | cntr2 | sum4 | cnt3 |", + "+------+------+------+------+------+-------+-------+-------+-------+-------+-------+------+", + "| 1482 | -631 | 606 | 3 | 9 | 902 | -834 | -1231 | 3 | 2 | -1797 | 9 |", + "| 1482 | -631 | 902 | 3 | 10 | 902 | -834 | -1424 | 3 | 3 | -1978 | 10 |", + "| 876 | -411 | 1193 | 4 | 10 | 587 | -612 | -1400 | 3 | 4 | -1941 | 10 |", + "| 866 | -404 | 1482 | 5 | 10 | 580 | -600 | -1374 | 4 | 5 | -1903 | 10 |", + "| 1411 | -397 | 1768 | 4 | 10 | 575 | -590 | -1347 | 2 | 6 | -1863 | 10 |", + "+------+------+------+------+------+-------+-------+-------+-------+-------+-------+------+", ]; assert_batches_eq!(expected, &actual); Ok(()) @@ -2602,11 +2567,11 @@ mod tests { "+-----+-----+-----+-----+-----+-----+-----+-----+-------+-------+-------------+-------------+------+------+-------+-------+------+------+------+------+-------+-------+--------+--------+", "| fv1 | fv2 | lv1 | lv2 | nv1 | nv2 | rn1 | rn2 | rank1 | rank2 | dense_rank1 | dense_rank2 | lag1 | lag2 | lead1 | lead2 | fvr1 | fvr2 | lvr1 | lvr2 | lagr1 | lagr2 | leadr1 | leadr2 |", "+-----+-----+-----+-----+-----+-----+-----+-----+-------+-------+-------------+-------------+------+------+-------+-------+------+------+------+------+-------+-------+--------+--------+", - "| 265 | 265 | 296 | 296 | 277 | 277 | 100 | 100 | 100 | 100 | 100 | 100 | 293 | 290 | 293 | 1004 | 296 | 296 | 293 | 293 | 1001 | 1002 | 1001 | 284 |", - "| 261 | 261 | 296 | 296 | 276 | 276 | 99 | 99 | 99 | 99 | 99 | 99 | 290 | 287 | 290 | 1004 | 296 | 296 | 290 | 290 | 296 | 1002 | 296 | 279 |", - "| 260 | 260 | 293 | 293 | 271 | 271 | 98 | 98 | 98 | 98 | 98 | 98 | 287 | 284 | 287 | 1004 | 296 | 296 | 287 | 287 | 293 | 296 | 293 | 277 |", - "| 259 | 259 | 290 | 290 | 267 | 267 | 97 | 97 | 97 | 97 | 97 | 97 | 284 | 279 | 284 | 1004 | 296 | 296 | 284 | 284 | 290 | 293 | 290 | 276 |", - "| 254 | 254 | 287 | 287 | 265 | 265 | 96 | 96 | 96 | 96 | 96 | 96 | 279 | 277 | 279 | 296 | 296 | 296 | 279 | 279 | 287 | 290 | 287 | 271 |", + "| 289 | 266 | 305 | 305 | 305 | 278 | 99 | 99 | 99 | 99 | 86 | 86 | 296 | 291 | 296 | 1004 | 305 | 305 | 301 | 296 | 305 | 1002 | 305 | 286 |", + "| 289 | 269 | 305 | 305 | 305 | 283 | 100 | 100 | 99 | 99 | 86 | 86 | 301 | 296 | 301 | 1004 | 305 | 305 | 301 | 301 | 1001 | 1002 | 1001 | 289 |", + "| 289 | 261 | 296 | 301 | | 275 | 98 | 98 | 98 | 98 | 85 | 85 | 291 | 289 | 291 | 1004 | 305 | 305 | 296 | 291 | 301 | 305 | 301 | 283 |", + "| 286 | 259 | 291 | 296 | | 272 | 97 | 97 | 97 | 97 | 84 | 84 | 289 | 286 | 289 | 1004 | 305 | 305 | 291 | 289 | 296 | 301 | 296 | 278 |", + "| 275 | 254 | 289 | 291 | 289 | 269 | 96 | 96 | 96 | 96 | 83 | 83 | 286 | 283 | 286 | 305 | 305 | 305 | 289 | 286 | 291 | 296 | 291 | 275 |", "+-----+-----+-----+-----+-----+-----+-----+-----+-------+-------+-------------+-------------+------+------+-------+-------+------+------+------+------+-------+-------+--------+--------+", ]; assert_batches_eq!(expected, &actual); From 2b2b376077d9e43d1c9bdbcc6d8bdb3ae918a151 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 26 Dec 2022 15:56:25 +0300 Subject: [PATCH 28/50] add support for bounded min max --- datafusion/core/tests/sql/window.rs | 40 ++++++++++++------- .../physical-expr/src/aggregate/min_max.rs | 20 +++++++++- .../src/window/sliding_aggregate.rs | 2 +- 3 files changed, 45 insertions(+), 17 deletions(-) diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index d4a6436a85c58..cf64601cd3645 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -2448,11 +2448,23 @@ mod tests { SUM(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING) as sum1, SUM(desc_col) OVER(ORDER BY ts RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING) as sum2, SUM(inc_col) OVER(ORDER BY ts ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING) as sum3, + MIN(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING) as min1, + MIN(desc_col) OVER(ORDER BY ts RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING) as min2, + MIN(inc_col) OVER(ORDER BY ts ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING) as min3, + MAX(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING) as max1, + MAX(desc_col) OVER(ORDER BY ts RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING) as max2, + MAX(inc_col) OVER(ORDER BY ts ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING) as max3, COUNT(*) OVER(ORDER BY ts RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING) as cnt1, COUNT(*) OVER(ORDER BY ts ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING) as cnt2, SUM(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING) as sumr1, SUM(desc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING) as sumr2, SUM(desc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sumr3, + MIN(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING) as minr1, + MIN(desc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING) as minr2, + MIN(inc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING) as minr3, + MAX(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING) as maxr1, + MAX(desc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING) as maxr2, + MAX(inc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING) as maxr3, COUNT(*) OVER(ORDER BY ts DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) as cntr1, COUNT(*) OVER(ORDER BY ts DESC ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING) as cntr2, SUM(desc_col) OVER(ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING) as sum4, @@ -2468,13 +2480,13 @@ mod tests { let formatted = displayable(physical_plan.as_ref()).indent().to_string(); let expected = { vec![ - "ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, sum3@2 as sum3, cnt1@3 as cnt1, cnt2@4 as cnt2, sumr1@5 as sumr1, sumr2@6 as sumr2, sumr3@7 as sumr3, cntr1@8 as cntr1, cntr2@9 as cntr2, sum4@10 as sum4, cnt3@11 as cnt3]", + "ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, sum3@2 as sum3, min1@3 as min1, min2@4 as min2, min3@5 as min3, max1@6 as max1, max2@7 as max2, max3@8 as max3, cnt1@9 as cnt1, cnt2@10 as cnt2, sumr1@11 as sumr1, sumr2@12 as sumr2, sumr3@13 as sumr3, minr1@14 as minr1, minr2@15 as minr2, minr3@16 as minr3, maxr1@17 as maxr1, maxr2@18 as maxr2, maxr3@19 as maxr3, cntr1@20 as cntr1, cntr2@21 as cntr2, sum4@22 as sum4, cnt3@23 as cnt3]", " GlobalLimitExec: skip=0, fetch=5", - " SortExec: [inc_col@12 DESC]", - " ProjectionExec: expr=[SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as sum1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@3 as sum2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@4 as sum3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@5 as cnt1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@6 as cnt2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@7 as sumr1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@8 as sumr2, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@9 as sumr3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@10 as cntr1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@11 as cntr2, SUM(annotated_data.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@0 as sum4, COUNT(UInt8(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@1 as cnt3, inc_col@13 as inc_col]", + " SortExec: [inc_col@24 DESC]", + " ProjectionExec: expr=[SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as sum1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@3 as sum2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@4 as sum3, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as min1, MIN(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@6 as min2, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@7 as min3, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@8 as max1, MAX(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@9 as max2, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as max3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@11 as cnt1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@12 as cnt2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@13 as sumr1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@14 as sumr2, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@15 as sumr3, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as minr1, MIN(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@17 as minr2, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as minr3, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as maxr1, MAX(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@20 as maxr2, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@21 as maxr3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@22 as cntr1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@23 as cntr2, SUM(annotated_data.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@0 as sum4, COUNT(UInt8(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@1 as cnt3, inc_col@25 as inc_col]", " BoundedWindowAggExec: wdw=[SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }]", - " BoundedWindowAggExec: wdw=[SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(8)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }]", - " BoundedWindowAggExec: wdw=[SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(8)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(2)), end_bound: Following(Int32(6)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(8)) }]", + " BoundedWindowAggExec: wdw=[SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, MIN(annotated_data.inc_col): Ok(Field { name: \"MIN(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, MIN(annotated_data.desc_col): Ok(Field { name: \"MIN(annotated_data.desc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, MIN(annotated_data.inc_col): Ok(Field { name: \"MIN(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, MAX(annotated_data.inc_col): Ok(Field { name: \"MAX(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, MAX(annotated_data.desc_col): Ok(Field { name: \"MAX(annotated_data.desc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, MAX(annotated_data.inc_col): Ok(Field { name: \"MAX(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(8)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(8)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, MIN(annotated_data.inc_col): Ok(Field { name: \"MIN(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, MIN(annotated_data.desc_col): Ok(Field { name: \"MIN(annotated_data.desc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)) }, MIN(annotated_data.inc_col): Ok(Field { name: \"MIN(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, MAX(annotated_data.inc_col): Ok(Field { name: \"MAX(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, MAX(annotated_data.desc_col): Ok(Field { name: \"MAX(annotated_data.desc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)) }, MAX(annotated_data.inc_col): Ok(Field { name: \"MAX(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(2)), end_bound: Following(Int32(6)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(8)) }]", ] }; @@ -2489,15 +2501,15 @@ mod tests { let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+------+------+------+------+------+-------+-------+-------+-------+-------+-------+------+", - "| sum1 | sum2 | sum3 | cnt1 | cnt2 | sumr1 | sumr2 | sumr3 | cntr1 | cntr2 | sum4 | cnt3 |", - "+------+------+------+------+------+-------+-------+-------+-------+-------+-------+------+", - "| 1482 | -631 | 606 | 3 | 9 | 902 | -834 | -1231 | 3 | 2 | -1797 | 9 |", - "| 1482 | -631 | 902 | 3 | 10 | 902 | -834 | -1424 | 3 | 3 | -1978 | 10 |", - "| 876 | -411 | 1193 | 4 | 10 | 587 | -612 | -1400 | 3 | 4 | -1941 | 10 |", - "| 866 | -404 | 1482 | 5 | 10 | 580 | -600 | -1374 | 4 | 5 | -1903 | 10 |", - "| 1411 | -397 | 1768 | 4 | 10 | 575 | -590 | -1347 | 2 | 6 | -1863 | 10 |", - "+------+------+------+------+------+-------+-------+-------+-------+-------+-------+------+", + "+------+------+------+------+------+------+------+------+------+------+------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+------+", + "| sum1 | sum2 | sum3 | min1 | min2 | min3 | max1 | max2 | max3 | cnt1 | cnt2 | sumr1 | sumr2 | sumr3 | minr1 | minr2 | minr3 | maxr1 | maxr2 | maxr3 | cntr1 | cntr2 | sum4 | cnt3 |", + "+------+------+------+------+------+------+------+------+------+------+------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+------+", + "| 1482 | -631 | 606 | 289 | -213 | 301 | 305 | -208 | 305 | 3 | 9 | 902 | -834 | -1231 | 301 | -213 | 269 | 305 | -210 | 305 | 3 | 2 | -1797 | 9 |", + "| 1482 | -631 | 902 | 289 | -213 | 296 | 305 | -208 | 305 | 3 | 10 | 902 | -834 | -1424 | 301 | -213 | 266 | 305 | -210 | 305 | 3 | 3 | -1978 | 10 |", + "| 876 | -411 | 1193 | 289 | -208 | 291 | 296 | -203 | 305 | 4 | 10 | 587 | -612 | -1400 | 296 | -213 | 261 | 305 | -208 | 301 | 3 | 4 | -1941 | 10 |", + "| 866 | -404 | 1482 | 286 | -203 | 289 | 291 | -201 | 305 | 5 | 10 | 580 | -600 | -1374 | 291 | -208 | 259 | 305 | -203 | 296 | 4 | 5 | -1903 | 10 |", + "| 1411 | -397 | 1768 | 275 | -201 | 286 | 289 | -196 | 305 | 4 | 10 | 575 | -590 | -1347 | 289 | -203 | 254 | 305 | -201 | 291 | 2 | 6 | -1863 | 10 |", + "+------+------+------+------+------+------+------+------+------+------+------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+------+", ]; assert_batches_eq!(expected, &actual); Ok(()) diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index a7bd6c360a904..bf4fd0868a64a 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -62,7 +62,7 @@ fn min_max_aggregate_data_type(input_type: DataType) -> DataType { } /// MAX aggregate expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Max { name: String, data_type: DataType, @@ -124,6 +124,10 @@ impl AggregateExpr for Max { is_row_accumulator_support_dtype(&self.data_type) } + fn supports_bounded_execution(&self) -> bool { + true + } + fn create_row_accumulator( &self, start_index: usize, @@ -134,6 +138,10 @@ impl AggregateExpr for Max { ))) } + fn reverse_expr(&self) -> Option> { + Some(Arc::new(self.clone())) + } + fn create_sliding_accumulator(&self) -> Result> { Ok(Box::new(SlidingMaxAccumulator::try_new(&self.data_type)?)) } @@ -672,7 +680,7 @@ impl RowAccumulator for MaxRowAccumulator { } /// MIN aggregate expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Min { name: String, data_type: DataType, @@ -734,6 +742,10 @@ impl AggregateExpr for Min { is_row_accumulator_support_dtype(&self.data_type) } + fn supports_bounded_execution(&self) -> bool { + true + } + fn create_row_accumulator( &self, start_index: usize, @@ -744,6 +756,10 @@ impl AggregateExpr for Min { ))) } + fn reverse_expr(&self) -> Option> { + Some(Arc::new(self.clone())) + } + fn create_sliding_accumulator(&self) -> Result> { Ok(Box::new(SlidingMinAccumulator::try_new(&self.data_type)?)) } diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index ac32cf010f790..64c7d20063642 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -148,7 +148,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { ) -> Result<()> { for (partition_row, partition_batch_state) in partition_batches.iter() { if !window_agg_state.contains_key(partition_row) { - let accumulator = self.aggregate.create_accumulator()?; + let accumulator = self.aggregate.create_sliding_accumulator()?; let field = self.aggregate.field()?; let out_type = field.data_type(); window_agg_state.insert( From 670fe327b9b92547ca906dac013427f5db28ec49 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 26 Dec 2022 17:59:30 +0300 Subject: [PATCH 29/50] address reviews --- .../windows/bounded_window_agg_exec.rs | 2 +- datafusion/core/tests/window_fuzz.rs | 37 ++++++++++++++++--- .../physical-expr/src/window/built_in.rs | 2 +- .../src/window/sliding_aggregate.rs | 2 +- .../physical-expr/src/window/window_expr.rs | 4 +- 5 files changed, 37 insertions(+), 10 deletions(-) diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index 84ba1f9544001..3201d5a3b0ed3 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -468,7 +468,7 @@ impl SortedPartitionByBoundedWindowStream { fn compute_aggregates(&mut self) -> ArrowResult { // calculate window cols for (idx, cur_window_expr) in self.window_expr.iter().enumerate() { - cur_window_expr.evaluate_bounded( + cur_window_expr.evaluate_stateful( &self.partition_batches, &mut self.window_agg_states[idx], )?; diff --git a/datafusion/core/tests/window_fuzz.rs b/datafusion/core/tests/window_fuzz.rs index 106f95b62b8b1..9b8a27cfadf89 100644 --- a/datafusion/core/tests/window_fuzz.rs +++ b/datafusion/core/tests/window_fuzz.rs @@ -21,6 +21,7 @@ use arrow::array::{ArrayRef, Int32Array}; use arrow::compute::{concat_batches, SortOptions}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; +use hashbrown::HashMap; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use tokio::runtime::Builder; @@ -31,7 +32,8 @@ use datafusion::physical_plan::windows::{ create_window_expr, BoundedWindowAggExec, WindowAggExec, }; use datafusion_expr::{ - AggregateFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunction, }; use datafusion::prelude::{SessionConfig, SessionContext}; @@ -123,10 +125,35 @@ async fn run_window_test( orderby_columns: Vec<&str>, partition_by_columns: Vec<&str>, ) { + let mut func_name_to_window_fn = HashMap::new(); + func_name_to_window_fn.insert( + "sum", + WindowFunction::AggregateFunction(AggregateFunction::Sum), + ); + func_name_to_window_fn.insert( + "count", + WindowFunction::AggregateFunction(AggregateFunction::Count), + ); + func_name_to_window_fn.insert( + "min", + WindowFunction::AggregateFunction(AggregateFunction::Min), + ); + func_name_to_window_fn.insert( + "max", + WindowFunction::AggregateFunction(AggregateFunction::Max), + ); + func_name_to_window_fn.insert( + "row_number", + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::RowNumber), + ); + let mut rng = StdRng::seed_from_u64(random_seed); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::with_config(session_config); let schema = input1[0].schema(); + let rand_fn_idx = rng.gen_range(0..func_name_to_window_fn.len()); + let fn_name = func_name_to_window_fn.keys().collect::>()[rand_fn_idx]; + let window_fn = func_name_to_window_fn.values().collect::>()[rand_fn_idx]; let preceding = rng.gen_range(0..50); let following = rng.gen_range(0..50); let rand_num = rng.gen_range(0..3); @@ -188,8 +215,8 @@ async fn run_window_test( let usual_window_exec = Arc::new( WindowAggExec::try_new( vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Sum), - "sum".to_owned(), + window_fn, + fn_name.to_string(), &[col("x", &schema).unwrap()], &partitionby_exprs, &orderby_exprs, @@ -209,8 +236,8 @@ async fn run_window_test( let running_window_exec = Arc::new( BoundedWindowAggExec::try_new( vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Sum), - "sum".to_owned(), + window_fn, + fn_name.to_string(), &[col("x", &schema).unwrap()], &partitionby_exprs, &orderby_exprs, diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 5633576e49265..be70f29ceaf47 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -128,7 +128,7 @@ impl WindowExpr for BuiltInWindowExpr { } /// evaluate the window function values against the batch - fn evaluate_bounded( + fn evaluate_stateful( &self, partition_batches: &PartitionBatches, window_agg_state: &mut PartitionWindowAggStates, diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 64c7d20063642..77534c4d8b3ee 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -141,7 +141,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { ScalarValue::iter_to_array(row_wise_results.into_iter()) } - fn evaluate_bounded( + fn evaluate_stateful( &self, partition_batches: &PartitionBatches, window_agg_state: &mut PartitionWindowAggStates, diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 0c78843f5ff1b..394a3abe99878 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -66,13 +66,13 @@ pub trait WindowExpr: Send + Sync + Debug { fn evaluate(&self, batch: &RecordBatch) -> Result; /// evaluate the window function values against the batch - fn evaluate_bounded( + fn evaluate_stateful( &self, _partition_batches: &PartitionBatches, _window_agg_state: &mut PartitionWindowAggStates, ) -> Result<()> { Err(DataFusionError::Internal( - "evaluate_bounded is not implemented".to_string(), + "evaluate_stateful is not implemented".to_string(), )) } From 3ea9eed9e3087c12381a4dd08650681d0f67f1ae Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 27 Dec 2022 09:06:43 +0300 Subject: [PATCH 30/50] rename sort rule --- datafusion/core/src/execution/context.rs | 4 ++-- datafusion/core/src/physical_optimizer/mod.rs | 2 +- ..._unnecessary_sorts.rs => optimize_sorts.rs} | 18 +++++++++--------- 3 files changed, 12 insertions(+), 12 deletions(-) rename datafusion/core/src/physical_optimizer/{remove_unnecessary_sorts.rs => optimize_sorts.rs} (98%) diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index c692bff7b5ffd..324cbd0af165c 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -100,7 +100,7 @@ use url::Url; use crate::catalog::listing_schema::ListingSchemaProvider; use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::memory_pool::MemoryPool; -use crate::physical_optimizer::remove_unnecessary_sorts::RemoveUnnecessarySorts; +use crate::physical_optimizer::optimize_sorts::OptimizeSorts; use crate::physical_optimizer::replace_window_with_bounded_impl::ReplaceWindowWithBoundedImpl; use uuid::Uuid; @@ -1609,7 +1609,7 @@ impl SessionState { // However, a deeper analysis may sometimes reveal that such a `SortExec` is actually unnecessary. // These cases typically arise when we have reversible `WindowAggExec`s or deep subqueries. The // rule below performs this analysis and removes unnecessary `SortExec`s. - physical_optimizers.push(Arc::new(RemoveUnnecessarySorts::new())); + physical_optimizers.push(Arc::new(OptimizeSorts::new())); // Replace WindowAggExec with BoundedWindowAggExec if conditions are met physical_optimizers.push(Arc::new(ReplaceWindowWithBoundedImpl::new())); diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index fe3bbf9bb8320..357b4665c0728 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -22,9 +22,9 @@ pub mod aggregate_statistics; pub mod coalesce_batches; pub mod enforcement; pub mod join_selection; +pub mod optimize_sorts; pub mod optimizer; pub mod pruning; -pub mod remove_unnecessary_sorts; pub mod repartition; pub mod replace_window_with_bounded_impl; mod utils; diff --git a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs b/datafusion/core/src/physical_optimizer/optimize_sorts.rs similarity index 98% rename from datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs rename to datafusion/core/src/physical_optimizer/optimize_sorts.rs index 593c31ded9a8b..b89a55f09fa2e 100644 --- a/datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs +++ b/datafusion/core/src/physical_optimizer/optimize_sorts.rs @@ -45,16 +45,16 @@ use std::sync::Arc; /// This rule inspects SortExec's in the given physical plan and removes the /// ones it can prove unnecessary. #[derive(Default)] -pub struct RemoveUnnecessarySorts {} +pub struct OptimizeSorts {} -impl RemoveUnnecessarySorts { +impl OptimizeSorts { #[allow(missing_docs)] pub fn new() -> Self { Self {} } } -/// This is a "data class" we use within the [RemoveUnnecessarySorts] rule +/// This is a "data class" we use within the [OptimizeSorts] rule /// that tracks the closest `SortExec` descendant for every child of a plan. #[derive(Debug, Clone)] struct PlanWithCorrespondingSort { @@ -118,7 +118,7 @@ impl TreeNodeRewritable for PlanWithCorrespondingSort { } } -impl PhysicalOptimizerRule for RemoveUnnecessarySorts { +impl PhysicalOptimizerRule for OptimizeSorts { fn optimize( &self, plan: Arc, @@ -589,7 +589,7 @@ mod tests { expected, actual ); let optimized_physical_plan = - RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + OptimizeSorts::new().optimize(physical_plan, &conf)?; let formatted = displayable(optimized_physical_plan.as_ref()) .indent() .to_string(); @@ -690,7 +690,7 @@ mod tests { expected, actual ); let optimized_physical_plan = - RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + OptimizeSorts::new().optimize(physical_plan, &conf)?; let formatted = displayable(optimized_physical_plan.as_ref()) .indent() .to_string(); @@ -736,7 +736,7 @@ mod tests { expected, actual ); let optimized_physical_plan = - RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + OptimizeSorts::new().optimize(physical_plan, &conf)?; let formatted = displayable(optimized_physical_plan.as_ref()) .indent() .to_string(); @@ -803,7 +803,7 @@ mod tests { expected, actual ); let optimized_physical_plan = - RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + OptimizeSorts::new().optimize(physical_plan, &conf)?; let formatted = displayable(optimized_physical_plan.as_ref()) .indent() .to_string(); @@ -865,7 +865,7 @@ mod tests { expected, actual ); let optimized_physical_plan = - RemoveUnnecessarySorts::new().optimize(physical_plan, &conf)?; + OptimizeSorts::new().optimize(physical_plan, &conf)?; let formatted = displayable(optimized_physical_plan.as_ref()) .indent() .to_string(); From ca666e940bf0d1b093c14bc0448eaf6f75b6545e Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 27 Dec 2022 09:33:11 +0300 Subject: [PATCH 31/50] Resolve merge conflicts --- datafusion/core/src/execution/context.rs | 2 +- datafusion/physical-expr/src/window/aggregate.rs | 1 - datafusion/physical-expr/src/window/built_in.rs | 3 +-- .../src/window/built_in_window_function_expr.rs | 1 - datafusion/physical-expr/src/window/nth_value.rs | 14 -------------- datafusion/physical-expr/src/window/row_number.rs | 1 + 6 files changed, 3 insertions(+), 19 deletions(-) diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 382e674cb1df0..780a321bd328d 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -1591,7 +1591,7 @@ impl SessionState { // Replace WindowAggExec with BoundedWindowAggExec if conditions are met physical_optimizers.push(Arc::new(ReplaceWindowWithBoundedImpl::new())); - SessionState { + let mut this = SessionState { session_id, optimizer: Optimizer::new(), physical_optimizers, diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index b8eae2597c518..8e598a0ad3c76 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -168,5 +168,4 @@ impl WindowExpr for AggregateWindowExpr { && !self.window_frame.start_bound.is_unbounded() && !self.window_frame.end_bound.is_unbounded() } - } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index e40509fdcab76..be70f29ceaf47 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -27,7 +27,7 @@ use crate::window::{ PartitionBatches, PartitionWindowAggStates, WindowAggState, WindowState, }; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; -use arrow::compute::SortOptions; +use arrow::compute::{concat, SortOptions}; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::ScalarValue; @@ -241,5 +241,4 @@ impl WindowExpr for BuiltInWindowExpr { || !(self.window_frame.start_bound.is_unbounded() || self.window_frame.end_bound.is_unbounded())) } - } diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index f5c5a31466ed2..6f41ec599a83c 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -72,5 +72,4 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { fn uses_window_frame(&self) -> bool { false } - } diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index eb6c23afbc098..28a1a08c1810e 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -132,20 +132,6 @@ impl BuiltInWindowFunctionExpr for NthValue { true } - fn reverse_expr(&self) -> Option> { - let reversed_kind = match self.kind { - NthValueKind::First => NthValueKind::Last, - NthValueKind::Last => NthValueKind::First, - NthValueKind::Nth(_) => return None, - }; - Some(Arc::new(Self { - name: self.name.clone(), - expr: self.expr.clone(), - data_type: self.data_type.clone(), - kind: reversed_kind, - })) - } - fn uses_window_frame(&self) -> bool { true } diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index 325d0649a576b..d099e8914d79e 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -25,6 +25,7 @@ use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{Result, ScalarValue}; use std::any::Any; +use std::ops::Range; use std::sync::Arc; /// row_number expression From 73d99c6cf733dfcc385166d9155ac9e4d965d46f Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 27 Dec 2022 10:17:46 +0300 Subject: [PATCH 32/50] refactors --- .../windows/bounded_window_agg_exec.rs | 26 +++++----- .../physical-expr/src/window/built_in.rs | 2 +- .../physical-expr/src/window/lead_lag.rs | 2 +- .../physical-expr/src/window/nth_value.rs | 2 +- .../src/window/partition_evaluator.rs | 4 +- datafusion/physical-expr/src/window/rank.rs | 2 +- .../physical-expr/src/window/row_number.rs | 2 +- .../src/window/sliding_aggregate.rs | 49 ++++--------------- 8 files changed, 29 insertions(+), 60 deletions(-) diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index 3201d5a3b0ed3..e6c80d6603505 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -325,20 +325,20 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { let n_out = self.calculate_n_out_row(); if n_out > 0 { let mut out_columns = vec![]; - for partition_window_agg_states in self.window_agg_states.iter() { - out_columns.push(get_aggregate_result_out_column( - partition_window_agg_states, - n_out, - )?); - } - - let batch_to_show = self - .input_buffer_record_batch - .columns() + self.window_agg_states .iter() - .map(|elem| elem.slice(0, n_out)) - .collect::>(); - out_columns.extend_from_slice(&batch_to_show); + .map(|elem| get_aggregate_result_out_column(elem, n_out)) + .chain( + self.input_buffer_record_batch + .columns() + .iter() + .map(|elem| Ok(elem.slice(0, n_out))), + ) + .map(|elem| { + out_columns.push(elem?); + Ok(()) + }) + .collect::>>()?; Ok(Some(out_columns)) } else { diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index be70f29ceaf47..a43c3151813b5 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -199,7 +199,7 @@ impl WindowExpr for BuiltInWindowExpr { row_wise_results .push(ScalarValue::try_from(self.expr.field()?.data_type())?) } else { - let res = evaluator.evaluate_bounded(&values)?; + let res = evaluator.evaluate_stateful(&values)?; row_wise_results.push(res); } last_range = state.window_frame_range.clone(); diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 5f9f690deeaca..9ab257f971eae 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -218,7 +218,7 @@ impl PartitionEvaluator for WindowShiftEvaluator { } } - fn evaluate_bounded(&mut self, values: &[ArrayRef]) -> Result { + fn evaluate_stateful(&mut self, values: &[ArrayRef]) -> Result { let dtype = values[0].data_type(); let idx = self.state.idx as i64 - self.shift_offset; if idx < 0 || idx as usize >= values[0].len() { diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 28a1a08c1810e..c3c3b55d4e88f 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -175,7 +175,7 @@ impl PartitionEvaluator for NthValueEvaluator { Ok(()) } - fn evaluate_bounded(&mut self, values: &[ArrayRef]) -> Result { + fn evaluate_stateful(&mut self, values: &[ArrayRef]) -> Result { self.evaluate_inside_range(values, self.state.range.clone()) } diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index 54681bb0efc0f..723b357de78d6 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -62,9 +62,9 @@ pub trait PartitionEvaluator: Debug + Send + Sync { } /// evaluate window function result inside given range - fn evaluate_bounded(&mut self, _values: &[ArrayRef]) -> Result { + fn evaluate_stateful(&mut self, _values: &[ArrayRef]) -> Result { Err(DataFusionError::NotImplemented( - "evaluate_bounded is not implemented by default".into(), + "evaluate_stateful is not implemented by default".into(), )) } diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index e96d48108f149..04fe4403b2c59 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -161,7 +161,7 @@ impl PartitionEvaluator for RankEvaluator { } /// evaluate window function result inside given range - fn evaluate_bounded(&mut self, _values: &[ArrayRef]) -> Result { + fn evaluate_stateful(&mut self, _values: &[ArrayRef]) -> Result { match self.rank_type { RankType::Basic => Ok(ScalarValue::UInt64(Some( self.state.last_rank_boundary as u64 + 1, diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index d099e8914d79e..b9ff05e5cdc85 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -89,7 +89,7 @@ impl PartitionEvaluator for NumRowsEvaluator { } /// evaluate window function result inside given range - fn evaluate_bounded(&mut self, _values: &[ArrayRef]) -> Result { + fn evaluate_stateful(&mut self, _values: &[ArrayRef]) -> Result { let n_row = self.state.n_rows as u64 + 1; self.state.n_rows += 1; Ok(ScalarValue::UInt64(Some(n_row))) diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 77534c4d8b3ee..bc293e2880748 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -95,50 +95,19 @@ impl WindowExpr for SlidingAggregateWindowExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let sort_options: Vec = - self.order_by.iter().map(|o| o.options).collect(); - let mut row_wise_results: Vec = vec![]; - let mut accumulator = self.aggregate.create_sliding_accumulator()?; - let length = batch.num_rows(); - let (values, order_bys) = self.get_values_orderbys(batch)?; let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); let mut last_range = Range { start: 0, end: 0 }; - - // We iterate on each row to perform a running calculation. - // First, cur_range is calculated, then it is compared with last_range. - for i in 0..length { - let cur_range = - window_frame_ctx.calculate_range(&order_bys, &sort_options, length, i)?; - let value = if cur_range.start == cur_range.end { - // We produce None if the window is empty. - ScalarValue::try_from(self.aggregate.field()?.data_type())? - } else { - // Accumulate any new rows that have entered the window: - let update_bound = cur_range.end - last_range.end; - if update_bound > 0 { - let update: Vec = values - .iter() - .map(|v| v.slice(last_range.end, update_bound)) - .collect(); - accumulator.update_batch(&update)? - } - // Remove rows that have now left the window: - let retract_bound = cur_range.start - last_range.start; - if retract_bound > 0 { - let retract: Vec = values - .iter() - .map(|v| v.slice(last_range.start, retract_bound)) - .collect(); - accumulator.retract_batch(&retract)? - } - accumulator.evaluate()? - }; - row_wise_results.push(value); - last_range = cur_range; - } - ScalarValue::iter_to_array(row_wise_results.into_iter()) + let mut idx = 0; + self.get_result_column( + &mut accumulator, + batch, + &mut window_frame_ctx, + &mut last_range, + &mut idx, + true, + ) } fn evaluate_stateful( From 09c19425b8e3a3128bfee39daa5e2f4940f11964 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 27 Dec 2022 17:06:17 +0300 Subject: [PATCH 33/50] Update fuzzy tests + minor changes --- datafusion/core/tests/window_fuzz.rs | 117 ++++++++++++++---- .../physical-expr/src/window/built_in.rs | 18 +-- 2 files changed, 107 insertions(+), 28 deletions(-) diff --git a/datafusion/core/tests/window_fuzz.rs b/datafusion/core/tests/window_fuzz.rs index 9b8a27cfadf89..7c6a99432bf68 100644 --- a/datafusion/core/tests/window_fuzz.rs +++ b/datafusion/core/tests/window_fuzz.rs @@ -38,7 +38,7 @@ use datafusion_expr::{ use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::ScalarValue; -use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::expressions::{col, lit}; use datafusion_physical_expr::PhysicalSortExpr; use test_utils::add_empty_batches; @@ -118,42 +118,111 @@ mod tests { } /// Perform batch and running window same input -/// and verify two outputs are equal +/// and verify outputs of `WindowAggExec` and `BoundedWindowAggExec` are equal async fn run_window_test( input1: Vec, random_seed: u64, orderby_columns: Vec<&str>, partition_by_columns: Vec<&str>, ) { - let mut func_name_to_window_fn = HashMap::new(); - func_name_to_window_fn.insert( + let mut rng = StdRng::seed_from_u64(random_seed); + let schema = input1[0].schema(); + let mut args = vec![col("x", &schema).unwrap()]; + let mut window_fn_map = HashMap::new(); + // HashMap values consists of tuple first element is WindowFunction, second is additional argument + // window function requires if any. For most of the window functions additional argument is empty + window_fn_map.insert( "sum", - WindowFunction::AggregateFunction(AggregateFunction::Sum), + ( + WindowFunction::AggregateFunction(AggregateFunction::Sum), + vec![], + ), ); - func_name_to_window_fn.insert( + window_fn_map.insert( "count", - WindowFunction::AggregateFunction(AggregateFunction::Count), + ( + WindowFunction::AggregateFunction(AggregateFunction::Count), + vec![], + ), ); - func_name_to_window_fn.insert( + window_fn_map.insert( "min", - WindowFunction::AggregateFunction(AggregateFunction::Min), + ( + WindowFunction::AggregateFunction(AggregateFunction::Min), + vec![], + ), ); - func_name_to_window_fn.insert( + window_fn_map.insert( "max", - WindowFunction::AggregateFunction(AggregateFunction::Max), + ( + WindowFunction::AggregateFunction(AggregateFunction::Max), + vec![], + ), ); - func_name_to_window_fn.insert( + window_fn_map.insert( "row_number", - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::RowNumber), + ( + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::RowNumber), + vec![], + ), + ); + window_fn_map.insert( + "rank", + ( + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Rank), + vec![], + ), + ); + window_fn_map.insert( + "first_value", + ( + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue), + vec![], + ), + ); + window_fn_map.insert( + "last_value", + ( + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue), + vec![], + ), + ); + window_fn_map.insert( + "nth_value", + ( + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::NthValue), + vec![lit(ScalarValue::Int64(Some(rng.gen_range(1..10))))], + ), + ); + window_fn_map.insert( + "lead", + ( + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead), + vec![ + lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), + lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), + ], + ), + ); + window_fn_map.insert( + "lag", + ( + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag), + vec![ + lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), + lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), + ], + ), ); - let mut rng = StdRng::seed_from_u64(random_seed); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::with_config(session_config); - let schema = input1[0].schema(); - let rand_fn_idx = rng.gen_range(0..func_name_to_window_fn.len()); - let fn_name = func_name_to_window_fn.keys().collect::>()[rand_fn_idx]; - let window_fn = func_name_to_window_fn.values().collect::>()[rand_fn_idx]; + let rand_fn_idx = rng.gen_range(0..window_fn_map.len()); + let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; + let (window_fn, new_args) = window_fn_map.values().collect::>()[rand_fn_idx]; + for new_arg in new_args { + args.push(new_arg.clone()); + } let preceding = rng.gen_range(0..50); let following = rng.gen_range(0..50); let rand_num = rng.gen_range(0..3); @@ -217,7 +286,7 @@ async fn run_window_test( vec![create_window_expr( window_fn, fn_name.to_string(), - &[col("x", &schema).unwrap()], + &args, &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), @@ -238,7 +307,7 @@ async fn run_window_test( vec![create_window_expr( window_fn, fn_name.to_string(), - &[col("x", &schema).unwrap()], + &args, &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), @@ -276,7 +345,13 @@ async fn run_window_test( .zip(&running_formatted_sorted) .enumerate() { - assert_eq!((i, usual_line), (i, running_line)); + assert_eq!( + (i, usual_line), + (i, running_line), + "Inconsistent result for window_fn: {:?}, args:{:?}", + window_fn, + args + ); } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index a43c3151813b5..80618bd221efa 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -165,9 +165,6 @@ impl WindowExpr for BuiltInWindowExpr { state.is_end = partition_batch_state.is_end; let num_rows = partition_batch_state.record_batch.num_rows(); - let columns = self.sort_columns(&partition_batch_state.record_batch)?; - let sort_partition_points = - self.evaluate_partition_points(num_rows, &columns)?; let (values, order_bys) = self.get_values_orderbys(&partition_batch_state.record_batch)?; @@ -175,18 +172,25 @@ impl WindowExpr for BuiltInWindowExpr { let mut row_wise_results: Vec = vec![]; let mut last_range = state.window_frame_range.clone(); let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + let sort_partition_points = if evaluator.include_rank() { + let columns = self.sort_columns(&partition_batch_state.record_batch)?; + self.evaluate_partition_points(num_rows, &columns)? + } else { + vec![] + }; for idx in state.last_calculated_index..num_rows { - state.window_frame_range = if !self.expr.uses_window_frame() { - evaluator.get_range(state, num_rows)? - } else { + state.window_frame_range = if self.expr.uses_window_frame() { window_frame_ctx.calculate_range( &order_bys, &sort_options, num_rows, idx, )? + } else { + evaluator.get_range(state, num_rows)? }; evaluator.update_state(state, &order_bys, &sort_partition_points)?; + // exit if range end index is length, need kind of flag to stop if state.window_frame_range.end == num_rows && !partition_batch_state.is_end @@ -197,7 +201,7 @@ impl WindowExpr for BuiltInWindowExpr { if state.window_frame_range.start == state.window_frame_range.end { // We produce None if the window is empty. row_wise_results - .push(ScalarValue::try_from(self.expr.field()?.data_type())?) + .push(ScalarValue::try_from(self.expr.field()?.data_type())?); } else { let res = evaluator.evaluate_stateful(&values)?; row_wise_results.push(res); From 39564d477542f1ea71077fa6a354ac5c89b94755 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Wed, 28 Dec 2022 23:35:41 -0500 Subject: [PATCH 34/50] Simplify code and improve comments --- datafusion/core/src/execution/context.rs | 6 +- datafusion/core/src/physical_optimizer/mod.rs | 2 +- ...ed_impl.rs => use_bounded_window_execs.rs} | 26 +- datafusion/core/src/physical_plan/common.rs | 8 +- .../windows/bounded_window_agg_exec.rs | 249 +++++++++--------- datafusion/expr/src/accumulator.rs | 2 +- datafusion/physical-expr/src/aggregate/mod.rs | 4 +- .../physical-expr/src/window/aggregate.rs | 2 +- .../physical-expr/src/window/built_in.rs | 37 ++- .../physical-expr/src/window/lead_lag.rs | 20 +- .../src/window/partition_evaluator.rs | 2 +- datafusion/physical-expr/src/window/rank.rs | 23 +- .../physical-expr/src/window/row_number.rs | 3 +- .../src/window/sliding_aggregate.rs | 21 +- .../physical-expr/src/window/window_expr.rs | 7 +- 15 files changed, 205 insertions(+), 207 deletions(-) rename datafusion/core/src/physical_optimizer/{replace_window_with_bounded_impl.rs => use_bounded_window_execs.rs} (77%) diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 780a321bd328d..755517258a55e 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -101,7 +101,7 @@ use crate::catalog::listing_schema::ListingSchemaProvider; use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::memory_pool::MemoryPool; use crate::physical_optimizer::optimize_sorts::OptimizeSorts; -use crate::physical_optimizer::replace_window_with_bounded_impl::ReplaceWindowWithBoundedImpl; +use crate::physical_optimizer::use_bounded_window_execs::UseBoundedWindowAggExec; use uuid::Uuid; use super::options::{ @@ -1588,8 +1588,8 @@ impl SessionState { // rule below performs this analysis and removes unnecessary `SortExec`s. physical_optimizers.push(Arc::new(OptimizeSorts::new())); - // Replace WindowAggExec with BoundedWindowAggExec if conditions are met - physical_optimizers.push(Arc::new(ReplaceWindowWithBoundedImpl::new())); + // Replace ordinary window executors with bounded-memory variants when possible: + physical_optimizers.push(Arc::new(UseBoundedWindowAggExec::new())); let mut this = SessionState { session_id, diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index 357b4665c0728..0ef779653908d 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -26,7 +26,7 @@ pub mod optimize_sorts; pub mod optimizer; pub mod pruning; pub mod repartition; -pub mod replace_window_with_bounded_impl; +pub mod use_bounded_window_execs; mod utils; pub use optimizer::PhysicalOptimizerRule; diff --git a/datafusion/core/src/physical_optimizer/replace_window_with_bounded_impl.rs b/datafusion/core/src/physical_optimizer/use_bounded_window_execs.rs similarity index 77% rename from datafusion/core/src/physical_optimizer/replace_window_with_bounded_impl.rs rename to datafusion/core/src/physical_optimizer/use_bounded_window_execs.rs index e7143b1e7bd03..8f6da52e818be 100644 --- a/datafusion/core/src/physical_optimizer/replace_window_with_bounded_impl.rs +++ b/datafusion/core/src/physical_optimizer/use_bounded_window_execs.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! ReplaceWindowWithBoundedImpl optimizer that replaces `WindowAggExec` -//! with `BoundedWindowAggExec` if window_expr can run using `BoundedWindowAggExec` +//! The [UseBoundedWindowAggExec] rule replaces [WindowAggExec]s with +//! [BoundedWindowAggExec]s if the window expression in question is +//! amenable to pipeline-friendly bounded memory execution. use crate::physical_plan::windows::BoundedWindowAggExec; use crate::physical_plan::windows::WindowAggExec; @@ -27,18 +28,19 @@ use crate::{ use datafusion_expr::WindowFrameUnits; use std::sync::Arc; -/// Optimizer rule that introduces replaces `WindowAggExec` with `BoundedWindowAggExec` -/// to run executor with bounded memory. +/// This rule checks whether [WindowAggExec]s in the query plan can be +/// replaced with [BoundedWindowAggExec]s, and replaces them whenever possible. #[derive(Default)] -pub struct ReplaceWindowWithBoundedImpl {} +pub struct UseBoundedWindowAggExec {} -impl ReplaceWindowWithBoundedImpl { +impl UseBoundedWindowAggExec { #[allow(missing_docs)] pub fn new() -> Self { Self {} } } -impl PhysicalOptimizerRule for ReplaceWindowWithBoundedImpl { + +impl PhysicalOptimizerRule for UseBoundedWindowAggExec { fn optimize( &self, plan: Arc, @@ -46,18 +48,18 @@ impl PhysicalOptimizerRule for ReplaceWindowWithBoundedImpl { ) -> Result> { plan.transform_up(&|plan| { if let Some(window_agg_exec) = plan.as_any().downcast_ref::() { - let is_contains_groups = + let contains_groups = window_agg_exec.window_expr().iter().any(|window_expr| { matches!( window_expr.get_window_frame().units, WindowFrameUnits::Groups ) }); - let can_run_bounded = window_agg_exec + let uses_bounded_memory = window_agg_exec .window_expr() .iter() - .all(|elem| elem.can_run_bounded()); - if !is_contains_groups && can_run_bounded { + .all(|elem| elem.uses_bounded_memory()); + if !contains_groups && uses_bounded_memory { return Ok(Some(Arc::new(BoundedWindowAggExec::try_new( window_agg_exec.window_expr().to_vec(), window_agg_exec.input().clone(), @@ -72,7 +74,7 @@ impl PhysicalOptimizerRule for ReplaceWindowWithBoundedImpl { } fn name(&self) -> &str { - "ReplaceWindowWithBoundedImpl" + "UseBoundedWindowAggExec" } fn schema_check(&self) -> bool { diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index b188f79c024ba..df7a3dd6c535a 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -96,10 +96,10 @@ pub async fn collect(stream: SendableRecordBatchStream) -> Result ArrowResult> { diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index e6c80d6603505..a2f218c4669b0 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -16,6 +16,9 @@ // under the License. //! Stream and channel implementations for window function expressions. +//! The executor given here uses bounded memory (does not maintain all +//! the input data seen so far), which makes it appropriate when processing +//! infinite inputs. use crate::error::Result; use crate::execution::context::TaskContext; @@ -46,7 +49,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::physical_plan::common::combine_batches_with_ref; +use crate::physical_plan::common::combine_batches; use datafusion_physical_expr::window::{ PartitionBatchState, PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowAggState, WindowState, @@ -150,15 +153,16 @@ impl ExecutionPlan for BoundedWindowAggExec { /// Get the output partitioning of this plan fn output_partitioning(&self) -> Partitioning { - // because we can have repartitioning using the partition keys - // this would be either 1 or more than 1 depending on the presense of - // repartitioning + // As we can have repartitioning using the partition keys, this can + // be either one or more than one, depending on the presence of + // repartitioning. self.input.output_partitioning() } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - // This executor maintains input order, and has required input_ordering filled - // hence output_ordering would be `required_input_ordering` + // This executor maintains input order, and it has a required input + // ordering. Therefore, output_ordering would be the same with + // `required_input_ordering`. self.required_input_ordering()[0] } @@ -269,20 +273,6 @@ impl ExecutionPlan for BoundedWindowAggExec { } } -/// Trait for updating state, calculate results for window functions -/// According to partition by column assumptions Sorted/Unsorted we may have different -/// implementations for these fields -pub trait PartitionByHandler { - /// Method to construct output columns from window_expression results - fn calculate_out_columns(&self) -> Result>>; - /// Given how many rows we emitted as results - /// prune no longer needed sections from the state - fn prune_state(&mut self, n_out: usize) -> Result<()>; - /// method to update record batches for each partition - /// when new record batches are received - fn update_partition_batch(&mut self, record_batch: RecordBatch) -> Result<()>; -} - fn create_schema( input_schema: &Schema, window_expr: &[Arc], @@ -295,36 +285,54 @@ fn create_schema( Ok(Schema::new(fields)) } +/// This trait defines the interface for updating the state and calculating +/// results for window functions. Depending on the partitioning scheme, one +/// may have different implementations for the functions within. +pub trait PartitionByHandler { + /// Constructs output columns from window_expression results. + fn calculate_out_columns(&self) -> Result>>; + /// Prunes the window state to remove any unnecessary information + /// given how many rows we emitted so far. + fn prune_state(&mut self, n_out: usize) -> Result<()>; + /// Updates record batches for each partition when new batches are + /// received. + fn update_partition_batch(&mut self, record_batch: RecordBatch) -> Result<()>; +} + /// stream for window aggregation plan /// assuming partition by column is sorted (or without PARTITION BY expression) pub struct SortedPartitionByBoundedWindowStream { schema: SchemaRef, input: SendableRecordBatchStream, - /// The record_batch executor receives as input (columns needed during aggregate results calculation) + /// The record batch executor receives as input (i.e. the columns needed + /// while calculating aggregation results). input_buffer_record_batch: RecordBatch, - /// we separate `input_buffer_record_batch` according to different partitions (determined by PARTITION BY columns) - /// and store the result record_batches per partition base in the `partition_batches`. - /// This variable is used during result calculation for each window_expression - /// This enables us to use same batch for different window_expressions (without copying) - // We may have keep record_batches for each window expression in the `PartitionWindowAggStates` - // However, this would use more memory (on the order of window_expression number) + /// We separate `input_buffer_record_batch` based on partitions (as + /// determined by PARTITION BY columns) and store them per partition + /// in `partition_batches`. We use this variable when calculating results + /// for each window expression. This enables us to use the same batch for + /// different window expressions without copying. + // Note that we could keep record batches for each window expression in + // `PartitionWindowAggStates`. However, this would use more memory (as + // many times as the number of window expressions). partition_batches: PartitionBatches, - /// Each executor can run multiple window expressions given - /// their PARTITION BY and ORDER BY sections are same - /// We keep state of the each window expression inside `window_agg_states` + /// An executor can run multiple window expressions if the PARTITION BY + /// and ORDER BY sections are same. We keep state of the each window + /// expression inside `window_agg_states`. window_agg_states: Vec, finished: bool, window_expr: Vec>, - baseline_metrics: BaselineMetrics, partition_by_sort_keys: Vec, + baseline_metrics: BaselineMetrics, } impl PartitionByHandler for SortedPartitionByBoundedWindowStream { /// This method constructs output columns using the result of each window expression fn calculate_out_columns(&self) -> Result>> { let n_out = self.calculate_n_out_row(); - if n_out > 0 { - let mut out_columns = vec![]; + if n_out == 0 { + Ok(None) + } else { self.window_agg_states .iter() .map(|elem| get_aggregate_result_out_column(elem, n_out)) @@ -334,28 +342,21 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { .iter() .map(|elem| Ok(elem.slice(0, n_out))), ) - .map(|elem| { - out_columns.push(elem?); - Ok(()) - }) - .collect::>>()?; - - Ok(Some(out_columns)) - } else { - Ok(None) + .collect::>>() + .map(Some) } } - /// Prunes sections in the state that are no longer needed - /// for calculating result. Determined by window frame boundaries - // For instance if `n_out` number of rows are calculated, we can remove first `n_out` rows from - // the `self.input_buffer_record_batch` + /// Prunes sections of the state that are no longer needed when calculating + /// results (as determined by window frame boundaries). + // For instance, if `n_out` number of rows are calculated, we can remove + // first `n_out` rows from `self.input_buffer_record_batch`. fn prune_state(&mut self, n_out: usize) -> Result<()> { - // Prunes `self.partition_batches` + // Prune `self.partition_batches`: self.prune_partition_batches()?; - // Prunes `self.input_buffer_record_batch` + // Prune `self.input_buffer_record_batch`: self.prune_input_batch(n_out)?; - // Prunes `self.window_agg_states` + // Prune `self.window_agg_states`: self.prune_out_columns(n_out)?; Ok(()) } @@ -380,7 +381,7 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { if let Some(partition_batch_state) = self.partition_batches.get_mut(&partition_row) { - let combined_partition_batch = combine_batches_with_ref( + partition_batch_state.record_batch = combine_batches( &[&partition_batch_state.record_batch, &partition_batch], self.input.schema(), )? @@ -389,7 +390,6 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { "Should contain at least one entry".to_string(), ) })?; - partition_batch_state.record_batch = combined_partition_batch; } else { let partition_batch_state = PartitionBatchState { record_batch: partition_batch, @@ -404,14 +404,13 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { for (idx, (_, partition_batch_state)) in self.partition_batches.iter_mut().enumerate() { - if idx < n_partitions - 1 { - partition_batch_state.is_end = true; - } + partition_batch_state.is_end |= idx < n_partitions - 1; } - if self.input_buffer_record_batch.num_rows() == 0 { - self.input_buffer_record_batch = record_batch; + self.input_buffer_record_batch = if self.input_buffer_record_batch.num_rows() == 0 + { + record_batch } else { - self.input_buffer_record_batch = combine_batches_with_ref( + combine_batches( &[&self.input_buffer_record_batch, &record_batch], self.input.schema(), )? @@ -419,8 +418,8 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { DataFusionError::Execution( "Should contain at least one entry".to_string(), ) - })?; - } + })? + }; Ok(()) } @@ -439,7 +438,7 @@ impl Stream for SortedPartitionByBoundedWindowStream { } impl SortedPartitionByBoundedWindowStream { - /// Create a new WindowAggStream + /// Create a new BoundedWindowAggStream pub fn new( schema: SchemaRef, window_expr: Vec>, @@ -447,10 +446,7 @@ impl SortedPartitionByBoundedWindowStream { baseline_metrics: BaselineMetrics, partition_by_sort_keys: Vec, ) -> Self { - let mut state = vec![]; - for _i in 0..window_expr.len() { - state.push(IndexMap::new()); - } + let state = window_expr.iter().map(|_| IndexMap::new()).collect(); let empty_batch = RecordBatch::new_empty(schema.clone()); Self { schema, @@ -467,20 +463,20 @@ impl SortedPartitionByBoundedWindowStream { fn compute_aggregates(&mut self) -> ArrowResult { // calculate window cols - for (idx, cur_window_expr) in self.window_expr.iter().enumerate() { - cur_window_expr.evaluate_stateful( - &self.partition_batches, - &mut self.window_agg_states[idx], - )?; + for (cur_window_expr, state) in + self.window_expr.iter().zip(&mut self.window_agg_states) + { + cur_window_expr.evaluate_stateful(&self.partition_batches, state)?; } + let schema = self.schema.clone(); let columns_to_show = self.calculate_out_columns()?; if let Some(columns_to_show) = columns_to_show { let n_generated = columns_to_show[0].len(); self.prune_state(n_generated)?; - RecordBatch::try_new(self.schema.clone(), columns_to_show) + RecordBatch::try_new(schema, columns_to_show) } else { - Ok(RecordBatch::new_empty(self.schema.clone())) + Ok(RecordBatch::new_empty(schema)) } } @@ -510,24 +506,25 @@ impl SortedPartitionByBoundedWindowStream { Poll::Ready(Some(result)) } - /// Method to calculate how many rows SortedPartitionByBoundedWindowStream can produce as output + /// Calculates how many rows [SortedPartitionByBoundedWindowStream] + /// can produce as output. fn calculate_n_out_row(&self) -> usize { - // different window aggregators may produce results with different rates - // we produce the overall batch result with the same speed as slowest one + // Different window aggregators may produce results with different rates. + // We produce the overall batch result with the same speed as slowest one. self.window_agg_states .iter() .map(|window_agg_state| { - // below variable stores how many elements are generated (can displayed) for current - // window expression + // Store how many elements are generated for the current + // window expression: let mut cur_window_expr_out_result_len = 0; - // We iterate over window_agg_state - // since it is IndexMap, iteration is over insertion order - // Hence we preserve sorting when partition columns are sorted + // We iterate over `window_agg_state`, which is an IndexMap. + // Iterations follow the insertion order, hence we preserve + // sorting when partition columns are sorted. for (_, WindowState { state, .. }) in window_agg_state.iter() { cur_window_expr_out_result_len += state.out_col.len(); - // If we do not generate all results for current partition - // we do not generate result for next partition - // otherwise we will lose input ordering + // If we do not generate all results for the current + // partition, we do not generate results for next + // partition -- otherwise we will lose input ordering. if state.n_row_result_missing > 0 { break; } @@ -539,22 +536,26 @@ impl SortedPartitionByBoundedWindowStream { .unwrap_or(0) } - /// prunes the sections of the record batch (for each partition) - /// we no longer need to calculate window function result + /// Prunes the sections of the record batch (for each partition) + /// that we no longer need to calculate the window function result. fn prune_partition_batches(&mut self) -> Result<()> { - // Remove partitions which we know that ended (is_end flag is true) - // Retain method keep the remaining elements in the insertion order - // Hence after removal we still preserve ordering in between partitions + // Remove partitions which we know already ended (is_end flag is true). + // Since the retain method preserves insertion order, we still have + // ordering in between partitions after removal. self.partition_batches .retain(|_, partition_batch_state| !partition_batch_state.is_end); - // `self.partition_batches` data are used by all window expressions - // hence when removing from `self.partition_batches` we need to remove from the earliest range boundary - // among all window expressions. `n_prune_each_partition` fill the earliest range boundary information - // for each partition. By this way we can delete no longer needed sections from the `self.partition_batches`. - // For instance if window frame one uses [10, 20] and window frame 2 uses [5, 15] - // We prune only first 5 elements from corresponding record batch in `self.partition_batches` - // Calculate how many element to prune for each partition_batch + // The data in `self.partition_batches` is used by all window expressions. + // Therefore, when removing from `self.partition_batches`, we need to remove + // from the earliest range boundary among all window expressions. Variable + // `n_prune_each_partition` fill the earliest range boundary information for + // each partition. This way, we can delete the no-longer-needed sections from + // `self.partition_batches`. + // For instance, if window frame one uses [10, 20] and window frame two uses + // [5, 15]; we only prune the first 5 elements from the corresponding record + // batch in `self.partition_batches`. + + // Calculate how many elements to prune for each partition batch let mut n_prune_each_partition: HashMap = HashMap::new(); for window_agg_state in self.window_agg_states.iter_mut() { window_agg_state.retain(|_, WindowState { state, .. }| !state.is_end); @@ -571,19 +572,17 @@ impl SortedPartitionByBoundedWindowStream { } let err = || DataFusionError::Execution("Expects to have partition".to_string()); - // Retracts no longer needed parts during window calculations from partition batch + // Retract no longer needed parts during window calculations from partition batch: for (partition_row, n_prune) in n_prune_each_partition.iter() { let partition_batch_state = self .partition_batches .get_mut(partition_row) .ok_or_else(err)?; - let new_record_batch = partition_batch_state.record_batch.slice( - *n_prune, - partition_batch_state.record_batch.num_rows() - n_prune, - ); - partition_batch_state.record_batch = new_record_batch; + let batch = &partition_batch_state.record_batch; + partition_batch_state.record_batch = + batch.slice(*n_prune, batch.num_rows() - n_prune); - // Update State indices, since we have pruned some rows from the beginning + // Update state indices since we have pruned some rows from the beginning: for window_agg_state in self.window_agg_states.iter_mut() { let window_state = window_agg_state.get_mut(partition_row).ok_or_else(err)?; @@ -600,10 +599,9 @@ impl SortedPartitionByBoundedWindowStream { } /// Prunes the section of the input batch whose aggregate results - /// are calculated and emitted as result + /// are calculated and emitted. fn prune_input_batch(&mut self, n_out: usize) -> Result<()> { - let len_batch = self.input_buffer_record_batch.num_rows(); - let n_to_keep = len_batch - n_out; + let n_to_keep = self.input_buffer_record_batch.num_rows() - n_out; let batch_to_keep = self .input_buffer_record_batch .columns() @@ -615,15 +613,17 @@ impl SortedPartitionByBoundedWindowStream { Ok(()) } - /// Prunes emitted parts from WindowAggState `out_col` field + /// Prunes emitted parts from WindowAggState `out_col` field. fn prune_out_columns(&mut self, n_out: usize) -> Result<()> { - // We store generated columns for each window expression in `out_col` field of the `WindowAggState` - // Given how many rows are emitted to output we remove these sections from state + // We store generated columns for each window expression in the `out_col` + // field of `WindowAggState`. Given how many rows are emitted, we remove + // these sections from state. for partition_window_agg_states in self.window_agg_states.iter_mut() { let mut running_length = 0; - // Remove total of `n_out` entries from `out_col` field of `WindowAggState`. Iterates in the - // insertion order. Hence we preserve per partition ordering. Without emitting all results for a partition - // we do not generate result for another partition + // Remove `n_out` entries from the `out_col` field of `WindowAggState`. + // Preserve per partition ordering by iterating in the order of insertion. + // Do not generate a result for a new partition without emitting all results + // for the current partition. for ( _, WindowState { @@ -647,7 +647,7 @@ impl SortedPartitionByBoundedWindowStream { pub fn partition_columns(&self, batch: &RecordBatch) -> Result> { self.partition_by_sort_keys .iter() - .map(|elem| elem.evaluate_to_sort_column(batch)) + .map(|e| e.evaluate_to_sort_column(batch)) .collect::>>() } @@ -658,16 +658,16 @@ impl SortedPartitionByBoundedWindowStream { num_rows: usize, partition_columns: &[SortColumn], ) -> Result>> { - if partition_columns.is_empty() { - Ok(vec![Range { + Ok(if partition_columns.is_empty() { + vec![Range { start: 0, end: num_rows, - }]) + }] } else { - Ok(lexicographical_partition_ranges(partition_columns) + lexicographical_partition_ranges(partition_columns) .map_err(DataFusionError::ArrowError)? - .collect::>()) - } + .collect::>() + }) } } @@ -683,7 +683,7 @@ fn get_aggregate_result_out_column( partition_window_agg_states: &PartitionWindowAggStates, len_to_show: usize, ) -> Result { - let mut ret = None; + let mut result = None; let mut running_length = 0; // We assume that iteration order is according to insertion order for ( @@ -696,12 +696,12 @@ fn get_aggregate_result_out_column( { if running_length < len_to_show { let n_to_use = min(len_to_show - running_length, out_col.len()); - running_length += n_to_use; let slice_to_use = out_col.slice(0, n_to_use); - ret = match ret { - Some(ret) => Some(concat(&[&ret, &slice_to_use])?), - None => Some(slice_to_use), - } + result = Some(match result { + Some(arr) => concat(&[&arr, &slice_to_use])?, + None => slice_to_use, + }); + running_length += n_to_use; } else { break; } @@ -712,5 +712,6 @@ fn get_aggregate_result_out_column( len_to_show, running_length ))); } - ret.ok_or_else(|| DataFusionError::Execution("Should contain something".to_string())) + result + .ok_or_else(|| DataFusionError::Execution("Should contain something".to_string())) } diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr/src/accumulator.rs index 131bd64c0343b..5f23576649867 100644 --- a/datafusion/expr/src/accumulator.rs +++ b/datafusion/expr/src/accumulator.rs @@ -88,7 +88,7 @@ pub trait Accumulator: Send + Sync + Debug { fn clone_dyn(&self) -> Result> { Err(DataFusionError::NotImplemented( - "clone_dyn is not implemented by default for this accumulator, to use it in for cloning implement this method".into(), + "clone_dyn is not implemented by default for this accumulator".into(), )) } } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 372823e754f06..065ae7799a48b 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -88,8 +88,8 @@ pub trait AggregateExpr: Send + Sync + Debug { false } - /// Specifies whether this aggregate function can run using bounded memory - /// To be true accumulator should have `retract_batch` implemented + /// Specifies whether this aggregate function can run using bounded memory. + /// Any accumulator returning "true" needs to implement `retract_batch`. fn supports_bounded_execution(&self) -> bool { false } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 8e598a0ad3c76..3317e2abdeeff 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -163,7 +163,7 @@ impl WindowExpr for AggregateWindowExpr { }) } - fn can_run_bounded(&self) -> bool { + fn uses_bounded_memory(&self) -> bool { self.aggregate.supports_bounded_execution() && !self.window_frame.start_bound.is_unbounded() && !self.window_frame.end_bound.is_unbounded() diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 80618bd221efa..c918bdda9e5ed 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -127,17 +127,17 @@ impl WindowExpr for BuiltInWindowExpr { } } - /// evaluate the window function values against the batch + /// Evaluate the window function against the batch. This function facilitates + /// stateful, bounded-memory implementations. fn evaluate_stateful( &self, partition_batches: &PartitionBatches, window_agg_state: &mut PartitionWindowAggStates, ) -> Result<()> { - let sort_options: Vec = - self.order_by.iter().map(|o| o.options).collect(); + let field = self.expr.field()?; + let out_type = field.data_type(); + let sort_options = self.order_by.iter().map(|o| o.options).collect::>(); for (partition_row, partition_batch_state) in partition_batches.iter() { - let field = self.expr.field()?; - let out_type = field.data_type(); if !window_agg_state.contains_key(partition_row) { let evaluator = self.expr.create_evaluator()?; window_agg_state.insert( @@ -163,13 +163,12 @@ impl WindowExpr for BuiltInWindowExpr { }; let mut state = &mut window_state.state; state.is_end = partition_batch_state.is_end; - let num_rows = partition_batch_state.record_batch.num_rows(); let (values, order_bys) = self.get_values_orderbys(&partition_batch_state.record_batch)?; // We iterate on each row to perform a running calculation. - let mut row_wise_results: Vec = vec![]; + let num_rows = partition_batch_state.record_batch.num_rows(); let mut last_range = state.window_frame_range.clone(); let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); let sort_partition_points = if evaluator.include_rank() { @@ -178,6 +177,7 @@ impl WindowExpr for BuiltInWindowExpr { } else { vec![] }; + let mut row_wise_results: Vec = vec![]; for idx in state.last_calculated_index..num_rows { state.window_frame_range = if self.expr.uses_window_frame() { window_frame_ctx.calculate_range( @@ -185,28 +185,27 @@ impl WindowExpr for BuiltInWindowExpr { &sort_options, num_rows, idx, - )? + ) } else { - evaluator.get_range(state, num_rows)? - }; + evaluator.get_range(state, num_rows) + }?; evaluator.update_state(state, &order_bys, &sort_partition_points)?; - // exit if range end index is length, need kind of flag to stop + // Exit if range end index is length, need kind of flag to stop if state.window_frame_range.end == num_rows && !partition_batch_state.is_end { state.window_frame_range = last_range.clone(); break; } - if state.window_frame_range.start == state.window_frame_range.end { + let frame_range = &state.window_frame_range; + row_wise_results.push(if frame_range.start == frame_range.end { // We produce None if the window is empty. - row_wise_results - .push(ScalarValue::try_from(self.expr.field()?.data_type())?); + ScalarValue::try_from(out_type) } else { - let res = evaluator.evaluate_stateful(&values)?; - row_wise_results.push(res); - } - last_range = state.window_frame_range.clone(); + evaluator.evaluate_stateful(&values) + }?); + last_range = frame_range.clone(); state.last_calculated_index = idx + 1; } state.window_frame_range = last_range; @@ -239,7 +238,7 @@ impl WindowExpr for BuiltInWindowExpr { }) } - fn can_run_bounded(&self) -> bool { + fn uses_bounded_memory(&self) -> bool { self.expr.supports_bounded_execution() && (!self.expr.uses_window_frame() || !(self.window_frame.start_bound.is_unbounded() diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 9ab257f971eae..fc815a220af0e 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -199,8 +199,9 @@ impl PartitionEvaluator for WindowShiftEvaluator { fn get_range(&self, state: &WindowAggState, n_rows: usize) -> Result> { if self.shift_offset > 0 { - let start = if state.last_calculated_index > self.shift_offset as usize { - state.last_calculated_index - self.shift_offset as usize + let offset = self.shift_offset as usize; + let start = if state.last_calculated_index > offset { + state.last_calculated_index - offset } else { 0 }; @@ -209,8 +210,8 @@ impl PartitionEvaluator for WindowShiftEvaluator { end: state.last_calculated_index + 1, }) } else { - let end = state.last_calculated_index + (-self.shift_offset) as usize; - let end = min(end, n_rows); + let offset = (-self.shift_offset) as usize; + let end = min(state.last_calculated_index + offset, n_rows); Ok(Range { start: state.last_calculated_index, end, @@ -219,12 +220,13 @@ impl PartitionEvaluator for WindowShiftEvaluator { } fn evaluate_stateful(&mut self, values: &[ArrayRef]) -> Result { - let dtype = values[0].data_type(); + let array = &values[0]; + let dtype = array.data_type(); let idx = self.state.idx as i64 - self.shift_offset; - if idx < 0 || idx as usize >= values[0].len() { + if idx < 0 || idx as usize >= array.len() { get_default_value(&self.default_value, dtype) } else { - ScalarValue::try_from_array(&values[0], idx as usize) + ScalarValue::try_from_array(array, idx as usize) } } @@ -239,8 +241,8 @@ fn get_default_value( default_value: &Option, dtype: &DataType, ) -> Result { - if let Some(val) = default_value { - if let ScalarValue::Int64(Some(val)) = val { + if let Some(value) = default_value { + if let ScalarValue::Int64(Some(val)) = value { ScalarValue::try_from_string(val.to_string(), dtype) } else { Err(DataFusionError::Internal( diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index 723b357de78d6..679bfd791128b 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -92,7 +92,7 @@ pub trait PartitionEvaluator: Debug + Send + Sync { fn clone_dyn(&self) -> Result> { Err(DataFusionError::NotImplemented( - "clone_dyn is not implemented by default for this evaluator, to use it in for cloning implement this method".into(), + "clone_dyn is not implemented by default for this evaluator".into(), )) } } diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 04fe4403b2c59..ead9d44535ba1 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -143,19 +143,16 @@ impl PartitionEvaluator for RankEvaluator { && state.last_calculated_index < elem.end }) .ok_or_else(|| DataFusionError::Execution("Expects sort_partition_points to contain state.last_calculated_index".to_string()))?; - let cur_chunk = sort_partition_points[chunk_idx].clone(); - let mut last_rank_data = vec![]; - for column in range_columns { - last_rank_data.push(ScalarValue::try_from_array(column, cur_chunk.end - 1)?) - } - if self.state.last_rank_data.is_empty() { - self.state.last_rank_data = last_rank_data; - self.state.last_rank_boundary = state.offset_pruned_rows + cur_chunk.start; - self.state.n_rank = chunk_idx + 1; - } else if self.state.last_rank_data != last_rank_data { + let chunk = &sort_partition_points[chunk_idx]; + let last_rank_data = range_columns + .iter() + .map(|c| ScalarValue::try_from_array(c, chunk.end - 1)) + .collect::>>()?; + let empty = self.state.last_rank_data.is_empty(); + if empty || self.state.last_rank_data != last_rank_data { self.state.last_rank_data = last_rank_data; - self.state.last_rank_boundary = state.offset_pruned_rows + cur_chunk.start; - self.state.n_rank += 1 + self.state.last_rank_boundary = state.offset_pruned_rows + chunk.start; + self.state.n_rank = 1 + if empty { chunk_idx } else { self.state.n_rank }; } Ok(()) } @@ -168,7 +165,7 @@ impl PartitionEvaluator for RankEvaluator { ))), RankType::Dense => Ok(ScalarValue::UInt64(Some(self.state.n_rank as u64))), RankType::Percent => Err(DataFusionError::Execution( - "Can not execute Percent_RANK in a streaming fashion".to_string(), + "Can not execute PERCENT_RANK in a streaming fashion".to_string(), )), } } diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index b9ff05e5cdc85..c858a5724a202 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -90,9 +90,8 @@ impl PartitionEvaluator for NumRowsEvaluator { /// evaluate window function result inside given range fn evaluate_stateful(&mut self, _values: &[ArrayRef]) -> Result { - let n_row = self.state.n_rows as u64 + 1; self.state.n_rows += 1; - Ok(ScalarValue::UInt64(Some(n_row))) + Ok(ScalarValue::UInt64(Some(self.state.n_rows as u64))) } fn evaluate(&self, _values: &[ArrayRef], num_rows: usize) -> Result { diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index bc293e2880748..e7a4bfae154cf 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -115,11 +115,11 @@ impl WindowExpr for SlidingAggregateWindowExpr { partition_batches: &PartitionBatches, window_agg_state: &mut PartitionWindowAggStates, ) -> Result<()> { + let field = self.aggregate.field()?; + let out_type = field.data_type(); for (partition_row, partition_batch_state) in partition_batches.iter() { if !window_agg_state.contains_key(partition_row) { let accumulator = self.aggregate.create_sliding_accumulator()?; - let field = self.aggregate.field()?; - let out_type = field.data_type(); window_agg_state.insert( partition_row.clone(), WindowState { @@ -142,8 +142,6 @@ impl WindowExpr for SlidingAggregateWindowExpr { let mut state = &mut window_state.state; state.is_end = partition_batch_state.is_end; - let num_rows = partition_batch_state.record_batch.num_rows(); - let mut idx = state.last_calculated_index; let mut last_range = state.window_frame_range.clone(); let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); @@ -159,6 +157,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { state.window_frame_range = last_range.clone(); state.out_col = concat(&[&state.out_col, &out_col])?; + let num_rows = partition_batch_state.record_batch.num_rows(); state.n_row_result_missing = num_rows - state.last_calculated_index; state.window_function_state = @@ -200,7 +199,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { }) } - fn can_run_bounded(&self) -> bool { + fn uses_bounded_memory(&self) -> bool { self.aggregate.supports_bounded_execution() && !self.window_frame.start_bound.is_unbounded() && !self.window_frame.end_bound.is_unbounded() @@ -268,7 +267,7 @@ impl SlidingAggregateWindowExpr { length, *idx, )?; - // exit if range end index is length, need kind of flag to stop + // Exit if range end index is length, need kind of flag to stop if cur_range.end == length && !is_end { break; } @@ -283,12 +282,10 @@ impl SlidingAggregateWindowExpr { last_range.end = cur_range.end; *idx += 1; } - let out_col = if !row_wise_results.is_empty() { - ScalarValue::iter_to_array(row_wise_results.into_iter())? + Ok(if row_wise_results.is_empty() { + ScalarValue::try_from(out_type)?.to_array_of_size(0) } else { - let a = ScalarValue::try_from(out_type)?; - a.to_array_of_size(0) - }; - Ok(out_col) + ScalarValue::iter_to_array(row_wise_results.into_iter())? + }) } } diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 394a3abe99878..de060d6b6b9b9 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -131,8 +131,9 @@ pub trait WindowExpr: Send + Sync + Debug { /// Get the window frame of this [WindowExpr]. fn get_window_frame(&self) -> &Arc; - /// get whether can run with bounded executor - fn can_run_bounded(&self) -> bool; + /// Return a flag indicating whether this [WindowExpr] can run with + /// bounded memory. + fn uses_bounded_memory(&self) -> bool; /// Get the reverse expression of this [WindowExpr]. fn get_reverse_expr(&self) -> Option>; @@ -267,7 +268,7 @@ pub struct WindowState { } pub type PartitionWindowAggStates = IndexMap; -/// The IndexMap(Ordered HashMap) where record_batch is seperated for each partition +/// The IndexMap (i.e. an ordered HashMap) where record batches are separated for each partition. pub type PartitionBatches = IndexMap; impl WindowAggState { From 8ac384782936e12381041884c172753bfb4e31aa Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Wed, 28 Dec 2022 23:53:58 -0500 Subject: [PATCH 35/50] Fix imports, make create_schema more functional --- datafusion/core/src/execution/context.rs | 4 ++-- .../src/physical_optimizer/use_bounded_window_execs.rs | 4 ++-- .../src/physical_plan/windows/bounded_window_agg_exec.rs | 8 ++++---- .../core/src/physical_plan/windows/window_agg_exec.rs | 8 ++++---- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index c1ce4e1725030..23d74d6a6a723 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -101,9 +101,9 @@ use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::memory_pool::MemoryPool; use crate::physical_optimizer::global_sort_selection::GlobalSortSelection; use crate::physical_optimizer::optimize_sorts::OptimizeSorts; -use crate::physical_optimizer::use_bounded_window_execs::UseBoundedWindowAggExec; use crate::physical_optimizer::pipeline_checker::PipelineChecker; use crate::physical_optimizer::pipeline_fixer::PipelineFixer; +use crate::physical_optimizer::use_bounded_window_execs::UseBoundedWindowAggExec; use uuid::Uuid; use super::options::{ @@ -1611,7 +1611,7 @@ impl SessionState { // The rule below performs this analysis and removes unnecessary sorts. physical_optimizers.push(Arc::new(OptimizeSorts::new())); // Replace ordinary window executors with bounded-memory variants when possible: - physical_optimizers.push(Arc::new(UseBoundedWindowAggExec::new())); + physical_optimizers.push(Arc::new(UseBoundedWindowAggExec::new())); // The CoalesceBatches rule will not influence the distribution and ordering of the whole // plan tree. Therefore, to avoid influencing other rules, it should be run at last. if config diff --git a/datafusion/core/src/physical_optimizer/use_bounded_window_execs.rs b/datafusion/core/src/physical_optimizer/use_bounded_window_execs.rs index 8f6da52e818be..8110d208233a7 100644 --- a/datafusion/core/src/physical_optimizer/use_bounded_window_execs.rs +++ b/datafusion/core/src/physical_optimizer/use_bounded_window_execs.rs @@ -22,7 +22,7 @@ use crate::physical_plan::windows::BoundedWindowAggExec; use crate::physical_plan::windows::WindowAggExec; use crate::{ - error::Result, physical_optimizer::PhysicalOptimizerRule, + config::ConfigOptions, error::Result, physical_optimizer::PhysicalOptimizerRule, physical_plan::rewrite::TreeNodeRewritable, }; use datafusion_expr::WindowFrameUnits; @@ -44,7 +44,7 @@ impl PhysicalOptimizerRule for UseBoundedWindowAggExec { fn optimize( &self, plan: Arc, - _config: &crate::execution::context::SessionConfig, + _config: &ConfigOptions, ) -> Result> { plan.transform_up(&|plan| { if let Some(window_agg_exec) = plan.as_any().downcast_ref::() { diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index a2f218c4669b0..2d83558ef1d96 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -277,10 +277,10 @@ fn create_schema( input_schema: &Schema, window_expr: &[Arc], ) -> Result { - let mut fields = Vec::with_capacity(input_schema.fields().len() + window_expr.len()); - for expr in window_expr { - fields.push(expr.field()?); - } + let mut fields = window_expr + .iter() + .map(|e| e.field()) + .collect::>>()?; fields.extend_from_slice(input_schema.fields()); Ok(Schema::new(fields)) } diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs index 23ec2d179f6bf..0dfba118932e0 100644 --- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs @@ -353,10 +353,10 @@ fn create_schema( input_schema: &Schema, window_expr: &[Arc], ) -> Result { - let mut fields = Vec::with_capacity(input_schema.fields().len() + window_expr.len()); - for expr in window_expr { - fields.push(expr.field()?); - } + let mut fields = window_expr + .iter() + .map(|e| e.field()) + .collect::>>()?; fields.extend_from_slice(input_schema.fields()); Ok(Schema::new(fields)) } From 3349edf18b82c545a64681a5cd60621a3b8e3c61 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 29 Dec 2022 09:16:26 +0300 Subject: [PATCH 36/50] address reviews --- .../windows/bounded_window_agg_exec.rs | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index 2d83558ef1d96..43727b8cc99bb 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -27,8 +27,8 @@ use crate::physical_plan::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, }; use crate::physical_plan::{ - ColumnStatistics, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, + Column, ColumnStatistics, DisplayFormatType, Distribution, ExecutionPlan, + Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, }; use arrow::array::Array; use arrow::compute::{concat, lexicographical_partition_ranges, SortColumn}; @@ -54,7 +54,7 @@ use datafusion_physical_expr::window::{ PartitionBatchState, PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowAggState, WindowState, }; -use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; +use datafusion_physical_expr::{EquivalenceProperties, EquivalentClass, PhysicalExpr}; use indexmap::IndexMap; use log::debug; @@ -182,17 +182,36 @@ impl ExecutionPlan for BoundedWindowAggExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - self.input.equivalence_properties() + // Although WindowAggExec does not change the equivalence properties from the input, but can not return the equivalence properties + // from the input directly, need to adjust the column index to align with the new schema. + let window_expr_len = self.window_expr.len(); + let mut new_properties = EquivalenceProperties::new(self.schema()); + let new_eq_classes = self + .input + .equivalence_properties() + .classes() + .iter() + .map(|prop| { + let new_head = Column::new( + prop.head().name(), + window_expr_len + prop.head().index(), + ); + let new_others = prop + .others() + .iter() + .map(|col| Column::new(col.name(), window_expr_len + col.index())) + .collect::>(); + EquivalentClass::new(new_head, new_others) + }) + .collect::>(); + new_properties.extend(new_eq_classes); + new_properties } fn maintains_input_order(&self) -> bool { true } - fn relies_on_input_order(&self) -> bool { - true - } - fn with_new_children( self: Arc, children: Vec>, @@ -267,7 +286,6 @@ impl ExecutionPlan for BoundedWindowAggExec { is_exact: input_stat.is_exact, num_rows: input_stat.num_rows, column_statistics: Some(column_statistics), - // TODO stats: knowing the type of the new columns we can guess the output size total_byte_size: None, } } @@ -348,7 +366,10 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { } /// Prunes sections of the state that are no longer needed when calculating - /// results (as determined by window frame boundaries). + /// results (as determined by window frame boundaries and number of results generated). + // For instance, if first `n` (not necessarily same with `n_out`) elements are no longer needed to + // calculate window expression result (outside the window frame boundary) we retract first `n` elements + // from `self.partition_batches` in corresponding partition. // For instance, if `n_out` number of rows are calculated, we can remove // first `n_out` rows from `self.input_buffer_record_batch`. fn prune_state(&mut self, n_out: usize) -> Result<()> { From 701c43ec0dab0ca4d38bd002ae58fbe816f4de1d Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 29 Dec 2022 10:17:48 +0300 Subject: [PATCH 37/50] undo yml change --- .github/workflows/rust.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 9b96c126eb76a..c15615d59b082 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -64,7 +64,7 @@ jobs: - name: Check Cargo.lock for datafusion-cli run: | # If this test fails, try running `cargo update` in the `datafusion-cli` directory - cargo check --manifest-path datafusion-cli/Cargo.toml + cargo check --manifest-path datafusion-cli/Cargo.toml --locked # test the crate linux-test: From 3b523b4475eb9c18b349758ac5bd06b763ca0788 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 29 Dec 2022 10:26:51 +0300 Subject: [PATCH 38/50] minor change to pass from CI --- .github/workflows/rust.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c15615d59b082..e3a58a47abf11 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -64,7 +64,7 @@ jobs: - name: Check Cargo.lock for datafusion-cli run: | # If this test fails, try running `cargo update` in the `datafusion-cli` directory - cargo check --manifest-path datafusion-cli/Cargo.toml --locked + cargo check --manifest-path datafusion-cli/Cargo.toml # test the crate linux-test: From 15d416ac805e14e120aaf1db1157aca6a40efbf9 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 29 Dec 2022 16:28:21 +0300 Subject: [PATCH 39/50] resolve merge conflicts --- .../windows/bounded_window_agg_exec.rs | 66 +++++++------------ .../physical_plan/windows/window_agg_exec.rs | 11 +++- datafusion/core/tests/sql/window.rs | 14 ++-- 3 files changed, 40 insertions(+), 51 deletions(-) diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index 43727b8cc99bb..4b26ee9054a6b 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -27,8 +27,8 @@ use crate::physical_plan::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, }; use crate::physical_plan::{ - Column, ColumnStatistics, DisplayFormatType, Distribution, ExecutionPlan, - Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, + ColumnStatistics, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, }; use arrow::array::Array; use arrow::compute::{concat, lexicographical_partition_ranges, SortColumn}; @@ -54,7 +54,7 @@ use datafusion_physical_expr::window::{ PartitionBatchState, PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowAggState, WindowState, }; -use datafusion_physical_expr::{EquivalenceProperties, EquivalentClass, PhysicalExpr}; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; use indexmap::IndexMap; use log::debug; @@ -159,11 +159,12 @@ impl ExecutionPlan for BoundedWindowAggExec { self.input.output_partitioning() } + fn unbounded_output(&self, children: &[bool]) -> Result { + Ok(children[0]) + } + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - // This executor maintains input order, and it has a required input - // ordering. Therefore, output_ordering would be the same with - // `required_input_ordering`. - self.required_input_ordering()[0] + self.input().output_ordering() } fn required_input_ordering(&self) -> Vec> { @@ -173,7 +174,7 @@ impl ExecutionPlan for BoundedWindowAggExec { fn required_input_distribution(&self) -> Vec { if self.partition_keys.is_empty() { - debug!("No partition defined for WindowAggExec!!!"); + debug!("No partition defined for BoundedWindowAggExec!!!"); vec![Distribution::SinglePartition] } else { //TODO support PartitionCollections if there is no common partition columns in the window_expr @@ -182,30 +183,7 @@ impl ExecutionPlan for BoundedWindowAggExec { } fn equivalence_properties(&self) -> EquivalenceProperties { - // Although WindowAggExec does not change the equivalence properties from the input, but can not return the equivalence properties - // from the input directly, need to adjust the column index to align with the new schema. - let window_expr_len = self.window_expr.len(); - let mut new_properties = EquivalenceProperties::new(self.schema()); - let new_eq_classes = self - .input - .equivalence_properties() - .classes() - .iter() - .map(|prop| { - let new_head = Column::new( - prop.head().name(), - window_expr_len + prop.head().index(), - ); - let new_others = prop - .others() - .iter() - .map(|col| Column::new(col.name(), window_expr_len + col.index())) - .collect::>(); - EquivalentClass::new(new_head, new_others) - }) - .collect::>(); - new_properties.extend(new_eq_classes); - new_properties + self.input().equivalence_properties() } fn maintains_input_order(&self) -> bool { @@ -276,12 +254,13 @@ impl ExecutionPlan for BoundedWindowAggExec { let win_cols = self.window_expr.len(); let input_cols = self.input_schema.fields().len(); // TODO stats: some windowing function will maintain invariants such as min, max... - let mut column_statistics = vec![ColumnStatistics::default(); win_cols]; + let mut column_statistics = Vec::with_capacity(win_cols + input_cols); if let Some(input_col_stats) = input_stat.column_statistics { column_statistics.extend(input_col_stats); } else { column_statistics.extend(vec![ColumnStatistics::default(); input_cols]); } + column_statistics.extend(vec![ColumnStatistics::default(); win_cols]); Statistics { is_exact: input_stat.is_exact, num_rows: input_stat.num_rows, @@ -295,11 +274,16 @@ fn create_schema( input_schema: &Schema, window_expr: &[Arc], ) -> Result { - let mut fields = window_expr + let mut fields = Vec::with_capacity(input_schema.fields().len() + window_expr.len()); + fields.extend_from_slice(input_schema.fields()); + // append results to the schema + window_expr .iter() - .map(|e| e.field()) + .map(|e| { + fields.push(e.field()?); + Ok(()) + }) .collect::>>()?; - fields.extend_from_slice(input_schema.fields()); Ok(Schema::new(fields)) } @@ -351,14 +335,14 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { if n_out == 0 { Ok(None) } else { - self.window_agg_states + self.input_buffer_record_batch + .columns() .iter() - .map(|elem| get_aggregate_result_out_column(elem, n_out)) + .map(|elem| Ok(elem.slice(0, n_out))) .chain( - self.input_buffer_record_batch - .columns() + self.window_agg_states .iter() - .map(|elem| Ok(elem.slice(0, n_out))), + .map(|elem| get_aggregate_result_out_column(elem, n_out)), ) .collect::>>() .map(Some) diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs index 2dd0d23d12004..2605219c740a1 100644 --- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs @@ -275,11 +275,16 @@ fn create_schema( input_schema: &Schema, window_expr: &[Arc], ) -> Result { - let mut fields = window_expr + let mut fields = Vec::with_capacity(input_schema.fields().len() + window_expr.len()); + fields.extend_from_slice(input_schema.fields()); + // append results to the schema + window_expr .iter() - .map(|e| e.field()) + .map(|e| { + fields.push(e.field()?); + Ok(()) + }) .collect::>>()?; - fields.extend_from_slice(input_schema.fields()); Ok(Schema::new(fields)) } diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index a0cf2bcda2c6b..afe01b468d664 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1759,11 +1759,11 @@ async fn test_window_partition_by_order_by() -> Result<()> { let formatted = displayable(physical_plan.as_ref()).indent().to_string(); let expected = { vec![ - "ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as COUNT(UInt8(1))]", + "ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(UInt8(1))]", " BoundedWindowAggExec: wdw=[COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", - " SortExec: [c1@1 ASC NULLS LAST,c2@2 ASC NULLS LAST]", + " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 1 }], 2)", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 2)", " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", @@ -1915,7 +1915,7 @@ async fn test_window_agg_sort_non_reversed_plan() -> Result<()> { " RepartitionExec: partitioning=RoundRobinBatch(2)", " GlobalLimitExec: skip=0, fetch=5", " BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c9@1 ASC NULLS LAST]", + " SortExec: [c9@0 ASC NULLS LAST]", " BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c9@0 DESC]", ] @@ -1970,7 +1970,7 @@ async fn test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> { " RepartitionExec: partitioning=RoundRobinBatch(2)", " GlobalLimitExec: skip=0, fetch=5", " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " SortExec: [c9@4 ASC NULLS LAST,c1@2 ASC NULLS LAST,c2@3 ASC NULLS LAST]", + " SortExec: [c9@2 ASC NULLS LAST,c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", " BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c9@2 DESC,c1@0 DESC]", @@ -2505,7 +2505,7 @@ mod tests { "ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, sum3@2 as sum3, min1@3 as min1, min2@4 as min2, min3@5 as min3, max1@6 as max1, max2@7 as max2, max3@8 as max3, cnt1@9 as cnt1, cnt2@10 as cnt2, sumr1@11 as sumr1, sumr2@12 as sumr2, sumr3@13 as sumr3, minr1@14 as minr1, minr2@15 as minr2, minr3@16 as minr3, maxr1@17 as maxr1, maxr2@18 as maxr2, maxr3@19 as maxr3, cntr1@20 as cntr1, cntr2@21 as cntr2, sum4@22 as sum4, cnt3@23 as cnt3]", " GlobalLimitExec: skip=0, fetch=5", " SortExec: [inc_col@24 DESC]", - " ProjectionExec: expr=[SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as sum1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@3 as sum2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@4 as sum3, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as min1, MIN(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@6 as min2, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@7 as min3, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@8 as max1, MAX(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@9 as max2, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as max3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@11 as cnt1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@12 as cnt2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@13 as sumr1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@14 as sumr2, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@15 as sumr3, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as minr1, MIN(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@17 as minr2, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as minr3, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as maxr1, MAX(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@20 as maxr2, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@21 as maxr3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@22 as cntr1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@23 as cntr2, SUM(annotated_data.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@0 as sum4, COUNT(UInt8(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@1 as cnt3, inc_col@25 as inc_col]", + " ProjectionExec: expr=[SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as sum1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@15 as sum2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as sum3, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as min1, MIN(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@18 as min2, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@19 as min3, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@20 as max1, MAX(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@21 as max2, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as max3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@23 as cnt1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@24 as cnt2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@3 as sumr1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@4 as sumr2, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as sumr3, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@6 as minr1, MIN(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@7 as minr2, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as minr3, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as maxr1, MAX(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@10 as maxr2, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@11 as maxr3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@12 as cntr1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@13 as cntr2, SUM(annotated_data.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@25 as sum4, COUNT(UInt8(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@26 as cnt3, inc_col@1 as inc_col]", " BoundedWindowAggExec: wdw=[SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }]", " BoundedWindowAggExec: wdw=[SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, MIN(annotated_data.inc_col): Ok(Field { name: \"MIN(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, MIN(annotated_data.desc_col): Ok(Field { name: \"MIN(annotated_data.desc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, MIN(annotated_data.inc_col): Ok(Field { name: \"MIN(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, MAX(annotated_data.inc_col): Ok(Field { name: \"MAX(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, MAX(annotated_data.desc_col): Ok(Field { name: \"MAX(annotated_data.desc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, MAX(annotated_data.inc_col): Ok(Field { name: \"MAX(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(8)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }]", " BoundedWindowAggExec: wdw=[SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(8)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, MIN(annotated_data.inc_col): Ok(Field { name: \"MIN(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, MIN(annotated_data.desc_col): Ok(Field { name: \"MIN(annotated_data.desc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)) }, MIN(annotated_data.inc_col): Ok(Field { name: \"MIN(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, MAX(annotated_data.inc_col): Ok(Field { name: \"MAX(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, MAX(annotated_data.desc_col): Ok(Field { name: \"MAX(annotated_data.desc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)) }, MAX(annotated_data.inc_col): Ok(Field { name: \"MAX(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(2)), end_bound: Following(Int32(6)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(8)) }]", @@ -2581,7 +2581,7 @@ mod tests { "ProjectionExec: expr=[fv1@0 as fv1, fv2@1 as fv2, lv1@2 as lv1, lv2@3 as lv2, nv1@4 as nv1, nv2@5 as nv2, rn1@6 as rn1, rn2@7 as rn2, rank1@8 as rank1, rank2@9 as rank2, dense_rank1@10 as dense_rank1, dense_rank2@11 as dense_rank2, lag1@12 as lag1, lag2@13 as lag2, lead1@14 as lead1, lead2@15 as lead2, fvr1@16 as fvr1, fvr2@17 as fvr2, lvr1@18 as lvr1, lvr2@19 as lvr2, lagr1@20 as lagr1, lagr2@21 as lagr2, leadr1@22 as leadr1, leadr2@23 as leadr2]", " GlobalLimitExec: skip=0, fetch=5", " SortExec: [ts@24 DESC]", - " ProjectionExec: expr=[FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@0 as fv1, FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@1 as fv2, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as lv1, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as lv2, NTH_VALUE(annotated_data.inc_col,Int64(5)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as nv1, NTH_VALUE(annotated_data.inc_col,Int64(5)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as nv2, ROW_NUMBER() ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as rn1, ROW_NUMBER() ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as rn2, RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as rank1, RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as rank2, DENSE_RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as dense_rank1, DENSE_RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as dense_rank2, LAG(annotated_data.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@12 as lag1, LAG(annotated_data.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lag2, LEAD(annotated_data.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@14 as lead1, LEAD(annotated_data.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as lead2, FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as fvr1, FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as fvr2, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@18 as lvr1, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as lvr2, LAG(annotated_data.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as lagr1, LAG(annotated_data.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as lagr2, LEAD(annotated_data.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as leadr1, LEAD(annotated_data.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as leadr2, ts@24 as ts]", + " ProjectionExec: expr=[FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data.inc_col,Int64(5)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data.inc_col,Int64(5)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, ROW_NUMBER() ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, ROW_NUMBER() ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, DENSE_RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, DENSE_RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, LAG(annotated_data.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, LAG(annotated_data.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, LEAD(annotated_data.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, LEAD(annotated_data.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, LAG(annotated_data.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, LAG(annotated_data.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, LEAD(annotated_data.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, LEAD(annotated_data.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2, ts@0 as ts]", " BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data.inc_col): Ok(Field { name: \"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, FIRST_VALUE(annotated_data.inc_col): Ok(Field { name: \"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LAST_VALUE(annotated_data.inc_col): Ok(Field { name: \"LAST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LAST_VALUE(annotated_data.inc_col): Ok(Field { name: \"LAST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, NTH_VALUE(annotated_data.inc_col,Int64(5)): Ok(Field { name: \"NTH_VALUE(annotated_data.inc_col,Int64(5))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, NTH_VALUE(annotated_data.inc_col,Int64(5)): Ok(Field { name: \"NTH_VALUE(annotated_data.inc_col,Int64(5))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, RANK(): Ok(Field { name: \"RANK()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, RANK(): Ok(Field { name: \"RANK()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, DENSE_RANK(): Ok(Field { name: \"DENSE_RANK()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, DENSE_RANK(): Ok(Field { name: \"DENSE_RANK()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LAG(annotated_data.inc_col,Int64(1),Int64(1001)): Ok(Field { name: \"LAG(annotated_data.inc_col,Int64(1),Int64(1001))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LAG(annotated_data.inc_col,Int64(2),Int64(1002)): Ok(Field { name: \"LAG(annotated_data.inc_col,Int64(2),Int64(1002))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(annotated_data.inc_col,Int64(-1),Int64(1001)): Ok(Field { name: \"LEAD(annotated_data.inc_col,Int64(-1),Int64(1001))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LEAD(annotated_data.inc_col,Int64(4),Int64(1004)): Ok(Field { name: \"LEAD(annotated_data.inc_col,Int64(4),Int64(1004))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }]", " BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data.inc_col): Ok(Field { name: \"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, FIRST_VALUE(annotated_data.inc_col): Ok(Field { name: \"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LAST_VALUE(annotated_data.inc_col): Ok(Field { name: \"LAST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LAST_VALUE(annotated_data.inc_col): Ok(Field { name: \"LAST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LAG(annotated_data.inc_col,Int64(1),Int64(1001)): Ok(Field { name: \"LAG(annotated_data.inc_col,Int64(1),Int64(1001))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LAG(annotated_data.inc_col,Int64(2),Int64(1002)): Ok(Field { name: \"LAG(annotated_data.inc_col,Int64(2),Int64(1002))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LEAD(annotated_data.inc_col,Int64(-1),Int64(1001)): Ok(Field { name: \"LEAD(annotated_data.inc_col,Int64(-1),Int64(1001))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LEAD(annotated_data.inc_col,Int64(4),Int64(1004)): Ok(Field { name: \"LEAD(annotated_data.inc_col,Int64(4),Int64(1004))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }]", ] From 9ceb13759d6127fa57979f0d8914a1f6e97c2782 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Thu, 29 Dec 2022 18:17:10 +0300 Subject: [PATCH 40/50] rename some members --- .../windows/bounded_window_agg_exec.rs | 52 +++++++++---------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index 4b26ee9054a6b..e728c586ac71a 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -308,7 +308,7 @@ pub struct SortedPartitionByBoundedWindowStream { input: SendableRecordBatchStream, /// The record batch executor receives as input (i.e. the columns needed /// while calculating aggregation results). - input_buffer_record_batch: RecordBatch, + input_buffer: RecordBatch, /// We separate `input_buffer_record_batch` based on partitions (as /// determined by PARTITION BY columns) and store them per partition /// in `partition_batches`. We use this variable when calculating results @@ -317,7 +317,7 @@ pub struct SortedPartitionByBoundedWindowStream { // Note that we could keep record batches for each window expression in // `PartitionWindowAggStates`. However, this would use more memory (as // many times as the number of window expressions). - partition_batches: PartitionBatches, + partition_buffers: PartitionBatches, /// An executor can run multiple window expressions if the PARTITION BY /// and ORDER BY sections are same. We keep state of the each window /// expression inside `window_agg_states`. @@ -335,7 +335,7 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { if n_out == 0 { Ok(None) } else { - self.input_buffer_record_batch + self.input_buffer .columns() .iter() .map(|elem| Ok(elem.slice(0, n_out))) @@ -384,7 +384,7 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { partition_range.end - partition_range.start, ); if let Some(partition_batch_state) = - self.partition_batches.get_mut(&partition_row) + self.partition_buffers.get_mut(&partition_row) { partition_batch_state.record_batch = combine_batches( &[&partition_batch_state.record_batch, &partition_batch], @@ -400,30 +400,26 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { record_batch: partition_batch, is_end: false, }; - self.partition_batches + self.partition_buffers .insert(partition_row, partition_batch_state); }; } } - let n_partitions = self.partition_batches.len(); + let n_partitions = self.partition_buffers.len(); for (idx, (_, partition_batch_state)) in - self.partition_batches.iter_mut().enumerate() + self.partition_buffers.iter_mut().enumerate() { partition_batch_state.is_end |= idx < n_partitions - 1; } - self.input_buffer_record_batch = if self.input_buffer_record_batch.num_rows() == 0 - { + self.input_buffer = if self.input_buffer.num_rows() == 0 { record_batch } else { - combine_batches( - &[&self.input_buffer_record_batch, &record_batch], - self.input.schema(), - )? - .ok_or_else(|| { - DataFusionError::Execution( - "Should contain at least one entry".to_string(), - ) - })? + combine_batches(&[&self.input_buffer, &record_batch], self.input.schema())? + .ok_or_else(|| { + DataFusionError::Execution( + "Should contain at least one entry".to_string(), + ) + })? }; Ok(()) @@ -456,8 +452,8 @@ impl SortedPartitionByBoundedWindowStream { Self { schema, input, - input_buffer_record_batch: empty_batch, - partition_batches: IndexMap::new(), + input_buffer: empty_batch, + partition_buffers: IndexMap::new(), window_agg_states: state, finished: false, window_expr, @@ -471,7 +467,7 @@ impl SortedPartitionByBoundedWindowStream { for (cur_window_expr, state) in self.window_expr.iter().zip(&mut self.window_agg_states) { - cur_window_expr.evaluate_stateful(&self.partition_batches, state)?; + cur_window_expr.evaluate_stateful(&self.partition_buffers, state)?; } let schema = self.schema.clone(); @@ -502,7 +498,7 @@ impl SortedPartitionByBoundedWindowStream { Some(Err(e)) => Err(e), None => { self.finished = true; - for (_, partition_batch_state) in self.partition_batches.iter_mut() { + for (_, partition_batch_state) in self.partition_buffers.iter_mut() { partition_batch_state.is_end = true; } self.compute_aggregates() @@ -547,7 +543,7 @@ impl SortedPartitionByBoundedWindowStream { // Remove partitions which we know already ended (is_end flag is true). // Since the retain method preserves insertion order, we still have // ordering in between partitions after removal. - self.partition_batches + self.partition_buffers .retain(|_, partition_batch_state| !partition_batch_state.is_end); // The data in `self.partition_batches` is used by all window expressions. @@ -580,7 +576,7 @@ impl SortedPartitionByBoundedWindowStream { // Retract no longer needed parts during window calculations from partition batch: for (partition_row, n_prune) in n_prune_each_partition.iter() { let partition_batch_state = self - .partition_batches + .partition_buffers .get_mut(partition_row) .ok_or_else(err)?; let batch = &partition_batch_state.record_batch; @@ -606,15 +602,15 @@ impl SortedPartitionByBoundedWindowStream { /// Prunes the section of the input batch whose aggregate results /// are calculated and emitted. fn prune_input_batch(&mut self, n_out: usize) -> Result<()> { - let n_to_keep = self.input_buffer_record_batch.num_rows() - n_out; + let n_to_keep = self.input_buffer.num_rows() - n_out; let batch_to_keep = self - .input_buffer_record_batch + .input_buffer .columns() .iter() .map(|elem| elem.slice(n_out, n_to_keep)) .collect::>(); - self.input_buffer_record_batch = - RecordBatch::try_new(self.input_buffer_record_batch.schema(), batch_to_keep)?; + self.input_buffer = + RecordBatch::try_new(self.input_buffer.schema(), batch_to_keep)?; Ok(()) } From 8b9aa6f343b732b827b255cbc0301fd63267cd55 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Fri, 30 Dec 2022 13:55:48 +0300 Subject: [PATCH 41/50] Move rule to physical planning --- datafusion/core/src/execution/context.rs | 3 - datafusion/core/src/physical_optimizer/mod.rs | 1 - .../src/physical_optimizer/optimize_sorts.rs | 55 +++++++++--- .../physical_optimizer/pipeline_checker.rs | 4 +- .../use_bounded_window_execs.rs | 83 ------------------- datafusion/core/src/physical_plan/planner.rs | 31 +++++-- .../physical-expr/src/window/aggregate.rs | 5 +- .../physical-expr/src/window/built_in.rs | 6 +- .../src/window/sliding_aggregate.rs | 5 +- 9 files changed, 80 insertions(+), 113 deletions(-) delete mode 100644 datafusion/core/src/physical_optimizer/use_bounded_window_execs.rs diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index fd75ef5522561..5203ac1d85af9 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -103,7 +103,6 @@ use crate::physical_optimizer::global_sort_selection::GlobalSortSelection; use crate::physical_optimizer::optimize_sorts::OptimizeSorts; use crate::physical_optimizer::pipeline_checker::PipelineChecker; use crate::physical_optimizer::pipeline_fixer::PipelineFixer; -use crate::physical_optimizer::use_bounded_window_execs::UseBoundedWindowAggExec; use uuid::Uuid; use super::options::{ @@ -1610,8 +1609,6 @@ impl SessionState { // These cases typically arise when we have reversible window expressions or deep subqueries. // The rule below performs this analysis and removes unnecessary sorts. physical_optimizers.push(Arc::new(OptimizeSorts::new())); - // Replace ordinary window executors with bounded-memory variants when possible: - physical_optimizers.push(Arc::new(UseBoundedWindowAggExec::new())); // The CoalesceBatches rule will not influence the distribution and ordering of the whole // plan tree. Therefore, to avoid influencing other rules, it should be run at last. if config diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index 1aa192d0819f3..fb07d54b99d98 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -28,7 +28,6 @@ pub mod optimizer; pub mod pipeline_checker; pub mod pruning; pub mod repartition; -pub mod use_bounded_window_execs; mod utils; pub mod pipeline_fixer; diff --git a/datafusion/core/src/physical_optimizer/optimize_sorts.rs b/datafusion/core/src/physical_optimizer/optimize_sorts.rs index ed827c14e1360..14044dd68ee46 100644 --- a/datafusion/core/src/physical_optimizer/optimize_sorts.rs +++ b/datafusion/core/src/physical_optimizer/optimize_sorts.rs @@ -33,10 +33,11 @@ use crate::physical_optimizer::utils::{ use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::rewrite::TreeNodeRewritable; use crate::physical_plan::sorts::sort::SortExec; -use crate::physical_plan::windows::WindowAggExec; +use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use arrow::datatypes::SchemaRef; use datafusion_common::{reverse_sort_options, DataFusionError}; +use datafusion_physical_expr::window::WindowExpr; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use itertools::izip; use std::iter::zip; @@ -187,7 +188,23 @@ fn optimize_sorts( // For window expressions, we can remove some sorts when we can // calculate the result in reverse: if let Some(res) = analyze_window_sort_removal( - window_agg_exec, + window_agg_exec.window_expr(), + &window_agg_exec.partition_keys, + sort_exec, + sort_onwards, + )? { + return Ok(Some(res)); + } + } else if let Some(bounded_window_agg_exec) = requirements + .plan + .as_any() + .downcast_ref::() + { + // For window expressions, we can remove some sorts when we can + // calculate the result in reverse: + if let Some(res) = analyze_window_sort_removal( + bounded_window_agg_exec.window_expr(), + &bounded_window_agg_exec.partition_keys, sort_exec, sort_onwards, )? { @@ -273,9 +290,10 @@ fn analyze_immediate_sort_removal( Ok(None) } -/// Analyzes a `WindowAggExec` to determine whether it may allow removing a sort. +/// Analyzes a `WindowAggExec` or `BoundedWindowAggExec` to determine whether it may allow removing a sort. fn analyze_window_sort_removal( - window_agg_exec: &WindowAggExec, + window_expr: &[Arc], + partition_keys: &[Arc], sort_exec: &SortExec, sort_onward: &mut Vec<(usize, Arc)>, ) -> Result> { @@ -289,7 +307,6 @@ fn analyze_window_sort_removal( // If there is no physical ordering, there is no way to remove a sort -- immediately return: return Ok(None); }; - let window_expr = window_agg_exec.window_expr(); let (can_skip_sorting, should_reverse) = can_skip_sort( window_expr[0].partition_by(), required_ordering, @@ -308,13 +325,27 @@ fn analyze_window_sort_removal( if let Some(window_expr) = new_window_expr { let new_child = remove_corresponding_sort_from_sub_plan(sort_onward)?; let new_schema = new_child.schema(); - let new_plan = Arc::new(WindowAggExec::try_new( - window_expr, - new_child, - new_schema, - window_agg_exec.partition_keys.clone(), - Some(physical_ordering.to_vec()), - )?); + + let uses_bounded_memory = + window_expr.iter().all(|elem| elem.uses_bounded_memory()); + // If all window exprs can run with bounded memory choose bounded window variant + let new_plan = if uses_bounded_memory { + Arc::new(BoundedWindowAggExec::try_new( + window_expr, + new_child, + new_schema, + partition_keys.to_vec(), + Some(physical_ordering.to_vec()), + )?) as _ + } else { + Arc::new(WindowAggExec::try_new( + window_expr, + new_child, + new_schema, + partition_keys.to_vec(), + Some(physical_ordering.to_vec()), + )?) as _ + }; return Ok(Some(PlanWithCorrespondingSort::new(new_plan))); } } diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index c35ef29f20c9b..96f0b0ff6932d 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -301,7 +301,7 @@ mod sql_tests { let case = QueryCase { sql: "SELECT c9, - SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1 + SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 FROM test LIMIT 5".to_string(), cases: vec![Arc::new(test1), Arc::new(test2)], @@ -325,7 +325,7 @@ mod sql_tests { let case = QueryCase { sql: "SELECT c9, - SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1 + SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 FROM test".to_string(), cases: vec![Arc::new(test1), Arc::new(test2)], error_operator: "Window Error".to_string() diff --git a/datafusion/core/src/physical_optimizer/use_bounded_window_execs.rs b/datafusion/core/src/physical_optimizer/use_bounded_window_execs.rs deleted file mode 100644 index 8110d208233a7..0000000000000 --- a/datafusion/core/src/physical_optimizer/use_bounded_window_execs.rs +++ /dev/null @@ -1,83 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! The [UseBoundedWindowAggExec] rule replaces [WindowAggExec]s with -//! [BoundedWindowAggExec]s if the window expression in question is -//! amenable to pipeline-friendly bounded memory execution. - -use crate::physical_plan::windows::BoundedWindowAggExec; -use crate::physical_plan::windows::WindowAggExec; -use crate::{ - config::ConfigOptions, error::Result, physical_optimizer::PhysicalOptimizerRule, - physical_plan::rewrite::TreeNodeRewritable, -}; -use datafusion_expr::WindowFrameUnits; -use std::sync::Arc; - -/// This rule checks whether [WindowAggExec]s in the query plan can be -/// replaced with [BoundedWindowAggExec]s, and replaces them whenever possible. -#[derive(Default)] -pub struct UseBoundedWindowAggExec {} - -impl UseBoundedWindowAggExec { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl PhysicalOptimizerRule for UseBoundedWindowAggExec { - fn optimize( - &self, - plan: Arc, - _config: &ConfigOptions, - ) -> Result> { - plan.transform_up(&|plan| { - if let Some(window_agg_exec) = plan.as_any().downcast_ref::() { - let contains_groups = - window_agg_exec.window_expr().iter().any(|window_expr| { - matches!( - window_expr.get_window_frame().units, - WindowFrameUnits::Groups - ) - }); - let uses_bounded_memory = window_agg_exec - .window_expr() - .iter() - .all(|elem| elem.uses_bounded_memory()); - if !contains_groups && uses_bounded_memory { - return Ok(Some(Arc::new(BoundedWindowAggExec::try_new( - window_agg_exec.window_expr().to_vec(), - window_agg_exec.input().clone(), - window_agg_exec.input().schema(), - window_agg_exec.partition_keys.clone(), - window_agg_exec.sort_keys.clone(), - )?))); - } - } - Ok(None) - }) - } - - fn name(&self) -> &str { - "UseBoundedWindowAggExec" - } - - fn schema_check(&self) -> bool { - true - } -} diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 5b001f01678ec..9cfdb3ee7033b 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -50,7 +50,7 @@ use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; -use crate::physical_plan::windows::WindowAggExec; +use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{joins::utils as join_utils, Partitioning}; use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr}; use crate::{ @@ -617,13 +617,28 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; - Ok(Arc::new(WindowAggExec::try_new( - window_expr, - input_exec, - physical_input_schema, - physical_partition_keys, - physical_sort_keys, - )?)) + let uses_bounded_memory = window_expr + .iter() + .all(|elem| elem.uses_bounded_memory()); + // If all window exprs can run with bounded memory choose bounded window variant + if uses_bounded_memory { + Ok(Arc::new(BoundedWindowAggExec::try_new( + window_expr, + input_exec, + physical_input_schema, + physical_partition_keys, + physical_sort_keys, + )?)) + } + else { + Ok(Arc::new(WindowAggExec::try_new( + window_expr, + input_exec, + physical_input_schema, + physical_partition_keys, + physical_sort_keys, + )?)) + } } LogicalPlan::Aggregate(Aggregate { input, diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 3317e2abdeeff..e9e37dda129a8 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -29,7 +29,7 @@ use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::WindowFrame; +use datafusion_expr::{WindowFrame, WindowFrameUnits}; use crate::window::window_expr::reverse_order_bys; use crate::window::SlidingAggregateWindowExpr; @@ -164,8 +164,11 @@ impl WindowExpr for AggregateWindowExpr { } fn uses_bounded_memory(&self) -> bool { + // Currently groups queries cannot use run with bounded memory + let is_group = matches!(self.window_frame.units, WindowFrameUnits::Groups); self.aggregate.supports_bounded_execution() && !self.window_frame.start_bound.is_unbounded() && !self.window_frame.end_bound.is_unbounded() + && !is_group } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index c918bdda9e5ed..1ee90f7946cd1 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -32,7 +32,7 @@ use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::WindowFrame; +use datafusion_expr::{WindowFrame, WindowFrameUnits}; use std::any::Any; use std::sync::Arc; @@ -242,6 +242,8 @@ impl WindowExpr for BuiltInWindowExpr { self.expr.supports_bounded_execution() && (!self.expr.uses_window_frame() || !(self.window_frame.start_bound.is_unbounded() - || self.window_frame.end_bound.is_unbounded())) + || self.window_frame.end_bound.is_unbounded() + // Currently groups queries cannot use run with bounded memory + || matches!(self.window_frame.units, WindowFrameUnits::Groups))) } } diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index e7a4bfae154cf..d3ce34ca5488f 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -29,7 +29,7 @@ use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{Accumulator, WindowFrame}; +use datafusion_expr::{Accumulator, WindowFrame, WindowFrameUnits}; use crate::window::window_expr::{reverse_order_bys, WindowFn, WindowFunctionState}; use crate::window::{ @@ -200,9 +200,12 @@ impl WindowExpr for SlidingAggregateWindowExpr { } fn uses_bounded_memory(&self) -> bool { + // Currently groups queries cannot use run with bounded memory + let is_group = matches!(self.window_frame.units, WindowFrameUnits::Groups); self.aggregate.supports_bounded_execution() && !self.window_frame.start_bound.is_unbounded() && !self.window_frame.end_bound.is_unbounded() + && !is_group } } From e13d6e0f43857707f1eac2e66aa2be37c6f75b53 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Fri, 30 Dec 2022 09:59:08 -0500 Subject: [PATCH 42/50] Minor stylistic/comment changes --- .../src/physical_optimizer/optimize_sorts.rs | 33 +++++++++---------- datafusion/core/src/physical_plan/planner.rs | 20 +++++------ .../physical-expr/src/window/aggregate.rs | 5 ++- .../physical-expr/src/window/built_in.rs | 4 +-- .../src/window/sliding_aggregate.rs | 5 ++- 5 files changed, 32 insertions(+), 35 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/optimize_sorts.rs b/datafusion/core/src/physical_optimizer/optimize_sorts.rs index 14044dd68ee46..a666775b093a0 100644 --- a/datafusion/core/src/physical_optimizer/optimize_sorts.rs +++ b/datafusion/core/src/physical_optimizer/optimize_sorts.rs @@ -182,33 +182,32 @@ fn optimize_sorts( sort_exec.input().equivalence_properties() }) { update_child_to_remove_unnecessary_sort(child, sort_onwards)?; - } else if let Some(window_agg_exec) = + } + // For window expressions, we can remove some sorts when we can + // calculate the result in reverse: + else if let Some(exec) = requirements.plan.as_any().downcast_ref::() { - // For window expressions, we can remove some sorts when we can - // calculate the result in reverse: - if let Some(res) = analyze_window_sort_removal( - window_agg_exec.window_expr(), - &window_agg_exec.partition_keys, + if let Some(result) = analyze_window_sort_removal( + exec.window_expr(), + &exec.partition_keys, sort_exec, sort_onwards, )? { - return Ok(Some(res)); + return Ok(Some(result)); } - } else if let Some(bounded_window_agg_exec) = requirements + } else if let Some(exec) = requirements .plan .as_any() .downcast_ref::() { - // For window expressions, we can remove some sorts when we can - // calculate the result in reverse: - if let Some(res) = analyze_window_sort_removal( - bounded_window_agg_exec.window_expr(), - &bounded_window_agg_exec.partition_keys, + if let Some(result) = analyze_window_sort_removal( + exec.window_expr(), + &exec.partition_keys, sort_exec, sort_onwards, )? { - return Ok(Some(res)); + return Ok(Some(result)); } } // TODO: Once we can ensure that required ordering information propagates with @@ -290,7 +289,8 @@ fn analyze_immediate_sort_removal( Ok(None) } -/// Analyzes a `WindowAggExec` or `BoundedWindowAggExec` to determine whether it may allow removing a sort. +/// Analyzes a [WindowAggExec] or a [BoundedWindowAggExec] to determine whether +/// it may allow removing a sort. fn analyze_window_sort_removal( window_expr: &[Arc], partition_keys: &[Arc], @@ -326,8 +326,7 @@ fn analyze_window_sort_removal( let new_child = remove_corresponding_sort_from_sub_plan(sort_onward)?; let new_schema = new_child.schema(); - let uses_bounded_memory = - window_expr.iter().all(|elem| elem.uses_bounded_memory()); + let uses_bounded_memory = window_expr.iter().all(|e| e.uses_bounded_memory()); // If all window exprs can run with bounded memory choose bounded window variant let new_plan = if uses_bounded_memory { Arc::new(BoundedWindowAggExec::try_new( diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 9cfdb3ee7033b..e03bff04668e9 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -619,26 +619,26 @@ impl DefaultPhysicalPlanner { let uses_bounded_memory = window_expr .iter() - .all(|elem| elem.uses_bounded_memory()); - // If all window exprs can run with bounded memory choose bounded window variant - if uses_bounded_memory { - Ok(Arc::new(BoundedWindowAggExec::try_new( + .all(|e| e.uses_bounded_memory()); + // If all window expressions can run with bounded memory, + // choose the bounded window variant: + Ok(if uses_bounded_memory { + Arc::new(BoundedWindowAggExec::try_new( window_expr, input_exec, physical_input_schema, physical_partition_keys, physical_sort_keys, - )?)) - } - else { - Ok(Arc::new(WindowAggExec::try_new( + )?) + } else { + Arc::new(WindowAggExec::try_new( window_expr, input_exec, physical_input_schema, physical_partition_keys, physical_sort_keys, - )?)) - } + )?) + }) } LogicalPlan::Aggregate(Aggregate { input, diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index e9e37dda129a8..df61e7cc8fbbd 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -164,11 +164,10 @@ impl WindowExpr for AggregateWindowExpr { } fn uses_bounded_memory(&self) -> bool { - // Currently groups queries cannot use run with bounded memory - let is_group = matches!(self.window_frame.units, WindowFrameUnits::Groups); + // NOTE: Currently, groups queries do not support the bounded memory variant. self.aggregate.supports_bounded_execution() && !self.window_frame.start_bound.is_unbounded() && !self.window_frame.end_bound.is_unbounded() - && !is_group + && !matches!(self.window_frame.units, WindowFrameUnits::Groups) } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 1ee90f7946cd1..f0484b790fbc6 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -239,11 +239,11 @@ impl WindowExpr for BuiltInWindowExpr { } fn uses_bounded_memory(&self) -> bool { + // NOTE: Currently, groups queries do not support the bounded memory variant. self.expr.supports_bounded_execution() && (!self.expr.uses_window_frame() || !(self.window_frame.start_bound.is_unbounded() || self.window_frame.end_bound.is_unbounded() - // Currently groups queries cannot use run with bounded memory - || matches!(self.window_frame.units, WindowFrameUnits::Groups))) + || matches!(self.window_frame.units, WindowFrameUnits::Groups))) } } diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index d3ce34ca5488f..587c313e31bd7 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -200,12 +200,11 @@ impl WindowExpr for SlidingAggregateWindowExpr { } fn uses_bounded_memory(&self) -> bool { - // Currently groups queries cannot use run with bounded memory - let is_group = matches!(self.window_frame.units, WindowFrameUnits::Groups); + // NOTE: Currently, groups queries do not support the bounded memory variant. self.aggregate.supports_bounded_execution() && !self.window_frame.start_bound.is_unbounded() && !self.window_frame.end_bound.is_unbounded() - && !is_group + && !matches!(self.window_frame.units, WindowFrameUnits::Groups) } } From d97a1ad4396f7eea3a704001300fe5f755e90ce8 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Sat, 31 Dec 2022 10:52:34 -0500 Subject: [PATCH 43/50] Simplify batch-merging utility functions --- datafusion/core/src/physical_plan/common.rs | 38 +++++++++++++------ .../windows/bounded_window_agg_exec.rs | 21 +++------- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index 9a12cb37942b3..f08652f7ca582 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -96,31 +96,45 @@ pub async fn collect(stream: SendableRecordBatchStream) -> Result ArrowResult { + let columns = (0..schema.fields.len()) + .map(|index| { + let first_column = first.column(index).as_ref(); + let second_column = second.column(index).as_ref(); + concat(&[first_column, second_column]) + }) + .collect::>>()?; + RecordBatch::try_new(schema, columns) +} + +/// Merge a slice of record batch references into a single record batch, or /// return None if the slice itself is empty. All the record batches inside the /// slice must have the same schema. -pub fn combine_batches( +pub fn merge_multiple_batches( batches: &[&RecordBatch], schema: SchemaRef, ) -> ArrowResult> { - if batches.is_empty() { - Ok(None) + Ok(if batches.is_empty() { + None } else { - let columns = schema - .fields() - .iter() - .enumerate() - .map(|(i, _)| { + let columns = (0..schema.fields.len()) + .map(|index| { concat( &batches .iter() - .map(|batch| batch.column(i).as_ref()) + .map(|batch| batch.column(index).as_ref()) .collect::>(), ) }) .collect::>>()?; - Ok(Some(RecordBatch::try_new(schema.clone(), columns)?)) - } + Some(RecordBatch::try_new(schema, columns)?) + }) } /// Recursively builds a list of files in a directory with a given extension diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index e728c586ac71a..2e3d00dd0a221 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -49,7 +49,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::physical_plan::common::combine_batches; +use crate::physical_plan::common::merge_batches; use datafusion_physical_expr::window::{ PartitionBatchState, PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowAggState, WindowState, @@ -386,15 +386,11 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { if let Some(partition_batch_state) = self.partition_buffers.get_mut(&partition_row) { - partition_batch_state.record_batch = combine_batches( - &[&partition_batch_state.record_batch, &partition_batch], + partition_batch_state.record_batch = merge_batches( + &partition_batch_state.record_batch, + &partition_batch, self.input.schema(), - )? - .ok_or_else(|| { - DataFusionError::Execution( - "Should contain at least one entry".to_string(), - ) - })?; + )?; } else { let partition_batch_state = PartitionBatchState { record_batch: partition_batch, @@ -414,12 +410,7 @@ impl PartitionByHandler for SortedPartitionByBoundedWindowStream { self.input_buffer = if self.input_buffer.num_rows() == 0 { record_batch } else { - combine_batches(&[&self.input_buffer, &record_batch], self.input.schema())? - .ok_or_else(|| { - DataFusionError::Execution( - "Should contain at least one entry".to_string(), - ) - })? + merge_batches(&self.input_buffer, &record_batch, self.input.schema())? }; Ok(()) From 29007ea8853555346fec353cbb2fb9d04d3bb637 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Mon, 2 Jan 2023 17:47:20 -0500 Subject: [PATCH 44/50] Remove unnecessary clones, simplify code --- .../windows/bounded_window_agg_exec.rs | 10 +++------- .../physical_plan/windows/window_agg_exec.rs | 10 +++------- datafusion/expr/src/accumulator.rs | 6 ------ .../src/window/partition_evaluator.rs | 6 ------ .../physical-expr/src/window/window_expr.rs | 19 ++++--------------- 5 files changed, 10 insertions(+), 41 deletions(-) diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index 2e3d00dd0a221..44c49408d7b24 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -277,13 +277,9 @@ fn create_schema( let mut fields = Vec::with_capacity(input_schema.fields().len() + window_expr.len()); fields.extend_from_slice(input_schema.fields()); // append results to the schema - window_expr - .iter() - .map(|e| { - fields.push(e.field()?); - Ok(()) - }) - .collect::>>()?; + for expr in window_expr { + fields.push(expr.field()?); + } Ok(Schema::new(fields)) } diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs index 2605219c740a1..bd413ad8eac9a 100644 --- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs @@ -278,13 +278,9 @@ fn create_schema( let mut fields = Vec::with_capacity(input_schema.fields().len() + window_expr.len()); fields.extend_from_slice(input_schema.fields()); // append results to the schema - window_expr - .iter() - .map(|e| { - fields.push(e.field()?); - Ok(()) - }) - .collect::>>()?; + for expr in window_expr { + fields.push(expr.field()?); + } Ok(Schema::new(fields)) } diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr/src/accumulator.rs index 899854c9db7ec..7e941d0cff97f 100644 --- a/datafusion/expr/src/accumulator.rs +++ b/datafusion/expr/src/accumulator.rs @@ -85,10 +85,4 @@ pub trait Accumulator: Send + Sync + Debug { /// Allocated means that for internal containers such as `Vec`, the `capacity` should be used /// not the `len` fn size(&self) -> usize; - - fn clone_dyn(&self) -> Result> { - Err(DataFusionError::NotImplemented( - "clone_dyn is not implemented by default for this accumulator".into(), - )) - } } diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index 679bfd791128b..4238dde09cf8f 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -89,10 +89,4 @@ pub trait PartitionEvaluator: Debug + Send + Sync { "evaluate_inside_range is not implemented by default".into(), )) } - - fn clone_dyn(&self) -> Result> { - Err(DataFusionError::NotImplemented( - "clone_dyn is not implemented by default for this evaluator".into(), - )) - } } diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index de060d6b6b9b9..2e1dd86d4592e 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -170,17 +170,6 @@ impl fmt::Debug for WindowFn { } } -impl Clone for WindowFn { - fn clone(&self) -> Self { - match self { - WindowFn::Builtin(builtin) => WindowFn::Builtin(builtin.clone_dyn().unwrap()), - WindowFn::Aggregate(aggregate) => { - WindowFn::Aggregate(aggregate.clone_dyn().unwrap()) - } - } - } -} - /// State for RANK(percent_rank, rank, dense_rank) /// builtin window function #[derive(Debug, Clone, Default)] @@ -218,7 +207,7 @@ pub enum BuiltinWindowState { #[default] Default, } -#[derive(Debug, Clone)] +#[derive(Debug)] pub enum WindowFunctionState { /// Different Aggregate functions may have different state definitions /// In [Accumulator] trait, [fn state(&self) -> Result>] implementation @@ -228,7 +217,7 @@ pub enum WindowFunctionState { BuiltinWindowState(BuiltinWindowState), } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct WindowAggState { /// The range that we calculate the window function pub window_frame_range: Range, @@ -248,7 +237,7 @@ pub struct WindowAggState { } /// State for each unique partition determined according to PARTITION BY column(s) -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct PartitionBatchState { /// The record_batch belonging to current partition pub record_batch: RecordBatch, @@ -261,7 +250,7 @@ pub struct PartitionBatchState { /// PartitionKey would consist of unique [a,b] pairs pub type PartitionKey = Vec; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct WindowState { pub state: WindowAggState, pub window_fn: WindowFn, From ac2f2489c7d4f3d1fb33b9587afbe2cad99e9f84 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Tue, 3 Jan 2023 19:02:41 +0300 Subject: [PATCH 45/50] update cargo lock file --- .github/workflows/rust.yml | 3 ++- datafusion-cli/Cargo.lock | 14 ++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index e3a58a47abf11..e25d5c0365664 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -64,7 +64,8 @@ jobs: - name: Check Cargo.lock for datafusion-cli run: | # If this test fails, try running `cargo update` in the `datafusion-cli` directory - cargo check --manifest-path datafusion-cli/Cargo.toml + # then commit updated Cargo.lock file + cargo check --manifest-path datafusion-cli/Cargo.toml --locked # test the crate linux-test: diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index decdaf2574fe1..367728e1c0ae9 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -495,9 +495,9 @@ dependencies = [ [[package]] name = "comfy-table" -version = "6.1.3" +version = "6.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e621e7e86c46fd8a14c32c6ae3cb95656621b4743a27d0cffedb831d46e7ad21" +checksum = "6e7b787b0dc42e8111badfdbe4c3059158ccb2db8780352fa1b01e8ccf45cc4d" dependencies = [ "strum", "strum_macros", @@ -673,6 +673,7 @@ dependencies = [ "futures", "glob", "hashbrown 0.13.1", + "indexmap", "itertools", "lazy_static", "log", @@ -764,6 +765,7 @@ dependencies = [ "datafusion-row", "half", "hashbrown 0.13.1", + "indexmap", "itertools", "lazy_static", "md-5", @@ -1114,9 +1116,9 @@ dependencies = [ [[package]] name = "half" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad6a9459c9c30b177b925162351f97e7d967c7ea8bab3b8352805327daf45554" +checksum = "6c467d36af040b7b2681f5fddd27427f6da8d3d072f575a265e181d2f8e8d157" dependencies = [ "crunchy", "num-traits", @@ -1718,9 +1720,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860" +checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66" [[package]] name = "ordered-float" From 0ca388990bfade3708f1cf5c4117f207fc2e5164 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Wed, 4 Jan 2023 10:23:56 +0300 Subject: [PATCH 46/50] address reviews --- .../windows/bounded_window_agg_exec.rs | 2 +- datafusion/core/tests/window_fuzz.rs | 106 +++++++----------- .../src/window/partition_evaluator.rs | 2 +- .../physical-expr/src/window/window_expr.rs | 25 ++--- 4 files changed, 49 insertions(+), 86 deletions(-) diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index 44c49408d7b24..5ed6a112c82f5 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -128,7 +128,7 @@ impl BoundedWindowAggExec { if let Some(a) = sort_keys.iter().find(|&e| e.expr.eq(item)) { result.push(a.clone()); } else { - return Err(DataFusionError::Execution( + return Err(DataFusionError::Internal( "Partition key not found in sort keys".to_string(), )); } diff --git a/datafusion/core/tests/window_fuzz.rs b/datafusion/core/tests/window_fuzz.rs index 7c6a99432bf68..471484af218d1 100644 --- a/datafusion/core/tests/window_fuzz.rs +++ b/datafusion/core/tests/window_fuzz.rs @@ -24,7 +24,6 @@ use arrow::util::pretty::pretty_format_batches; use hashbrown::HashMap; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; -use tokio::runtime::Builder; use datafusion::physical_plan::collect; use datafusion::physical_plan::memory::MemoryExec; @@ -46,74 +45,48 @@ use test_utils::add_empty_batches; mod tests { use super::*; - #[test] - fn single_order_by_test() { - let rt = Builder::new_multi_thread() - .worker_threads(8) - .build() - .unwrap(); + #[tokio::test(flavor = "multi_thread", worker_threads = 8)] + async fn single_order_by_test() { let n = 100; - let handles_low_cardinality = (1..n).map(|i| { - rt.spawn(run_window_test( - make_staggered_batches::(1000, i), - i, - vec!["a"], - vec![], - )) - }); - let handles_high_cardinality = (1..n).map(|i| { - rt.spawn(run_window_test( - make_staggered_batches::(1000, i), - i, - vec!["a"], - vec![], - )) - }); - let handles = handles_low_cardinality - .into_iter() - .chain(handles_high_cardinality.into_iter()) - .collect::>>(); - rt.block_on(async { - for handle in handles { - handle.await.unwrap(); + let distincts = vec![1, 100]; + for distinct in distincts { + let mut handles = Vec::new(); + for i in 1..n { + let job = tokio::spawn(run_window_test( + make_staggered_batches::(1000, distinct, i), + i, + vec!["a"], + vec![], + )); + handles.push(job); } - }); + for job in handles { + job.await.unwrap(); + } + } } - #[test] - fn order_by_with_partition_test() { - let rt = Builder::new_multi_thread() - .worker_threads(8) - .build() - .unwrap(); + #[tokio::test(flavor = "multi_thread", worker_threads = 8)] + async fn order_by_with_partition_test() { let n = 100; - // since we have sorted pairs (a,b) to not violate per partition soring - // partition should be field a, order by should be field b - let handles_low_cardinality = (1..n).map(|i| { - rt.spawn(run_window_test( - make_staggered_batches::(1000, i), - i, - vec!["b"], - vec!["a"], - )) - }); - let handles_high_cardinality = (1..n).map(|i| { - rt.spawn(run_window_test( - make_staggered_batches::(1000, i), - i, - vec!["b"], - vec!["a"], - )) - }); - let handles = handles_low_cardinality - .into_iter() - .chain(handles_high_cardinality.into_iter()) - .collect::>>(); - rt.block_on(async { - for handle in handles { - handle.await.unwrap(); + let distincts = vec![1, 100]; + for distinct in distincts { + // since we have sorted pairs (a,b) to not violate per partition soring + // partition should be field a, order by should be field b + let mut handles = Vec::new(); + for i in 1..n { + let job = tokio::spawn(run_window_test( + make_staggered_batches::(1000, distinct, i), + i, + vec!["b"], + vec!["a"], + )); + handles.push(job); } - }); + for job in handles { + job.await.unwrap(); + } + } } } @@ -358,8 +331,9 @@ async fn run_window_test( /// Return randomly sized record batches with: /// two sorted int32 columns 'a', 'b' ranged from 0..len / DISTINCT as columns /// two random int32 columns 'x', 'y' as other columns -fn make_staggered_batches( +fn make_staggered_batches( len: usize, + distinct: usize, random_seed: u64, ) -> Vec { // use a random number generator to pick a random sized output @@ -369,8 +343,8 @@ fn make_staggered_batches( let mut input4: Vec = vec![0; len]; input12.iter_mut().for_each(|v| { *v = ( - rng.gen_range(0..(len / DISTINCT)) as i32, - rng.gen_range(0..(len / DISTINCT)) as i32, + rng.gen_range(0..(len / distinct)) as i32, + rng.gen_range(0..(len / distinct)) as i32, ) }); rng.fill(&mut input3[..]); diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index 4238dde09cf8f..e6cead76d13d2 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -26,7 +26,7 @@ use std::fmt::Debug; use std::ops::Range; /// Partition evaluator -pub trait PartitionEvaluator: Debug + Send + Sync { +pub trait PartitionEvaluator: Debug + Send { /// Whether the evaluator should be evaluated with rank fn include_rank(&self) -> bool { false diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 2e1dd86d4592e..656b6723b0d69 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -26,7 +26,6 @@ use datafusion_common::{reverse_sort_options, DataFusionError, Result, ScalarVal use datafusion_expr::{Accumulator, WindowFrame}; use indexmap::IndexMap; use std::any::Any; -use std::fmt; use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; @@ -71,9 +70,10 @@ pub trait WindowExpr: Send + Sync + Debug { _partition_batches: &PartitionBatches, _window_agg_state: &mut PartitionWindowAggStates, ) -> Result<()> { - Err(DataFusionError::Internal( - "evaluate_stateful is not implemented".to_string(), - )) + Err(DataFusionError::Internal(format!( + "evaluate_stateful is not implemented for {}", + self.name() + ))) } /// evaluate the partition points given the sort columns; if the sort columns are @@ -152,24 +152,12 @@ pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec), Aggregate(Box), } -impl fmt::Debug for WindowFn { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - WindowFn::Builtin(builtin, ..) => { - write!(f, "partition evaluator: {:?}", builtin) - } - WindowFn::Aggregate(aggregate, ..) => { - write!(f, "accumulator: {:?}", aggregate) - } - } - } -} - /// State for RANK(percent_rank, rank, dense_rank) /// builtin window function #[derive(Debug, Clone, Default)] @@ -225,7 +213,8 @@ pub struct WindowAggState { pub last_calculated_index: usize, /// The offset of the deleted row number pub offset_pruned_rows: usize, - /// + /// State of the window function, required to calculate its result + // For instance, for ROW_NUMBER we keep the row index counter to generate correct result pub window_function_state: WindowFunctionState, /// Stores the results calculated by window frame pub out_col: ArrayRef, From 1e764dd8ebce430fd95324532271340c4e00d146 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Wed, 4 Jan 2023 11:49:01 +0300 Subject: [PATCH 47/50] update comments --- datafusion/core/src/execution/context.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 70882d8fa47d5..369960a350416 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -1480,14 +1480,14 @@ impl SessionState { // Since the transformations it applies may alter output partitioning properties of operators // (e.g. by swapping hash join sides), this rule runs before BasicEnforcement. Arc::new(PipelineFixer::new()), - // It's for adding essential repartition and local sorting operator to satisfy the - // required distribution and local sort. + // BasicEnforcement is for adding essential repartition and local sorting operators + // to satisfy the required distribution and local sort requirements. // Please make sure that the whole plan tree is determined. Arc::new(BasicEnforcement::new()), - // `BasicEnforcement` stage conservatively inserts `SortExec`s to satisfy ordering requirements. - // However, a deeper analysis may sometimes reveal that such a `SortExec` is actually unnecessary. - // These cases typically arise when we have reversible `WindowAggExec`s or deep subqueries. The - // rule below performs this analysis and removes unnecessary `SortExec`s. + // The BasicEnforcement stage conservatively inserts sorts to satisfy ordering requirements. + // However, a deeper analysis may sometimes reveal that such a sort is actually unnecessary. + // These cases typically arise when we have reversible window expressions or deep subqueries. + // The rule below performs this analysis and removes unnecessary sorts. Arc::new(OptimizeSorts::new()), // It will not influence the distribution and ordering of the whole plan tree. // Therefore, to avoid influencing other rules, it should be run at last. From 28d68bbcc18c321d6993612e842c3aa39f843a18 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Wed, 4 Jan 2023 13:00:33 +0300 Subject: [PATCH 48/50] resolve linter error --- datafusion/core/src/catalog/listing_schema.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/catalog/listing_schema.rs b/datafusion/core/src/catalog/listing_schema.rs index dcc15cb06471e..265e08f7a6a3a 100644 --- a/datafusion/core/src/catalog/listing_schema.rs +++ b/datafusion/core/src/catalog/listing_schema.rs @@ -237,7 +237,7 @@ mod tests { path: Path::new("/file"), is_dir: true, }; - assert!(table_path.to_string().expect("table path").ends_with("/")); + assert!(table_path.to_string().expect("table path").ends_with('/')); } #[test] @@ -246,6 +246,6 @@ mod tests { path: Path::new("/file"), is_dir: false, }; - assert!(!table_path.to_string().expect("table_path").ends_with("/")); + assert!(!table_path.to_string().expect("table_path").ends_with('/')); } } From c4b61c57128a39fbf890ff4e4344cc3d4e954900 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Wed, 4 Jan 2023 11:06:00 -0500 Subject: [PATCH 49/50] Tidy up comments after final review --- .github/workflows/rust.yml | 2 +- datafusion/core/src/execution/context.rs | 31 ++++++++++++------------ 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index e25d5c0365664..c4d8cd53306de 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -64,7 +64,7 @@ jobs: - name: Check Cargo.lock for datafusion-cli run: | # If this test fails, try running `cargo update` in the `datafusion-cli` directory - # then commit updated Cargo.lock file + # and check in the updated Cargo.lock file. cargo check --manifest-path datafusion-cli/Cargo.toml --locked # test the crate diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 369960a350416..42f5ae18ef305 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -1453,26 +1453,25 @@ impl SessionState { // We need to take care of the rule ordering. They may influence each other. let physical_optimizers: Vec> = vec![ Arc::new(AggregateStatistics::new()), - // - In order to increase the parallelism, it will change the output partitioning - // of some operators in the plan tree, which will influence other rules. - // Therefore, it should be run as soon as possible. - // - The reason to make it optional is - // - it's not used for the distributed engine, Ballista. - // - it's conflicted with some parts of the BasicEnforcement, since it will - // introduce additional repartitioning while the BasicEnforcement aims at - // reducing unnecessary repartitioning. + // In order to increase the parallelism, the Repartition rule will change the + // output partitioning of some operators in the plan tree, which will influence + // other rules. Therefore, it should run as soon as possible. It is optional because: + // - It's not used for the distributed engine, Ballista. + // - It's conflicted with some parts of the BasicEnforcement, since it will + // introduce additional repartitioning while the BasicEnforcement aims at + // reducing unnecessary repartitioning. Arc::new(Repartition::new()), - //- Currently it will depend on the partition number to decide whether to change the - // single node sort to parallel local sort and merge. Therefore, it should be run - // after the Repartition. - // - Since it will change the output ordering of some operators, it should be run + // - Currently it will depend on the partition number to decide whether to change the + // single node sort to parallel local sort and merge. Therefore, GlobalSortSelection + // should run after the Repartition. + // - Since it will change the output ordering of some operators, it should run // before JoinSelection and BasicEnforcement, which may depend on that. Arc::new(GlobalSortSelection::new()), - // Statistics-base join selection will change the Auto mode to real join implementation, + // Statistics-based join selection will change the Auto mode to a real join implementation, // like collect left, or hash join, or future sort merge join, which will // influence the BasicEnforcement to decide whether to add additional repartition // and local sort to meet the distribution and ordering requirements. - // Therefore, it should be run before BasicEnforcement + // Therefore, it should run before BasicEnforcement. Arc::new(JoinSelection::new()), // If the query is processing infinite inputs, the PipelineFixer rule applies the // necessary transformations to make the query runnable (if it is not already runnable). @@ -1489,8 +1488,8 @@ impl SessionState { // These cases typically arise when we have reversible window expressions or deep subqueries. // The rule below performs this analysis and removes unnecessary sorts. Arc::new(OptimizeSorts::new()), - // It will not influence the distribution and ordering of the whole plan tree. - // Therefore, to avoid influencing other rules, it should be run at last. + // The CoalesceBatches rule will not influence the distribution and ordering of the + // whole plan tree. Therefore, to avoid influencing other rules, it should run last. Arc::new(CoalesceBatches::new()), // The PipelineChecker rule will reject non-runnable query plans that use // pipeline-breaking operators on infinite input(s). The rule generates a From 8463418d83876300eef38ec9a4513d9cebaad9bf Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 4 Jan 2023 15:08:50 -0500 Subject: [PATCH 50/50] Minor: add link to upstream ticket --- datafusion/core/src/physical_plan/common.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index f08652f7ca582..d01ed5e2bbd09 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -98,6 +98,8 @@ pub async fn collect(stream: SendableRecordBatchStream) -> Result