diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R index 459b7435a87..beb18e82039 100644 --- a/r/R/dplyr-summarize.R +++ b/r/R/dplyr-summarize.R @@ -51,10 +51,6 @@ summarise.Dataset <- summarise.ArrowTabular <- summarise.arrow_dplyr_query # This is the Arrow summarize implementation do_arrow_summarize <- function(.data, ..., .groups = NULL) { - if (!is.null(.groups)) { - # ARROW-13550 - abort("`summarize()` with `.groups` argument not supported in Arrow") - } exprs <- ensure_named_exprs(quos(...)) # Create a stateful environment for recording our evaluated expressions @@ -97,6 +93,35 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) { ctx$post_mutate )[c(.data$group_by_vars, names(exprs))] } + + # Handle .groups argument + if (length(.data$group_by_vars)) { + if (is.null(.groups)) { + # dplyr docs say: + # When ‘.groups’ is not specified, it is chosen based on the + # number of rows of the results: + # • If all the results have 1 row, you get "drop_last". + # • If the number of rows varies, you get "keep". + # + # But we don't support anything that returns multiple rows now + .groups <- "drop_last" + } else { + assert_that(is.string(.groups)) + } + if (.groups == "drop_last") { + out$group_by_vars <- head(.data$group_by_vars, -1) + } else if (.groups == "keep") { + out$group_by_vars <- .data$group_by_vars + } else if (.groups == "rowwise") { + stop(arrow_not_supported('.groups = "rowwise"')) + } else if (.groups != "drop") { + # Drop means don't group by anything so there's nothing to do. + # Anything else is invalid + stop(paste("Invalid .groups argument:", .groups)) + } + # TODO: shouldn't we be doing something with `drop_empty_groups` in summarize? (ARROW-14044) + out$drop_empty_groups <- .data$drop_empty_groups + } out } diff --git a/r/tests/testthat/helper-expectation.R b/r/tests/testthat/helper-expectation.R index 72f07f32c96..e765bd6cf54 100644 --- a/r/tests/testthat/helper-expectation.R +++ b/r/tests/testthat/helper-expectation.R @@ -105,7 +105,7 @@ expect_dplyr_equal <- function(expr, ), warning ) - expect_equivalent(via_batch, expected, ...) + expect_equal(via_batch, expected, ...) } else { skip_msg <- c(skip_msg, skip_record_batch) } @@ -118,7 +118,7 @@ expect_dplyr_equal <- function(expr, ), warning ) - expect_equivalent(via_table, expected, ...) + expect_equal(via_table, expected, ...) } else { skip_msg <- c(skip_msg, skip_table) } diff --git a/r/tests/testthat/test-chunked-array.R b/r/tests/testthat/test-chunked-array.R index 8ec8952a129..8ff4e6684c4 100644 --- a/r/tests/testthat/test-chunked-array.R +++ b/r/tests/testthat/test-chunked-array.R @@ -203,10 +203,12 @@ test_that("ChunkedArray supports difftime", { }) test_that("ChunkedArray supports empty arrays (ARROW-13761)", { - types <- c(int8(), int16(), int32(), int64(), uint8(), uint16(), uint32(), - uint64(), float32(), float64(), timestamp("ns"), binary(), - large_binary(), fixed_size_binary(32), date32(), date64(), - decimal(4, 2)) + types <- c( + int8(), int16(), int32(), int64(), uint8(), uint16(), uint32(), + uint64(), float32(), float64(), timestamp("ns"), binary(), + large_binary(), fixed_size_binary(32), date32(), date64(), + decimal(4, 2) + ) empty_filter <- ChunkedArray$create(type = bool()) for (type in types) { diff --git a/r/tests/testthat/test-compute-no-bindings.R b/r/tests/testthat/test-compute-no-bindings.R index 05beb924d77..0546b98c0af 100644 --- a/r/tests/testthat/test-compute-no-bindings.R +++ b/r/tests/testthat/test-compute-no-bindings.R @@ -121,7 +121,6 @@ test_that("non-bound compute kernels using ModeOptions", { }) test_that("non-bound compute kernels using PartitionNthOptions", { - result <- call_function( "partition_nth_indices", Array$create(c(11:20)), @@ -131,7 +130,6 @@ test_that("non-bound compute kernels using PartitionNthOptions", { # (depends on C++ standard library implementation) expect_true(all(as.vector(result[1:3]) < 3)) expect_true(all(as.vector(result[4:10]) >= 3)) - }) diff --git a/r/tests/testthat/test-dplyr-string-functions.R b/r/tests/testthat/test-dplyr-string-functions.R index b6b8f5a714a..1098619f3a5 100644 --- a/r/tests/testthat/test-dplyr-string-functions.R +++ b/r/tests/testthat/test-dplyr-string-functions.R @@ -410,49 +410,63 @@ test_that("strsplit and str_split", { input %>% mutate(x = strsplit(x, "and")) %>% collect(), - df + df, + # Pass check.attributes = FALSE through to expect_equal + # (which gives you expect_equivalent() behavior). + # This is because the vctr that comes back from arrow (ListArray) + # has type information in it, but it's just a bare list from R/dplyr. + # Note also that whenever we bump up to testthat 3rd edition (ARROW-12871), + # the parameter is called `ignore_attr = TRUE` + check.attributes = FALSE ) expect_dplyr_equal( input %>% mutate(x = strsplit(x, "and.*", fixed = TRUE)) %>% collect(), - df + df, + check.attributes = FALSE ) expect_dplyr_equal( input %>% mutate(x = strsplit(x, " +and +")) %>% collect(), - df + df, + check.attributes = FALSE ) expect_dplyr_equal( input %>% mutate(x = str_split(x, "and")) %>% collect(), - df + df, + check.attributes = FALSE ) expect_dplyr_equal( input %>% mutate(x = str_split(x, "and", n = 2)) %>% collect(), - df + df, + check.attributes = FALSE ) expect_dplyr_equal( input %>% mutate(x = str_split(x, fixed("and"), n = 2)) %>% collect(), - df + df, + check.attributes = FALSE ) expect_dplyr_equal( input %>% mutate(x = str_split(x, regex("and"), n = 2)) %>% collect(), - df + df, + check.attributes = FALSE ) expect_dplyr_equal( input %>% mutate(x = str_split(x, "Foo|bar", n = 2)) %>% collect(), - df + df, + check.attributes = FALSE ) }) diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index fa4bffe30d7..12bb50fb3d5 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -390,7 +390,7 @@ test_that("Expressions on aggregations", { ) # Aggregate on an aggregate (trivial but dplyr allows) - skip("Not supported") + skip("Aggregate on an aggregate not supported") expect_dplyr_equal( input %>% group_by(some_grouping) %>% @@ -468,3 +468,54 @@ test_that("Not (yet) supported: implicit join", { warning = "Expression dbl - int not supported in Arrow; pulling data into R" ) }) + +test_that(".groups argument", { + expect_dplyr_equal( + input %>% + group_by(some_grouping, int < 6) %>% + summarize(count = n()) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + group_by(some_grouping, int < 6) %>% + summarize(count = n(), .groups = "drop_last") %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + group_by(some_grouping, int < 6) %>% + summarize(count = n(), .groups = "keep") %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + group_by(some_grouping, int < 6) %>% + summarize(count = n(), .groups = "drop") %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + group_by(some_grouping, int < 6) %>% + summarize(count = n(), .groups = "rowwise") %>% + collect(), + tbl, + warning = TRUE + ) + + # abandon_ship() raises the warning, then dplyr itself errors + # This isn't ideal but it's fine and won't be an issue on Datasets + expect_error( + expect_warning( + Table$create(tbl) %>% + group_by(some_grouping, int < 6) %>% + summarize(count = n(), .groups = "NOTVALID"), + "Invalid .groups argument" + ), + "NOTVALID" + ) +})