diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c15615d59b082..c4d8cd53306de 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -64,6 +64,7 @@ jobs: - name: Check Cargo.lock for datafusion-cli run: | # If this test fails, try running `cargo update` in the `datafusion-cli` directory + # and check in the updated Cargo.lock file. cargo check --manifest-path datafusion-cli/Cargo.toml --locked # test the crate diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index b78d5af9c937d..be55ac735664d 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -673,6 +673,7 @@ dependencies = [ "futures", "glob", "hashbrown 0.13.1", + "indexmap", "itertools", "lazy_static", "log", @@ -764,6 +765,7 @@ dependencies = [ "datafusion-row", "half", "hashbrown 0.13.1", + "indexmap", "itertools", "lazy_static", "md-5", diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 31b85cb184136..44c25f593c4a7 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -75,6 +75,7 @@ flate2 = { version = "1.0.24", optional = true } futures = "0.3" glob = "0.3.0" hashbrown = { version = "0.13", features = ["raw"] } +indexmap = "1.9.2" itertools = "0.10" lazy_static = { version = "^1.4.0" } log = "^0.4" diff --git a/datafusion/core/src/catalog/listing_schema.rs b/datafusion/core/src/catalog/listing_schema.rs index dcc15cb06471e..265e08f7a6a3a 100644 --- a/datafusion/core/src/catalog/listing_schema.rs +++ b/datafusion/core/src/catalog/listing_schema.rs @@ -237,7 +237,7 @@ mod tests { path: Path::new("/file"), is_dir: true, }; - assert!(table_path.to_string().expect("table path").ends_with("/")); + assert!(table_path.to_string().expect("table path").ends_with('/')); } #[test] @@ -246,6 +246,6 @@ mod tests { path: Path::new("/file"), is_dir: false, }; - assert!(!table_path.to_string().expect("table_path").ends_with("/")); + assert!(!table_path.to_string().expect("table_path").ends_with('/')); } } diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 70882d8fa47d5..42f5ae18ef305 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -1453,26 +1453,25 @@ impl SessionState { // We need to take care of the rule ordering. They may influence each other. let physical_optimizers: Vec> = vec![ Arc::new(AggregateStatistics::new()), - // - In order to increase the parallelism, it will change the output partitioning - // of some operators in the plan tree, which will influence other rules. - // Therefore, it should be run as soon as possible. - // - The reason to make it optional is - // - it's not used for the distributed engine, Ballista. - // - it's conflicted with some parts of the BasicEnforcement, since it will - // introduce additional repartitioning while the BasicEnforcement aims at - // reducing unnecessary repartitioning. + // In order to increase the parallelism, the Repartition rule will change the + // output partitioning of some operators in the plan tree, which will influence + // other rules. Therefore, it should run as soon as possible. It is optional because: + // - It's not used for the distributed engine, Ballista. + // - It's conflicted with some parts of the BasicEnforcement, since it will + // introduce additional repartitioning while the BasicEnforcement aims at + // reducing unnecessary repartitioning. Arc::new(Repartition::new()), - //- Currently it will depend on the partition number to decide whether to change the - // single node sort to parallel local sort and merge. Therefore, it should be run - // after the Repartition. - // - Since it will change the output ordering of some operators, it should be run + // - Currently it will depend on the partition number to decide whether to change the + // single node sort to parallel local sort and merge. Therefore, GlobalSortSelection + // should run after the Repartition. + // - Since it will change the output ordering of some operators, it should run // before JoinSelection and BasicEnforcement, which may depend on that. Arc::new(GlobalSortSelection::new()), - // Statistics-base join selection will change the Auto mode to real join implementation, + // Statistics-based join selection will change the Auto mode to a real join implementation, // like collect left, or hash join, or future sort merge join, which will // influence the BasicEnforcement to decide whether to add additional repartition // and local sort to meet the distribution and ordering requirements. - // Therefore, it should be run before BasicEnforcement + // Therefore, it should run before BasicEnforcement. Arc::new(JoinSelection::new()), // If the query is processing infinite inputs, the PipelineFixer rule applies the // necessary transformations to make the query runnable (if it is not already runnable). @@ -1480,17 +1479,17 @@ impl SessionState { // Since the transformations it applies may alter output partitioning properties of operators // (e.g. by swapping hash join sides), this rule runs before BasicEnforcement. Arc::new(PipelineFixer::new()), - // It's for adding essential repartition and local sorting operator to satisfy the - // required distribution and local sort. + // BasicEnforcement is for adding essential repartition and local sorting operators + // to satisfy the required distribution and local sort requirements. // Please make sure that the whole plan tree is determined. Arc::new(BasicEnforcement::new()), - // `BasicEnforcement` stage conservatively inserts `SortExec`s to satisfy ordering requirements. - // However, a deeper analysis may sometimes reveal that such a `SortExec` is actually unnecessary. - // These cases typically arise when we have reversible `WindowAggExec`s or deep subqueries. The - // rule below performs this analysis and removes unnecessary `SortExec`s. + // The BasicEnforcement stage conservatively inserts sorts to satisfy ordering requirements. + // However, a deeper analysis may sometimes reveal that such a sort is actually unnecessary. + // These cases typically arise when we have reversible window expressions or deep subqueries. + // The rule below performs this analysis and removes unnecessary sorts. Arc::new(OptimizeSorts::new()), - // It will not influence the distribution and ordering of the whole plan tree. - // Therefore, to avoid influencing other rules, it should be run at last. + // The CoalesceBatches rule will not influence the distribution and ordering of the + // whole plan tree. Therefore, to avoid influencing other rules, it should run last. Arc::new(CoalesceBatches::new()), // The PipelineChecker rule will reject non-runnable query plans that use // pipeline-breaking operators on infinite input(s). The rule generates a diff --git a/datafusion/core/src/physical_optimizer/optimize_sorts.rs b/datafusion/core/src/physical_optimizer/optimize_sorts.rs index a47026cc773d4..17b27bfa7d0ae 100644 --- a/datafusion/core/src/physical_optimizer/optimize_sorts.rs +++ b/datafusion/core/src/physical_optimizer/optimize_sorts.rs @@ -33,10 +33,11 @@ use crate::physical_optimizer::utils::{ use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::rewrite::TreeNodeRewritable; use crate::physical_plan::sorts::sort::SortExec; -use crate::physical_plan::windows::WindowAggExec; +use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; use arrow::datatypes::SchemaRef; use datafusion_common::{reverse_sort_options, DataFusionError}; +use datafusion_physical_expr::window::WindowExpr; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use itertools::izip; use std::iter::zip; @@ -181,17 +182,32 @@ fn optimize_sorts( sort_exec.input().equivalence_properties() }) { update_child_to_remove_unnecessary_sort(child, sort_onwards)?; - } else if let Some(window_agg_exec) = + } + // For window expressions, we can remove some sorts when we can + // calculate the result in reverse: + else if let Some(exec) = requirements.plan.as_any().downcast_ref::() { - // For window expressions, we can remove some sorts when we can - // calculate the result in reverse: - if let Some(res) = analyze_window_sort_removal( - window_agg_exec, + if let Some(result) = analyze_window_sort_removal( + exec.window_expr(), + &exec.partition_keys, + sort_exec, + sort_onwards, + )? { + return Ok(Some(result)); + } + } else if let Some(exec) = requirements + .plan + .as_any() + .downcast_ref::() + { + if let Some(result) = analyze_window_sort_removal( + exec.window_expr(), + &exec.partition_keys, sort_exec, sort_onwards, )? { - return Ok(Some(res)); + return Ok(Some(result)); } } // TODO: Once we can ensure that required ordering information propagates with @@ -273,9 +289,11 @@ fn analyze_immediate_sort_removal( Ok(None) } -/// Analyzes a `WindowAggExec` to determine whether it may allow removing a sort. +/// Analyzes a [WindowAggExec] or a [BoundedWindowAggExec] to determine whether +/// it may allow removing a sort. fn analyze_window_sort_removal( - window_agg_exec: &WindowAggExec, + window_expr: &[Arc], + partition_keys: &[Arc], sort_exec: &SortExec, sort_onward: &mut Vec<(usize, Arc)>, ) -> Result> { @@ -289,7 +307,6 @@ fn analyze_window_sort_removal( // If there is no physical ordering, there is no way to remove a sort -- immediately return: return Ok(None); }; - let window_expr = window_agg_exec.window_expr(); let (can_skip_sorting, should_reverse) = can_skip_sort( window_expr[0].partition_by(), required_ordering, @@ -308,13 +325,26 @@ fn analyze_window_sort_removal( if let Some(window_expr) = new_window_expr { let new_child = remove_corresponding_sort_from_sub_plan(sort_onward)?; let new_schema = new_child.schema(); - let new_plan = Arc::new(WindowAggExec::try_new( - window_expr, - new_child, - new_schema, - window_agg_exec.partition_keys.clone(), - Some(physical_ordering.to_vec()), - )?); + + let uses_bounded_memory = window_expr.iter().all(|e| e.uses_bounded_memory()); + // If all window exprs can run with bounded memory choose bounded window variant + let new_plan = if uses_bounded_memory { + Arc::new(BoundedWindowAggExec::try_new( + window_expr, + new_child, + new_schema, + partition_keys.to_vec(), + Some(physical_ordering.to_vec()), + )?) as _ + } else { + Arc::new(WindowAggExec::try_new( + window_expr, + new_child, + new_schema, + partition_keys.to_vec(), + Some(physical_ordering.to_vec()), + )?) as _ + }; return Ok(Some(PlanWithCorrespondingSort::new(new_plan))); } } diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index c35ef29f20c9b..96f0b0ff6932d 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -301,7 +301,7 @@ mod sql_tests { let case = QueryCase { sql: "SELECT c9, - SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1 + SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 FROM test LIMIT 5".to_string(), cases: vec![Arc::new(test1), Arc::new(test2)], @@ -325,7 +325,7 @@ mod sql_tests { let case = QueryCase { sql: "SELECT c9, - SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1 + SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 FROM test".to_string(), cases: vec![Arc::new(test1), Arc::new(test2)], error_operator: "Window Error".to_string() diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index 7ea8dfe35bc6d..d01ed5e2bbd09 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -22,6 +22,7 @@ use crate::error::{DataFusionError, Result}; use crate::execution::context::TaskContext; use crate::physical_plan::metrics::MemTrackingMetrics; use crate::physical_plan::{displayable, ColumnStatistics, ExecutionPlan, Statistics}; +use arrow::compute::concat; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::ArrowError; use arrow::error::Result as ArrowResult; @@ -95,6 +96,51 @@ pub async fn collect(stream: SendableRecordBatchStream) -> Result ArrowResult { + let columns = (0..schema.fields.len()) + .map(|index| { + let first_column = first.column(index).as_ref(); + let second_column = second.column(index).as_ref(); + concat(&[first_column, second_column]) + }) + .collect::>>()?; + RecordBatch::try_new(schema, columns) +} + +/// Merge a slice of record batch references into a single record batch, or +/// return None if the slice itself is empty. All the record batches inside the +/// slice must have the same schema. +/// +/// Can use concat_batches after https://github.com/apache/arrow-rs/issues/3456 +pub fn merge_multiple_batches( + batches: &[&RecordBatch], + schema: SchemaRef, +) -> ArrowResult> { + Ok(if batches.is_empty() { + None + } else { + let columns = (0..schema.fields.len()) + .map(|index| { + concat( + &batches + .iter() + .map(|batch| batch.column(index).as_ref()) + .collect::>(), + ) + }) + .collect::>>()?; + Some(RecordBatch::try_new(schema, columns)?) + }) +} + /// Recursively builds a list of files in a directory with a given extension pub fn build_checked_file_list(dir: &str, ext: &str) -> Result> { let mut filenames: Vec = Vec::new(); diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 978ed195bb062..217ab2aa41caf 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -47,7 +47,7 @@ use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; -use crate::physical_plan::windows::WindowAggExec; +use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{joins::utils as join_utils, Partitioning}; use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr}; use crate::{ @@ -614,13 +614,28 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; - Ok(Arc::new(WindowAggExec::try_new( - window_expr, - input_exec, - physical_input_schema, - physical_partition_keys, - physical_sort_keys, - )?)) + let uses_bounded_memory = window_expr + .iter() + .all(|e| e.uses_bounded_memory()); + // If all window expressions can run with bounded memory, + // choose the bounded window variant: + Ok(if uses_bounded_memory { + Arc::new(BoundedWindowAggExec::try_new( + window_expr, + input_exec, + physical_input_schema, + physical_partition_keys, + physical_sort_keys, + )?) + } else { + Arc::new(WindowAggExec::try_new( + window_expr, + input_exec, + physical_input_schema, + physical_partition_keys, + physical_sort_keys, + )?) + }) } LogicalPlan::Aggregate(Aggregate { input, diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs new file mode 100644 index 0000000000000..5ed6a112c82f5 --- /dev/null +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -0,0 +1,705 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Stream and channel implementations for window function expressions. +//! The executor given here uses bounded memory (does not maintain all +//! the input data seen so far), which makes it appropriate when processing +//! infinite inputs. + +use crate::error::Result; +use crate::execution::context::TaskContext; +use crate::physical_plan::expressions::PhysicalSortExpr; +use crate::physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, +}; +use crate::physical_plan::{ + ColumnStatistics, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, +}; +use arrow::array::Array; +use arrow::compute::{concat, lexicographical_partition_ranges, SortColumn}; +use arrow::{ + array::ArrayRef, + datatypes::{Schema, SchemaRef}, + error::Result as ArrowResult, + record_batch::RecordBatch, +}; +use datafusion_common::{DataFusionError, ScalarValue}; +use futures::stream::Stream; +use futures::{ready, StreamExt}; +use std::any::Any; +use std::cmp::min; +use std::collections::HashMap; +use std::ops::Range; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::physical_plan::common::merge_batches; +use datafusion_physical_expr::window::{ + PartitionBatchState, PartitionBatches, PartitionKey, PartitionWindowAggStates, + WindowAggState, WindowState, +}; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; +use indexmap::IndexMap; +use log::debug; + +/// Window execution plan +#[derive(Debug)] +pub struct BoundedWindowAggExec { + /// Input plan + input: Arc, + /// Window function expression + window_expr: Vec>, + /// Schema after the window is run + schema: SchemaRef, + /// Schema before the window + input_schema: SchemaRef, + /// Partition Keys + pub partition_keys: Vec>, + /// Sort Keys + pub sort_keys: Option>, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +impl BoundedWindowAggExec { + /// Create a new execution plan for window aggregates + pub fn try_new( + window_expr: Vec>, + input: Arc, + input_schema: SchemaRef, + partition_keys: Vec>, + sort_keys: Option>, + ) -> Result { + let schema = create_schema(&input_schema, &window_expr)?; + let schema = Arc::new(schema); + Ok(Self { + input, + window_expr, + schema, + input_schema, + partition_keys, + sort_keys, + metrics: ExecutionPlanMetricsSet::new(), + }) + } + + /// Window expressions + pub fn window_expr(&self) -> &[Arc] { + &self.window_expr + } + + /// Input plan + pub fn input(&self) -> &Arc { + &self.input + } + + /// Get the input schema before any window functions are applied + pub fn input_schema(&self) -> SchemaRef { + self.input_schema.clone() + } + + /// Return the output sort order of partition keys: For example + /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a + // We are sure that partition by columns are always at the beginning of sort_keys + // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely + // to calculate partition separation points + pub fn partition_by_sort_keys(&self) -> Result> { + let mut result = vec![]; + // All window exprs have the same partition by, so we just use the first one: + let partition_by = self.window_expr()[0].partition_by(); + let sort_keys = self.sort_keys.as_deref().unwrap_or(&[]); + for item in partition_by { + if let Some(a) = sort_keys.iter().find(|&e| e.expr.eq(item)) { + result.push(a.clone()); + } else { + return Err(DataFusionError::Internal( + "Partition key not found in sort keys".to_string(), + )); + } + } + Ok(result) + } +} + +impl ExecutionPlan for BoundedWindowAggExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + // As we can have repartitioning using the partition keys, this can + // be either one or more than one, depending on the presence of + // repartitioning. + self.input.output_partitioning() + } + + fn unbounded_output(&self, children: &[bool]) -> Result { + Ok(children[0]) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.input().output_ordering() + } + + fn required_input_ordering(&self) -> Vec> { + let sort_keys = self.sort_keys.as_deref(); + vec![sort_keys] + } + + fn required_input_distribution(&self) -> Vec { + if self.partition_keys.is_empty() { + debug!("No partition defined for BoundedWindowAggExec!!!"); + vec![Distribution::SinglePartition] + } else { + //TODO support PartitionCollections if there is no common partition columns in the window_expr + vec![Distribution::HashPartitioned(self.partition_keys.clone())] + } + } + + fn equivalence_properties(&self) -> EquivalenceProperties { + self.input().equivalence_properties() + } + + fn maintains_input_order(&self) -> bool { + true + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(BoundedWindowAggExec::try_new( + self.window_expr.clone(), + children[0].clone(), + self.input_schema.clone(), + self.partition_keys.clone(), + self.sort_keys.clone(), + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let input = self.input.execute(partition, context)?; + let stream = Box::pin(SortedPartitionByBoundedWindowStream::new( + self.schema.clone(), + self.window_expr.clone(), + input, + BaselineMetrics::new(&self.metrics, partition), + self.partition_by_sort_keys()?, + )); + Ok(stream) + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + write!(f, "BoundedWindowAggExec: ")?; + let g: Vec = self + .window_expr + .iter() + .map(|e| { + format!( + "{}: {:?}, frame: {:?}", + e.name().to_owned(), + e.field(), + e.get_window_frame() + ) + }) + .collect(); + write!(f, "wdw=[{}]", g.join(", "))?; + } + } + Ok(()) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Statistics { + let input_stat = self.input.statistics(); + let win_cols = self.window_expr.len(); + let input_cols = self.input_schema.fields().len(); + // TODO stats: some windowing function will maintain invariants such as min, max... + let mut column_statistics = Vec::with_capacity(win_cols + input_cols); + if let Some(input_col_stats) = input_stat.column_statistics { + column_statistics.extend(input_col_stats); + } else { + column_statistics.extend(vec![ColumnStatistics::default(); input_cols]); + } + column_statistics.extend(vec![ColumnStatistics::default(); win_cols]); + Statistics { + is_exact: input_stat.is_exact, + num_rows: input_stat.num_rows, + column_statistics: Some(column_statistics), + total_byte_size: None, + } + } +} + +fn create_schema( + input_schema: &Schema, + window_expr: &[Arc], +) -> Result { + let mut fields = Vec::with_capacity(input_schema.fields().len() + window_expr.len()); + fields.extend_from_slice(input_schema.fields()); + // append results to the schema + for expr in window_expr { + fields.push(expr.field()?); + } + Ok(Schema::new(fields)) +} + +/// This trait defines the interface for updating the state and calculating +/// results for window functions. Depending on the partitioning scheme, one +/// may have different implementations for the functions within. +pub trait PartitionByHandler { + /// Constructs output columns from window_expression results. + fn calculate_out_columns(&self) -> Result>>; + /// Prunes the window state to remove any unnecessary information + /// given how many rows we emitted so far. + fn prune_state(&mut self, n_out: usize) -> Result<()>; + /// Updates record batches for each partition when new batches are + /// received. + fn update_partition_batch(&mut self, record_batch: RecordBatch) -> Result<()>; +} + +/// stream for window aggregation plan +/// assuming partition by column is sorted (or without PARTITION BY expression) +pub struct SortedPartitionByBoundedWindowStream { + schema: SchemaRef, + input: SendableRecordBatchStream, + /// The record batch executor receives as input (i.e. the columns needed + /// while calculating aggregation results). + input_buffer: RecordBatch, + /// We separate `input_buffer_record_batch` based on partitions (as + /// determined by PARTITION BY columns) and store them per partition + /// in `partition_batches`. We use this variable when calculating results + /// for each window expression. This enables us to use the same batch for + /// different window expressions without copying. + // Note that we could keep record batches for each window expression in + // `PartitionWindowAggStates`. However, this would use more memory (as + // many times as the number of window expressions). + partition_buffers: PartitionBatches, + /// An executor can run multiple window expressions if the PARTITION BY + /// and ORDER BY sections are same. We keep state of the each window + /// expression inside `window_agg_states`. + window_agg_states: Vec, + finished: bool, + window_expr: Vec>, + partition_by_sort_keys: Vec, + baseline_metrics: BaselineMetrics, +} + +impl PartitionByHandler for SortedPartitionByBoundedWindowStream { + /// This method constructs output columns using the result of each window expression + fn calculate_out_columns(&self) -> Result>> { + let n_out = self.calculate_n_out_row(); + if n_out == 0 { + Ok(None) + } else { + self.input_buffer + .columns() + .iter() + .map(|elem| Ok(elem.slice(0, n_out))) + .chain( + self.window_agg_states + .iter() + .map(|elem| get_aggregate_result_out_column(elem, n_out)), + ) + .collect::>>() + .map(Some) + } + } + + /// Prunes sections of the state that are no longer needed when calculating + /// results (as determined by window frame boundaries and number of results generated). + // For instance, if first `n` (not necessarily same with `n_out`) elements are no longer needed to + // calculate window expression result (outside the window frame boundary) we retract first `n` elements + // from `self.partition_batches` in corresponding partition. + // For instance, if `n_out` number of rows are calculated, we can remove + // first `n_out` rows from `self.input_buffer_record_batch`. + fn prune_state(&mut self, n_out: usize) -> Result<()> { + // Prune `self.partition_batches`: + self.prune_partition_batches()?; + // Prune `self.input_buffer_record_batch`: + self.prune_input_batch(n_out)?; + // Prune `self.window_agg_states`: + self.prune_out_columns(n_out)?; + Ok(()) + } + + fn update_partition_batch(&mut self, record_batch: RecordBatch) -> Result<()> { + let partition_columns = self.partition_columns(&record_batch)?; + let num_rows = record_batch.num_rows(); + if num_rows > 0 { + let partition_points = + self.evaluate_partition_points(num_rows, &partition_columns)?; + for partition_range in partition_points { + let partition_row = partition_columns + .iter() + .map(|arr| { + ScalarValue::try_from_array(&arr.values, partition_range.start) + }) + .collect::>()?; + let partition_batch = record_batch.slice( + partition_range.start, + partition_range.end - partition_range.start, + ); + if let Some(partition_batch_state) = + self.partition_buffers.get_mut(&partition_row) + { + partition_batch_state.record_batch = merge_batches( + &partition_batch_state.record_batch, + &partition_batch, + self.input.schema(), + )?; + } else { + let partition_batch_state = PartitionBatchState { + record_batch: partition_batch, + is_end: false, + }; + self.partition_buffers + .insert(partition_row, partition_batch_state); + }; + } + } + let n_partitions = self.partition_buffers.len(); + for (idx, (_, partition_batch_state)) in + self.partition_buffers.iter_mut().enumerate() + { + partition_batch_state.is_end |= idx < n_partitions - 1; + } + self.input_buffer = if self.input_buffer.num_rows() == 0 { + record_batch + } else { + merge_batches(&self.input_buffer, &record_batch, self.input.schema())? + }; + + Ok(()) + } +} + +impl Stream for SortedPartitionByBoundedWindowStream { + type Item = ArrowResult; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let poll = self.poll_next_inner(cx); + self.baseline_metrics.record_poll(poll) + } +} + +impl SortedPartitionByBoundedWindowStream { + /// Create a new BoundedWindowAggStream + pub fn new( + schema: SchemaRef, + window_expr: Vec>, + input: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, + partition_by_sort_keys: Vec, + ) -> Self { + let state = window_expr.iter().map(|_| IndexMap::new()).collect(); + let empty_batch = RecordBatch::new_empty(schema.clone()); + Self { + schema, + input, + input_buffer: empty_batch, + partition_buffers: IndexMap::new(), + window_agg_states: state, + finished: false, + window_expr, + baseline_metrics, + partition_by_sort_keys, + } + } + + fn compute_aggregates(&mut self) -> ArrowResult { + // calculate window cols + for (cur_window_expr, state) in + self.window_expr.iter().zip(&mut self.window_agg_states) + { + cur_window_expr.evaluate_stateful(&self.partition_buffers, state)?; + } + + let schema = self.schema.clone(); + let columns_to_show = self.calculate_out_columns()?; + if let Some(columns_to_show) = columns_to_show { + let n_generated = columns_to_show[0].len(); + self.prune_state(n_generated)?; + RecordBatch::try_new(schema, columns_to_show) + } else { + Ok(RecordBatch::new_empty(schema)) + } + } + + #[inline] + fn poll_next_inner( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.finished { + return Poll::Ready(None); + } + + let result = match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + self.update_partition_batch(batch)?; + self.compute_aggregates() + } + Some(Err(e)) => Err(e), + None => { + self.finished = true; + for (_, partition_batch_state) in self.partition_buffers.iter_mut() { + partition_batch_state.is_end = true; + } + self.compute_aggregates() + } + }; + Poll::Ready(Some(result)) + } + + /// Calculates how many rows [SortedPartitionByBoundedWindowStream] + /// can produce as output. + fn calculate_n_out_row(&self) -> usize { + // Different window aggregators may produce results with different rates. + // We produce the overall batch result with the same speed as slowest one. + self.window_agg_states + .iter() + .map(|window_agg_state| { + // Store how many elements are generated for the current + // window expression: + let mut cur_window_expr_out_result_len = 0; + // We iterate over `window_agg_state`, which is an IndexMap. + // Iterations follow the insertion order, hence we preserve + // sorting when partition columns are sorted. + for (_, WindowState { state, .. }) in window_agg_state.iter() { + cur_window_expr_out_result_len += state.out_col.len(); + // If we do not generate all results for the current + // partition, we do not generate results for next + // partition -- otherwise we will lose input ordering. + if state.n_row_result_missing > 0 { + break; + } + } + cur_window_expr_out_result_len + }) + .into_iter() + .min() + .unwrap_or(0) + } + + /// Prunes the sections of the record batch (for each partition) + /// that we no longer need to calculate the window function result. + fn prune_partition_batches(&mut self) -> Result<()> { + // Remove partitions which we know already ended (is_end flag is true). + // Since the retain method preserves insertion order, we still have + // ordering in between partitions after removal. + self.partition_buffers + .retain(|_, partition_batch_state| !partition_batch_state.is_end); + + // The data in `self.partition_batches` is used by all window expressions. + // Therefore, when removing from `self.partition_batches`, we need to remove + // from the earliest range boundary among all window expressions. Variable + // `n_prune_each_partition` fill the earliest range boundary information for + // each partition. This way, we can delete the no-longer-needed sections from + // `self.partition_batches`. + // For instance, if window frame one uses [10, 20] and window frame two uses + // [5, 15]; we only prune the first 5 elements from the corresponding record + // batch in `self.partition_batches`. + + // Calculate how many elements to prune for each partition batch + let mut n_prune_each_partition: HashMap = HashMap::new(); + for window_agg_state in self.window_agg_states.iter_mut() { + window_agg_state.retain(|_, WindowState { state, .. }| !state.is_end); + for (partition_row, WindowState { state: value, .. }) in window_agg_state { + if let Some(state) = n_prune_each_partition.get_mut(partition_row) { + if value.window_frame_range.start < *state { + *state = value.window_frame_range.start; + } + } else { + n_prune_each_partition + .insert(partition_row.clone(), value.window_frame_range.start); + } + } + } + + let err = || DataFusionError::Execution("Expects to have partition".to_string()); + // Retract no longer needed parts during window calculations from partition batch: + for (partition_row, n_prune) in n_prune_each_partition.iter() { + let partition_batch_state = self + .partition_buffers + .get_mut(partition_row) + .ok_or_else(err)?; + let batch = &partition_batch_state.record_batch; + partition_batch_state.record_batch = + batch.slice(*n_prune, batch.num_rows() - n_prune); + + // Update state indices since we have pruned some rows from the beginning: + for window_agg_state in self.window_agg_states.iter_mut() { + let window_state = + window_agg_state.get_mut(partition_row).ok_or_else(err)?; + let mut state = &mut window_state.state; + state.window_frame_range = Range { + start: state.window_frame_range.start - n_prune, + end: state.window_frame_range.end - n_prune, + }; + state.last_calculated_index -= n_prune; + state.offset_pruned_rows += n_prune; + } + } + Ok(()) + } + + /// Prunes the section of the input batch whose aggregate results + /// are calculated and emitted. + fn prune_input_batch(&mut self, n_out: usize) -> Result<()> { + let n_to_keep = self.input_buffer.num_rows() - n_out; + let batch_to_keep = self + .input_buffer + .columns() + .iter() + .map(|elem| elem.slice(n_out, n_to_keep)) + .collect::>(); + self.input_buffer = + RecordBatch::try_new(self.input_buffer.schema(), batch_to_keep)?; + Ok(()) + } + + /// Prunes emitted parts from WindowAggState `out_col` field. + fn prune_out_columns(&mut self, n_out: usize) -> Result<()> { + // We store generated columns for each window expression in the `out_col` + // field of `WindowAggState`. Given how many rows are emitted, we remove + // these sections from state. + for partition_window_agg_states in self.window_agg_states.iter_mut() { + let mut running_length = 0; + // Remove `n_out` entries from the `out_col` field of `WindowAggState`. + // Preserve per partition ordering by iterating in the order of insertion. + // Do not generate a result for a new partition without emitting all results + // for the current partition. + for ( + _, + WindowState { + state: WindowAggState { out_col, .. }, + .. + }, + ) in partition_window_agg_states + { + if running_length < n_out { + let n_to_del = min(out_col.len(), n_out - running_length); + let n_to_keep = out_col.len() - n_to_del; + *out_col = out_col.slice(n_to_del, n_to_keep); + running_length += n_to_del; + } + } + } + Ok(()) + } + + /// Get Partition Columns + pub fn partition_columns(&self, batch: &RecordBatch) -> Result> { + self.partition_by_sort_keys + .iter() + .map(|e| e.evaluate_to_sort_column(batch)) + .collect::>>() + } + + /// evaluate the partition points given the sort columns; if the sort columns are + /// empty then the result will be a single element vec of the whole column rows. + fn evaluate_partition_points( + &self, + num_rows: usize, + partition_columns: &[SortColumn], + ) -> Result>> { + Ok(if partition_columns.is_empty() { + vec![Range { + start: 0, + end: num_rows, + }] + } else { + lexicographical_partition_ranges(partition_columns) + .map_err(DataFusionError::ArrowError)? + .collect::>() + }) + } +} + +impl RecordBatchStream for SortedPartitionByBoundedWindowStream { + /// Get the schema + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +/// Calculates the section we can show results for expression +fn get_aggregate_result_out_column( + partition_window_agg_states: &PartitionWindowAggStates, + len_to_show: usize, +) -> Result { + let mut result = None; + let mut running_length = 0; + // We assume that iteration order is according to insertion order + for ( + _, + WindowState { + state: WindowAggState { out_col, .. }, + .. + }, + ) in partition_window_agg_states + { + if running_length < len_to_show { + let n_to_use = min(len_to_show - running_length, out_col.len()); + let slice_to_use = out_col.slice(0, n_to_use); + result = Some(match result { + Some(arr) => concat(&[&arr, &slice_to_use])?, + None => slice_to_use, + }); + running_length += n_to_use; + } else { + break; + } + } + if running_length != len_to_show { + return Err(DataFusionError::Execution(format!( + "Generated row number should be {}, it is {}", + len_to_show, running_length + ))); + } + result + .ok_or_else(|| DataFusionError::Execution("Should contain something".to_string())) +} diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index bbf6c91821f4d..2d7aa0494d323 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -39,8 +39,10 @@ use datafusion_physical_expr::window::{ use std::convert::TryInto; use std::sync::Arc; +mod bounded_window_agg_exec; mod window_agg_exec; +pub use bounded_window_agg_exec::BoundedWindowAggExec; pub use datafusion_physical_expr::window::{ AggregateWindowExpr, BuiltInWindowExpr, WindowExpr, }; diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index d3a2043f1e8ac..5ca49cff2883a 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -16,6 +16,8 @@ // under the License. use super::*; +use ::parquet::arrow::arrow_writer::ArrowWriter; +use ::parquet::file::properties::WriterProperties; /// for window functions without order by the first, last, and nth function call does not make sense #[tokio::test] @@ -1757,11 +1759,11 @@ async fn test_window_partition_by_order_by() -> Result<()> { let expected = { vec![ "ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(UInt8(1))]", - " WindowAggExec: wdw=[COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 2)", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }, Column { name: \"c2\", index: 1 }], 2)", @@ -1800,8 +1802,8 @@ async fn test_window_agg_sort_reversed_plan() -> Result<()> { "ProjectionExec: expr=[c9@0 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c9@0 DESC]", ] }; @@ -1856,8 +1858,8 @@ async fn test_window_agg_sort_reversed_plan_builtin() -> Result<()> { "ProjectionExec: expr=[c9@0 as c9, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as fv1, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as lead2]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt32(NULL)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt32(NULL)) }]", - " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt32(NULL)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt32(NULL)) }]", + " BoundedWindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }]", " SortExec: [c9@0 DESC]", ] }; @@ -1908,9 +1910,9 @@ async fn test_window_agg_sort_non_reversed_plan() -> Result<()> { "ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as rn1, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c9@0 ASC NULLS LAST]", - " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c9@0 DESC]", ] }; @@ -1962,10 +1964,10 @@ async fn test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> { "ProjectionExec: expr=[c9@2 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum2, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as rn2]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c9@2 ASC NULLS LAST,c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", - " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c9@2 DESC,c1@0 DESC]", ] }; @@ -2099,8 +2101,8 @@ async fn test_window_agg_sort_orderby_reversed_partitionby_plan() -> Result<()> "ProjectionExec: expr=[c9@1 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum2]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]", ] }; @@ -2154,8 +2156,8 @@ async fn test_window_agg_sort_partitionby_reversed_plan() -> Result<()> { "ProjectionExec: expr=[c9@1 as c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum2]", " RepartitionExec: partitioning=RoundRobinBatch(2)", " GlobalLimitExec: skip=0, fetch=5", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]", ] }; @@ -2351,3 +2353,251 @@ async fn test_window_agg_sort_orderby_reversed_partitionby_reversed_plan() -> Re Ok(()) } + +fn write_test_data_to_parquet(tmpdir: &TempDir, n_file: usize) -> Result<()> { + let ts_field = Field::new("ts", DataType::Int32, false); + let inc_field = Field::new("inc_col", DataType::Int32, false); + let desc_field = Field::new("desc_col", DataType::Int32, false); + + let schema = Arc::new(Schema::new(vec![ts_field, inc_field, desc_field])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from_slice([ + 1, 1, 5, 9, 10, 11, 16, 21, 22, 26, 26, 28, 31, 33, 38, 42, 47, 51, 53, + 53, 58, 63, 67, 68, 70, 72, 72, 76, 81, 85, 86, 88, 91, 96, 97, 98, 100, + 101, 102, 104, 104, 108, 112, 113, 113, 114, 114, 117, 122, 126, 131, + 131, 136, 136, 136, 139, 141, 146, 147, 147, 152, 154, 159, 161, 163, + 164, 167, 172, 173, 177, 180, 185, 186, 191, 195, 195, 199, 203, 207, + 210, 213, 218, 221, 224, 226, 230, 232, 235, 238, 238, 239, 244, 245, + 247, 250, 254, 258, 262, 264, 264, + ])), + Arc::new(Int32Array::from_slice([ + 1, 5, 10, 15, 20, 21, 26, 29, 30, 33, 37, 40, 43, 44, 45, 49, 51, 53, 58, + 61, 65, 70, 75, 78, 83, 88, 90, 91, 95, 97, 100, 105, 109, 111, 115, 119, + 120, 124, 126, 129, 131, 135, 140, 143, 144, 147, 148, 149, 151, 155, + 156, 159, 160, 163, 165, 170, 172, 177, 181, 182, 186, 187, 192, 196, + 197, 199, 203, 207, 209, 213, 214, 216, 219, 221, 222, 225, 226, 231, + 236, 237, 242, 245, 247, 248, 253, 254, 259, 261, 266, 269, 272, 275, + 278, 283, 286, 289, 291, 296, 301, 305, + ])), + Arc::new(Int32Array::from_slice([ + 100, 98, 93, 91, 86, 84, 81, 77, 75, 71, 70, 69, 64, 62, 59, 55, 50, 45, + 41, 40, 39, 36, 31, 28, 23, 22, 17, 13, 10, 6, 5, 2, 1, -1, -4, -5, -6, + -8, -12, -16, -17, -19, -24, -25, -29, -34, -37, -42, -47, -48, -49, -53, + -57, -58, -61, -65, -67, -68, -71, -73, -75, -76, -78, -83, -87, -91, + -95, -98, -101, -105, -106, -111, -114, -116, -120, -125, -128, -129, + -134, -139, -142, -143, -146, -150, -154, -158, -163, -168, -172, -176, + -181, -184, -189, -193, -196, -201, -203, -208, -210, -213, + ])), + ], + )?; + let n_chunk = batch.num_rows() / n_file; + for i in 0..n_file { + let target_file = tmpdir.path().join(format!("{}.parquet", i)); + let file = File::create(target_file).unwrap(); + // Default writer properties + let props = WriterProperties::builder().build(); + let chunks_start = i * n_chunk; + let cur_batch = batch.slice(chunks_start, n_chunk); + // let chunks_end = chunks_start + n_chunk; + let mut writer = + ArrowWriter::try_new(file, cur_batch.schema(), Some(props)).unwrap(); + + writer.write(&cur_batch).expect("Writing batch"); + + // writer must be closed to write footer + writer.close().unwrap(); + } + Ok(()) +} + +async fn get_test_context(tmpdir: &TempDir) -> Result { + let session_config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::with_config(session_config); + + let parquet_read_options = ParquetReadOptions::default(); + // The sort order is specified (not actually correct in this case) + let file_sort_order = [col("ts")] + .into_iter() + .map(|e| { + let ascending = true; + let nulls_first = false; + e.sort(ascending, nulls_first) + }) + .collect::>(); + + let options_sort = parquet_read_options + .to_listing_options(&ctx.copied_config()) + .with_file_sort_order(Some(file_sort_order)); + + write_test_data_to_parquet(tmpdir, 1)?; + let provided_schema = None; + let sql_definition = None; + ctx.register_listing_table( + "annotated_data", + tmpdir.path().to_string_lossy(), + options_sort.clone(), + provided_schema, + sql_definition, + ) + .await + .unwrap(); + Ok(ctx) +} + +mod tests { + use super::*; + + #[tokio::test] + async fn test_source_sorted_aggregate() -> Result<()> { + let tmpdir = TempDir::new().unwrap(); + let ctx = get_test_context(&tmpdir).await?; + + let sql = "SELECT + SUM(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING) as sum1, + SUM(desc_col) OVER(ORDER BY ts RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING) as sum2, + SUM(inc_col) OVER(ORDER BY ts ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING) as sum3, + MIN(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING) as min1, + MIN(desc_col) OVER(ORDER BY ts RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING) as min2, + MIN(inc_col) OVER(ORDER BY ts ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING) as min3, + MAX(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING) as max1, + MAX(desc_col) OVER(ORDER BY ts RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING) as max2, + MAX(inc_col) OVER(ORDER BY ts ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING) as max3, + COUNT(*) OVER(ORDER BY ts RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING) as cnt1, + COUNT(*) OVER(ORDER BY ts ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING) as cnt2, + SUM(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING) as sumr1, + SUM(desc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING) as sumr2, + SUM(desc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sumr3, + MIN(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING) as minr1, + MIN(desc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING) as minr2, + MIN(inc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING) as minr3, + MAX(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING) as maxr1, + MAX(desc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING) as maxr2, + MAX(inc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING) as maxr3, + COUNT(*) OVER(ORDER BY ts DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) as cntr1, + COUNT(*) OVER(ORDER BY ts DESC ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING) as cntr2, + SUM(desc_col) OVER(ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING) as sum4, + COUNT(*) OVER(ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING) as cnt3 + FROM annotated_data + ORDER BY inc_col DESC + LIMIT 5 + "; + + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, sum3@2 as sum3, min1@3 as min1, min2@4 as min2, min3@5 as min3, max1@6 as max1, max2@7 as max2, max3@8 as max3, cnt1@9 as cnt1, cnt2@10 as cnt2, sumr1@11 as sumr1, sumr2@12 as sumr2, sumr3@13 as sumr3, minr1@14 as minr1, minr2@15 as minr2, minr3@16 as minr3, maxr1@17 as maxr1, maxr2@18 as maxr2, maxr3@19 as maxr3, cntr1@20 as cntr1, cntr2@21 as cntr2, sum4@22 as sum4, cnt3@23 as cnt3]", + " GlobalLimitExec: skip=0, fetch=5", + " SortExec: [inc_col@24 DESC]", + " ProjectionExec: expr=[SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as sum1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@15 as sum2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as sum3, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as min1, MIN(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@18 as min2, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@19 as min3, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@20 as max1, MAX(annotated_data.desc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@21 as max2, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as max3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@23 as cnt1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@24 as cnt2, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@3 as sumr1, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@4 as sumr2, SUM(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as sumr3, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@6 as minr1, MIN(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@7 as minr2, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as minr3, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as maxr1, MAX(annotated_data.desc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@10 as maxr2, MAX(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@11 as maxr3, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@12 as cntr1, COUNT(UInt8(1)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@13 as cntr2, SUM(annotated_data.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@25 as sum4, COUNT(UInt8(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@26 as cnt3, inc_col@1 as inc_col]", + " BoundedWindowAggExec: wdw=[SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, MIN(annotated_data.inc_col): Ok(Field { name: \"MIN(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, MIN(annotated_data.desc_col): Ok(Field { name: \"MIN(annotated_data.desc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, MIN(annotated_data.inc_col): Ok(Field { name: \"MIN(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, MAX(annotated_data.inc_col): Ok(Field { name: \"MAX(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, MAX(annotated_data.desc_col): Ok(Field { name: \"MAX(annotated_data.desc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)) }, MAX(annotated_data.inc_col): Ok(Field { name: \"MAX(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(8)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(8)), end_bound: Following(Int32(1)) }, SUM(annotated_data.desc_col): Ok(Field { name: \"SUM(annotated_data.desc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, MIN(annotated_data.inc_col): Ok(Field { name: \"MIN(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, MIN(annotated_data.desc_col): Ok(Field { name: \"MIN(annotated_data.desc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)) }, MIN(annotated_data.inc_col): Ok(Field { name: \"MIN(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, MAX(annotated_data.inc_col): Ok(Field { name: \"MAX(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, MAX(annotated_data.desc_col): Ok(Field { name: \"MAX(annotated_data.desc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)) }, MAX(annotated_data.inc_col): Ok(Field { name: \"MAX(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(2)), end_bound: Following(Int32(6)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(8)) }]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------+------+------+------+------+------+------+------+------+------+------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+------+", + "| sum1 | sum2 | sum3 | min1 | min2 | min3 | max1 | max2 | max3 | cnt1 | cnt2 | sumr1 | sumr2 | sumr3 | minr1 | minr2 | minr3 | maxr1 | maxr2 | maxr3 | cntr1 | cntr2 | sum4 | cnt3 |", + "+------+------+------+------+------+------+------+------+------+------+------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+------+", + "| 1482 | -631 | 606 | 289 | -213 | 301 | 305 | -208 | 305 | 3 | 9 | 902 | -834 | -1231 | 301 | -213 | 269 | 305 | -210 | 305 | 3 | 2 | -1797 | 9 |", + "| 1482 | -631 | 902 | 289 | -213 | 296 | 305 | -208 | 305 | 3 | 10 | 902 | -834 | -1424 | 301 | -213 | 266 | 305 | -210 | 305 | 3 | 3 | -1978 | 10 |", + "| 876 | -411 | 1193 | 289 | -208 | 291 | 296 | -203 | 305 | 4 | 10 | 587 | -612 | -1400 | 296 | -213 | 261 | 305 | -208 | 301 | 3 | 4 | -1941 | 10 |", + "| 866 | -404 | 1482 | 286 | -203 | 289 | 291 | -201 | 305 | 5 | 10 | 580 | -600 | -1374 | 291 | -208 | 259 | 305 | -203 | 296 | 4 | 5 | -1903 | 10 |", + "| 1411 | -397 | 1768 | 275 | -201 | 286 | 289 | -196 | 305 | 4 | 10 | 575 | -590 | -1347 | 289 | -203 | 254 | 305 | -201 | 291 | 2 | 6 | -1863 | 10 |", + "+------+------+------+------+------+------+------+------+------+------+------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+-------+------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) + } + + #[tokio::test] + async fn test_source_sorted_builtin() -> Result<()> { + let tmpdir = TempDir::new().unwrap(); + let ctx = get_test_context(&tmpdir).await?; + + let sql = "SELECT + FIRST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as fv1, + FIRST_VALUE(inc_col) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as fv2, + LAST_VALUE(inc_col) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as lv1, + LAST_VALUE(inc_col) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lv2, + NTH_VALUE(inc_col, 5) OVER(ORDER BY ts RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as nv1, + NTH_VALUE(inc_col, 5) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as nv2, + ROW_NUMBER() OVER(ORDER BY ts RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS rn1, + ROW_NUMBER() OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as rn2, + RANK() OVER(ORDER BY ts RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS rank1, + RANK() OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as rank2, + DENSE_RANK() OVER(ORDER BY ts RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS dense_rank1, + DENSE_RANK() OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as dense_rank2, + LAG(inc_col, 1, 1001) OVER(ORDER BY ts RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS lag1, + LAG(inc_col, 2, 1002) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lag2, + LEAD(inc_col, -1, 1001) OVER(ORDER BY ts RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS lead1, + LEAD(inc_col, 4, 1004) OVER(ORDER BY ts ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lead2, + FIRST_VALUE(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as fvr1, + FIRST_VALUE(inc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as fvr2, + LAST_VALUE(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 10 PRECEDING and 1 FOLLOWING) as lvr1, + LAST_VALUE(inc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lvr2, + LAG(inc_col, 1, 1001) OVER(ORDER BY ts DESC RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS lagr1, + LAG(inc_col, 2, 1002) OVER(ORDER BY ts DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lagr2, + LEAD(inc_col, -1, 1001) OVER(ORDER BY ts DESC RANGE BETWEEN 1 PRECEDING and 10 FOLLOWING) AS leadr1, + LEAD(inc_col, 4, 1004) OVER(ORDER BY ts DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as leadr2 + FROM annotated_data + ORDER BY ts DESC + LIMIT 5 + "; + + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "ProjectionExec: expr=[fv1@0 as fv1, fv2@1 as fv2, lv1@2 as lv1, lv2@3 as lv2, nv1@4 as nv1, nv2@5 as nv2, rn1@6 as rn1, rn2@7 as rn2, rank1@8 as rank1, rank2@9 as rank2, dense_rank1@10 as dense_rank1, dense_rank2@11 as dense_rank2, lag1@12 as lag1, lag2@13 as lag2, lead1@14 as lead1, lead2@15 as lead2, fvr1@16 as fvr1, fvr2@17 as fvr2, lvr1@18 as lvr1, lvr2@19 as lvr2, lagr1@20 as lagr1, lagr2@21 as lagr2, leadr1@22 as leadr1, leadr2@23 as leadr2]", + " GlobalLimitExec: skip=0, fetch=5", + " SortExec: [ts@24 DESC]", + " ProjectionExec: expr=[FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data.inc_col,Int64(5)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data.inc_col,Int64(5)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, ROW_NUMBER() ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, ROW_NUMBER() ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, DENSE_RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, DENSE_RANK() ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, LAG(annotated_data.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, LAG(annotated_data.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, LEAD(annotated_data.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, LEAD(annotated_data.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, LAG(annotated_data.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, LAG(annotated_data.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, LEAD(annotated_data.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, LEAD(annotated_data.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2, ts@0 as ts]", + " BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data.inc_col): Ok(Field { name: \"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, FIRST_VALUE(annotated_data.inc_col): Ok(Field { name: \"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LAST_VALUE(annotated_data.inc_col): Ok(Field { name: \"LAST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LAST_VALUE(annotated_data.inc_col): Ok(Field { name: \"LAST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, NTH_VALUE(annotated_data.inc_col,Int64(5)): Ok(Field { name: \"NTH_VALUE(annotated_data.inc_col,Int64(5))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, NTH_VALUE(annotated_data.inc_col,Int64(5)): Ok(Field { name: \"NTH_VALUE(annotated_data.inc_col,Int64(5))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, RANK(): Ok(Field { name: \"RANK()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, RANK(): Ok(Field { name: \"RANK()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, DENSE_RANK(): Ok(Field { name: \"DENSE_RANK()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, DENSE_RANK(): Ok(Field { name: \"DENSE_RANK()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LAG(annotated_data.inc_col,Int64(1),Int64(1001)): Ok(Field { name: \"LAG(annotated_data.inc_col,Int64(1),Int64(1001))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LAG(annotated_data.inc_col,Int64(2),Int64(1002)): Ok(Field { name: \"LAG(annotated_data.inc_col,Int64(2),Int64(1002))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(annotated_data.inc_col,Int64(-1),Int64(1001)): Ok(Field { name: \"LEAD(annotated_data.inc_col,Int64(-1),Int64(1001))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LEAD(annotated_data.inc_col,Int64(4),Int64(1004)): Ok(Field { name: \"LEAD(annotated_data.inc_col,Int64(4),Int64(1004))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }]", + " BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data.inc_col): Ok(Field { name: \"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, FIRST_VALUE(annotated_data.inc_col): Ok(Field { name: \"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LAST_VALUE(annotated_data.inc_col): Ok(Field { name: \"LAST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)) }, LAST_VALUE(annotated_data.inc_col): Ok(Field { name: \"LAST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LAG(annotated_data.inc_col,Int64(1),Int64(1001)): Ok(Field { name: \"LAG(annotated_data.inc_col,Int64(1),Int64(1001))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LAG(annotated_data.inc_col,Int64(2),Int64(1002)): Ok(Field { name: \"LAG(annotated_data.inc_col,Int64(2),Int64(1002))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }, LEAD(annotated_data.inc_col,Int64(-1),Int64(1001)): Ok(Field { name: \"LEAD(annotated_data.inc_col,Int64(-1),Int64(1001))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)) }, LEAD(annotated_data.inc_col,Int64(4),Int64(1004)): Ok(Field { name: \"LEAD(annotated_data.inc_col,Int64(4),Int64(1004))\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)) }]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----+-----+-----+-----+-----+-----+-----+-----+-------+-------+-------------+-------------+------+------+-------+-------+------+------+------+------+-------+-------+--------+--------+", + "| fv1 | fv2 | lv1 | lv2 | nv1 | nv2 | rn1 | rn2 | rank1 | rank2 | dense_rank1 | dense_rank2 | lag1 | lag2 | lead1 | lead2 | fvr1 | fvr2 | lvr1 | lvr2 | lagr1 | lagr2 | leadr1 | leadr2 |", + "+-----+-----+-----+-----+-----+-----+-----+-----+-------+-------+-------------+-------------+------+------+-------+-------+------+------+------+------+-------+-------+--------+--------+", + "| 289 | 266 | 305 | 305 | 305 | 278 | 99 | 99 | 99 | 99 | 86 | 86 | 296 | 291 | 296 | 1004 | 305 | 305 | 301 | 296 | 305 | 1002 | 305 | 286 |", + "| 289 | 269 | 305 | 305 | 305 | 283 | 100 | 100 | 99 | 99 | 86 | 86 | 301 | 296 | 301 | 1004 | 305 | 305 | 301 | 301 | 1001 | 1002 | 1001 | 289 |", + "| 289 | 261 | 296 | 301 | | 275 | 98 | 98 | 98 | 98 | 85 | 85 | 291 | 289 | 291 | 1004 | 305 | 305 | 296 | 291 | 301 | 305 | 301 | 283 |", + "| 286 | 259 | 291 | 296 | | 272 | 97 | 97 | 97 | 97 | 84 | 84 | 289 | 286 | 289 | 1004 | 305 | 305 | 291 | 289 | 296 | 301 | 296 | 278 |", + "| 275 | 254 | 289 | 291 | 289 | 269 | 96 | 96 | 96 | 96 | 83 | 83 | 286 | 283 | 286 | 305 | 305 | 305 | 289 | 286 | 291 | 296 | 291 | 275 |", + "+-----+-----+-----+-----+-----+-----+-----+-----+-------+-------+-------------+-------------+------+------+-------+-------+------+------+------+------+-------+-------+--------+--------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) + } +} diff --git a/datafusion/core/tests/window_fuzz.rs b/datafusion/core/tests/window_fuzz.rs new file mode 100644 index 0000000000000..471484af218d1 --- /dev/null +++ b/datafusion/core/tests/window_fuzz.rs @@ -0,0 +1,385 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array}; +use arrow::compute::{concat_batches, SortOptions}; +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::pretty_format_batches; +use hashbrown::HashMap; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +use datafusion::physical_plan::collect; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::windows::{ + create_window_expr, BoundedWindowAggExec, WindowAggExec, +}; +use datafusion_expr::{ + AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunction, +}; + +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_common::ScalarValue; +use datafusion_physical_expr::expressions::{col, lit}; +use datafusion_physical_expr::PhysicalSortExpr; +use test_utils::add_empty_batches; + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test(flavor = "multi_thread", worker_threads = 8)] + async fn single_order_by_test() { + let n = 100; + let distincts = vec![1, 100]; + for distinct in distincts { + let mut handles = Vec::new(); + for i in 1..n { + let job = tokio::spawn(run_window_test( + make_staggered_batches::(1000, distinct, i), + i, + vec!["a"], + vec![], + )); + handles.push(job); + } + for job in handles { + job.await.unwrap(); + } + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 8)] + async fn order_by_with_partition_test() { + let n = 100; + let distincts = vec![1, 100]; + for distinct in distincts { + // since we have sorted pairs (a,b) to not violate per partition soring + // partition should be field a, order by should be field b + let mut handles = Vec::new(); + for i in 1..n { + let job = tokio::spawn(run_window_test( + make_staggered_batches::(1000, distinct, i), + i, + vec!["b"], + vec!["a"], + )); + handles.push(job); + } + for job in handles { + job.await.unwrap(); + } + } + } +} + +/// Perform batch and running window same input +/// and verify outputs of `WindowAggExec` and `BoundedWindowAggExec` are equal +async fn run_window_test( + input1: Vec, + random_seed: u64, + orderby_columns: Vec<&str>, + partition_by_columns: Vec<&str>, +) { + let mut rng = StdRng::seed_from_u64(random_seed); + let schema = input1[0].schema(); + let mut args = vec![col("x", &schema).unwrap()]; + let mut window_fn_map = HashMap::new(); + // HashMap values consists of tuple first element is WindowFunction, second is additional argument + // window function requires if any. For most of the window functions additional argument is empty + window_fn_map.insert( + "sum", + ( + WindowFunction::AggregateFunction(AggregateFunction::Sum), + vec![], + ), + ); + window_fn_map.insert( + "count", + ( + WindowFunction::AggregateFunction(AggregateFunction::Count), + vec![], + ), + ); + window_fn_map.insert( + "min", + ( + WindowFunction::AggregateFunction(AggregateFunction::Min), + vec![], + ), + ); + window_fn_map.insert( + "max", + ( + WindowFunction::AggregateFunction(AggregateFunction::Max), + vec![], + ), + ); + window_fn_map.insert( + "row_number", + ( + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::RowNumber), + vec![], + ), + ); + window_fn_map.insert( + "rank", + ( + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Rank), + vec![], + ), + ); + window_fn_map.insert( + "first_value", + ( + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue), + vec![], + ), + ); + window_fn_map.insert( + "last_value", + ( + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue), + vec![], + ), + ); + window_fn_map.insert( + "nth_value", + ( + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::NthValue), + vec![lit(ScalarValue::Int64(Some(rng.gen_range(1..10))))], + ), + ); + window_fn_map.insert( + "lead", + ( + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead), + vec![ + lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), + lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), + ], + ), + ); + window_fn_map.insert( + "lag", + ( + WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag), + vec![ + lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), + lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), + ], + ), + ); + + let session_config = SessionConfig::new().with_batch_size(50); + let ctx = SessionContext::with_config(session_config); + let rand_fn_idx = rng.gen_range(0..window_fn_map.len()); + let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; + let (window_fn, new_args) = window_fn_map.values().collect::>()[rand_fn_idx]; + for new_arg in new_args { + args.push(new_arg.clone()); + } + let preceding = rng.gen_range(0..50); + let following = rng.gen_range(0..50); + let rand_num = rng.gen_range(0..3); + let units = if rand_num < 1 { + WindowFrameUnits::Range + } else if rand_num < 2 { + WindowFrameUnits::Rows + } else { + // For now we do not support GROUPS in BoundedWindowAggExec implementation + // TODO: once GROUPS handling is available, use WindowFrameUnits::GROUPS in randomized tests also. + WindowFrameUnits::Range + }; + let window_frame = match units { + // In range queries window frame boundaries should match column type + WindowFrameUnits::Range => WindowFrame { + units, + start_bound: WindowFrameBound::Preceding(ScalarValue::Int32(Some(preceding))), + end_bound: WindowFrameBound::Following(ScalarValue::Int32(Some(following))), + }, + // In window queries, window frame boundary should be Uint64 + WindowFrameUnits::Rows => WindowFrame { + units, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some( + preceding as u64, + ))), + end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some( + following as u64, + ))), + }, + // Once GROUPS support is added construct window frame for this case also + _ => todo!(), + }; + let mut orderby_exprs = vec![]; + for column in orderby_columns { + orderby_exprs.push(PhysicalSortExpr { + expr: col(column, &schema).unwrap(), + options: SortOptions::default(), + }) + } + let mut partitionby_exprs = vec![]; + for column in partition_by_columns { + partitionby_exprs.push(col(column, &schema).unwrap()); + } + let mut sort_keys = vec![]; + for partition_by_expr in &partitionby_exprs { + sort_keys.push(PhysicalSortExpr { + expr: partition_by_expr.clone(), + options: SortOptions::default(), + }) + } + for order_by_expr in &orderby_exprs { + sort_keys.push(order_by_expr.clone()) + } + + let concat_input_record = concat_batches(&schema, &input1).unwrap(); + let exec1 = Arc::new( + MemoryExec::try_new(&[vec![concat_input_record]], schema.clone(), None).unwrap(), + ); + let usual_window_exec = Arc::new( + WindowAggExec::try_new( + vec![create_window_expr( + window_fn, + fn_name.to_string(), + &args, + &partitionby_exprs, + &orderby_exprs, + Arc::new(window_frame.clone()), + schema.as_ref(), + ) + .unwrap()], + exec1, + schema.clone(), + vec![], + Some(sort_keys.clone()), + ) + .unwrap(), + ); + let exec2 = + Arc::new(MemoryExec::try_new(&[input1.clone()], schema.clone(), None).unwrap()); + let running_window_exec = Arc::new( + BoundedWindowAggExec::try_new( + vec![create_window_expr( + window_fn, + fn_name.to_string(), + &args, + &partitionby_exprs, + &orderby_exprs, + Arc::new(window_frame.clone()), + schema.as_ref(), + ) + .unwrap()], + exec2, + schema.clone(), + vec![], + Some(sort_keys), + ) + .unwrap(), + ); + + let task_ctx = ctx.task_ctx(); + let collected_usual = collect(usual_window_exec, task_ctx.clone()).await.unwrap(); + + let collected_running = collect(running_window_exec, task_ctx.clone()) + .await + .unwrap(); + // compare + let usual_formatted = pretty_format_batches(&collected_usual).unwrap().to_string(); + let running_formatted = pretty_format_batches(&collected_running) + .unwrap() + .to_string(); + + let mut usual_formatted_sorted: Vec<&str> = usual_formatted.trim().lines().collect(); + usual_formatted_sorted.sort_unstable(); + + let mut running_formatted_sorted: Vec<&str> = + running_formatted.trim().lines().collect(); + running_formatted_sorted.sort_unstable(); + for (i, (usual_line, running_line)) in usual_formatted_sorted + .iter() + .zip(&running_formatted_sorted) + .enumerate() + { + assert_eq!( + (i, usual_line), + (i, running_line), + "Inconsistent result for window_fn: {:?}, args:{:?}", + window_fn, + args + ); + } +} + +/// Return randomly sized record batches with: +/// two sorted int32 columns 'a', 'b' ranged from 0..len / DISTINCT as columns +/// two random int32 columns 'x', 'y' as other columns +fn make_staggered_batches( + len: usize, + distinct: usize, + random_seed: u64, +) -> Vec { + // use a random number generator to pick a random sized output + let mut rng = StdRng::seed_from_u64(random_seed); + let mut input12: Vec<(i32, i32)> = vec![(0, 0); len]; + let mut input3: Vec = vec![0; len]; + let mut input4: Vec = vec![0; len]; + input12.iter_mut().for_each(|v| { + *v = ( + rng.gen_range(0..(len / distinct)) as i32, + rng.gen_range(0..(len / distinct)) as i32, + ) + }); + rng.fill(&mut input3[..]); + rng.fill(&mut input4[..]); + input12.sort(); + let input1 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.0)); + let input2 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.1)); + let input3 = Int32Array::from_iter_values(input3.into_iter()); + let input4 = Int32Array::from_iter_values(input4.into_iter()); + + // split into several record batches + let mut remainder = RecordBatch::try_from_iter(vec![ + ("a", Arc::new(input1) as ArrayRef), + ("b", Arc::new(input2) as ArrayRef), + ("x", Arc::new(input3) as ArrayRef), + ("y", Arc::new(input4) as ArrayRef), + ]) + .unwrap(); + + let mut batches = vec![]; + if STREAM { + while remainder.num_rows() > 0 { + let batch_size = rng.gen_range(0..50); + if remainder.num_rows() < batch_size { + break; + } + batches.push(remainder.slice(0, batch_size)); + remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); + } + } else { + while remainder.num_rows() > 0 { + let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + batches.push(remainder.slice(0, batch_size)); + remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); + } + } + add_empty_batches(batches, &mut rng) +} diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 094d233a90017..5aede03fd2e02 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -51,6 +51,7 @@ datafusion-expr = { path = "../expr", version = "15.0.0" } datafusion-row = { path = "../row", version = "15.0.0" } half = { version = "2.1", default-features = false } hashbrown = { version = "0.13", features = ["raw"] } +indexmap = "1.9.2" itertools = { version = "0.10", features = ["use_std"] } lazy_static = { version = "^1.4.0" } md-5 = { version = "^0.10.0", optional = true } diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 813952117af11..8ccf87ac2b1da 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -98,6 +98,10 @@ impl AggregateExpr for Count { true } + fn supports_bounded_execution(&self) -> bool { + true + } + fn create_row_accumulator( &self, start_index: usize, diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index a7bd6c360a904..bf4fd0868a64a 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -62,7 +62,7 @@ fn min_max_aggregate_data_type(input_type: DataType) -> DataType { } /// MAX aggregate expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Max { name: String, data_type: DataType, @@ -124,6 +124,10 @@ impl AggregateExpr for Max { is_row_accumulator_support_dtype(&self.data_type) } + fn supports_bounded_execution(&self) -> bool { + true + } + fn create_row_accumulator( &self, start_index: usize, @@ -134,6 +138,10 @@ impl AggregateExpr for Max { ))) } + fn reverse_expr(&self) -> Option> { + Some(Arc::new(self.clone())) + } + fn create_sliding_accumulator(&self) -> Result> { Ok(Box::new(SlidingMaxAccumulator::try_new(&self.data_type)?)) } @@ -672,7 +680,7 @@ impl RowAccumulator for MaxRowAccumulator { } /// MIN aggregate expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Min { name: String, data_type: DataType, @@ -734,6 +742,10 @@ impl AggregateExpr for Min { is_row_accumulator_support_dtype(&self.data_type) } + fn supports_bounded_execution(&self) -> bool { + true + } + fn create_row_accumulator( &self, start_index: usize, @@ -744,6 +756,10 @@ impl AggregateExpr for Min { ))) } + fn reverse_expr(&self) -> Option> { + Some(Arc::new(self.clone())) + } + fn create_sliding_accumulator(&self) -> Result> { Ok(Box::new(SlidingMinAccumulator::try_new(&self.data_type)?)) } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 528a5cc73f403..c42a5c03b3060 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -88,6 +88,12 @@ pub trait AggregateExpr: Send + Sync + Debug { false } + /// Specifies whether this aggregate function can run using bounded memory. + /// Any accumulator returning "true" needs to implement `retract_batch`. + fn supports_bounded_execution(&self) -> bool { + false + } + /// RowAccumulator to access/update row-based aggregation state in-place. /// Currently, row accumulator only supports states of fixed-sized type. /// diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 5de9a9296bc92..8f78abfd5e919 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -113,6 +113,10 @@ impl AggregateExpr for Sum { is_row_accumulator_support_dtype(&self.data_type) } + fn supports_bounded_execution(&self) -> bool { + true + } + fn create_row_accumulator( &self, start_index: usize, diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 5c46f38f220ff..df61e7cc8fbbd 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -29,7 +29,7 @@ use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::WindowFrame; +use datafusion_expr::{WindowFrame, WindowFrameUnits}; use crate::window::window_expr::reverse_order_bys; use crate::window::SlidingAggregateWindowExpr; @@ -162,4 +162,12 @@ impl WindowExpr for AggregateWindowExpr { } }) } + + fn uses_bounded_memory(&self) -> bool { + // NOTE: Currently, groups queries do not support the bounded memory variant. + self.aggregate.supports_bounded_execution() + && !self.window_frame.start_bound.is_unbounded() + && !self.window_frame.end_bound.is_unbounded() + && !matches!(self.window_frame.units, WindowFrameUnits::Groups) + } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 9804432b2056d..f0484b790fbc6 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -20,14 +20,19 @@ use super::window_frame_state::WindowFrameContext; use super::BuiltInWindowFunctionExpr; use super::WindowExpr; -use crate::window::window_expr::reverse_order_bys; +use crate::window::window_expr::{ + reverse_order_bys, BuiltinWindowState, WindowFn, WindowFunctionState, +}; +use crate::window::{ + PartitionBatches, PartitionWindowAggStates, WindowAggState, WindowState, +}; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; -use arrow::compute::SortOptions; +use arrow::compute::{concat, SortOptions}; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::WindowFrame; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::{WindowFrame, WindowFrameUnits}; use std::any::Any; use std::sync::Arc; @@ -91,7 +96,7 @@ impl WindowExpr for BuiltInWindowExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let evaluator = self.expr.create_evaluator()?; let num_rows = batch.num_rows(); - if evaluator.uses_window_frame() { + if self.expr.uses_window_frame() { let sort_options: Vec = self.order_by.iter().map(|o| o.options).collect(); let mut row_wise_results = vec![]; @@ -122,6 +127,102 @@ impl WindowExpr for BuiltInWindowExpr { } } + /// Evaluate the window function against the batch. This function facilitates + /// stateful, bounded-memory implementations. + fn evaluate_stateful( + &self, + partition_batches: &PartitionBatches, + window_agg_state: &mut PartitionWindowAggStates, + ) -> Result<()> { + let field = self.expr.field()?; + let out_type = field.data_type(); + let sort_options = self.order_by.iter().map(|o| o.options).collect::>(); + for (partition_row, partition_batch_state) in partition_batches.iter() { + if !window_agg_state.contains_key(partition_row) { + let evaluator = self.expr.create_evaluator()?; + window_agg_state.insert( + partition_row.clone(), + WindowState { + state: WindowAggState::new( + out_type, + WindowFunctionState::BuiltinWindowState( + BuiltinWindowState::Default, + ), + )?, + window_fn: WindowFn::Builtin(evaluator), + }, + ); + }; + let window_state = + window_agg_state.get_mut(partition_row).ok_or_else(|| { + DataFusionError::Execution("Cannot find state".to_string()) + })?; + let evaluator = match &mut window_state.window_fn { + WindowFn::Builtin(evaluator) => evaluator, + _ => unreachable!(), + }; + let mut state = &mut window_state.state; + state.is_end = partition_batch_state.is_end; + + let (values, order_bys) = + self.get_values_orderbys(&partition_batch_state.record_batch)?; + + // We iterate on each row to perform a running calculation. + let num_rows = partition_batch_state.record_batch.num_rows(); + let mut last_range = state.window_frame_range.clone(); + let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + let sort_partition_points = if evaluator.include_rank() { + let columns = self.sort_columns(&partition_batch_state.record_batch)?; + self.evaluate_partition_points(num_rows, &columns)? + } else { + vec![] + }; + let mut row_wise_results: Vec = vec![]; + for idx in state.last_calculated_index..num_rows { + state.window_frame_range = if self.expr.uses_window_frame() { + window_frame_ctx.calculate_range( + &order_bys, + &sort_options, + num_rows, + idx, + ) + } else { + evaluator.get_range(state, num_rows) + }?; + evaluator.update_state(state, &order_bys, &sort_partition_points)?; + + // Exit if range end index is length, need kind of flag to stop + if state.window_frame_range.end == num_rows + && !partition_batch_state.is_end + { + state.window_frame_range = last_range.clone(); + break; + } + let frame_range = &state.window_frame_range; + row_wise_results.push(if frame_range.start == frame_range.end { + // We produce None if the window is empty. + ScalarValue::try_from(out_type) + } else { + evaluator.evaluate_stateful(&values) + }?); + last_range = frame_range.clone(); + state.last_calculated_index = idx + 1; + } + state.window_frame_range = last_range; + let out_col = if row_wise_results.is_empty() { + ScalarValue::try_from(out_type)?.to_array_of_size(0) + } else { + ScalarValue::iter_to_array(row_wise_results.into_iter())? + }; + + state.out_col = concat(&[&state.out_col, &out_col])?; + state.n_row_result_missing = num_rows - state.last_calculated_index; + state.window_function_state = + WindowFunctionState::BuiltinWindowState(evaluator.state()?); + } + Ok(()) + } + fn get_window_frame(&self) -> &Arc { &self.window_frame } @@ -136,4 +237,13 @@ impl WindowExpr for BuiltInWindowExpr { )) as _ }) } + + fn uses_bounded_memory(&self) -> bool { + // NOTE: Currently, groups queries do not support the bounded memory variant. + self.expr.supports_bounded_execution() + && (!self.expr.uses_window_frame() + || !(self.window_frame.start_bound.is_unbounded() + || self.window_frame.end_bound.is_unbounded() + || matches!(self.window_frame.units, WindowFrameUnits::Groups))) + } } diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index c358403fefdac..6f41ec599a83c 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -64,4 +64,12 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { fn reverse_expr(&self) -> Option> { None } + + fn supports_bounded_execution(&self) -> bool { + false + } + + fn uses_window_frame(&self) -> bool { + false + } } diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs index 45fe51178afcd..3abb91e06f6a2 100644 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ b/datafusion/physical-expr/src/window/cume_dist.rs @@ -66,6 +66,7 @@ impl BuiltInWindowFunctionExpr for CumeDist { } } +#[derive(Debug)] pub(crate) struct CumeDistEvaluator; impl PartitionEvaluator for CumeDistEvaluator { diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index e18815c4c3a62..fc815a220af0e 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -19,7 +19,8 @@ //! at runtime during query execution use crate::window::partition_evaluator::PartitionEvaluator; -use crate::window::BuiltInWindowFunctionExpr; +use crate::window::window_expr::{BuiltinWindowState, LeadLagState}; +use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::compute::cast; @@ -27,7 +28,8 @@ use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use std::any::Any; -use std::ops::Neg; +use std::cmp::min; +use std::ops::{Neg, Range}; use std::sync::Arc; /// window shift expression @@ -102,11 +104,16 @@ impl BuiltInWindowFunctionExpr for WindowShift { fn create_evaluator(&self) -> Result> { Ok(Box::new(WindowShiftEvaluator { + state: LeadLagState { idx: 0 }, shift_offset: self.shift_offset, default_value: self.default_value.clone(), })) } + fn supports_bounded_execution(&self) -> bool { + true + } + fn reverse_expr(&self) -> Option> { Some(Arc::new(Self { name: self.name.clone(), @@ -118,7 +125,9 @@ impl BuiltInWindowFunctionExpr for WindowShift { } } +#[derive(Debug)] pub(crate) struct WindowShiftEvaluator { + state: LeadLagState, shift_offset: i64, default_value: Option, } @@ -173,6 +182,54 @@ fn shift_with_default_value( } impl PartitionEvaluator for WindowShiftEvaluator { + fn state(&self) -> Result { + // If we do not use state we just return Default + Ok(BuiltinWindowState::LeadLag(self.state.clone())) + } + + fn update_state( + &mut self, + state: &WindowAggState, + _range_columns: &[ArrayRef], + _sort_partition_points: &[Range], + ) -> Result<()> { + self.state.idx = state.last_calculated_index; + Ok(()) + } + + fn get_range(&self, state: &WindowAggState, n_rows: usize) -> Result> { + if self.shift_offset > 0 { + let offset = self.shift_offset as usize; + let start = if state.last_calculated_index > offset { + state.last_calculated_index - offset + } else { + 0 + }; + Ok(Range { + start, + end: state.last_calculated_index + 1, + }) + } else { + let offset = (-self.shift_offset) as usize; + let end = min(state.last_calculated_index + offset, n_rows); + Ok(Range { + start: state.last_calculated_index, + end, + }) + } + } + + fn evaluate_stateful(&mut self, values: &[ArrayRef]) -> Result { + let array = &values[0]; + let dtype = array.data_type(); + let idx = self.state.idx as i64 - self.shift_offset; + if idx < 0 || idx as usize >= array.len() { + get_default_value(&self.default_value, dtype) + } else { + ScalarValue::try_from_array(array, idx as usize) + } + } + fn evaluate(&self, values: &[ArrayRef], _num_rows: usize) -> Result { // LEAD, LAG window functions take single column, values will have size 1 let value = &values[0]; @@ -180,6 +237,23 @@ impl PartitionEvaluator for WindowShiftEvaluator { } } +fn get_default_value( + default_value: &Option, + dtype: &DataType, +) -> Result { + if let Some(value) = default_value { + if let ScalarValue::Int64(Some(val)) = value { + ScalarValue::try_from_string(val.to_string(), dtype) + } else { + Err(DataFusionError::Internal( + "Expects default value to have Int64 type".to_string(), + )) + } + } else { + Ok(ScalarValue::try_from(dtype)?) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index ffbce598e6ad0..35036a6dbeb4c 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -33,4 +33,10 @@ pub use aggregate::AggregateWindowExpr; pub use built_in::BuiltInWindowExpr; pub use built_in_window_function_expr::BuiltInWindowFunctionExpr; pub use sliding_aggregate::SlidingAggregateWindowExpr; +pub use window_expr::PartitionBatchState; +pub use window_expr::PartitionBatches; +pub use window_expr::PartitionKey; +pub use window_expr::PartitionWindowAggStates; +pub use window_expr::WindowAggState; pub use window_expr::WindowExpr; +pub use window_expr::WindowState; diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index e998b47018a52..c3c3b55d4e88f 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -19,7 +19,8 @@ //! that can evaluated at runtime during query execution use crate::window::partition_evaluator::PartitionEvaluator; -use crate::window::BuiltInWindowFunctionExpr; +use crate::window::window_expr::{BuiltinWindowState, NthValueState}; +use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; use crate::PhysicalExpr; use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; @@ -121,7 +122,18 @@ impl BuiltInWindowFunctionExpr for NthValue { } fn create_evaluator(&self) -> Result> { - Ok(Box::new(NthValueEvaluator { kind: self.kind })) + Ok(Box::new(NthValueEvaluator { + state: NthValueState::default(), + kind: self.kind, + })) + } + + fn supports_bounded_execution(&self) -> bool { + true + } + + fn uses_window_frame(&self) -> bool { + true } fn reverse_expr(&self) -> Option> { @@ -140,13 +152,31 @@ impl BuiltInWindowFunctionExpr for NthValue { } /// Value evaluator for nth_value functions +#[derive(Debug)] pub(crate) struct NthValueEvaluator { + state: NthValueState, kind: NthValueKind, } impl PartitionEvaluator for NthValueEvaluator { - fn uses_window_frame(&self) -> bool { - true + fn state(&self) -> Result { + // If we do not use state we just return Default + Ok(BuiltinWindowState::NthValue(self.state.clone())) + } + + fn update_state( + &mut self, + state: &WindowAggState, + _range_columns: &[ArrayRef], + _sort_partition_points: &[Range], + ) -> Result<()> { + // If we do not use state, update_state does nothing + self.state.range = state.window_frame_range.clone(); + Ok(()) + } + + fn evaluate_stateful(&mut self, values: &[ArrayRef]) -> Result { + self.evaluate_inside_range(values, self.state.range.clone()) } fn evaluate_inside_range( diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs index f5844eccc63a8..b8365dba19d01 100644 --- a/datafusion/physical-expr/src/window/ntile.rs +++ b/datafusion/physical-expr/src/window/ntile.rs @@ -64,6 +64,7 @@ impl BuiltInWindowFunctionExpr for Ntile { } } +#[derive(Debug)] pub(crate) struct NtileEvaluator { n: u64, } diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index 86500441df5bc..e6cead76d13d2 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -17,26 +17,54 @@ //! partition evaluation module +use crate::window::window_expr::BuiltinWindowState; +use crate::window::WindowAggState; use arrow::array::ArrayRef; use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; +use std::fmt::Debug; use std::ops::Range; /// Partition evaluator -pub trait PartitionEvaluator { +pub trait PartitionEvaluator: Debug + Send { /// Whether the evaluator should be evaluated with rank fn include_rank(&self) -> bool { false } - fn uses_window_frame(&self) -> bool { - false + /// Returns state of the Built-in Window Function + fn state(&self) -> Result { + // If we do not use state we just return Default + Ok(BuiltinWindowState::Default) + } + + fn update_state( + &mut self, + _state: &WindowAggState, + _range_columns: &[ArrayRef], + _sort_partition_points: &[Range], + ) -> Result<()> { + // If we do not use state, update_state does nothing + Ok(()) + } + + fn get_range(&self, _state: &WindowAggState, _n_rows: usize) -> Result> { + Err(DataFusionError::NotImplemented( + "get_range is not implemented for this window function".to_string(), + )) } /// evaluate the partition evaluator against the partition fn evaluate(&self, _values: &[ArrayRef], _num_rows: usize) -> Result { Err(DataFusionError::NotImplemented( - "evaluate_partition is not implemented by default".into(), + "evaluate is not implemented by default".into(), + )) + } + + /// evaluate window function result inside given range + fn evaluate_stateful(&mut self, _values: &[ArrayRef]) -> Result { + Err(DataFusionError::NotImplemented( + "evaluate_stateful is not implemented by default".into(), )) } diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 87e01528de5a8..ead9d44535ba1 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -19,12 +19,13 @@ //! at runtime during query execution use crate::window::partition_evaluator::PartitionEvaluator; -use crate::window::BuiltInWindowFunctionExpr; +use crate::window::window_expr::{BuiltinWindowState, RankState}; +use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; use crate::PhysicalExpr; use arrow::array::ArrayRef; use arrow::array::{Float64Array, UInt64Array}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::Result; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use std::any::Any; use std::iter; use std::ops::Range; @@ -98,18 +99,77 @@ impl BuiltInWindowFunctionExpr for Rank { &self.name } + fn supports_bounded_execution(&self) -> bool { + matches!(self.rank_type, RankType::Basic | RankType::Dense) + } + fn create_evaluator(&self) -> Result> { Ok(Box::new(RankEvaluator { + state: RankState::default(), rank_type: self.rank_type, })) } } +#[derive(Debug)] pub(crate) struct RankEvaluator { + state: RankState, rank_type: RankType, } impl PartitionEvaluator for RankEvaluator { + fn get_range(&self, state: &WindowAggState, _n_rows: usize) -> Result> { + Ok(Range { + start: state.last_calculated_index, + end: state.last_calculated_index + 1, + }) + } + + fn state(&self) -> Result { + Ok(BuiltinWindowState::Rank(self.state.clone())) + } + + fn update_state( + &mut self, + state: &WindowAggState, + range_columns: &[ArrayRef], + sort_partition_points: &[Range], + ) -> Result<()> { + // find range inside `sort_partition_points` containing `state.last_calculated_index` + let chunk_idx = sort_partition_points + .iter() + .position(|elem| { + elem.start <= state.last_calculated_index + && state.last_calculated_index < elem.end + }) + .ok_or_else(|| DataFusionError::Execution("Expects sort_partition_points to contain state.last_calculated_index".to_string()))?; + let chunk = &sort_partition_points[chunk_idx]; + let last_rank_data = range_columns + .iter() + .map(|c| ScalarValue::try_from_array(c, chunk.end - 1)) + .collect::>>()?; + let empty = self.state.last_rank_data.is_empty(); + if empty || self.state.last_rank_data != last_rank_data { + self.state.last_rank_data = last_rank_data; + self.state.last_rank_boundary = state.offset_pruned_rows + chunk.start; + self.state.n_rank = 1 + if empty { chunk_idx } else { self.state.n_rank }; + } + Ok(()) + } + + /// evaluate window function result inside given range + fn evaluate_stateful(&mut self, _values: &[ArrayRef]) -> Result { + match self.rank_type { + RankType::Basic => Ok(ScalarValue::UInt64(Some( + self.state.last_rank_boundary as u64 + 1, + ))), + RankType::Dense => Ok(ScalarValue::UInt64(Some(self.state.n_rank as u64))), + RankType::Percent => Err(DataFusionError::Execution( + "Can not execute PERCENT_RANK in a streaming fashion".to_string(), + )), + } + } + fn include_rank(&self) -> bool { true } diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index b27ac29d27640..c858a5724a202 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -18,12 +18,14 @@ //! Defines physical expression for `row_number` that can evaluated at runtime during query execution use crate::window::partition_evaluator::PartitionEvaluator; -use crate::window::BuiltInWindowFunctionExpr; +use crate::window::window_expr::{BuiltinWindowState, NumRowsState}; +use crate::window::{BuiltInWindowFunctionExpr, WindowAggState}; use crate::PhysicalExpr; use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::Result; +use datafusion_common::{Result, ScalarValue}; use std::any::Any; +use std::ops::Range; use std::sync::Arc; /// row_number expression @@ -62,12 +64,36 @@ impl BuiltInWindowFunctionExpr for RowNumber { fn create_evaluator(&self) -> Result> { Ok(Box::::default()) } + + fn supports_bounded_execution(&self) -> bool { + true + } } -#[derive(Default)] -pub(crate) struct NumRowsEvaluator {} +#[derive(Default, Debug)] +pub(crate) struct NumRowsEvaluator { + state: NumRowsState, +} impl PartitionEvaluator for NumRowsEvaluator { + fn state(&self) -> Result { + // If we do not use state we just return Default + Ok(BuiltinWindowState::NumRows(self.state.clone())) + } + + fn get_range(&self, state: &WindowAggState, _n_rows: usize) -> Result> { + Ok(Range { + start: state.last_calculated_index, + end: state.last_calculated_index + 1, + }) + } + + /// evaluate window function result inside given range + fn evaluate_stateful(&mut self, _values: &[ArrayRef]) -> Result { + self.state.n_rows += 1; + Ok(ScalarValue::UInt64(Some(self.state.n_rows as u64))) + } + fn evaluate(&self, _values: &[ArrayRef], num_rows: usize) -> Result { Ok(Arc::new(UInt64Array::from_iter_values( 1..(num_rows as u64) + 1, diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 2a0fa86b7fe33..587c313e31bd7 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -23,16 +23,19 @@ use std::ops::Range; use std::sync::Arc; use arrow::array::Array; -use arrow::compute::SortOptions; +use arrow::compute::{concat, SortOptions}; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::WindowFrame; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::{Accumulator, WindowFrame, WindowFrameUnits}; -use crate::window::window_expr::reverse_order_bys; -use crate::window::AggregateWindowExpr; +use crate::window::window_expr::{reverse_order_bys, WindowFn, WindowFunctionState}; +use crate::window::{ + AggregateWindowExpr, PartitionBatches, PartitionWindowAggStates, WindowAggState, + WindowState, +}; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use crate::{window::WindowExpr, AggregateExpr}; @@ -92,50 +95,75 @@ impl WindowExpr for SlidingAggregateWindowExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let sort_options: Vec = - self.order_by.iter().map(|o| o.options).collect(); - let mut row_wise_results: Vec = vec![]; - let mut accumulator = self.aggregate.create_sliding_accumulator()?; - let length = batch.num_rows(); - let (values, order_bys) = self.get_values_orderbys(batch)?; let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); let mut last_range = Range { start: 0, end: 0 }; + let mut idx = 0; + self.get_result_column( + &mut accumulator, + batch, + &mut window_frame_ctx, + &mut last_range, + &mut idx, + true, + ) + } - // We iterate on each row to perform a running calculation. - // First, cur_range is calculated, then it is compared with last_range. - for i in 0..length { - let cur_range = - window_frame_ctx.calculate_range(&order_bys, &sort_options, length, i)?; - let value = if cur_range.start == cur_range.end { - // We produce None if the window is empty. - ScalarValue::try_from(self.aggregate.field()?.data_type())? - } else { - // Accumulate any new rows that have entered the window: - let update_bound = cur_range.end - last_range.end; - if update_bound > 0 { - let update: Vec = values - .iter() - .map(|v| v.slice(last_range.end, update_bound)) - .collect(); - accumulator.update_batch(&update)? - } - // Remove rows that have now left the window: - let retract_bound = cur_range.start - last_range.start; - if retract_bound > 0 { - let retract: Vec = values - .iter() - .map(|v| v.slice(last_range.start, retract_bound)) - .collect(); - accumulator.retract_batch(&retract)? - } - accumulator.evaluate()? + fn evaluate_stateful( + &self, + partition_batches: &PartitionBatches, + window_agg_state: &mut PartitionWindowAggStates, + ) -> Result<()> { + let field = self.aggregate.field()?; + let out_type = field.data_type(); + for (partition_row, partition_batch_state) in partition_batches.iter() { + if !window_agg_state.contains_key(partition_row) { + let accumulator = self.aggregate.create_sliding_accumulator()?; + window_agg_state.insert( + partition_row.clone(), + WindowState { + state: WindowAggState::new( + out_type, + WindowFunctionState::AggregateState(vec![]), + )?, + window_fn: WindowFn::Aggregate(accumulator), + }, + ); }; - row_wise_results.push(value); - last_range = cur_range; + let window_state = + window_agg_state.get_mut(partition_row).ok_or_else(|| { + DataFusionError::Execution("Cannot find state".to_string()) + })?; + let accumulator = match &mut window_state.window_fn { + WindowFn::Aggregate(accumulator) => accumulator, + _ => unreachable!(), + }; + let mut state = &mut window_state.state; + state.is_end = partition_batch_state.is_end; + + let mut idx = state.last_calculated_index; + let mut last_range = state.window_frame_range.clone(); + let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + let out_col = self.get_result_column( + accumulator, + &partition_batch_state.record_batch, + &mut window_frame_ctx, + &mut last_range, + &mut idx, + state.is_end, + )?; + state.last_calculated_index = idx; + state.window_frame_range = last_range.clone(); + + state.out_col = concat(&[&state.out_col, &out_col])?; + let num_rows = partition_batch_state.record_batch.num_rows(); + state.n_row_result_missing = num_rows - state.last_calculated_index; + + state.window_function_state = + WindowFunctionState::AggregateState(accumulator.state()?); } - ScalarValue::iter_to_array(row_wise_results.into_iter()) + Ok(()) } fn partition_by(&self) -> &[Arc] { @@ -170,4 +198,96 @@ impl WindowExpr for SlidingAggregateWindowExpr { } }) } + + fn uses_bounded_memory(&self) -> bool { + // NOTE: Currently, groups queries do not support the bounded memory variant. + self.aggregate.supports_bounded_execution() + && !self.window_frame.start_bound.is_unbounded() + && !self.window_frame.end_bound.is_unbounded() + && !matches!(self.window_frame.units, WindowFrameUnits::Groups) + } +} + +impl SlidingAggregateWindowExpr { + /// For given range calculate accumulator result inside range on value_slice and + /// update accumulator state + fn get_aggregate_result_inside_range( + &self, + last_range: &Range, + cur_range: &Range, + value_slice: &[ArrayRef], + accumulator: &mut Box, + ) -> Result { + let value = if cur_range.start == cur_range.end { + // We produce None if the window is empty. + ScalarValue::try_from(self.aggregate.field()?.data_type())? + } else { + // Accumulate any new rows that have entered the window: + let update_bound = cur_range.end - last_range.end; + if update_bound > 0 { + let update: Vec = value_slice + .iter() + .map(|v| v.slice(last_range.end, update_bound)) + .collect(); + accumulator.update_batch(&update)? + } + // Remove rows that have now left the window: + let retract_bound = cur_range.start - last_range.start; + if retract_bound > 0 { + let retract: Vec = value_slice + .iter() + .map(|v| v.slice(last_range.start, retract_bound)) + .collect(); + accumulator.retract_batch(&retract)? + } + accumulator.evaluate()? + }; + Ok(value) + } + + fn get_result_column( + &self, + accumulator: &mut Box, + record_batch: &RecordBatch, + window_frame_ctx: &mut WindowFrameContext, + last_range: &mut Range, + idx: &mut usize, + is_end: bool, + ) -> Result { + let (values, order_bys) = self.get_values_orderbys(record_batch)?; + // We iterate on each row to perform a running calculation. + let length = values[0].len(); + let sort_options: Vec = + self.order_by.iter().map(|o| o.options).collect(); + let mut row_wise_results: Vec = vec![]; + let field = self.aggregate.field()?; + let out_type = field.data_type(); + while *idx < length { + let cur_range = window_frame_ctx.calculate_range( + &order_bys, + &sort_options, + length, + *idx, + )?; + // Exit if range end index is length, need kind of flag to stop + if cur_range.end == length && !is_end { + break; + } + let value = self.get_aggregate_result_inside_range( + last_range, + &cur_range, + &values, + accumulator, + )?; + row_wise_results.push(value); + last_range.start = cur_range.start; + last_range.end = cur_range.end; + *idx += 1; + } + Ok(if row_wise_results.is_empty() { + ScalarValue::try_from(out_type)?.to_array_of_size(0) + } else { + ScalarValue::iter_to_array(row_wise_results.into_iter())? + }) + } } diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index a718fa4cd3b36..656b6723b0d69 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. +use crate::window::partition_evaluator::PartitionEvaluator; use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::compute::kernels::partition::lexicographical_partition_ranges; use arrow::compute::kernels::sort::SortColumn; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::{reverse_sort_options, DataFusionError, Result}; -use datafusion_expr::WindowFrame; +use arrow_schema::DataType; +use datafusion_common::{reverse_sort_options, DataFusionError, Result, ScalarValue}; +use datafusion_expr::{Accumulator, WindowFrame}; +use indexmap::IndexMap; use std::any::Any; use std::fmt::Debug; use std::ops::Range; @@ -61,6 +64,18 @@ pub trait WindowExpr: Send + Sync + Debug { /// evaluate the window function values against the batch fn evaluate(&self, batch: &RecordBatch) -> Result; + /// evaluate the window function values against the batch + fn evaluate_stateful( + &self, + _partition_batches: &PartitionBatches, + _window_agg_state: &mut PartitionWindowAggStates, + ) -> Result<()> { + Err(DataFusionError::Internal(format!( + "evaluate_stateful is not implemented for {}", + self.name() + ))) + } + /// evaluate the partition points given the sort columns; if the sort columns are /// empty then the result will be a single element vec of the whole column rows. fn evaluate_partition_points( @@ -116,6 +131,10 @@ pub trait WindowExpr: Send + Sync + Debug { /// Get the window frame of this [WindowExpr]. fn get_window_frame(&self) -> &Arc; + /// Return a flag indicating whether this [WindowExpr] can run with + /// bounded memory. + fn uses_bounded_memory(&self) -> bool; + /// Get the reverse expression of this [WindowExpr]. fn get_reverse_expr(&self) -> Option>; } @@ -132,3 +151,118 @@ pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec), + Aggregate(Box), +} + +/// State for RANK(percent_rank, rank, dense_rank) +/// builtin window function +#[derive(Debug, Clone, Default)] +pub struct RankState { + /// The last values for rank as these values change, we increase n_rank + pub last_rank_data: Vec, + /// The index where last_rank_boundary is started + pub last_rank_boundary: usize, + /// Rank number kept from the start + pub n_rank: usize, +} + +/// State for 'ROW_NUMBER' builtin window function +#[derive(Debug, Clone, Default)] +pub struct NumRowsState { + pub n_rows: usize, +} + +#[derive(Debug, Clone, Default)] +pub struct NthValueState { + pub range: Range, +} + +#[derive(Debug, Clone, Default)] +pub struct LeadLagState { + pub idx: usize, +} + +#[derive(Debug, Clone, Default)] +pub enum BuiltinWindowState { + Rank(RankState), + NumRows(NumRowsState), + NthValue(NthValueState), + LeadLag(LeadLagState), + #[default] + Default, +} +#[derive(Debug)] +pub enum WindowFunctionState { + /// Different Aggregate functions may have different state definitions + /// In [Accumulator] trait, [fn state(&self) -> Result>] implementation + /// dictates that. + AggregateState(Vec), + /// BuiltinWindowState + BuiltinWindowState(BuiltinWindowState), +} + +#[derive(Debug)] +pub struct WindowAggState { + /// The range that we calculate the window function + pub window_frame_range: Range, + /// The index of the last row that its result is calculated inside the partition record batch buffer. + pub last_calculated_index: usize, + /// The offset of the deleted row number + pub offset_pruned_rows: usize, + /// State of the window function, required to calculate its result + // For instance, for ROW_NUMBER we keep the row index counter to generate correct result + pub window_function_state: WindowFunctionState, + /// Stores the results calculated by window frame + pub out_col: ArrayRef, + /// Keeps track of how many rows should be generated to be in sync with input record_batch. + // (For each row in the input record batch we need to generate a window result). + pub n_row_result_missing: usize, + /// flag indicating whether we have received all data for this partition + pub is_end: bool, +} + +/// State for each unique partition determined according to PARTITION BY column(s) +#[derive(Debug)] +pub struct PartitionBatchState { + /// The record_batch belonging to current partition + pub record_batch: RecordBatch, + /// flag indicating whether we have received all data for this partition + pub is_end: bool, +} + +/// key for IndexMap for each unique partition +/// For instance, if window frame is OVER(PARTITION BY a,b) +/// PartitionKey would consist of unique [a,b] pairs +pub type PartitionKey = Vec; + +#[derive(Debug)] +pub struct WindowState { + pub state: WindowAggState, + pub window_fn: WindowFn, +} +pub type PartitionWindowAggStates = IndexMap; + +/// The IndexMap (i.e. an ordered HashMap) where record batches are separated for each partition. +pub type PartitionBatches = IndexMap; + +impl WindowAggState { + pub fn new( + out_type: &DataType, + window_function_state: WindowFunctionState, + ) -> Result { + let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0); + Ok(Self { + window_frame_range: Range { start: 0, end: 0 }, + last_calculated_index: 0, + offset_pruned_rows: 0, + window_function_state, + out_col: empty_out_col, + n_row_result_missing: 0, + is_end: false, + }) + } +} diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 4002a49cf585b..dfd878275181c 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -50,7 +50,10 @@ pub fn partitions_to_sorted_vec(partitions: &[Vec]) -> Vec, rng: &mut StdRng) -> Vec { +pub fn add_empty_batches( + batches: Vec, + rng: &mut StdRng, +) -> Vec { let schema = batches[0].schema(); batches