diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index 808956efe15..e9d0d17f730 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -839,7 +839,6 @@ agg_funcs$var <- function(x, na.rm = FALSE, ddof = 1) { options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof) ) } - agg_funcs$n_distinct <- function(x, na.rm = FALSE) { list( fun = "count_distinct", @@ -847,7 +846,6 @@ agg_funcs$n_distinct <- function(x, na.rm = FALSE) { options = list(na.rm = na.rm) ) } - agg_funcs$n <- function() { list( fun = "sum", @@ -855,6 +853,28 @@ agg_funcs$n <- function() { options = list() ) } +agg_funcs$min <- function(..., na.rm = FALSE) { + args <- list2(...) + if (length(args) > 1) { + arrow_not_supported("Multiple arguments to min()") + } + list( + fun = "min", + data = args[[1]], + options = list(skip_nulls = na.rm, min_count = 0L) + ) +} +agg_funcs$max <- function(..., na.rm = FALSE) { + args <- list2(...) + if (length(args) > 1) { + arrow_not_supported("Multiple arguments to max()") + } + list( + fun = "max", + data = args[[1]], + options = list(skip_nulls = na.rm, min_count = 0L) + ) +} output_type <- function(fun, input_type) { # These are quick and dirty heuristics. diff --git a/r/src/compute.cpp b/r/src/compute.cpp index c6ba0a28046..a01c35be4ef 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -171,10 +171,11 @@ std::shared_ptr make_compute_options( return out; } - if (func_name == "min_max" || func_name == "sum" || func_name == "mean" || - func_name == "any" || func_name == "all" || func_name == "hash_min_max" || - func_name == "hash_sum" || func_name == "hash_mean" || func_name == "hash_any" || - func_name == "hash_all") { + if (func_name == "all" || func_name == "hash_all" || func_name == "any" || + func_name == "hash_any" || func_name == "mean" || func_name == "hash_mean" || + func_name == "min_max" || func_name == "hash_min_max" || func_name == "min" || + func_name == "hash_min" || func_name == "max" || func_name == "hash_max" || + func_name == "sum" || func_name == "hash_sum") { using Options = arrow::compute::ScalarAggregateOptions; auto out = std::make_shared(Options::Defaults()); if (!Rf_isNull(options["min_count"])) { diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 14157545d61..837bf8048c5 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -621,17 +621,15 @@ test_that("Creating UnionDataset", { }) test_that("map_batches", { + skip("map_batches() is broken (ARROW-14029)") skip_if_not_available("parquet") ds <- open_dataset(dataset_dir, partitioning = "part") - expect_warning( - expect_equivalent( - ds %>% - filter(int > 5) %>% - select(int, lgl) %>% - map_batches(~ summarize(., min_int = min(int))), - tibble(min_int = c(6L, 101L)) - ), - "pulling data into R" # ARROW-13502 + expect_equivalent( + ds %>% + filter(int > 5) %>% + select(int, lgl) %>% + map_batches(~ summarize(., min_int = min(int))), + tibble(min_int = c(6L, 101L)) ) }) diff --git a/r/tests/testthat/test-dplyr-group-by.R b/r/tests/testthat/test-dplyr-group-by.R index 18be2a9304a..d6abb20c01c 100644 --- a/r/tests/testthat/test-dplyr-group-by.R +++ b/r/tests/testthat/test-dplyr-group-by.R @@ -28,9 +28,8 @@ test_that("group_by groupings are recorded", { group_by(chr) %>% select(int, chr) %>% filter(int > 5) %>% - summarize(min_int = min(int)), - tbl, - warning = TRUE + collect(), + tbl ) }) @@ -62,9 +61,23 @@ test_that("ungroup", { select(int, chr) %>% ungroup() %>% filter(int > 5) %>% - summarize(min_int = min(int)), - tbl, - warning = TRUE + collect(), + tbl + ) + + # to confirm that the above expectation is actually testing what we think it's + # testing, verify that expect_dplyr_equal() distinguishes between grouped and + # ungrouped tibbles + expect_error( + expect_dplyr_equal( + input %>% + group_by(chr) %>% + select(int, chr) %>% + (function(x) if (inherits(x, "tbl_df")) ungroup(x) else x) %>% + filter(int > 5) %>% + collect(), + tbl + ) ) }) diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index 12bb50fb3d5..c74ed6aa938 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -30,28 +30,6 @@ tbl$verses <- verses[[1]] tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2 * (1:10) + 1, side = "both") tbl$some_grouping <- rep(c(1, 2), 5) -test_that("summarize", { - expect_dplyr_equal( - input %>% - select(int, chr) %>% - filter(int > 5) %>% - summarize(min_int = min(int)) %>% - collect(), - tbl, - warning = TRUE - ) - - expect_dplyr_equal( - input %>% - select(int, chr) %>% - filter(int > 5) %>% - summarize(min_int = min(int) / 2) %>% - collect(), - tbl, - warning = TRUE - ) -}) - test_that("summarize() doesn't evaluate eagerly", { expect_s3_class( Table$create(tbl) %>% @@ -251,6 +229,131 @@ test_that("Group by n_distinct() on dataset", { ) }) +test_that("summarize() with min() and max()", { + expect_dplyr_equal( + input %>% + select(int, chr) %>% + filter(int > 5) %>% # this filters out the NAs in `int` + summarize(min_int = min(int), max_int = max(int)) %>% + collect(), + tbl, + ) + expect_dplyr_equal( + input %>% + select(int, chr) %>% + filter(int > 5) %>% # this filters out the NAs in `int` + summarize( + min_int = min(int + 4) / 2, + max_int = 3 / max(42 - int) + ) %>% + collect(), + tbl, + ) + expect_dplyr_equal( + input %>% + select(int, chr) %>% + summarize(min_int = min(int), max_int = max(int)) %>% + collect(), + tbl, + ) + expect_dplyr_equal( + input %>% + select(int) %>% + summarize( + min_int = min(int, na.rm = TRUE), + max_int = max(int, na.rm = TRUE) + ) %>% + collect(), + tbl, + ) + expect_dplyr_equal( + input %>% + select(dbl, int) %>% + summarize( + min_int = -min(log(ceiling(dbl)), na.rm = TRUE), + max_int = log(max(as.double(int), na.rm = TRUE)) + ) %>% + collect(), + tbl, + ) + + # multiple dots arguments to min(), max() not supported + expect_dplyr_equal( + input %>% + summarize(min_mult = min(dbl, int)) %>% + collect(), + tbl, + warning = "Multiple arguments to min\\(\\) not supported by Arrow" + ) + expect_dplyr_equal( + input %>% + select(int, dbl, dbl2) %>% + summarize(max_mult = max(int, dbl, dbl2)) %>% + collect(), + tbl, + warning = "Multiple arguments to max\\(\\) not supported by Arrow" + ) + + # min(logical) or max(logical) yields integer in R + # min(Boolean) or max(Boolean) yields Boolean in Arrow + expect_dplyr_equal( + input %>% + select(lgl) %>% + summarize( + max_lgl = as.logical(max(lgl, na.rm = TRUE)), + min_lgl = as.logical(min(lgl, na.rm = TRUE)) + ) %>% + collect(), + tbl, + ) +}) + +test_that("min() and max() on character strings", { + expect_dplyr_equal( + input %>% + summarize( + min_chr = min(chr, na.rm = TRUE), + max_chr = max(chr, na.rm = TRUE) + ) %>% + collect(), + tbl, + ) + skip("Strings not supported by hash_min_max (ARROW-13988)") + expect_dplyr_equal( + input %>% + group_by(fct) %>% + summarize( + min_chr = min(chr, na.rm = TRUE), + max_chr = max(chr, na.rm = TRUE) + ) %>% + collect(), + tbl, + ) +}) + +test_that("summarise() with !!sym()", { + test_chr_col <- "int" + test_dbl_col <- "dbl" + test_lgl_col <- "lgl" + expect_dplyr_equal( + input %>% + group_by(false) %>% + summarise( + sum = sum(!!sym(test_dbl_col)), + any = any(!!sym(test_lgl_col)), + all = all(!!sym(test_lgl_col)), + mean = mean(!!sym(test_dbl_col)), + sd = sd(!!sym(test_dbl_col)), + var = var(!!sym(test_dbl_col)), + n_distinct = n_distinct(!!sym(test_chr_col)), + min = min(!!sym(test_dbl_col)), + max = max(!!sym(test_dbl_col)) + ) %>% + collect(), + tbl + ) +}) + test_that("Filter and aggregate", { expect_dplyr_equal( input %>%