From eea5b2c64b811e99e7f732ffd256354177e67d5b Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Mon, 28 Dec 2020 12:02:07 +0100 Subject: [PATCH] Support count_distinct in DataFrame API --- rust/datafusion/src/execution/dataframe_impl.rs | 3 ++- rust/datafusion/src/logical_plan/expr.rs | 9 +++++++++ rust/datafusion/src/logical_plan/mod.rs | 6 +++--- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index 7b47aa218dc..db8b86b0137 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -207,6 +207,7 @@ mod tests { avg(col("c12")), sum(col("c12")), count(col("c12")), + count_distinct(col("c12")), ]; let df = df.aggregate(group_expr, aggr_expr)?; @@ -214,7 +215,7 @@ mod tests { let plan = df.to_logical_plan(); // build same plan using SQL API - let sql = "SELECT c1, MIN(c12), MAX(c12), AVG(c12), SUM(c12), COUNT(c12) \ + let sql = "SELECT c1, MIN(c12), MAX(c12), AVG(c12), SUM(c12), COUNT(c12), COUNT(DISTINCT c12) \ FROM aggregate_test_100 \ GROUP BY c1"; let sql_plan = create_plan(sql)?; diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 0ae26a37364..4aee03c4c12 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -570,6 +570,15 @@ pub fn count(expr: Expr) -> Expr { } } +/// Create an expression to represent the count(distinct) aggregate function +pub fn count_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction { + fun: aggregates::AggregateFunction::Count, + distinct: true, + args: vec![expr], + } +} + /// Whether it can be represented as a literal expression pub trait Literal { /// convert the value to a Literal expression diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index f810b0162c4..0d37da6f6a3 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -35,9 +35,9 @@ pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, col, concat, cos, - count, create_udaf, create_udf, exp, exprlist_to_fields, floor, length, lit, ln, - log10, log2, lower, max, min, or, round, signum, sin, sqrt, sum, tan, trim, trunc, - upper, when, Expr, Literal, + count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, + length, lit, ln, log10, log2, lower, max, min, or, round, signum, sin, sqrt, sum, + tan, trim, trunc, upper, when, Expr, Literal, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator;