From 9a419b7056e743f5c67e211242b52d8bebb90742 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Fri, 26 Jul 2024 17:28:40 +0800 Subject: [PATCH 1/4] UDAF input types --- .../tests/user_defined/user_defined_plan.rs | 8 +++--- datafusion/expr/src/function.rs | 8 +++--- .../src/approx_distinct.rs | 2 +- .../functions-aggregate/src/approx_median.rs | 2 +- .../src/approx_percentile_cont.rs | 2 +- .../functions-aggregate/src/array_agg.rs | 12 ++++---- datafusion/functions-aggregate/src/average.rs | 16 +++++------ datafusion/functions-aggregate/src/count.rs | 4 +-- .../functions-aggregate/src/first_last.rs | 2 +- datafusion/functions-aggregate/src/median.rs | 4 +-- .../functions-aggregate/src/nth_value.rs | 4 +-- datafusion/functions-aggregate/src/stddev.rs | 4 +-- .../physical-expr-common/src/aggregate/mod.rs | 28 ++++++++++--------- 13 files changed, 50 insertions(+), 46 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index a44f522ba95ac..47804b927e641 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -68,6 +68,10 @@ use arrow::{ record_batch::RecordBatch, util::pretty::pretty_format_batches, }; +use async_trait::async_trait; +use futures::{Stream, StreamExt}; + +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::{ common::cast::{as_int64_array, as_string_array}, common::{arrow_datafusion_err, internal_err, DFSchemaRef}, @@ -90,16 +94,12 @@ use datafusion::{ physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, prelude::{SessionConfig, SessionContext}, }; - -use async_trait::async_trait; -use datafusion::execution::session_state::SessionStateBuilder; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; use datafusion_expr::Projection; use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; -use futures::{Stream, StreamExt}; /// Execute the specified sql and return the resulting record batches /// pretty printed as a String. diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index d722e55de487c..7333de705aadc 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -94,8 +94,8 @@ pub struct AccumulatorArgs<'a> { /// ``` pub is_distinct: bool, - /// The input type of the aggregate function. - pub input_type: &'a DataType, + /// The input types of the aggregate function. + pub input_type: &'a [DataType], /// The logical expression of arguments the aggregate function takes. pub input_exprs: &'a [Expr], @@ -109,8 +109,8 @@ pub struct StateFieldsArgs<'a> { /// The name of the aggregate function. pub name: &'a str, - /// The input type of the aggregate function. - pub input_type: &'a DataType, + /// The input types of the aggregate function. + pub input_type: &'a [DataType], /// The return type of the aggregate function. pub return_type: &'a DataType, diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 7c6aef9944f69..9a90822e4f0bd 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -277,7 +277,7 @@ impl AggregateUDFImpl for ApproxDistinct { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let accumulator: Box = match acc_args.input_type { + let accumulator: Box = match &acc_args.input_type[0] { // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL // TODO support for boolean (trivial case) // https://github.com/apache/datafusion/issues/1109 diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index bc723c8629539..7e790e354c9d3 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -113,7 +113,7 @@ impl AggregateUDFImpl for ApproxMedian { Ok(Box::new(ApproxPercentileAccumulator::new( 0.5_f64, - acc_args.input_type.clone(), + acc_args.input_type[0].clone(), ))) } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index dfb94a84cbecc..deb9ed9704115 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -104,7 +104,7 @@ impl ApproxPercentileCont { None }; - let accumulator: ApproxPercentileAccumulator = match args.input_type { + let accumulator: ApproxPercentileAccumulator = match &args.input_type[0] { t @ (DataType::UInt8 | DataType::UInt16 | DataType::UInt32 diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 96b39ae4121eb..1dc483085e0b6 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -89,14 +89,14 @@ impl AggregateUDFImpl for ArrayAgg { if args.is_distinct { return Ok(vec![Field::new_list( format_state_name(args.name, "distinct_array_agg"), - Field::new("item", args.input_type.clone(), true), + Field::new("item", args.input_type[0].clone(), true), true, )]); } let mut fields = vec![Field::new_list( format_state_name(args.name, "array_agg"), - Field::new("item", args.input_type.clone(), true), + Field::new("item", args.input_type[0].clone(), true), true, )]; @@ -117,12 +117,14 @@ impl AggregateUDFImpl for ArrayAgg { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if acc_args.is_distinct { return Ok(Box::new(DistinctArrayAggAccumulator::try_new( - acc_args.input_type, + &acc_args.input_type[0], )?)); } if acc_args.sort_exprs.is_empty() { - return Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?)); + return Ok(Box::new(ArrayAggAccumulator::try_new( + &acc_args.input_type[0], + )?)); } let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( @@ -136,7 +138,7 @@ impl AggregateUDFImpl for ArrayAgg { .collect::>>()?; OrderSensitiveArrayAggAccumulator::try_new( - acc_args.input_type, + &acc_args.input_type[0], &ordering_dtypes, ordering_req, acc_args.is_reversed, diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 18642fb843293..310c3e6261d61 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -93,7 +93,7 @@ impl AggregateUDFImpl for Avg { } use DataType::*; // instantiate specialized accumulator based for the type - match (acc_args.input_type, acc_args.data_type) { + match (&acc_args.input_type[0], acc_args.data_type) { (Float64, Float64) => Ok(Box::::default()), ( Decimal128(sum_precision, sum_scale), @@ -120,7 +120,7 @@ impl AggregateUDFImpl for Avg { })), _ => exec_err!( "AvgAccumulator for ({} --> {})", - acc_args.input_type, + &acc_args.input_type[0], acc_args.data_type ), } @@ -135,7 +135,7 @@ impl AggregateUDFImpl for Avg { ), Field::new( format_state_name(args.name, "sum"), - args.input_type.clone(), + args.input_type[0].clone(), true, ), ]) @@ -154,10 +154,10 @@ impl AggregateUDFImpl for Avg { ) -> Result> { use DataType::*; // instantiate specialized accumulator based for the type - match (args.input_type, args.data_type) { + match (&args.input_type[0], args.data_type) { (Float64, Float64) => { Ok(Box::new(AvgGroupsAccumulator::::new( - args.input_type, + &args.input_type[0], args.data_type, |sum: f64, count: u64| Ok(sum / count as f64), ))) @@ -176,7 +176,7 @@ impl AggregateUDFImpl for Avg { move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); Ok(Box::new(AvgGroupsAccumulator::::new( - args.input_type, + &args.input_type[0], args.data_type, avg_fn, ))) @@ -197,7 +197,7 @@ impl AggregateUDFImpl for Avg { }; Ok(Box::new(AvgGroupsAccumulator::::new( - args.input_type, + &args.input_type[0], args.data_type, avg_fn, ))) @@ -205,7 +205,7 @@ impl AggregateUDFImpl for Avg { _ => not_impl_err!( "AvgGroupsAccumulator for ({} --> {})", - args.input_type, + &args.input_type[0], args.data_type ), } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 0ead22e90a163..b475cc24035b6 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -125,7 +125,7 @@ impl AggregateUDFImpl for Count { if args.is_distinct { Ok(vec![Field::new_list( format_state_name(args.name, "count distinct"), - Field::new("item", args.input_type.clone(), true), + Field::new("item", args.input_type[0].clone(), true), false, )]) } else { @@ -146,7 +146,7 @@ impl AggregateUDFImpl for Count { return not_impl_err!("COUNT DISTINCT with multiple arguments"); } - let data_type = acc_args.input_type; + let data_type = &acc_args.input_type[0]; Ok(match data_type { // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator DataType::Int8 => Box::new( diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 8969937d377c4..f4f845f2b8497 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -447,7 +447,7 @@ impl AggregateUDFImpl for LastValue { } = args; let mut fields = vec![Field::new( format_state_name(name, "last_value"), - input_type.clone(), + input_type[0].clone(), true, )]; fields.extend(ordering_fields.to_vec()); diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index bb926b8da2712..956993eeb72b0 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -102,7 +102,7 @@ impl AggregateUDFImpl for Median { fn state_fields(&self, args: StateFieldsArgs) -> Result> { //Intermediate state is a list of the elements we have collected so far - let field = Field::new("item", args.input_type.clone(), true); + let field = Field::new("item", args.input_type[0].clone(), true); let state_name = if args.is_distinct { "distinct_median" } else { @@ -133,7 +133,7 @@ impl AggregateUDFImpl for Median { }; } - let dt = acc_args.input_type; + let dt = &acc_args.input_type[0]; downcast_integer! { dt => (helper, dt), DataType::Float16 => helper!(Float16Type, dt), diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 9bbd68c9bdf60..49c3a15d765c4 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -124,7 +124,7 @@ impl AggregateUDFImpl for NthValueAgg { NthValueAccumulator::try_new( n, - acc_args.input_type, + &acc_args.input_type[0], &ordering_dtypes, ordering_req, ) @@ -138,7 +138,7 @@ impl AggregateUDFImpl for NthValueAgg { // The hard-coded `true` should be changed once the field for // nullability is added to `StateFieldArgs` struct. // See: https://github.com/apache/datafusion/pull/11063 - Field::new("item", args.input_type.clone(), true), + Field::new("item", args.input_type[0].clone(), true), false, )]; let orderings = args.ordering_fields.to_vec(); diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 247962dc2ce11..5989845457400 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -335,7 +335,7 @@ mod tests { name: "a", is_distinct: false, is_reversed: false, - input_type: &DataType::Float64, + input_type: &[DataType::Float64], input_exprs: &[datafusion_expr::col("a")], }; @@ -348,7 +348,7 @@ mod tests { name: "a", is_distinct: false, is_reversed: false, - input_type: &DataType::Float64, + input_type: &[DataType::Float64], input_exprs: &[datafusion_expr::col("a")], }; diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index b58a5a6faf242..0d9587b8ab6ea 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -15,31 +15,33 @@ // specific language governing permissions and limitations // under the License. -pub mod count_distinct; -pub mod groups_accumulator; -pub mod merge_arrays; -pub mod stats; -pub mod tdigest; -pub mod utils; +use std::fmt::Debug; +use std::{any::Any, sync::Arc}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + +use datafusion_common::exec_err; use datafusion_common::{internal_err, not_impl_err, DFSchema, Result}; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::type_coercion::aggregates::check_arg_count; +use datafusion_expr::utils::AggregateOrderSensitivity; use datafusion_expr::ReversedUDAF; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, }; -use std::fmt::Debug; -use std::{any::Any, sync::Arc}; -use self::utils::down_cast_any_ref; use crate::physical_expr::PhysicalExpr; use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; use crate::utils::reverse_order_bys; -use datafusion_common::exec_err; -use datafusion_expr::utils::AggregateOrderSensitivity; +use self::utils::down_cast_any_ref; + +pub mod count_distinct; +pub mod groups_accumulator; +pub mod merge_arrays; +pub mod stats; +pub mod tdigest; +pub mod utils; /// Creates a physical expression of the UDAF, that includes all necessary type coercion. /// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. @@ -225,7 +227,7 @@ impl AggregateExprBuilder { ignore_nulls, ordering_fields, is_distinct, - input_type: input_exprs_types[0].clone(), + input_type: input_exprs_types, is_reversed, })) } @@ -466,7 +468,7 @@ pub struct AggregateFunctionExpr { ordering_fields: Vec, is_distinct: bool, is_reversed: bool, - input_type: DataType, + input_type: Vec, } impl AggregateFunctionExpr { From cfde493ef43137cbc613baa7d33aa93d23fd7a9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Fri, 26 Jul 2024 17:38:11 +0800 Subject: [PATCH 2/4] Rename --- datafusion/expr/src/function.rs | 4 ++-- .../functions-aggregate/src/approx_distinct.rs | 2 +- .../functions-aggregate/src/approx_median.rs | 2 +- .../src/approx_percentile_cont.rs | 2 +- datafusion/functions-aggregate/src/array_agg.rs | 10 +++++----- datafusion/functions-aggregate/src/average.rs | 16 ++++++++-------- datafusion/functions-aggregate/src/count.rs | 4 ++-- datafusion/functions-aggregate/src/first_last.rs | 4 ++-- datafusion/functions-aggregate/src/median.rs | 4 ++-- datafusion/functions-aggregate/src/nth_value.rs | 4 ++-- datafusion/functions-aggregate/src/stddev.rs | 4 ++-- .../physical-expr-common/src/aggregate/mod.rs | 14 +++++++------- 12 files changed, 35 insertions(+), 35 deletions(-) diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 7333de705aadc..d8be2b4347323 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -95,7 +95,7 @@ pub struct AccumulatorArgs<'a> { pub is_distinct: bool, /// The input types of the aggregate function. - pub input_type: &'a [DataType], + pub input_types: &'a [DataType], /// The logical expression of arguments the aggregate function takes. pub input_exprs: &'a [Expr], @@ -110,7 +110,7 @@ pub struct StateFieldsArgs<'a> { pub name: &'a str, /// The input types of the aggregate function. - pub input_type: &'a [DataType], + pub input_types: &'a [DataType], /// The return type of the aggregate function. pub return_type: &'a DataType, diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 9a90822e4f0bd..56ef32e7ebe07 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -277,7 +277,7 @@ impl AggregateUDFImpl for ApproxDistinct { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let accumulator: Box = match &acc_args.input_type[0] { + let accumulator: Box = match &acc_args.input_types[0] { // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL // TODO support for boolean (trivial case) // https://github.com/apache/datafusion/issues/1109 diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 7e790e354c9d3..e12e3445a83ed 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -113,7 +113,7 @@ impl AggregateUDFImpl for ApproxMedian { Ok(Box::new(ApproxPercentileAccumulator::new( 0.5_f64, - acc_args.input_type[0].clone(), + acc_args.input_types[0].clone(), ))) } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index deb9ed9704115..16837dc80748c 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -104,7 +104,7 @@ impl ApproxPercentileCont { None }; - let accumulator: ApproxPercentileAccumulator = match &args.input_type[0] { + let accumulator: ApproxPercentileAccumulator = match &args.input_types[0] { t @ (DataType::UInt8 | DataType::UInt16 | DataType::UInt32 diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 1dc483085e0b6..7352ea12fe54a 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -89,14 +89,14 @@ impl AggregateUDFImpl for ArrayAgg { if args.is_distinct { return Ok(vec![Field::new_list( format_state_name(args.name, "distinct_array_agg"), - Field::new("item", args.input_type[0].clone(), true), + Field::new("item", args.input_types[0].clone(), true), true, )]); } let mut fields = vec![Field::new_list( format_state_name(args.name, "array_agg"), - Field::new("item", args.input_type[0].clone(), true), + Field::new("item", args.input_types[0].clone(), true), true, )]; @@ -117,13 +117,13 @@ impl AggregateUDFImpl for ArrayAgg { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if acc_args.is_distinct { return Ok(Box::new(DistinctArrayAggAccumulator::try_new( - &acc_args.input_type[0], + &acc_args.input_types[0], )?)); } if acc_args.sort_exprs.is_empty() { return Ok(Box::new(ArrayAggAccumulator::try_new( - &acc_args.input_type[0], + &acc_args.input_types[0], )?)); } @@ -138,7 +138,7 @@ impl AggregateUDFImpl for ArrayAgg { .collect::>>()?; OrderSensitiveArrayAggAccumulator::try_new( - &acc_args.input_type[0], + &acc_args.input_types[0], &ordering_dtypes, ordering_req, acc_args.is_reversed, diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 310c3e6261d61..228bce1979a38 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -93,7 +93,7 @@ impl AggregateUDFImpl for Avg { } use DataType::*; // instantiate specialized accumulator based for the type - match (&acc_args.input_type[0], acc_args.data_type) { + match (&acc_args.input_types[0], acc_args.data_type) { (Float64, Float64) => Ok(Box::::default()), ( Decimal128(sum_precision, sum_scale), @@ -120,7 +120,7 @@ impl AggregateUDFImpl for Avg { })), _ => exec_err!( "AvgAccumulator for ({} --> {})", - &acc_args.input_type[0], + &acc_args.input_types[0], acc_args.data_type ), } @@ -135,7 +135,7 @@ impl AggregateUDFImpl for Avg { ), Field::new( format_state_name(args.name, "sum"), - args.input_type[0].clone(), + args.input_types[0].clone(), true, ), ]) @@ -154,10 +154,10 @@ impl AggregateUDFImpl for Avg { ) -> Result> { use DataType::*; // instantiate specialized accumulator based for the type - match (&args.input_type[0], args.data_type) { + match (&args.input_types[0], args.data_type) { (Float64, Float64) => { Ok(Box::new(AvgGroupsAccumulator::::new( - &args.input_type[0], + &args.input_types[0], args.data_type, |sum: f64, count: u64| Ok(sum / count as f64), ))) @@ -176,7 +176,7 @@ impl AggregateUDFImpl for Avg { move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); Ok(Box::new(AvgGroupsAccumulator::::new( - &args.input_type[0], + &args.input_types[0], args.data_type, avg_fn, ))) @@ -197,7 +197,7 @@ impl AggregateUDFImpl for Avg { }; Ok(Box::new(AvgGroupsAccumulator::::new( - &args.input_type[0], + &args.input_types[0], args.data_type, avg_fn, ))) @@ -205,7 +205,7 @@ impl AggregateUDFImpl for Avg { _ => not_impl_err!( "AvgGroupsAccumulator for ({} --> {})", - &args.input_type[0], + &args.input_types[0], args.data_type ), } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index b475cc24035b6..206e9c33db8b1 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -125,7 +125,7 @@ impl AggregateUDFImpl for Count { if args.is_distinct { Ok(vec![Field::new_list( format_state_name(args.name, "count distinct"), - Field::new("item", args.input_type[0].clone(), true), + Field::new("item", args.input_types[0].clone(), true), false, )]) } else { @@ -146,7 +146,7 @@ impl AggregateUDFImpl for Count { return not_impl_err!("COUNT DISTINCT with multiple arguments"); } - let data_type = &acc_args.input_type[0]; + let data_type = &acc_args.input_types[0]; Ok(match data_type { // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator DataType::Int8 => Box::new( diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index f4f845f2b8497..587767b8e356a 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -440,14 +440,14 @@ impl AggregateUDFImpl for LastValue { fn state_fields(&self, args: StateFieldsArgs) -> Result> { let StateFieldsArgs { name, - input_type, + input_types, return_type: _, ordering_fields, is_distinct: _, } = args; let mut fields = vec![Field::new( format_state_name(name, "last_value"), - input_type[0].clone(), + input_types[0].clone(), true, )]; fields.extend(ordering_fields.to_vec()); diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 956993eeb72b0..febf1fcd2fefb 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -102,7 +102,7 @@ impl AggregateUDFImpl for Median { fn state_fields(&self, args: StateFieldsArgs) -> Result> { //Intermediate state is a list of the elements we have collected so far - let field = Field::new("item", args.input_type[0].clone(), true); + let field = Field::new("item", args.input_types[0].clone(), true); let state_name = if args.is_distinct { "distinct_median" } else { @@ -133,7 +133,7 @@ impl AggregateUDFImpl for Median { }; } - let dt = &acc_args.input_type[0]; + let dt = &acc_args.input_types[0]; downcast_integer! { dt => (helper, dt), DataType::Float16 => helper!(Float16Type, dt), diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 49c3a15d765c4..1473aef65fb22 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -124,7 +124,7 @@ impl AggregateUDFImpl for NthValueAgg { NthValueAccumulator::try_new( n, - &acc_args.input_type[0], + &acc_args.input_types[0], &ordering_dtypes, ordering_req, ) @@ -138,7 +138,7 @@ impl AggregateUDFImpl for NthValueAgg { // The hard-coded `true` should be changed once the field for // nullability is added to `StateFieldArgs` struct. // See: https://github.com/apache/datafusion/pull/11063 - Field::new("item", args.input_type[0].clone(), true), + Field::new("item", args.input_types[0].clone(), true), false, )]; let orderings = args.ordering_fields.to_vec(); diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 5989845457400..df757ddc04226 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -335,7 +335,7 @@ mod tests { name: "a", is_distinct: false, is_reversed: false, - input_type: &[DataType::Float64], + input_types: &[DataType::Float64], input_exprs: &[datafusion_expr::col("a")], }; @@ -348,7 +348,7 @@ mod tests { name: "a", is_distinct: false, is_reversed: false, - input_type: &[DataType::Float64], + input_types: &[DataType::Float64], input_exprs: &[datafusion_expr::col("a")], }; diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 0d9587b8ab6ea..0a11e6b451eac 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -227,7 +227,7 @@ impl AggregateExprBuilder { ignore_nulls, ordering_fields, is_distinct, - input_type: input_exprs_types, + input_types: input_exprs_types, is_reversed, })) } @@ -468,7 +468,7 @@ pub struct AggregateFunctionExpr { ordering_fields: Vec, is_distinct: bool, is_reversed: bool, - input_type: Vec, + input_types: Vec, } impl AggregateFunctionExpr { @@ -506,7 +506,7 @@ impl AggregateExpr for AggregateFunctionExpr { fn state_fields(&self) -> Result> { let args = StateFieldsArgs { name: &self.name, - input_type: &self.input_type, + input_types: &self.input_types, return_type: &self.data_type, ordering_fields: &self.ordering_fields, is_distinct: self.is_distinct, @@ -527,7 +527,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_type: &self.input_type, + input_types: &self.input_types, input_exprs: &self.logical_args, name: &self.name, is_reversed: self.is_reversed, @@ -544,7 +544,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_type: &self.input_type, + input_types: &self.input_types, input_exprs: &self.logical_args, name: &self.name, is_reversed: self.is_reversed, @@ -616,7 +616,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_type: &self.input_type, + input_types: &self.input_types, input_exprs: &self.logical_args, name: &self.name, is_reversed: self.is_reversed, @@ -632,7 +632,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_type: &self.input_type, + input_types: &self.input_types, input_exprs: &self.logical_args, name: &self.name, is_reversed: self.is_reversed, From 1a3c5ca7b0162efebe9bcc923c9d03307e41e445 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 27 Jul 2024 07:27:52 -0400 Subject: [PATCH 3/4] Update COMMENTS.md --- datafusion/functions-aggregate/COMMENTS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/COMMENTS.md b/datafusion/functions-aggregate/COMMENTS.md index 23a996faf0075..3816cc1827113 100644 --- a/datafusion/functions-aggregate/COMMENTS.md +++ b/datafusion/functions-aggregate/COMMENTS.md @@ -54,7 +54,7 @@ first argument and the definition looks like this: // `input_type` : data type of the first argument let mut fields = vec![Field::new_list( format_state_name(self.name(), "nth_value"), - Field::new("item", args.input_type.clone(), true /* nullable of list item */ ), + Field::new("item", args.input_types.clone(), true /* nullable of list item */ ), false, // nullable of list itself )]; ``` From 23f8878b79bd33d0c55d5b4b56df99a5e49eb6b2 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 27 Jul 2024 08:59:08 -0400 Subject: [PATCH 4/4] Update datafusion/functions-aggregate/COMMENTS.md --- datafusion/functions-aggregate/COMMENTS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/COMMENTS.md b/datafusion/functions-aggregate/COMMENTS.md index 3816cc1827113..e669e13557115 100644 --- a/datafusion/functions-aggregate/COMMENTS.md +++ b/datafusion/functions-aggregate/COMMENTS.md @@ -54,7 +54,7 @@ first argument and the definition looks like this: // `input_type` : data type of the first argument let mut fields = vec![Field::new_list( format_state_name(self.name(), "nth_value"), - Field::new("item", args.input_types.clone(), true /* nullable of list item */ ), + Field::new("item", args.input_types[0].clone(), true /* nullable of list item */ ), false, // nullable of list itself )]; ```