Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions r/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ importFrom(rlang,"%||%")
importFrom(rlang,.data)
importFrom(rlang,abort)
importFrom(rlang,as_label)
importFrom(rlang,as_quosure)
importFrom(rlang,caller_env)
importFrom(rlang,dots_n)
importFrom(rlang,enexpr)
Expand All @@ -325,6 +326,7 @@ importFrom(rlang,is_quosure)
importFrom(rlang,list2)
importFrom(rlang,new_data_mask)
importFrom(rlang,new_environment)
importFrom(rlang,quo_get_env)
importFrom(rlang,quo_get_expr)
importFrom(rlang,quo_is_null)
importFrom(rlang,quo_name)
Expand Down
4 changes: 2 additions & 2 deletions r/R/arrow-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
#' @importFrom assertthat assert_that is.string
#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos
#' @importFrom rlang eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec
#' @importFrom rlang is_bare_character quo_get_expr quo_set_expr .data seq2 is_quosure enexpr enexprs
#' @importFrom rlang expr caller_env is_character quo_name
#' @importFrom rlang is_bare_character quo_get_expr quo_get_env quo_set_expr .data seq2
#' @importFrom rlang expr caller_env is_character quo_name is_quosure enexpr enexprs as_quosure
#' @importFrom tidyselect vars_pull vars_rename vars_select eval_select
#' @useDynLib arrow, .registration = TRUE
#' @keywords internal
Expand Down
1 change: 1 addition & 0 deletions r/R/dplyr-eval.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ arrow_eval <- function(expr, mask) {
# else, for things not supported by Arrow return a "try-error",
# which we'll handle differently
msg <- conditionMessage(e)
if (getOption("arrow.debug", FALSE)) print(msg)
patterns <- .cache$i18ized_error_pattern
if (is.null(patterns)) {
patterns <- i18ize_error_messages()
Expand Down
177 changes: 161 additions & 16 deletions r/R/dplyr-summarize.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ summarise.arrow_dplyr_query <- function(.data, ..., .engine = c("arrow", "duckdb
unlist(lapply(exprs, all.vars)), # vars referenced in summarise
dplyr::group_vars(.data) # vars needed for grouping
))
.data <- dplyr::select(.data, vars_to_keep)
# If exprs rely on the results of previous exprs
# (total = sum(x), mean = total / n())
# then not all vars will correspond to columns in the data,
# so don't try to select() them (use intersect() to exclude them)
# Note that this select() isn't useful for the Arrow summarize implementation
# because it will effectively project to keep what it needs anyway,
# but the duckdb and data.frame fallback versions do benefit from select here
.data <- dplyr::select(.data, intersect(vars_to_keep, names(.data)))
if (match.arg(.engine) == "duckdb") {
dplyr::summarise(to_duckdb(.data), ...)
} else {
Expand All @@ -42,33 +49,65 @@ summarise.arrow_dplyr_query <- function(.data, ..., .engine = c("arrow", "duckdb
}
summarise.Dataset <- summarise.ArrowTabular <- summarise.arrow_dplyr_query

# This is the Arrow summarize implementation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great — finding these can be tricky in this code sometimes

do_arrow_summarize <- function(.data, ..., .groups = NULL) {
if (!is.null(.groups)) {
# ARROW-13550
abort("`summarize()` with `.groups` argument not supported in Arrow")
}
exprs <- ensure_named_exprs(quos(...))

mask <- arrow_mask(.data, aggregation = TRUE)

results <- empty_named_list()
# Create a stateful environment for recording our evaluated expressions
# It's more complex than other places because a single summarize() expr
# may result in multiple query nodes (Aggregate, Project),
# and we have to walk through the expressions to disentangle them.
ctx <- env(
mask = arrow_mask(.data, aggregation = TRUE),
aggregations = empty_named_list(),
post_mutate = empty_named_list()
)
for (i in seq_along(exprs)) {
# Iterate over the indices and not the names because names may be repeated
# (which overwrites the previous name)
new_var <- names(exprs)[i]
results[[new_var]] <- arrow_eval(exprs[[i]], mask)
if (inherits(results[[new_var]], "try-error")) {
msg <- handle_arrow_not_supported(
results[[new_var]],
as_label(exprs[[i]])
)
stop(msg, call. = FALSE)
}
summarize_eval(names(exprs)[i], exprs[[i]], ctx)
}

.data$aggregations <- results
# TODO: should in-memory query evaluate eagerly?
collapse.arrow_dplyr_query(.data)
# Apply the results to the .data object.
# First, the aggregations
.data$aggregations <- ctx$aggregations
# Then collapse the query so that the resulting query object can have
# additional operations applied to it
out <- collapse.arrow_dplyr_query(.data)
# The expressions may have been translated into
# "first, aggregate, then transform the result further"
# nolint start
# For example,
# summarize(mean = sum(x) / n())
# is effectively implemented as
# summarize(..temp0 = sum(x), ..temp1 = n()) %>%
# mutate(mean = ..temp0 / ..temp1) %>%
# select(-starts_with("..temp"))
# If this is the case, there will be expressions in post_mutate
# nolint end
if (length(ctx$post_mutate)) {
# Append post_mutate, and make sure order is correct
# according to input exprs (also dropping ..temp columns)
out$selected_columns <- c(
out$selected_columns,
ctx$post_mutate
)[c(.data$group_by_vars, names(exprs))]
}
out
}

arrow_eval_or_stop <- function(expr, mask) {
# TODO: change arrow_eval error handling behavior?
out <- arrow_eval(expr, mask)
if (inherits(out, "try-error")) {
msg <- handle_arrow_not_supported(out, as_label(expr))
stop(msg, call. = FALSE)
}
out
}

summarize_projection <- function(.data) {
Expand All @@ -81,3 +120,109 @@ summarize_projection <- function(.data) {
format_aggregation <- function(x) {
paste0(x$fun, "(", x$data$ToString(), ")")
}

# This function handles each summarize expression and turns it into the
# appropriate combination of (1) aggregations (possibly temporary) and
# (2) post-aggregation transformations (mutate)
# The function returns nothing: it assigns into the `ctx` environment
summarize_eval <- function(name, quosure, ctx, recurse = FALSE) {
expr <- quo_get_expr(quosure)
ctx$quo_env <- quo_get_env(quosure)

funs_in_expr <- all_funs(expr)
if (length(funs_in_expr) == 0) {
# If it is a scalar or field ref, no special handling required
ctx$aggregations[[name]] <- arrow_eval_or_stop(quosure, ctx$mask)
return()
}

# Start inspecting the expr to see what aggregations it involves
agg_funs <- names(agg_funcs)
outer_agg <- funs_in_expr[1] %in% agg_funs
inner_agg <- funs_in_expr[-1] %in% agg_funs

# First, pull out any aggregations wrapped in other function calls
if (any(inner_agg)) {
expr <- extract_aggregations(expr, ctx)
}

# By this point, there are no more aggregation functions in expr
# except for possibly the outer function call:
# they've all been pulled out to ctx$aggregations, and in their place in expr
# there are variable names, which will correspond to field refs in the
# query object after aggregation and collapse().
# So if we want to know if there are any aggregations inside expr,
# we have to look for them by their new var names
inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations)

if (outer_agg) {
# This is something like agg(fun(x, y)
# It just works by normal arrow_eval, unless there's a mix of aggs and
# columns in the original data like agg(fun(x, agg(x)))
# (but that will have been caught in extract_aggregations())
ctx$aggregations[[name]] <- arrow_eval_or_stop(
as_quosure(expr, ctx$quo_env),
ctx$mask
)
return()
} else if (all(inner_agg_exprs)) {
# Something like: fun(agg(x), agg(y))
# So based on the aggregations that have been extracted, mutate after
mutate_mask <- arrow_mask(
list(selected_columns = make_field_refs(names(ctx$aggregations)))
)
ctx$post_mutate[[name]] <- arrow_eval_or_stop(
as_quosure(expr, ctx$quo_env),
mutate_mask
)
return()
}

# Backstop for any other odd cases, like fun(x, y) (i.e. no aggregation),
# or aggregation functions that aren't supported in Arrow (not in agg_funcs)
stop(
handle_arrow_not_supported(
quo_get_expr(quosure),
as_label(quo_get_expr(quosure))
),
call. = FALSE
)
}

# This function recurses through expr, pulls out any aggregation expressions,
# and inserts a variable name (field ref) in place of the aggregation
extract_aggregations <- function(expr, ctx) {
# Keep the input in case we need to raise an error message with it
original_expr <- expr
funs <- all_funs(expr)
if (length(funs) == 0) {
return(expr)
} else if (length(funs) > 1) {
# Recurse more
expr[-1] <- lapply(expr[-1], extract_aggregations, ctx)
}
if (funs[1] %in% names(agg_funcs)) {
inner_agg_exprs <- all_vars(expr) %in% names(ctx$aggregations)
if (any(inner_agg_exprs) & !all(inner_agg_exprs)) {
# We can't aggregate over a combination of dataset columns and other
# aggregations (e.g. sum(x - mean(x)))
# TODO: support in ARROW-13926
# TODO: Add "because" arg to explain _why_ it's not supported?
# TODO: this message could also say "not supported in summarize()"
# since some of these expressions may be legal elsewhere
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This last TODO I think is important — I think anyone who gets a not supported message will assume that expression is not supported in Arrow at all anywhere.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah maybe so, but all of the examples I can think of are just bad form (like summarize(new_column = dbl - int), which really should be expressed in mutate()). Anyway I don't think this PR is the last word on any of this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stop(
handle_arrow_not_supported(original_expr, as_label(original_expr)),
call. = FALSE
)
}

# We have an aggregation expression with no other aggregations inside it,
# so arrow_eval the expression on the data and give it a ..temp name prefix,
# then insert that name (symbol) back into the expression so that we can
# mutate() on the result of the aggregation and reference this field.
tmpname <- paste0("..temp", length(ctx$aggregations))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is vanishingly rare, but could we at this point check for any fields named in tmpname here? I can't imagine anyone would have one, but it would be better to error. Or would this error later without proactive checking?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only way this would fail that I can think of is that if you deliberately did something like summarize(..temp1 = mean(x), avg = sum(x) / n()), and, well, can't you just not do that? Open to other suggestions, the names don't matter as long as they're unique and valid as names (so you can't just use the expression like "sum(x)" as the name).

ctx$aggregations[[tmpname]] <- arrow_eval_or_stop(as_quosure(expr, ctx$quo_env), ctx$mask)
expr <- as.symbol(tmpname)
}
expr
}
17 changes: 3 additions & 14 deletions r/R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,10 @@ r_symbolic_constants <- c(
"NA_integer_", "NA_real_", "NA_complex_", "NA_character_"
)

is_function <- function(expr, name) {
if (!is.call(expr)) {
return(FALSE)
} else {
if (deparse1(expr[[1]]) == name) {
return(TRUE)
}
out <- lapply(expr, is_function, name)
}
any(vapply(out, isTRUE, TRUE))
}

all_funs <- function(expr) {
names <- all_names(expr)
names[vapply(names, function(name) is_function(expr, name), TRUE)]
# Don't use setdiff so that we preserve duplicates
out <- all.names(expr)
out[!(out %in% all.vars(expr))]
}

all_vars <- function(expr) {
Expand Down
97 changes: 94 additions & 3 deletions r/tests/testthat/test-dplyr-summarize.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ test_that("Group by var on dataset", {
})

test_that("n()", {
withr::local_options(list(arrow.debug = TRUE))
expect_dplyr_equal(
input %>%
summarize(counts = n()) %>%
Expand Down Expand Up @@ -350,22 +349,57 @@ test_that("Expressions on aggregations", {
any = any(lgl),
all = all(lgl)
) %>%
compute() %>%
ungroup() %>% # TODO: loosen the restriction on mutate after group_by
mutate(some = any & !all) %>%
select(some_grouping, some) %>%
collect(),
tbl
)
# More concisely:
skip("TODO: ARROW-13778")
expect_dplyr_equal(
input %>%
group_by(some_grouping) %>%
summarize(any(lgl) & !all(lgl)) %>%
collect(),
tbl
)

# Save one of the aggregates first
expect_dplyr_equal(
input %>%
group_by(some_grouping) %>%
summarize(
any_lgl = any(lgl),
some = any_lgl & !all(lgl)
) %>%
collect(),
tbl
)

# Make sure order of columns in result is correct
expect_dplyr_equal(
input %>%
group_by(some_grouping) %>%
summarize(
any_lgl = any(lgl),
some = any_lgl & !all(lgl),
n()
) %>%
collect(),
tbl
)

# Aggregate on an aggregate (trivial but dplyr allows)
skip("Not supported")
expect_dplyr_equal(
input %>%
group_by(some_grouping) %>%
summarize(
any_lgl = any(any(lgl))
) %>%
collect(),
tbl
)
})

test_that("Summarize with 0 arguments", {
Expand All @@ -377,3 +411,60 @@ test_that("Summarize with 0 arguments", {
tbl
)
})

test_that("Not (yet) supported: implicit join", {
withr::local_options(list(arrow.debug = TRUE))
expect_dplyr_equal(
input %>%
group_by(some_grouping) %>%
summarize(
sum((dbl - mean(dbl))^2)
) %>%
collect(),
tbl,
warning = "Expression sum\\(\\(dbl - mean\\(dbl\\)\\)\\^2\\) not supported in Arrow; pulling data into R"
)
expect_dplyr_equal(
input %>%
group_by(some_grouping) %>%
summarize(
sum(dbl - mean(dbl))
) %>%
collect(),
tbl,
warning = "Expression sum\\(dbl - mean\\(dbl\\)\\) not supported in Arrow; pulling data into R"
)
expect_dplyr_equal(
input %>%
group_by(some_grouping) %>%
summarize(
sqrt(sum((dbl - mean(dbl))^2) / (n() - 1L))
) %>%
collect(),
tbl,
warning = "Expression sum\\(\\(dbl - mean\\(dbl\\)\\)\\^2\\) not supported in Arrow; pulling data into R"
)

expect_dplyr_equal(
input %>%
group_by(some_grouping) %>%
summarize(
dbl - mean(dbl)
) %>%
collect(),
tbl,
warning = "Expression dbl - mean\\(dbl\\) not supported in Arrow; pulling data into R"
)

# This one could possibly be supported--in mutate()
expect_dplyr_equal(
input %>%
group_by(some_grouping) %>%
summarize(
dbl - int
) %>%
collect(),
tbl,
warning = "Expression dbl - int not supported in Arrow; pulling data into R"
)
})