From afca8e1ef892c4598f1724416a8a9d3e5890131f Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 14 Sep 2021 14:50:10 -0400 Subject: [PATCH 1/7] Allow mutate after group_by as long as there are no aggregations --- r/R/dplyr-mutate.R | 6 ++++-- r/tests/testthat/test-dplyr-mutate.R | 15 +++++++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/r/R/dplyr-mutate.R b/r/R/dplyr-mutate.R index 051c5254e50..aa3e1d15039 100644 --- a/r/R/dplyr-mutate.R +++ b/r/R/dplyr-mutate.R @@ -38,11 +38,13 @@ mutate.arrow_dplyr_query <- function(.data, .data <- as_adq(.data) # Restrict the cases we support for now - if (length(dplyr::group_vars(.data)) > 0) { + has_aggregations <- any(unlist(lapply(exprs, all_funs)) %in% names(agg_funcs)) + if (has_aggregations) { + # ARROW-13926 # mutate() on a grouped dataset does calculations within groups # This doesn't matter on scalar ops (arithmetic etc.) but it does # for things with aggregations (e.g. subtracting the mean) - return(abandon_ship(call, .data, "mutate() on grouped data not supported in Arrow")) + return(abandon_ship(call, .data, "window functions not supported in Arrow")) } mask <- arrow_mask(.data) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 44127839108..bc536e8e874 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -340,14 +340,25 @@ test_that("dplyr::mutate's examples", { # The mutate operation may yield different results on grouped # tibbles because the expressions are computed within groups. # The following normalises `mass` by the global average: - # TODO: ARROW-11702 + # TODO: ARROW-13926 expect_dplyr_equal( input %>% select(name, mass, species) %>% mutate(mass_norm = mass / mean(mass, na.rm = TRUE)) %>% collect(), starwars, - warning = TRUE + warning = "window function" + ) +}) + +test_that("Can mutate after group_by as long as there are no aggregations", { + expect_dplyr_equal( + input %>% + select(int, chr) %>% + group_by(chr) %>% + mutate(int = int + 6L) %>% + collect(), + tbl ) }) From 86ac1351069a7387342245e2230239f4182f2460 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 14 Sep 2021 15:57:19 -0400 Subject: [PATCH 2/7] Improve all_funs() utility function --- r/R/util.R | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/r/R/util.R b/r/R/util.R index 4811d5dbfbb..5695944dee1 100644 --- a/r/R/util.R +++ b/r/R/util.R @@ -58,10 +58,22 @@ r_symbolic_constants <- c( "NA_integer_", "NA_real_", "NA_complex_", "NA_character_" ) +is_function <- function(expr, name) { + if (!is.call(expr)) { + return(FALSE) + } else { + if (deparse(expr[[1]]) == name) { + return(TRUE) + } + out <- lapply(expr, is_function, name) + } + any(vapply(out, isTRUE, TRUE)) +} + all_funs <- function(expr) { # Don't use setdiff so that we preserve duplicates - out <- all.names(expr) - out[!(out %in% all.vars(expr))] + names <- all.names(expr) + names[vapply(names, function(name) {is_function(expr, name)}, TRUE)] } all_vars <- function(expr) { From be7d9e04f3efd38b608ca88e14a1a412e1365cda Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 14 Sep 2021 15:57:29 -0400 Subject: [PATCH 3/7] Add tests --- r/tests/testthat/test-dplyr-mutate.R | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index bc536e8e874..827141568de 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -360,6 +360,38 @@ test_that("Can mutate after group_by as long as there are no aggregations", { collect(), tbl ) + expect_dplyr_equal( + input %>% + select(mean = int, chr) %>% + # rename `int` to `mean` and use `mean` in `mutate()` to test that + # `all_funs()` does not incorrectly identify it as an aggregate function + group_by(chr) %>% + mutate(mean = mean + 6L) %>% + collect(), + tbl + ) + expect_warning( + tbl %>% + Table$create() %>% + select(int, chr) %>% + group_by(chr) %>% + mutate(avg_int = mean(int)) %>% + collect(), + "window functions not supported in Arrow; pulling data into R", + fixed = TRUE + ) + expect_warning( + tbl %>% + Table$create() %>% + select(mean = int, chr) %>% + # rename `int` to `mean` and use `mean(mean)` in `mutate()` to test that + # `all_funs()` detects `mean()` despite the collision with a column name + group_by(chr) %>% + mutate(avg_int = mean(mean)) %>% + collect(), + "window functions not supported in Arrow; pulling data into R", + fixed = TRUE + ) }) test_that("handle bad expressions", { From a4603ed271926775b93343c248a8e5190339db08 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 14 Sep 2021 16:40:25 -0400 Subject: [PATCH 4/7] Update r/R/util.R Co-authored-by: Neal Richardson --- r/R/util.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/R/util.R b/r/R/util.R index 5695944dee1..13090fdb1b3 100644 --- a/r/R/util.R +++ b/r/R/util.R @@ -67,7 +67,7 @@ is_function <- function(expr, name) { } out <- lapply(expr, is_function, name) } - any(vapply(out, isTRUE, TRUE)) + any(map_lgl(out, isTRUE)) } all_funs <- function(expr) { From ca8f7ae52424129d4d0ae4f32a2515dd614a936e Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 14 Sep 2021 16:41:24 -0400 Subject: [PATCH 5/7] Update r/R/util.R Co-authored-by: Neal Richardson --- r/R/util.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/R/util.R b/r/R/util.R index 13090fdb1b3..cec5cc4adc9 100644 --- a/r/R/util.R +++ b/r/R/util.R @@ -73,7 +73,7 @@ is_function <- function(expr, name) { all_funs <- function(expr) { # Don't use setdiff so that we preserve duplicates names <- all.names(expr) - names[vapply(names, function(name) {is_function(expr, name)}, TRUE)] + names[map_lgl(names, ~ is_function(expr, .))] } all_vars <- function(expr) { From ad83c052bc12583f8c3b313e645ceaa7226630d9 Mon Sep 17 00:00:00 2001 From: Ian Cook Date: Tue, 14 Sep 2021 17:06:41 -0400 Subject: [PATCH 6/7] Improve comment Co-authored-by: Neal Richardson --- r/R/util.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/r/R/util.R b/r/R/util.R index cec5cc4adc9..229c86ba63e 100644 --- a/r/R/util.R +++ b/r/R/util.R @@ -71,7 +71,9 @@ is_function <- function(expr, name) { } all_funs <- function(expr) { - # Don't use setdiff so that we preserve duplicates + # It is not sufficient to simply do something like `setdiff(all.names, all.vars)` here + # because that would fail to return the names of functions that share names with + # variables. To preserve duplicates, call `all.names()` not `all_names()` here. names <- all.names(expr) names[map_lgl(names, ~ is_function(expr, .))] } From 616f888b4d30afaaeb4293c62da082c74931202e Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 14 Sep 2021 17:34:47 -0400 Subject: [PATCH 7/7] Fix another test and squash a deprecation warning --- r/R/dplyr-mutate.R | 2 +- r/R/util.R | 11 ++++++++--- r/tests/testthat/test-dataset.R | 2 +- r/tests/testthat/test-dplyr-mutate.R | 4 ++-- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/r/R/dplyr-mutate.R b/r/R/dplyr-mutate.R index aa3e1d15039..65421cd1e7f 100644 --- a/r/R/dplyr-mutate.R +++ b/r/R/dplyr-mutate.R @@ -44,7 +44,7 @@ mutate.arrow_dplyr_query <- function(.data, # mutate() on a grouped dataset does calculations within groups # This doesn't matter on scalar ops (arithmetic etc.) but it does # for things with aggregations (e.g. subtracting the mean) - return(abandon_ship(call, .data, "window functions not supported in Arrow")) + return(abandon_ship(call, .data, "window functions not currently supported in Arrow")) } mask <- arrow_mask(.data) diff --git a/r/R/util.R b/r/R/util.R index 229c86ba63e..40b06fe959e 100644 --- a/r/R/util.R +++ b/r/R/util.R @@ -71,9 +71,14 @@ is_function <- function(expr, name) { } all_funs <- function(expr) { - # It is not sufficient to simply do something like `setdiff(all.names, all.vars)` here - # because that would fail to return the names of functions that share names with - # variables. To preserve duplicates, call `all.names()` not `all_names()` here. + # It is not sufficient to simply do something like + # setdiff(all.names, all.vars) + # here because that would fail to return the names of functions that + # share names with variables. + # To preserve duplicates, call `all.names()` not `all_names()` here. + if (is_quosure(expr)) { + expr <- quo_get_expr(expr) + } names <- all.names(expr) names[map_lgl(names, ~ is_function(expr, .))] } diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 41265f0e638..14157545d61 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -772,7 +772,7 @@ test_that("mutate() features not yet implemented", { ds %>% group_by(int) %>% mutate(avg = mean(int)), - "mutate() on grouped data not supported in Arrow\nCall collect() first to pull data into R.", + "window functions not currently supported in Arrow\nCall collect() first to pull data into R.", fixed = TRUE ) }) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 827141568de..30fc12ccf17 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -377,7 +377,7 @@ test_that("Can mutate after group_by as long as there are no aggregations", { group_by(chr) %>% mutate(avg_int = mean(int)) %>% collect(), - "window functions not supported in Arrow; pulling data into R", + "window functions not currently supported in Arrow; pulling data into R", fixed = TRUE ) expect_warning( @@ -389,7 +389,7 @@ test_that("Can mutate after group_by as long as there are no aggregations", { group_by(chr) %>% mutate(avg_int = mean(mean)) %>% collect(), - "window functions not supported in Arrow; pulling data into R", + "window functions not currently supported in Arrow; pulling data into R", fixed = TRUE ) })