From 34a2ccbba07da5630ec837ce9a613731fde3d288 Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Fri, 12 Sep 2025 17:06:19 +0900 Subject: [PATCH 1/2] Introduce `avg_distinct()` and `sum_distinct()` functions to DataFrame API --- datafusion/core/tests/dataframe/mod.rs | 44 +++++++++++-------- datafusion/functions-aggregate/src/average.rs | 13 +++++- datafusion/functions-aggregate/src/lib.rs | 2 + datafusion/functions-aggregate/src/sum.rs | 12 +++++ docs/source/user-guide/expressions.md | 4 +- 5 files changed, 54 insertions(+), 21 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index a563459f42a11..fa4131c089d7b 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -35,7 +35,8 @@ use arrow::util::pretty::pretty_format_batches; use datafusion::{assert_batches_eq, dataframe}; use datafusion_functions_aggregate::count::{count_all, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ - array_agg, avg, count, count_distinct, max, median, min, sum, + array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum, + sum_distinct, }; use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_window::expr_fn::{first_value, row_number}; @@ -496,32 +497,35 @@ async fn drop_with_periods() -> Result<()> { #[tokio::test] async fn aggregate() -> Result<()> { // build plan using DataFrame API - let df = test_table().await?; + // union so some of the distincts have a clearly distinct result + let df = test_table().await?.union(test_table().await?)?; let group_expr = vec![col("c1")]; let aggr_expr = vec![ - min(col("c12")), - max(col("c12")), - avg(col("c12")), - sum(col("c12")), - count(col("c12")), - count_distinct(col("c12")), + min(col("c4")).alias("min(c4)"), + max(col("c4")).alias("max(c4)"), + avg(col("c4")).alias("avg(c4)"), + avg_distinct(col("c4")).alias("avg_distinct(c4)"), + sum(col("c4")).alias("sum(c4)"), + sum_distinct(col("c4")).alias("sum_distinct(c4)"), + count(col("c4")).alias("count(c4)"), + count_distinct(col("c4")).alias("count_distinct(c4)"), ]; let df: Vec = df.aggregate(group_expr, aggr_expr)?.collect().await?; assert_snapshot!( batches_to_sort_string(&df), - @r###" - +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+ - | c1 | min(aggregate_test_100.c12) | max(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) | - +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+ - | a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 | - | b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 | - | c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 | - | d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 | - | e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 | - +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+ - "### + @r" + +----+---------+---------+---------------------+---------------------+---------+------------------+-----------+--------------------+ + | c1 | min(c4) | max(c4) | avg(c4) | avg_distinct(c4) | sum(c4) | sum_distinct(c4) | count(c4) | count_distinct(c4) | + +----+---------+---------+---------------------+---------------------+---------+------------------+-----------+--------------------+ + | a | -28462 | 32064 | 306.04761904761904 | 306.04761904761904 | 12854 | 6427 | 42 | 21 | + | b | -28070 | 25286 | 7732.315789473684 | 7732.315789473684 | 293828 | 146914 | 38 | 19 | + | c | -30508 | 29106 | -1320.5238095238096 | -1320.5238095238096 | -55462 | -27731 | 42 | 21 | + | d | -24558 | 31106 | 10890.111111111111 | 10890.111111111111 | 392044 | 196022 | 36 | 18 | + | e | -31500 | 32514 | -4268.333333333333 | -4268.333333333333 | -179270 | -89635 | 42 | 21 | + +----+---------+---------+---------------------+---------------------+---------+------------------+-----------+--------------------+ + " ); Ok(()) @@ -536,7 +540,9 @@ async fn aggregate_assert_no_empty_batches() -> Result<()> { min(col("c12")), max(col("c12")), avg(col("c12")), + avg_distinct(col("c12")), sum(col("c12")), + sum_distinct(col("c12")), count(col("c12")), count_distinct(col("c12")), median(col("c12")), diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index f7cb74fd55a25..8694a7fc0dbd8 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -36,7 +36,7 @@ use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_typ use datafusion_expr::utils::format_state_name; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, + Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator, ReversedUDAF, Signature, }; @@ -62,6 +62,17 @@ make_udaf_expr_and_func!( avg_udaf ); +pub fn avg_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + avg_udaf(), + vec![expr], + true, + None, + vec![], + None, + )) +} + #[user_doc( doc_section(label = "General Functions"), description = "Returns the average of numeric values in the specified column.", diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index b5bb69f6da9d8..8236d456fd929 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -105,6 +105,7 @@ pub mod expr_fn { pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; pub use super::array_agg::array_agg; pub use super::average::avg; + pub use super::average::avg_distinct; pub use super::bit_and_or_xor::bit_and; pub use super::bit_and_or_xor::bit_or; pub use super::bit_and_or_xor::bit_xor; @@ -134,6 +135,7 @@ pub mod expr_fn { pub use super::stddev::stddev; pub use super::stddev::stddev_pop; pub use super::sum::sum; + pub use super::sum::sum_distinct; pub use super::variance::var_pop; pub use super::variance::var_sample; } diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 445c7dfe6b7af..5d0459ef703f3 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -19,6 +19,7 @@ use ahash::RandomState; use datafusion_expr::utils::AggregateOrderSensitivity; +use datafusion_expr::Expr; use std::any::Any; use std::mem::size_of_val; @@ -53,6 +54,17 @@ make_udaf_expr_and_func!( sum_udaf ); +pub fn sum_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + sum_udaf(), + vec![expr], + true, + None, + vec![], + None, + )) +} + /// Sum only supports a subset of numeric types, instead relying on type coercion /// /// This macro is similar to [downcast_primitive](arrow::array::downcast_primitive) diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index abf0286fa85bd..56e4369a9b8b5 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -288,6 +288,7 @@ select log(-1), log(0), sqrt(-1); | Syntax | Description | | ------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | | avg(expr) | Сalculates the average value for `expr`. | +| avg_distinct(expr) | Creates an expression to represent the avg(distinct) aggregate function | | approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. | | approx_median(expr) | Calculates an approximation of the median for `expr`. | | approx_percentile_cont(expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr`. Optional `centroids` parameter controls accuracy (default: 100). | @@ -298,7 +299,7 @@ select log(-1), log(0), sqrt(-1); | bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. | | bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. | | count(expr) | Returns the number of rows for `expr`. | -| count_distinct | Creates an expression to represent the count(distinct) aggregate function | +| count_distinct(expr) | Creates an expression to represent the count(distinct) aggregate function | | cube(exprs) | Creates a grouping set for all combination of `exprs` | | grouping_set(exprs) | Create a grouping set. | | max(expr) | Finds the maximum value of `expr`. | @@ -306,6 +307,7 @@ select log(-1), log(0), sqrt(-1); | min(expr) | Finds the minimum value of `expr`. | | rollup(exprs) | Creates a grouping set for rollup sets. | | sum(expr) | Сalculates the sum of `expr`. | +| sum_distinct(expr) | Creates an expression to represent the sum(distinct) aggregate function | ## Aggregate Function Builder From 99386feb229574a42966f316f1b6e46e0fcdef00 Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Sun, 14 Sep 2025 16:00:35 +0900 Subject: [PATCH 2/2] Add to roundtrip proto tests --- datafusion/proto/tests/cases/roundtrip_logical_plan.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index c76036a4344fb..1a50857bf0a41 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -32,6 +32,7 @@ use datafusion::execution::options::ArrowReadOptions; use datafusion::optimizer::eliminate_nested_union::EliminateNestedUnion; use datafusion::optimizer::Optimizer; use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_functions_aggregate::sum::sum_distinct; use prost::Message; use std::any::Any; use std::collections::HashMap; @@ -82,8 +83,8 @@ use datafusion_expr::{ }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::expr_fn::{ - approx_distinct, array_agg, avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, - nth_value, + approx_distinct, array_agg, avg, avg_distinct, bit_and, bit_or, bit_xor, bool_and, + bool_or, corr, nth_value, }; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_functions_window_common::field::WindowUDFFieldArgs; @@ -967,10 +968,12 @@ async fn roundtrip_expr_api() -> Result<()> { functions_window::nth_value::last_value(lit(1)), functions_window::nth_value::nth_value(lit(1), 1), avg(lit(1.5)), + avg_distinct(lit(1.5)), covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), corr(lit(1.5), lit(2.2)), sum(lit(1)), + sum_distinct(lit(1)), max(lit(1)), median(lit(2)), min(lit(2)),