diff --git a/datafusion/src/physical_plan/expressions/stddev.rs b/datafusion/src/physical_plan/expressions/stddev.rs index d6e28f18d3558..2c85b90491521 100644 --- a/datafusion/src/physical_plan/expressions/stddev.rs +++ b/datafusion/src/physical_plan/expressions/stddev.rs @@ -25,8 +25,7 @@ use crate::physical_plan::{ expressions::variance::VarianceAccumulator, Accumulator, AggregateExpr, PhysicalExpr, }; use crate::scalar::ScalarValue; -use arrow::datatypes::DataType; -use arrow::datatypes::Field; +use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; use super::{format_state_name, StatsType}; @@ -216,8 +215,8 @@ impl Accumulator for StddevAccumulator { fn state(&self) -> Result> { Ok(vec![ ScalarValue::from(self.variance.get_count()), - self.variance.get_mean(), - self.variance.get_m2(), + ScalarValue::from(self.variance.get_mean()), + ScalarValue::from(self.variance.get_m2()), ]) } @@ -229,6 +228,14 @@ impl Accumulator for StddevAccumulator { self.variance.merge(states) } + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.variance.update_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.variance.merge_batch(states) + } + fn evaluate(&self) -> Result { let variance = self.variance.evaluate()?; match variance { diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index 3f592b00fd4ef..75164405e537f 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -23,8 +23,13 @@ use std::sync::Arc; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; use crate::scalar::ScalarValue; -use arrow::datatypes::DataType; -use arrow::datatypes::Field; +use arrow::array::Float64Array; +use arrow::{ + array::{ArrayRef, UInt64Array}, + compute::cast, + datatypes::DataType, + datatypes::Field, +}; use super::{format_state_name, StatsType}; @@ -209,8 +214,8 @@ impl AggregateExpr for VariancePop { #[derive(Debug)] pub struct VarianceAccumulator { - m2: ScalarValue, - mean: ScalarValue, + m2: f64, + mean: f64, count: u64, stats_type: StatsType, } @@ -219,9 +224,9 @@ impl VarianceAccumulator { /// Creates a new `VarianceAccumulator` pub fn try_new(s_type: StatsType) -> Result { Ok(Self { - m2: ScalarValue::from(0 as f64), - mean: ScalarValue::from(0 as f64), - count: 0, + m2: 0_f64, + mean: 0_f64, + count: 0_u64, stats_type: s_type, }) } @@ -230,12 +235,12 @@ impl VarianceAccumulator { self.count } - pub fn get_mean(&self) -> ScalarValue { - self.mean.clone() + pub fn get_mean(&self) -> f64 { + self.mean } - pub fn get_m2(&self) -> ScalarValue { - self.m2.clone() + pub fn get_m2(&self) -> f64 { + self.m2 } } @@ -243,80 +248,174 @@ impl Accumulator for VarianceAccumulator { fn state(&self) -> Result> { Ok(vec![ ScalarValue::from(self.count), - self.mean.clone(), - self.m2.clone(), + ScalarValue::from(self.mean), + ScalarValue::from(self.m2), ]) } + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &cast(&values[0], &DataType::Float64)?; + let arr = values.as_any().downcast_ref::().unwrap(); + + for i in 0..arr.len() { + let value = arr.value(i); + + if value == 0_f64 && values.is_null(i) { + continue; + } + let new_count = self.count + 1; + let delta1 = value - self.mean; + let new_mean = delta1 / new_count as f64 + self.mean; + let delta2 = value - new_mean; + let new_m2 = self.m2 + delta1 * delta2; + + self.count += 1; + self.mean = new_mean; + self.m2 = new_m2; + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = states[0].as_any().downcast_ref::().unwrap(); + let means = states[1].as_any().downcast_ref::().unwrap(); + let m2s = states[2].as_any().downcast_ref::().unwrap(); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0_u64 { + continue; + } + let new_count = self.count + c; + let new_mean = self.mean * self.count as f64 / new_count as f64 + + means.value(i) * c as f64 / new_count as f64; + let delta = self.mean - means.value(i); + let new_m2 = self.m2 + + m2s.value(i) + + delta * delta * self.count as f64 * c as f64 / new_count as f64; + + self.count = new_count; + self.mean = new_mean; + self.m2 = new_m2; + } + Ok(()) + } + fn update(&mut self, values: &[ScalarValue]) -> Result<()> { let values = &values[0]; let is_empty = values.is_null(); + let mean = ScalarValue::from(self.mean); + let m2 = ScalarValue::from(self.m2); if !is_empty { let new_count = self.count + 1; - let delta1 = ScalarValue::add(values, &self.mean.arithmetic_negate())?; + let delta1 = ScalarValue::add(values, &mean.arithmetic_negate())?; let new_mean = ScalarValue::add( &ScalarValue::div(&delta1, &ScalarValue::from(new_count as f64))?, - &self.mean, + &mean, )?; let delta2 = ScalarValue::add(values, &new_mean.arithmetic_negate())?; let tmp = ScalarValue::mul(&delta1, &delta2)?; - let new_m2 = ScalarValue::add(&self.m2, &tmp)?; + let new_m2 = ScalarValue::add(&m2, &tmp)?; self.count += 1; - self.mean = new_mean; - self.m2 = new_m2; + + if let ScalarValue::Float64(Some(c)) = new_mean { + self.mean = c; + } else { + unreachable!() + }; + if let ScalarValue::Float64(Some(m)) = new_m2 { + self.m2 = m; + } else { + unreachable!() + }; } Ok(()) } fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - let count = &states[0]; - let mean = &states[1]; - let m2 = &states[2]; + let count; + let mean; + let m2; let mut new_count: u64 = self.count; - // counts are summed - if let ScalarValue::UInt64(Some(c)) = count { - if *c == 0_u64 { - return Ok(()); - } + if let ScalarValue::UInt64(Some(c)) = states[0] { + count = c; + } else { + unreachable!() + }; - if self.count == 0 { - self.count = *c; - self.mean = mean.clone(); - self.m2 = m2.clone(); - return Ok(()); - } - new_count += c + if count == 0_u64 { + return Ok(()); + } + + if let ScalarValue::Float64(Some(m)) = states[1] { + mean = m; } else { unreachable!() }; + if let ScalarValue::Float64(Some(n)) = states[2] { + m2 = n; + } else { + unreachable!() + }; + + if self.count == 0 { + self.count = count; + self.mean = mean; + self.m2 = m2; + return Ok(()); + } - let new_mean = ScalarValue::div( - &ScalarValue::add(&self.mean, mean)?, - &ScalarValue::from(2_f64), + new_count += count; + + let mean1 = ScalarValue::from(self.mean); + let mean2 = ScalarValue::from(mean); + + let new_mean = ScalarValue::add( + &ScalarValue::div( + &ScalarValue::mul(&mean1, &ScalarValue::from(self.count))?, + &ScalarValue::from(new_count as f64), + )?, + &ScalarValue::div( + &ScalarValue::mul(&mean2, &ScalarValue::from(count))?, + &ScalarValue::from(new_count as f64), + )?, )?; - let delta = ScalarValue::add(&mean.arithmetic_negate(), &self.mean)?; + + let delta = ScalarValue::add(&mean2.arithmetic_negate(), &mean1)?; let delta_sqrt = ScalarValue::mul(&delta, &delta)?; let new_m2 = ScalarValue::add( &ScalarValue::add( &ScalarValue::mul( &delta_sqrt, &ScalarValue::div( - &ScalarValue::mul(&ScalarValue::from(self.count), count)?, + &ScalarValue::mul( + &ScalarValue::from(self.count), + &ScalarValue::from(count), + )?, &ScalarValue::from(new_count as f64), )?, )?, - &self.m2, + &ScalarValue::from(self.m2), )?, - m2, + &ScalarValue::from(m2), )?; self.count = new_count; - self.mean = new_mean; - self.m2 = new_m2; + if let ScalarValue::Float64(Some(c)) = new_mean { + self.mean = c; + } else { + unreachable!() + }; + if let ScalarValue::Float64(Some(m)) = new_m2 { + self.m2 = m; + } else { + unreachable!() + }; Ok(()) } @@ -339,17 +438,10 @@ impl Accumulator for VarianceAccumulator { )); } - match self.m2 { - ScalarValue::Float64(e) => { - if self.count == 0 { - Ok(ScalarValue::Float64(None)) - } else { - Ok(ScalarValue::Float64(e.map(|f| f / count as f64))) - } - } - _ => Err(DataFusionError::Internal( - "M2 should be f64 for variance".to_string(), - )), + if self.count == 0 { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(Some(self.m2 / count as f64))) } } }