Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ use arrow::util::pretty::pretty_format_batches;
use datafusion::{assert_batches_eq, dataframe};
use datafusion_functions_aggregate::count::{count_all, count_all_window};
use datafusion_functions_aggregate::expr_fn::{
array_agg, avg, count, count_distinct, max, median, min, sum,
array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum,
sum_distinct,
};
use datafusion_functions_nested::make_array::make_array_udf;
use datafusion_functions_window::expr_fn::{first_value, row_number};
Expand Down Expand Up @@ -496,32 +497,35 @@ async fn drop_with_periods() -> Result<()> {
#[tokio::test]
async fn aggregate() -> Result<()> {
// build plan using DataFrame API
let df = test_table().await?;
// union so some of the distincts have a clearly distinct result
let df = test_table().await?.union(test_table().await?)?;
let group_expr = vec![col("c1")];
let aggr_expr = vec![
min(col("c12")),
max(col("c12")),
avg(col("c12")),
sum(col("c12")),
count(col("c12")),
count_distinct(col("c12")),
min(col("c4")).alias("min(c4)"),
max(col("c4")).alias("max(c4)"),
avg(col("c4")).alias("avg(c4)"),
avg_distinct(col("c4")).alias("avg_distinct(c4)"),
sum(col("c4")).alias("sum(c4)"),
sum_distinct(col("c4")).alias("sum_distinct(c4)"),
count(col("c4")).alias("count(c4)"),
count_distinct(col("c4")).alias("count_distinct(c4)"),
Comment on lines +504 to +511
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched to c4 from c12 as c12 had some precision variations for avg_distinct leading to inconsistent test results, and figured it was easier to switch columns than slap round on the outputs

];

let df: Vec<RecordBatch> = df.aggregate(group_expr, aggr_expr)?.collect().await?;

assert_snapshot!(
batches_to_sort_string(&df),
@r###"
+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+
| c1 | min(aggregate_test_100.c12) | max(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) |
+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+
| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |
| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |
| c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 |
| d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 |
| e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 |
+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+
"###
@r"
+----+---------+---------+---------------------+---------------------+---------+------------------+-----------+--------------------+
| c1 | min(c4) | max(c4) | avg(c4) | avg_distinct(c4) | sum(c4) | sum_distinct(c4) | count(c4) | count_distinct(c4) |
+----+---------+---------+---------------------+---------------------+---------+------------------+-----------+--------------------+
| a | -28462 | 32064 | 306.04761904761904 | 306.04761904761904 | 12854 | 6427 | 42 | 21 |
| b | -28070 | 25286 | 7732.315789473684 | 7732.315789473684 | 293828 | 146914 | 38 | 19 |
| c | -30508 | 29106 | -1320.5238095238096 | -1320.5238095238096 | -55462 | -27731 | 42 | 21 |
| d | -24558 | 31106 | 10890.111111111111 | 10890.111111111111 | 392044 | 196022 | 36 | 18 |
| e | -31500 | 32514 | -4268.333333333333 | -4268.333333333333 | -179270 | -89635 | 42 | 21 |
+----+---------+---------+---------------------+---------------------+---------+------------------+-----------+--------------------+
"
);

Ok(())
Expand All @@ -536,7 +540,9 @@ async fn aggregate_assert_no_empty_batches() -> Result<()> {
min(col("c12")),
max(col("c12")),
avg(col("c12")),
avg_distinct(col("c12")),
sum(col("c12")),
sum_distinct(col("c12")),
count(col("c12")),
count_distinct(col("c12")),
median(col("c12")),
Expand Down
13 changes: 12 additions & 1 deletion datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_typ
use datafusion_expr::utils::format_state_name;
use datafusion_expr::Volatility::Immutable;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator,
Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator,
ReversedUDAF, Signature,
};

Expand All @@ -62,6 +62,17 @@ make_udaf_expr_and_func!(
avg_udaf
);

pub fn avg_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
avg_udaf(),
vec![expr],
true,
None,
vec![],
None,
))
}
Comment on lines +65 to +74
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as how count handles it:

pub fn count_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
count_udaf(),
vec![expr],
true,
None,
vec![],
None,
))
}


