diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index a1fa715646d..b35c87e87f3 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -825,3 +825,10 @@ agg_funcs$var <- function(x, na.rm = FALSE, ddof = 1) { options = list(ddof = ddof) ) } +agg_funcs$n <- function() { + list( + fun = "sum", + data = Expression$scalar(1L), + options = list() + ) +} diff --git a/r/src/compute.cpp b/r/src/compute.cpp index aec9d9a9f50..5468fa83113 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -177,8 +177,12 @@ std::shared_ptr make_compute_options( func_name == "hash_all") { using Options = arrow::compute::ScalarAggregateOptions; auto out = std::make_shared(Options::Defaults()); - out->min_count = cpp11::as_cpp(options["na.min_count"]); - out->skip_nulls = cpp11::as_cpp(options["na.rm"]); + if (!Rf_isNull(options["na.min_count"])) { + out->min_count = cpp11::as_cpp(options["na.min_count"]); + } + if (!Rf_isNull(options["na.rm"])) { + out->skip_nulls = cpp11::as_cpp(options["na.rm"]); + } return out; } diff --git a/r/tests/testthat/test-dplyr-aggregate.R b/r/tests/testthat/test-dplyr-aggregate.R index 7c2f7f890b9..8d68bbaf48e 100644 --- a/r/tests/testthat/test-dplyr-aggregate.R +++ b/r/tests/testthat/test-dplyr-aggregate.R @@ -154,10 +154,26 @@ test_that("Group by var on dataset", { ) }) - -test_that("Group by any/all", { +test_that("n()", { withr::local_options(list(arrow.debug = TRUE)) + expect_dplyr_equal( + input %>% + summarize(counts = n()) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize(counts = n()) %>% + arrange(some_grouping) %>% + collect(), + tbl + ) +}) + +test_that("Group by any/all", { expect_dplyr_equal( input %>% group_by(some_grouping) %>% diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index ed03c58a884..d3a9994b5f1 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -26,6 +26,7 @@ tbl$verses <- verses[[1]] # c(" a ", " b ", " c ", ...) increasing padding # nchar = 3 5 7 9 11 13 15 17 19 21 tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2 * (1:10) + 1, side = "both") +tbl$another_chr <- tail(letters, 10) test_that("basic select/filter/collect", { batch <- record_batch(tbl) @@ -961,7 +962,7 @@ test_that("No duplicate field names are allowed in an arrow_dplyr_query", { filter(int > 0), regexp = paste0( 'The following field names were found more than once in the data: "int", "dbl", ', - '"dbl2", "lgl", "false", "chr", "fct", "verses", and "padded_strings"' + '"dbl2", "lgl", "false", "chr", "fct", "verses", "padded_strings"' ) ) }) @@ -1109,9 +1110,6 @@ test_that("trig functions", { }) test_that("if_else and ifelse", { - tbl <- example_data - tbl$another_chr <- tail(letters, 10) - expect_dplyr_equal( input %>% mutate( @@ -1342,7 +1340,6 @@ test_that("case_when()", { ) ) - skip("case_when does not yet support with variable-width types (ARROW-13222)") expect_dplyr_equal( input %>% transmute(cw = case_when(lgl ~ "abc")) %>% @@ -1355,10 +1352,11 @@ test_that("case_when()", { collect(), tbl ) + skip("ARROW-13799: factor() should error but instead we get a string error message in its place") expect_dplyr_equal( input %>% mutate( - cw = paste0(case_when(!(!(!(lgl))) ~ factor(chr), TRUE ~ fct), "!") + cw = case_when(!(!(!(lgl))) ~ factor(chr), TRUE ~ fct) ) %>% collect(), tbl