From 927fc3c1c07aed6ec05c3fe47c1bdeced33298a0 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 7 Jan 2023 14:42:08 -0500 Subject: [PATCH 1/2] Wire up retract_batch for Stddev/StddevPop/Variance/VariancePop to --- datafusion/physical-expr/src/aggregate/stddev.rs | 12 ++++++++++++ datafusion/physical-expr/src/aggregate/variance.rs | 10 ++++++++++ 2 files changed, 22 insertions(+) diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs index 05dc56cff2c6f..dc9a4c796acd8 100644 --- a/datafusion/physical-expr/src/aggregate/stddev.rs +++ b/datafusion/physical-expr/src/aggregate/stddev.rs @@ -73,6 +73,10 @@ impl AggregateExpr for Stddev { Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) } + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) + } + fn state_fields(&self) -> Result> { Ok(vec![ Field::new( @@ -128,6 +132,10 @@ impl AggregateExpr for StddevPop { Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) } + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) + } + fn state_fields(&self) -> Result> { Ok(vec![ Field::new( @@ -184,6 +192,10 @@ impl Accumulator for StddevAccumulator { self.variance.update_batch(values) } + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.variance.retract_batch(values) + } + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { self.variance.merge_batch(states) } diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs index d1ccea7e1d7a5..ccac9f8ee0875 100644 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ b/datafusion/physical-expr/src/aggregate/variance.rs @@ -79,6 +79,10 @@ impl AggregateExpr for Variance { Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) } + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) + } + fn state_fields(&self) -> Result> { Ok(vec![ Field::new( @@ -136,6 +140,12 @@ impl AggregateExpr for VariancePop { )?)) } + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(VarianceAccumulator::try_new( + StatsType::Population, + )?)) + } + fn state_fields(&self) -> Result> { Ok(vec![ Field::new( From effc4728912684f63b2e4ce837ae1b1816796480 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 7 Jan 2023 15:18:53 -0500 Subject: [PATCH 2/2] Add test for Stddev/StddevPop/Variance/VariancePop with window frame --- datafusion/core/tests/sql/window.rs | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index 5ca49cff2883a..c1b43f4b2824f 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -523,6 +523,34 @@ async fn window_frame_rows_preceding() -> Result<()> { Ok(()) } +#[tokio::test] +async fn window_frame_rows_preceding_stddev_variance() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT \ + VAR(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ + VAR_POP(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ + STDDEV(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\ + STDDEV_POP(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)\ + FROM aggregate_test_100 \ + ORDER BY c9 \ + LIMIT 5"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+---------------------------------+------------------------------------+-------------------------------+----------------------------------+", + "| VARIANCE(aggregate_test_100.c4) | VARIANCEPOP(aggregate_test_100.c4) | STDDEV(aggregate_test_100.c4) | STDDEVPOP(aggregate_test_100.c4) |", + "+---------------------------------+------------------------------------+-------------------------------+----------------------------------+", + "| 46721.33333333174 | 31147.555555554496 | 216.15118166073427 | 176.4867007894773 |", + "| 2639429.333333332 | 1759619.5555555548 | 1624.6320609089714 | 1326.5065229977404 |", + "| 746202.3333333324 | 497468.2222222216 | 863.8300372951455 | 705.3142719541563 |", + "| 768422.9999999981 | 512281.9999999988 | 876.5973990378925 | 715.7387791645767 |", + "| 66526.3333333288 | 44350.88888888587 | 257.9269922542594 | 210.5965073045749 |", + "+---------------------------------+------------------------------------+-------------------------------+----------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn window_frame_rows_preceding_with_partition_unique_order_by() -> Result<()> { let ctx = SessionContext::new();