diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R index deb26859e8b..5640565b153 100644 --- a/r/R/dplyr-collect.R +++ b/r/R/dplyr-collect.R @@ -62,17 +62,18 @@ restore_dplyr_features <- function(df, query) { # An arrow_dplyr_query holds some attributes that Arrow doesn't know about # After calling collect(), make sure these features are carried over - if (length(query$group_by_vars) > 0) { - # Preserve groupings, if present + if (length(dplyr::group_vars(query))) { if (is.data.frame(df)) { - df <- dplyr::grouped_df( + # Preserve groupings, if present + df <- dplyr::group_by( df, - dplyr::group_vars(query), - drop = dplyr::group_by_drop_default(query) + !!!syms(dplyr::group_vars(query)), + .drop = dplyr::group_by_drop_default(query), + .add = FALSE ) } else { # This is a Table, via compute() or collect(as_data_frame = FALSE) - df$metadata$r$attributes$.group_vars <- query$group_by_vars + df$metadata$r$attributes$.group_vars <- dplyr::group_vars(query) } } df diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 1f529066dd3..c7b58bc0275 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -49,6 +49,12 @@ arrow_dplyr_query <- function(.data) { if (inherits(.data, "data.frame")) { .data <- Table$create(.data) } + # ARROW-17737: If .data is a Table, remove groups from metadata + # (we've already grabbed the groups above) + if (inherits(.data, "ArrowTabular")) { + .data <- ungroup.ArrowTabular(.data) + } + # Evaluating expressions on a dataset with duplicated fieldnames will error dupes <- duplicated(names(.data)) if (any(dupes)) { @@ -182,7 +188,7 @@ dim.arrow_dplyr_query <- function(x) { # Query on in-memory Table, so evaluate the filter # Don't need any columns x <- select.arrow_dplyr_query(x, NULL) - rows <- nrow(compute.arrow_dplyr_query(x)) + rows <- nrow(as_arrow_table(x)) } c(rows, cols) } diff --git a/r/tests/testthat/test-dplyr-group-by.R b/r/tests/testthat/test-dplyr-group-by.R index 9bb6aa9600d..319eec25ed7 100644 --- a/r/tests/testthat/test-dplyr-group-by.R +++ b/r/tests/testthat/test-dplyr-group-by.R @@ -79,6 +79,25 @@ test_that("ungroup", { ) }) +test_that("Groups before conversion to a Table must not be restored after collect() (ARROW-17737)", { + compare_dplyr_binding( + .input %>% + group_by(chr, .add = FALSE) %>% + ungroup() %>% + collect(), + tbl %>% + group_by(int) + ) + compare_dplyr_binding( + .input %>% + group_by(chr, .add = TRUE) %>% + ungroup() %>% + collect(), + tbl %>% + group_by(int) + ) +}) + test_that("group_by then rename", { compare_dplyr_binding( .input %>% @@ -196,6 +215,20 @@ test_that("group_by() with .add", { collect(), tbl ) + compare_dplyr_binding( + .input %>% + group_by(.add = FALSE) %>% + collect(), + tbl %>% + group_by(dbl2) + ) + compare_dplyr_binding( + .input %>% + group_by(.add = TRUE) %>% + collect(), + tbl %>% + group_by(dbl2) + ) compare_dplyr_binding( .input %>% group_by(chr, .add = FALSE) %>%