From eb093a87d362ce6b45f4657dbcf2398420506a20 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 7 Jan 2023 13:39:32 -0500 Subject: [PATCH 1/2] Implement retract_batch for AvgAccumulator, Add avg to custom window frame tests --- datafusion/core/tests/sql/window.rs | 38 ++++++++++--------- .../physical-expr/src/aggregate/average.rs | 13 +++++++ 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 5ca49cff2883a..0c3ecfa59ba9d 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -503,21 +503,22 @@ async fn window_frame_rows_preceding() -> Result<()> { register_aggregate_csv(&ctx).await?; let sql = "SELECT \ SUM(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ + AVG(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ COUNT(*) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)\ FROM aggregate_test_100 \ ORDER BY c9 \ LIMIT 5"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+----------------------------+-----------------+", - "| SUM(aggregate_test_100.c4) | COUNT(UInt8(1)) |", - "+----------------------------+-----------------+", - "| -48302 | 3 |", - "| 11243 | 3 |", - "| -51311 | 3 |", - "| -2391 | 3 |", - "| 46756 | 3 |", - "+----------------------------+-----------------+", + "+----------------------------+----------------------------+-----------------+", + "| SUM(aggregate_test_100.c4) | AVG(aggregate_test_100.c4) | COUNT(UInt8(1)) |", + "+----------------------------+----------------------------+-----------------+", + "| -48302 | -16100.666666666666 | 3 |", + "| 11243 | 3747.6666666666665 | 3 |", + "| -51311 | -17103.666666666668 | 3 |", + "| -2391 | -797 | 3 |", + "| 46756 | 15585.333333333334 | 3 |", + "+----------------------------+----------------------------+-----------------+", ]; assert_batches_eq!(expected, &actual); Ok(()) @@ -529,21 +530,22 @@ async fn window_frame_rows_preceding_with_partition_unique_order_by() -> Result< register_aggregate_csv(&ctx).await?; let sql = "SELECT \ SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ + AVG(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ COUNT(*) OVER(PARTITION BY c2 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)\ FROM aggregate_test_100 \ ORDER BY c9 \ LIMIT 5"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+----------------------------+-----------------+", - "| SUM(aggregate_test_100.c4) | COUNT(UInt8(1)) |", - "+----------------------------+-----------------+", - "| -38611 | 2 |", - "| 17547 | 2 |", - "| -1301 | 2 |", - "| 26638 | 3 |", - "| 26861 | 3 |", - "+----------------------------+-----------------+", + "+----------------------------+----------------------------+-----------------+", + "| SUM(aggregate_test_100.c4) | AVG(aggregate_test_100.c4) | COUNT(UInt8(1)) |", + "+----------------------------+----------------------------+-----------------+", + "| -38611 | -19305.5 | 2 |", + "| 17547 | 8773.5 | 2 |", + "| -1301 | -650.5 | 2 |", + "| 26638 | 13319 | 3 |", + "| 26861 | 8953.666666666666 | 3 |", + "+----------------------------+----------------------------+-----------------+", ]; assert_batches_eq!(expected, &actual); Ok(()) diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 12f84ca1f798e..91e2ab9150235 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -37,6 +37,7 @@ use datafusion_common::{downcast_value, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; use datafusion_row::accessor::RowAccessor; +use crate::aggregate::sum::sum_batch; /// AVG aggregate expression #[derive(Debug)] @@ -119,6 +120,10 @@ impl AggregateExpr for Avg { self.data_type.clone(), ))) } + + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(AvgAccumulator::try_new(&self.data_type)?)) + } } /// An accumulator to compute the average @@ -154,6 +159,14 @@ impl Accumulator for AvgAccumulator { Ok(()) } + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + self.count -= (values.len() - values.data().null_count()) as u64; + let delta = sum_batch(values, &self.sum.get_datatype())?; + self.sum = self.sum.sub(&delta)?; + Ok(()) + } + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { let counts = downcast_value!(states[0], UInt64Array); // counts are summed From 830a6d0efa0d89072ffdfaca2a5a8fceadc9a78d Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 7 Jan 2023 13:48:33 -0500 Subject: [PATCH 2/2] fmt --- datafusion/physical-expr/src/aggregate/average.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 91e2ab9150235..216bd56af8ef0 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -25,6 +25,7 @@ use crate::aggregate::row_accumulator::{ is_row_accumulator_support_dtype, RowAccumulator, }; use crate::aggregate::sum; +use crate::aggregate::sum::sum_batch; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; use arrow::compute; @@ -37,7 +38,6 @@ use datafusion_common::{downcast_value, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; use datafusion_row::accessor::RowAccessor; -use crate::aggregate::sum::sum_batch; /// AVG aggregate expression #[derive(Debug)]