-
Notifications
You must be signed in to change notification settings - Fork 4k
ARROW-13778: [R] Handle complex summarize expressions #11108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
529f80f
0abb74e
9d58040
acb0461
8d037e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,14 @@ summarise.arrow_dplyr_query <- function(.data, ..., .engine = c("arrow", "duckdb | |
| unlist(lapply(exprs, all.vars)), # vars referenced in summarise | ||
| dplyr::group_vars(.data) # vars needed for grouping | ||
| )) | ||
| .data <- dplyr::select(.data, vars_to_keep) | ||
| # If exprs rely on the results of previous exprs | ||
| # (total = sum(x), mean = total / n()) | ||
| # then not all vars will correspond to columns in the data, | ||
| # so don't try to select() them (use intersect() to exclude them) | ||
| # Note that this select() isn't useful for the Arrow summarize implementation | ||
| # because it will effectively project to keep what it needs anyway, | ||
| # but the duckdb and data.frame fallback versions do benefit from select here | ||
| .data <- dplyr::select(.data, intersect(vars_to_keep, names(.data))) | ||
| if (match.arg(.engine) == "duckdb") { | ||
| dplyr::summarise(to_duckdb(.data), ...) | ||
| } else { | ||
|
|
@@ -42,33 +49,65 @@ summarise.arrow_dplyr_query <- function(.data, ..., .engine = c("arrow", "duckdb | |
| } | ||
| summarise.Dataset <- summarise.ArrowTabular <- summarise.arrow_dplyr_query | ||
|
|
||
| # This is the Arrow summarize implementation | ||
| do_arrow_summarize <- function(.data, ..., .groups = NULL) { | ||
| if (!is.null(.groups)) { | ||
| # ARROW-13550 | ||
| abort("`summarize()` with `.groups` argument not supported in Arrow") | ||
| } | ||
| exprs <- ensure_named_exprs(quos(...)) | ||
|
|
||
| mask <- arrow_mask(.data, aggregation = TRUE) | ||
|
|
||
| results <- empty_named_list() | ||
| # Create a stateful environment for recording our evaluated expressions | ||
| # It's more complex than other places because a single summarize() expr | ||
| # may result in multiple query nodes (Aggregate, Project), | ||
| # and we have to walk through the expressions to disentangle them. | ||
| ctx <- env( | ||
| mask = arrow_mask(.data, aggregation = TRUE), | ||
| aggregations = empty_named_list(), | ||
| post_mutate = empty_named_list() | ||
| ) | ||
| for (i in seq_along(exprs)) { | ||
| # Iterate over the indices and not the names because names may be repeated | ||
| # (which overwrites the previous name) | ||
| new_var <- names(exprs)[i] | ||
| results[[new_var]] <- arrow_eval(exprs[[i]], mask) | ||
| if (inherits(results[[new_var]], "try-error")) { | ||
| msg <- handle_arrow_not_supported( | ||
| results[[new_var]], | ||
| as_label(exprs[[i]]) | ||
| ) | ||
| stop(msg, call. = FALSE) | ||
| } | ||
| summarize_eval(names(exprs)[i], exprs[[i]], ctx) | ||
| } | ||
|
|
||
| .data$aggregations <- results | ||
| # TODO: should in-memory query evaluate eagerly? | ||
| collapse.arrow_dplyr_query(.data) | ||
| # Apply the results to the .data object. | ||
| # First, the aggregations | ||
| .data$aggregations <- ctx$aggregations | ||
| # Then collapse the query so that the resulting query object can have | ||
| # additional operations applied to it | ||
| out <- collapse.arrow_dplyr_query(.data) | ||
| # The expressions may have been translated into | ||
| # "first, aggregate, then transform the result further" | ||
| # nolint start | ||
| # For example, | ||
nealrichardson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # summarize(mean = sum(x) / n()) | ||
| # is effectively implemented as | ||
| # summarize(..temp0 = sum(x), ..temp1 = n()) %>% | ||
| # mutate(mean = ..temp0 / ..temp1) %>% | ||
| # select(-starts_with("..temp")) | ||
| # If this is the case, there will be expressions in post_mutate | ||
nealrichardson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # nolint end | ||
| if (length(ctx$post_mutate)) { | ||
| # Append post_mutate, and make sure order is correct | ||
| # according to input exprs (also dropping ..temp columns) | ||
| out$selected_columns <- c( | ||
| out$selected_columns, | ||
| ctx$post_mutate | ||
| )[c(.data$group_by_vars, names(exprs))] | ||
| } | ||
| out | ||
| } | ||
|
|
||
| arrow_eval_or_stop <- function(expr, mask) { | ||
| # TODO: change arrow_eval error handling behavior? | ||
| out <- arrow_eval(expr, mask) | ||
| if (inherits(out, "try-error")) { | ||
| msg <- handle_arrow_not_supported(out, as_label(expr)) | ||
| stop(msg, call. = FALSE) | ||
| } | ||
| out | ||
| } | ||
|
|
||
| summarize_projection <- function(.data) { | ||
|
|
@@ -81,3 +120,109 @@ summarize_projection <- function(.data) { | |
| format_aggregation <- function(x) { | ||
| paste0(x$fun, "(", x$data$ToString(), ")") | ||
| } | ||
|
|
||
| # This function handles each summarize expression and turns it into the | ||
| # appropriate combination of (1) aggregations (possibly temporary) and | ||
| # (2) post-aggregation transformations (mutate) | ||
| # The function returns nothing: it assigns into the `ctx` environment | ||
| summarize_eval <- function(name, quosure, ctx, recurse = FALSE) { | ||
| expr <- quo_get_expr(quosure) | ||
| ctx$quo_env <- quo_get_env(quosure) | ||
|
|
||
| funs_in_expr <- all_funs(expr) | ||
| if (length(funs_in_expr) == 0) { | ||
| # If it is a scalar or field ref, no special handling required | ||
| ctx$aggregations[[name]] <- arrow_eval_or_stop(quosure, ctx$mask) | ||
| return() | ||
| } | ||
|
|
||
| # Start inspecting the expr to see what aggregations it involves | ||
| agg_funs <- names(agg_funcs) | ||
| outer_agg <- funs_in_expr[1] %in% agg_funs | ||
| inner_agg <- funs_in_expr[-1] %in% agg_funs | ||
|
|
||
| # First, pull out any aggregations wrapped in other function calls | ||
| if (any(inner_agg)) { | ||
| expr <- extract_aggregations(expr, ctx) | ||
| } | ||
|
|
||
| # By this point, there are no more aggregation functions in expr | ||
| # except for possibly the outer function call: | ||
| # they've all been pulled out to ctx$aggregations, and in their place in expr | ||
| # there are variable names, which will correspond to field refs in the | ||
| # query object after aggregation and collapse(). | ||
| # So if we want to know if there are any aggregations inside expr, | ||
| # we have to look for them by their new var names | ||
| inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations) | ||
|
|
||
| if (outer_agg) { | ||
| # This is something like agg(fun(x, y) | ||
| # It just works by normal arrow_eval, unless there's a mix of aggs and | ||
| # columns in the original data like agg(fun(x, agg(x))) | ||
| # (but that will have been caught in extract_aggregations()) | ||
| ctx$aggregations[[name]] <- arrow_eval_or_stop( | ||
| as_quosure(expr, ctx$quo_env), | ||
| ctx$mask | ||
| ) | ||
| return() | ||
| } else if (all(inner_agg_exprs)) { | ||
| # Something like: fun(agg(x), agg(y)) | ||
| # So based on the aggregations that have been extracted, mutate after | ||
| mutate_mask <- arrow_mask( | ||
| list(selected_columns = make_field_refs(names(ctx$aggregations))) | ||
| ) | ||
| ctx$post_mutate[[name]] <- arrow_eval_or_stop( | ||
| as_quosure(expr, ctx$quo_env), | ||
| mutate_mask | ||
| ) | ||
| return() | ||
| } | ||
|
|
||
| # Backstop for any other odd cases, like fun(x, y) (i.e. no aggregation), | ||
| # or aggregation functions that aren't supported in Arrow (not in agg_funcs) | ||
| stop( | ||
| handle_arrow_not_supported( | ||
| quo_get_expr(quosure), | ||
| as_label(quo_get_expr(quosure)) | ||
| ), | ||
| call. = FALSE | ||
| ) | ||
| } | ||
|
|
||
| # This function recurses through expr, pulls out any aggregation expressions, | ||
| # and inserts a variable name (field ref) in place of the aggregation | ||
| extract_aggregations <- function(expr, ctx) { | ||
| # Keep the input in case we need to raise an error message with it | ||
| original_expr <- expr | ||
| funs <- all_funs(expr) | ||
| if (length(funs) == 0) { | ||
| return(expr) | ||
| } else if (length(funs) > 1) { | ||
| # Recurse more | ||
| expr[-1] <- lapply(expr[-1], extract_aggregations, ctx) | ||
| } | ||
| if (funs[1] %in% names(agg_funcs)) { | ||
| inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations) | ||
| if (any(inner_agg_exprs) & !all(inner_agg_exprs)) { | ||
| # We can't aggregate over a combination of dataset columns and other | ||
| # aggregations (e.g. sum(x - mean(x))) | ||
| # TODO: support in ARROW-13926 | ||
| # TODO: Add "because" arg to explain _why_ it's not supported? | ||
| # TODO: this message could also say "not supported in summarize()" | ||
| # since some of these expressions may be legal elsewhere | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This last TODO I think is important — I think anyone who gets a not supported message will assume that expression is not supported in Arrow at all anywhere.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah maybe so, but all of the examples I can think of are just bad form (like
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| stop( | ||
| handle_arrow_not_supported(original_expr, as_label(original_expr)), | ||
| call. = FALSE | ||
| ) | ||
| } | ||
|
|
||
| # We have an aggregation expression with no other aggregations inside it, | ||
| # so arrow_eval the expression on the data and give it a ..temp name prefix, | ||
| # then insert that name (symbol) back into the expression so that we can | ||
| # mutate() on the result of the aggregation and reference this field. | ||
| tmpname <- paste0("..temp", length(ctx$aggregations)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is vanishingly rare, but could we at this point check for any fields named in
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only way this would fail that I can think of is that if you deliberately did something like |
||
| ctx$aggregations[[tmpname]] <- arrow_eval_or_stop(as_quosure(expr, ctx$quo_env), ctx$mask) | ||
| expr <- as.symbol(tmpname) | ||
| } | ||
| expr | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great — finding these can be tricky in this code sometimes