From 3d7badc3f3c286a0ea22c72fb5672b54c5bae117 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 24 Aug 2021 13:50:09 -0400 Subject: [PATCH 1/3] Start implementing dplyr::n() (blocked by segfault) --- r/R/dplyr-functions.R | 7 +++++++ r/src/compute.cpp | 8 ++++++-- r/tests/testthat/test-dplyr-aggregate.R | 21 ++++++++++++++++++++- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index a1fa715646d..83ce488bed1 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -825,3 +825,10 @@ agg_funcs$var <- function(x, na.rm = FALSE, ddof = 1) { options = list(ddof = ddof) ) } +agg_funcs$n <- function() { + list( + fun = "sum", + data = Expression$scalar(1L), + options = list() + ) +} diff --git a/r/src/compute.cpp b/r/src/compute.cpp index aec9d9a9f50..5468fa83113 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -177,8 +177,12 @@ std::shared_ptr make_compute_options( func_name == "hash_all") { using Options = arrow::compute::ScalarAggregateOptions; auto out = std::make_shared(Options::Defaults()); - out->min_count = cpp11::as_cpp(options["na.min_count"]); - out->skip_nulls = cpp11::as_cpp(options["na.rm"]); + if (!Rf_isNull(options["na.min_count"])) { + out->min_count = cpp11::as_cpp(options["na.min_count"]); + } + if (!Rf_isNull(options["na.rm"])) { + out->skip_nulls = cpp11::as_cpp(options["na.rm"]); + } return out; } diff --git a/r/tests/testthat/test-dplyr-aggregate.R b/r/tests/testthat/test-dplyr-aggregate.R index 7c2f7f890b9..81134f98259 100644 --- a/r/tests/testthat/test-dplyr-aggregate.R +++ b/r/tests/testthat/test-dplyr-aggregate.R @@ -154,6 +154,25 @@ test_that("Group by var on dataset", { ) }) +test_that("n()", { + withr::local_options(list(arrow.debug = TRUE)) + expect_dplyr_equal( + input %>% + summarize(counts = n()) %>% + collect(), + tbl + ) + + skip("segfault") + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize(counts = n()) %>% + arrange(some_grouping) %>% + collect(), + tbl + ) +}) test_that("Group by any/all", { withr::local_options(list(arrow.debug = TRUE)) @@ -255,4 +274,4 @@ test_that("Filter and aggregate", { collect(), tbl ) -}) +}) \ No newline at end of file From d43642890ad1849f3ec4b4d80f1d550bb318cd6d Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Mon, 30 Aug 2021 12:43:13 -0400 Subject: [PATCH 2/3] Unskip test and lint --- r/R/dplyr-functions.R | 4 ++-- r/tests/testthat/test-dplyr-aggregate.R | 29 +++++++++++-------------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index 83ce488bed1..b35c87e87f3 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -826,9 +826,9 @@ agg_funcs$var <- function(x, na.rm = FALSE, ddof = 1) { ) } agg_funcs$n <- function() { - list( + list( fun = "sum", data = Expression$scalar(1L), options = list() - ) + ) } diff --git a/r/tests/testthat/test-dplyr-aggregate.R b/r/tests/testthat/test-dplyr-aggregate.R index 81134f98259..8d68bbaf48e 100644 --- a/r/tests/testthat/test-dplyr-aggregate.R +++ b/r/tests/testthat/test-dplyr-aggregate.R @@ -155,28 +155,25 @@ test_that("Group by var on dataset", { }) test_that("n()", { - withr::local_options(list(arrow.debug = TRUE)) - expect_dplyr_equal( + withr::local_options(list(arrow.debug = TRUE)) + expect_dplyr_equal( input %>% - summarize(counts = n()) %>% - collect(), + summarize(counts = n()) %>% + collect(), tbl - ) - - skip("segfault") - expect_dplyr_equal( + ) + + expect_dplyr_equal( input %>% - group_by(some_grouping) %>% - summarize(counts = n()) %>% - arrange(some_grouping) %>% - collect(), + group_by(some_grouping) %>% + summarize(counts = n()) %>% + arrange(some_grouping) %>% + collect(), tbl - ) + ) }) test_that("Group by any/all", { - withr::local_options(list(arrow.debug = TRUE)) - expect_dplyr_equal( input %>% group_by(some_grouping) %>% @@ -274,4 +271,4 @@ test_that("Filter and aggregate", { collect(), tbl ) -}) \ No newline at end of file +}) From 2ccd6e36bf5ac404d0ebecdc53361fd3dcf859bb Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Mon, 30 Aug 2021 13:27:58 -0400 Subject: [PATCH 3/3] Move the case_when skip forward --- r/tests/testthat/test-dplyr.R | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index ed03c58a884..d3a9994b5f1 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -26,6 +26,7 @@ tbl$verses <- verses[[1]] # c(" a ", " b ", " c ", ...) increasing padding # nchar = 3 5 7 9 11 13 15 17 19 21 tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2 * (1:10) + 1, side = "both") +tbl$another_chr <- tail(letters, 10) test_that("basic select/filter/collect", { batch <- record_batch(tbl) @@ -961,7 +962,7 @@ test_that("No duplicate field names are allowed in an arrow_dplyr_query", { filter(int > 0), regexp = paste0( 'The following field names were found more than once in the data: "int", "dbl", ', - '"dbl2", "lgl", "false", "chr", "fct", "verses", and "padded_strings"' + '"dbl2", "lgl", "false", "chr", "fct", "verses", "padded_strings"' ) ) }) @@ -1109,9 +1110,6 @@ test_that("trig functions", { }) test_that("if_else and ifelse", { - tbl <- example_data - tbl$another_chr <- tail(letters, 10) - expect_dplyr_equal( input %>% mutate( @@ -1342,7 +1340,6 @@ test_that("case_when()", { ) ) - skip("case_when does not yet support with variable-width types (ARROW-13222)") expect_dplyr_equal( input %>% transmute(cw = case_when(lgl ~ "abc")) %>% @@ -1355,10 +1352,11 @@ test_that("case_when()", { collect(), tbl ) + skip("ARROW-13799: factor() should error but instead we get a string error message in its place") expect_dplyr_equal( input %>% mutate( - cw = paste0(case_when(!(!(!(lgl))) ~ factor(chr), TRUE ~ fct), "!") + cw = case_when(!(!(!(lgl))) ~ factor(chr), TRUE ~ fct) ) %>% collect(), tbl