diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index b35c87e87f3..e535546dd1b 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -801,7 +801,6 @@ agg_funcs$all <- function(x, na.rm = FALSE) { options = list(na.rm = na.rm, na.min_count = 0L) ) } - agg_funcs$mean <- function(x, na.rm = FALSE) { list( fun = "mean", @@ -825,6 +824,15 @@ agg_funcs$var <- function(x, na.rm = FALSE, ddof = 1) { options = list(ddof = ddof) ) } + +agg_funcs$n_distinct <- function(x, na.rm = FALSE) { + list( + fun = "count_distinct", + data = x, + options = list(na.rm = na.rm) + ) +} + agg_funcs$n <- function() { list( fun = "sum", diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 5468fa83113..788e75a03d2 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -194,6 +194,14 @@ std::shared_ptr make_compute_options( return out; } + if (func_name == "hash_count_distinct") { + using Options = arrow::compute::CountOptions; + auto out = std::make_shared(Options::Defaults()); + out->mode = + cpp11::as_cpp(options["na.rm"]) ? Options::ONLY_VALID : Options::ALL; + return out; + } + if (func_name == "min_element_wise" || func_name == "max_element_wise") { using Options = arrow::compute::ElementWiseAggregateOptions; bool skip_nulls = true; diff --git a/r/tests/testthat/test-dplyr-aggregate.R b/r/tests/testthat/test-dplyr-aggregate.R index 8d68bbaf48e..3a04b6d2314 100644 --- a/r/tests/testthat/test-dplyr-aggregate.R +++ b/r/tests/testthat/test-dplyr-aggregate.R @@ -235,6 +235,23 @@ test_that("Group by any/all", { ) }) +test_that("Group by n_distinct() on dataset", { + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize(distinct = n_distinct(lgl, na.rm = FALSE)) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize(distinct = n_distinct(lgl, na.rm = TRUE)) %>% + collect(), + tbl + ) +}) + test_that("Filter and aggregate", { expect_dplyr_equal( input %>%