From 0332f7feb6d0f1da38a7e24b5f0db96fe041ad13 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Mon, 3 Jan 2022 21:35:20 -0800 Subject: [PATCH 01/22] Initial implementation of variance --- .../src/physical_plan/expressions/mod.rs | 2 + .../src/physical_plan/expressions/stddev.rs | 408 ++++++++++++++++ .../src/physical_plan/expressions/variance.rs | 441 ++++++++++++++++++ 3 files changed, 851 insertions(+) create mode 100644 datafusion/src/physical_plan/expressions/stddev.rs create mode 100644 datafusion/src/physical_plan/expressions/variance.rs diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 134c6d89ac4f1..46c168926d205 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -28,6 +28,8 @@ use arrow::record_batch::RecordBatch; mod approx_distinct; mod array_agg; mod average; +mod stddev; +mod variance; #[macro_use] mod binary; mod case; diff --git a/datafusion/src/physical_plan/expressions/stddev.rs b/datafusion/src/physical_plan/expressions/stddev.rs new file mode 100644 index 0000000000000..92da6bd4ca557 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/stddev.rs @@ -0,0 +1,408 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::convert::TryFrom; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use crate::scalar::{ + ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128, +}; +use arrow::compute; +use arrow::datatypes::DataType; +use arrow::{ + array::{ArrayRef, UInt64Array}, + datatypes::Field, +}; + +use super::{format_state_name, sum}; + +/// STDDEV (standard deviation) aggregate expression +#[derive(Debug)] +pub struct Stddev { + name: String, + expr: Arc, + data_type: DataType, +} + +/// function return type of an standard deviation +pub fn stddev_return_type(arg_type: &DataType) -> Result { + match arg_type { + DataType::Decimal(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4); + let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4); + Ok(DataType::Decimal(new_precision, new_scale)) + } + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(DataFusionError::Plan(format!( + "STDDEV does not support {:?}", + other + ))), + } +} + +pub(crate) fn is_stddev_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal(_, _) + ) +} + +impl Stddev { + /// Create a new STDDEV aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + // the result of stddev just support FLOAT64 and Decimal data type. + assert!(matches!( + data_type, + DataType::Float64 | DataType::Decimal(_, _) + )); + Self { + name: name.into(), + expr, + data_type, + } + } +} + +impl AggregateExpr for Stddev { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(StddevAccumulator::try_new( + // stddev is f64 or decimal + &self.data_type, + )?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "sum"), + self.data_type.clone(), + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +/// An accumulator to compute the average +#[derive(Debug)] +pub struct StddevAccumulator { + // sum is used for null + sum: ScalarValue, + count: u64, +} + +impl StddevAccumulator { + /// Creates a new `StddevAccumulator` + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + sum: ScalarValue::try_from(datatype)?, + count: 0, + }) + } +} + +impl Accumulator for StddevAccumulator { + fn state(&self) -> Result> { + Ok(vec![ScalarValue::from(self.count), self.sum.clone()]) + } + + fn update(&mut self, values: &[ScalarValue]) -> Result<()> { + let values = &values[0]; + + self.count += (!values.is_null()) as u64; + self.sum = sum::sum(&self.sum, values)?; + + Ok(()) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + + self.count += (values.len() - values.data().null_count()) as u64; + self.sum = sum::sum(&self.sum, &sum::sum_batch(values)?)?; + Ok(()) + } + + fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { + let count = &states[0]; + // counts are summed + if let ScalarValue::UInt64(Some(c)) = count { + self.count += c + } else { + unreachable!() + }; + + // sums are summed + self.sum = sum::sum(&self.sum, &states[1])?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = states[0].as_any().downcast_ref::().unwrap(); + // counts are summed + self.count += compute::sum(counts).unwrap_or(0); + + // sums are summed + self.sum = sum::sum(&self.sum, &sum::sum_batch(&states[1])?)?; + Ok(()) + } + + fn evaluate(&self) -> Result { + match self.sum { + ScalarValue::Float64(e) => { + Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64))) + } + ScalarValue::Decimal128(value, precision, scale) => { + Ok(match value { + None => ScalarValue::Decimal128(None, precision, scale), + // TODO add the checker for overflow the precision + Some(v) => ScalarValue::Decimal128( + Some(v / self.count as i128), + precision, + scale, + ), + }) + } + _ => Err(DataFusionError::Internal( + "Sum should be f64 on average".to_string(), + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::physical_plan::expressions::col; + use crate::{error::Result, generic_test_op}; + use arrow::record_batch::RecordBatch; + use arrow::{array::*, datatypes::*}; + + #[test] + fn test_stddev_return_data_type() -> Result<()> { + let data_type = DataType::Decimal(10, 5); + let result_type = stddev_return_type(&data_type)?; + assert_eq!(DataType::Decimal(14, 9), result_type); + + let data_type = DataType::Decimal(36, 10); + let result_type = stddev_return_type(&data_type)?; + assert_eq!(DataType::Decimal(38, 14), result_type); + Ok(()) + } + + #[test] + fn stddev_decimal() -> Result<()> { + // test agg + let mut decimal_builder = DecimalBuilder::new(6, 10, 0); + for i in 1..7 { + decimal_builder.append_value(i as i128)?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + + generic_test_op!( + array, + DataType::Decimal(10, 0), + Stddev, + ScalarValue::Decimal128(Some(35000), 14, 4), + DataType::Decimal(14, 4) + ) + } + + #[test] + fn stddev_decimal_with_nulls() -> Result<()> { + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for i in 1..6 { + if i == 2 { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(i)?; + } + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Stddev, + ScalarValue::Decimal128(Some(32500), 14, 4), + DataType::Decimal(14, 4) + ) + } + + #[test] + fn stddev_decimal_all_nulls() -> Result<()> { + // test agg + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for _i in 1..6 { + decimal_builder.append_null()?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Stddev, + ScalarValue::Decimal128(None, 14, 4), + DataType::Decimal(14, 4) + ) + } + + #[test] + fn stddev_i32() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); + generic_test_op!( + a, + DataType::Int32, + Stddev, + ScalarValue::from(3_f64), + DataType::Float64 + ) + } + + #[test] + fn stddev_i32_with_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(3), + Some(4), + Some(5), + ])); + generic_test_op!( + a, + DataType::Int32, + Stddev, + ScalarValue::from(3.25f64), + DataType::Float64 + ) + } + + #[test] + fn stddev_i32_all_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); + generic_test_op!( + a, + DataType::Int32, + Stddev, + ScalarValue::Float64(None), + DataType::Float64 + ) + } + + #[test] + fn stddev_u32() -> Result<()> { + let a: ArrayRef = + Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); + generic_test_op!( + a, + DataType::UInt32, + Stddev, + ScalarValue::from(3.0f64), + DataType::Float64 + ) + } + + #[test] + fn stddev_f32() -> Result<()> { + let a: ArrayRef = + Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); + generic_test_op!( + a, + DataType::Float32, + Stddev, + ScalarValue::from(3_f64), + DataType::Float64 + ) + } + + #[test] + fn stddev_f64() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); + generic_test_op!( + a, + DataType::Float64, + Stddev, + ScalarValue::from(3_f64), + DataType::Float64 + ) + } + + fn aggregate( + batch: &RecordBatch, + agg: Arc, + ) -> Result { + let mut accum = agg.create_accumulator()?; + let expr = agg.expressions(); + let values = expr + .iter() + .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + accum.update_batch(&values)?; + accum.evaluate() + } +} diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs new file mode 100644 index 0000000000000..6a916b1e1b5c2 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -0,0 +1,441 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::convert::TryFrom; +use std::sync::Arc; + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use crate::scalar::{ + ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128, +}; +use arrow::array::*; +use arrow::compute; +use arrow::datatypes::DataType; +use arrow::datatypes::Field; + +use super::{format_state_name, sum}; + +/// STDDEV (standard deviation) aggregate expression +#[derive(Debug)] +pub struct Variance { + name: String, + expr: Arc, + data_type: DataType, +} + +/// function return type of an standard deviation +pub fn variance_return_type(arg_type: &DataType) -> Result { + match arg_type { + DataType::Decimal(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4); + let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4); + Ok(DataType::Decimal(new_precision, new_scale)) + } + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(DataFusionError::Plan(format!( + "STDDEV does not support {:?}", + other + ))), + } +} + +pub(crate) fn is_variance_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal(_, _) + ) +} + +impl Variance { + /// Create a new STDDEV aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + // the result of variance just support FLOAT64 and Decimal data type. + assert!(matches!( + data_type, + DataType::Float64 | DataType::Decimal(_, _) + )); + Self { + name: name.into(), + expr, + data_type, + } + } +} + +impl AggregateExpr for Variance { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(VarianceAccumulator::try_new( + // variance is f64 or decimal + &self.data_type, + )?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "sum"), + self.data_type.clone(), + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +/// An accumulator to compute variance +#[derive(Debug)] +pub struct VarianceAccumulator { + m2: ScalarValue, + mean: ScalarValue, + count: u64, +} + +impl VarianceAccumulator { + /// Creates a new `VarianceAccumulator` + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + m2: ScalarValue::try_from(datatype)?, + mean: ScalarValue::try_from(datatype)?, + count: 0, + }) + } + + // TODO: There should be a generic implementation of ScalarValue arithmetic somewhere + // There is also a similar function in averate.rs + fn div(lhs: &ScalarValue, rhs: u64) -> Result { + match lhs { + ScalarValue::Float64(e) => { + Ok(ScalarValue::Float64(e.map(|f| f / rhs as f64))) + } + _ => Err(DataFusionError::Internal( + "Numerator should be f64 to calculate variance".to_string(), + )), + } + } + + // TODO: There should be a generic implementation of ScalarValue arithmetic somewhere + // This is only used to calculate multiplications of deltas which are guarenteed to be f64 + // Assumption in this function is lhs and rhs are not none values and are the same data type + fn mul(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { + match (lhs, rhs) { + (ScalarValue::Float64(f1), + ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() * f2.unwrap()))) + } + _ => Err(DataFusionError::Internal( + "Delta should be f64 to calculate variance".to_string(), + )), + } + } +} + +impl Accumulator for VarianceAccumulator { + fn state(&self) -> Result> { + Ok(vec![ScalarValue::from(self.count), self.mean.clone(), self.m2.clone()]) + } + + fn update(&mut self, values: &[ScalarValue]) -> Result<()> { + let values = &values[0]; + let is_empty = values.is_null(); + + if !is_empty { + let delta1 = sum::sum(values, &self.mean.arithmetic_negate())?; + let sum = sum::sum(&self.mean, values)?; + let new_mean = VarianceAccumulator::div(&sum, 2)?; + let delta2 = sum::sum(values, &self.mean.arithmetic_negate())?; + let tmp = VarianceAccumulator::mul(&delta1, &delta2)?; + let new_m2 = sum::sum(&self.m2, &tmp)?; + self.count += 1; + self.mean = new_mean; + self.m2 = new_m2; + } + + Ok(()) + } + + fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { + let count = &states[0]; + let mean = &states[1]; + let m2 = &states[2]; + let mut new_count: u64 = self.count; + // counts are summed + if let ScalarValue::UInt64(Some(c)) = count { + new_count += c + } else { + unreachable!() + }; + let new_mean = + VarianceAccumulator::div( + &sum::sum( + &self.mean, + mean)?, + 2)?; + let delta = sum::sum(&mean.arithmetic_negate(), &self.mean)?; + let delta_sqrt = VarianceAccumulator::mul(&delta, &delta)?; + let new_m2 = + sum::sum( + &sum::sum( + &VarianceAccumulator::mul( + &delta_sqrt, + &VarianceAccumulator::div( + &VarianceAccumulator::mul( + &ScalarValue::from(self.count), + count)?, + new_count)?)?, + &self.m2)?, + &m2)?; + + self.count = new_count; + self.mean = new_mean; + self.m2 = new_m2; + + Ok(()) + } + + fn evaluate(&self) -> Result { + match self.m2 { + ScalarValue::Float64(e) => { + Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64))) + } + _ => Err(DataFusionError::Internal( + "M2 should be f64 for variance".to_string(), + )), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::physical_plan::expressions::col; + use crate::{error::Result, generic_test_op}; + use arrow::record_batch::RecordBatch; + use arrow::{array::*, datatypes::*}; + + #[test] + fn test_variance_return_data_type() -> Result<()> { + let data_type = DataType::Decimal(10, 5); + let result_type = variance_return_type(&data_type)?; + assert_eq!(DataType::Decimal(14, 9), result_type); + + let data_type = DataType::Decimal(36, 10); + let result_type = variance_return_type(&data_type)?; + assert_eq!(DataType::Decimal(38, 14), result_type); + Ok(()) + } + + #[test] + fn variance_decimal() -> Result<()> { + // test agg + let mut decimal_builder = DecimalBuilder::new(6, 10, 0); + for i in 1..7 { + decimal_builder.append_value(i as i128)?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + + generic_test_op!( + array, + DataType::Decimal(10, 0), + Variance, + ScalarValue::Decimal128(Some(35000), 14, 4), + DataType::Decimal(14, 4) + ) + } + + #[test] + fn variance_decimal_with_nulls() -> Result<()> { + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for i in 1..6 { + if i == 2 { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(i)?; + } + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Variance, + ScalarValue::Decimal128(Some(32500), 14, 4), + DataType::Decimal(14, 4) + ) + } + + #[test] + fn variance_decimal_all_nulls() -> Result<()> { + // test agg + let mut decimal_builder = DecimalBuilder::new(5, 10, 0); + for _i in 1..6 { + decimal_builder.append_null()?; + } + let array: ArrayRef = Arc::new(decimal_builder.finish()); + generic_test_op!( + array, + DataType::Decimal(10, 0), + Variance, + ScalarValue::Decimal128(None, 14, 4), + DataType::Decimal(14, 4) + ) + } + + #[test] + fn variance_i32() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); + generic_test_op!( + a, + DataType::Int32, + Variance, + ScalarValue::from(3_f64), + DataType::Float64 + ) + } + + #[test] + fn variance_i32_with_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(3), + Some(4), + Some(5), + ])); + generic_test_op!( + a, + DataType::Int32, + Variance, + ScalarValue::from(3.25f64), + DataType::Float64 + ) + } + + #[test] + fn variance_i32_all_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); + generic_test_op!( + a, + DataType::Int32, + Variance, + ScalarValue::Float64(None), + DataType::Float64 + ) + } + + #[test] + fn variance_u32() -> Result<()> { + let a: ArrayRef = + Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); + generic_test_op!( + a, + DataType::UInt32, + Variance, + ScalarValue::from(3.0f64), + DataType::Float64 + ) + } + + #[test] + fn variance_f32() -> Result<()> { + let a: ArrayRef = + Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); + generic_test_op!( + a, + DataType::Float32, + Variance, + ScalarValue::from(3_f64), + DataType::Float64 + ) + } + + #[test] + fn variance_f64() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); + generic_test_op!( + a, + DataType::Float64, + Variance, + ScalarValue::from(3_f64), + DataType::Float64 + ) + } + + fn aggregate( + batch: &RecordBatch, + agg: Arc, + ) -> Result { + let mut accum = agg.create_accumulator()?; + let expr = agg.expressions(); + let values = expr + .iter() + .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + accum.update_batch(&values)?; + accum.evaluate() + } +} From ba1140be882f81fa2ac9d786feb5c212d442d4a4 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Mon, 3 Jan 2022 22:18:59 -0800 Subject: [PATCH 02/22] get simple f64 type tests working --- .../src/physical_plan/expressions/variance.rs | 76 ++++++++++++------- 1 file changed, 47 insertions(+), 29 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index 6a916b1e1b5c2..d365a8939909a 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -158,8 +158,8 @@ impl VarianceAccumulator { /// Creates a new `VarianceAccumulator` pub fn try_new(datatype: &DataType) -> Result { Ok(Self { - m2: ScalarValue::try_from(datatype)?, - mean: ScalarValue::try_from(datatype)?, + m2: ScalarValue::from(0 as f64), + mean: ScalarValue::from(0 as f64), count: 0, }) } @@ -203,10 +203,14 @@ impl Accumulator for VarianceAccumulator { let is_empty = values.is_null(); if !is_empty { + let new_count = self.count + 1; let delta1 = sum::sum(values, &self.mean.arithmetic_negate())?; let sum = sum::sum(&self.mean, values)?; - let new_mean = VarianceAccumulator::div(&sum, 2)?; - let delta2 = sum::sum(values, &self.mean.arithmetic_negate())?; + let new_mean = sum::sum( + &VarianceAccumulator::div(&delta1, new_count)?, + &self.mean)?; + //let new_mean = VarianceAccumulator::div(&sum, 2)?; + let delta2 = sum::sum(values, &new_mean.arithmetic_negate())?; let tmp = VarianceAccumulator::mul(&delta1, &delta2)?; let new_m2 = sum::sum(&self.m2, &tmp)?; self.count += 1; @@ -276,6 +280,45 @@ mod tests { use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; + + #[test] + fn variance_f64_1() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64, 2_f64])); + generic_test_op!( + a, + DataType::Float64, + Variance, + ScalarValue::from(0.25_f64), + DataType::Float64 + ) + } + + #[test] + fn variance_f64() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); + generic_test_op!( + a, + DataType::Float64, + Variance, + ScalarValue::from(2_f64), + DataType::Float64 + ) + } + + #[test] + fn variance_i32() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); + generic_test_op!( + a, + DataType::Int32, + Variance, + ScalarValue::from(2_f64), + DataType::Float64 + ) + } + #[test] fn test_variance_return_data_type() -> Result<()> { let data_type = DataType::Decimal(10, 5); @@ -343,18 +386,6 @@ mod tests { ) } - #[test] - fn variance_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Int32, - Variance, - ScalarValue::from(3_f64), - DataType::Float64 - ) - } - #[test] fn variance_i32_with_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![ @@ -411,19 +442,6 @@ mod tests { ) } - #[test] - fn variance_f64() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!( - a, - DataType::Float64, - Variance, - ScalarValue::from(3_f64), - DataType::Float64 - ) - } - fn aggregate( batch: &RecordBatch, agg: Arc, From ae2cc929cea63e9541500aaf6a05e8f7ffbaf376 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Tue, 4 Jan 2022 21:14:34 -0800 Subject: [PATCH 03/22] add math functions to ScalarValue, some tests --- .../src/physical_plan/expressions/variance.rs | 121 +++----- datafusion/src/scalar.rs | 277 ++++++++++++++++++ 2 files changed, 323 insertions(+), 75 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index d365a8939909a..0b04977f0fba0 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -31,7 +31,7 @@ use arrow::compute; use arrow::datatypes::DataType; use arrow::datatypes::Field; -use super::{format_state_name, sum}; +use super::format_state_name; /// STDDEV (standard deviation) aggregate expression #[derive(Debug)] @@ -163,34 +163,6 @@ impl VarianceAccumulator { count: 0, }) } - - // TODO: There should be a generic implementation of ScalarValue arithmetic somewhere - // There is also a similar function in averate.rs - fn div(lhs: &ScalarValue, rhs: u64) -> Result { - match lhs { - ScalarValue::Float64(e) => { - Ok(ScalarValue::Float64(e.map(|f| f / rhs as f64))) - } - _ => Err(DataFusionError::Internal( - "Numerator should be f64 to calculate variance".to_string(), - )), - } - } - - // TODO: There should be a generic implementation of ScalarValue arithmetic somewhere - // This is only used to calculate multiplications of deltas which are guarenteed to be f64 - // Assumption in this function is lhs and rhs are not none values and are the same data type - fn mul(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { - match (lhs, rhs) { - (ScalarValue::Float64(f1), - ScalarValue::Float64(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() * f2.unwrap()))) - } - _ => Err(DataFusionError::Internal( - "Delta should be f64 to calculate variance".to_string(), - )), - } - } } impl Accumulator for VarianceAccumulator { @@ -204,15 +176,14 @@ impl Accumulator for VarianceAccumulator { if !is_empty { let new_count = self.count + 1; - let delta1 = sum::sum(values, &self.mean.arithmetic_negate())?; - let sum = sum::sum(&self.mean, values)?; - let new_mean = sum::sum( - &VarianceAccumulator::div(&delta1, new_count)?, + let delta1 = ScalarValue::add(values, &self.mean.arithmetic_negate())?; + let new_mean = ScalarValue::add( + &ScalarValue::div(&delta1, &ScalarValue::from(new_count as f64))?, &self.mean)?; - //let new_mean = VarianceAccumulator::div(&sum, 2)?; - let delta2 = sum::sum(values, &new_mean.arithmetic_negate())?; - let tmp = VarianceAccumulator::mul(&delta1, &delta2)?; - let new_m2 = sum::sum(&self.m2, &tmp)?; + let delta2 = ScalarValue::add(values, &new_mean.arithmetic_negate())?; + let tmp = ScalarValue::mul(&delta1, &delta2)?; + + let new_m2 = ScalarValue::add(&self.m2, &tmp)?; self.count += 1; self.mean = new_mean; self.m2 = new_m2; @@ -233,23 +204,23 @@ impl Accumulator for VarianceAccumulator { unreachable!() }; let new_mean = - VarianceAccumulator::div( - &sum::sum( + ScalarValue::div( + &ScalarValue::add( &self.mean, mean)?, - 2)?; - let delta = sum::sum(&mean.arithmetic_negate(), &self.mean)?; - let delta_sqrt = VarianceAccumulator::mul(&delta, &delta)?; + &ScalarValue::from(2 as f64))?; + let delta = ScalarValue::add(&mean.arithmetic_negate(), &self.mean)?; + let delta_sqrt = ScalarValue::mul(&delta, &delta)?; let new_m2 = - sum::sum( - &sum::sum( - &VarianceAccumulator::mul( + ScalarValue::add( + &ScalarValue::add( + &ScalarValue::mul( &delta_sqrt, - &VarianceAccumulator::div( - &VarianceAccumulator::mul( + &ScalarValue::div( + &ScalarValue::mul( &ScalarValue::from(self.count), count)?, - new_count)?)?, + &ScalarValue::from(new_count as f64))?)?, &self.m2)?, &m2)?; @@ -295,7 +266,7 @@ mod tests { } #[test] - fn variance_f64() -> Result<()> { + fn variance_f64_2() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); generic_test_op!( @@ -319,6 +290,32 @@ mod tests { ) } + #[test] + fn variance_u32() -> Result<()> { + let a: ArrayRef = + Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); + generic_test_op!( + a, + DataType::UInt32, + Variance, + ScalarValue::from(2.0f64), + DataType::Float64 + ) + } + + #[test] + fn variance_f32() -> Result<()> { + let a: ArrayRef = + Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); + generic_test_op!( + a, + DataType::Float32, + Variance, + ScalarValue::from(2_f64), + DataType::Float64 + ) + } + #[test] fn test_variance_return_data_type() -> Result<()> { let data_type = DataType::Decimal(10, 5); @@ -416,32 +413,6 @@ mod tests { ) } - #[test] - fn variance_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!( - a, - DataType::UInt32, - Variance, - ScalarValue::from(3.0f64), - DataType::Float64 - ) - } - - #[test] - fn variance_f32() -> Result<()> { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!( - a, - DataType::Float32, - Variance, - ScalarValue::from(3_f64), - DataType::Float64 - ) - } - fn aggregate( batch: &RecordBatch, agg: Arc, diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index cdcf11eccea27..8de1621c982dc 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -526,6 +526,283 @@ macro_rules! eq_array_primitive { } impl ScalarValue { + + /// Return true if the value is numeric + pub fn is_numeric(&self) -> bool { + match self { + ScalarValue::Float32(_) | + ScalarValue::Float64(_) | + ScalarValue::Decimal128(_, _, _) | + ScalarValue::Int8(_) | + ScalarValue::Int16(_) | + ScalarValue::Int32(_) | + ScalarValue::Int64(_) | + ScalarValue::UInt8(_) | + ScalarValue::UInt16(_) | + ScalarValue::UInt32(_) | + ScalarValue::UInt64(_) => { + true + } + _ => false + } + } + + /// Add two numeric ScalarValues + pub fn add(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(DataFusionError::Internal( + format!( + "Division is only supported on numeric types, \ + here has {:?} and {:?}", + lhs.get_datatype(), rhs.get_datatype() + ))); + } + + // TODO: Finding a good way to support operation between different types without + // writing a hige match block. + match (lhs, rhs) { + (ScalarValue::Decimal128(v1, u1, s1), _) | + (_, ScalarValue::Decimal128(v1, u1, s1)) => { + Err(DataFusionError::Internal( + format!( + "Division with Decimals are not supported for now" + ))) + }, + // f64 / _ + (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() + f2.unwrap()))) + }, + // f32 / _ + (ScalarValue::Float32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::Float32(f1), ScalarValue::Float32(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap() as f64))) + }, + // i64 / _ + (ScalarValue::Int64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::Int64(f1), ScalarValue::Int64(f2)) => { + Ok(ScalarValue::Int64(Some(f1.unwrap() + f2.unwrap()))) + }, + // i32 / _ + (ScalarValue::Int32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::Int32(f1), ScalarValue::Int32(f2)) => { + Ok(ScalarValue::Int64(Some(f1.unwrap() as i64 + f2.unwrap() as i64))) + }, + // i16 / _ + (ScalarValue::Int16(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::Int16(f1), ScalarValue::Int16(f2)) => { + Ok(ScalarValue::Int32(Some(f1.unwrap() as i32 + f2.unwrap() as i32))) + }, + // i8 / _ + (ScalarValue::Int8(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::Int8(f1), ScalarValue::Int8(f2)) => { + Ok(ScalarValue::Int16(Some(f1.unwrap() as i16 + f2.unwrap() as i16))) + }, + // u64 / _ + (ScalarValue::UInt64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::UInt64(f1), ScalarValue::UInt64(f2)) => { + Ok(ScalarValue::UInt64(Some(f1.unwrap() as u64 + f2.unwrap() as u64))) + }, + // u32 / _ + (ScalarValue::UInt32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::UInt32(f1), ScalarValue::UInt32(f2)) => { + Ok(ScalarValue::UInt64(Some(f1.unwrap() as u64 + f2.unwrap() as u64))) + }, + // u16 / _ + (ScalarValue::UInt16(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::UInt16(f1), ScalarValue::UInt16(f2)) => { + Ok(ScalarValue::UInt32(Some(f1.unwrap() as u32 + f2.unwrap() as u32))) + }, + // u8 / _ + (ScalarValue::UInt8(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) + }, + (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => { + Ok(ScalarValue::UInt16(Some(f1.unwrap() as u16 / f2.unwrap() as u16))) + }, + + _ => Err(DataFusionError::Internal( + format!( + "Addition only support calculation with the same type or f64 as one of the numbers for now, here has {:?} and {:?}", + lhs.get_datatype(), rhs.get_datatype() + ))), + } + } + + /// Multiply two numeric ScalarValues + pub fn mul(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(DataFusionError::Internal( + format!( + "Multiplication is only supported on numeric types, \ + here has {:?} and {:?}", + lhs.get_datatype(), rhs.get_datatype() + ))); + } + + // TODO: Finding a good way to support operation between different types without + // writing a hige match block. + match (lhs, rhs) { + // f64 / _ + (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() * f2.unwrap()))) + }, + // f32 / _ + (ScalarValue::Float32(f1), ScalarValue::Float32(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 * f2.unwrap() as f64))) + }, + // i64 / _ + (ScalarValue::Int64(f1), ScalarValue::Int64(f2)) => { + Ok(ScalarValue::Int64(Some(f1.unwrap() * f2.unwrap()))) + }, + // i32 / _ + (ScalarValue::Int32(f1), ScalarValue::Int32(f2)) => { + Ok(ScalarValue::Int64(Some(f1.unwrap() as i64 * f2.unwrap() as i64))) + }, + // i16 / _ + (ScalarValue::Int16(f1), ScalarValue::Int16(f2)) => { + Ok(ScalarValue::Int32(Some(f1.unwrap() as i32 * f2.unwrap() as i32))) + }, + // i8 / _ + (ScalarValue::Int8(f1), ScalarValue::Int8(f2)) => { + Ok(ScalarValue::Int16(Some(f1.unwrap() as i16 * f2.unwrap() as i16))) + }, + // u64 / _ + (ScalarValue::UInt64(f1), ScalarValue::UInt64(f2)) => { + Ok(ScalarValue::UInt64(Some(f1.unwrap() as u64 * f2.unwrap() as u64))) + }, + // u32 / _ + (ScalarValue::UInt32(f1), ScalarValue::UInt32(f2)) => { + Ok(ScalarValue::UInt64(Some(f1.unwrap() as u64 * f2.unwrap() as u64))) + }, + // u16 / _ + (ScalarValue::UInt16(f1), ScalarValue::UInt16(f2)) => { + Ok(ScalarValue::UInt32(Some(f1.unwrap() as u32 * f2.unwrap() as u32))) + }, + // u8 / _ + (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => { + Ok(ScalarValue::UInt16(Some(f1.unwrap() as u16 * f2.unwrap() as u16))) + }, + _ => Err(DataFusionError::Internal( + format!( + "Multiplication only support f64 for now, here has {:?} and {:?}", + lhs.get_datatype(), rhs.get_datatype() + ))), + } + } + + /// Division between two numeric ScalarValues + pub fn div(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(DataFusionError::Internal( + format!( + "Division is only supported on numeric types, \ + here has {:?} and {:?}", + lhs.get_datatype(), rhs.get_datatype() + ))); + } + + // TODO: Finding a good way to support operation between different types without + // writing a hige match block. + match (lhs, rhs) { + (ScalarValue::Decimal128(v1, u1, s1), _) | + (_, ScalarValue::Decimal128(v1, u1, s1)) => { + Err(DataFusionError::Internal( + format!( + "Division with Decimals are not supported for now" + ))) + }, + // f64 / _ + (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() / f2.unwrap()))) + }, + // f32 / _ + (ScalarValue::Float32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64/ f2.unwrap()))) + }, + (ScalarValue::Float32(f1), ScalarValue::Float32(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64/ f2.unwrap() as f64))) + }, + // i64 / _ + (ScalarValue::Int64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::Int64(f1), ScalarValue::Int64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // i32 / _ + (ScalarValue::Int32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::Int32(f1), ScalarValue::Int32(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // i16 / _ + (ScalarValue::Int16(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::Int16(f1), ScalarValue::Int16(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // i8 / _ + (ScalarValue::Int8(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::Int8(f1), ScalarValue::Int8(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // u64 / _ + (ScalarValue::UInt64(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::UInt64(f1), ScalarValue::UInt64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // u32 / _ + (ScalarValue::UInt32(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::UInt32(f1), ScalarValue::UInt32(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // u16 / _ + (ScalarValue::UInt16(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::UInt16(f1), ScalarValue::UInt16(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + // u8 / _ + (ScalarValue::UInt8(f1), ScalarValue::Float64(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap()))) + }, + (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => { + Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) + }, + + _ => Err(DataFusionError::Internal( + format!( + "Division only support calculation with the same type or f64 as denominator for now, here has {:?} and {:?}", + lhs.get_datatype(), rhs.get_datatype() + ))), + } + } + /// Create a decimal Scalar from value/precision and scale. pub fn try_new_decimal128( value: i128, From 70b11165ba9514bb6918d4b4e621730daf73724b Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Tue, 4 Jan 2022 22:02:40 -0800 Subject: [PATCH 04/22] add to expressions and tests --- datafusion/src/physical_plan/aggregates.rs | 85 ++++++++++++++++++- .../coercion_rule/aggregate_rule.rs | 13 ++- .../src/physical_plan/expressions/mod.rs | 6 +- .../src/physical_plan/expressions/variance.rs | 3 - 4 files changed, 99 insertions(+), 8 deletions(-) diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index e9f9696a56e8c..75ed669795343 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -35,7 +35,7 @@ use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_t use crate::physical_plan::distinct_expressions; use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use expressions::{avg_return_type, sum_return_type}; +use expressions::{avg_return_type, sum_return_type, variance_return_type, Variance}; use std::{fmt, str::FromStr, sync::Arc}; /// the implementation of an aggregate function @@ -64,6 +64,8 @@ pub enum AggregateFunction { ApproxDistinct, /// array_agg ArrayAgg, + /// Variance (Population) + Variance, } impl fmt::Display for AggregateFunction { @@ -84,6 +86,7 @@ impl FromStr for AggregateFunction { "sum" => AggregateFunction::Sum, "approx_distinct" => AggregateFunction::ApproxDistinct, "array_agg" => AggregateFunction::ArrayAgg, + "variance" => AggregateFunction::Variance, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -116,6 +119,7 @@ pub fn return_type( Ok(coerced_data_types[0].clone()) } AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), + AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]), AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new( "item", @@ -211,6 +215,16 @@ pub fn create_aggregate_expr( return Err(DataFusionError::NotImplemented( "AVG(DISTINCT) aggregations are not available".to_string(), )); + }, + (AggregateFunction::Variance, false) => Arc::new(expressions::Variance::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )), + (AggregateFunction::Variance, true) => { + return Err(DataFusionError::NotImplemented( + "VARIANCE(DISTINCT) aggregations are not available".to_string(), + )); } }) } @@ -256,7 +270,7 @@ pub fn signature(fun: &AggregateFunction) -> Signature { .collect::>(); Signature::uniform(1, valid, Volatility::Immutable) } - AggregateFunction::Avg | AggregateFunction::Sum => { + AggregateFunction::Avg | AggregateFunction::Sum | AggregateFunction::Variance => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } } @@ -450,6 +464,47 @@ mod tests { Ok(()) } + #[test] + fn test_variance_expr() -> Result<()> { + let funcs = vec![AggregateFunction::Variance]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + match fun { + AggregateFunction::Variance => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ); + } + _ => {} + }; + } + } + Ok(()) + } + #[test] fn test_min_max() -> Result<()> { let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?; @@ -544,4 +599,30 @@ mod tests { let observed = return_type(&AggregateFunction::Avg, &[DataType::Utf8]); assert!(observed.is_err()); } + + #[test] + fn test_variance_return_type() -> Result<()> { + let observed = return_type(&AggregateFunction::Variance, &[DataType::Float32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Variance, &[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Variance, &[DataType::Int32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Variance, &[DataType::UInt32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Variance, &[DataType::Int64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_variance_no_utf8() { + let observed = return_type(&AggregateFunction::Variance, &[DataType::Utf8]); + assert!(observed.is_err()); + } } diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index e76e4a6b023e0..944c942d2f0ef 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -21,7 +21,7 @@ use crate::arrow::datatypes::Schema; use crate::error::{DataFusionError, Result}; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::{ - is_avg_support_arg_type, is_sum_support_arg_type, try_cast, + is_avg_support_arg_type, is_sum_support_arg_type, is_variance_support_arg_type, try_cast, }; use crate::physical_plan::functions::{Signature, TypeSignature}; use crate::physical_plan::PhysicalExpr; @@ -86,6 +86,17 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::Variance => { + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval + if !is_variance_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } } } diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 46c168926d205..1ea6a77f6e149 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -28,8 +28,6 @@ use arrow::record_batch::RecordBatch; mod approx_distinct; mod array_agg; mod average; -mod stddev; -mod variance; #[macro_use] mod binary; mod case; @@ -54,6 +52,8 @@ mod rank; mod row_number; mod sum; mod try_cast; +mod stddev; +mod variance; /// Module with some convenient methods used in expression building pub mod helpers { @@ -88,6 +88,8 @@ pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; pub(crate) use sum::is_sum_support_arg_type; pub use sum::{sum_return_type, Sum}; +pub(crate) use variance::is_variance_support_arg_type; +pub use variance::{variance_return_type, Variance}; pub use try_cast::{try_cast, TryCastExpr}; /// returns the name of the state diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index 0b04977f0fba0..db83717dd3fa3 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -18,7 +18,6 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; -use std::convert::TryFrom; use std::sync::Arc; use crate::error::{DataFusionError, Result}; @@ -26,8 +25,6 @@ use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; use crate::scalar::{ ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128, }; -use arrow::array::*; -use arrow::compute; use arrow::datatypes::DataType; use arrow::datatypes::Field; From 522a960474a3242a6343fd340baab5989e3ef28e Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Tue, 4 Jan 2022 22:43:23 -0800 Subject: [PATCH 05/22] add more tests --- datafusion/src/physical_plan/expressions/variance.rs | 9 +++++++-- datafusion/tests/sql/aggregates.rs | 12 ++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index db83717dd3fa3..eed57b2ed326f 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -59,7 +59,7 @@ pub fn variance_return_type(arg_type: &DataType) -> Result { | DataType::Float32 | DataType::Float64 => Ok(DataType::Float64), other => Err(DataFusionError::Plan(format!( - "STDDEV does not support {:?}", + "VARIANCE does not support {:?}", other ))), } @@ -127,7 +127,12 @@ impl AggregateExpr for Variance { true, ), Field::new( - &format_state_name(&self.name, "sum"), + &format_state_name(&self.name, "mean"), + self.data_type.clone(), + true, + ), + Field::new( + &format_state_name(&self.name, "m2"), self.data_type.clone(), true, ), diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 243d0084d890e..99538fa79c8c9 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -49,6 +49,18 @@ async fn csv_query_avg() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_variance() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT variance(c2) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["1.8675"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_external_table_count() { let mut ctx = ExecutionContext::new(); From fafab18ba5ce01017dee7ffb6fd184a99411757e Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Wed, 5 Jan 2022 21:20:34 -0800 Subject: [PATCH 06/22] add test for ScalarValue add --- datafusion/src/scalar.rs | 108 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 100 insertions(+), 8 deletions(-) diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 8de1621c982dc..1be91998f1da4 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -552,7 +552,7 @@ impl ScalarValue { if !lhs.is_numeric() || !rhs.is_numeric() { return Err(DataFusionError::Internal( format!( - "Division is only supported on numeric types, \ + "Addition only supports numeric types, \ here has {:?} and {:?}", lhs.get_datatype(), rhs.get_datatype() ))); @@ -560,12 +560,13 @@ impl ScalarValue { // TODO: Finding a good way to support operation between different types without // writing a hige match block. + // TODO: Add support for decimal types match (lhs, rhs) { - (ScalarValue::Decimal128(v1, u1, s1), _) | - (_, ScalarValue::Decimal128(v1, u1, s1)) => { + (ScalarValue::Decimal128(_, _, _), _) | + (_, ScalarValue::Decimal128(_, _, _)) => { Err(DataFusionError::Internal( format!( - "Division with Decimals are not supported for now" + "Addition with Decimals are not supported for now" ))) }, // f64 / _ @@ -633,7 +634,7 @@ impl ScalarValue { Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 + f2.unwrap()))) }, (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => { - Ok(ScalarValue::UInt16(Some(f1.unwrap() as u16 / f2.unwrap() as u16))) + Ok(ScalarValue::UInt16(Some(f1.unwrap() as u16 + f2.unwrap() as u16))) }, _ => Err(DataFusionError::Internal( @@ -657,7 +658,15 @@ impl ScalarValue { // TODO: Finding a good way to support operation between different types without // writing a hige match block. + // TODO: Add support for decimal type match (lhs, rhs) { + (ScalarValue::Decimal128(_, _, _), _) | + (_, ScalarValue::Decimal128(_, _, _)) => { + Err(DataFusionError::Internal( + format!( + "Multiplication with Decimals are not supported for now" + ))) + }, // f64 / _ (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { Ok(ScalarValue::Float64(Some(f1.unwrap() * f2.unwrap()))) @@ -718,10 +727,11 @@ impl ScalarValue { } // TODO: Finding a good way to support operation between different types without - // writing a hige match block. + // writing a hige match block. + // TODO: Add support for decimal types match (lhs, rhs) { - (ScalarValue::Decimal128(v1, u1, s1), _) | - (_, ScalarValue::Decimal128(v1, u1, s1)) => { + (ScalarValue::Decimal128(_, _, _), _) | + (_, ScalarValue::Decimal128(_, _, _)) => { Err(DataFusionError::Internal( format!( "Division with Decimals are not supported for now" @@ -3359,3 +3369,85 @@ mod tests { ); } } + +#[test] + fn scalar_addition() { + let v1 = &ScalarValue::from(1 as i64); + let v2 = &ScalarValue::from(2 as i64); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as i64)); + + let v1 = &ScalarValue::from(100 as i64); + let v2 = &ScalarValue::from(-32 as i64); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(68 as i64)); + + let v1 = &ScalarValue::from(-102 as i64); + let v2 = &ScalarValue::from(32 as i64); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(-70 as i64)); + + let v1 = &ScalarValue::from(1); + let v2 = &ScalarValue::from(2); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as i64)); + + let v1 = &ScalarValue::from(std::i32::MAX); + let v2 = &ScalarValue::from(std::i32::MAX); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(std::i32::MAX as i64 * 2)); + + let v1 = &ScalarValue::from(1 as i16); + let v2 = &ScalarValue::from(2 as i16); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as i32)); + + let v1 = &ScalarValue::from(std::i16::MAX); + let v2 = &ScalarValue::from(std::i16::MAX); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(std::i16::MAX as i32 * 2)); + + let v1 = &ScalarValue::from(1 as i8); + let v2 = &ScalarValue::from(2 as i8); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as i16)); + + let v1 = &ScalarValue::from(std::i8::MAX); + let v2 = &ScalarValue::from(std::i8::MAX); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(std::i8::MAX as i16 * 2)); + + let v1 = &ScalarValue::from(1 as u64); + let v2 = &ScalarValue::from(2 as u64); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as u64)); + + let v1 = &ScalarValue::from(1 as u32); + let v2 = &ScalarValue::from(2 as u32); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as u64)); + + let v1 = &ScalarValue::from(std::u32::MAX); + let v2 = &ScalarValue::from(std::u32::MAX); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(std::u32::MAX as u64 * 2)); + + let v1 = &ScalarValue::from(1 as u16); + let v2 = &ScalarValue::from(2 as u16); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as u32)); + + let v1 = &ScalarValue::from(std::u16::MAX); + let v2 = &ScalarValue::from(std::u16::MAX); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(std::u16::MAX as u32 * 2)); + + let v1 = &ScalarValue::from(1 as u8); + let v2 = &ScalarValue::from(2 as u8); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as u16)); + + let v1 = &ScalarValue::from(std::u8::MAX); + let v2 = &ScalarValue::from(std::u8::MAX); + assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(std::u8::MAX as u16 * 2)); + + let v1 = &ScalarValue::from(1); + let v2 = &ScalarValue::from(2 as u16); + let actual = ScalarValue::add(v1, v2).is_err(); + assert_eq!(actual, true); + + let v1 = &ScalarValue::from(1); + let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); + let actual = ScalarValue::add(v1, v2).is_err(); + assert_eq!(actual, true); + + let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); + let v2 = &ScalarValue::from(2); + let actual = ScalarValue::add(v1, v2).is_err(); + assert_eq!(actual, true); + } From 031e8c03c6d80adbcc1c7bc0b7a6a9537b435f92 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Wed, 5 Jan 2022 22:33:53 -0800 Subject: [PATCH 07/22] add tests for scalar arithmetic --- datafusion/src/scalar.rs | 161 +++++++++++++++++++++++---------------- 1 file changed, 97 insertions(+), 64 deletions(-) diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 1be91998f1da4..06ec7a226de34 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -3368,86 +3368,119 @@ mod tests { DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_owned())) ); } -} - -#[test] - fn scalar_addition() { - let v1 = &ScalarValue::from(1 as i64); - let v2 = &ScalarValue::from(2 as i64); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as i64)); - - let v1 = &ScalarValue::from(100 as i64); - let v2 = &ScalarValue::from(-32 as i64); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(68 as i64)); - - let v1 = &ScalarValue::from(-102 as i64); - let v2 = &ScalarValue::from(32 as i64); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(-70 as i64)); - - let v1 = &ScalarValue::from(1); - let v2 = &ScalarValue::from(2); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as i64)); - - let v1 = &ScalarValue::from(std::i32::MAX); - let v2 = &ScalarValue::from(std::i32::MAX); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(std::i32::MAX as i64 * 2)); - - let v1 = &ScalarValue::from(1 as i16); - let v2 = &ScalarValue::from(2 as i16); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as i32)); - let v1 = &ScalarValue::from(std::i16::MAX); - let v2 = &ScalarValue::from(std::i16::MAX); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(std::i16::MAX as i32 * 2)); - - let v1 = &ScalarValue::from(1 as i8); - let v2 = &ScalarValue::from(2 as i8); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as i16)); - - let v1 = &ScalarValue::from(std::i8::MAX); - let v2 = &ScalarValue::from(std::i8::MAX); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(std::i8::MAX as i16 * 2)); - - let v1 = &ScalarValue::from(1 as u64); - let v2 = &ScalarValue::from(2 as u64); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as u64)); + macro_rules! test_scalar_op { + ($OP:ident, $LHS:expr, $LHS_TYPE:ident, $RHS:expr, $RHS_TYPE:ident, $RESULT:expr, $RESULT_TYPE:ident) => {{ + let v1 = &ScalarValue::from($LHS as $LHS_TYPE); + let v2 = &ScalarValue::from($RHS as $RHS_TYPE); + assert_eq!(ScalarValue::$OP(v1, v2).unwrap(), ScalarValue::from($RESULT as $RESULT_TYPE)); + }}; + } - let v1 = &ScalarValue::from(1 as u32); - let v2 = &ScalarValue::from(2 as u32); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as u64)); + macro_rules! test_scalar_op_err { + ($OP:ident, $LHS:expr, $LHS_TYPE:ident, $RHS:expr, $RHS_TYPE:ident) => {{ + let v1 = &ScalarValue::from($LHS as $LHS_TYPE); + let v2 = &ScalarValue::from($RHS as $RHS_TYPE); + let actual = ScalarValue::add(v1, v2).is_err(); + assert_eq!(actual, true); + }}; + } - let v1 = &ScalarValue::from(std::u32::MAX); - let v2 = &ScalarValue::from(std::u32::MAX); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(std::u32::MAX as u64 * 2)); + #[test] + fn scalar_addition() { - let v1 = &ScalarValue::from(1 as u16); - let v2 = &ScalarValue::from(2 as u16); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as u32)); + test_scalar_op!(add, 1, f64, 2, f64, 3, f64); + test_scalar_op!(add, 1, f32, 2, f32, 3, f64); + test_scalar_op!(add, 1, i64, 2, i64, 3, i64); + test_scalar_op!(add, 100, i64, -32, i64, 68, i64); + test_scalar_op!(add, -102, i64, 32, i64, -70, i64); + test_scalar_op!(add, 1, i32, 2, i32, 3, i64); + test_scalar_op!(add, std::i32::MAX, i32, std::i32::MAX, i32, std::i32::MAX as i64 * 2, i64); + test_scalar_op!(add, 1, i16, 2, i16, 3, i32); + test_scalar_op!(add, std::i16::MAX, i16, std::i16::MAX, i16, std::i16::MAX as i32 * 2, i32); + test_scalar_op!(add, 1, i8, 2, i8, 3, i16); + test_scalar_op!(add, std::i8::MAX, i8, std::i8::MAX, i8, std::i8::MAX as i16 * 2, i16); + test_scalar_op!(add, 1, u64, 2, u64, 3, u64); + test_scalar_op!(add, 1, u32, 2, u32, 3, u64); + test_scalar_op!(add, std::u32::MAX, u32, std::u32::MAX, u32, std::u32::MAX as u64 * 2, u64); + test_scalar_op!(add, 1, u16, 2, u16, 3, u32); + test_scalar_op!(add, std::u16::MAX, u16, std::u16::MAX, u16, std::u16::MAX as u32 * 2, u32); + test_scalar_op!(add, 1, u8, 2, u8, 3, u16); + test_scalar_op!(add, std::u8::MAX, u8, std::u8::MAX, u8, std::u8::MAX as u16 * 2, u16); + test_scalar_op_err!(add, 1, i32, 2, u16); + test_scalar_op_err!(add, 1, i32, 2, u16); - let v1 = &ScalarValue::from(std::u16::MAX); - let v2 = &ScalarValue::from(std::u16::MAX); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(std::u16::MAX as u32 * 2)); + let v1 = &ScalarValue::from(1); + let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); + let actual = ScalarValue::add(v1, v2).is_err(); + assert_eq!(actual, true); - let v1 = &ScalarValue::from(1 as u8); - let v2 = &ScalarValue::from(2 as u8); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(3 as u16)); + let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); + let v2 = &ScalarValue::from(2); + let actual = ScalarValue::add(v1, v2).is_err(); + assert_eq!(actual, true); + } - let v1 = &ScalarValue::from(std::u8::MAX); - let v2 = &ScalarValue::from(std::u8::MAX); - assert_eq!(ScalarValue::add(v1, v2).unwrap(), ScalarValue::from(std::u8::MAX as u16 * 2)); + #[test] + fn scalar_multiplication() { + + test_scalar_op!(mul, 1, f64, 2, f64, 2, f64); + test_scalar_op!(mul, 1, f32, 2, f32, 2, f64); + test_scalar_op!(mul, 15, i64, 2, i64, 30, i64); + test_scalar_op!(mul, 100, i64, -32, i64, -3200, i64); + test_scalar_op!(mul, -1.1, f64, 2, f64, -2.2, f64); + test_scalar_op!(mul, 1, i32, 2, i32, 2, i64); + test_scalar_op!(mul, std::i32::MAX, i32, std::i32::MAX, i32, std::i32::MAX as i64 * std::i32::MAX as i64, i64); + test_scalar_op!(mul, 1, i16, 2, i16, 2, i32); + test_scalar_op!(mul, std::i16::MAX, i16, std::i16::MAX, i16, std::i16::MAX as i32 * std::i16::MAX as i32, i32); + test_scalar_op!(mul, 1, i8, 2, i8, 2, i16); + test_scalar_op!(mul, std::i8::MAX, i8, std::i8::MAX, i8, std::i8::MAX as i16 * std::i8::MAX as i16, i16); + test_scalar_op!(mul, 1, u64, 2, u64, 2, u64); + test_scalar_op!(mul, 1, u32, 2, u32, 2, u64); + test_scalar_op!(mul, std::u32::MAX, u32, std::u32::MAX, u32, std::u32::MAX as u64 * std::u32::MAX as u64, u64); + test_scalar_op!(mul, 1, u16, 2, u16, 2, u32); + test_scalar_op!(mul, std::u16::MAX, u16, std::u16::MAX, u16, std::u16::MAX as u32 * std::u16::MAX as u32, u32); + test_scalar_op!(mul, 1, u8, 2, u8, 2, u16); + test_scalar_op!(mul, std::u8::MAX, u8, std::u8::MAX, u8, std::u8::MAX as u16 * std::u8::MAX as u16, u16); + test_scalar_op_err!(mul, 1, i32, 2, u16); + test_scalar_op_err!(mul, 1, i32, 2, u16); let v1 = &ScalarValue::from(1); - let v2 = &ScalarValue::from(2 as u16); - let actual = ScalarValue::add(v1, v2).is_err(); + let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); + let actual = ScalarValue::mul(v1, v2).is_err(); + assert_eq!(actual, true); + + let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); + let v2 = &ScalarValue::from(2); + let actual = ScalarValue::mul(v1, v2).is_err(); assert_eq!(actual, true); + } + + #[test] + fn scalar_division() { + + test_scalar_op!(div, 1, f64, 2, f64, 0.5, f64); + test_scalar_op!(div, 1, f32, 2, f32, 0.5, f64); + test_scalar_op!(div, 15, i64, 2, i64, 7.5, f64); + test_scalar_op!(div, 100, i64, -2, i64, -50, f64); + test_scalar_op!(div, 1, i32, 2, i32, 0.5, f64); + test_scalar_op!(div, 1, i16, 2, i16, 0.5, f64); + test_scalar_op!(div, 1, i8, 2, i8, 0.5, f64); + test_scalar_op!(div, 1, u64, 2, u64, 0.5, f64); + test_scalar_op!(div, 1, u32, 2, u32, 0.5, f64); + test_scalar_op!(div, 1, u16, 2, u16, 0.5, f64); + test_scalar_op!(div, 1, u8, 2, u8, 0.5, f64); + test_scalar_op_err!(div, 1, i32, 2, u16); + test_scalar_op_err!(div, 1, i32, 2, u16); let v1 = &ScalarValue::from(1); let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); - let actual = ScalarValue::add(v1, v2).is_err(); + let actual = ScalarValue::div(v1, v2).is_err(); assert_eq!(actual, true); let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); let v2 = &ScalarValue::from(2); - let actual = ScalarValue::add(v1, v2).is_err(); + let actual = ScalarValue::div(v1, v2).is_err(); assert_eq!(actual, true); } +} \ No newline at end of file From d2a27550e45f2b659e546703aff23de90329aa65 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Wed, 5 Jan 2022 22:53:34 -0800 Subject: [PATCH 08/22] add test, finish variance --- .../src/physical_plan/expressions/variance.rs | 63 ++----------------- 1 file changed, 6 insertions(+), 57 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index eed57b2ed326f..0c44a21f26da6 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -236,7 +236,11 @@ impl Accumulator for VarianceAccumulator { fn evaluate(&self) -> Result { match self.m2 { ScalarValue::Float64(e) => { - Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64))) + if self.count == 0 { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64))) + } } _ => Err(DataFusionError::Internal( "M2 should be f64 for variance".to_string(), @@ -330,61 +334,6 @@ mod tests { Ok(()) } - #[test] - fn variance_decimal() -> Result<()> { - // test agg - let mut decimal_builder = DecimalBuilder::new(6, 10, 0); - for i in 1..7 { - decimal_builder.append_value(i as i128)?; - } - let array: ArrayRef = Arc::new(decimal_builder.finish()); - - generic_test_op!( - array, - DataType::Decimal(10, 0), - Variance, - ScalarValue::Decimal128(Some(35000), 14, 4), - DataType::Decimal(14, 4) - ) - } - - #[test] - fn variance_decimal_with_nulls() -> Result<()> { - let mut decimal_builder = DecimalBuilder::new(5, 10, 0); - for i in 1..6 { - if i == 2 { - decimal_builder.append_null()?; - } else { - decimal_builder.append_value(i)?; - } - } - let array: ArrayRef = Arc::new(decimal_builder.finish()); - generic_test_op!( - array, - DataType::Decimal(10, 0), - Variance, - ScalarValue::Decimal128(Some(32500), 14, 4), - DataType::Decimal(14, 4) - ) - } - - #[test] - fn variance_decimal_all_nulls() -> Result<()> { - // test agg - let mut decimal_builder = DecimalBuilder::new(5, 10, 0); - for _i in 1..6 { - decimal_builder.append_null()?; - } - let array: ArrayRef = Arc::new(decimal_builder.finish()); - generic_test_op!( - array, - DataType::Decimal(10, 0), - Variance, - ScalarValue::Decimal128(None, 14, 4), - DataType::Decimal(14, 4) - ) - } - #[test] fn variance_i32_with_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![ @@ -398,7 +347,7 @@ mod tests { a, DataType::Int32, Variance, - ScalarValue::from(3.25f64), + ScalarValue::from(2.1875f64), DataType::Float64 ) } From b3729b82160bc9d1297d7cfb0a381f79a2bde375 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Wed, 5 Jan 2022 23:04:31 -0800 Subject: [PATCH 09/22] fix warnings --- ballista/rust/core/proto/ballista.proto | 1 + ballista/rust/core/src/serde/logical_plan/to_proto.rs | 3 +++ ballista/rust/core/src/serde/mod.rs | 1 + datafusion/src/physical_plan/aggregates.rs | 4 ++-- datafusion/src/physical_plan/expressions/variance.rs | 7 ++----- 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 493fb97b82b16..0650ce3fb61b1 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -169,6 +169,7 @@ enum AggregateFunction { COUNT = 4; APPROX_DISTINCT = 5; ARRAY_AGG = 6; + VARIANCE=7; } message AggregateExprNode { diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 47b5df47cd730..236d1ae028337 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1134,6 +1134,8 @@ impl TryInto for &Expr { AggregateFunction::Sum => protobuf::AggregateFunction::Sum, AggregateFunction::Avg => protobuf::AggregateFunction::Avg, AggregateFunction::Count => protobuf::AggregateFunction::Count, + AggregateFunction::Variance => protobuf::AggregateFunction::Variance, + }; let arg = &args[0]; @@ -1364,6 +1366,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Count => Self::Count, AggregateFunction::ApproxDistinct => Self::ApproxDistinct, AggregateFunction::ArrayAgg => Self::ArrayAgg, + AggregateFunction::Variance => Self::Variance, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index f5442c40e660f..f98135d4ff7c2 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -119,6 +119,7 @@ impl From for AggregateFunction { AggregateFunction::ApproxDistinct } protobuf::AggregateFunction::ArrayAgg => AggregateFunction::ArrayAgg, + protobuf::AggregateFunction::Variance => AggregateFunction::Variance, } } } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 75ed669795343..d3ae100c1aa48 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -35,7 +35,7 @@ use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_t use crate::physical_plan::distinct_expressions; use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use expressions::{avg_return_type, sum_return_type, variance_return_type, Variance}; +use expressions::{avg_return_type, sum_return_type, variance_return_type}; use std::{fmt, str::FromStr, sync::Arc}; /// the implementation of an aggregate function @@ -281,7 +281,7 @@ mod tests { use super::*; use crate::error::Result; use crate::physical_plan::expressions::{ - ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Sum, + ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Sum, Variance, }; #[test] diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index 0c44a21f26da6..828e99655da5b 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -113,10 +113,7 @@ impl AggregateExpr for Variance { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(VarianceAccumulator::try_new( - // variance is f64 or decimal - &self.data_type, - )?)) + Ok(Box::new(VarianceAccumulator::try_new()?)) } fn state_fields(&self) -> Result> { @@ -158,7 +155,7 @@ pub struct VarianceAccumulator { impl VarianceAccumulator { /// Creates a new `VarianceAccumulator` - pub fn try_new(datatype: &DataType) -> Result { + pub fn try_new() -> Result { Ok(Self { m2: ScalarValue::from(0 as f64), mean: ScalarValue::from(0 as f64), From 48a1485dee735e801f9a23354bf67406b8bf2f69 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Wed, 5 Jan 2022 23:21:25 -0800 Subject: [PATCH 10/22] add more sql tests --- datafusion/tests/sql/aggregates.rs | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 99538fa79c8c9..ff3f22f9f0bbd 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -50,7 +50,7 @@ async fn csv_query_avg() -> Result<()> { } #[tokio::test] -async fn csv_query_variance() -> Result<()> { +async fn csv_query_variance_1() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT variance(c2) FROM aggregate_test_100"; @@ -61,6 +61,30 @@ async fn csv_query_variance() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_variance_2() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT variance(c6) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["26156334342021890000000000000000000000"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_variance_3() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT variance(c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.09234223721582163"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_external_table_count() { let mut ctx = ExecutionContext::new(); From 9c0131114f08f22701aa5cc75337e1ffcd00a69b Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Thu, 6 Jan 2022 12:42:39 -0800 Subject: [PATCH 11/22] add stddev and tests --- .../src/physical_plan/expressions/stddev.rs | 249 ++++++------------ .../src/physical_plan/expressions/variance.rs | 50 ++-- 2 files changed, 101 insertions(+), 198 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/stddev.rs b/datafusion/src/physical_plan/expressions/stddev.rs index 92da6bd4ca557..101d469cf8dc8 100644 --- a/datafusion/src/physical_plan/expressions/stddev.rs +++ b/datafusion/src/physical_plan/expressions/stddev.rs @@ -18,22 +18,15 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; -use std::convert::TryFrom; use std::sync::Arc; use crate::error::{DataFusionError, Result}; -use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; -use crate::scalar::{ - ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128, -}; -use arrow::compute; +use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr, expressions::variance::VarianceAccumulator}; +use crate::scalar::ScalarValue; use arrow::datatypes::DataType; -use arrow::{ - array::{ArrayRef, UInt64Array}, - datatypes::Field, -}; +use arrow::datatypes::Field; -use super::{format_state_name, sum}; +use super::format_state_name; /// STDDEV (standard deviation) aggregate expression #[derive(Debug)] @@ -43,16 +36,9 @@ pub struct Stddev { data_type: DataType, } -/// function return type of an standard deviation +/// function return type of standard deviation pub fn stddev_return_type(arg_type: &DataType) -> Result { match arg_type { - DataType::Decimal(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 - let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4); - let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4); - Ok(DataType::Decimal(new_precision, new_scale)) - } DataType::Int8 | DataType::Int16 | DataType::Int32 @@ -83,7 +69,6 @@ pub(crate) fn is_stddev_support_arg_type(arg_type: &DataType) -> bool { | DataType::Int64 | DataType::Float32 | DataType::Float64 - | DataType::Decimal(_, _) ) } @@ -97,7 +82,7 @@ impl Stddev { // the result of stddev just support FLOAT64 and Decimal data type. assert!(matches!( data_type, - DataType::Float64 | DataType::Decimal(_, _) + DataType::Float64 )); Self { name: name.into(), @@ -118,10 +103,7 @@ impl AggregateExpr for Stddev { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(StddevAccumulator::try_new( - // stddev is f64 or decimal - &self.data_type, - )?)) + Ok(Box::new(StddevAccumulator::try_new()?)) } fn state_fields(&self) -> Result> { @@ -132,8 +114,13 @@ impl AggregateExpr for Stddev { true, ), Field::new( - &format_state_name(&self.name, "sum"), - self.data_type.clone(), + &format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "m2"), + DataType::Float64, true, ), ]) @@ -151,85 +138,43 @@ impl AggregateExpr for Stddev { /// An accumulator to compute the average #[derive(Debug)] pub struct StddevAccumulator { - // sum is used for null - sum: ScalarValue, - count: u64, + variance: VarianceAccumulator, } impl StddevAccumulator { /// Creates a new `StddevAccumulator` - pub fn try_new(datatype: &DataType) -> Result { + pub fn try_new() -> Result { Ok(Self { - sum: ScalarValue::try_from(datatype)?, - count: 0, + variance: VarianceAccumulator::try_new()?, }) } } impl Accumulator for StddevAccumulator { fn state(&self) -> Result> { - Ok(vec![ScalarValue::from(self.count), self.sum.clone()]) + Ok(vec![ScalarValue::from(self.variance.get_count()), self.variance.get_mean(), self.variance.get_m2()]) } fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - let values = &values[0]; - - self.count += (!values.is_null()) as u64; - self.sum = sum::sum(&self.sum, values)?; - - Ok(()) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - - self.count += (values.len() - values.data().null_count()) as u64; - self.sum = sum::sum(&self.sum, &sum::sum_batch(values)?)?; - Ok(()) + self.variance.update(values) } fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - let count = &states[0]; - // counts are summed - if let ScalarValue::UInt64(Some(c)) = count { - self.count += c - } else { - unreachable!() - }; - - // sums are summed - self.sum = sum::sum(&self.sum, &states[1])?; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = states[0].as_any().downcast_ref::().unwrap(); - // counts are summed - self.count += compute::sum(counts).unwrap_or(0); - - // sums are summed - self.sum = sum::sum(&self.sum, &sum::sum_batch(&states[1])?)?; - Ok(()) + self.variance.merge(states) } fn evaluate(&self) -> Result { - match self.sum { + let variance = self.variance.evaluate()?; + match variance { ScalarValue::Float64(e) => { - Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64))) - } - ScalarValue::Decimal128(value, precision, scale) => { - Ok(match value { - None => ScalarValue::Decimal128(None, precision, scale), - // TODO add the checker for overflow the precision - Some(v) => ScalarValue::Decimal128( - Some(v / self.count as i128), - precision, - scale, - ), - }) + if e == None { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(e.map(|f| f.sqrt()))) + } } _ => Err(DataFusionError::Internal( - "Sum should be f64 on average".to_string(), + "Variance should be f64".to_string(), )), } } @@ -244,69 +189,28 @@ mod tests { use arrow::{array::*, datatypes::*}; #[test] - fn test_stddev_return_data_type() -> Result<()> { - let data_type = DataType::Decimal(10, 5); - let result_type = stddev_return_type(&data_type)?; - assert_eq!(DataType::Decimal(14, 9), result_type); - - let data_type = DataType::Decimal(36, 10); - let result_type = stddev_return_type(&data_type)?; - assert_eq!(DataType::Decimal(38, 14), result_type); - Ok(()) - } - - #[test] - fn stddev_decimal() -> Result<()> { - // test agg - let mut decimal_builder = DecimalBuilder::new(6, 10, 0); - for i in 1..7 { - decimal_builder.append_value(i as i128)?; - } - let array: ArrayRef = Arc::new(decimal_builder.finish()); - - generic_test_op!( - array, - DataType::Decimal(10, 0), - Stddev, - ScalarValue::Decimal128(Some(35000), 14, 4), - DataType::Decimal(14, 4) - ) - } - - #[test] - fn stddev_decimal_with_nulls() -> Result<()> { - let mut decimal_builder = DecimalBuilder::new(5, 10, 0); - for i in 1..6 { - if i == 2 { - decimal_builder.append_null()?; - } else { - decimal_builder.append_value(i)?; - } - } - let array: ArrayRef = Arc::new(decimal_builder.finish()); + fn stddev_f64_1() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64, 2_f64])); generic_test_op!( - array, - DataType::Decimal(10, 0), + a, + DataType::Float64, Stddev, - ScalarValue::Decimal128(Some(32500), 14, 4), - DataType::Decimal(14, 4) + ScalarValue::from(0.5_f64), + DataType::Float64 ) } #[test] - fn stddev_decimal_all_nulls() -> Result<()> { - // test agg - let mut decimal_builder = DecimalBuilder::new(5, 10, 0); - for _i in 1..6 { - decimal_builder.append_null()?; - } - let array: ArrayRef = Arc::new(decimal_builder.finish()); + fn stddev_f64_2() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); generic_test_op!( - array, - DataType::Decimal(10, 0), + a, + DataType::Float64, Stddev, - ScalarValue::Decimal128(None, 14, 4), - DataType::Decimal(14, 4) + ScalarValue::from(1.4142135623730951_f64), + DataType::Float64 ) } @@ -317,76 +221,75 @@ mod tests { a, DataType::Int32, Stddev, - ScalarValue::from(3_f64), + ScalarValue::from(1.4142135623730951_f64), DataType::Float64 ) } #[test] - fn stddev_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); + fn stddev_u32() -> Result<()> { + let a: ArrayRef = + Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); generic_test_op!( a, - DataType::Int32, + DataType::UInt32, Stddev, - ScalarValue::from(3.25f64), + ScalarValue::from(1.4142135623730951f64), DataType::Float64 ) } #[test] - fn stddev_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); + fn stddev_f32() -> Result<()> { + let a: ArrayRef = + Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); generic_test_op!( a, - DataType::Int32, + DataType::Float32, Stddev, - ScalarValue::Float64(None), + ScalarValue::from(1.4142135623730951_f64), DataType::Float64 ) } #[test] - fn stddev_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!( - a, - DataType::UInt32, - Stddev, - ScalarValue::from(3.0f64), - DataType::Float64 - ) + fn test_stddev_return_data_type() -> Result<()> { + let data_type = DataType::Float64; + let result_type = stddev_return_type(&data_type)?; + assert_eq!(DataType::Float64, result_type); + + let data_type = DataType::Decimal(36, 10); + let result_type = stddev_return_type(&data_type).is_err(); + assert_eq!(true, result_type); + Ok(()) } #[test] - fn stddev_f32() -> Result<()> { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); + fn stddev_i32_with_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(3), + Some(4), + Some(5), + ])); generic_test_op!( a, - DataType::Float32, + DataType::Int32, Stddev, - ScalarValue::from(3_f64), + ScalarValue::from(1.479019945774904), DataType::Float64 ) } #[test] - fn stddev_f64() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); + fn stddev_i32_all_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); generic_test_op!( a, - DataType::Float64, + DataType::Int32, Stddev, - ScalarValue::from(3_f64), + ScalarValue::Float64(None), DataType::Float64 ) } diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index 828e99655da5b..0de249251a49e 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -22,32 +22,22 @@ use std::sync::Arc; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; -use crate::scalar::{ - ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128, -}; +use crate::scalar::ScalarValue; use arrow::datatypes::DataType; use arrow::datatypes::Field; use super::format_state_name; -/// STDDEV (standard deviation) aggregate expression +/// VARIANCE aggregate expression #[derive(Debug)] pub struct Variance { name: String, expr: Arc, - data_type: DataType, } -/// function return type of an standard deviation +/// function return type of variance pub fn variance_return_type(arg_type: &DataType) -> Result { match arg_type { - DataType::Decimal(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 - let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4); - let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4); - Ok(DataType::Decimal(new_precision, new_scale)) - } DataType::Int8 | DataType::Int16 | DataType::Int32 @@ -78,26 +68,24 @@ pub(crate) fn is_variance_support_arg_type(arg_type: &DataType) -> bool { | DataType::Int64 | DataType::Float32 | DataType::Float64 - | DataType::Decimal(_, _) ) } impl Variance { - /// Create a new STDDEV aggregate function + /// Create a new VARIANCE aggregate function pub fn new( expr: Arc, name: impl Into, data_type: DataType, ) -> Self { - // the result of variance just support FLOAT64 and Decimal data type. + // the result of variance just support FLOAT64 data type. assert!(matches!( data_type, - DataType::Float64 | DataType::Decimal(_, _) + DataType::Float64 )); Self { name: name.into(), expr, - data_type, } } } @@ -109,7 +97,7 @@ impl AggregateExpr for Variance { } fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) + Ok(Field::new(&self.name, DataType::Float64,true)) } fn create_accumulator(&self) -> Result> { @@ -125,12 +113,12 @@ impl AggregateExpr for Variance { ), Field::new( &format_state_name(&self.name, "mean"), - self.data_type.clone(), + DataType::Float64, true, ), Field::new( &format_state_name(&self.name, "m2"), - self.data_type.clone(), + DataType::Float64, true, ), ]) @@ -162,6 +150,18 @@ impl VarianceAccumulator { count: 0, }) } + + pub fn get_count(&self) -> u64 { + self.count + } + + pub fn get_mean(&self) -> ScalarValue { + self.mean.clone() + } + + pub fn get_m2(&self) -> ScalarValue { + self.m2.clone() + } } impl Accumulator for VarianceAccumulator { @@ -321,13 +321,13 @@ mod tests { #[test] fn test_variance_return_data_type() -> Result<()> { - let data_type = DataType::Decimal(10, 5); + let data_type = DataType::Float64; let result_type = variance_return_type(&data_type)?; - assert_eq!(DataType::Decimal(14, 9), result_type); + assert_eq!(DataType::Float64, result_type); let data_type = DataType::Decimal(36, 10); - let result_type = variance_return_type(&data_type)?; - assert_eq!(DataType::Decimal(38, 14), result_type); + let result_type = variance_return_type(&data_type).is_err(); + assert_eq!(true, result_type); Ok(()) } From b2747f5bc70dde5207c985a2d04d0e6004e00e47 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Thu, 6 Jan 2022 12:58:57 -0800 Subject: [PATCH 12/22] add the hooks and expression --- ballista/rust/core/proto/ballista.proto | 1 + .../core/src/serde/logical_plan/to_proto.rs | 3 +- ballista/rust/core/src/serde/mod.rs | 1 + datafusion/src/physical_plan/aggregates.rs | 89 ++++++++++++++++++- .../coercion_rule/aggregate_rule.rs | 17 +++- .../src/physical_plan/expressions/mod.rs | 2 + 6 files changed, 105 insertions(+), 8 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 0650ce3fb61b1..9e9d6a497b575 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -170,6 +170,7 @@ enum AggregateFunction { APPROX_DISTINCT = 5; ARRAY_AGG = 6; VARIANCE=7; + STDDEV=8; } message AggregateExprNode { diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 236d1ae028337..aa073893f3728 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1135,7 +1135,7 @@ impl TryInto for &Expr { AggregateFunction::Avg => protobuf::AggregateFunction::Avg, AggregateFunction::Count => protobuf::AggregateFunction::Count, AggregateFunction::Variance => protobuf::AggregateFunction::Variance, - + AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, }; let arg = &args[0]; @@ -1367,6 +1367,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::ApproxDistinct => Self::ApproxDistinct, AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Variance => Self::Variance, + AggregateFunction::Stddev => Self::Stddev, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index f98135d4ff7c2..62a7ed4eae8b1 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -120,6 +120,7 @@ impl From for AggregateFunction { } protobuf::AggregateFunction::ArrayAgg => AggregateFunction::ArrayAgg, protobuf::AggregateFunction::Variance => AggregateFunction::Variance, + protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev, } } } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index d3ae100c1aa48..df0f8b6fcecc0 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -35,7 +35,7 @@ use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_t use crate::physical_plan::distinct_expressions; use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use expressions::{avg_return_type, sum_return_type, variance_return_type}; +use expressions::{avg_return_type, sum_return_type, variance_return_type, stddev_return_type}; use std::{fmt, str::FromStr, sync::Arc}; /// the implementation of an aggregate function @@ -66,6 +66,8 @@ pub enum AggregateFunction { ArrayAgg, /// Variance (Population) Variance, + /// Standard Deviation (Population) + Stddev, } impl fmt::Display for AggregateFunction { @@ -87,6 +89,7 @@ impl FromStr for AggregateFunction { "approx_distinct" => AggregateFunction::ApproxDistinct, "array_agg" => AggregateFunction::ArrayAgg, "variance" => AggregateFunction::Variance, + "stddev" => AggregateFunction::Stddev, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -120,6 +123,7 @@ pub fn return_type( } AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]), + AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]), AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new( "item", @@ -225,7 +229,17 @@ pub fn create_aggregate_expr( return Err(DataFusionError::NotImplemented( "VARIANCE(DISTINCT) aggregations are not available".to_string(), )); - } + }, + (AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )), + (AggregateFunction::Stddev, true) => { + return Err(DataFusionError::NotImplemented( + "VARIANCE(DISTINCT) aggregations are not available".to_string(), + )); + }, }) } @@ -270,7 +284,7 @@ pub fn signature(fun: &AggregateFunction) -> Signature { .collect::>(); Signature::uniform(1, valid, Volatility::Immutable) } - AggregateFunction::Avg | AggregateFunction::Sum | AggregateFunction::Variance => { + AggregateFunction::Avg | AggregateFunction::Sum | AggregateFunction::Variance | AggregateFunction::Stddev=> { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } } @@ -281,7 +295,7 @@ mod tests { use super::*; use crate::error::Result; use crate::physical_plan::expressions::{ - ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Sum, Variance, + ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Sum, Variance, Stddev, }; #[test] @@ -505,6 +519,47 @@ mod tests { Ok(()) } + #[test] + fn test_stddev_expr() -> Result<()> { + let funcs = vec![AggregateFunction::Stddev]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + match fun { + AggregateFunction::Stddev => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ); + } + _ => {} + }; + } + } + Ok(()) + } + #[test] fn test_min_max() -> Result<()> { let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?; @@ -625,4 +680,30 @@ mod tests { let observed = return_type(&AggregateFunction::Variance, &[DataType::Utf8]); assert!(observed.is_err()); } + + #[test] + fn test_stddev_return_type() -> Result<()> { + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Float32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Stddev, &[DataType::UInt32])?; + assert_eq!(DataType::Float64, observed); + + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_stddev_no_utf8() { + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Utf8]); + assert!(observed.is_err()); + } } diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index 944c942d2f0ef..7d8d8bf7dc5f5 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -21,7 +21,11 @@ use crate::arrow::datatypes::Schema; use crate::error::{DataFusionError, Result}; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::{ - is_avg_support_arg_type, is_sum_support_arg_type, is_variance_support_arg_type, try_cast, + is_avg_support_arg_type, + is_sum_support_arg_type, + is_variance_support_arg_type, + try_cast, + is_stddev_support_arg_type, }; use crate::physical_plan::functions::{Signature, TypeSignature}; use crate::physical_plan::PhysicalExpr; @@ -87,8 +91,6 @@ pub(crate) fn coerce_types( Ok(input_types.to_vec()) } AggregateFunction::Variance => { - // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc - // smallint, int, bigint, real, double precision, decimal, or interval if !is_variance_support_arg_type(&input_types[0]) { return Err(DataFusionError::Plan(format!( "The function {:?} does not support inputs of type {:?}.", @@ -97,6 +99,15 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::Stddev => { + if !is_stddev_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } } } diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 1ea6a77f6e149..6d90fb493cd96 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -90,6 +90,8 @@ pub(crate) use sum::is_sum_support_arg_type; pub use sum::{sum_return_type, Sum}; pub(crate) use variance::is_variance_support_arg_type; pub use variance::{variance_return_type, Variance}; +pub(crate) use stddev::is_stddev_support_arg_type; +pub use stddev::{stddev_return_type, Stddev}; pub use try_cast::{try_cast, TryCastExpr}; /// returns the name of the state From d393cb9a0c3c2ec4ee6f2721599da476ccb1beb8 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Thu, 6 Jan 2022 13:04:29 -0800 Subject: [PATCH 13/22] add more tests --- datafusion/tests/sql/aggregates.rs | 36 ++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index ff3f22f9f0bbd..f4e5d4afcad90 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -85,6 +85,42 @@ async fn csv_query_variance_3() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_stddev_1() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT stddev(c2) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["1.3665650368716449"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_stddev_2() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT stddev(c6) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["5114326382039172000"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_stddev_3() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT stddev(c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.30387865541334363"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_external_table_count() { let mut ctx = ExecutionContext::new(); From 0e534a4f4d18bb8b55fc126943f4ce80fa4c3b65 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Thu, 6 Jan 2022 18:55:46 -0800 Subject: [PATCH 14/22] fix lint and clipy --- datafusion/src/physical_plan/aggregates.rs | 55 ++-- .../coercion_rule/aggregate_rule.rs | 7 +- .../src/physical_plan/expressions/mod.rs | 8 +- .../src/physical_plan/expressions/stddev.rs | 29 +- .../src/physical_plan/expressions/variance.rs | 60 ++-- datafusion/src/scalar.rs | 295 +++++++++++------- 6 files changed, 264 insertions(+), 190 deletions(-) diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index df0f8b6fcecc0..3008d12271643 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -35,7 +35,9 @@ use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_t use crate::physical_plan::distinct_expressions; use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use expressions::{avg_return_type, sum_return_type, variance_return_type, stddev_return_type}; +use expressions::{ + avg_return_type, stddev_return_type, sum_return_type, variance_return_type, +}; use std::{fmt, str::FromStr, sync::Arc}; /// the implementation of an aggregate function @@ -219,7 +221,7 @@ pub fn create_aggregate_expr( return Err(DataFusionError::NotImplemented( "AVG(DISTINCT) aggregations are not available".to_string(), )); - }, + } (AggregateFunction::Variance, false) => Arc::new(expressions::Variance::new( coerced_phy_exprs[0].clone(), name, @@ -229,7 +231,7 @@ pub fn create_aggregate_expr( return Err(DataFusionError::NotImplemented( "VARIANCE(DISTINCT) aggregations are not available".to_string(), )); - }, + } (AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new( coerced_phy_exprs[0].clone(), name, @@ -239,7 +241,7 @@ pub fn create_aggregate_expr( return Err(DataFusionError::NotImplemented( "VARIANCE(DISTINCT) aggregations are not available".to_string(), )); - }, + } }) } @@ -284,7 +286,10 @@ pub fn signature(fun: &AggregateFunction) -> Signature { .collect::>(); Signature::uniform(1, valid, Volatility::Immutable) } - AggregateFunction::Avg | AggregateFunction::Sum | AggregateFunction::Variance | AggregateFunction::Stddev=> { + AggregateFunction::Avg + | AggregateFunction::Sum + | AggregateFunction::Variance + | AggregateFunction::Stddev => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } } @@ -295,7 +300,7 @@ mod tests { use super::*; use crate::error::Result; use crate::physical_plan::expressions::{ - ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Sum, Variance, Stddev, + ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Stddev, Sum, Variance, }; #[test] @@ -503,17 +508,14 @@ mod tests { &input_schema, "c1", )?; - match fun { - AggregateFunction::Variance => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; + if fun == AggregateFunction::Variance { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ) + } } } Ok(()) @@ -544,17 +546,14 @@ mod tests { &input_schema, "c1", )?; - match fun { - AggregateFunction::Stddev => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; + if fun == AggregateFunction::Variance { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ) + } } } Ok(()) diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index 7d8d8bf7dc5f5..d5d547fc859c7 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -21,11 +21,8 @@ use crate::arrow::datatypes::Schema; use crate::error::{DataFusionError, Result}; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::{ - is_avg_support_arg_type, - is_sum_support_arg_type, - is_variance_support_arg_type, - try_cast, - is_stddev_support_arg_type, + is_avg_support_arg_type, is_stddev_support_arg_type, is_sum_support_arg_type, + is_variance_support_arg_type, try_cast, }; use crate::physical_plan::functions::{Signature, TypeSignature}; use crate::physical_plan::PhysicalExpr; diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 6d90fb493cd96..079ffd39a892c 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -50,9 +50,9 @@ mod nth_value; mod nullif; mod rank; mod row_number; +mod stddev; mod sum; mod try_cast; -mod stddev; mod variance; /// Module with some convenient methods used in expression building @@ -86,13 +86,13 @@ pub use nth_value::NthValue; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; +pub(crate) use stddev::is_stddev_support_arg_type; +pub use stddev::{stddev_return_type, Stddev}; pub(crate) use sum::is_sum_support_arg_type; pub use sum::{sum_return_type, Sum}; +pub use try_cast::{try_cast, TryCastExpr}; pub(crate) use variance::is_variance_support_arg_type; pub use variance::{variance_return_type, Variance}; -pub(crate) use stddev::is_stddev_support_arg_type; -pub use stddev::{stddev_return_type, Stddev}; -pub use try_cast::{try_cast, TryCastExpr}; /// returns the name of the state pub fn format_state_name(name: &str, state_name: &str) -> String { diff --git a/datafusion/src/physical_plan/expressions/stddev.rs b/datafusion/src/physical_plan/expressions/stddev.rs index 101d469cf8dc8..ac8d7954743d0 100644 --- a/datafusion/src/physical_plan/expressions/stddev.rs +++ b/datafusion/src/physical_plan/expressions/stddev.rs @@ -21,7 +21,9 @@ use std::any::Any; use std::sync::Arc; use crate::error::{DataFusionError, Result}; -use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr, expressions::variance::VarianceAccumulator}; +use crate::physical_plan::{ + expressions::variance::VarianceAccumulator, Accumulator, AggregateExpr, PhysicalExpr, +}; use crate::scalar::ScalarValue; use arrow::datatypes::DataType; use arrow::datatypes::Field; @@ -80,10 +82,7 @@ impl Stddev { data_type: DataType, ) -> Self { // the result of stddev just support FLOAT64 and Decimal data type. - assert!(matches!( - data_type, - DataType::Float64 - )); + assert!(matches!(data_type, DataType::Float64)); Self { name: name.into(), expr, @@ -152,7 +151,11 @@ impl StddevAccumulator { impl Accumulator for StddevAccumulator { fn state(&self) -> Result> { - Ok(vec![ScalarValue::from(self.variance.get_count()), self.variance.get_mean(), self.variance.get_m2()]) + Ok(vec![ + ScalarValue::from(self.variance.get_count()), + self.variance.get_mean(), + self.variance.get_m2(), + ]) } fn update(&mut self, values: &[ScalarValue]) -> Result<()> { @@ -190,8 +193,7 @@ mod tests { #[test] fn stddev_f64_1() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64])); + let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64])); generic_test_op!( a, DataType::Float64, @@ -209,7 +211,7 @@ mod tests { a, DataType::Float64, Stddev, - ScalarValue::from(1.4142135623730951_f64), + ScalarValue::from(std::f64::consts::SQRT_2), DataType::Float64 ) } @@ -221,7 +223,7 @@ mod tests { a, DataType::Int32, Stddev, - ScalarValue::from(1.4142135623730951_f64), + ScalarValue::from(std::f64::consts::SQRT_2), DataType::Float64 ) } @@ -234,7 +236,7 @@ mod tests { a, DataType::UInt32, Stddev, - ScalarValue::from(1.4142135623730951f64), + ScalarValue::from(std::f64::consts::SQRT_2), DataType::Float64 ) } @@ -247,7 +249,7 @@ mod tests { a, DataType::Float32, Stddev, - ScalarValue::from(1.4142135623730951_f64), + ScalarValue::from(std::f64::consts::SQRT_2), DataType::Float64 ) } @@ -259,8 +261,7 @@ mod tests { assert_eq!(DataType::Float64, result_type); let data_type = DataType::Decimal(36, 10); - let result_type = stddev_return_type(&data_type).is_err(); - assert_eq!(true, result_type); + assert!(!stddev_return_type(&data_type).is_err()); Ok(()) } diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index 0de249251a49e..8699136f83e5c 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -79,10 +79,7 @@ impl Variance { data_type: DataType, ) -> Self { // the result of variance just support FLOAT64 data type. - assert!(matches!( - data_type, - DataType::Float64 - )); + assert!(matches!(data_type, DataType::Float64)); Self { name: name.into(), expr, @@ -97,7 +94,7 @@ impl AggregateExpr for Variance { } fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64,true)) + Ok(Field::new(&self.name, DataType::Float64, true)) } fn create_accumulator(&self) -> Result> { @@ -166,7 +163,11 @@ impl VarianceAccumulator { impl Accumulator for VarianceAccumulator { fn state(&self) -> Result> { - Ok(vec![ScalarValue::from(self.count), self.mean.clone(), self.m2.clone()]) + Ok(vec![ + ScalarValue::from(self.count), + self.mean.clone(), + self.m2.clone(), + ]) } fn update(&mut self, values: &[ScalarValue]) -> Result<()> { @@ -178,7 +179,8 @@ impl Accumulator for VarianceAccumulator { let delta1 = ScalarValue::add(values, &self.mean.arithmetic_negate())?; let new_mean = ScalarValue::add( &ScalarValue::div(&delta1, &ScalarValue::from(new_count as f64))?, - &self.mean)?; + &self.mean, + )?; let delta2 = ScalarValue::add(values, &new_mean.arithmetic_negate())?; let tmp = ScalarValue::mul(&delta1, &delta2)?; @@ -187,7 +189,7 @@ impl Accumulator for VarianceAccumulator { self.mean = new_mean; self.m2 = new_m2; } - + Ok(()) } @@ -202,26 +204,25 @@ impl Accumulator for VarianceAccumulator { } else { unreachable!() }; - let new_mean = - ScalarValue::div( - &ScalarValue::add( - &self.mean, - mean)?, - &ScalarValue::from(2 as f64))?; + let new_mean = ScalarValue::div( + &ScalarValue::add(&self.mean, mean)?, + &ScalarValue::from(2_f64), + )?; let delta = ScalarValue::add(&mean.arithmetic_negate(), &self.mean)?; let delta_sqrt = ScalarValue::mul(&delta, &delta)?; - let new_m2 = - ScalarValue::add( - &ScalarValue::add( - &ScalarValue::mul( - &delta_sqrt, - &ScalarValue::div( - &ScalarValue::mul( - &ScalarValue::from(self.count), - count)?, - &ScalarValue::from(new_count as f64))?)?, - &self.m2)?, - &m2)?; + let new_m2 = ScalarValue::add( + &ScalarValue::add( + &ScalarValue::mul( + &delta_sqrt, + &ScalarValue::div( + &ScalarValue::mul(&ScalarValue::from(self.count), count)?, + &ScalarValue::from(new_count as f64), + )?, + )?, + &self.m2, + )?, + m2, + )?; self.count = new_count; self.mean = new_mean; @@ -254,11 +255,9 @@ mod tests { use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; - #[test] fn variance_f64_1() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64])); + let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64])); generic_test_op!( a, DataType::Float64, @@ -326,8 +325,7 @@ mod tests { assert_eq!(DataType::Float64, result_type); let data_type = DataType::Decimal(36, 10); - let result_type = variance_return_type(&data_type).is_err(); - assert_eq!(true, result_type); + assert!(!variance_return_type(&data_type).is_err()); Ok(()) } diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 06ec7a226de34..747530a249d9b 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -526,48 +526,43 @@ macro_rules! eq_array_primitive { } impl ScalarValue { - /// Return true if the value is numeric pub fn is_numeric(&self) -> bool { - match self { - ScalarValue::Float32(_) | - ScalarValue::Float64(_) | - ScalarValue::Decimal128(_, _, _) | - ScalarValue::Int8(_) | - ScalarValue::Int16(_) | - ScalarValue::Int32(_) | - ScalarValue::Int64(_) | - ScalarValue::UInt8(_) | - ScalarValue::UInt16(_) | - ScalarValue::UInt32(_) | - ScalarValue::UInt64(_) => { - true - } - _ => false - } - } + matches!(self, + ScalarValue::Float32(_) + | ScalarValue::Float64(_) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + ) + } /// Add two numeric ScalarValues pub fn add(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { if !lhs.is_numeric() || !rhs.is_numeric() { - return Err(DataFusionError::Internal( - format!( - "Addition only supports numeric types, \ + return Err(DataFusionError::Internal(format!( + "Addition only supports numeric types, \ here has {:?} and {:?}", - lhs.get_datatype(), rhs.get_datatype() + lhs.get_datatype(), + rhs.get_datatype() ))); } // TODO: Finding a good way to support operation between different types without - // writing a hige match block. + // writing a hige match block. // TODO: Add support for decimal types match (lhs, rhs) { - (ScalarValue::Decimal128(_, _, _), _) | + (ScalarValue::Decimal128(_, _, _), _) | (_, ScalarValue::Decimal128(_, _, _)) => { Err(DataFusionError::Internal( - format!( - "Addition with Decimals are not supported for now" - ))) + "Addition with Decimals are not supported for now".to_string() + )) }, // f64 / _ (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { @@ -636,7 +631,6 @@ impl ScalarValue { (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => { Ok(ScalarValue::UInt16(Some(f1.unwrap() as u16 + f2.unwrap() as u16))) }, - _ => Err(DataFusionError::Internal( format!( "Addition only support calculation with the same type or f64 as one of the numbers for now, here has {:?} and {:?}", @@ -648,69 +642,66 @@ impl ScalarValue { /// Multiply two numeric ScalarValues pub fn mul(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { if !lhs.is_numeric() || !rhs.is_numeric() { - return Err(DataFusionError::Internal( - format!( - "Multiplication is only supported on numeric types, \ + return Err(DataFusionError::Internal(format!( + "Multiplication is only supported on numeric types, \ here has {:?} and {:?}", - lhs.get_datatype(), rhs.get_datatype() + lhs.get_datatype(), + rhs.get_datatype() ))); } // TODO: Finding a good way to support operation between different types without - // writing a hige match block. + // writing a hige match block. // TODO: Add support for decimal type match (lhs, rhs) { - (ScalarValue::Decimal128(_, _, _), _) | - (_, ScalarValue::Decimal128(_, _, _)) => { - Err(DataFusionError::Internal( - format!( - "Multiplication with Decimals are not supported for now" - ))) - }, + (ScalarValue::Decimal128(_, _, _), _) + | (_, ScalarValue::Decimal128(_, _, _)) => Err(DataFusionError::Internal( + "Multiplication with Decimals are not supported for now".to_string() + )), // f64 / _ (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { Ok(ScalarValue::Float64(Some(f1.unwrap() * f2.unwrap()))) - }, + } // f32 / _ - (ScalarValue::Float32(f1), ScalarValue::Float32(f2)) => { - Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 * f2.unwrap() as f64))) - }, + (ScalarValue::Float32(f1), ScalarValue::Float32(f2)) => Ok( + ScalarValue::Float64(Some(f1.unwrap() as f64 * f2.unwrap() as f64)), + ), // i64 / _ (ScalarValue::Int64(f1), ScalarValue::Int64(f2)) => { Ok(ScalarValue::Int64(Some(f1.unwrap() * f2.unwrap()))) - }, + } // i32 / _ - (ScalarValue::Int32(f1), ScalarValue::Int32(f2)) => { - Ok(ScalarValue::Int64(Some(f1.unwrap() as i64 * f2.unwrap() as i64))) - }, + (ScalarValue::Int32(f1), ScalarValue::Int32(f2)) => Ok(ScalarValue::Int64( + Some(f1.unwrap() as i64 * f2.unwrap() as i64), + )), // i16 / _ - (ScalarValue::Int16(f1), ScalarValue::Int16(f2)) => { - Ok(ScalarValue::Int32(Some(f1.unwrap() as i32 * f2.unwrap() as i32))) - }, + (ScalarValue::Int16(f1), ScalarValue::Int16(f2)) => Ok(ScalarValue::Int32( + Some(f1.unwrap() as i32 * f2.unwrap() as i32), + )), // i8 / _ - (ScalarValue::Int8(f1), ScalarValue::Int8(f2)) => { - Ok(ScalarValue::Int16(Some(f1.unwrap() as i16 * f2.unwrap() as i16))) - }, + (ScalarValue::Int8(f1), ScalarValue::Int8(f2)) => Ok(ScalarValue::Int16( + Some(f1.unwrap() as i16 * f2.unwrap() as i16), + )), // u64 / _ - (ScalarValue::UInt64(f1), ScalarValue::UInt64(f2)) => { - Ok(ScalarValue::UInt64(Some(f1.unwrap() as u64 * f2.unwrap() as u64))) - }, + (ScalarValue::UInt64(f1), ScalarValue::UInt64(f2)) => Ok( + ScalarValue::UInt64(Some(f1.unwrap() as u64 * f2.unwrap() as u64)), + ), // u32 / _ - (ScalarValue::UInt32(f1), ScalarValue::UInt32(f2)) => { - Ok(ScalarValue::UInt64(Some(f1.unwrap() as u64 * f2.unwrap() as u64))) - }, + (ScalarValue::UInt32(f1), ScalarValue::UInt32(f2)) => Ok( + ScalarValue::UInt64(Some(f1.unwrap() as u64 * f2.unwrap() as u64)), + ), // u16 / _ - (ScalarValue::UInt16(f1), ScalarValue::UInt16(f2)) => { - Ok(ScalarValue::UInt32(Some(f1.unwrap() as u32 * f2.unwrap() as u32))) - }, + (ScalarValue::UInt16(f1), ScalarValue::UInt16(f2)) => Ok( + ScalarValue::UInt32(Some(f1.unwrap() as u32 * f2.unwrap() as u32)), + ), // u8 / _ - (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => { - Ok(ScalarValue::UInt16(Some(f1.unwrap() as u16 * f2.unwrap() as u16))) - }, - _ => Err(DataFusionError::Internal( - format!( + (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => Ok(ScalarValue::UInt16( + Some(f1.unwrap() as u16 * f2.unwrap() as u16), + )), + _ => Err(DataFusionError::Internal(format!( "Multiplication only support f64 for now, here has {:?} and {:?}", - lhs.get_datatype(), rhs.get_datatype() + lhs.get_datatype(), + rhs.get_datatype() ))), } } @@ -718,24 +709,23 @@ impl ScalarValue { /// Division between two numeric ScalarValues pub fn div(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { if !lhs.is_numeric() || !rhs.is_numeric() { - return Err(DataFusionError::Internal( - format!( - "Division is only supported on numeric types, \ + return Err(DataFusionError::Internal(format!( + "Division is only supported on numeric types, \ here has {:?} and {:?}", - lhs.get_datatype(), rhs.get_datatype() + lhs.get_datatype(), + rhs.get_datatype() ))); } // TODO: Finding a good way to support operation between different types without // writing a hige match block. - // TODO: Add support for decimal types + // TODO: Add support for decimal types match (lhs, rhs) { - (ScalarValue::Decimal128(_, _, _), _) | + (ScalarValue::Decimal128(_, _, _), _) | (_, ScalarValue::Decimal128(_, _, _)) => { Err(DataFusionError::Internal( - format!( - "Division with Decimals are not supported for now" - ))) + "Division with Decimals are not supported for now".to_string() + )) }, // f64 / _ (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { @@ -804,7 +794,6 @@ impl ScalarValue { (ScalarValue::UInt8(f1), ScalarValue::UInt8(f2)) => { Ok(ScalarValue::Float64(Some(f1.unwrap() as f64 / f2.unwrap() as f64))) }, - _ => Err(DataFusionError::Internal( format!( "Division only support calculation with the same type or f64 as denominator for now, here has {:?} and {:?}", @@ -3373,7 +3362,10 @@ mod tests { ($OP:ident, $LHS:expr, $LHS_TYPE:ident, $RHS:expr, $RHS_TYPE:ident, $RESULT:expr, $RESULT_TYPE:ident) => {{ let v1 = &ScalarValue::from($LHS as $LHS_TYPE); let v2 = &ScalarValue::from($RHS as $RHS_TYPE); - assert_eq!(ScalarValue::$OP(v1, v2).unwrap(), ScalarValue::from($RESULT as $RESULT_TYPE)); + assert_eq!( + ScalarValue::$OP(v1, v2).unwrap(), + ScalarValue::from($RESULT as $RESULT_TYPE) + ); }}; } @@ -3388,77 +3380,166 @@ mod tests { #[test] fn scalar_addition() { - test_scalar_op!(add, 1, f64, 2, f64, 3, f64); test_scalar_op!(add, 1, f32, 2, f32, 3, f64); test_scalar_op!(add, 1, i64, 2, i64, 3, i64); test_scalar_op!(add, 100, i64, -32, i64, 68, i64); test_scalar_op!(add, -102, i64, 32, i64, -70, i64); test_scalar_op!(add, 1, i32, 2, i32, 3, i64); - test_scalar_op!(add, std::i32::MAX, i32, std::i32::MAX, i32, std::i32::MAX as i64 * 2, i64); + test_scalar_op!( + add, + std::i32::MAX, + i32, + std::i32::MAX, + i32, + std::i32::MAX as i64 * 2, + i64 + ); test_scalar_op!(add, 1, i16, 2, i16, 3, i32); - test_scalar_op!(add, std::i16::MAX, i16, std::i16::MAX, i16, std::i16::MAX as i32 * 2, i32); + test_scalar_op!( + add, + std::i16::MAX, + i16, + std::i16::MAX, + i16, + std::i16::MAX as i32 * 2, + i32 + ); test_scalar_op!(add, 1, i8, 2, i8, 3, i16); - test_scalar_op!(add, std::i8::MAX, i8, std::i8::MAX, i8, std::i8::MAX as i16 * 2, i16); + test_scalar_op!( + add, + std::i8::MAX, + i8, + std::i8::MAX, + i8, + std::i8::MAX as i16 * 2, + i16 + ); test_scalar_op!(add, 1, u64, 2, u64, 3, u64); test_scalar_op!(add, 1, u32, 2, u32, 3, u64); - test_scalar_op!(add, std::u32::MAX, u32, std::u32::MAX, u32, std::u32::MAX as u64 * 2, u64); + test_scalar_op!( + add, + std::u32::MAX, + u32, + std::u32::MAX, + u32, + std::u32::MAX as u64 * 2, + u64 + ); test_scalar_op!(add, 1, u16, 2, u16, 3, u32); - test_scalar_op!(add, std::u16::MAX, u16, std::u16::MAX, u16, std::u16::MAX as u32 * 2, u32); + test_scalar_op!( + add, + std::u16::MAX, + u16, + std::u16::MAX, + u16, + std::u16::MAX as u32 * 2, + u32 + ); test_scalar_op!(add, 1, u8, 2, u8, 3, u16); - test_scalar_op!(add, std::u8::MAX, u8, std::u8::MAX, u8, std::u8::MAX as u16 * 2, u16); + test_scalar_op!( + add, + std::u8::MAX, + u8, + std::u8::MAX, + u8, + std::u8::MAX as u16 * 2, + u16 + ); test_scalar_op_err!(add, 1, i32, 2, u16); test_scalar_op_err!(add, 1, i32, 2, u16); let v1 = &ScalarValue::from(1); let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); - let actual = ScalarValue::add(v1, v2).is_err(); - assert_eq!(actual, true); + assert!(!ScalarValue::add(v1, v2).is_err()); let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); let v2 = &ScalarValue::from(2); - let actual = ScalarValue::add(v1, v2).is_err(); - assert_eq!(actual, true); + assert!(ScalarValue::add(v1, v2).is_err()); } #[test] fn scalar_multiplication() { - test_scalar_op!(mul, 1, f64, 2, f64, 2, f64); test_scalar_op!(mul, 1, f32, 2, f32, 2, f64); test_scalar_op!(mul, 15, i64, 2, i64, 30, i64); test_scalar_op!(mul, 100, i64, -32, i64, -3200, i64); test_scalar_op!(mul, -1.1, f64, 2, f64, -2.2, f64); test_scalar_op!(mul, 1, i32, 2, i32, 2, i64); - test_scalar_op!(mul, std::i32::MAX, i32, std::i32::MAX, i32, std::i32::MAX as i64 * std::i32::MAX as i64, i64); + test_scalar_op!( + mul, + std::i32::MAX, + i32, + std::i32::MAX, + i32, + std::i32::MAX as i64 * std::i32::MAX as i64, + i64 + ); test_scalar_op!(mul, 1, i16, 2, i16, 2, i32); - test_scalar_op!(mul, std::i16::MAX, i16, std::i16::MAX, i16, std::i16::MAX as i32 * std::i16::MAX as i32, i32); + test_scalar_op!( + mul, + std::i16::MAX, + i16, + std::i16::MAX, + i16, + std::i16::MAX as i32 * std::i16::MAX as i32, + i32 + ); test_scalar_op!(mul, 1, i8, 2, i8, 2, i16); - test_scalar_op!(mul, std::i8::MAX, i8, std::i8::MAX, i8, std::i8::MAX as i16 * std::i8::MAX as i16, i16); + test_scalar_op!( + mul, + std::i8::MAX, + i8, + std::i8::MAX, + i8, + std::i8::MAX as i16 * std::i8::MAX as i16, + i16 + ); test_scalar_op!(mul, 1, u64, 2, u64, 2, u64); test_scalar_op!(mul, 1, u32, 2, u32, 2, u64); - test_scalar_op!(mul, std::u32::MAX, u32, std::u32::MAX, u32, std::u32::MAX as u64 * std::u32::MAX as u64, u64); + test_scalar_op!( + mul, + std::u32::MAX, + u32, + std::u32::MAX, + u32, + std::u32::MAX as u64 * std::u32::MAX as u64, + u64 + ); test_scalar_op!(mul, 1, u16, 2, u16, 2, u32); - test_scalar_op!(mul, std::u16::MAX, u16, std::u16::MAX, u16, std::u16::MAX as u32 * std::u16::MAX as u32, u32); + test_scalar_op!( + mul, + std::u16::MAX, + u16, + std::u16::MAX, + u16, + std::u16::MAX as u32 * std::u16::MAX as u32, + u32 + ); test_scalar_op!(mul, 1, u8, 2, u8, 2, u16); - test_scalar_op!(mul, std::u8::MAX, u8, std::u8::MAX, u8, std::u8::MAX as u16 * std::u8::MAX as u16, u16); + test_scalar_op!( + mul, + std::u8::MAX, + u8, + std::u8::MAX, + u8, + std::u8::MAX as u16 * std::u8::MAX as u16, + u16 + ); test_scalar_op_err!(mul, 1, i32, 2, u16); test_scalar_op_err!(mul, 1, i32, 2, u16); let v1 = &ScalarValue::from(1); let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); - let actual = ScalarValue::mul(v1, v2).is_err(); - assert_eq!(actual, true); + assert!(!ScalarValue::mul(v1, v2).is_err()); let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); let v2 = &ScalarValue::from(2); - let actual = ScalarValue::mul(v1, v2).is_err(); - assert_eq!(actual, true); + assert!(!ScalarValue::mul(v1, v2).is_err()); } #[test] fn scalar_division() { - test_scalar_op!(div, 1, f64, 2, f64, 0.5, f64); test_scalar_op!(div, 1, f32, 2, f32, 0.5, f64); test_scalar_op!(div, 15, i64, 2, i64, 7.5, f64); @@ -3475,12 +3556,10 @@ mod tests { let v1 = &ScalarValue::from(1); let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); - let actual = ScalarValue::div(v1, v2).is_err(); - assert_eq!(actual, true); + assert!(!ScalarValue::div(v1, v2).is_err()); let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); let v2 = &ScalarValue::from(2); - let actual = ScalarValue::div(v1, v2).is_err(); - assert_eq!(actual, true); + assert!(!ScalarValue::div(v1, v2).is_err()); } -} \ No newline at end of file +} From 4c3d58c00ed448536479eb7ea609f2f1714ee0d0 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Fri, 7 Jan 2022 16:09:21 -0800 Subject: [PATCH 15/22] address comments and fix test errors --- datafusion/src/physical_plan/aggregates.rs | 2 +- datafusion/src/physical_plan/expressions/mod.rs | 6 ++---- datafusion/src/physical_plan/expressions/stddev.rs | 14 +++++++++++++- .../src/physical_plan/expressions/variance.rs | 2 +- datafusion/src/scalar.rs | 14 +++++++------- 5 files changed, 24 insertions(+), 14 deletions(-) diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 3008d12271643..71cc817eadf69 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -239,7 +239,7 @@ pub fn create_aggregate_expr( )), (AggregateFunction::Stddev, true) => { return Err(DataFusionError::NotImplemented( - "VARIANCE(DISTINCT) aggregations are not available".to_string(), + "STDDEV(DISTINCT) aggregations are not available".to_string(), )); } }) diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 079ffd39a892c..e7d42d3439ece 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -86,13 +86,11 @@ pub use nth_value::NthValue; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; -pub(crate) use stddev::is_stddev_support_arg_type; -pub use stddev::{stddev_return_type, Stddev}; +pub (crate )use stddev::{stddev_return_type, is_stddev_support_arg_type, Stddev}; pub(crate) use sum::is_sum_support_arg_type; pub use sum::{sum_return_type, Sum}; pub use try_cast::{try_cast, TryCastExpr}; -pub(crate) use variance::is_variance_support_arg_type; -pub use variance::{variance_return_type, Variance}; +pub (crate) use variance::{variance_return_type, is_variance_support_arg_type, Variance}; /// returns the name of the state pub fn format_state_name(name: &str, state_name: &str) -> String { diff --git a/datafusion/src/physical_plan/expressions/stddev.rs b/datafusion/src/physical_plan/expressions/stddev.rs index ac8d7954743d0..01d80e8a56f9f 100644 --- a/datafusion/src/physical_plan/expressions/stddev.rs +++ b/datafusion/src/physical_plan/expressions/stddev.rs @@ -39,7 +39,7 @@ pub struct Stddev { } /// function return type of standard deviation -pub fn stddev_return_type(arg_type: &DataType) -> Result { +pub(crate) fn stddev_return_type(arg_type: &DataType) -> Result { match arg_type { DataType::Int8 | DataType::Int16 @@ -205,6 +205,18 @@ mod tests { #[test] fn stddev_f64_2() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); + generic_test_op!( + a, + DataType::Float64, + Stddev, + ScalarValue::from(0.7760297817881877), + DataType::Float64 + ) + } + + #[test] + fn stddev_f64_3() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); generic_test_op!( diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index 8699136f83e5c..98e2813aff261 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -36,7 +36,7 @@ pub struct Variance { } /// function return type of variance -pub fn variance_return_type(arg_type: &DataType) -> Result { +pub(crate) fn variance_return_type(arg_type: &DataType) -> Result { match arg_type { DataType::Int8 | DataType::Int16 diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 747530a249d9b..5f66c2eed41b3 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -3373,8 +3373,8 @@ mod tests { ($OP:ident, $LHS:expr, $LHS_TYPE:ident, $RHS:expr, $RHS_TYPE:ident) => {{ let v1 = &ScalarValue::from($LHS as $LHS_TYPE); let v2 = &ScalarValue::from($RHS as $RHS_TYPE); - let actual = ScalarValue::add(v1, v2).is_err(); - assert_eq!(actual, true); + let actual = ScalarValue::$OP(v1, v2).is_err(); + assert!(actual); }}; } @@ -3451,7 +3451,7 @@ mod tests { let v1 = &ScalarValue::from(1); let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); - assert!(!ScalarValue::add(v1, v2).is_err()); + assert!(ScalarValue::add(v1, v2).is_err()); let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); let v2 = &ScalarValue::from(2); @@ -3531,11 +3531,11 @@ mod tests { let v1 = &ScalarValue::from(1); let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); - assert!(!ScalarValue::mul(v1, v2).is_err()); + assert!(ScalarValue::mul(v1, v2).is_err()); let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); let v2 = &ScalarValue::from(2); - assert!(!ScalarValue::mul(v1, v2).is_err()); + assert!(ScalarValue::mul(v1, v2).is_err()); } #[test] @@ -3556,10 +3556,10 @@ mod tests { let v1 = &ScalarValue::from(1); let v2 = &ScalarValue::Decimal128(Some(2), 0, 0); - assert!(!ScalarValue::div(v1, v2).is_err()); + assert!(ScalarValue::div(v1, v2).is_err()); let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); let v2 = &ScalarValue::from(2); - assert!(!ScalarValue::div(v1, v2).is_err()); + assert!(ScalarValue::div(v1, v2).is_err()); } } From cf74208c7ba527d9959b82a2026eb6d878220765 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Fri, 7 Jan 2022 16:13:54 -0800 Subject: [PATCH 16/22] address comments --- datafusion/src/physical_plan/expressions/variance.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index 98e2813aff261..8c981fcac54b3 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -131,6 +131,15 @@ impl AggregateExpr for Variance { } /// An accumulator to compute variance +/// The algrithm used is an online implementation and numerically stable. It is based on this paper: +/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". +/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. +/// +/// The algorithm has been analyzed here: +/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". +/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. + + #[derive(Debug)] pub struct VarianceAccumulator { m2: ScalarValue, From 1c5a30333e19d3fea3ba9d3b537320ecba69abab Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Fri, 7 Jan 2022 18:44:46 -0800 Subject: [PATCH 17/22] add population and sample for variance and stddev --- ballista/rust/core/proto/ballista.proto | 4 +- .../core/src/serde/logical_plan/to_proto.rs | 4 + ballista/rust/core/src/serde/mod.rs | 2 + datafusion/src/physical_plan/aggregates.rs | 118 +++++++++++++++++- .../coercion_rule/aggregate_rule.rs | 18 +++ .../src/physical_plan/expressions/mod.rs | 6 +- .../src/physical_plan/expressions/stats.rs | 25 ++++ .../src/physical_plan/expressions/stddev.rs | 98 +++++++++++++-- .../src/physical_plan/expressions/variance.rs | 110 +++++++++++++++- datafusion/tests/sql/aggregates.rs | 60 ++++++++- 10 files changed, 415 insertions(+), 30 deletions(-) create mode 100644 datafusion/src/physical_plan/expressions/stats.rs diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 9e9d6a497b575..aa7b6a9f900fe 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -170,7 +170,9 @@ enum AggregateFunction { APPROX_DISTINCT = 5; ARRAY_AGG = 6; VARIANCE=7; - STDDEV=8; + VARIANCE_POP=8; + STDDEV=9; + STDDEV_POP=10; } message AggregateExprNode { diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index aa073893f3728..6eebd96cad587 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1135,7 +1135,9 @@ impl TryInto for &Expr { AggregateFunction::Avg => protobuf::AggregateFunction::Avg, AggregateFunction::Count => protobuf::AggregateFunction::Count, AggregateFunction::Variance => protobuf::AggregateFunction::Variance, + AggregateFunction::VariancePop => protobuf::AggregateFunction::VariancePop, AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, + AggregateFunction::StddevPop => protobuf::AggregateFunction::StddevPop, }; let arg = &args[0]; @@ -1367,7 +1369,9 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::ApproxDistinct => Self::ApproxDistinct, AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Variance => Self::Variance, + AggregateFunction::VariancePop => Self::VariancePop, AggregateFunction::Stddev => Self::Stddev, + AggregateFunction::StddevPop => Self::StddevPop, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 62a7ed4eae8b1..fd3b57b3deda1 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -120,7 +120,9 @@ impl From for AggregateFunction { } protobuf::AggregateFunction::ArrayAgg => AggregateFunction::ArrayAgg, protobuf::AggregateFunction::Variance => AggregateFunction::Variance, + protobuf::AggregateFunction::VariancePop => AggregateFunction::VariancePop, protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev, + protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop, } } } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 71cc817eadf69..2b0ddb9a4966b 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -66,10 +66,14 @@ pub enum AggregateFunction { ApproxDistinct, /// array_agg ArrayAgg, - /// Variance (Population) + /// Variance (Sample) Variance, - /// Standard Deviation (Population) + /// Variance (Population) + VariancePop, + /// Standard Deviation (Sample) Stddev, + /// Standard Deviation (Population) + StddevPop, } impl fmt::Display for AggregateFunction { @@ -90,8 +94,12 @@ impl FromStr for AggregateFunction { "sum" => AggregateFunction::Sum, "approx_distinct" => AggregateFunction::ApproxDistinct, "array_agg" => AggregateFunction::ArrayAgg, - "variance" => AggregateFunction::Variance, + "var" => AggregateFunction::Variance, + "var_samp" => AggregateFunction::Variance, + "var_pop" => AggregateFunction::VariancePop, "stddev" => AggregateFunction::Stddev, + "stddev_samp" => AggregateFunction::Stddev, + "stddev_pop" => AggregateFunction::StddevPop, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -125,7 +133,9 @@ pub fn return_type( } AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]), + AggregateFunction::VariancePop => variance_return_type(&coerced_data_types[0]), AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]), + AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]), AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new( "item", @@ -229,7 +239,17 @@ pub fn create_aggregate_expr( )), (AggregateFunction::Variance, true) => { return Err(DataFusionError::NotImplemented( - "VARIANCE(DISTINCT) aggregations are not available".to_string(), + "VAR(DISTINCT) aggregations are not available".to_string(), + )); + } + (AggregateFunction::VariancePop, false) => Arc::new(expressions::VariancePop::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )), + (AggregateFunction::VariancePop, true) => { + return Err(DataFusionError::NotImplemented( + "VAR_POP(DISTINCT) aggregations are not available".to_string(), )); } (AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new( @@ -242,6 +262,16 @@ pub fn create_aggregate_expr( "STDDEV(DISTINCT) aggregations are not available".to_string(), )); } + (AggregateFunction::StddevPop, false) => Arc::new(expressions::StddevPop::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )), + (AggregateFunction::StddevPop, true) => { + return Err(DataFusionError::NotImplemented( + "STDDEV_POP(DISTINCT) aggregations are not available".to_string(), + )); + } }) } @@ -289,7 +319,9 @@ pub fn signature(fun: &AggregateFunction) -> Signature { AggregateFunction::Avg | AggregateFunction::Sum | AggregateFunction::Variance - | AggregateFunction::Stddev => { + | AggregateFunction::VariancePop + | AggregateFunction::Stddev + | AggregateFunction::StddevPop => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } } @@ -521,6 +553,44 @@ mod tests { Ok(()) } + #[test] + fn test_var_pop_expr() -> Result<()> { + let funcs = vec![AggregateFunction::VariancePop]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + if fun == AggregateFunction::Variance { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ) + } + } + } + Ok(()) + } + #[test] fn test_stddev_expr() -> Result<()> { let funcs = vec![AggregateFunction::Stddev]; @@ -559,6 +629,44 @@ mod tests { Ok(()) } + #[test] + fn test_stddev_pop_expr() -> Result<()> { + let funcs = vec![AggregateFunction::StddevPop]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + if fun == AggregateFunction::Variance { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ) + } + } + } + Ok(()) + } + #[test] fn test_min_max() -> Result<()> { let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?; diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index d5d547fc859c7..d74b4e465c891 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -96,6 +96,15 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::VariancePop => { + if !is_variance_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } AggregateFunction::Stddev => { if !is_stddev_support_arg_type(&input_types[0]) { return Err(DataFusionError::Plan(format!( @@ -105,6 +114,15 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::StddevPop => { + if !is_stddev_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } } } diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index e7d42d3439ece..b3986c23b4855 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -54,6 +54,7 @@ mod stddev; mod sum; mod try_cast; mod variance; +mod stats; /// Module with some convenient methods used in expression building pub mod helpers { @@ -86,11 +87,12 @@ pub use nth_value::NthValue; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; -pub (crate )use stddev::{stddev_return_type, is_stddev_support_arg_type, Stddev}; +pub (crate )use stddev::{stddev_return_type, is_stddev_support_arg_type, Stddev, StddevPop}; pub(crate) use sum::is_sum_support_arg_type; pub use sum::{sum_return_type, Sum}; pub use try_cast::{try_cast, TryCastExpr}; -pub (crate) use variance::{variance_return_type, is_variance_support_arg_type, Variance}; +pub (crate) use variance::{variance_return_type, is_variance_support_arg_type, Variance, VariancePop}; +pub use stats::StatsType; /// returns the name of the state pub fn format_state_name(name: &str, state_name: &str) -> String { diff --git a/datafusion/src/physical_plan/expressions/stats.rs b/datafusion/src/physical_plan/expressions/stats.rs new file mode 100644 index 0000000000000..1c102cb088f8c --- /dev/null +++ b/datafusion/src/physical_plan/expressions/stats.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Enum used for differenciating population and sample for statistical functions +#[derive(Debug, Clone, Copy)] +pub enum StatsType { + /// Population + Population, + /// Sample + Sample, +} \ No newline at end of file diff --git a/datafusion/src/physical_plan/expressions/stddev.rs b/datafusion/src/physical_plan/expressions/stddev.rs index 01d80e8a56f9f..5caf826c7ca92 100644 --- a/datafusion/src/physical_plan/expressions/stddev.rs +++ b/datafusion/src/physical_plan/expressions/stddev.rs @@ -28,14 +28,20 @@ use crate::scalar::ScalarValue; use arrow::datatypes::DataType; use arrow::datatypes::Field; -use super::format_state_name; +use super::{format_state_name, StatsType}; -/// STDDEV (standard deviation) aggregate expression +/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression #[derive(Debug)] pub struct Stddev { name: String, expr: Arc, - data_type: DataType, +} + +/// STDDEV_POP population aggregate expression +#[derive(Debug)] +pub struct StddevPop { + name: String, + expr: Arc, } /// function return type of standard deviation @@ -86,7 +92,6 @@ impl Stddev { Self { name: name.into(), expr, - data_type, } } } @@ -98,11 +103,11 @@ impl AggregateExpr for Stddev { } fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) + Ok(Field::new(&self.name, DataType::Float64, true)) } fn create_accumulator(&self) -> Result> { - Ok(Box::new(StddevAccumulator::try_new()?)) + Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) } fn state_fields(&self) -> Result> { @@ -134,6 +139,64 @@ impl AggregateExpr for Stddev { } } +impl StddevPop { + /// Create a new STDDEV aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + // the result of stddev just support FLOAT64 and Decimal data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + } + } +} + +impl AggregateExpr for StddevPop { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "m2"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} /// An accumulator to compute the average #[derive(Debug)] pub struct StddevAccumulator { @@ -142,9 +205,9 @@ pub struct StddevAccumulator { impl StddevAccumulator { /// Creates a new `StddevAccumulator` - pub fn try_new() -> Result { + pub fn try_new(s_type: StatsType) -> Result { Ok(Self { - variance: VarianceAccumulator::try_new()?, + variance: VarianceAccumulator::try_new(s_type)?, }) } } @@ -197,7 +260,7 @@ mod tests { generic_test_op!( a, DataType::Float64, - Stddev, + StddevPop, ScalarValue::from(0.5_f64), DataType::Float64 ) @@ -209,7 +272,7 @@ mod tests { generic_test_op!( a, DataType::Float64, - Stddev, + StddevPop, ScalarValue::from(0.7760297817881877), DataType::Float64 ) @@ -222,12 +285,25 @@ mod tests { generic_test_op!( a, DataType::Float64, - Stddev, + StddevPop, ScalarValue::from(std::f64::consts::SQRT_2), DataType::Float64 ) } + #[test] + fn stddev_f64_4() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); + generic_test_op!( + a, + DataType::Float64, + Stddev, + ScalarValue::from(0.9504384952922168), + DataType::Float64 + ) + } + #[test] fn stddev_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index 8c981fcac54b3..961acde31da96 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -26,15 +26,22 @@ use crate::scalar::ScalarValue; use arrow::datatypes::DataType; use arrow::datatypes::Field; -use super::format_state_name; +use super::{format_state_name, StatsType}; -/// VARIANCE aggregate expression +/// VAR and VAR_SAMP aggregate expression #[derive(Debug)] pub struct Variance { name: String, expr: Arc, } +/// VAR_POP aggregate expression +#[derive(Debug)] +pub struct VariancePop { + name: String, + expr: Arc, +} + /// function return type of variance pub(crate) fn variance_return_type(arg_type: &DataType) -> Result { match arg_type { @@ -98,7 +105,66 @@ impl AggregateExpr for Variance { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(VarianceAccumulator::try_new()?)) + Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + &format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new( + &format_state_name(&self.name, "m2"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl VariancePop { + /// Create a new VAR_POP aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + // the result of variance just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + } + } +} + +impl AggregateExpr for VariancePop { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(VarianceAccumulator::try_new(StatsType::Population)?)) } fn state_fields(&self) -> Result> { @@ -145,15 +211,17 @@ pub struct VarianceAccumulator { m2: ScalarValue, mean: ScalarValue, count: u64, + s_type: StatsType, } impl VarianceAccumulator { /// Creates a new `VarianceAccumulator` - pub fn try_new() -> Result { + pub fn try_new(s_type: StatsType) -> Result { Ok(Self { m2: ScalarValue::from(0 as f64), mean: ScalarValue::from(0 as f64), count: 0, + s_type: s_type, }) } @@ -241,12 +309,18 @@ impl Accumulator for VarianceAccumulator { } fn evaluate(&self) -> Result { + let count = + match self.s_type { + StatsType::Population => self.count, + StatsType::Sample => self.count - 1, + }; + match self.m2 { ScalarValue::Float64(e) => { if self.count == 0 { Ok(ScalarValue::Float64(None)) } else { - Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64))) + Ok(ScalarValue::Float64(e.map(|f| f / count as f64))) } } _ => Err(DataFusionError::Internal( @@ -289,6 +363,32 @@ mod tests { ) } + #[test] + fn variance_f64_3() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); + generic_test_op!( + a, + DataType::Float64, + VariancePop, + ScalarValue::from(2.5_f64), + DataType::Float64 + ) + } + + #[test] + fn variance_f64_4() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); + generic_test_op!( + a, + DataType::Float64, + Variance, + ScalarValue::from(0.9033333333333333_f64), + DataType::Float64 + ) + } + #[test] fn variance_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index f4e5d4afcad90..66dfd9bd311c3 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -53,7 +53,7 @@ async fn csv_query_avg() -> Result<()> { async fn csv_query_variance_1() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT variance(c2) FROM aggregate_test_100"; + let sql = "SELECT var_pop(c2) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; actual.sort(); let expected = vec![vec!["1.8675"]]; @@ -65,7 +65,7 @@ async fn csv_query_variance_1() -> Result<()> { async fn csv_query_variance_2() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT variance(c6) FROM aggregate_test_100"; + let sql = "SELECT var_pop(c6) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; actual.sort(); let expected = vec![vec!["26156334342021890000000000000000000000"]]; @@ -77,7 +77,7 @@ async fn csv_query_variance_2() -> Result<()> { async fn csv_query_variance_3() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT variance(c12) FROM aggregate_test_100"; + let sql = "SELECT var_pop(c12) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; actual.sort(); let expected = vec![vec!["0.09234223721582163"]]; @@ -85,11 +85,35 @@ async fn csv_query_variance_3() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_variance_4() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT var(c2) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["1.8863636363636365"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_variance_5() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT var_samp(c2) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["1.8863636363636365"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_stddev_1() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT stddev(c2) FROM aggregate_test_100"; + let sql = "SELECT stddev_pop(c2) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; actual.sort(); let expected = vec![vec!["1.3665650368716449"]]; @@ -101,7 +125,7 @@ async fn csv_query_stddev_1() -> Result<()> { async fn csv_query_stddev_2() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT stddev(c6) FROM aggregate_test_100"; + let sql = "SELECT stddev_pop(c6) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; actual.sort(); let expected = vec![vec!["5114326382039172000"]]; @@ -113,7 +137,7 @@ async fn csv_query_stddev_2() -> Result<()> { async fn csv_query_stddev_3() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT stddev(c12) FROM aggregate_test_100"; + let sql = "SELECT stddev_pop(c12) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; actual.sort(); let expected = vec![vec!["0.30387865541334363"]]; @@ -121,6 +145,30 @@ async fn csv_query_stddev_3() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_stddev_4() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT stddev(c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.3054095399405338"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_stddev_5() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT stddev_samp(c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.3054095399405338"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_external_table_count() { let mut ctx = ExecutionContext::new(); From 8e88267ab7a451f80dc78fcb0fb4ea4abd879c0b Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Fri, 7 Jan 2022 20:25:23 -0800 Subject: [PATCH 18/22] address more comments --- .../src/optimizer/simplify_expressions.rs | 2 +- datafusion/src/scalar.rs | 54 +++++++++++++++++++ datafusion/tests/sql/aggregates.rs | 12 +++++ 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index ff2c05c76f18c..2d9c1448e09d2 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -497,7 +497,7 @@ impl ConstEvaluator { } /// Internal helper to evaluates an Expr - fn evaluate_to_scalar(&self, expr: Expr) -> Result { + pub(crate) fn evaluate_to_scalar(&self, expr: Expr) -> Result { if let Expr::Literal(s) = expr { return Ok(s); } diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 5f66c2eed41b3..6b01019f935a5 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -554,6 +554,12 @@ impl ScalarValue { ))); } + if lhs.is_null() || rhs.is_null() { + return Err(DataFusionError::Internal( + "Addition does not support empty values".to_string() + )); + } + // TODO: Finding a good way to support operation between different types without // writing a hige match block. // TODO: Add support for decimal types @@ -650,6 +656,12 @@ impl ScalarValue { ))); } + if lhs.is_null() || rhs.is_null() { + return Err(DataFusionError::Internal( + "Multiplication does not support empty values".to_string() + )); + } + // TODO: Finding a good way to support operation between different types without // writing a hige match block. // TODO: Add support for decimal type @@ -717,6 +729,12 @@ impl ScalarValue { ))); } + if lhs.is_null() || rhs.is_null() { + return Err(DataFusionError::Internal( + "Division does not support empty values".to_string() + )); + } + // TODO: Finding a good way to support operation between different types without // writing a hige match block. // TODO: Add support for decimal types @@ -3456,6 +3474,18 @@ mod tests { let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); let v2 = &ScalarValue::from(2); assert!(ScalarValue::add(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::from(2); + assert!(ScalarValue::add(v1, v2).is_err()); + + let v2 = &ScalarValue::Float32(None); + let v1 = &ScalarValue::from(2); + assert!(ScalarValue::add(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::Float32(None); + assert!(ScalarValue::add(v1, v2).is_err()); } #[test] @@ -3536,6 +3566,18 @@ mod tests { let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); let v2 = &ScalarValue::from(2); assert!(ScalarValue::mul(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::from(2); + assert!(ScalarValue::mul(v1, v2).is_err()); + + let v2 = &ScalarValue::Float32(None); + let v1 = &ScalarValue::from(2); + assert!(ScalarValue::mul(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::Float32(None); + assert!(ScalarValue::mul(v1, v2).is_err()); } #[test] @@ -3561,5 +3603,17 @@ mod tests { let v1 = &ScalarValue::Decimal128(Some(1), 0, 0); let v2 = &ScalarValue::from(2); assert!(ScalarValue::div(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::from(2); + assert!(ScalarValue::div(v1, v2).is_err()); + + let v2 = &ScalarValue::Float32(None); + let v1 = &ScalarValue::from(2); + assert!(ScalarValue::div(v1, v2).is_err()); + + let v1 = &ScalarValue::Float32(None); + let v2 = &ScalarValue::Float32(None); + assert!(ScalarValue::div(v1, v2).is_err()); } } diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 66dfd9bd311c3..fd0964caf8ca0 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -169,6 +169,18 @@ async fn csv_query_stddev_5() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_stddev_6() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "select stddev(sq.column1) from (values (1.1), (2.0), (3.0)) as sq"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.9504384952922168"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_external_table_count() { let mut ctx = ExecutionContext::new(); From dd007588ea91bf9929226444e4798a37e229da4c Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Fri, 7 Jan 2022 20:33:33 -0800 Subject: [PATCH 19/22] fmt --- .../core/src/serde/logical_plan/to_proto.rs | 8 +++- datafusion/src/physical_plan/aggregates.rs | 14 ++++--- .../src/physical_plan/expressions/mod.rs | 12 ++++-- .../src/physical_plan/expressions/stats.rs | 2 +- .../src/physical_plan/expressions/stddev.rs | 15 ++++--- .../src/physical_plan/expressions/variance.rs | 41 +++++++++---------- datafusion/src/scalar.rs | 31 +++++++------- 7 files changed, 66 insertions(+), 57 deletions(-) diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 6eebd96cad587..3e0e5583604dc 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1135,9 +1135,13 @@ impl TryInto for &Expr { AggregateFunction::Avg => protobuf::AggregateFunction::Avg, AggregateFunction::Count => protobuf::AggregateFunction::Count, AggregateFunction::Variance => protobuf::AggregateFunction::Variance, - AggregateFunction::VariancePop => protobuf::AggregateFunction::VariancePop, + AggregateFunction::VariancePop => { + protobuf::AggregateFunction::VariancePop + } AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, - AggregateFunction::StddevPop => protobuf::AggregateFunction::StddevPop, + AggregateFunction::StddevPop => { + protobuf::AggregateFunction::StddevPop + } }; let arg = &args[0]; diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 2b0ddb9a4966b..07b0ff8b33b29 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -242,11 +242,13 @@ pub fn create_aggregate_expr( "VAR(DISTINCT) aggregations are not available".to_string(), )); } - (AggregateFunction::VariancePop, false) => Arc::new(expressions::VariancePop::new( - coerced_phy_exprs[0].clone(), - name, - return_type, - )), + (AggregateFunction::VariancePop, false) => { + Arc::new(expressions::VariancePop::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )) + } (AggregateFunction::VariancePop, true) => { return Err(DataFusionError::NotImplemented( "VAR_POP(DISTINCT) aggregations are not available".to_string(), @@ -320,7 +322,7 @@ pub fn signature(fun: &AggregateFunction) -> Signature { | AggregateFunction::Sum | AggregateFunction::Variance | AggregateFunction::VariancePop - | AggregateFunction::Stddev + | AggregateFunction::Stddev | AggregateFunction::StddevPop => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index b3986c23b4855..a85d867085572 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -50,11 +50,11 @@ mod nth_value; mod nullif; mod rank; mod row_number; +mod stats; mod stddev; mod sum; mod try_cast; mod variance; -mod stats; /// Module with some convenient methods used in expression building pub mod helpers { @@ -87,12 +87,16 @@ pub use nth_value::NthValue; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; -pub (crate )use stddev::{stddev_return_type, is_stddev_support_arg_type, Stddev, StddevPop}; +pub use stats::StatsType; +pub(crate) use stddev::{ + is_stddev_support_arg_type, stddev_return_type, Stddev, StddevPop, +}; pub(crate) use sum::is_sum_support_arg_type; pub use sum::{sum_return_type, Sum}; pub use try_cast::{try_cast, TryCastExpr}; -pub (crate) use variance::{variance_return_type, is_variance_support_arg_type, Variance, VariancePop}; -pub use stats::StatsType; +pub(crate) use variance::{ + is_variance_support_arg_type, variance_return_type, Variance, VariancePop, +}; /// returns the name of the state pub fn format_state_name(name: &str, state_name: &str) -> String { diff --git a/datafusion/src/physical_plan/expressions/stats.rs b/datafusion/src/physical_plan/expressions/stats.rs index 1c102cb088f8c..3f2d266622dee 100644 --- a/datafusion/src/physical_plan/expressions/stats.rs +++ b/datafusion/src/physical_plan/expressions/stats.rs @@ -22,4 +22,4 @@ pub enum StatsType { Population, /// Sample Sample, -} \ No newline at end of file +} diff --git a/datafusion/src/physical_plan/expressions/stddev.rs b/datafusion/src/physical_plan/expressions/stddev.rs index 5caf826c7ca92..3a2cea7404d19 100644 --- a/datafusion/src/physical_plan/expressions/stddev.rs +++ b/datafusion/src/physical_plan/expressions/stddev.rs @@ -293,8 +293,7 @@ mod tests { #[test] fn stddev_f64_4() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); + let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); generic_test_op!( a, DataType::Float64, @@ -310,7 +309,7 @@ mod tests { generic_test_op!( a, DataType::Int32, - Stddev, + StddevPop, ScalarValue::from(std::f64::consts::SQRT_2), DataType::Float64 ) @@ -323,7 +322,7 @@ mod tests { generic_test_op!( a, DataType::UInt32, - Stddev, + StddevPop, ScalarValue::from(std::f64::consts::SQRT_2), DataType::Float64 ) @@ -336,7 +335,7 @@ mod tests { generic_test_op!( a, DataType::Float32, - Stddev, + StddevPop, ScalarValue::from(std::f64::consts::SQRT_2), DataType::Float64 ) @@ -349,7 +348,7 @@ mod tests { assert_eq!(DataType::Float64, result_type); let data_type = DataType::Decimal(36, 10); - assert!(!stddev_return_type(&data_type).is_err()); + assert!(stddev_return_type(&data_type).is_err()); Ok(()) } @@ -365,7 +364,7 @@ mod tests { generic_test_op!( a, DataType::Int32, - Stddev, + StddevPop, ScalarValue::from(1.479019945774904), DataType::Float64 ) @@ -377,7 +376,7 @@ mod tests { generic_test_op!( a, DataType::Int32, - Stddev, + StddevPop, ScalarValue::Float64(None), DataType::Float64 ) diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index 961acde31da96..5ccff4c4d778d 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -164,7 +164,9 @@ impl AggregateExpr for VariancePop { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(VarianceAccumulator::try_new(StatsType::Population)?)) + Ok(Box::new(VarianceAccumulator::try_new( + StatsType::Population, + )?)) } fn state_fields(&self) -> Result> { @@ -198,14 +200,13 @@ impl AggregateExpr for VariancePop { /// An accumulator to compute variance /// The algrithm used is an online implementation and numerically stable. It is based on this paper: -/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". +/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". /// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. -/// +/// /// The algorithm has been analyzed here: -/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". +/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". /// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. - #[derive(Debug)] pub struct VarianceAccumulator { m2: ScalarValue, @@ -309,11 +310,10 @@ impl Accumulator for VarianceAccumulator { } fn evaluate(&self) -> Result { - let count = - match self.s_type { - StatsType::Population => self.count, - StatsType::Sample => self.count - 1, - }; + let count = match self.s_type { + StatsType::Population => self.count, + StatsType::Sample => self.count - 1, + }; match self.m2 { ScalarValue::Float64(e) => { @@ -344,7 +344,7 @@ mod tests { generic_test_op!( a, DataType::Float64, - Variance, + VariancePop, ScalarValue::from(0.25_f64), DataType::Float64 ) @@ -357,7 +357,7 @@ mod tests { generic_test_op!( a, DataType::Float64, - Variance, + VariancePop, ScalarValue::from(2_f64), DataType::Float64 ) @@ -370,7 +370,7 @@ mod tests { generic_test_op!( a, DataType::Float64, - VariancePop, + Variance, ScalarValue::from(2.5_f64), DataType::Float64 ) @@ -378,8 +378,7 @@ mod tests { #[test] fn variance_f64_4() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); + let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); generic_test_op!( a, DataType::Float64, @@ -395,7 +394,7 @@ mod tests { generic_test_op!( a, DataType::Int32, - Variance, + VariancePop, ScalarValue::from(2_f64), DataType::Float64 ) @@ -408,7 +407,7 @@ mod tests { generic_test_op!( a, DataType::UInt32, - Variance, + VariancePop, ScalarValue::from(2.0f64), DataType::Float64 ) @@ -421,7 +420,7 @@ mod tests { generic_test_op!( a, DataType::Float32, - Variance, + VariancePop, ScalarValue::from(2_f64), DataType::Float64 ) @@ -434,7 +433,7 @@ mod tests { assert_eq!(DataType::Float64, result_type); let data_type = DataType::Decimal(36, 10); - assert!(!variance_return_type(&data_type).is_err()); + assert!(variance_return_type(&data_type).is_err()); Ok(()) } @@ -450,7 +449,7 @@ mod tests { generic_test_op!( a, DataType::Int32, - Variance, + VariancePop, ScalarValue::from(2.1875f64), DataType::Float64 ) @@ -462,7 +461,7 @@ mod tests { generic_test_op!( a, DataType::Int32, - Variance, + VariancePop, ScalarValue::Float64(None), DataType::Float64 ) diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 6b01019f935a5..cf6e8a1ac1c2f 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -528,18 +528,19 @@ macro_rules! eq_array_primitive { impl ScalarValue { /// Return true if the value is numeric pub fn is_numeric(&self) -> bool { - matches!(self, + matches!( + self, ScalarValue::Float32(_) - | ScalarValue::Float64(_) - | ScalarValue::Decimal128(_, _, _) - | ScalarValue::Int8(_) - | ScalarValue::Int16(_) - | ScalarValue::Int32(_) - | ScalarValue::Int64(_) - | ScalarValue::UInt8(_) - | ScalarValue::UInt16(_) - | ScalarValue::UInt32(_) - | ScalarValue::UInt64(_) + | ScalarValue::Float64(_) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) ) } @@ -556,7 +557,7 @@ impl ScalarValue { if lhs.is_null() || rhs.is_null() { return Err(DataFusionError::Internal( - "Addition does not support empty values".to_string() + "Addition does not support empty values".to_string(), )); } @@ -658,7 +659,7 @@ impl ScalarValue { if lhs.is_null() || rhs.is_null() { return Err(DataFusionError::Internal( - "Multiplication does not support empty values".to_string() + "Multiplication does not support empty values".to_string(), )); } @@ -668,7 +669,7 @@ impl ScalarValue { match (lhs, rhs) { (ScalarValue::Decimal128(_, _, _), _) | (_, ScalarValue::Decimal128(_, _, _)) => Err(DataFusionError::Internal( - "Multiplication with Decimals are not supported for now".to_string() + "Multiplication with Decimals are not supported for now".to_string(), )), // f64 / _ (ScalarValue::Float64(f1), ScalarValue::Float64(f2)) => { @@ -731,7 +732,7 @@ impl ScalarValue { if lhs.is_null() || rhs.is_null() { return Err(DataFusionError::Internal( - "Division does not support empty values".to_string() + "Division does not support empty values".to_string(), )); } From 844ffd50152343ff5962945d898c5d6848391d02 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Fri, 7 Jan 2022 20:54:57 -0800 Subject: [PATCH 20/22] add test for less than 2 values --- .../src/physical_plan/expressions/stddev.rs | 18 ++++++++++++++ .../src/physical_plan/expressions/variance.rs | 24 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/datafusion/src/physical_plan/expressions/stddev.rs b/datafusion/src/physical_plan/expressions/stddev.rs index 3a2cea7404d19..17656a2bb0d94 100644 --- a/datafusion/src/physical_plan/expressions/stddev.rs +++ b/datafusion/src/physical_plan/expressions/stddev.rs @@ -352,6 +352,24 @@ mod tests { Ok(()) } + #[test] + fn test_stddev_1_input() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64])); + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + + let agg = Arc::new(Stddev::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) + } + #[test] fn stddev_i32_with_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![ diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index 5ccff4c4d778d..05ee8676054db 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -315,6 +315,12 @@ impl Accumulator for VarianceAccumulator { StatsType::Sample => self.count - 1, }; + if count <=1 { + return Err(DataFusionError::Internal( + "At least two values are needed to calculate variance".to_string(), + )) + } + match self.m2 { ScalarValue::Float64(e) => { if self.count == 0 { @@ -437,6 +443,24 @@ mod tests { Ok(()) } + #[test] + fn test_variance_1_input() -> Result<()> { + let a: ArrayRef = + Arc::new(Float64Array::from(vec![1_f64])); + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + + let agg = Arc::new(Variance::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) + } + #[test] fn variance_i32_with_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![ From 987a6047a1130df6bf8d37fda16728615cf31d49 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Fri, 7 Jan 2022 22:38:32 -0800 Subject: [PATCH 21/22] fix inconsistency in the merge logic --- .../src/physical_plan/expressions/stddev.rs | 38 +++++++----- .../src/physical_plan/expressions/variance.rs | 62 +++++++++++++------ 2 files changed, 63 insertions(+), 37 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/stddev.rs b/datafusion/src/physical_plan/expressions/stddev.rs index 17656a2bb0d94..d6e28f18d3558 100644 --- a/datafusion/src/physical_plan/expressions/stddev.rs +++ b/datafusion/src/physical_plan/expressions/stddev.rs @@ -354,20 +354,19 @@ mod tests { #[test] fn test_stddev_1_input() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64])); + let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - let agg = Arc::new(Stddev::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - let actual = aggregate(&batch, agg); - assert!(actual.is_err()); + let agg = Arc::new(Stddev::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); - Ok(()) + Ok(()) } #[test] @@ -391,13 +390,18 @@ mod tests { #[test] fn stddev_i32_all_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!( - a, - DataType::Int32, - StddevPop, - ScalarValue::Float64(None), - DataType::Float64 - ) + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + + let agg = Arc::new(Stddev::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) } fn aggregate( diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index 05ee8676054db..e739ae6c0a309 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -276,12 +276,24 @@ impl Accumulator for VarianceAccumulator { let mean = &states[1]; let m2 = &states[2]; let mut new_count: u64 = self.count; + // counts are summed if let ScalarValue::UInt64(Some(c)) = count { + if *c <= 0 as u64 { + return Ok(()); + } + + if self.count <= 0 { + self.count = *c; + self.mean = mean.clone(); + self.m2 = m2.clone(); + return Ok(()); + } new_count += c } else { unreachable!() }; + let new_mean = ScalarValue::div( &ScalarValue::add(&self.mean, mean)?, &ScalarValue::from(2_f64), @@ -312,13 +324,19 @@ impl Accumulator for VarianceAccumulator { fn evaluate(&self) -> Result { let count = match self.s_type { StatsType::Population => self.count, - StatsType::Sample => self.count - 1, + StatsType::Sample => { + if self.count > 0 { + self.count - 1 + } else { + self.count + } + } }; - if count <=1 { + if count <= 1 { return Err(DataFusionError::Internal( "At least two values are needed to calculate variance".to_string(), - )) + )); } match self.m2 { @@ -445,20 +463,19 @@ mod tests { #[test] fn test_variance_1_input() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64])); + let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - let agg = Arc::new(Variance::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - let actual = aggregate(&batch, agg); - assert!(actual.is_err()); + let agg = Arc::new(Variance::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); - Ok(()) + Ok(()) } #[test] @@ -482,13 +499,18 @@ mod tests { #[test] fn variance_i32_all_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!( - a, - DataType::Int32, - VariancePop, - ScalarValue::Float64(None), - DataType::Float64 - ) + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + + let agg = Arc::new(Variance::new( + col("a", &schema)?, + "bla".to_string(), + DataType::Float64, + )); + let actual = aggregate(&batch, agg); + assert!(actual.is_err()); + + Ok(()) } fn aggregate( From d2ff16d1053290caf55d56739f61315c699571e4 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Fri, 7 Jan 2022 22:57:31 -0800 Subject: [PATCH 22/22] fix lint and clipy --- datafusion/src/physical_plan/expressions/variance.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index e739ae6c0a309..3f592b00fd4ef 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -212,7 +212,7 @@ pub struct VarianceAccumulator { m2: ScalarValue, mean: ScalarValue, count: u64, - s_type: StatsType, + stats_type: StatsType, } impl VarianceAccumulator { @@ -222,7 +222,7 @@ impl VarianceAccumulator { m2: ScalarValue::from(0 as f64), mean: ScalarValue::from(0 as f64), count: 0, - s_type: s_type, + stats_type: s_type, }) } @@ -279,11 +279,11 @@ impl Accumulator for VarianceAccumulator { // counts are summed if let ScalarValue::UInt64(Some(c)) = count { - if *c <= 0 as u64 { + if *c == 0_u64 { return Ok(()); } - if self.count <= 0 { + if self.count == 0 { self.count = *c; self.mean = mean.clone(); self.m2 = m2.clone(); @@ -322,7 +322,7 @@ impl Accumulator for VarianceAccumulator { } fn evaluate(&self) -> Result { - let count = match self.s_type { + let count = match self.stats_type { StatsType::Population => self.count, StatsType::Sample => { if self.count > 0 {