diff --git a/r/NAMESPACE b/r/NAMESPACE index 5164e7c9f20..cabd8ffb5f8 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -306,6 +306,7 @@ importFrom(rlang,"%||%") importFrom(rlang,.data) importFrom(rlang,abort) importFrom(rlang,as_label) +importFrom(rlang,as_quosure) importFrom(rlang,caller_env) importFrom(rlang,dots_n) importFrom(rlang,enexpr) @@ -325,6 +326,7 @@ importFrom(rlang,is_quosure) importFrom(rlang,list2) importFrom(rlang,new_data_mask) importFrom(rlang,new_environment) +importFrom(rlang,quo_get_env) importFrom(rlang,quo_get_expr) importFrom(rlang,quo_is_null) importFrom(rlang,quo_name) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index c09b8f05319..e41fe58d696 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -21,8 +21,8 @@ #' @importFrom assertthat assert_that is.string #' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos #' @importFrom rlang eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec -#' @importFrom rlang is_bare_character quo_get_expr quo_set_expr .data seq2 is_quosure enexpr enexprs -#' @importFrom rlang expr caller_env is_character quo_name +#' @importFrom rlang is_bare_character quo_get_expr quo_get_env quo_set_expr .data seq2 +#' @importFrom rlang expr caller_env is_character quo_name is_quosure enexpr enexprs as_quosure #' @importFrom tidyselect vars_pull vars_rename vars_select eval_select #' @useDynLib arrow, .registration = TRUE #' @keywords internal diff --git a/r/R/dplyr-eval.R b/r/R/dplyr-eval.R index 89eec94e4d2..c65ed605834 100644 --- a/r/R/dplyr-eval.R +++ b/r/R/dplyr-eval.R @@ -28,6 +28,7 @@ arrow_eval <- function(expr, mask) { # else, for things not supported by Arrow return a "try-error", # which we'll handle differently msg <- conditionMessage(e) + if (getOption("arrow.debug", FALSE)) print(msg) patterns <- .cache$i18ized_error_pattern if (is.null(patterns)) { patterns <- i18ize_error_messages() diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R index 3a6c76e28cb..459b7435a87 100644 --- a/r/R/dplyr-summarize.R +++ b/r/R/dplyr-summarize.R @@ -27,7 +27,14 @@ summarise.arrow_dplyr_query <- function(.data, ..., .engine = c("arrow", "duckdb unlist(lapply(exprs, all.vars)), # vars referenced in summarise dplyr::group_vars(.data) # vars needed for grouping )) - .data <- dplyr::select(.data, vars_to_keep) + # If exprs rely on the results of previous exprs + # (total = sum(x), mean = total / n()) + # then not all vars will correspond to columns in the data, + # so don't try to select() them (use intersect() to exclude them) + # Note that this select() isn't useful for the Arrow summarize implementation + # because it will effectively project to keep what it needs anyway, + # but the duckdb and data.frame fallback versions do benefit from select here + .data <- dplyr::select(.data, intersect(vars_to_keep, names(.data))) if (match.arg(.engine) == "duckdb") { dplyr::summarise(to_duckdb(.data), ...) } else { @@ -42,6 +49,7 @@ summarise.arrow_dplyr_query <- function(.data, ..., .engine = c("arrow", "duckdb } summarise.Dataset <- summarise.ArrowTabular <- summarise.arrow_dplyr_query +# This is the Arrow summarize implementation do_arrow_summarize <- function(.data, ..., .groups = NULL) { if (!is.null(.groups)) { # ARROW-13550 @@ -49,26 +57,57 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) { } exprs <- ensure_named_exprs(quos(...)) - mask <- arrow_mask(.data, aggregation = TRUE) - - results <- empty_named_list() + # Create a stateful environment for recording our evaluated expressions + # It's more complex than other places because a single summarize() expr + # may result in multiple query nodes (Aggregate, Project), + # and we have to walk through the expressions to disentangle them. + ctx <- env( + mask = arrow_mask(.data, aggregation = TRUE), + aggregations = empty_named_list(), + post_mutate = empty_named_list() + ) for (i in seq_along(exprs)) { # Iterate over the indices and not the names because names may be repeated # (which overwrites the previous name) - new_var <- names(exprs)[i] - results[[new_var]] <- arrow_eval(exprs[[i]], mask) - if (inherits(results[[new_var]], "try-error")) { - msg <- handle_arrow_not_supported( - results[[new_var]], - as_label(exprs[[i]]) - ) - stop(msg, call. = FALSE) - } + summarize_eval(names(exprs)[i], exprs[[i]], ctx) } - .data$aggregations <- results - # TODO: should in-memory query evaluate eagerly? - collapse.arrow_dplyr_query(.data) + # Apply the results to the .data object. + # First, the aggregations + .data$aggregations <- ctx$aggregations + # Then collapse the query so that the resulting query object can have + # additional operations applied to it + out <- collapse.arrow_dplyr_query(.data) + # The expressions may have been translated into + # "first, aggregate, then transform the result further" + # nolint start + # For example, + # summarize(mean = sum(x) / n()) + # is effectively implemented as + # summarize(..temp0 = sum(x), ..temp1 = n()) %>% + # mutate(mean = ..temp0 / ..temp1) %>% + # select(-starts_with("..temp")) + # If this is the case, there will be expressions in post_mutate + # nolint end + if (length(ctx$post_mutate)) { + # Append post_mutate, and make sure order is correct + # according to input exprs (also dropping ..temp columns) + out$selected_columns <- c( + out$selected_columns, + ctx$post_mutate + )[c(.data$group_by_vars, names(exprs))] + } + out +} + +arrow_eval_or_stop <- function(expr, mask) { + # TODO: change arrow_eval error handling behavior? + out <- arrow_eval(expr, mask) + if (inherits(out, "try-error")) { + msg <- handle_arrow_not_supported(out, as_label(expr)) + stop(msg, call. = FALSE) + } + out } summarize_projection <- function(.data) { @@ -81,3 +120,109 @@ summarize_projection <- function(.data) { format_aggregation <- function(x) { paste0(x$fun, "(", x$data$ToString(), ")") } + +# This function handles each summarize expression and turns it into the +# appropriate combination of (1) aggregations (possibly temporary) and +# (2) post-aggregation transformations (mutate) +# The function returns nothing: it assigns into the `ctx` environment +summarize_eval <- function(name, quosure, ctx, recurse = FALSE) { + expr <- quo_get_expr(quosure) + ctx$quo_env <- quo_get_env(quosure) + + funs_in_expr <- all_funs(expr) + if (length(funs_in_expr) == 0) { + # If it is a scalar or field ref, no special handling required + ctx$aggregations[[name]] <- arrow_eval_or_stop(quosure, ctx$mask) + return() + } + + # Start inspecting the expr to see what aggregations it involves + agg_funs <- names(agg_funcs) + outer_agg <- funs_in_expr[1] %in% agg_funs + inner_agg <- funs_in_expr[-1] %in% agg_funs + + # First, pull out any aggregations wrapped in other function calls + if (any(inner_agg)) { + expr <- extract_aggregations(expr, ctx) + } + + # By this point, there are no more aggregation functions in expr + # except for possibly the outer function call: + # they've all been pulled out to ctx$aggregations, and in their place in expr + # there are variable names, which will correspond to field refs in the + # query object after aggregation and collapse(). + # So if we want to know if there are any aggregations inside expr, + # we have to look for them by their new var names + inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations) + + if (outer_agg) { + # This is something like agg(fun(x, y) + # It just works by normal arrow_eval, unless there's a mix of aggs and + # columns in the original data like agg(fun(x, agg(x))) + # (but that will have been caught in extract_aggregations()) + ctx$aggregations[[name]] <- arrow_eval_or_stop( + as_quosure(expr, ctx$quo_env), + ctx$mask + ) + return() + } else if (all(inner_agg_exprs)) { + # Something like: fun(agg(x), agg(y)) + # So based on the aggregations that have been extracted, mutate after + mutate_mask <- arrow_mask( + list(selected_columns = make_field_refs(names(ctx$aggregations))) + ) + ctx$post_mutate[[name]] <- arrow_eval_or_stop( + as_quosure(expr, ctx$quo_env), + mutate_mask + ) + return() + } + + # Backstop for any other odd cases, like fun(x, y) (i.e. no aggregation), + # or aggregation functions that aren't supported in Arrow (not in agg_funcs) + stop( + handle_arrow_not_supported( + quo_get_expr(quosure), + as_label(quo_get_expr(quosure)) + ), + call. = FALSE + ) +} + +# This function recurses through expr, pulls out any aggregation expressions, +# and inserts a variable name (field ref) in place of the aggregation +extract_aggregations <- function(expr, ctx) { + # Keep the input in case we need to raise an error message with it + original_expr <- expr + funs <- all_funs(expr) + if (length(funs) == 0) { + return(expr) + } else if (length(funs) > 1) { + # Recurse more + expr[-1] <- lapply(expr[-1], extract_aggregations, ctx) + } + if (funs[1] %in% names(agg_funcs)) { + inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations) + if (any(inner_agg_exprs) & !all(inner_agg_exprs)) { + # We can't aggregate over a combination of dataset columns and other + # aggregations (e.g. sum(x - mean(x))) + # TODO: support in ARROW-13926 + # TODO: Add "because" arg to explain _why_ it's not supported? + # TODO: this message could also say "not supported in summarize()" + # since some of these expressions may be legal elsewhere + stop( + handle_arrow_not_supported(original_expr, as_label(original_expr)), + call. = FALSE + ) + } + + # We have an aggregation expression with no other aggregations inside it, + # so arrow_eval the expression on the data and give it a ..temp name prefix, + # then insert that name (symbol) back into the expression so that we can + # mutate() on the result of the aggregation and reference this field. + tmpname <- paste0("..temp", length(ctx$aggregations)) + ctx$aggregations[[tmpname]] <- arrow_eval_or_stop(as_quosure(expr, ctx$quo_env), ctx$mask) + expr <- as.symbol(tmpname) + } + expr +} diff --git a/r/R/util.R b/r/R/util.R index 5958b0b3111..4811d5dbfbb 100644 --- a/r/R/util.R +++ b/r/R/util.R @@ -58,21 +58,10 @@ r_symbolic_constants <- c( "NA_integer_", "NA_real_", "NA_complex_", "NA_character_" ) -is_function <- function(expr, name) { - if (!is.call(expr)) { - return(FALSE) - } else { - if (deparse1(expr[[1]]) == name) { - return(TRUE) - } - out <- lapply(expr, is_function, name) - } - any(vapply(out, isTRUE, TRUE)) -} - all_funs <- function(expr) { - names <- all_names(expr) - names[vapply(names, function(name) is_function(expr, name), TRUE)] + # Don't use setdiff so that we preserve duplicates + out <- all.names(expr) + out[!(out %in% all.vars(expr))] } all_vars <- function(expr) { diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index 18596fcf30c..fa4bffe30d7 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -162,7 +162,6 @@ test_that("Group by var on dataset", { }) test_that("n()", { - withr::local_options(list(arrow.debug = TRUE)) expect_dplyr_equal( input %>% summarize(counts = n()) %>% @@ -350,7 +349,6 @@ test_that("Expressions on aggregations", { any = any(lgl), all = all(lgl) ) %>% - compute() %>% ungroup() %>% # TODO: loosen the restriction on mutate after group_by mutate(some = any & !all) %>% select(some_grouping, some) %>% @@ -358,7 +356,6 @@ test_that("Expressions on aggregations", { tbl ) # More concisely: - skip("TODO: ARROW-13778") expect_dplyr_equal( input %>% group_by(some_grouping) %>% @@ -366,6 +363,43 @@ test_that("Expressions on aggregations", { collect(), tbl ) + + # Save one of the aggregates first + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + any_lgl = any(lgl), + some = any_lgl & !all(lgl) + ) %>% + collect(), + tbl + ) + + # Make sure order of columns in result is correct + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + any_lgl = any(lgl), + some = any_lgl & !all(lgl), + n() + ) %>% + collect(), + tbl + ) + + # Aggregate on an aggregate (trivial but dplyr allows) + skip("Not supported") + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + any_lgl = any(any(lgl)) + ) %>% + collect(), + tbl + ) }) test_that("Summarize with 0 arguments", { @@ -377,3 +411,60 @@ test_that("Summarize with 0 arguments", { tbl ) }) + +test_that("Not (yet) supported: implicit join", { + withr::local_options(list(arrow.debug = TRUE)) + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + sum((dbl - mean(dbl))^2) + ) %>% + collect(), + tbl, + warning = "Expression sum\\(\\(dbl - mean\\(dbl\\)\\)\\^2\\) not supported in Arrow; pulling data into R" + ) + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + sum(dbl - mean(dbl)) + ) %>% + collect(), + tbl, + warning = "Expression sum\\(dbl - mean\\(dbl\\)\\) not supported in Arrow; pulling data into R" + ) + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + sqrt(sum((dbl - mean(dbl))^2) / (n() - 1L)) + ) %>% + collect(), + tbl, + warning = "Expression sum\\(\\(dbl - mean\\(dbl\\)\\)\\^2\\) not supported in Arrow; pulling data into R" + ) + + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + dbl - mean(dbl) + ) %>% + collect(), + tbl, + warning = "Expression dbl - mean\\(dbl\\) not supported in Arrow; pulling data into R" + ) + + # This one could possibly be supported--in mutate() + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + dbl - int + ) %>% + collect(), + tbl, + warning = "Expression dbl - int not supported in Arrow; pulling data into R" + ) +})