From bab7b466bee8f09a49d95ca52971ed0630557a15 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Tue, 1 Feb 2022 23:41:14 -0800 Subject: [PATCH 01/14] add median operator --- ballista/rust/core/proto/ballista.proto | 1 + .../core/src/serde/logical_plan/to_proto.rs | 2 + ballista/rust/core/src/serde/mod.rs | 1 + datafusion/src/physical_plan/aggregates.rs | 74 ++++- .../coercion_rule/aggregate_rule.rs | 14 +- .../expressions/approx_percentile_cont.rs | 4 + .../src/physical_plan/expressions/median.rs | 298 ++++++++++++++++++ .../src/physical_plan/expressions/mod.rs | 2 + datafusion/tests/sql/aggregates.rs | 36 +++ 9 files changed, 428 insertions(+), 4 deletions(-) create mode 100644 datafusion/src/physical_plan/expressions/median.rs diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index fb006e532ff34..df9dcc1d5040e 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -177,6 +177,7 @@ enum AggregateFunction { STDDEV_POP=12; CORRELATION=13; APPROX_PERCENTILE_CONT = 14; + MEDIAN=15; } message AggregateExprNode { diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 4b13ce577cfb9..0a3be7db9e333 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1100,6 +1100,7 @@ impl TryInto for &Expr { AggregateFunction::Correlation => { protobuf::AggregateFunction::Correlation } + AggregateFunction::Median => protobuf::AggregateFunction::Median, }; let aggregate_expr = protobuf::AggregateExprNode { @@ -1340,6 +1341,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::StddevPop => Self::StddevPop, AggregateFunction::Correlation => Self::Correlation, AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, + AggregateFunction::Median => Self::Median, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 64a60dc4da5d4..e62996171172e 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -132,6 +132,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::ApproxPercentileCont => { AggregateFunction::ApproxPercentileCont } + protobuf::AggregateFunction::Median => AggregateFunction::Median, } } } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 8fc94d3860147..239bfbcdd5651 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -82,6 +82,8 @@ pub enum AggregateFunction { Correlation, /// Approximate continuous percentile function ApproxPercentileCont, + /// Median + Median, } impl fmt::Display for AggregateFunction { @@ -113,6 +115,7 @@ impl FromStr for AggregateFunction { "covar_pop" => AggregateFunction::CovariancePop, "corr" => AggregateFunction::Correlation, "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, + "median" => AggregateFunction::Median, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -161,6 +164,7 @@ pub fn return_type( true, )))), AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()), + AggregateFunction::Median => Ok(coerced_data_types[0].clone()), } } @@ -349,6 +353,16 @@ pub fn create_aggregate_expr( .to_string(), )); } + (AggregateFunction::Median, false) => Arc::new(expressions::Median::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )), + (AggregateFunction::Median, true) => { + return Err(DataFusionError::NotImplemented( + "MEDIAN(DISTINCT) aggregations are not available".to_string(), + )); + } }) } @@ -398,7 +412,8 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature { | AggregateFunction::Variance | AggregateFunction::VariancePop | AggregateFunction::Stddev - | AggregateFunction::StddevPop => { + | AggregateFunction::StddevPop + | AggregateFunction::Median => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::Covariance | AggregateFunction::CovariancePop => { @@ -423,7 +438,8 @@ mod tests { use super::*; use crate::physical_plan::expressions::{ ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, Correlation, Count, - Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, + Covariance, DistinctArrayAgg, DistinctCount, Max, Median, Min, Stddev, Sum, + Variance, }; use crate::{error::Result, scalar::ScalarValue}; @@ -996,6 +1012,60 @@ mod tests { Ok(()) } + #[test] + fn test_median_expr() -> Result<()> { + let funcs = vec![AggregateFunction::Median]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + + if fun == AggregateFunction::Median { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", data_type.clone(), true), + result_agg_phy_exprs.field().unwrap() + ); + } + } + } + Ok(()) + } + + #[test] + fn test_median() -> Result<()> { + let observed = return_type(&AggregateFunction::Median, &[DataType::Utf8]); + assert!(observed.is_err()); + + let observed = return_type(&AggregateFunction::Median, &[DataType::Int32])?; + assert_eq!(DataType::Int32, observed); + + let observed = + return_type(&AggregateFunction::Median, &[DataType::Decimal(10, 6)]); + assert!(observed.is_err()); + + Ok(()) + } + #[test] fn test_min_max() -> Result<()> { let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?; diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index bae2de74c7b74..cda223e5eaa1a 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -21,8 +21,9 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::{ is_avg_support_arg_type, is_correlation_support_arg_type, - is_covariance_support_arg_type, is_stddev_support_arg_type, is_sum_support_arg_type, - is_variance_support_arg_type, try_cast, + is_covariance_support_arg_type, is_median_support_arg_type, + is_stddev_support_arg_type, is_sum_support_arg_type, is_variance_support_arg_type, + try_cast, }; use crate::physical_plan::functions::{Signature, TypeSignature}; use crate::physical_plan::PhysicalExpr; @@ -154,6 +155,15 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::Median => { + if !is_median_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } } } diff --git a/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs index cba30ee481abc..688c94e6afa5b 100644 --- a/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs +++ b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs @@ -203,6 +203,10 @@ impl ApproxPercentileAccumulator { return_type, } } + + pub(crate) fn get_digest(&self) -> &TDigest { + &self.digest + } } impl Accumulator for ApproxPercentileAccumulator { diff --git a/datafusion/src/physical_plan/expressions/median.rs b/datafusion/src/physical_plan/expressions/median.rs new file mode 100644 index 0000000000000..d4e7d837a1956 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/median.rs @@ -0,0 +1,298 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::sync::Arc; + +use crate::error::Result; +use crate::physical_plan::{ + expressions::approx_percentile_cont::ApproxPercentileAccumulator, Accumulator, + AggregateExpr, PhysicalExpr, +}; +use crate::scalar::ScalarValue; +use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; + +use super::format_state_name; + +/// MEDIAN aggregate expression +#[derive(Debug)] +pub struct Median { + name: String, + expr: Arc, + data_type: DataType, +} + +pub(crate) fn is_median_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + +impl Median { + /// Create a new MEDIAN aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + expr, + data_type, + } + } +} + +impl AggregateExpr for Median { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(MedianAccumulator::try_new( + self.data_type.clone(), + )?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "max_size"), + DataType::UInt64, + false, + ), + Field::new( + &format_state_name(&self.name, "sum"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "count"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "max"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "min"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "centroids"), + DataType::List(Box::new(Field::new("item", DataType::Float64, true))), + false, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +/// An accumulator to compute the median. +/// It is using approx_percentile_cont under the hood, which is an approximation. +/// We will revist this and may provide an implementation to calculate the exact median in the future. +#[derive(Debug)] +pub struct MedianAccumulator { + perc_cont: ApproxPercentileAccumulator, +} + +impl MedianAccumulator { + /// Creates a new `MedianAccumulator` + pub fn try_new(data_type: DataType) -> Result { + Ok(Self { + perc_cont: ApproxPercentileAccumulator::new(0.5_f64, data_type), + }) + } +} + +impl Accumulator for MedianAccumulator { + fn state(&self) -> Result> { + Ok(self.perc_cont.get_digest().to_scalar_state()) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.perc_cont.update_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.perc_cont.merge_batch(states) + } + + fn evaluate(&self) -> Result { + self.perc_cont.evaluate() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::from_slice::FromSlice; + use crate::physical_plan::expressions::col; + use crate::{error::Result, generic_test_op}; + use arrow::record_batch::RecordBatch; + use arrow::{array::*, datatypes::*}; + + #[test] + fn median_f64_1() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64])); + generic_test_op!( + a, + DataType::Float64, + Median, + ScalarValue::from(1.5_f64), + DataType::Float64 + ) + } + + #[test] + fn median_f64_2() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1.1_f64, 2_f64, 3_f64])); + generic_test_op!( + a, + DataType::Float64, + Median, + ScalarValue::from(2_f64), + DataType::Float64 + ) + } + + #[test] + fn median_f64_3() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[ + 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, + ])); + generic_test_op!( + a, + DataType::Float64, + Median, + ScalarValue::from(3_f64), + DataType::Float64 + ) + } + + #[test] + fn median_i32() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 3, 4, 5])); + generic_test_op!( + a, + DataType::Int32, + Median, + ScalarValue::from(3), + DataType::Int32 + ) + } + + #[test] + fn median_u32() -> Result<()> { + let a: ArrayRef = Arc::new(UInt32Array::from_slice(&[ + 1_u32, 2_u32, 3_u32, 4_u32, 5_u32, + ])); + generic_test_op!( + a, + DataType::UInt32, + Median, + ScalarValue::from(3_u32), + DataType::UInt32 + ) + } + + #[test] + fn median_f32() -> Result<()> { + let a: ArrayRef = Arc::new(Float32Array::from_slice(&[ + 1_f32, 2_f32, 3_f32, 4_f32, 5_f32, + ])); + generic_test_op!( + a, + DataType::Float32, + Median, + ScalarValue::from(3_f32), + DataType::Float32 + ) + } + + #[test] + fn median_i32_with_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(3), + Some(4), + Some(5), + ])); + generic_test_op!( + a, + DataType::Int32, + Median, + ScalarValue::from(3), + DataType::Int32 + ) + } + + #[test] + fn median_i32_all_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![None])); + generic_test_op!( + a, + DataType::Int32, + Median, + ScalarValue::from(0), + DataType::Int32 + ) + } + + fn aggregate( + batch: &RecordBatch, + agg: Arc, + ) -> Result { + let mut accum = agg.create_accumulator()?; + let expr = agg.expressions(); + let values = expr + .iter() + .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + accum.update_batch(&values)?; + accum.evaluate() + } +} diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 9344fbd6b1bc4..1ba17960243dc 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -47,6 +47,7 @@ mod min_max; mod correlation; mod covariance; mod distinct_expressions; +mod median; mod negative; mod not; mod nth_value; @@ -92,6 +93,7 @@ pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; pub use lead_lag::{lag, lead}; pub use literal::{lit, Literal}; +pub(crate) use median::{is_median_support_arg_type, Median}; pub use min_max::{Max, Min}; pub(crate) use min_max::{MaxAccumulator, MinAccumulator}; pub use negative::{negative, NegativeExpr}; diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index a025d4eeec860..69226d58bf95a 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -219,6 +219,42 @@ async fn csv_query_stddev_6() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_median_1() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT median(c2) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["3"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_median_2() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT median(c6) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["1146409980542786560"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_median_3() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT median(c12) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql).await; + actual.sort(); + let expected = vec![vec!["0.5550065410522981"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_external_table_count() { let mut ctx = ExecutionContext::new(); From 9b9abe38642b89c0e7a5d93fc3e702c7836bff25 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Tue, 1 Feb 2022 23:46:00 -0800 Subject: [PATCH 02/14] update doc --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1ea972cc4df17..6baff7a1b8437 100644 --- a/README.md +++ b/README.md @@ -328,7 +328,7 @@ This library currently supports many SQL constructs, including - `CAST` to change types, including e.g. `Timestamp(Nanosecond, None)` - Many mathematical unary and binary expressions such as `+`, `/`, `sqrt`, `tan`, `>=`. - `WHERE` to filter -- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `CORR`, `VAR`, `COVAR`, `STDDEV` (sample and population) +- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `MEDIAN`, `CORR`, `VAR`, `COVAR`, `STDDEV` (sample and population) - `ORDER BY` together with an expression and optional `ASC` or `DESC` and also optional `NULLS FIRST` or `NULLS LAST` ## Supported Functions From 07f58193b11f386c0bd61b62cf2b17c44334cc33 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Wed, 2 Feb 2022 23:01:30 -0800 Subject: [PATCH 03/14] rename median to approx_median --- ballista/rust/core/proto/ballista.proto | 2 +- .../core/src/serde/logical_plan/to_proto.rs | 6 +- ballista/rust/core/src/serde/mod.rs | 2 +- datafusion/src/physical_plan/aggregates.rs | 44 ++++++----- .../coercion_rule/aggregate_rule.rs | 8 +- .../{median.rs => approx_median.rs} | 75 ++++++++++++------- .../src/physical_plan/expressions/mod.rs | 4 +- datafusion/tests/sql/aggregates.rs | 9 +-- 8 files changed, 86 insertions(+), 64 deletions(-) rename datafusion/src/physical_plan/expressions/{median.rs => approx_median.rs} (82%) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index df9dcc1d5040e..0b0d364ca7c70 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -177,7 +177,7 @@ enum AggregateFunction { STDDEV_POP=12; CORRELATION=13; APPROX_PERCENTILE_CONT = 14; - MEDIAN=15; + APPROX_MEDIAN=15; } message AggregateExprNode { diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 0a3be7db9e333..84910b2c31fa9 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1100,7 +1100,9 @@ impl TryInto for &Expr { AggregateFunction::Correlation => { protobuf::AggregateFunction::Correlation } - AggregateFunction::Median => protobuf::AggregateFunction::Median, + AggregateFunction::ApproxMedian => { + protobuf::AggregateFunction::ApproxMedian + } }; let aggregate_expr = protobuf::AggregateExprNode { @@ -1341,7 +1343,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::StddevPop => Self::StddevPop, AggregateFunction::Correlation => Self::Correlation, AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, - AggregateFunction::Median => Self::Median, + AggregateFunction::ApproxMedian => Self::ApproxMedian, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index e62996171172e..f7b0b9436c4cb 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -132,7 +132,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::ApproxPercentileCont => { AggregateFunction::ApproxPercentileCont } - protobuf::AggregateFunction::Median => AggregateFunction::Median, + protobuf::AggregateFunction::ApproxMedian => AggregateFunction::ApproxMedian, } } } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 239bfbcdd5651..8c4747d97492a 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -82,8 +82,8 @@ pub enum AggregateFunction { Correlation, /// Approximate continuous percentile function ApproxPercentileCont, - /// Median - Median, + /// ApproxMedian + ApproxMedian, } impl fmt::Display for AggregateFunction { @@ -115,7 +115,7 @@ impl FromStr for AggregateFunction { "covar_pop" => AggregateFunction::CovariancePop, "corr" => AggregateFunction::Correlation, "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, - "median" => AggregateFunction::Median, + "approx_median" => AggregateFunction::ApproxMedian, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -164,7 +164,7 @@ pub fn return_type( true, )))), AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()), - AggregateFunction::Median => Ok(coerced_data_types[0].clone()), + AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()), } } @@ -353,12 +353,14 @@ pub fn create_aggregate_expr( .to_string(), )); } - (AggregateFunction::Median, false) => Arc::new(expressions::Median::new( - coerced_phy_exprs[0].clone(), - name, - return_type, - )), - (AggregateFunction::Median, true) => { + (AggregateFunction::ApproxMedian, false) => { + Arc::new(expressions::ApproxMedian::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )) + } + (AggregateFunction::ApproxMedian, true) => { return Err(DataFusionError::NotImplemented( "MEDIAN(DISTINCT) aggregations are not available".to_string(), )); @@ -413,7 +415,7 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature { | AggregateFunction::VariancePop | AggregateFunction::Stddev | AggregateFunction::StddevPop - | AggregateFunction::Median => { + | AggregateFunction::ApproxMedian => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::Covariance | AggregateFunction::CovariancePop => { @@ -437,8 +439,8 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature { mod tests { use super::*; use crate::physical_plan::expressions::{ - ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, Correlation, Count, - Covariance, DistinctArrayAgg, DistinctCount, Max, Median, Min, Stddev, Sum, + ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, Correlation, + Count, Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, }; use crate::{error::Result, scalar::ScalarValue}; @@ -1014,7 +1016,7 @@ mod tests { #[test] fn test_median_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Median]; + let funcs = vec![AggregateFunction::ApproxMedian]; let data_types = vec![ DataType::UInt32, DataType::UInt64, @@ -1038,8 +1040,8 @@ mod tests { "c1", )?; - if fun == AggregateFunction::Median { - assert!(result_agg_phy_exprs.as_any().is::()); + if fun == AggregateFunction::ApproxMedian { + assert!(result_agg_phy_exprs.as_any().is::()); assert_eq!("c1", result_agg_phy_exprs.name()); assert_eq!( Field::new("c1", data_type.clone(), true), @@ -1053,14 +1055,16 @@ mod tests { #[test] fn test_median() -> Result<()> { - let observed = return_type(&AggregateFunction::Median, &[DataType::Utf8]); + let observed = return_type(&AggregateFunction::ApproxMedian, &[DataType::Utf8]); assert!(observed.is_err()); - let observed = return_type(&AggregateFunction::Median, &[DataType::Int32])?; + let observed = return_type(&AggregateFunction::ApproxMedian, &[DataType::Int32])?; assert_eq!(DataType::Int32, observed); - let observed = - return_type(&AggregateFunction::Median, &[DataType::Decimal(10, 6)]); + let observed = return_type( + &AggregateFunction::ApproxMedian, + &[DataType::Decimal(10, 6)], + ); assert!(observed.is_err()); Ok(()) diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index cda223e5eaa1a..482f61f064840 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -20,8 +20,8 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::{ - is_avg_support_arg_type, is_correlation_support_arg_type, - is_covariance_support_arg_type, is_median_support_arg_type, + is_approx_median_support_arg_type, is_avg_support_arg_type, + is_correlation_support_arg_type, is_covariance_support_arg_type, is_stddev_support_arg_type, is_sum_support_arg_type, is_variance_support_arg_type, try_cast, }; @@ -155,8 +155,8 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } - AggregateFunction::Median => { - if !is_median_support_arg_type(&input_types[0]) { + AggregateFunction::ApproxMedian => { + if !is_approx_median_support_arg_type(&input_types[0]) { return Err(DataFusionError::Plan(format!( "The function {:?} does not support inputs of type {:?}.", agg_fun, input_types[0] diff --git a/datafusion/src/physical_plan/expressions/median.rs b/datafusion/src/physical_plan/expressions/approx_median.rs similarity index 82% rename from datafusion/src/physical_plan/expressions/median.rs rename to datafusion/src/physical_plan/expressions/approx_median.rs index d4e7d837a1956..1fe510078e4a6 100644 --- a/datafusion/src/physical_plan/expressions/median.rs +++ b/datafusion/src/physical_plan/expressions/approx_median.rs @@ -32,13 +32,13 @@ use super::format_state_name; /// MEDIAN aggregate expression #[derive(Debug)] -pub struct Median { +pub struct ApproxMedian { name: String, expr: Arc, data_type: DataType, } -pub(crate) fn is_median_support_arg_type(arg_type: &DataType) -> bool { +pub(crate) fn is_approx_median_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, DataType::UInt8 @@ -54,8 +54,8 @@ pub(crate) fn is_median_support_arg_type(arg_type: &DataType) -> bool { ) } -impl Median { - /// Create a new MEDIAN aggregate function +impl ApproxMedian { + /// Create a new APPROX_MEDIAN aggregate function pub fn new( expr: Arc, name: impl Into, @@ -69,7 +69,7 @@ impl Median { } } -impl AggregateExpr for Median { +impl AggregateExpr for ApproxMedian { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -80,7 +80,7 @@ impl AggregateExpr for Median { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(MedianAccumulator::try_new( + Ok(Box::new(ApproxMedianAccumulator::try_new( self.data_type.clone(), )?)) } @@ -129,16 +129,16 @@ impl AggregateExpr for Median { } } -/// An accumulator to compute the median. +/// An accumulator to compute the approx_median. /// It is using approx_percentile_cont under the hood, which is an approximation. -/// We will revist this and may provide an implementation to calculate the exact median in the future. +/// We will revist this and may provide an implementation to calculate the exact approx_median in the future. #[derive(Debug)] -pub struct MedianAccumulator { +pub struct ApproxMedianAccumulator { perc_cont: ApproxPercentileAccumulator, } -impl MedianAccumulator { - /// Creates a new `MedianAccumulator` +impl ApproxMedianAccumulator { + /// Creates a new `ApproxMedianAccumulator` pub fn try_new(data_type: DataType) -> Result { Ok(Self { perc_cont: ApproxPercentileAccumulator::new(0.5_f64, data_type), @@ -146,7 +146,7 @@ impl MedianAccumulator { } } -impl Accumulator for MedianAccumulator { +impl Accumulator for ApproxMedianAccumulator { fn state(&self) -> Result> { Ok(self.perc_cont.get_digest().to_scalar_state()) } @@ -174,85 +174,85 @@ mod tests { use arrow::{array::*, datatypes::*}; #[test] - fn median_f64_1() -> Result<()> { + fn approx_median_f64_1() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64])); generic_test_op!( a, DataType::Float64, - Median, + ApproxMedian, ScalarValue::from(1.5_f64), DataType::Float64 ) } #[test] - fn median_f64_2() -> Result<()> { + fn approx_median_f64_2() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1.1_f64, 2_f64, 3_f64])); generic_test_op!( a, DataType::Float64, - Median, + ApproxMedian, ScalarValue::from(2_f64), DataType::Float64 ) } #[test] - fn median_f64_3() -> Result<()> { + fn approx_median_f64_3() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from_slice(&[ 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, ])); generic_test_op!( a, DataType::Float64, - Median, + ApproxMedian, ScalarValue::from(3_f64), DataType::Float64 ) } #[test] - fn median_i32() -> Result<()> { + fn approx_median_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 3, 4, 5])); generic_test_op!( a, DataType::Int32, - Median, + ApproxMedian, ScalarValue::from(3), DataType::Int32 ) } #[test] - fn median_u32() -> Result<()> { + fn approx_median_u32() -> Result<()> { let a: ArrayRef = Arc::new(UInt32Array::from_slice(&[ 1_u32, 2_u32, 3_u32, 4_u32, 5_u32, ])); generic_test_op!( a, DataType::UInt32, - Median, + ApproxMedian, ScalarValue::from(3_u32), DataType::UInt32 ) } #[test] - fn median_f32() -> Result<()> { + fn approx_median_f32() -> Result<()> { let a: ArrayRef = Arc::new(Float32Array::from_slice(&[ 1_f32, 2_f32, 3_f32, 4_f32, 5_f32, ])); generic_test_op!( a, DataType::Float32, - Median, + ApproxMedian, ScalarValue::from(3_f32), DataType::Float32 ) } #[test] - fn median_i32_with_nulls() -> Result<()> { + fn approx_median_i32_with_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![ Some(1), None, @@ -263,19 +263,38 @@ mod tests { generic_test_op!( a, DataType::Int32, - Median, + ApproxMedian, ScalarValue::from(3), DataType::Int32 ) } #[test] - fn median_i32_all_nulls() -> Result<()> { + fn approx_median_i32_with_nulls_2() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(5), + Some(1), + None, + None, + Some(3), + Some(4), + ])); + generic_test_op!( + a, + DataType::Int32, + ApproxMedian, + ScalarValue::from(2), + DataType::Int32 + ) + } + + #[test] + fn approx_median_i32_all_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![None])); generic_test_op!( a, DataType::Int32, - Median, + ApproxMedian, ScalarValue::from(0), DataType::Int32 ) diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 1ba17960243dc..0f83cf8c32232 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -44,10 +44,10 @@ mod lead_lag; mod literal; #[macro_use] mod min_max; +mod approx_median; mod correlation; mod covariance; mod distinct_expressions; -mod median; mod negative; mod not; mod nth_value; @@ -66,6 +66,7 @@ pub mod helpers { } pub use approx_distinct::ApproxDistinct; +pub(crate) use approx_median::{is_approx_median_support_arg_type, ApproxMedian}; pub use approx_percentile_cont::{ is_approx_percentile_cont_supported_arg_type, ApproxPercentileCont, }; @@ -93,7 +94,6 @@ pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; pub use lead_lag::{lag, lead}; pub use literal::{lit, Literal}; -pub(crate) use median::{is_median_support_arg_type, Median}; pub use min_max::{Max, Min}; pub(crate) use min_max::{MaxAccumulator, MinAccumulator}; pub use negative::{negative, NegativeExpr}; diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 69226d58bf95a..f26c3891926da 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -223,9 +223,8 @@ async fn csv_query_stddev_6() -> Result<()> { async fn csv_query_median_1() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT median(c2) FROM aggregate_test_100"; + let sql = "SELECT approx_median(c2) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; - actual.sort(); let expected = vec![vec!["3"]]; assert_float_eq(&expected, &actual); Ok(()) @@ -235,9 +234,8 @@ async fn csv_query_median_1() -> Result<()> { async fn csv_query_median_2() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT median(c6) FROM aggregate_test_100"; + let sql = "SELECT approx_median(c6) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; - actual.sort(); let expected = vec![vec!["1146409980542786560"]]; assert_float_eq(&expected, &actual); Ok(()) @@ -247,9 +245,8 @@ async fn csv_query_median_2() -> Result<()> { async fn csv_query_median_3() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT median(c12) FROM aggregate_test_100"; + let sql = "SELECT approx_median(c12) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; - actual.sort(); let expected = vec![vec!["0.5550065410522981"]]; assert_float_eq(&expected, &actual); Ok(()) From 170b8e6f78cef22edeb0425cd035b9aca3f2cd8a Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Wed, 2 Feb 2022 23:04:05 -0800 Subject: [PATCH 04/14] rename median to approx_median --- datafusion/tests/sql/aggregates.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index f26c3891926da..68d3a3997bf2c 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -224,7 +224,7 @@ async fn csv_query_median_1() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT approx_median(c2) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["3"]]; assert_float_eq(&expected, &actual); Ok(()) @@ -235,7 +235,7 @@ async fn csv_query_median_2() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT approx_median(c6) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["1146409980542786560"]]; assert_float_eq(&expected, &actual); Ok(()) @@ -246,7 +246,7 @@ async fn csv_query_median_3() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT approx_median(c12) FROM aggregate_test_100"; - let mut actual = execute(&mut ctx, sql).await; + let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["0.5550065410522981"]]; assert_float_eq(&expected, &actual); Ok(()) From 03b64df9ed0d2988e54b8649eadf9c4398802daf Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Wed, 2 Feb 2022 23:10:20 -0800 Subject: [PATCH 05/14] add doc --- README.md | 2 +- datafusion/src/physical_plan/expressions/approx_median.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6baff7a1b8437..f99ca558ffc99 100644 --- a/README.md +++ b/README.md @@ -328,7 +328,7 @@ This library currently supports many SQL constructs, including - `CAST` to change types, including e.g. `Timestamp(Nanosecond, None)` - Many mathematical unary and binary expressions such as `+`, `/`, `sqrt`, `tan`, `>=`. - `WHERE` to filter -- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `MEDIAN`, `CORR`, `VAR`, `COVAR`, `STDDEV` (sample and population) +- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `APPROX_PERCENTILE_CONT`, `APPROX_MEDIAN`, `CORR`, `VAR`, `COVAR`, `STDDEV` (sample and population) - `ORDER BY` together with an expression and optional `ASC` or `DESC` and also optional `NULLS FIRST` or `NULLS LAST` ## Supported Functions diff --git a/datafusion/src/physical_plan/expressions/approx_median.rs b/datafusion/src/physical_plan/expressions/approx_median.rs index 1fe510078e4a6..cfd29bca15c7e 100644 --- a/datafusion/src/physical_plan/expressions/approx_median.rs +++ b/datafusion/src/physical_plan/expressions/approx_median.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions that can evaluated at runtime during query execution +//! Defines physical expressions for APPROX_MEDIAN that can be evaluated at runtime during query execution use std::any::Any; use std::sync::Arc; From 56b1ad342d2064d7bc46422217076ca70ddf7939 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Sat, 5 Feb 2022 02:05:23 +0000 Subject: [PATCH 06/14] test optimizer --- datafusion/src/execution/context.rs | 5 +- datafusion/src/optimizer/mod.rs | 1 + datafusion/src/optimizer/to_approx_perc.rs | 251 +++++++++++++++++++++ 3 files changed, 256 insertions(+), 1 deletion(-) create mode 100644 datafusion/src/optimizer/to_approx_perc.rs diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 023d3a0023be3..3f65b8b6b9e5f 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -72,13 +72,15 @@ use crate::optimizer::limit_push_down::LimitPushDown; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::projection_push_down::ProjectionPushDown; use crate::optimizer::simplify_expressions::SimplifyExpressions; +use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; +use crate::optimizer::to_approx_perc::ToApproxPerc; + use crate::physical_optimizer::coalesce_batches::CoalesceBatches; use crate::physical_optimizer::merge_exec::AddCoalescePartitionsExec; use crate::physical_optimizer::repartition::Repartition; use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use crate::logical_plan::plan::Explain; -use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::physical_plan::planner::DefaultPhysicalPlanner; use crate::physical_plan::udf::ScalarUDF; use crate::physical_plan::ExecutionPlan; @@ -927,6 +929,7 @@ impl Default for ExecutionConfig { Arc::new(ProjectionPushDown::new()), Arc::new(FilterPushDown::new()), Arc::new(LimitPushDown::new()), + Arc::new(ToApproxPerc::new()), Arc::new(SingleDistinctToGroupBy::new()), ], physical_optimizers: vec![ diff --git a/datafusion/src/optimizer/mod.rs b/datafusion/src/optimizer/mod.rs index 984cbee909471..418eaad4bc5ce 100644 --- a/datafusion/src/optimizer/mod.rs +++ b/datafusion/src/optimizer/mod.rs @@ -27,4 +27,5 @@ pub mod optimizer; pub mod projection_push_down; pub mod simplify_expressions; pub mod single_distinct_to_groupby; +pub mod to_approx_perc; pub mod utils; diff --git a/datafusion/src/optimizer/to_approx_perc.rs b/datafusion/src/optimizer/to_approx_perc.rs new file mode 100644 index 0000000000000..39a9b30b102d6 --- /dev/null +++ b/datafusion/src/optimizer/to_approx_perc.rs @@ -0,0 +1,251 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! espression/function to approx_percentile optimizer rule + +use crate::error::Result; +use crate::execution::context::ExecutionProps; +use crate::logical_plan::plan::{Aggregate, Projection}; +use crate::logical_plan::{col, columnize_expr, DFSchema, Expr, LogicalPlan}; +use crate::physical_plan::aggregates; +use crate::optimizer::optimizer::OptimizerRule; +use crate::optimizer::utils; +use hashbrown::HashSet; +use std::sync::Arc; + +/// espression/function to approx_percentile optimizer rule +/// ```text +/// SELECT F1(s) +/// ... +/// +/// Into +/// +/// SELECT APPROX_PERCENTILE_CONT(s, lit(n)) as "F1(s)" +/// ... +/// ``` +pub struct ToApproxPerc {} + +impl ToApproxPerc { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +fn optimize(plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Aggregate(Aggregate { + input, + aggr_expr, + schema, + group_expr, + }) => { + let new_aggr_expr = aggr_expr + .iter() + .map(|agg_expr| match agg_expr { + Expr::AggregateFunction { fun, args, .. } => { + let mut new_args = args.clone(); + match fun { + aggregates::AggregateFunction::ApproxMedian => { + //new_args.push(lit(0.5_f64)); + Expr::AggregateFunction { + fun: aggregates::AggregateFunction::ApproxPercentileCont, + args: new_args, + distinct: false, + } + } + _ => agg_expr.clone(), + } + } + _ => agg_expr.clone(), + }) + .collect::>(); + + Ok(LogicalPlan::Aggregate(Aggregate { + input: input.clone(), + aggr_expr: new_aggr_expr, + schema: schema.clone(), + group_expr: group_expr.clone(), + })) + } + _ => Ok(plan.clone()) + } +} + +impl OptimizerRule for ToApproxPerc { + fn optimize( + &self, + plan: &LogicalPlan, + _execution_props: &ExecutionProps, + ) -> Result { + optimize(plan) + } + fn name(&self) -> &str { + "ToApproxPerc" + } +} + +// #[cfg(test)] +// mod tests { +// use super::*; +// use crate::logical_plan::{col, count, count_distinct, lit, max, LogicalPlanBuilder}; +// use crate::physical_plan::aggregates; +// use crate::test::*; + +// fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { +// let rule = SingleDistinctToGroupBy::new(); +// let optimized_plan = rule +// .optimize(plan, &ExecutionProps::new()) +// .expect("failed to optimize plan"); +// let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); +// assert_eq!(formatted_plan, expected); +// } + +// #[test] +// fn not_exist_distinct() -> Result<()> { +// let table_scan = test_table_scan()?; + +// let plan = LogicalPlanBuilder::from(table_scan) +// .aggregate(Vec::::new(), vec![max(col("b"))])? +// .build()?; + +// // Do nothing +// let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#test.b)]] [MAX(test.b):UInt32;N]\ +// \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; + +// assert_optimized_plan_eq(&plan, expected); +// Ok(()) +// } + +// #[test] +// fn single_distinct() -> Result<()> { +// let table_scan = test_table_scan()?; + +// let plan = LogicalPlanBuilder::from(table_scan) +// .aggregate(Vec::::new(), vec![count_distinct(col("b"))])? +// .build()?; + +// // Should work +// let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):UInt64;N]\ +// \n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):UInt64;N]\ +// \n Aggregate: groupBy=[[#test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ +// \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; + +// assert_optimized_plan_eq(&plan, expected); +// Ok(()) +// } + +// #[test] +// fn single_distinct_expr() -> Result<()> { +// let table_scan = test_table_scan()?; + +// let plan = LogicalPlanBuilder::from(table_scan) +// .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? +// .build()?; + +// let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):UInt64;N]\ +// \n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):UInt64;N]\ +// \n Aggregate: groupBy=[[Int32(2) * #test.b AS alias1]], aggr=[[]] [alias1:Int32]\ +// \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; + +// assert_optimized_plan_eq(&plan, expected); +// Ok(()) +// } + +// #[test] +// fn single_distinct_and_groupby() -> Result<()> { +// let table_scan = test_table_scan()?; + +// let plan = LogicalPlanBuilder::from(table_scan) +// .aggregate(vec![col("a")], vec![count_distinct(col("b"))])? +// .build()?; + +// // Should work +// let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N]\ +// \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1)]] [a:UInt32, COUNT(alias1):UInt64;N]\ +// \n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ +// \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; + +// assert_optimized_plan_eq(&plan, expected); +// Ok(()) +// } + +// #[test] +// fn two_distinct_and_groupby() -> Result<()> { +// let table_scan = test_table_scan()?; + +// let plan = LogicalPlanBuilder::from(table_scan) +// .aggregate( +// vec![col("a")], +// vec![count_distinct(col("b")), count_distinct(col("c"))], +// )? +// .build()?; + +// // Do nothing +// let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(DISTINCT #test.c)]] [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, COUNT(DISTINCT test.c):UInt64;N]\ +// \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; + +// assert_optimized_plan_eq(&plan, expected); +// Ok(()) +// } + +// #[test] +// fn one_field_two_distinct_and_groupby() -> Result<()> { +// let table_scan = test_table_scan()?; + +// let plan = LogicalPlanBuilder::from(table_scan) +// .aggregate( +// vec![col("a")], +// vec![ +// count_distinct(col("b")), +// Expr::AggregateFunction { +// fun: aggregates::AggregateFunction::Max, +// distinct: true, +// args: vec![col("b")], +// }, +// ], +// )? +// .build()?; +// // Should work +// let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b), #MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, MAX(DISTINCT test.b):UInt32;N]\ +// \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1), MAX(#alias1)]] [a:UInt32, COUNT(alias1):UInt64;N, MAX(alias1):UInt32;N]\ +// \n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ +// \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; + +// assert_optimized_plan_eq(&plan, expected); +// Ok(()) +// } + +// #[test] +// fn distinct_and_common() -> Result<()> { +// let table_scan = test_table_scan()?; + +// let plan = LogicalPlanBuilder::from(table_scan) +// .aggregate( +// vec![col("a")], +// vec![count_distinct(col("b")), count(col("c"))], +// )? +// .build()?; + +// // Do nothing +// let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(#test.c)]] [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, COUNT(test.c):UInt64;N]\ +// \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; + +// assert_optimized_plan_eq(&plan, expected); +// Ok(()) +// } +// } From 9e63810556a855578c3f71ce089c487ad3ed19b5 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Fri, 4 Feb 2022 23:34:15 -0800 Subject: [PATCH 07/14] try rewriting logical plan --- datafusion/src/optimizer/to_approx_perc.rs | 260 +++++--------- .../expressions/approx_median.rs | 236 +------------ .../expressions/approx_median_old.rs | 320 ++++++++++++++++++ .../expressions/approx_percentile_cont.rs | 6 +- 4 files changed, 413 insertions(+), 409 deletions(-) create mode 100644 datafusion/src/physical_plan/expressions/approx_median_old.rs diff --git a/datafusion/src/optimizer/to_approx_perc.rs b/datafusion/src/optimizer/to_approx_perc.rs index 39a9b30b102d6..e88efe86019f9 100644 --- a/datafusion/src/optimizer/to_approx_perc.rs +++ b/datafusion/src/optimizer/to_approx_perc.rs @@ -19,13 +19,12 @@ use crate::error::Result; use crate::execution::context::ExecutionProps; -use crate::logical_plan::plan::{Aggregate, Projection}; -use crate::logical_plan::{col, columnize_expr, DFSchema, Expr, LogicalPlan}; -use crate::physical_plan::aggregates; +use crate::logical_plan::plan::Aggregate; +use crate::logical_plan::{Expr, LogicalPlan}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; -use hashbrown::HashSet; -use std::sync::Arc; +use crate::physical_plan::aggregates; +use crate::scalar::ScalarValue; /// espression/function to approx_percentile optimizer rule /// ```text @@ -46,6 +45,12 @@ impl ToApproxPerc { } } +impl Default for ToApproxPerc { + fn default() -> Self { + Self::new() + } +} + fn optimize(plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Aggregate(Aggregate { @@ -55,25 +60,9 @@ fn optimize(plan: &LogicalPlan) -> Result { group_expr, }) => { let new_aggr_expr = aggr_expr - .iter() - .map(|agg_expr| match agg_expr { - Expr::AggregateFunction { fun, args, .. } => { - let mut new_args = args.clone(); - match fun { - aggregates::AggregateFunction::ApproxMedian => { - //new_args.push(lit(0.5_f64)); - Expr::AggregateFunction { - fun: aggregates::AggregateFunction::ApproxPercentileCont, - args: new_args, - distinct: false, - } - } - _ => agg_expr.clone(), - } - } - _ => agg_expr.clone(), - }) - .collect::>(); + .iter() + .map(|agg_expr| replace_with_percentile(agg_expr).unwrap()) + .collect::>(); Ok(LogicalPlan::Aggregate(Aggregate { input: input.clone(), @@ -82,7 +71,41 @@ fn optimize(plan: &LogicalPlan) -> Result { group_expr: group_expr.clone(), })) } - _ => Ok(plan.clone()) + _ => optimize_children(plan), + } +} + +fn optimize_children(plan: &LogicalPlan) -> Result { + let expr = plan.expressions(); + let inputs = plan.inputs(); + let new_inputs = inputs + .iter() + .map(|plan| optimize(plan)) + .collect::>>()?; + utils::from_plan(plan, &expr, &new_inputs) +} + +fn replace_with_percentile(expr: &Expr) -> Result { + match expr { + Expr::AggregateFunction { + fun, + args, + distinct, + } => { + let mut new_args = args.clone(); + let mut new_func = fun.clone(); + if fun == &aggregates::AggregateFunction::ApproxMedian { + new_args.push(Expr::Literal(ScalarValue::Float64(Some(0.5_f64)))); + new_func = aggregates::AggregateFunction::ApproxPercentileCont; + } + + Ok(Expr::AggregateFunction { + fun: new_func, + args: new_args, + distinct: *distinct, + }) + } + _ => Ok(expr.clone()), } } @@ -99,153 +122,40 @@ impl OptimizerRule for ToApproxPerc { } } -// #[cfg(test)] -// mod tests { -// use super::*; -// use crate::logical_plan::{col, count, count_distinct, lit, max, LogicalPlanBuilder}; -// use crate::physical_plan::aggregates; -// use crate::test::*; - -// fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { -// let rule = SingleDistinctToGroupBy::new(); -// let optimized_plan = rule -// .optimize(plan, &ExecutionProps::new()) -// .expect("failed to optimize plan"); -// let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); -// assert_eq!(formatted_plan, expected); -// } - -// #[test] -// fn not_exist_distinct() -> Result<()> { -// let table_scan = test_table_scan()?; - -// let plan = LogicalPlanBuilder::from(table_scan) -// .aggregate(Vec::::new(), vec![max(col("b"))])? -// .build()?; - -// // Do nothing -// let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#test.b)]] [MAX(test.b):UInt32;N]\ -// \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; - -// assert_optimized_plan_eq(&plan, expected); -// Ok(()) -// } - -// #[test] -// fn single_distinct() -> Result<()> { -// let table_scan = test_table_scan()?; - -// let plan = LogicalPlanBuilder::from(table_scan) -// .aggregate(Vec::::new(), vec![count_distinct(col("b"))])? -// .build()?; - -// // Should work -// let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):UInt64;N]\ -// \n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):UInt64;N]\ -// \n Aggregate: groupBy=[[#test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ -// \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; - -// assert_optimized_plan_eq(&plan, expected); -// Ok(()) -// } - -// #[test] -// fn single_distinct_expr() -> Result<()> { -// let table_scan = test_table_scan()?; - -// let plan = LogicalPlanBuilder::from(table_scan) -// .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? -// .build()?; - -// let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):UInt64;N]\ -// \n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):UInt64;N]\ -// \n Aggregate: groupBy=[[Int32(2) * #test.b AS alias1]], aggr=[[]] [alias1:Int32]\ -// \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; - -// assert_optimized_plan_eq(&plan, expected); -// Ok(()) -// } - -// #[test] -// fn single_distinct_and_groupby() -> Result<()> { -// let table_scan = test_table_scan()?; - -// let plan = LogicalPlanBuilder::from(table_scan) -// .aggregate(vec![col("a")], vec![count_distinct(col("b"))])? -// .build()?; - -// // Should work -// let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N]\ -// \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1)]] [a:UInt32, COUNT(alias1):UInt64;N]\ -// \n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ -// \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; - -// assert_optimized_plan_eq(&plan, expected); -// Ok(()) -// } - -// #[test] -// fn two_distinct_and_groupby() -> Result<()> { -// let table_scan = test_table_scan()?; - -// let plan = LogicalPlanBuilder::from(table_scan) -// .aggregate( -// vec![col("a")], -// vec![count_distinct(col("b")), count_distinct(col("c"))], -// )? -// .build()?; - -// // Do nothing -// let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(DISTINCT #test.c)]] [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, COUNT(DISTINCT test.c):UInt64;N]\ -// \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; - -// assert_optimized_plan_eq(&plan, expected); -// Ok(()) -// } - -// #[test] -// fn one_field_two_distinct_and_groupby() -> Result<()> { -// let table_scan = test_table_scan()?; - -// let plan = LogicalPlanBuilder::from(table_scan) -// .aggregate( -// vec![col("a")], -// vec![ -// count_distinct(col("b")), -// Expr::AggregateFunction { -// fun: aggregates::AggregateFunction::Max, -// distinct: true, -// args: vec![col("b")], -// }, -// ], -// )? -// .build()?; -// // Should work -// let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b), #MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, MAX(DISTINCT test.b):UInt32;N]\ -// \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1), MAX(#alias1)]] [a:UInt32, COUNT(alias1):UInt64;N, MAX(alias1):UInt32;N]\ -// \n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ -// \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; - -// assert_optimized_plan_eq(&plan, expected); -// Ok(()) -// } - -// #[test] -// fn distinct_and_common() -> Result<()> { -// let table_scan = test_table_scan()?; - -// let plan = LogicalPlanBuilder::from(table_scan) -// .aggregate( -// vec![col("a")], -// vec![count_distinct(col("b")), count(col("c"))], -// )? -// .build()?; - -// // Do nothing -// let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(#test.c)]] [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, COUNT(test.c):UInt64;N]\ -// \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; - -// assert_optimized_plan_eq(&plan, expected); -// Ok(()) -// } -// } +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::{col, LogicalPlanBuilder}; + use crate::physical_plan::aggregates; + use crate::test::*; + + fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { + let rule = ToApproxPerc::new(); + let optimized_plan = rule + .optimize(plan, &ExecutionProps::new()) + .expect("failed to optimize plan"); + let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); + assert_eq!(formatted_plan, expected); + } + + #[test] + fn median_1() -> Result<()> { + let table_scan = test_table_scan()?; + let expr = Expr::AggregateFunction { + fun: aggregates::AggregateFunction::ApproxMedian, + distinct: false, + args: vec![col("b")], + }; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![expr])? + .build()?; + + // Do nothing + let expected = "Aggregate: groupBy=[[]], aggr=[[APPROXPERCENTILECONT(#test.b, Float64(0.5))]] [APPROXMEDIAN(test.b):UInt32;N]\ + \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/expressions/approx_median.rs b/datafusion/src/physical_plan/expressions/approx_median.rs index cfd29bca15c7e..fd8c542de912e 100644 --- a/datafusion/src/physical_plan/expressions/approx_median.rs +++ b/datafusion/src/physical_plan/expressions/approx_median.rs @@ -15,20 +15,14 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions for APPROX_MEDIAN that can be evaluated at runtime during query execution +//! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution use std::any::Any; use std::sync::Arc; use crate::error::Result; -use crate::physical_plan::{ - expressions::approx_percentile_cont::ApproxPercentileAccumulator, Accumulator, - AggregateExpr, PhysicalExpr, -}; -use crate::scalar::ScalarValue; -use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; - -use super::format_state_name; +use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use arrow::{datatypes::DataType, datatypes::Field}; /// MEDIAN aggregate expression #[derive(Debug)] @@ -80,44 +74,11 @@ impl AggregateExpr for ApproxMedian { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(ApproxMedianAccumulator::try_new( - self.data_type.clone(), - )?)) + unimplemented!() } fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - &format_state_name(&self.name, "max_size"), - DataType::UInt64, - false, - ), - Field::new( - &format_state_name(&self.name, "sum"), - DataType::Float64, - false, - ), - Field::new( - &format_state_name(&self.name, "count"), - DataType::Float64, - false, - ), - Field::new( - &format_state_name(&self.name, "max"), - DataType::Float64, - false, - ), - Field::new( - &format_state_name(&self.name, "min"), - DataType::Float64, - false, - ), - Field::new( - &format_state_name(&self.name, "centroids"), - DataType::List(Box::new(Field::new("item", DataType::Float64, true))), - false, - ), - ]) + unimplemented!() } fn expressions(&self) -> Vec> { @@ -128,190 +89,3 @@ impl AggregateExpr for ApproxMedian { &self.name } } - -/// An accumulator to compute the approx_median. -/// It is using approx_percentile_cont under the hood, which is an approximation. -/// We will revist this and may provide an implementation to calculate the exact approx_median in the future. -#[derive(Debug)] -pub struct ApproxMedianAccumulator { - perc_cont: ApproxPercentileAccumulator, -} - -impl ApproxMedianAccumulator { - /// Creates a new `ApproxMedianAccumulator` - pub fn try_new(data_type: DataType) -> Result { - Ok(Self { - perc_cont: ApproxPercentileAccumulator::new(0.5_f64, data_type), - }) - } -} - -impl Accumulator for ApproxMedianAccumulator { - fn state(&self) -> Result> { - Ok(self.perc_cont.get_digest().to_scalar_state()) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - self.perc_cont.update_batch(values) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.perc_cont.merge_batch(states) - } - - fn evaluate(&self) -> Result { - self.perc_cont.evaluate() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::from_slice::FromSlice; - use crate::physical_plan::expressions::col; - use crate::{error::Result, generic_test_op}; - use arrow::record_batch::RecordBatch; - use arrow::{array::*, datatypes::*}; - - #[test] - fn approx_median_f64_1() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64])); - generic_test_op!( - a, - DataType::Float64, - ApproxMedian, - ScalarValue::from(1.5_f64), - DataType::Float64 - ) - } - - #[test] - fn approx_median_f64_2() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1.1_f64, 2_f64, 3_f64])); - generic_test_op!( - a, - DataType::Float64, - ApproxMedian, - ScalarValue::from(2_f64), - DataType::Float64 - ) - } - - #[test] - fn approx_median_f64_3() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(&[ - 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, - ])); - generic_test_op!( - a, - DataType::Float64, - ApproxMedian, - ScalarValue::from(3_f64), - DataType::Float64 - ) - } - - #[test] - fn approx_median_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Int32, - ApproxMedian, - ScalarValue::from(3), - DataType::Int32 - ) - } - - #[test] - fn approx_median_u32() -> Result<()> { - let a: ArrayRef = Arc::new(UInt32Array::from_slice(&[ - 1_u32, 2_u32, 3_u32, 4_u32, 5_u32, - ])); - generic_test_op!( - a, - DataType::UInt32, - ApproxMedian, - ScalarValue::from(3_u32), - DataType::UInt32 - ) - } - - #[test] - fn approx_median_f32() -> Result<()> { - let a: ArrayRef = Arc::new(Float32Array::from_slice(&[ - 1_f32, 2_f32, 3_f32, 4_f32, 5_f32, - ])); - generic_test_op!( - a, - DataType::Float32, - ApproxMedian, - ScalarValue::from(3_f32), - DataType::Float32 - ) - } - - #[test] - fn approx_median_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - generic_test_op!( - a, - DataType::Int32, - ApproxMedian, - ScalarValue::from(3), - DataType::Int32 - ) - } - - #[test] - fn approx_median_i32_with_nulls_2() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(5), - Some(1), - None, - None, - Some(3), - Some(4), - ])); - generic_test_op!( - a, - DataType::Int32, - ApproxMedian, - ScalarValue::from(2), - DataType::Int32 - ) - } - - #[test] - fn approx_median_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None])); - generic_test_op!( - a, - DataType::Int32, - ApproxMedian, - ScalarValue::from(0), - DataType::Int32 - ) - } - - fn aggregate( - batch: &RecordBatch, - agg: Arc, - ) -> Result { - let mut accum = agg.create_accumulator()?; - let expr = agg.expressions(); - let values = expr - .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) - .collect::>>()?; - accum.update_batch(&values)?; - accum.evaluate() - } -} diff --git a/datafusion/src/physical_plan/expressions/approx_median_old.rs b/datafusion/src/physical_plan/expressions/approx_median_old.rs new file mode 100644 index 0000000000000..cb621c8a72463 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/approx_median_old.rs @@ -0,0 +1,320 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution + +use std::any::Any; +use std::sync::Arc; + +use crate::error::Result; +use crate::physical_plan::{ + expressions::approx_percentile_cont::ApproxPercentileAccumulator, Accumulator, + AggregateExpr, PhysicalExpr, +}; +use crate::scalar::ScalarValue; +use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; + +use super::format_state_name; + +/// MEDIAN aggregate expression +#[derive(Debug)] +pub struct ApproxMedian { + name: String, + expr: Arc, + data_type: DataType, +} + +pub(crate) fn is_approx_median_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + +impl ApproxMedian { + /// Create a new APPROX_MEDIAN aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + expr, + data_type, + } + } +} + +impl AggregateExpr for ApproxMedian { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(ApproxMedianAccumulator::try_new( + self.data_type.clone(), + )?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "max_size"), + DataType::UInt64, + false, + ), + Field::new( + &format_state_name(&self.name, "sum"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "count"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "max"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "min"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "centroids"), + DataType::List(Box::new(Field::new("item", DataType::Float64, true))), + false, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +/// An accumulator to compute the approx_median. +/// It is using approx_percentile_cont under the hood, which is an approximation. +/// We will revist this and may provide an implementation to calculate the exact approx_median in the future. +#[derive(Debug)] +pub struct ApproxMedianAccumulator { + perc_cont: ApproxPercentileAccumulator, +} + +impl ApproxMedianAccumulator { + /// Creates a new `ApproxMedianAccumulator` + pub fn try_new(data_type: DataType) -> Result { + Ok(Self { + perc_cont: ApproxPercentileAccumulator::new(0.5_f64, data_type), + }) + } +} + +impl Accumulator for ApproxMedianAccumulator { + fn state(&self) -> Result> { + Ok(self.perc_cont.get_digest().to_scalar_state()) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.perc_cont.update_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.perc_cont.merge_batch(states) + } + + fn evaluate(&self) -> Result { + self.perc_cont.evaluate() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::from_slice::FromSlice; + use crate::physical_plan::expressions::col; + use crate::{error::Result, generic_test_op}; + use arrow::record_batch::RecordBatch; + use arrow::{array::*, datatypes::*}; + + #[test] + fn approx_median_f64_1() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64])); + generic_test_op!( + a, + DataType::Float64, + ApproxMedian, + ScalarValue::from(1.5_f64), + DataType::Float64 + ) + } + + #[test] + fn approx_median_f64_2() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1.1_f64, 2_f64, 3_f64])); + generic_test_op!( + a, + DataType::Float64, + ApproxMedian, + ScalarValue::from(2_f64), + DataType::Float64 + ) + } + + #[test] + fn approx_median_f64_3() -> Result<()> { + let a: ArrayRef = Arc::new(Float64Array::from_slice(&[ + 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, + ])); + generic_test_op!( + a, + DataType::Float64, + ApproxMedian, + ScalarValue::from(3_f64), + DataType::Float64 + ) + } + + #[test] + fn approx_median_i32() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 3, 4, 5])); + generic_test_op!( + a, + DataType::Int32, + ApproxMedian, + ScalarValue::from(3), + DataType::Int32 + ) + } + + #[test] + fn approx_median_u32() -> Result<()> { + let a: ArrayRef = Arc::new(UInt32Array::from_slice(&[ + 1_u32, 2_u32, 3_u32, 4_u32, 5_u32, + ])); + generic_test_op!( + a, + DataType::UInt32, + ApproxMedian, + ScalarValue::from(3_u32), + DataType::UInt32 + ) + } + + #[test] + fn approx_median_f32() -> Result<()> { + let a: ArrayRef = Arc::new(Float32Array::from_slice(&[ + 1_f32, 2_f32, 3_f32, 4_f32, 5_f32, + ])); + generic_test_op!( + a, + DataType::Float32, + ApproxMedian, + ScalarValue::from(3_f32), + DataType::Float32 + ) + } + + #[test] + fn approx_median_i32_with_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(3), + Some(4), + Some(5), + ])); + generic_test_op!( + a, + DataType::Int32, + ApproxMedian, + ScalarValue::from(3), + DataType::Int32 + ) + } + + #[test] + fn approx_median_i32_with_nulls_2() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(7), + Some(6), + Some(1), + None, + None, + None, + None, + Some(5), + Some(4), + ])); + generic_test_op!( + a, + DataType::Int32, + ApproxMedian, + ScalarValue::from(2), + DataType::Int32 + ) + } + + #[test] + fn approx_median_i32_all_nulls() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![None])); + generic_test_op!( + a, + DataType::Int32, + ApproxMedian, + ScalarValue::from(0), + DataType::Int32 + ) + } + + fn aggregate( + batch: &RecordBatch, + agg: Arc, + ) -> Result { + let mut accum = agg.create_accumulator()?; + let expr = agg.expressions(); + let values = expr + .iter() + .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + accum.update_batch(&values)?; + accum.evaluate() + } +} \ No newline at end of file diff --git a/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs index 688c94e6afa5b..2776d48eb0d5d 100644 --- a/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs +++ b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs @@ -204,9 +204,9 @@ impl ApproxPercentileAccumulator { } } - pub(crate) fn get_digest(&self) -> &TDigest { - &self.digest - } + // pub(crate) fn get_digest(&self) -> &TDigest { + // &self.digest + // } } impl Accumulator for ApproxPercentileAccumulator { From 89fcd51539aa497d3bca4543dede00c9bc776874 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Sun, 6 Feb 2022 12:46:12 -0800 Subject: [PATCH 08/14] move rewrite rule to earlier stages --- datafusion/src/execution/context.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 3f65b8b6b9e5f..301611c50363e 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -924,12 +924,14 @@ impl Default for ExecutionConfig { // Simplify expressions first to maximize the chance // of applying other optimizations Arc::new(SimplifyExpressions::new()), + // Renaming functions to percentile early in case + // other optimizations can be applied later + Arc::new(ToApproxPerc::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), Arc::new(ProjectionPushDown::new()), Arc::new(FilterPushDown::new()), Arc::new(LimitPushDown::new()), - Arc::new(ToApproxPerc::new()), Arc::new(SingleDistinctToGroupBy::new()), ], physical_optimizers: vec![ From 19848547f15de65d7d4b2938dc6046776b222d27 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Sun, 6 Feb 2022 13:13:36 -0800 Subject: [PATCH 09/14] fix lint --- datafusion/src/execution/context.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 301611c50363e..e2395e59191b8 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -924,7 +924,7 @@ impl Default for ExecutionConfig { // Simplify expressions first to maximize the chance // of applying other optimizations Arc::new(SimplifyExpressions::new()), - // Renaming functions to percentile early in case + // Renaming functions to percentile early in case // other optimizations can be applied later Arc::new(ToApproxPerc::new()), Arc::new(CommonSubexprEliminate::new()), From b9425d86095f89c9a01fa5dce19e9f9aa0061033 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Sun, 6 Feb 2022 13:42:40 -0800 Subject: [PATCH 10/14] move the rule after projection push down --- datafusion/src/execution/context.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index e2395e59191b8..01e48b93b21d0 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -924,12 +924,10 @@ impl Default for ExecutionConfig { // Simplify expressions first to maximize the chance // of applying other optimizations Arc::new(SimplifyExpressions::new()), - // Renaming functions to percentile early in case - // other optimizations can be applied later - Arc::new(ToApproxPerc::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), Arc::new(ProjectionPushDown::new()), + Arc::new(ToApproxPerc::new()), Arc::new(FilterPushDown::new()), Arc::new(LimitPushDown::new()), Arc::new(SingleDistinctToGroupBy::new()), From 7ddf81b5140294c23d4405a8e4935eb6445f0396 Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Mon, 7 Feb 2022 23:24:39 -0800 Subject: [PATCH 11/14] get ready to merge --- datafusion/src/execution/context.rs | 5 +- .../expressions/approx_median_old.rs | 320 ------------------ 2 files changed, 4 insertions(+), 321 deletions(-) delete mode 100644 datafusion/src/physical_plan/expressions/approx_median_old.rs diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 01e48b93b21d0..21bf59eaff8c0 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -927,10 +927,13 @@ impl Default for ExecutionConfig { Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), Arc::new(ProjectionPushDown::new()), - Arc::new(ToApproxPerc::new()), Arc::new(FilterPushDown::new()), Arc::new(LimitPushDown::new()), Arc::new(SingleDistinctToGroupBy::new()), + // ToApproxPerc must be applied last because + // it rewrites only the function and may interfere with + // other rules + Arc::new(ToApproxPerc::new()), ], physical_optimizers: vec![ Arc::new(AggregateStatistics::new()), diff --git a/datafusion/src/physical_plan/expressions/approx_median_old.rs b/datafusion/src/physical_plan/expressions/approx_median_old.rs deleted file mode 100644 index cb621c8a72463..0000000000000 --- a/datafusion/src/physical_plan/expressions/approx_median_old.rs +++ /dev/null @@ -1,320 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution - -use std::any::Any; -use std::sync::Arc; - -use crate::error::Result; -use crate::physical_plan::{ - expressions::approx_percentile_cont::ApproxPercentileAccumulator, Accumulator, - AggregateExpr, PhysicalExpr, -}; -use crate::scalar::ScalarValue; -use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; - -use super::format_state_name; - -/// MEDIAN aggregate expression -#[derive(Debug)] -pub struct ApproxMedian { - name: String, - expr: Arc, - data_type: DataType, -} - -pub(crate) fn is_approx_median_support_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - ) -} - -impl ApproxMedian { - /// Create a new APPROX_MEDIAN aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - } - } -} - -impl AggregateExpr for ApproxMedian { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(ApproxMedianAccumulator::try_new( - self.data_type.clone(), - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - &format_state_name(&self.name, "max_size"), - DataType::UInt64, - false, - ), - Field::new( - &format_state_name(&self.name, "sum"), - DataType::Float64, - false, - ), - Field::new( - &format_state_name(&self.name, "count"), - DataType::Float64, - false, - ), - Field::new( - &format_state_name(&self.name, "max"), - DataType::Float64, - false, - ), - Field::new( - &format_state_name(&self.name, "min"), - DataType::Float64, - false, - ), - Field::new( - &format_state_name(&self.name, "centroids"), - DataType::List(Box::new(Field::new("item", DataType::Float64, true))), - false, - ), - ]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -/// An accumulator to compute the approx_median. -/// It is using approx_percentile_cont under the hood, which is an approximation. -/// We will revist this and may provide an implementation to calculate the exact approx_median in the future. -#[derive(Debug)] -pub struct ApproxMedianAccumulator { - perc_cont: ApproxPercentileAccumulator, -} - -impl ApproxMedianAccumulator { - /// Creates a new `ApproxMedianAccumulator` - pub fn try_new(data_type: DataType) -> Result { - Ok(Self { - perc_cont: ApproxPercentileAccumulator::new(0.5_f64, data_type), - }) - } -} - -impl Accumulator for ApproxMedianAccumulator { - fn state(&self) -> Result> { - Ok(self.perc_cont.get_digest().to_scalar_state()) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - self.perc_cont.update_batch(values) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.perc_cont.merge_batch(states) - } - - fn evaluate(&self) -> Result { - self.perc_cont.evaluate() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::from_slice::FromSlice; - use crate::physical_plan::expressions::col; - use crate::{error::Result, generic_test_op}; - use arrow::record_batch::RecordBatch; - use arrow::{array::*, datatypes::*}; - - #[test] - fn approx_median_f64_1() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1_f64, 2_f64])); - generic_test_op!( - a, - DataType::Float64, - ApproxMedian, - ScalarValue::from(1.5_f64), - DataType::Float64 - ) - } - - #[test] - fn approx_median_f64_2() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(&[1.1_f64, 2_f64, 3_f64])); - generic_test_op!( - a, - DataType::Float64, - ApproxMedian, - ScalarValue::from(2_f64), - DataType::Float64 - ) - } - - #[test] - fn approx_median_f64_3() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from_slice(&[ - 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, - ])); - generic_test_op!( - a, - DataType::Float64, - ApproxMedian, - ScalarValue::from(3_f64), - DataType::Float64 - ) - } - - #[test] - fn approx_median_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Int32, - ApproxMedian, - ScalarValue::from(3), - DataType::Int32 - ) - } - - #[test] - fn approx_median_u32() -> Result<()> { - let a: ArrayRef = Arc::new(UInt32Array::from_slice(&[ - 1_u32, 2_u32, 3_u32, 4_u32, 5_u32, - ])); - generic_test_op!( - a, - DataType::UInt32, - ApproxMedian, - ScalarValue::from(3_u32), - DataType::UInt32 - ) - } - - #[test] - fn approx_median_f32() -> Result<()> { - let a: ArrayRef = Arc::new(Float32Array::from_slice(&[ - 1_f32, 2_f32, 3_f32, 4_f32, 5_f32, - ])); - generic_test_op!( - a, - DataType::Float32, - ApproxMedian, - ScalarValue::from(3_f32), - DataType::Float32 - ) - } - - #[test] - fn approx_median_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - generic_test_op!( - a, - DataType::Int32, - ApproxMedian, - ScalarValue::from(3), - DataType::Int32 - ) - } - - #[test] - fn approx_median_i32_with_nulls_2() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(7), - Some(6), - Some(1), - None, - None, - None, - None, - Some(5), - Some(4), - ])); - generic_test_op!( - a, - DataType::Int32, - ApproxMedian, - ScalarValue::from(2), - DataType::Int32 - ) - } - - #[test] - fn approx_median_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None])); - generic_test_op!( - a, - DataType::Int32, - ApproxMedian, - ScalarValue::from(0), - DataType::Int32 - ) - } - - fn aggregate( - batch: &RecordBatch, - agg: Arc, - ) -> Result { - let mut accum = agg.create_accumulator()?; - let expr = agg.expressions(); - let values = expr - .iter() - .map(|e| e.evaluate(batch)) - .map(|r| r.map(|v| v.into_array(batch.num_rows()))) - .collect::>>()?; - accum.update_batch(&values)?; - accum.evaluate() - } -} \ No newline at end of file From 5c514fbe5beec777ed814d4f031529a43608302b Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Mon, 7 Feb 2022 23:45:14 -0800 Subject: [PATCH 12/14] remove unused function --- .../coercion_rule/aggregate_rule.rs | 9 ++++----- .../physical_plan/expressions/approx_median.rs | 16 ---------------- datafusion/src/physical_plan/expressions/mod.rs | 2 +- 3 files changed, 5 insertions(+), 22 deletions(-) diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index 482f61f064840..47d406579241b 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -20,10 +20,9 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::{ - is_approx_median_support_arg_type, is_avg_support_arg_type, - is_correlation_support_arg_type, is_covariance_support_arg_type, - is_stddev_support_arg_type, is_sum_support_arg_type, is_variance_support_arg_type, - try_cast, + is_avg_support_arg_type, is_correlation_support_arg_type, + is_covariance_support_arg_type, is_stddev_support_arg_type, is_sum_support_arg_type, + is_variance_support_arg_type, try_cast, }; use crate::physical_plan::functions::{Signature, TypeSignature}; use crate::physical_plan::PhysicalExpr; @@ -156,7 +155,7 @@ pub(crate) fn coerce_types( Ok(input_types.to_vec()) } AggregateFunction::ApproxMedian => { - if !is_approx_median_support_arg_type(&input_types[0]) { + if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { return Err(DataFusionError::Plan(format!( "The function {:?} does not support inputs of type {:?}.", agg_fun, input_types[0] diff --git a/datafusion/src/physical_plan/expressions/approx_median.rs b/datafusion/src/physical_plan/expressions/approx_median.rs index fd8c542de912e..2ca585759c6b4 100644 --- a/datafusion/src/physical_plan/expressions/approx_median.rs +++ b/datafusion/src/physical_plan/expressions/approx_median.rs @@ -32,22 +32,6 @@ pub struct ApproxMedian { data_type: DataType, } -pub(crate) fn is_approx_median_support_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - ) -} - impl ApproxMedian { /// Create a new APPROX_MEDIAN aggregate function pub fn new( diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 0f83cf8c32232..06afe004ff344 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -66,7 +66,7 @@ pub mod helpers { } pub use approx_distinct::ApproxDistinct; -pub(crate) use approx_median::{is_approx_median_support_arg_type, ApproxMedian}; +pub(crate) use approx_median::ApproxMedian; pub use approx_percentile_cont::{ is_approx_percentile_cont_supported_arg_type, ApproxPercentileCont, }; From edf3495914c7369bca2bd3840137d96c0804e84b Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Tue, 8 Feb 2022 17:09:03 -0800 Subject: [PATCH 13/14] remove commented out code --- .../src/physical_plan/expressions/approx_percentile_cont.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs index 2776d48eb0d5d..cba30ee481abc 100644 --- a/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs +++ b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs @@ -203,10 +203,6 @@ impl ApproxPercentileAccumulator { return_type, } } - - // pub(crate) fn get_digest(&self) -> &TDigest { - // &self.digest - // } } impl Accumulator for ApproxPercentileAccumulator { From f48db261c301027cd3d5022e4eaa0a5869c16dcf Mon Sep 17 00:00:00 2001 From: Lin Ma Date: Tue, 8 Feb 2022 17:10:10 -0800 Subject: [PATCH 14/14] Update datafusion/src/optimizer/to_approx_perc.rs Co-authored-by: Andrew Lamb --- datafusion/src/optimizer/to_approx_perc.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/src/optimizer/to_approx_perc.rs b/datafusion/src/optimizer/to_approx_perc.rs index e88efe86019f9..c33c3f67602a1 100644 --- a/datafusion/src/optimizer/to_approx_perc.rs +++ b/datafusion/src/optimizer/to_approx_perc.rs @@ -151,7 +151,7 @@ mod tests { .aggregate(Vec::::new(), vec![expr])? .build()?; - // Do nothing + // Rewrite to use approx_percentile let expected = "Aggregate: groupBy=[[]], aggr=[[APPROXPERCENTILECONT(#test.b, Float64(0.5))]] [APPROXMEDIAN(test.b):UInt32;N]\ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";