From 1ab77083c37a370a33ccc04b78274f7c46d3fdf4 Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Thu, 15 Dec 2022 19:29:11 +0800 Subject: [PATCH 1/8] Avoid generate duplicate sort Keys from Window Expressions, fix bug when decide Window Expressions ordering --- datafusion/core/src/physical_plan/planner.rs | 4 +- datafusion/core/tests/sql/window.rs | 115 ++++++++++- datafusion/expr/src/logical_plan/builder.rs | 32 ++- datafusion/expr/src/utils.rs | 198 ++++++++++++++++++- datafusion/sql/src/planner.rs | 1 - 5 files changed, 328 insertions(+), 22 deletions(-) diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 5b39d71f073de..bad4193880ec8 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -563,12 +563,12 @@ impl DefaultPhysicalPlanner { } _ => unreachable!(), }; - let sort_keys = get_sort_keys(&window_expr[0]); + let sort_keys = get_sort_keys(&window_expr[0])?; if window_expr.len() > 1 { debug_assert!( window_expr[1..] .iter() - .all(|expr| get_sort_keys(expr) == sort_keys), + .all(|expr| get_sort_keys(expr).unwrap() == sort_keys), "all window expressions shall have the same sort keys, as guaranteed by logical planning" ); } diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index b550e7f5dd60b..f0cf63fa692af 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1642,7 +1642,120 @@ async fn test_window_agg_sort() -> Result<()> { assert_eq!( expected, actual_trim_last, "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected, actual + expected, actual_trim_last + ); + Ok(()) +} + +#[tokio::test] +async fn over_order_by_sort_keys_sorting_prefix_compacting() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + + let sql = "SELECT c2, MAX(c9) OVER (ORDER BY c2), SUM(c9) OVER (), MIN(c9) OVER (ORDER BY c2, c9) from aggregate_test_100"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c2@3 as c2, MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as MAX(aggregate_test_100.c9), SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@0 as SUM(aggregate_test_100.c9), MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as MIN(aggregate_test_100.c9)]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }]", + " WindowAggExec: wdw=[MAX(aggregate_test_100.c9): Ok(Field { name: \"MAX(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]", + " WindowAggExec: wdw=[MIN(aggregate_test_100.c9): Ok(Field { name: \"MIN(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]", + " SortExec: [c2@0 ASC NULLS LAST,c9@1 ASC NULLS LAST]" + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual_trim_last + ); + Ok(()) +} + +/// FIXME: for now we are not detecting prefix of sorting keys in order to re-arrange with global and save one SortExec +#[tokio::test] +async fn over_order_by_sort_keys_sorting_global_order_compacting() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + + let sql = "SELECT c2, MAX(c9) OVER (ORDER BY c9, c2), SUM(c9) OVER (), MIN(c9) OVER (ORDER BY c2, c9) from aggregate_test_100 ORDER BY c2"; + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "SortExec: [c2@0 ASC NULLS LAST]", + " CoalescePartitionsExec", + " ProjectionExec: expr=[c2@3 as c2, MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as MAX(aggregate_test_100.c9), SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@0 as SUM(aggregate_test_100.c9), MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as MIN(aggregate_test_100.c9)]", + " RepartitionExec: partitioning=RoundRobinBatch(10)", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }]", + " WindowAggExec: wdw=[MAX(aggregate_test_100.c9): Ok(Field { name: \"MAX(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]", + " SortExec: [c9@2 ASC NULLS LAST,c2@1 ASC NULLS LAST]", + " WindowAggExec: wdw=[MIN(aggregate_test_100.c9): Ok(Field { name: \"MIN(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]", + " SortExec: [c2@0 ASC NULLS LAST,c9@1 ASC NULLS LAST]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual_trim_last + ); + Ok(()) +} + +#[tokio::test] +async fn test_window_partition_by_order_by() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + + let sql = "SELECT \ + SUM(c4) OVER(PARTITION BY c1, c2 ORDER BY c2 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ + COUNT(*) OVER(PARTITION BY c1 ORDER BY c2 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) \ + FROM aggregate_test_100"; + + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(sql).expect(&msg); + let state = ctx.state(); + let logical_plan = state.optimize(&plan)?; + let physical_plan = state.create_physical_plan(&logical_plan).await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as COUNT(UInt8(1))]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", + " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 10)", + " RepartitionExec: partitioning=RoundRobinBatch(10)", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual_trim_last ); Ok(()) } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 1c3b813f70f71..45901e55813e0 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -21,7 +21,7 @@ use crate::expr_rewriter::{ coerce_plan_expr_for_schema, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs, }; use crate::type_coercion::binary::comparison_coercion; -use crate::utils::{columnize_expr, exprlist_to_fields, from_plan}; +use crate::utils::{columnize_expr, compare_sort_expr, exprlist_to_fields, from_plan}; use crate::{and, binary_expr, Operator}; use crate::{ logical_plan::{ @@ -42,6 +42,7 @@ use datafusion_common::{ ToDFSchema, }; use std::any::Any; +use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; use std::sync::Arc; @@ -249,16 +250,29 @@ impl LogicalPlanBuilder { ) -> Result { let mut plan = input; let mut groups = group_window_expr_by_sort_keys(&window_exprs)?; - // sort by sort_key len descending, so that more deeply sorted plans gets nested further - // down as children; to further mimic the behavior of PostgreSQL, we want stable sort - // and a reverse so that tieing sort keys are reversed in order; note that by this rule - // if there's an empty over, it'll be at the top level - groups.sort_by(|(key_a, _), (key_b, _)| key_a.len().cmp(&key_b.len())); - groups.reverse(); + // To align with the behavior of PostgreSQL, we want the sort_keys sorted as same rule as PostgreSQL that first + // we compare the sort key themselves and if one window's sort keys are a prefix of another + // put the window with more sort keys first. so more deeply sorted plans gets nested further down as children. + // The sort_by() implementation here is a stable sort. + // Note that by this rule if there's an empty over, it'll be at the top level + groups.sort_by(|(key_a, _), (key_b, _)| { + for (first, second) in key_a.iter().zip(key_b.iter()) { + let key_ordering = compare_sort_expr(first, second, plan.schema()); + match key_ordering { + Ordering::Less => { + return Ordering::Less; + } + Ordering::Greater => { + return Ordering::Greater; + } + Ordering::Equal => {} + } + } + key_b.len().cmp(&key_a.len()) + }); for (_, exprs) in groups { let window_exprs = exprs.into_iter().cloned().collect::>(); - // the partition and sort itself is done at physical level, see physical_planner's - // fn create_initial_plan + // the partition and sort itself is done at physical level, see the BasicEnforcement rule plan = LogicalPlanBuilder::from(plan) .window(window_exprs)? .build()?; diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index e920915c088be..fb9b6aed620df 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -30,6 +30,7 @@ use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; +use std::cmp::Ordering; use std::collections::HashSet; use std::sync::Arc; @@ -202,20 +203,113 @@ pub fn expand_qualified_wildcard( type WindowSortKey = Vec; /// Generate a sort key for a given window expr's partition_by and order_bu expr -pub fn generate_sort_key(partition_by: &[Expr], order_by: &[Expr]) -> WindowSortKey { - let mut sort_key = vec![]; +pub fn generate_sort_key( + partition_by: &[Expr], + order_by: &[Expr], +) -> Result { + let normalized_order_by_keys = order_by + .iter() + .map(|e| match e { + Expr::Sort { + expr, + asc: _, + nulls_first: _, + } => Ok(Expr::Sort { + expr: expr.clone(), + asc: true, + nulls_first: false, + }), + _ => Err(DataFusionError::Plan( + "Order by only accepts sort expressions".to_string(), + )), + }) + .collect::>>()?; + + let mut final_sort_keys = vec![]; partition_by.iter().for_each(|e| { - let e = e.clone().sort(true, true); - if !sort_key.contains(&e) { - sort_key.push(e); + // By default, create sort key with ASC is true and NULLS LAST to be consistent with + // postgres rule: https://www.postgresql.org/docs/current/queries-order.html + let e = e.clone().sort(true, false); + if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) { + let order_by_key = &order_by[pos]; + if !final_sort_keys.contains(order_by_key) { + final_sort_keys.push(order_by_key.clone()); + } + } else if !final_sort_keys.contains(&e) { + final_sort_keys.push(e); } }); + order_by.iter().for_each(|e| { - if !sort_key.contains(e) { - sort_key.push(e.clone()); + if !final_sort_keys.contains(e) { + final_sort_keys.push(e.clone()); } }); - sort_key + Ok(final_sort_keys) +} + +/// Compare the sort expr as PostgreSQL's common_prefix_cmp(): +/// https://github.com/postgres/postgres/blob/master/src/backend/optimizer/plan/planner.c +pub fn compare_sort_expr( + sort_expr_a: &Expr, + sort_expr_b: &Expr, + schema: &DFSchemaRef, +) -> Ordering { + match (sort_expr_a, sort_expr_b) { + ( + Expr::Sort { + expr: expr_a, + asc: asc_a, + nulls_first: nulls_first_a, + }, + Expr::Sort { + expr: expr_b, + asc: asc_b, + nulls_first: nulls_first_b, + }, + ) => { + let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema); + let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema); + for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) { + match idx_a.cmp(idx_b) { + Ordering::Less => { + return Ordering::Less; + } + Ordering::Greater => { + return Ordering::Greater; + } + Ordering::Equal => {} + } + } + match ref_indexes_a.len().cmp(&ref_indexes_b.len()) { + Ordering::Less => return Ordering::Greater, + Ordering::Greater => { + return Ordering::Less; + } + Ordering::Equal => {} + } + match (asc_a, asc_b) { + (true, false) => { + return Ordering::Greater; + } + (false, true) => { + return Ordering::Less; + } + _ => {} + } + match (nulls_first_a, nulls_first_b) { + (true, false) => { + return Ordering::Less; + } + (false, true) => { + return Ordering::Greater; + } + _ => {} + } + Ordering::Equal + } + _ => Ordering::Equal, + } } /// group a slice of window expression expr by their order by expressions @@ -225,7 +319,7 @@ pub fn group_window_expr_by_sort_keys( let mut result = vec![]; window_expr.iter().try_for_each(|expr| match expr { Expr::WindowFunction { partition_by, order_by, .. } => { - let sort_key = generate_sort_key(partition_by, order_by); + let sort_key = generate_sort_key(partition_by, order_by)?; if let Some((_, values)) = result.iter_mut().find( |group: &&mut (WindowSortKey, Vec<&Expr>)| matches!(group, (key, _) if *key == sort_key), ) { @@ -755,6 +849,48 @@ pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result { } } +/// Recursively walk an expression tree, collecting the column indexes +/// referenced in the expression +struct ColumnIndexesCollector<'a> { + schema: &'a DFSchemaRef, + indexes: Vec, +} + +impl ExpressionVisitor for ColumnIndexesCollector<'_> { + fn pre_visit(mut self, expr: &Expr) -> Result> + where + Self: ExpressionVisitor, + { + match expr { + Expr::Column(qc) => { + if let Ok(idx) = self.schema.index_of_column(qc) { + self.indexes.push(idx); + } + } + Expr::Literal(_) => { + self.indexes.push(std::usize::MAX); + } + _ => {} + } + Ok(Recursion::Continue(self)) + } +} + +pub(crate) fn find_column_indexes_referenced_by_expr( + e: &Expr, + schema: &DFSchemaRef, +) -> Vec { + // As the `ExpressionVisitor` impl above always returns Ok, this + // "can't" error + let ColumnIndexesCollector { indexes, .. } = e + .accept(ColumnIndexesCollector { + schema, + indexes: vec![], + }) + .expect("Unexpected error"); + indexes +} + /// can this data type be used in hash join equal conditions?? /// data types here come from function 'equal_rows', if more data types are supported /// in equal_rows(hash join), add those data types here to generate join logical plan. @@ -984,4 +1120,48 @@ mod tests { assert_eq!(expected, result); Ok(()) } + + #[test] + fn avoid_generate_duplicate_sort_keys() -> Result<()> { + let asc_or_desc = [true, false]; + let nulls_first_or_last = [true, false]; + let partition_by = &[col("age"), col("name"), col("created_at")]; + for asc_ in asc_or_desc { + for nulls_first_ in nulls_first_or_last { + let order_by = &[ + Expr::Sort { + expr: Box::new(col("age")), + asc: asc_, + nulls_first: nulls_first_, + }, + Expr::Sort { + expr: Box::new(col("name")), + asc: asc_, + nulls_first: nulls_first_, + }, + ]; + + let expected = vec![ + Expr::Sort { + expr: Box::new(col("age")), + asc: asc_, + nulls_first: nulls_first_, + }, + Expr::Sort { + expr: Box::new(col("name")), + asc: asc_, + nulls_first: nulls_first_, + }, + Expr::Sort { + expr: Box::new(col("created_at")), + asc: true, + nulls_first: false, + }, + ]; + let result = generate_sort_key(partition_by, order_by)?; + assert_eq!(expected, result); + } + } + Ok(()) + } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 6cf279eda4178..ae40b7b6c2a6c 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -5167,7 +5167,6 @@ mod tests { /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8) /// ``` /// - /// FIXME: for now we are not detecting prefix of sorting keys in order to save one sort exec phase #[test] fn over_order_by_sort_keys_sorting_prefix_compacting() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders"; From f5b6dfb1fa50591e95ebc624f964062688aa27d1 Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Thu, 15 Dec 2022 19:35:23 +0800 Subject: [PATCH 2/8] fix test comment --- datafusion/core/tests/sql/window.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index f0cf63fa692af..b544f751d6556 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1695,7 +1695,7 @@ async fn over_order_by_sort_keys_sorting_global_order_compacting() -> Result<()> let logical_plan = state.optimize(&plan)?; let physical_plan = state.create_physical_plan(&logical_plan).await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - // Only 1 SortExec was added + // 3 SortExec are added let expected = { vec![ "SortExec: [c2@0 ASC NULLS LAST]", From 43687ec3cda52b2acb4450e4de6ab0a5c6884b18 Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Thu, 15 Dec 2022 20:05:48 +0800 Subject: [PATCH 3/8] fix UT --- datafusion/core/tests/sql/window.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index b544f751d6556..72af7ae8a7c51 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1649,7 +1649,7 @@ async fn test_window_agg_sort() -> Result<()> { #[tokio::test] async fn over_order_by_sort_keys_sorting_prefix_compacting() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(2)); register_aggregate_csv(&ctx).await?; let sql = "SELECT c2, MAX(c9) OVER (ORDER BY c2), SUM(c9) OVER (), MIN(c9) OVER (ORDER BY c2, c9) from aggregate_test_100"; @@ -1685,7 +1685,7 @@ async fn over_order_by_sort_keys_sorting_prefix_compacting() -> Result<()> { /// FIXME: for now we are not detecting prefix of sorting keys in order to re-arrange with global and save one SortExec #[tokio::test] async fn over_order_by_sort_keys_sorting_global_order_compacting() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(2)); register_aggregate_csv(&ctx).await?; let sql = "SELECT c2, MAX(c9) OVER (ORDER BY c9, c2), SUM(c9) OVER (), MIN(c9) OVER (ORDER BY c2, c9) from aggregate_test_100 ORDER BY c2"; @@ -1701,7 +1701,7 @@ async fn over_order_by_sort_keys_sorting_global_order_compacting() -> Result<()> "SortExec: [c2@0 ASC NULLS LAST]", " CoalescePartitionsExec", " ProjectionExec: expr=[c2@3 as c2, MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as MAX(aggregate_test_100.c9), SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@0 as SUM(aggregate_test_100.c9), MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as MIN(aggregate_test_100.c9)]", - " RepartitionExec: partitioning=RoundRobinBatch(10)", + " RepartitionExec: partitioning=RoundRobinBatch(2)", " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }]", " WindowAggExec: wdw=[MAX(aggregate_test_100.c9): Ok(Field { name: \"MAX(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]", " SortExec: [c9@2 ASC NULLS LAST,c2@1 ASC NULLS LAST]", @@ -1723,7 +1723,7 @@ async fn over_order_by_sort_keys_sorting_global_order_compacting() -> Result<()> #[tokio::test] async fn test_window_partition_by_order_by() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(2)); register_aggregate_csv(&ctx).await?; let sql = "SELECT \ @@ -1744,8 +1744,8 @@ async fn test_window_partition_by_order_by() -> Result<()> { " WindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 10)", - " RepartitionExec: partitioning=RoundRobinBatch(10)", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 2)", + " RepartitionExec: partitioning=RoundRobinBatch(2)", ] }; From c2d49c10d702557bbc765e72b4fa6ae4ff7cc942 Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Thu, 15 Dec 2022 20:38:58 +0800 Subject: [PATCH 4/8] fix UT --- datafusion/core/tests/sql/select.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 35d8dc6d6e089..fe4e5271f93d4 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -650,7 +650,7 @@ async fn query_on_string_dictionary() -> Result<()> { ]) .unwrap(); - let ctx = SessionContext::new(); + let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(4)); ctx.register_batch("test", batch)?; // Basic SELECT From 952a30ba93334b9acc9b0513d6f4e198a8291e01 Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Fri, 16 Dec 2022 02:44:16 +0800 Subject: [PATCH 5/8] Fix create Sort Columns from Partition Columns in WindowExpr, add more UTs for Null String sort testing --- datafusion/core/tests/sql/select.rs | 84 ++++++++++++++++++- .../physical-expr/src/window/window_expr.rs | 17 +++- 2 files changed, 97 insertions(+), 4 deletions(-) diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index fe4e5271f93d4..e8636f32f3db9 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -650,7 +650,7 @@ async fn query_on_string_dictionary() -> Result<()> { ]) .unwrap(); - let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(4)); + let ctx = SessionContext::new(); ctx.register_batch("test", batch)?; // Basic SELECT @@ -835,6 +835,88 @@ async fn query_on_string_dictionary() -> Result<()> { Ok(()) } +#[tokio::test] +async fn sort_on_window_null_string() -> Result<()> { + let d1: DictionaryArray = + vec![Some("one"), None, Some("three")].into_iter().collect(); + let d2: StringArray = vec![Some("ONE"), None, Some("THREE")].into_iter().collect(); + let d3: LargeStringArray = + vec![Some("One"), None, Some("Three")].into_iter().collect(); + + let batch = RecordBatch::try_from_iter(vec![ + ("d1", Arc::new(d1) as ArrayRef), + ("d2", Arc::new(d2) as ArrayRef), + ("d3", Arc::new(d3) as ArrayRef), + ]) + .unwrap(); + + let ctx = SessionContext::with_config(SessionConfig::new().with_target_partitions(2)); + ctx.register_batch("test", batch)?; + + let sql = + "SELECT d1, row_number() OVER (partition by d1) as rn1 FROM test order by d1 asc"; + + let actual = execute_to_batches(&ctx, sql).await; + // NULLS LAST + let expected = vec![ + "+-------+-----+", + "| d1 | rn1 |", + "+-------+-----+", + "| one | 1 |", + "| three | 1 |", + "| | 1 |", + "+-------+-----+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = "SELECT d2, row_number() OVER (partition by d2) as rn1 FROM test"; + let actual = execute_to_batches(&ctx, sql).await; + // NULLS LAST + let expected = vec![ + "+-------+-----+", + "| d2 | rn1 |", + "+-------+-----+", + "| ONE | 1 |", + "| THREE | 1 |", + "| | 1 |", + "+-------+-----+", + ]; + assert_batches_eq!(expected, &actual); + + let sql = + "SELECT d2, row_number() OVER (partition by d2 order by d2 desc) as rn1 FROM test"; + + let actual = execute_to_batches(&ctx, sql).await; + // NULLS FIRST + let expected = vec![ + "+-------+-----+", + "| d2 | rn1 |", + "+-------+-----+", + "| | 1 |", + "| THREE | 1 |", + "| ONE | 1 |", + "+-------+-----+", + ]; + assert_batches_eq!(expected, &actual); + + // FIXME sort on LargeUtf8 String has bug. + // let sql = + // "SELECT d3, row_number() OVER (partition by d3) as rn1 FROM test"; + // let actual = execute_to_batches(&ctx, sql).await; + // let expected = vec![ + // "+-------+-----+", + // "| d3 | rn1 |", + // "+-------+-----+", + // "| | 1 |", + // "| One | 1 |", + // "| Three | 1 |", + // "+-------+-----+", + // ]; + // assert_batches_eq!(expected, &actual); + + Ok(()) +} + #[tokio::test] async fn filter_with_time32second() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 209e0544f2fa6..2fbc6e2c4c8e4 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -91,9 +91,20 @@ pub trait WindowExpr: Send + Sync + Debug { self.partition_by() .iter() .map(|expr| { - PhysicalSortExpr { - expr: expr.clone(), - options: SortOptions::default(), + if let Some(idx) = + self.order_by().iter().position(|key| key.expr.eq(expr)) + { + self.order_by()[idx].clone() + } else { + // When ASC is true, by default NULLS LAST to be consistent with PostgreSQL's rule: + // https://www.postgresql.org/docs/current/queries-order.html + PhysicalSortExpr { + expr: expr.clone(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + } } .evaluate_to_sort_column(batch) }) From 55653418aa77ef1a373428a1e3d2c6546581d231 Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Fri, 16 Dec 2022 11:05:37 +0800 Subject: [PATCH 6/8] fix clippy check --- datafusion/expr/src/field_util.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/field_util.rs b/datafusion/expr/src/field_util.rs index 629f3952d86e4..94efeda806483 100644 --- a/datafusion/expr/src/field_util.rs +++ b/datafusion/expr/src/field_util.rs @@ -28,7 +28,7 @@ use datafusion_common::{DataFusionError, Result, ScalarValue}; pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { match (data_type, key) { (DataType::List(lt), ScalarValue::Int64(Some(i))) => { - Ok(Field::new(&i.to_string(), lt.data_type().clone(), true)) + Ok(Field::new(i.to_string(), lt.data_type().clone(), true)) } (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { if s.is_empty() { From 2b0ac0fee625e4b36079c2f9f0dd94cd13d5109f Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Sun, 18 Dec 2022 21:35:05 +0800 Subject: [PATCH 7/8] tiny change --- datafusion/expr/src/utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index fb9b6aed620df..ce06031d13b5e 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -228,7 +228,7 @@ pub fn generate_sort_key( let mut final_sort_keys = vec![]; partition_by.iter().for_each(|e| { // By default, create sort key with ASC is true and NULLS LAST to be consistent with - // postgres rule: https://www.postgresql.org/docs/current/queries-order.html + // PostgreSQL's rule: https://www.postgresql.org/docs/current/queries-order.html let e = e.clone().sort(true, false); if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) { let order_by_key = &order_by[pos]; From a7f96676c8adec1614ea72e3125de8638af056f2 Mon Sep 17 00:00:00 2001 From: "mingmwang@ebay.com" Date: Mon, 19 Dec 2022 10:34:20 +0800 Subject: [PATCH 8/8] merge with upstream, fix issue --- datafusion/expr/src/utils.rs | 47 ++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 276204f3f37fb..2577c3a1970ca 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -17,13 +17,14 @@ //! Expression utilities +use crate::expr::Sort; use crate::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; use crate::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; use crate::logical_plan::builder::build_join_schema; use crate::logical_plan::{ Aggregate, Analyze, CreateMemoryTable, CreateView, Distinct, Extension, Filter, Join, - Limit, Partitioning, Prepare, Projection, Repartition, Sort, Subquery, SubqueryAlias, - Union, Values, Window, + Limit, Partitioning, Prepare, Projection, Repartition, Sort as SortPlan, Subquery, + SubqueryAlias, Union, Values, Window, }; use crate::{ Cast, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, TableScan, TryCast, @@ -212,15 +213,9 @@ pub fn generate_sort_key( let normalized_order_by_keys = order_by .iter() .map(|e| match e { - Expr::Sort { - expr, - asc: _, - nulls_first: _, - } => Ok(Expr::Sort { - expr: expr.clone(), - asc: true, - nulls_first: false, - }), + Expr::Sort(Sort { expr, .. }) => { + Ok(Expr::Sort(Sort::new(expr.clone(), true, false))) + } _ => Err(DataFusionError::Plan( "Order by only accepts sort expressions".to_string(), )), @@ -259,16 +254,16 @@ pub fn compare_sort_expr( ) -> Ordering { match (sort_expr_a, sort_expr_b) { ( - Expr::Sort { + Expr::Sort(Sort { expr: expr_a, asc: asc_a, nulls_first: nulls_first_a, - }, - Expr::Sort { + }), + Expr::Sort(Sort { expr: expr_b, asc: asc_b, nulls_first: nulls_first_b, - }, + }), ) => { let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema); let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema); @@ -558,7 +553,7 @@ pub fn from_plan( expr[group_expr.len()..].to_vec(), schema.clone(), )?)), - LogicalPlan::Sort(Sort { fetch, .. }) => Ok(LogicalPlan::Sort(Sort { + LogicalPlan::Sort(SortPlan { fetch, .. }) => Ok(LogicalPlan::Sort(SortPlan { expr: expr.to_vec(), input: Arc::new(inputs[0].clone()), fetch: *fetch, @@ -1088,34 +1083,34 @@ mod tests { for asc_ in asc_or_desc { for nulls_first_ in nulls_first_or_last { let order_by = &[ - Expr::Sort { + Expr::Sort(Sort { expr: Box::new(col("age")), asc: asc_, nulls_first: nulls_first_, - }, - Expr::Sort { + }), + Expr::Sort(Sort { expr: Box::new(col("name")), asc: asc_, nulls_first: nulls_first_, - }, + }), ]; let expected = vec![ - Expr::Sort { + Expr::Sort(Sort { expr: Box::new(col("age")), asc: asc_, nulls_first: nulls_first_, - }, - Expr::Sort { + }), + Expr::Sort(Sort { expr: Box::new(col("name")), asc: asc_, nulls_first: nulls_first_, - }, - Expr::Sort { + }), + Expr::Sort(Sort { expr: Box::new(col("created_at")), asc: true, nulls_first: false, - }, + }), ]; let result = generate_sort_key(partition_by, order_by)?; assert_eq!(expected, result);