diff --git a/r/R/dplyr-mutate.R b/r/R/dplyr-mutate.R index 051c5254e50..65421cd1e7f 100644 --- a/r/R/dplyr-mutate.R +++ b/r/R/dplyr-mutate.R @@ -38,11 +38,13 @@ mutate.arrow_dplyr_query <- function(.data, .data <- as_adq(.data) # Restrict the cases we support for now - if (length(dplyr::group_vars(.data)) > 0) { + has_aggregations <- any(unlist(lapply(exprs, all_funs)) %in% names(agg_funcs)) + if (has_aggregations) { + # ARROW-13926 # mutate() on a grouped dataset does calculations within groups # This doesn't matter on scalar ops (arithmetic etc.) but it does # for things with aggregations (e.g. subtracting the mean) - return(abandon_ship(call, .data, "mutate() on grouped data not supported in Arrow")) + return(abandon_ship(call, .data, "window functions not currently supported in Arrow")) } mask <- arrow_mask(.data) diff --git a/r/R/util.R b/r/R/util.R index 4811d5dbfbb..40b06fe959e 100644 --- a/r/R/util.R +++ b/r/R/util.R @@ -58,10 +58,29 @@ r_symbolic_constants <- c( "NA_integer_", "NA_real_", "NA_complex_", "NA_character_" ) +is_function <- function(expr, name) { + if (!is.call(expr)) { + return(FALSE) + } else { + if (deparse(expr[[1]]) == name) { + return(TRUE) + } + out <- lapply(expr, is_function, name) + } + any(map_lgl(out, isTRUE)) +} + all_funs <- function(expr) { - # Don't use setdiff so that we preserve duplicates - out <- all.names(expr) - out[!(out %in% all.vars(expr))] + # It is not sufficient to simply do something like + # setdiff(all.names, all.vars) + # here because that would fail to return the names of functions that + # share names with variables. + # To preserve duplicates, call `all.names()` not `all_names()` here. + if (is_quosure(expr)) { + expr <- quo_get_expr(expr) + } + names <- all.names(expr) + names[map_lgl(names, ~ is_function(expr, .))] } all_vars <- function(expr) { diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 41265f0e638..14157545d61 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -772,7 +772,7 @@ test_that("mutate() features not yet implemented", { ds %>% group_by(int) %>% mutate(avg = mean(int)), - "mutate() on grouped data not supported in Arrow\nCall collect() first to pull data into R.", + "window functions not currently supported in Arrow\nCall collect() first to pull data into R.", fixed = TRUE ) }) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 44127839108..30fc12ccf17 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -340,14 +340,57 @@ test_that("dplyr::mutate's examples", { # The mutate operation may yield different results on grouped # tibbles because the expressions are computed within groups. # The following normalises `mass` by the global average: - # TODO: ARROW-11702 + # TODO: ARROW-13926 expect_dplyr_equal( input %>% select(name, mass, species) %>% mutate(mass_norm = mass / mean(mass, na.rm = TRUE)) %>% collect(), starwars, - warning = TRUE + warning = "window function" + ) +}) + +test_that("Can mutate after group_by as long as there are no aggregations", { + expect_dplyr_equal( + input %>% + select(int, chr) %>% + group_by(chr) %>% + mutate(int = int + 6L) %>% + collect(), + tbl + ) + expect_dplyr_equal( + input %>% + select(mean = int, chr) %>% + # rename `int` to `mean` and use `mean` in `mutate()` to test that + # `all_funs()` does not incorrectly identify it as an aggregate function + group_by(chr) %>% + mutate(mean = mean + 6L) %>% + collect(), + tbl + ) + expect_warning( + tbl %>% + Table$create() %>% + select(int, chr) %>% + group_by(chr) %>% + mutate(avg_int = mean(int)) %>% + collect(), + "window functions not currently supported in Arrow; pulling data into R", + fixed = TRUE + ) + expect_warning( + tbl %>% + Table$create() %>% + select(mean = int, chr) %>% + # rename `int` to `mean` and use `mean(mean)` in `mutate()` to test that + # `all_funs()` detects `mean()` despite the collision with a column name + group_by(chr) %>% + mutate(avg_int = mean(mean)) %>% + collect(), + "window functions not currently supported in Arrow; pulling data into R", + fixed = TRUE ) })