From 529f80fd6ed52078ac21cbad367acfc22edca496 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 8 Sep 2021 09:49:30 -0400 Subject: [PATCH 1/5] Rough first commit that makes the test pass --- r/R/dplyr-eval.R | 1 + r/R/dplyr-summarize.R | 111 ++++++++++++++++++++---- r/R/util.R | 17 +--- r/tests/testthat/test-dplyr-summarize.R | 4 +- 4 files changed, 101 insertions(+), 32 deletions(-) diff --git a/r/R/dplyr-eval.R b/r/R/dplyr-eval.R index 89eec94e4d2..c65ed605834 100644 --- a/r/R/dplyr-eval.R +++ b/r/R/dplyr-eval.R @@ -28,6 +28,7 @@ arrow_eval <- function(expr, mask) { # else, for things not supported by Arrow return a "try-error", # which we'll handle differently msg <- conditionMessage(e) + if (getOption("arrow.debug", FALSE)) print(msg) patterns <- .cache$i18ized_error_pattern if (is.null(patterns)) { patterns <- i18ize_error_messages() diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R index 3a6c76e28cb..80ac799116e 100644 --- a/r/R/dplyr-summarize.R +++ b/r/R/dplyr-summarize.R @@ -49,26 +49,38 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) { } exprs <- ensure_named_exprs(quos(...)) - mask <- arrow_mask(.data, aggregation = TRUE) - - results <- empty_named_list() + # Create a stateful environment for recording our evaluated expressions + # It's more complex than other places because a single summarize() expr + # may result in multiple query nodes (Aggregate, Project) + ctx <- env( + mask = arrow_mask(.data, aggregation = TRUE), + results = empty_named_list(), + post_mutate = empty_named_list() + ) for (i in seq_along(exprs)) { # Iterate over the indices and not the names because names may be repeated # (which overwrites the previous name) - new_var <- names(exprs)[i] - results[[new_var]] <- arrow_eval(exprs[[i]], mask) - if (inherits(results[[new_var]], "try-error")) { - msg <- handle_arrow_not_supported( - results[[new_var]], - as_label(exprs[[i]]) - ) - stop(msg, call. = FALSE) - } + summarize_eval(names(exprs)[i], exprs[[i]], ctx) } - .data$aggregations <- results - # TODO: should in-memory query evaluate eagerly? - collapse.arrow_dplyr_query(.data) + .data$aggregations <- ctx$results + out <- collapse.arrow_dplyr_query(.data) + if (length(ctx$post_mutate)) { + # mutate() + # TODO: get order of columns correct + out$selected_columns <- c(out$selected_columns[-grep("^\\.\\.temp", names(out$selected_columns))], ctx$post_mutate) + } + out +} + +arrow_eval_or_stop <- function(expr, mask) { + # TODO: change arrow_eval error handling behavior? + out <- arrow_eval(expr, mask) + if (inherits(out, "try-error")) { + msg <- handle_arrow_not_supported(out, as_label(expr)) + stop(msg, call. = FALSE) + } + out } summarize_projection <- function(.data) { @@ -81,3 +93,72 @@ summarize_projection <- function(.data) { format_aggregation <- function(x) { paste0(x$fun, "(", x$data$ToString(), ")") } + +# Cases: +# * agg(fun(x, y)): OK +# * fun(agg(x), agg(y)): TODO now: pull out aggregates, insert fieldref, then mutate +# * z = agg(x); fun(z, agg(y)): TODO now +# * agg(fun(agg(x), agg(y))): TODO now too? is this meaningful? (dplyr doesn't error on it) +# * fun(agg(x), y): Later (implicit join; seems to be equivalent to doing it in mutate) +# * z = agg(x); fun(z, y): Later (same, implicit join) + +# find aggregation subcomponents +# eval, insert fieldref; give "..temp" prefix to name +# record fieldrefs in list and in mask +# + +summarize_eval <- function(name, quosure, ctx, recurse = FALSE) { + expr <- quo_get_expr(quosure) + ctx$quo_env <- rlang::quo_get_env(quosure) + funs_in_expr <- all_funs(expr) + + if (length(funs_in_expr) == 0) { + # Skip if it is a scalar or field ref + ctx$results[[name]] <- arrow_eval_or_stop(quosure, ctx$mask) + return() + } + + agg_funs <- names(agg_funcs) + outer_agg <- funs_in_expr[1] %in% agg_funs + inner_agg <- funs_in_expr[-1] %in% agg_funs + + # First, pull out any aggregations wrapped in other function calls + if (any(inner_agg)) { + expr <- extract_aggregations(expr, ctx) + } + + inner_agg_exprs <- all_vars(expr) %in% names(ctx$results) + + if (outer_agg) { + # This just works by normal arrow_eval, unless there's a mix of aggs and + # columns in the original data like agg(fun(x, agg(x))) + # TODO if this errors, check whether all/any inner_agg_exprs + ctx$results[[name]] <- arrow_eval_or_stop(quosure, ctx$mask) + return() + } else if (all(inner_agg_exprs)) { + # fun(agg(x), ...) + # So based on the aggregations that have been extracted, mutate after + mutate_mask <- arrow_mask(list(selected_columns = make_field_refs(names(ctx$results)))) + ctx$post_mutate[[name]] <- arrow_eval_or_stop(rlang::as_quosure(expr, ctx$quo_env), mutate_mask) + return() + } + # TODO: Handle some known cases + + stop(handle_arrow_not_supported(expr, as_label(expr)), call. = FALSE) +} + +extract_aggregations <- function(expr, ctx) { + funs <- all_funs(expr) + if (length(funs) == 0) { + return(expr) + } else if (length(funs) > 1) { + # Recurse more + expr[-1] <- lapply(expr[-1], extract_aggregations, ctx) + } + if (funs[1] %in% names(agg_funcs)) { + tmpname <- paste0("..temp", length(ctx$results)) + ctx$results[[tmpname]] <- arrow_eval_or_stop(rlang::as_quosure(expr, ctx$quo_env), ctx$mask) + expr <- as.symbol(tmpname) + } + expr +} diff --git a/r/R/util.R b/r/R/util.R index 5958b0b3111..4811d5dbfbb 100644 --- a/r/R/util.R +++ b/r/R/util.R @@ -58,21 +58,10 @@ r_symbolic_constants <- c( "NA_integer_", "NA_real_", "NA_complex_", "NA_character_" ) -is_function <- function(expr, name) { - if (!is.call(expr)) { - return(FALSE) - } else { - if (deparse1(expr[[1]]) == name) { - return(TRUE) - } - out <- lapply(expr, is_function, name) - } - any(vapply(out, isTRUE, TRUE)) -} - all_funs <- function(expr) { - names <- all_names(expr) - names[vapply(names, function(name) is_function(expr, name), TRUE)] + # Don't use setdiff so that we preserve duplicates + out <- all.names(expr) + out[!(out %in% all.vars(expr))] } all_vars <- function(expr) { diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index 18596fcf30c..dd58a1c9dd3 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -162,7 +162,6 @@ test_that("Group by var on dataset", { }) test_that("n()", { - withr::local_options(list(arrow.debug = TRUE)) expect_dplyr_equal( input %>% summarize(counts = n()) %>% @@ -350,7 +349,6 @@ test_that("Expressions on aggregations", { any = any(lgl), all = all(lgl) ) %>% - compute() %>% ungroup() %>% # TODO: loosen the restriction on mutate after group_by mutate(some = any & !all) %>% select(some_grouping, some) %>% @@ -358,7 +356,7 @@ test_that("Expressions on aggregations", { tbl ) # More concisely: - skip("TODO: ARROW-13778") + withr::local_options(list(arrow.debug = TRUE)) expect_dplyr_equal( input %>% group_by(some_grouping) %>% From 0abb74e4bfd3220cc481fb5aa10c485d9153ccd4 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 8 Sep 2021 10:56:32 -0400 Subject: [PATCH 2/5] Namespace and doc --- r/NAMESPACE | 2 ++ r/R/arrow-package.R | 4 ++-- r/R/dplyr-summarize.R | 6 +++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/r/NAMESPACE b/r/NAMESPACE index 5164e7c9f20..cabd8ffb5f8 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -306,6 +306,7 @@ importFrom(rlang,"%||%") importFrom(rlang,.data) importFrom(rlang,abort) importFrom(rlang,as_label) +importFrom(rlang,as_quosure) importFrom(rlang,caller_env) importFrom(rlang,dots_n) importFrom(rlang,enexpr) @@ -325,6 +326,7 @@ importFrom(rlang,is_quosure) importFrom(rlang,list2) importFrom(rlang,new_data_mask) importFrom(rlang,new_environment) +importFrom(rlang,quo_get_env) importFrom(rlang,quo_get_expr) importFrom(rlang,quo_is_null) importFrom(rlang,quo_name) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index c09b8f05319..e41fe58d696 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -21,8 +21,8 @@ #' @importFrom assertthat assert_that is.string #' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos #' @importFrom rlang eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec -#' @importFrom rlang is_bare_character quo_get_expr quo_set_expr .data seq2 is_quosure enexpr enexprs -#' @importFrom rlang expr caller_env is_character quo_name +#' @importFrom rlang is_bare_character quo_get_expr quo_get_env quo_set_expr .data seq2 +#' @importFrom rlang expr caller_env is_character quo_name is_quosure enexpr enexprs as_quosure #' @importFrom tidyselect vars_pull vars_rename vars_select eval_select #' @useDynLib arrow, .registration = TRUE #' @keywords internal diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R index 80ac799116e..1d8eb266cf8 100644 --- a/r/R/dplyr-summarize.R +++ b/r/R/dplyr-summarize.R @@ -109,7 +109,7 @@ format_aggregation <- function(x) { summarize_eval <- function(name, quosure, ctx, recurse = FALSE) { expr <- quo_get_expr(quosure) - ctx$quo_env <- rlang::quo_get_env(quosure) + ctx$quo_env <- quo_get_env(quosure) funs_in_expr <- all_funs(expr) if (length(funs_in_expr) == 0) { @@ -139,7 +139,7 @@ summarize_eval <- function(name, quosure, ctx, recurse = FALSE) { # fun(agg(x), ...) # So based on the aggregations that have been extracted, mutate after mutate_mask <- arrow_mask(list(selected_columns = make_field_refs(names(ctx$results)))) - ctx$post_mutate[[name]] <- arrow_eval_or_stop(rlang::as_quosure(expr, ctx$quo_env), mutate_mask) + ctx$post_mutate[[name]] <- arrow_eval_or_stop(as_quosure(expr, ctx$quo_env), mutate_mask) return() } # TODO: Handle some known cases @@ -157,7 +157,7 @@ extract_aggregations <- function(expr, ctx) { } if (funs[1] %in% names(agg_funcs)) { tmpname <- paste0("..temp", length(ctx$results)) - ctx$results[[tmpname]] <- arrow_eval_or_stop(rlang::as_quosure(expr, ctx$quo_env), ctx$mask) + ctx$results[[tmpname]] <- arrow_eval_or_stop(as_quosure(expr, ctx$quo_env), ctx$mask) expr <- as.symbol(tmpname) } expr From 9d580403df047c950e26c3b0310173923b723198 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 8 Sep 2021 16:22:54 -0400 Subject: [PATCH 3/5] Handle some unsupported cases --- r/R/dplyr-summarize.R | 27 +++++++++--- r/tests/testthat/test-dplyr-summarize.R | 58 ++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 6 deletions(-) diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R index 1d8eb266cf8..b692abcf800 100644 --- a/r/R/dplyr-summarize.R +++ b/r/R/dplyr-summarize.R @@ -27,7 +27,7 @@ summarise.arrow_dplyr_query <- function(.data, ..., .engine = c("arrow", "duckdb unlist(lapply(exprs, all.vars)), # vars referenced in summarise dplyr::group_vars(.data) # vars needed for grouping )) - .data <- dplyr::select(.data, vars_to_keep) + .data <- dplyr::select(.data, intersect(vars_to_keep, names(.data))) if (match.arg(.engine) == "duckdb") { dplyr::summarise(to_duckdb(.data), ...) } else { @@ -132,7 +132,7 @@ summarize_eval <- function(name, quosure, ctx, recurse = FALSE) { if (outer_agg) { # This just works by normal arrow_eval, unless there's a mix of aggs and # columns in the original data like agg(fun(x, agg(x))) - # TODO if this errors, check whether all/any inner_agg_exprs + # (but that will have been caught in extract_aggregations()) ctx$results[[name]] <- arrow_eval_or_stop(quosure, ctx$mask) return() } else if (all(inner_agg_exprs)) { @@ -142,12 +142,20 @@ summarize_eval <- function(name, quosure, ctx, recurse = FALSE) { ctx$post_mutate[[name]] <- arrow_eval_or_stop(as_quosure(expr, ctx$quo_env), mutate_mask) return() } - # TODO: Handle some known cases - - stop(handle_arrow_not_supported(expr, as_label(expr)), call. = FALSE) + # !outer_agg && !all(inner_agg_exprs) + # This is fun(x, agg(y)), which really should be in mutate() + # but summarize() allows it. (See also below in extract_aggregations) + # TODO: support in ARROW-13926 + # (This could also be fun(x, y), which would work in mutate() already + # if it were the only expression) + # TODO: this message should probably also say "not supported in summarize()" + # since some of these expressions may be legal elsewhere + stop(handle_arrow_not_supported(quo_get_expr(quosure), as_label(quo_get_expr(quosure))), call. = FALSE) } extract_aggregations <- function(expr, ctx) { + # Keep the input in case we need to raise an error message with it + original_expr <- expr funs <- all_funs(expr) if (length(funs) == 0) { return(expr) @@ -156,6 +164,15 @@ extract_aggregations <- function(expr, ctx) { expr[-1] <- lapply(expr[-1], extract_aggregations, ctx) } if (funs[1] %in% names(agg_funcs)) { + inner_agg_exprs <- all_vars(expr) %in% names(ctx$results) + if (any(inner_agg_exprs) & !all(inner_agg_exprs)) { + # We can't aggregate over a combination of dataset columns and other + # aggregations (e.g. sum(x - mean(x))) + # TODO: Add "because" arg to explain _why_ it's not supported? + # TODO: support in ARROW-13926 + stop(handle_arrow_not_supported(original_expr, as_label(original_expr)), call. = FALSE) + } + tmpname <- paste0("..temp", length(ctx$results)) ctx$results[[tmpname]] <- arrow_eval_or_stop(as_quosure(expr, ctx$quo_env), ctx$mask) expr <- as.symbol(tmpname) diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index dd58a1c9dd3..b7f62cd0422 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -356,7 +356,6 @@ test_that("Expressions on aggregations", { tbl ) # More concisely: - withr::local_options(list(arrow.debug = TRUE)) expect_dplyr_equal( input %>% group_by(some_grouping) %>% @@ -364,6 +363,18 @@ test_that("Expressions on aggregations", { collect(), tbl ) + + # Save one of the aggregates first + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + any_lgl = any(lgl), + some = any_lgl & !all(lgl) + ) %>% + collect(), + tbl + ) }) test_that("Summarize with 0 arguments", { @@ -375,3 +386,48 @@ test_that("Summarize with 0 arguments", { tbl ) }) + +test_that("Not (yet) supported: implicit join", { + withr::local_options(list(arrow.debug = TRUE)) + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + sum((dbl - mean(dbl))^2) + ) %>% + collect(), + tbl, + warning = "Expression sum\\(\\(dbl - mean\\(dbl\\)\\)\\^2\\) not supported in Arrow; pulling data into R" + ) + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + sqrt(sum((dbl - mean(dbl))^2) / (n() - 1L)) + ) %>% + collect(), + tbl, + warning = "Expression sum\\(\\(dbl - mean\\(dbl\\)\\)\\^2\\) not supported in Arrow; pulling data into R" + ) + + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + dbl - mean(dbl) + ) %>% + collect(), + tbl, + warning = "Expression dbl - mean\\(dbl\\) not supported in Arrow; pulling data into R" + ) + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + dbl - int + ) %>% + collect(), + tbl, + warning = "Expression dbl - int not supported in Arrow; pulling data into R" + ) +}) From acb04618cbe1666539b2f09aad61b366f2532ce8 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Thu, 9 Sep 2021 14:50:37 -0400 Subject: [PATCH 4/5] Add copious comments and a few more tests --- r/R/dplyr-summarize.R | 129 ++++++++++++++++-------- r/tests/testthat/test-dplyr-summarize.R | 37 +++++++ 2 files changed, 124 insertions(+), 42 deletions(-) diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R index b692abcf800..fbc237446dc 100644 --- a/r/R/dplyr-summarize.R +++ b/r/R/dplyr-summarize.R @@ -27,6 +27,13 @@ summarise.arrow_dplyr_query <- function(.data, ..., .engine = c("arrow", "duckdb unlist(lapply(exprs, all.vars)), # vars referenced in summarise dplyr::group_vars(.data) # vars needed for grouping )) + # If exprs rely on the results of previous exprs + # (total = sum(x), mean = total / n()) + # then not all vars will correspond to columns in the data, + # so don't try to select() them (use intersect() to exclude them) + # Note that this select() isn't useful for the Arrow summarize implementation + # because it will effectively project to keep what it needs anyway, + # but the duckdb and data.frame fallback versions do benefit from select here .data <- dplyr::select(.data, intersect(vars_to_keep, names(.data))) if (match.arg(.engine) == "duckdb") { dplyr::summarise(to_duckdb(.data), ...) @@ -42,6 +49,7 @@ summarise.arrow_dplyr_query <- function(.data, ..., .engine = c("arrow", "duckdb } summarise.Dataset <- summarise.ArrowTabular <- summarise.arrow_dplyr_query +# This is the Arrow summarize implementation do_arrow_summarize <- function(.data, ..., .groups = NULL) { if (!is.null(.groups)) { # ARROW-13550 @@ -51,10 +59,11 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) { # Create a stateful environment for recording our evaluated expressions # It's more complex than other places because a single summarize() expr - # may result in multiple query nodes (Aggregate, Project) + # may result in multiple query nodes (Aggregate, Project), + # and we have to walk through the expressions to disentangle them. ctx <- env( mask = arrow_mask(.data, aggregation = TRUE), - results = empty_named_list(), + aggregations = empty_named_list(), post_mutate = empty_named_list() ) for (i in seq_along(exprs)) { @@ -63,12 +72,28 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) { summarize_eval(names(exprs)[i], exprs[[i]], ctx) } - .data$aggregations <- ctx$results + # Apply the results to the .data object. + # First, the aggregations + .data$aggregations <- ctx$aggregations + # Then collapse the query so that the resulting query object can have + # additional operations applied to it out <- collapse.arrow_dplyr_query(.data) + # The expressions may have been translated into + # "first, aggregate, then transform the result further" + # For example, + # summarize(mean = sum(x) / n()) + # is effectively implemented as + # summarize(..temp0 = sum(x), ..temp1 = n()) %>% + # mutate(mean = ..temp0 / ..temp1) %>% + # select(-starts_with("..temp")) + # If this is the case, there will be expressions in post_mutate if (length(ctx$post_mutate)) { - # mutate() - # TODO: get order of columns correct - out$selected_columns <- c(out$selected_columns[-grep("^\\.\\.temp", names(out$selected_columns))], ctx$post_mutate) + # Append post_mutate, and make sure order is correct + # according to input exprs (also dropping ..temp columns) + out$selected_columns <- c( + out$selected_columns, + ctx$post_mutate + )[c(.data$group_by_vars, names(exprs))] } out } @@ -94,30 +119,22 @@ format_aggregation <- function(x) { paste0(x$fun, "(", x$data$ToString(), ")") } -# Cases: -# * agg(fun(x, y)): OK -# * fun(agg(x), agg(y)): TODO now: pull out aggregates, insert fieldref, then mutate -# * z = agg(x); fun(z, agg(y)): TODO now -# * agg(fun(agg(x), agg(y))): TODO now too? is this meaningful? (dplyr doesn't error on it) -# * fun(agg(x), y): Later (implicit join; seems to be equivalent to doing it in mutate) -# * z = agg(x); fun(z, y): Later (same, implicit join) - -# find aggregation subcomponents -# eval, insert fieldref; give "..temp" prefix to name -# record fieldrefs in list and in mask -# - +# This function handles each summarize expression and turns it into the +# appropriate combination of (1) aggregations (possibly temporary) and +# (2) post-aggregation transformations (mutate) +# The function returns nothing: it assigns into the `ctx` environment summarize_eval <- function(name, quosure, ctx, recurse = FALSE) { expr <- quo_get_expr(quosure) ctx$quo_env <- quo_get_env(quosure) - funs_in_expr <- all_funs(expr) + funs_in_expr <- all_funs(expr) if (length(funs_in_expr) == 0) { - # Skip if it is a scalar or field ref - ctx$results[[name]] <- arrow_eval_or_stop(quosure, ctx$mask) + # If it is a scalar or field ref, no special handling required + ctx$aggregations[[name]] <- arrow_eval_or_stop(quosure, ctx$mask) return() } + # Start inspecting the expr to see what aggregations it involves agg_funs <- names(agg_funcs) outer_agg <- funs_in_expr[1] %in% agg_funs inner_agg <- funs_in_expr[-1] %in% agg_funs @@ -127,32 +144,51 @@ summarize_eval <- function(name, quosure, ctx, recurse = FALSE) { expr <- extract_aggregations(expr, ctx) } - inner_agg_exprs <- all_vars(expr) %in% names(ctx$results) + # By this point, there are no more aggregation functions in expr + # except for possibly the outer function call: + # they've all been pulled out to ctx$aggregations, and in their place in expr + # there are variable names, which will correspond to field refs in the + # query object after aggregation and collapse(). + # So if we want to know if there are any aggregations inside expr, + # we have to look for them by their new var names + inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations) if (outer_agg) { - # This just works by normal arrow_eval, unless there's a mix of aggs and + # This is something like agg(fun(x, y) + # It just works by normal arrow_eval, unless there's a mix of aggs and # columns in the original data like agg(fun(x, agg(x))) # (but that will have been caught in extract_aggregations()) - ctx$results[[name]] <- arrow_eval_or_stop(quosure, ctx$mask) + ctx$aggregations[[name]] <- arrow_eval_or_stop( + as_quosure(expr, ctx$quo_env), + ctx$mask + ) return() } else if (all(inner_agg_exprs)) { - # fun(agg(x), ...) + # fun(agg(x), agg(y)) # So based on the aggregations that have been extracted, mutate after - mutate_mask <- arrow_mask(list(selected_columns = make_field_refs(names(ctx$results)))) - ctx$post_mutate[[name]] <- arrow_eval_or_stop(as_quosure(expr, ctx$quo_env), mutate_mask) + mutate_mask <- arrow_mask( + list(selected_columns = make_field_refs(names(ctx$aggregations))) + ) + ctx$post_mutate[[name]] <- arrow_eval_or_stop( + as_quosure(expr, ctx$quo_env), + mutate_mask + ) return() } - # !outer_agg && !all(inner_agg_exprs) - # This is fun(x, agg(y)), which really should be in mutate() - # but summarize() allows it. (See also below in extract_aggregations) - # TODO: support in ARROW-13926 - # (This could also be fun(x, y), which would work in mutate() already - # if it were the only expression) - # TODO: this message should probably also say "not supported in summarize()" - # since some of these expressions may be legal elsewhere - stop(handle_arrow_not_supported(quo_get_expr(quosure), as_label(quo_get_expr(quosure))), call. = FALSE) + + # Backstop for any other odd cases, like fun(x, y) (i.e. no aggregation), + # or aggregation functions that aren't supported in Arrow (not in agg_funcs) + stop( + handle_arrow_not_supported( + quo_get_expr(quosure), + as_label(quo_get_expr(quosure)) + ), + call. = FALSE + ) } +# This function recurses through expr, pulls out any aggregation expressions, +# and inserts a variable name (field ref) in place of the aggregation extract_aggregations <- function(expr, ctx) { # Keep the input in case we need to raise an error message with it original_expr <- expr @@ -164,17 +200,26 @@ extract_aggregations <- function(expr, ctx) { expr[-1] <- lapply(expr[-1], extract_aggregations, ctx) } if (funs[1] %in% names(agg_funcs)) { - inner_agg_exprs <- all_vars(expr) %in% names(ctx$results) + inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations) if (any(inner_agg_exprs) & !all(inner_agg_exprs)) { # We can't aggregate over a combination of dataset columns and other # aggregations (e.g. sum(x - mean(x))) - # TODO: Add "because" arg to explain _why_ it's not supported? # TODO: support in ARROW-13926 - stop(handle_arrow_not_supported(original_expr, as_label(original_expr)), call. = FALSE) + # TODO: Add "because" arg to explain _why_ it's not supported? + # TODO: this message could also say "not supported in summarize()" + # since some of these expressions may be legal elsewhere + stop( + handle_arrow_not_supported(original_expr, as_label(original_expr)), + call. = FALSE + ) } - tmpname <- paste0("..temp", length(ctx$results)) - ctx$results[[tmpname]] <- arrow_eval_or_stop(as_quosure(expr, ctx$quo_env), ctx$mask) + # We have an aggregation expression with no other aggregations inside it, + # so arrow_eval the expression on the data and give it a ..temp name prefix, + # then insert that name (symbol) back into the expression so that we can + # mutate() on the result of the aggregation and reference this field. + tmpname <- paste0("..temp", length(ctx$aggregations)) + ctx$aggregations[[tmpname]] <- arrow_eval_or_stop(as_quosure(expr, ctx$quo_env), ctx$mask) expr <- as.symbol(tmpname) } expr diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index b7f62cd0422..fa4bffe30d7 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -375,6 +375,31 @@ test_that("Expressions on aggregations", { collect(), tbl ) + + # Make sure order of columns in result is correct + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + any_lgl = any(lgl), + some = any_lgl & !all(lgl), + n() + ) %>% + collect(), + tbl + ) + + # Aggregate on an aggregate (trivial but dplyr allows) + skip("Not supported") + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + any_lgl = any(any(lgl)) + ) %>% + collect(), + tbl + ) }) test_that("Summarize with 0 arguments", { @@ -399,6 +424,16 @@ test_that("Not (yet) supported: implicit join", { tbl, warning = "Expression sum\\(\\(dbl - mean\\(dbl\\)\\)\\^2\\) not supported in Arrow; pulling data into R" ) + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + sum(dbl - mean(dbl)) + ) %>% + collect(), + tbl, + warning = "Expression sum\\(dbl - mean\\(dbl\\)\\) not supported in Arrow; pulling data into R" + ) expect_dplyr_equal( input %>% group_by(some_grouping) %>% @@ -420,6 +455,8 @@ test_that("Not (yet) supported: implicit join", { tbl, warning = "Expression dbl - mean\\(dbl\\) not supported in Arrow; pulling data into R" ) + + # This one could possibly be supported--in mutate() expect_dplyr_equal( input %>% group_by(some_grouping) %>% From 8d037e842d47ba115e4e3f345fa97a82a8e51c4e Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Thu, 9 Sep 2021 19:57:33 -0400 Subject: [PATCH 5/5] Subvert the linter Co-authored-by: Jonathan Keane --- r/R/dplyr-summarize.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R index fbc237446dc..459b7435a87 100644 --- a/r/R/dplyr-summarize.R +++ b/r/R/dplyr-summarize.R @@ -80,6 +80,7 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) { out <- collapse.arrow_dplyr_query(.data) # The expressions may have been translated into # "first, aggregate, then transform the result further" + # nolint start # For example, # summarize(mean = sum(x) / n()) # is effectively implemented as @@ -87,6 +88,7 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) { # mutate(mean = ..temp0 / ..temp1) %>% # select(-starts_with("..temp")) # If this is the case, there will be expressions in post_mutate + # nolint end if (length(ctx$post_mutate)) { # Append post_mutate, and make sure order is correct # according to input exprs (also dropping ..temp columns) @@ -164,7 +166,7 @@ summarize_eval <- function(name, quosure, ctx, recurse = FALSE) { ) return() } else if (all(inner_agg_exprs)) { - # fun(agg(x), agg(y)) + # Something like: fun(agg(x), agg(y)) # So based on the aggregations that have been extracted, mutate after mutate_mask <- arrow_mask( list(selected_columns = make_field_refs(names(ctx$aggregations)))