#[user_doc(
doc_section(label = "General Functions"),
description = "Returns the average of numeric values in the specified column.",
Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ pub mod expr_fn {
pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight;
pub use super::array_agg::array_agg;
pub use super::average::avg;
pub use super::average::avg_distinct;
pub use super::bit_and_or_xor::bit_and;
pub use super::bit_and_or_xor::bit_or;
pub use super::bit_and_or_xor::bit_xor;
Expand Down Expand Up @@ -134,6 +135,7 @@ pub mod expr_fn {
pub use super::stddev::stddev;
pub use super::stddev::stddev_pop;
pub use super::sum::sum;
pub use super::sum::sum_distinct;
pub use super::variance::var_pop;
pub use super::variance::var_sample;
}
Expand Down
12 changes: 12 additions & 0 deletions datafusion/functions-aggregate/src/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

use ahash::RandomState;
use datafusion_expr::utils::AggregateOrderSensitivity;
use datafusion_expr::Expr;
use std::any::Any;
use std::mem::size_of_val;

Expand Down Expand Up @@ -53,6 +54,17 @@ make_udaf_expr_and_func!(
sum_udaf
);

pub fn sum_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
sum_udaf(),
vec![expr],
true,
None,
vec![],
None,
))
}

/// Sum only supports a subset of numeric types, instead relying on type coercion
///
/// This macro is similar to [downcast_primitive](arrow::array::downcast_primitive)
Expand Down
7 changes: 5 additions & 2 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use datafusion::execution::options::ArrowReadOptions;
use datafusion::optimizer::eliminate_nested_union::EliminateNestedUnion;
use datafusion::optimizer::Optimizer;
use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_functions_aggregate::sum::sum_distinct;
use prost::Message;
use std::any::Any;
use std::collections::HashMap;
Expand Down Expand Up @@ -82,8 +83,8 @@ use datafusion_expr::{
};
use datafusion_functions_aggregate::average::avg_udaf;
use datafusion_functions_aggregate::expr_fn::{
approx_distinct, array_agg, avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr,
nth_value,
approx_distinct, array_agg, avg, avg_distinct, bit_and, bit_or, bit_xor, bool_and,
bool_or, corr, nth_value,
};
use datafusion_functions_aggregate::string_agg::string_agg;
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
Expand Down Expand Up @@ -967,10 +968,12 @@ async fn roundtrip_expr_api() -> Result<()> {
functions_window::nth_value::last_value(lit(1)),
functions_window::nth_value::nth_value(lit(1), 1),
avg(lit(1.5)),
avg_distinct(lit(1.5)),
covar_samp(lit(1.5), lit(2.2)),
covar_pop(lit(1.5), lit(2.2)),
corr(lit(1.5), lit(2.2)),
sum(lit(1)),
sum_distinct(lit(1)),
max(lit(1)),
median(lit(2)),
min(lit(2)),
Expand Down
4 changes: 3 additions & 1 deletion docs/source/user-guide/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ select log(-1), log(0), sqrt(-1);
| Syntax | Description |
| ------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- |
| avg(expr) | Сalculates the average value for `expr`. |
| avg_distinct(expr) | Creates an expression to represent the avg(distinct) aggregate function |
| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. |
| approx_median(expr) | Calculates an approximation of the median for `expr`. |
| approx_percentile_cont(expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr`. Optional `centroids` parameter controls accuracy (default: 100). |
Expand All @@ -298,14 +299,15 @@ select log(-1), log(0), sqrt(-1);
| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. |
| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. |
| count(expr) | Returns the number of rows for `expr`. |
| count_distinct | Creates an expression to represent the count(distinct) aggregate function |
| count_distinct(expr) | Creates an expression to represent the count(distinct) aggregate function |
| cube(exprs) | Creates a grouping set for all combination of `exprs` |
| grouping_set(exprs) | Create a grouping set. |
| max(expr) | Finds the maximum value of `expr`. |
| median(expr) | Сalculates the median of `expr`. |
| min(expr) | Finds the minimum value of `expr`. |
| rollup(exprs) | Creates a grouping set for rollup sets. |
| sum(expr) | Сalculates the sum of `expr`. |
| sum_distinct(expr) | Creates an expression to represent the sum(distinct) aggregate function |

## Aggregate Function Builder

Expand Down