diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt index 7a1b012b84108..611fd75eff9ee 100644 --- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt +++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt @@ -1091,3 +1091,15 @@ query U SELECT ARRAY_AGG([1]); ---- [[1]] + +# variance_single_value +query RRRR +select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0)) as sq; +---- +NULL 0 NULL 0 + +# variance_two_values +query RRRR +select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0), (3.0)) as sq; +---- +2 1 1.4142135623730951 1 diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs index 05dc56cff2c6f..4c9e46644a746 100644 --- a/datafusion/physical-expr/src/aggregate/stddev.rs +++ b/datafusion/physical-expr/src/aggregate/stddev.rs @@ -307,8 +307,8 @@ mod tests { "bla".to_string(), DataType::Float64, )); - let actual = aggregate(&batch, agg); - assert!(actual.is_err()); + let actual = aggregate(&batch, agg).unwrap(); + assert_eq!(actual, ScalarValue::Float64(None)); Ok(()) } @@ -341,9 +341,8 @@ mod tests { "bla".to_string(), DataType::Float64, )); - let actual = aggregate(&batch, agg); - assert!(actual.is_err()); - + let actual = aggregate(&batch, agg).unwrap(); + assert_eq!(actual, ScalarValue::Float64(None)); Ok(()) } diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs index d1ccea7e1d7a5..2895137447114 100644 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ b/datafusion/physical-expr/src/aggregate/variance.rs @@ -286,17 +286,17 @@ impl Accumulator for VarianceAccumulator { } }; - if count <= 1 { - return Err(DataFusionError::Internal( - "At least two values are needed to calculate variance".to_string(), - )); - } - - if self.count == 0 { - Ok(ScalarValue::Float64(None)) - } else { - Ok(ScalarValue::Float64(Some(self.m2 / count as f64))) - } + Ok(ScalarValue::Float64(match self.count { + 0 => None, + 1 => { + if let StatsType::Population = self.stats_type { + Some(0.0) + } else { + None + } + } + _ => Some(self.m2 / count as f64), + })) } fn size(&self) -> usize { @@ -382,8 +382,8 @@ mod tests { "bla".to_string(), DataType::Float64, )); - let actual = aggregate(&batch, agg); - assert!(actual.is_err()); + let actual = aggregate(&batch, agg).unwrap(); + assert_eq!(actual, ScalarValue::Float64(None)); Ok(()) } @@ -416,8 +416,8 @@ mod tests { "bla".to_string(), DataType::Float64, )); - let actual = aggregate(&batch, agg); - assert!(actual.is_err()); + let actual = aggregate(&batch, agg).unwrap(); + assert_eq!(actual, ScalarValue::Float64(None)); Ok(()) }