diff --git a/r/NAMESPACE b/r/NAMESPACE index 61ca5d8fdc4..f89a7352ec9 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -323,6 +323,7 @@ importFrom(rlang,is_bare_character) importFrom(rlang,is_character) importFrom(rlang,is_false) importFrom(rlang,is_integerish) +importFrom(rlang,is_interactive) importFrom(rlang,is_quosure) importFrom(rlang,list2) importFrom(rlang,new_data_mask) diff --git a/r/NEWS.md b/r/NEWS.md index 381b6334909..bbefb33bbc0 100644 --- a/r/NEWS.md +++ b/r/NEWS.md @@ -23,7 +23,7 @@ There are now two ways to query Arrow data: ## 1. Grouped aggregation in Arrow -`dplyr::summarize()`, both grouped and ungrouped, is now implemented for Arrow Datasets, Tables, and RecordBatches. Because data is scanned in chunks, you can aggregate over larger-than-memory datasets backed by many files. Supported aggregation functions include `n()`, `n_distinct()`, `sum()`, `mean()`, `var()`, `sd()`, `any()`, and `all()`. +`dplyr::summarize()`, both grouped and ungrouped, is now implemented for Arrow Datasets, Tables, and RecordBatches. Because data is scanned in chunks, you can aggregate over larger-than-memory datasets backed by many files. Supported aggregation functions include `n()`, `n_distinct()`, `min(),` `max()`, `sum()`, `mean()`, `var()`, `sd()`, `any()`, and `all()`. `median()` and `quantile()` with one probability are also supported and currently return approximate results using the t-digest algorithm. This enhancement does change the behavior of `summarize()` and `collect()` in some cases: see "Breaking changes" below for details. diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index e2d3ec18846..ec9c0500494 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -21,7 +21,7 @@ #' @importFrom assertthat assert_that is.string #' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos #' @importFrom rlang eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec -#' @importFrom rlang is_bare_character quo_get_expr quo_get_env quo_set_expr .data seq2 +#' @importFrom rlang is_bare_character quo_get_expr quo_get_env quo_set_expr .data seq2 is_interactive #' @importFrom rlang expr caller_env is_character quo_name is_quosure enexpr enexprs as_quosure #' @importFrom tidyselect vars_pull vars_rename vars_select eval_select #' @useDynLib arrow, .registration = TRUE diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R index 8a5488bf599..06914abe072 100644 --- a/r/R/dplyr-collect.R +++ b/r/R/dplyr-collect.R @@ -96,9 +96,10 @@ implicit_schema <- function(.data) { # and they get projected to this order after aggregation) # * Infer the output types from the aggregations group_fields <- new_fields[.data$group_by_vars] + hash <- length(.data$group_by_vars) > 0 agg_fields <- imap( new_fields[setdiff(names(new_fields), .data$group_by_vars)], - ~ output_type(.data$aggregations[[.y]][["fun"]], .x) + ~ output_type(.data$aggregations[[.y]][["fun"]], .x, hash) ) new_fields <- c(group_fields, agg_fields) } diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index 765438ac908..e72842e1d2b 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -889,6 +889,37 @@ agg_funcs$var <- function(x, na.rm = FALSE, ddof = 1) { options = list(skip_nulls = na.rm, min_count = 0L, ddof = ddof) ) } +agg_funcs$quantile <- function(x, probs, na.rm = FALSE) { + if (length(probs) != 1) { + arrow_not_supported("quantile() with length(probs) != 1") + } + # TODO: Bind to the Arrow function that returns an exact quantile and remove + # this warning (ARROW-14021) + warn( + "quantile() currently returns an approximate quantile in Arrow", + .frequency = ifelse(is_interactive(), "once", "always"), + .frequency_id = "arrow.quantile.approximate" + ) + list( + fun = "tdigest", + data = x, + options = list(skip_nulls = na.rm, q = probs) + ) +} +agg_funcs$median <- function(x, na.rm = FALSE) { + # TODO: Bind to the Arrow function that returns an exact median and remove + # this warning (ARROW-14021) + warn( + "median() currently returns an approximate median in Arrow", + .frequency = ifelse(is_interactive(), "once", "always"), + .frequency_id = "arrow.median.approximate" + ) + list( + fun = "approximate_median", + data = x, + options = list(skip_nulls = na.rm) + ) +} agg_funcs$n_distinct <- function(x, na.rm = FALSE) { list( fun = "count_distinct", @@ -926,15 +957,21 @@ agg_funcs$max <- function(..., na.rm = FALSE) { ) } -output_type <- function(fun, input_type) { +output_type <- function(fun, input_type, hash) { # These are quick and dirty heuristics. if (fun %in% c("any", "all")) { bool() } else if (fun %in% "sum") { # It may upcast to a bigger type but this is close enough input_type - } else if (fun %in% c("mean", "stddev", "variance")) { + } else if (fun %in% c("mean", "stddev", "variance", "approximate_median")) { float64() + } else if (fun %in% "tdigest") { + if (hash) { + fixed_size_list_of(float64(), 1L) + } else { + float64() + } } else { # Just so things don't error, assume the resulting type is the same input_type diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R index beb18e82039..6158026603f 100644 --- a/r/R/dplyr-summarize.R +++ b/r/R/dplyr-summarize.R @@ -65,7 +65,12 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) { for (i in seq_along(exprs)) { # Iterate over the indices and not the names because names may be repeated # (which overwrites the previous name) - summarize_eval(names(exprs)[i], exprs[[i]], ctx) + summarize_eval( + names(exprs)[i], + exprs[[i]], + ctx, + length(.data$group_by_vars) > 0 + ) } # Apply the results to the .data object. @@ -150,7 +155,7 @@ format_aggregation <- function(x) { # appropriate combination of (1) aggregations (possibly temporary) and # (2) post-aggregation transformations (mutate) # The function returns nothing: it assigns into the `ctx` environment -summarize_eval <- function(name, quosure, ctx, recurse = FALSE) { +summarize_eval <- function(name, quosure, ctx, hash, recurse = FALSE) { expr <- quo_get_expr(quosure) ctx$quo_env <- quo_get_env(quosure) @@ -161,6 +166,15 @@ summarize_eval <- function(name, quosure, ctx, recurse = FALSE) { return() } + # For the quantile() binding in the hash aggregation case, we need to mutate + # the list output from the Arrow hash_tdigest kernel to flatten it into a + # column of type float64. We do that by modifying the unevaluated expression + # to replace quantile(...) with arrow_list_element(quantile(...), 0L) + if (hash && "quantile" %in% funs_in_expr) { + expr <- wrap_hash_quantile(expr) + funs_in_expr <- all_funs(expr) + } + # Start inspecting the expr to see what aggregations it involves agg_funs <- names(agg_funcs) outer_agg <- funs_in_expr[1] %in% agg_funs @@ -251,3 +265,17 @@ extract_aggregations <- function(expr, ctx) { } expr } + +# This function recurses through expr and wraps each call to quantile() with a +# call to arrow_list_element() +wrap_hash_quantile <- function(expr) { + if (length(expr) == 1) { + return(expr) + } else { + if (is.call(expr) && expr[[1]] == quote(quantile)) { + return(str2lang(paste0("arrow_list_element(", deparse1(expr), ", 0L)"))) + } else { + return(as.call(lapply(expr, wrap_hash_quantile))) + } + } +} diff --git a/r/R/util.R b/r/R/util.R index 94d3a78782f..918dba07eae 100644 --- a/r/R/util.R +++ b/r/R/util.R @@ -22,6 +22,13 @@ if (!exists("deparse1")) { } } +# for compatibility with R versions earlier than 3.6.0 +if (!exists("str2lang")) { + str2lang <- function(s) { + parse(text = s, keep.source = FALSE)[[1]] + } +} + oxford_paste <- function(x, conjunction = "and", quote = TRUE) { if (quote && is.character(x)) { x <- paste0('"', x, '"') diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 8268fd10eee..952a7cc3a9a 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -172,10 +172,11 @@ std::shared_ptr make_compute_options( } if (func_name == "all" || func_name == "hash_all" || func_name == "any" || - func_name == "hash_any" || func_name == "mean" || func_name == "hash_mean" || - func_name == "min_max" || func_name == "hash_min_max" || func_name == "min" || - func_name == "hash_min" || func_name == "max" || func_name == "hash_max" || - func_name == "sum" || func_name == "hash_sum") { + func_name == "hash_any" || func_name == "approximate_median" || + func_name == "hash_approximate_median" || func_name == "mean" || + func_name == "hash_mean" || func_name == "min_max" || func_name == "hash_min_max" || + func_name == "min" || func_name == "hash_min" || func_name == "max" || + func_name == "hash_max" || func_name == "sum" || func_name == "hash_sum") { using Options = arrow::compute::ScalarAggregateOptions; auto out = std::make_shared(Options::Defaults()); if (!Rf_isNull(options["min_count"])) { @@ -187,6 +188,18 @@ std::shared_ptr make_compute_options( return out; } + if (func_name == "tdigest" || func_name == "hash_tdigest") { + using Options = arrow::compute::TDigestOptions; + auto out = std::make_shared(Options::Defaults()); + if (!Rf_isNull(options["q"])) { + out->q = cpp11::as_cpp>(options["q"]); + } + if (!Rf_isNull(options["skip_nulls"])) { + out->skip_nulls = cpp11::as_cpp(options["skip_nulls"]); + } + return out; + } + if (func_name == "count") { using Options = arrow::compute::CountOptions; auto out = std::make_shared(Options::Defaults()); diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index c74ed6aa938..8739c70bbda 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -229,6 +229,151 @@ test_that("Group by n_distinct() on dataset", { ) }) +test_that("median()", { + # When medians are integer-valued, stats::median() sometimes returns output of + # type integer, whereas whereas the Arrow approx_median kernels always return + # output of type float64. The calls to median(int, ...) in the tests below + # are enclosed in as.double() to work around this known difference. + + # with groups + expect_dplyr_equal( + input %>% + group_by(some_grouping) %>% + summarize( + med_dbl = median(dbl), + med_int = as.double(median(int)), + med_dbl_narmf = median(dbl, FALSE), + med_int_narmf = as.double(median(int, na.rm = FALSE)), + med_dbl_narmt = median(dbl, na.rm = TRUE), + med_int_narmt = as.double(median(int, TRUE)) + ) %>% + arrange(some_grouping) %>% + collect(), + tbl, + warning = "median\\(\\) currently returns an approximate median in Arrow" + ) + # without groups, with na.rm = TRUE + expect_dplyr_equal( + input %>% + summarize( + med_dbl_narmt = median(dbl, na.rm = TRUE), + med_int_narmt = as.double(median(int, TRUE)) + ) %>% + collect(), + tbl, + warning = "median\\(\\) currently returns an approximate median in Arrow" + ) + # without groups, with na.rm = FALSE (the default) + expect_dplyr_equal( + input %>% + summarize( + med_dbl = median(dbl), + med_int = as.double(median(int)), + med_dbl_narmf = median(dbl, FALSE), + med_int_narmf = as.double(median(int, na.rm = FALSE)) + ) %>% + collect(), + tbl, + warning = "median\\(\\) currently returns an approximate median in Arrow" + ) +}) + +test_that("quantile()", { + # The default method for stats::quantile() throws an error when na.rm = FALSE + # and the input contains NA or NaN, whereas the Arrow tdigest kernels return + # null in this situation. To work around this known difference, the tests + # below always use na.rm = TRUE when the data contains NA or NaN. + + # The default method for stats::quantile() has an argument `names` that + # controls whether the result has a names attribute. It defaults to + # names = TRUE. With Arrow, it is not possible to give the result a names + # attribute, so the quantile() binding in Arrow does not accept a `names` + # argument. Differences in this names attribute cause expect_dplyr_equal() to + # report that the objects are not equal, so we do not use expect_dplyr_equal() + # in the tests below. + + # The tests below all use probs = 0.5 because other values cause differences + # between the exact quantiles returned by R and the approximate quantiles + # returned by Arrow. + + # When quantiles are integer-valued, stats::quantile() sometimes returns + # output of type integer, whereas whereas the Arrow tdigest kernels always + # return output of type float64. The calls to quantile(int, ...) in the tests + # below are enclosed in as.double() to work around this known difference. + + # with groups + expect_warning( + expect_equal( + tbl %>% + group_by(some_grouping) %>% + summarize( + q_dbl = quantile(dbl, probs = 0.5, na.rm = TRUE, names = FALSE), + q_int = as.double( + quantile(int, probs = 0.5, na.rm = TRUE, names = FALSE) + ) + ) %>% + arrange(some_grouping), + Table$create(tbl) %>% + group_by(some_grouping) %>% + summarize( + q_dbl = quantile(dbl, probs = 0.5, na.rm = TRUE), + q_int = as.double(quantile(int, probs = 0.5, na.rm = TRUE)) + ) %>% + arrange(some_grouping) %>% + collect() + ), + "quantile() currently returns an approximate quantile in Arrow", + fixed = TRUE + ) + + # without groups + expect_warning( + expect_equal( + tbl %>% + summarize( + q_dbl = quantile(dbl, probs = 0.5, na.rm = TRUE, names = FALSE), + q_int = as.double( + quantile(int, probs = 0.5, na.rm = TRUE, names = FALSE) + ) + ), + Table$create(tbl) %>% + summarize( + q_dbl = quantile(dbl, probs = 0.5, na.rm = TRUE), + q_int = as.double(quantile(int, probs = 0.5, na.rm = TRUE)) + ) %>% + collect() + ), + "quantile() currently returns an approximate quantile in Arrow", + fixed = TRUE + ) + + # with missing values and na.rm = FALSE + expect_warning( + expect_equal( + tibble( + q_dbl = NA_real_, + q_int = NA_real_ + ), + Table$create(tbl) %>% + summarize( + q_dbl = quantile(dbl, probs = 0.5, na.rm = FALSE), + q_int = as.double(quantile(int, probs = 0.5, na.rm = FALSE)) + ) %>% + collect() + ), + "quantile() currently returns an approximate quantile in Arrow", + fixed = TRUE + ) + + # with a vector of 2+ probs + expect_warning( + Table$create(tbl) %>% + summarize(q = quantile(dbl, probs = c(0.2, 0.8), na.rm = FALSE)), + "quantile() with length(probs) != 1 not supported by Arrow", + fixed = TRUE + ) +}) + test_that("summarize() with min() and max()", { expect_dplyr_equal( input %>%