From 9e07243eb814ccee8bff15d860c8a89c6ade2dc7 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 17 Feb 2021 17:41:25 -0800 Subject: [PATCH 01/15] Rework selected_columns to hold field refs. Move to a similar model for array_expressions --- r/R/arrowExports.R | 4 ++ r/R/dataset-scan.R | 3 +- r/R/dataset-write.R | 2 +- r/R/dplyr.R | 78 ++++++++++++++++++++---------- r/R/expression.R | 33 ++++++++++++- r/src/arrowExports.cpp | 9 ++++ r/src/expression.cpp | 7 +++ r/tests/testthat/test-dplyr.R | 14 ++++-- r/tests/testthat/test-expression.R | 12 +++++ 9 files changed, 130 insertions(+), 32 deletions(-) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 3d0f31ce8f3..790232c8e21 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -744,6 +744,10 @@ dataset___expr__field_ref <- function(name){ .Call(`_arrow_dataset___expr__field_ref`, name) } +dataset___expr__get_field_ref_name <- function(ref){ + .Call(`_arrow_dataset___expr__get_field_ref_name`, ref) +} + dataset___expr__scalar <- function(x){ .Call(`_arrow_dataset___expr__scalar`, x) } diff --git a/r/R/dataset-scan.R b/r/R/dataset-scan.R index 45fc968ed08..16039d8fca6 100644 --- a/r/R/dataset-scan.R +++ b/r/R/dataset-scan.R @@ -71,7 +71,8 @@ Scanner$create <- function(dataset, if (inherits(dataset, "arrow_dplyr_query")) { return(Scanner$create( dataset$.data, - dataset$selected_columns, + # Note: selected_columns is no longer a character vector + map_chr(dataset$selected_columns, ~.$field_name), dataset$filtered_rows, use_threads, ... diff --git a/r/R/dataset-write.R b/r/R/dataset-write.R index c5c92926715..a6871aace9f 100644 --- a/r/R/dataset-write.R +++ b/r/R/dataset-write.R @@ -63,7 +63,7 @@ write_dataset <- function(dataset, ...) { if (inherits(dataset, "arrow_dplyr_query")) { # We can select a subset of columns but we can't rename them - if (!all(dataset$selected_columns == names(dataset$selected_columns))) { + if (!all(map_chr(dataset$selected_columns, ~.$field_name) == names(dataset$selected_columns))) { stop("Renaming columns when writing a dataset is not yet supported", call. = FALSE) } # partitioning vars need to be in the `select` schema diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 32713741b53..60ae5c0b69d 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -30,14 +30,19 @@ arrow_dplyr_query <- function(.data) { if (inherits(.data, "arrow_dplyr_query")) { return(.data) } + # selected_columns is a named list: + # * contents are references/expressions pointing to the data + # * names are the names they should be in the end (i.e. this + # records any renaming) + if (inherits(.data, "Dataset")) { + selected_columns <- lapply(names(.data), Expression$field_ref) + } else { + selected_columns <- lapply(names(.data), function(x) array_expression("array_ref", field_name = x)) + } structure( list( .data = .data$clone(), - # selected_columns is a named character vector: - # * vector contents are the names of the columns in the data - # * vector names are the names they should be in the end (i.e. this - # records any renaming) - selected_columns = set_names(names(.data)), + selected_columns = set_names(selected_columns, names(.data)), # filtered_rows will be an Expression filtered_rows = TRUE, # group_by_vars is a character vector of columns (as renamed) @@ -52,6 +57,8 @@ arrow_dplyr_query <- function(.data) { print.arrow_dplyr_query <- function(x, ...) { schm <- x$.data$schema cols <- x$selected_columns + # TODO: selected_columns is no longer a character vector + # TODO: if cols are expressions, we won't know what their type will be at this time fields <- map_chr(cols, ~schm$GetFieldByName(.)$ToString()) # Strip off the field names as they are in the dataset and add the renamed ones fields <- paste(names(cols), sub("^.*?: ", "", fields), sep = ": ", collapse = "\n") @@ -89,7 +96,7 @@ dim.arrow_dplyr_query <- function(x) { rows <- NA_integer_ } else { # Evaluate the filter expression to a BooleanArray and count - rows <- as.integer(sum(eval_array_expression(x$filtered_rows), na.rm = TRUE)) + rows <- as.integer(sum(eval_array_expression(x$filtered_rows, x$.data), na.rm = TRUE)) } c(rows, cols) } @@ -275,18 +282,14 @@ array_function_list <- build_function_list(build_array_expression) filter_mask <- function(.data) { if (query_on_dataset(.data)) { f_env <- new_environment(dataset_function_list) - var_binder <- function(x) Expression$field_ref(x) } else { f_env <- new_environment(array_function_list) - var_binder <- function(x) .data$.data[[x]] } # Add the column references - # Renaming is handled automatically by the named list - data_pronoun <- lapply(.data$selected_columns, var_binder) - env_bind(f_env, !!!data_pronoun) + env_bind(f_env, !!!.data$selected_columns) # Then bind the data pronoun - env_bind(f_env, .data = data_pronoun) + env_bind(f_env, .data = .data$selected_columns) new_data_mask(f_env) } @@ -309,8 +312,26 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { # See dataset.R for Dataset and Scanner(Builder) classes tab <- Scanner$create(x)$ToTable() } else { - # This is a Table/RecordBatch. See record-batch.R for the [ method - tab <- x$.data[x$filtered_rows, x$selected_columns, keep_na = FALSE] + # This is a Table or RecordBatch + + # Filter and select the data referenced in selected columns + if (isTRUE(x$filtered_rows)) { + filter <- TRUE + } else { + filter <- eval_array_expression(x$filtered_rows, x$.data) + } + tab <- x$.data[filter, find_array_refs(x$selected_columns), keep_na = FALSE] + # Now evaluate those expressions on the filtered table + cols <- lapply(x$selected_columns, eval_array_expression, data = tab) + if (length(cols) == 0) { + tab <- tab[, integer(0)] + } else { + if (inherits(x$.data, "Table")) { + tab <- Table$create(!!!cols) + } else { + tab <- RecordBatch$create(!!!cols) + } + } } if (as_data_frame) { df <- as.data.frame(tab) @@ -327,6 +348,14 @@ ensure_group_vars <- function(x) { if (inherits(x, "arrow_dplyr_query")) { # Before pulling data from Arrow, make sure all group vars are in the projection gv <- set_names(setdiff(dplyr::group_vars(x), names(x))) + if (length(gv)) { + # TODO: selected_columns is no longer a character vector, so assemble refs (correctly!) + if (query_on_dataset(x)) { + gv <- lapply(gv, Expression$field_ref) + } else { + gv <- lapply(gv, function(var) x$.data[[var]]) + } + } x$selected_columns <- c(x$selected_columns, gv) } x @@ -337,21 +366,20 @@ restore_dplyr_features <- function(df, query) { # After calling collect(), make sure these features are carried over grouped <- length(query$group_by_vars) > 0 - renamed <- !identical(names(df), names(query)) - if (is.data.frame(df)) { + renamed <- ncol(df) && !identical(names(df), names(query)) + if (renamed) { # In case variables were renamed, apply those names - if (renamed && ncol(df)) { - names(df) <- names(query) - } + names(df) <- names(query) + } + if (grouped) { # Preserve groupings, if present - if (grouped) { + if (is.data.frame(df)) { df <- dplyr::grouped_df(df, dplyr::group_vars(query)) + } else { + # This is a Table, via collect(as_data_frame = FALSE) + df <- arrow_dplyr_query(df) + df$group_by_vars <- query$group_by_vars } - } else if (grouped || renamed) { - # This is a Table, via collect(as_data_frame = FALSE) - df <- arrow_dplyr_query(df) - names(df$selected_columns) <- names(query) - df$group_by_vars <- query$group_by_vars } df } diff --git a/r/R/expression.R b/r/R/expression.R index 878b800c652..a121c9194b2 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -143,7 +143,14 @@ cast_array_expression <- function(x, to_type, safe = TRUE, ...) { .array_function_map <- c(.unary_function_map, .binary_function_map) -eval_array_expression <- function(x) { +eval_array_expression <- function(x, data = NULL) { + if (!is.null(data)) { + x <- bind_array_refs(x, data) + } + if (inherits(x, "ArrowDatum")) { + # Nothing to evaluate + return(x) + } x$args <- lapply(x$args, function (a) { if (inherits(a, "array_expression")) { eval_array_expression(a) @@ -154,6 +161,27 @@ eval_array_expression <- function(x) { call_function(x$fun, args = x$args, options = x$options %||% empty_named_list()) } +find_array_refs <- function(x) { + if (identical(x$fun, "array_ref")) { + out <- x$args$field_name + } else { + out <- lapply(x$args, find_array_refs) + } + unlist(out) +} + +# Take an array_expression and replace array_refs with arrays/chunkedarrays from data +bind_array_refs <- function(x, data) { + if (inherits(x, "array_expression")) { + if (identical(x$fun, "array_ref")) { + x <- data[[x$args$field_name]] + } else { + x$args <- lapply(x$args, bind_array_refs, data) + } + } + x +} + #' @export is.na.array_expression <- function(x) array_expression("is.na", x) @@ -217,6 +245,9 @@ Expression <- R6Class("Expression", inherit = ArrowObject, ) Expression$create("cast", self, options = modifyList(opts, list(...))) } + ), + active = list( + field_name = function() dataset___expr__get_field_ref_name(self) ) ) Expression$create <- function(function_name, diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 839c9d6c173..73ee64844a6 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -1569,6 +1569,14 @@ BEGIN_CPP11 END_CPP11 } // expression.cpp +std::string dataset___expr__get_field_ref_name(const std::shared_ptr& ref); +extern "C" SEXP _arrow_dataset___expr__get_field_ref_name(SEXP ref_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type ref(ref_sexp); + return cpp11::as_sexp(dataset___expr__get_field_ref_name(ref)); +END_CPP11 +} +// expression.cpp std::shared_ptr dataset___expr__scalar(const std::shared_ptr& x); extern "C" SEXP _arrow_dataset___expr__scalar(SEXP x_sexp){ BEGIN_CPP11 @@ -3702,6 +3710,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_FixedSizeListType__list_size", (DL_FUNC) &_arrow_FixedSizeListType__list_size, 1}, { "_arrow_dataset___expr__call", (DL_FUNC) &_arrow_dataset___expr__call, 3}, { "_arrow_dataset___expr__field_ref", (DL_FUNC) &_arrow_dataset___expr__field_ref, 1}, + { "_arrow_dataset___expr__get_field_ref_name", (DL_FUNC) &_arrow_dataset___expr__get_field_ref_name, 1}, { "_arrow_dataset___expr__scalar", (DL_FUNC) &_arrow_dataset___expr__scalar, 1}, { "_arrow_dataset___expr__ToString", (DL_FUNC) &_arrow_dataset___expr__ToString, 1}, { "_arrow_ipc___WriteFeather__Table", (DL_FUNC) &_arrow_ipc___WriteFeather__Table, 6}, diff --git a/r/src/expression.cpp b/r/src/expression.cpp index ddb1e72c309..76d8222967b 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -47,6 +47,13 @@ std::shared_ptr dataset___expr__field_ref(std::string name) { return std::make_shared(ds::field_ref(std::move(name))); } +// [[arrow::export]] +std::string dataset___expr__get_field_ref_name( + const std::shared_ptr& ref) { + auto refname = ref->field_ref()->name(); + return *refname; +} + // [[arrow::export]] std::shared_ptr dataset___expr__scalar( const std::shared_ptr& x) { diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 6d9945a115a..53535d9179d 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -318,9 +318,10 @@ test_that("filter environment scope", { tbl ) # Also for functions - # 'could not find function "isEqualTo"' + # 'could not find function "isEqualTo"' because we haven't defined it yet expect_dplyr_error(filter(batch, isEqualTo(int, 4))) + skip("Need to substitute in user defined function too") # TODO: fix this: this isEqualTo function is eagerly evaluating; it should # instead yield array_expressions. Probably bc the parent env of the function # has the Ops.Array methods defined; we need to move it so that the parent @@ -599,7 +600,7 @@ test_that("collect(as_data_frame=FALSE)", { select(int, strng = chr) %>% filter(int > 5) %>% collect(as_data_frame = FALSE) - expect_is(b3, "arrow_dplyr_query") + expect_is(b3, "RecordBatch") expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng"))) b4 <- batch %>% @@ -632,7 +633,7 @@ test_that("head", { select(int, strng = chr) %>% filter(int > 5) %>% head(2) - expect_is(b3, "arrow_dplyr_query") + expect_is(b3, "RecordBatch") expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng"))) b4 <- batch %>% @@ -641,6 +642,11 @@ test_that("head", { group_by(int) %>% head(2) expect_is(b4, "arrow_dplyr_query") + # print(b4) + print(as.data.frame(b4)) + print( expected %>% + rename(strng = chr) %>% + group_by(int)) expect_equal( as.data.frame(b4), expected %>% @@ -665,7 +671,7 @@ test_that("tail", { select(int, strng = chr) %>% filter(int > 5) %>% tail(2) - expect_is(b3, "arrow_dplyr_query") + expect_is(b3, "RecordBatch") expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng"))) b4 <- batch %>% diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index 3c100812ff1..3df7270f4c5 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -34,8 +34,20 @@ test_that("array_expression print method", { ) }) +test_that("array_refs", { + tab <- Table$create(a = 1:5) + ex <- build_array_expression(">", array_expression("array_ref", field_name = "a"), 4) + expect_is(ex, "array_expression") + expect_identical(ex$args[[1]]$args$field_name, "a") + expect_identical(find_array_refs(ex), "a") + out <- eval_array_expression(ex, tab) + expect_is(out, "ChunkedArray") + expect_equal(as.vector(out), c(FALSE, FALSE, FALSE, FALSE, TRUE)) +}) + test_that("C++ expressions", { f <- Expression$field_ref("f") + expect_identical(f$field_name, "f") g <- Expression$field_ref("g") date <- Expression$scalar(as.Date("2020-01-15")) ts <- Expression$scalar(as.POSIXct("2020-01-17 11:11:11")) From 2866a3cbd12ec675d28e401d09ef9dd9443d4e3c Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Thu, 18 Feb 2021 12:37:28 -0800 Subject: [PATCH 02/15] Patch up tests --- r/R/dataset-scan.R | 16 ++++++++++++++-- r/R/dataset-write.R | 2 +- r/R/dplyr.R | 13 ++++++++++--- r/R/expression.R | 10 +++++++--- r/tests/testthat/test-dplyr.R | 7 +------ 5 files changed, 33 insertions(+), 15 deletions(-) diff --git a/r/R/dataset-scan.R b/r/R/dataset-scan.R index 16039d8fca6..19d4afd3227 100644 --- a/r/R/dataset-scan.R +++ b/r/R/dataset-scan.R @@ -69,10 +69,16 @@ Scanner$create <- function(dataset, batch_size = NULL, ...) { if (inherits(dataset, "arrow_dplyr_query")) { + if (inherits(dataset$.data, "ArrowTabular")) { + # To handle mutate() on Table/RecordBatch, we need to collect(as_data_frame=FALSE) now + dataset <- dplyr::collect(dataset, as_data_frame = FALSE) + # Slight hack: replace selected_columns with named character vector, + # which ScannerBuilder$Project can handle. + # We can't keep array_refs here because they don't translate to + } return(Scanner$create( dataset$.data, - # Note: selected_columns is no longer a character vector - map_chr(dataset$selected_columns, ~.$field_name), + dataset$selected_columns, dataset$filtered_rows, use_threads, ... @@ -153,6 +159,12 @@ map_batches <- function(X, FUN, ..., .data.frame = TRUE) { ScannerBuilder <- R6Class("ScannerBuilder", inherit = ArrowObject, public = list( Project = function(cols) { + # cols is either a character vector or a named list of Expressions + if (!is.character(cols)) { + # We don't yet support mutate() on datasets, so this is just a list + # of FieldRefs, and we need to back out the field names + cols <- get_field_names(cols) + } assert_is(cols, "character") dataset___ScannerBuilder__Project(self, cols) self diff --git a/r/R/dataset-write.R b/r/R/dataset-write.R index a6871aace9f..61c20b31f2d 100644 --- a/r/R/dataset-write.R +++ b/r/R/dataset-write.R @@ -63,7 +63,7 @@ write_dataset <- function(dataset, ...) { if (inherits(dataset, "arrow_dplyr_query")) { # We can select a subset of columns but we can't rename them - if (!all(map_chr(dataset$selected_columns, ~.$field_name) == names(dataset$selected_columns))) { + if (!all(get_field_names(dataset) == names(dataset$selected_columns))) { stop("Renaming columns when writing a dataset is not yet supported", call. = FALSE) } # partitioning vars need to be in the `select` schema diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 60ae5c0b69d..bda81a78806 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -56,7 +56,7 @@ arrow_dplyr_query <- function(.data) { #' @export print.arrow_dplyr_query <- function(x, ...) { schm <- x$.data$schema - cols <- x$selected_columns + cols <- get_field_names(x) # TODO: selected_columns is no longer a character vector # TODO: if cols are expressions, we won't know what their type will be at this time fields <- map_chr(cols, ~schm$GetFieldByName(.)$ToString()) @@ -80,6 +80,13 @@ print.arrow_dplyr_query <- function(x, ...) { invisible(x) } +get_field_names <- function(selected_cols) { + if (inherits(selected_cols, "arrow_dplyr_query")) { + selected_cols <- selected_cols$selected_columns + } + map_chr(selected_cols, ~.$field_name %||% .$args$field_name %||% "") +} + # These are the names reflecting all select/rename, not what is in Arrow #' @export names.arrow_dplyr_query <- function(x) names(x$selected_columns) @@ -351,9 +358,9 @@ ensure_group_vars <- function(x) { if (length(gv)) { # TODO: selected_columns is no longer a character vector, so assemble refs (correctly!) if (query_on_dataset(x)) { - gv <- lapply(gv, Expression$field_ref) + gv <- set_names(lapply(gv, Expression$field_ref), gv) } else { - gv <- lapply(gv, function(var) x$.data[[var]]) + gv <- set_names(lapply(gv, function(x) array_expression("array_ref", field_name = x)), gv) } } x$selected_columns <- c(x$selected_columns, gv) diff --git a/r/R/expression.R b/r/R/expression.R index a121c9194b2..a926007c1b1 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -209,9 +209,13 @@ print.array_expression <- function(x, ...) { deparse(arg) } }) - # Prune this for readability - function_name <- sub("_kleene", "", x$fun) - paste0(function_name, "(", paste(printed_args, collapse = ", "), ")") + if (identical(x$fun, "array_ref")) { + x$args$field_name + } else { + # Prune this for readability + function_name <- sub("_kleene", "", x$fun) + paste0(function_name, "(", paste(printed_args, collapse = ", "), ")") + } } ########### diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 53535d9179d..7273178b443 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -247,7 +247,7 @@ test_that("Print method", { int: int32 chr: string -* Filter: and(and(greater(, 2), or(equal(, "d"), equal(, "f"))), less(, 5)) +* Filter: and(and(greater(dbl, 2), or(equal(chr, "d"), equal(chr, "f"))), less(int, 5)) See $.data for the source Arrow object', fixed = TRUE ) @@ -642,11 +642,6 @@ test_that("head", { group_by(int) %>% head(2) expect_is(b4, "arrow_dplyr_query") - # print(b4) - print(as.data.frame(b4)) - print( expected %>% - rename(strng = chr) %>% - group_by(int)) expect_equal( as.data.frame(b4), expected %>% From 1e968788db216400a2effff767c78cd8d175f7cc Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Thu, 18 Feb 2021 13:11:54 -0800 Subject: [PATCH 03/15] Clean up comments --- r/R/dataset-scan.R | 3 --- r/R/dplyr.R | 5 ++--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/r/R/dataset-scan.R b/r/R/dataset-scan.R index 19d4afd3227..ec6f85c4bab 100644 --- a/r/R/dataset-scan.R +++ b/r/R/dataset-scan.R @@ -72,9 +72,6 @@ Scanner$create <- function(dataset, if (inherits(dataset$.data, "ArrowTabular")) { # To handle mutate() on Table/RecordBatch, we need to collect(as_data_frame=FALSE) now dataset <- dplyr::collect(dataset, as_data_frame = FALSE) - # Slight hack: replace selected_columns with named character vector, - # which ScannerBuilder$Project can handle. - # We can't keep array_refs here because they don't translate to } return(Scanner$create( dataset$.data, diff --git a/r/R/dplyr.R b/r/R/dplyr.R index bda81a78806..a4ddf56f106 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -57,8 +57,7 @@ arrow_dplyr_query <- function(.data) { print.arrow_dplyr_query <- function(x, ...) { schm <- x$.data$schema cols <- get_field_names(x) - # TODO: selected_columns is no longer a character vector - # TODO: if cols are expressions, we won't know what their type will be at this time + # TODO: if cols are expressions, they won't be in the schema fields <- map_chr(cols, ~schm$GetFieldByName(.)$ToString()) # Strip off the field names as they are in the dataset and add the renamed ones fields <- paste(names(cols), sub("^.*?: ", "", fields), sep = ": ", collapse = "\n") @@ -356,7 +355,7 @@ ensure_group_vars <- function(x) { # Before pulling data from Arrow, make sure all group vars are in the projection gv <- set_names(setdiff(dplyr::group_vars(x), names(x))) if (length(gv)) { - # TODO: selected_columns is no longer a character vector, so assemble refs (correctly!) + # selected_columns is no longer a character vector, so assemble refs if (query_on_dataset(x)) { gv <- set_names(lapply(gv, Expression$field_ref), gv) } else { From 3129eb67c4f47616e3869c129fc8b4cd751a5900 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Thu, 18 Feb 2021 13:41:18 -0800 Subject: [PATCH 04/15] Split up test-dplyr.R --- r/tests/testthat/helper-expectation.R | 63 +++++ r/tests/testthat/test-dplyr-filter.R | 287 +++++++++++++++++++++ r/tests/testthat/test-dplyr-mutate.R | 49 ++++ r/tests/testthat/test-dplyr.R | 349 -------------------------- 4 files changed, 399 insertions(+), 349 deletions(-) create mode 100644 r/tests/testthat/test-dplyr-filter.R create mode 100644 r/tests/testthat/test-dplyr-mutate.R diff --git a/r/tests/testthat/helper-expectation.R b/r/tests/testthat/helper-expectation.R index ce0f9de8a54..76edea61f57 100644 --- a/r/tests/testthat/helper-expectation.R +++ b/r/tests/testthat/helper-expectation.R @@ -59,3 +59,66 @@ verify_output <- function(...) { } testthat::verify_output(...) } + +expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its start + tbl, # A tbl/df as reference, will make RB/Table with + skip_record_batch = NULL, # Msg, if should skip RB test + skip_table = NULL, # Msg, if should skip Table test + ...) { + expr <- rlang::enquo(expr) + expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))) + + skip_msg <- NULL + + if (is.null(skip_record_batch)) { + via_batch <- rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = record_batch(tbl))) + ) + expect_equivalent(via_batch, expected, ...) + } else { + skip_msg <- c(skip_msg, skip_record_batch) + } + + if (is.null(skip_table)) { + via_table <- rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = Table$create(tbl))) + ) + expect_equivalent(via_table, expected, ...) + } else { + skip_msg <- c(skip_msg, skip_table) + } + + if (!is.null(skip_msg)) { + skip(paste(skip_msg, collpase = "\n")) + } +} + +expect_dplyr_error <- function(expr, # A dplyr pipeline with `input` as its start + tbl, # A tbl/df as reference, will make RB/Table with + ...) { + expr <- rlang::enquo(expr) + msg <- tryCatch( + rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))), + error = function (e) conditionMessage(e) + ) + expect_is(msg, "character", label = "dplyr on data.frame did not error") + + expect_error( + rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = record_batch(tbl))) + ), + msg, + ... + ) + expect_error( + rlang::eval_tidy( + expr, + rlang::new_data_mask(rlang::env(input = Table$create(tbl))) + ), + msg, + ... + ) +} \ No newline at end of file diff --git a/r/tests/testthat/test-dplyr-filter.R b/r/tests/testthat/test-dplyr-filter.R new file mode 100644 index 00000000000..6e37d41cc2f --- /dev/null +++ b/r/tests/testthat/test-dplyr-filter.R @@ -0,0 +1,287 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +library(dplyr) +library(stringr) + +tbl <- example_data +# Add some better string data +tbl$verses <- verses[[1]] +# c(" a ", " b ", " c ", ...) increasing padding +# nchar = 3 5 7 9 11 13 15 17 19 21 +tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2*(1:10)+1, side = "both") + +test_that("filter() on is.na()", { + expect_dplyr_equal( + input %>% + filter(is.na(lgl)) %>% + select(chr, int, lgl) %>% + collect(), + tbl + ) +}) + +test_that("filter() with NAs in selection", { + expect_dplyr_equal( + input %>% + filter(lgl) %>% + select(chr, int, lgl) %>% + collect(), + tbl + ) +}) + +test_that("Filter returning an empty Table should not segfault (ARROW-8354)", { + expect_dplyr_equal( + input %>% + filter(false) %>% + select(chr, int, lgl) %>% + collect(), + tbl + ) +}) + +test_that("filtering with expression", { + char_sym <- "b" + expect_dplyr_equal( + input %>% + filter(chr == char_sym) %>% + select(string = chr, int) %>% + collect(), + tbl + ) +}) + +test_that("filtering with arithmetic", { + expect_dplyr_equal( + input %>% + filter(dbl + 1 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl / 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl / 2L > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int / 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int / 2L > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl %/% 2 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) +}) + +test_that("filtering with expression + autocasting", { + expect_dplyr_equal( + input %>% + filter(dbl + 1 > 3L) %>% # test autocasting with comparison to 3L + select(string = chr, int, dbl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(int + 1 > 3) %>% + select(string = chr, int, dbl) %>% + collect(), + tbl + ) +}) + +test_that("More complex select/filter", { + expect_dplyr_equal( + input %>% + filter(dbl > 2, chr == "d" | chr == "f") %>% + select(chr, int, lgl) %>% + filter(int < 5) %>% + select(int, chr) %>% + collect(), + tbl + ) +}) + +test_that("filter() with %in%", { + expect_dplyr_equal( + input %>% + filter(dbl > 2, chr %in% c("d", "f")) %>% + collect(), + tbl + ) +}) + +test_that("filter() with string ops", { + # Extra instrumentation to ensure that we're calling Arrow compute here + # because many base R string functions implicitly call as.character, + # which means they still work on Arrays but actually force data into R + # 1) wrapper that raises a warning if as.character is called. Can't wrap + # the whole test because as.character apparently gets called in other + # (presumably legitimate) places + # 2) Wrap the test in expect_warning(expr, NA) to catch the warning + + with_no_as_character <- function(expr) { + trace( + "as.character", + tracer = quote(warning("as.character was called")), + print = FALSE, + where = toupper + ) + on.exit(untrace("as.character", where = toupper)) + force(expr) + } + + expect_warning( + expect_dplyr_equal( + input %>% + filter(dbl > 2, with_no_as_character(toupper(chr)) %in% c("D", "F")) %>% + collect(), + tbl + ), + NA) + + expect_dplyr_equal( + input %>% + filter(dbl > 2, str_length(verses) > 25) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(dbl > 2, str_length(str_trim(padded_strings, "left")) > 5) %>% + collect(), + tbl + ) +}) + +test_that("filter environment scope", { + # "object 'b_var' not found" + expect_dplyr_error(input %>% filter(batch, chr == b_var)) + + b_var <- "b" + expect_dplyr_equal( + input %>% + filter(chr == b_var) %>% + collect(), + tbl + ) + # Also for functions + # 'could not find function "isEqualTo"' because we haven't defined it yet + expect_dplyr_error(filter(batch, isEqualTo(int, 4))) + + skip("Need to substitute in user defined function too") + # TODO: fix this: this isEqualTo function is eagerly evaluating; it should + # instead yield array_expressions. Probably bc the parent env of the function + # has the Ops.Array methods defined; we need to move it so that the parent + # env is the data mask we use in the dplyr eval + isEqualTo <- function(x, y) x == y & !is.na(x) + expect_dplyr_equal( + input %>% + select(-fct) %>% # factor levels aren't identical + filter(isEqualTo(int, 4)) %>% + collect(), + tbl + ) +}) + +test_that("Filtering on a column that doesn't exist errors correctly", { + skip("Error handling in filter() needs to be internationalized") + expect_error( + batch %>% filter(not_a_col == 42) %>% collect(), + "object 'not_a_col' not found" + ) +}) + +test_that("Filtering with a function that doesn't have an Array/expr method still works", { + expect_warning( + expect_dplyr_equal( + input %>% + filter(int > 2, pnorm(dbl) > .99) %>% + collect(), + tbl + ), + 'Filter expression not implemented in Arrow: pnorm(dbl) > 0.99; pulling data into R', + fixed = TRUE + ) +}) + +test_that("filter() with .data pronoun", { + expect_dplyr_equal( + input %>% + filter(.data$dbl > 4) %>% + select(.data$chr, .data$int, .data$lgl) %>% + collect(), + tbl + ) + + expect_dplyr_equal( + input %>% + filter(is.na(.data$lgl)) %>% + select(.data$chr, .data$int, .data$lgl) %>% + collect(), + tbl + ) + + # and the .env pronoun too! + chr <- 4 + expect_dplyr_equal( + input %>% + filter(.data$dbl > .env$chr) %>% + select(.data$chr, .data$int, .data$lgl) %>% + collect(), + tbl + ) + + # but there is an error if we don't override the masking with `.env` + expect_dplyr_error( + tbl %>% + filter(.data$dbl > chr) %>% + select(.data$chr, .data$int, .data$lgl) %>% + collect() + ) +}) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R new file mode 100644 index 00000000000..4b8b8e57f68 --- /dev/null +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +library(dplyr) +library(stringr) + +tbl <- example_data +# Add some better string data +tbl$verses <- verses[[1]] +# c(" a ", " b ", " c ", ...) increasing padding +# nchar = 3 5 7 9 11 13 15 17 19 21 +tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2*(1:10)+1, side = "both") + +test_that("mutate", { + expect_dplyr_equal( + input %>% + select(int, chr) %>% + filter(int > 5) %>% + mutate(int = int + 6L) %>% + summarize(min_int = min(int)), + tbl + ) +}) + +test_that("transmute", { + skip("TODO: reimplement transmute (with dplyr 1.0, it no longer just works via mutate)") + expect_dplyr_equal( + input %>% + select(int, chr) %>% + filter(int > 5) %>% + transmute(int = int + 6L) %>% + summarize(min_int = min(int)), + tbl + ) +}) diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index 7273178b443..13610f1c6f1 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -15,74 +15,9 @@ # specific language governing permissions and limitations # under the License. -context("dplyr verbs") - library(dplyr) library(stringr) -expect_dplyr_equal <- function(expr, # A dplyr pipeline with `input` as its start - tbl, # A tbl/df as reference, will make RB/Table with - skip_record_batch = NULL, # Msg, if should skip RB test - skip_table = NULL, # Msg, if should skip Table test - ...) { - expr <- rlang::enquo(expr) - expected <- rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))) - - skip_msg <- NULL - - if (is.null(skip_record_batch)) { - via_batch <- rlang::eval_tidy( - expr, - rlang::new_data_mask(rlang::env(input = record_batch(tbl))) - ) - expect_equivalent(via_batch, expected, ...) - } else { - skip_msg <- c(skip_msg, skip_record_batch) - } - - if (is.null(skip_table)) { - via_table <- rlang::eval_tidy( - expr, - rlang::new_data_mask(rlang::env(input = Table$create(tbl))) - ) - expect_equivalent(via_table, expected, ...) - } else { - skip_msg <- c(skip_msg, skip_table) - } - - if (!is.null(skip_msg)) { - skip(paste(skip_msg, collpase = "\n")) - } -} - -expect_dplyr_error <- function(expr, # A dplyr pipeline with `input` as its start - tbl, # A tbl/df as reference, will make RB/Table with - ...) { - expr <- rlang::enquo(expr) - msg <- tryCatch( - rlang::eval_tidy(expr, rlang::new_data_mask(rlang::env(input = tbl))), - error = function (e) conditionMessage(e) - ) - expect_is(msg, "character", label = "dplyr on data.frame did not error") - - expect_error( - rlang::eval_tidy( - expr, - rlang::new_data_mask(rlang::env(input = record_batch(tbl))) - ), - msg, - ... - ) - expect_error( - rlang::eval_tidy( - expr, - rlang::new_data_mask(rlang::env(input = Table$create(tbl))) - ), - msg, - ... - ) -} - tbl <- example_data # Add some better string data tbl$verses <- verses[[1]] @@ -104,127 +39,6 @@ test_that("basic select/filter/collect", { expect_identical(collect(batch), tbl) }) -test_that("filter() on is.na()", { - expect_dplyr_equal( - input %>% - filter(is.na(lgl)) %>% - select(chr, int, lgl) %>% - collect(), - tbl - ) -}) - -test_that("filter() with NAs in selection", { - expect_dplyr_equal( - input %>% - filter(lgl) %>% - select(chr, int, lgl) %>% - collect(), - tbl - ) -}) - -test_that("Filter returning an empty Table should not segfault (ARROW-8354)", { - expect_dplyr_equal( - input %>% - filter(false) %>% - select(chr, int, lgl) %>% - collect(), - tbl - ) -}) - -test_that("filtering with expression", { - char_sym <- "b" - expect_dplyr_equal( - input %>% - filter(chr == char_sym) %>% - select(string = chr, int) %>% - collect(), - tbl - ) -}) - -test_that("filtering with arithmetic", { - expect_dplyr_equal( - input %>% - filter(dbl + 1 > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(dbl / 2 > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(dbl / 2L > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(int / 2 > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(int / 2L > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(dbl %/% 2 > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) -}) - -test_that("filtering with expression + autocasting", { - expect_dplyr_equal( - input %>% - filter(dbl + 1 > 3L) %>% # test autocasting with comparison to 3L - select(string = chr, int, dbl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(int + 1 > 3) %>% - select(string = chr, int, dbl) %>% - collect(), - tbl - ) -}) - -test_that("More complex select/filter", { - expect_dplyr_equal( - input %>% - filter(dbl > 2, chr == "d" | chr == "f") %>% - select(chr, int, lgl) %>% - filter(int < 5) %>% - select(int, chr) %>% - collect(), - tbl - ) -}) - test_that("dim() on query", { expect_dplyr_equal( input %>% @@ -253,146 +67,6 @@ See $.data for the source Arrow object', ) }) -test_that("filter() with %in%", { - expect_dplyr_equal( - input %>% - filter(dbl > 2, chr %in% c("d", "f")) %>% - collect(), - tbl - ) -}) - -test_that("filter() with string ops", { - # Extra instrumentation to ensure that we're calling Arrow compute here - # because many base R string functions implicitly call as.character, - # which means they still work on Arrays but actually force data into R - # 1) wrapper that raises a warning if as.character is called. Can't wrap - # the whole test because as.character apparently gets called in other - # (presumably legitimate) places - # 2) Wrap the test in expect_warning(expr, NA) to catch the warning - - with_no_as_character <- function(expr) { - trace( - "as.character", - tracer = quote(warning("as.character was called")), - print = FALSE, - where = toupper - ) - on.exit(untrace("as.character", where = toupper)) - force(expr) - } - - expect_warning( - expect_dplyr_equal( - input %>% - filter(dbl > 2, with_no_as_character(toupper(chr)) %in% c("D", "F")) %>% - collect(), - tbl - ), - NA) - - expect_dplyr_equal( - input %>% - filter(dbl > 2, str_length(verses) > 25) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(dbl > 2, str_length(str_trim(padded_strings, "left")) > 5) %>% - collect(), - tbl - ) -}) - -test_that("filter environment scope", { - # "object 'b_var' not found" - expect_dplyr_error(input %>% filter(batch, chr == b_var)) - - b_var <- "b" - expect_dplyr_equal( - input %>% - filter(chr == b_var) %>% - collect(), - tbl - ) - # Also for functions - # 'could not find function "isEqualTo"' because we haven't defined it yet - expect_dplyr_error(filter(batch, isEqualTo(int, 4))) - - skip("Need to substitute in user defined function too") - # TODO: fix this: this isEqualTo function is eagerly evaluating; it should - # instead yield array_expressions. Probably bc the parent env of the function - # has the Ops.Array methods defined; we need to move it so that the parent - # env is the data mask we use in the dplyr eval - isEqualTo <- function(x, y) x == y & !is.na(x) - expect_dplyr_equal( - input %>% - select(-fct) %>% # factor levels aren't identical - filter(isEqualTo(int, 4)) %>% - collect(), - tbl - ) -}) - -test_that("Filtering on a column that doesn't exist errors correctly", { - skip("Error handling in filter() needs to be internationalized") - expect_error( - batch %>% filter(not_a_col == 42) %>% collect(), - "object 'not_a_col' not found" - ) -}) - -test_that("Filtering with a function that doesn't have an Array/expr method still works", { - expect_warning( - expect_dplyr_equal( - input %>% - filter(int > 2, pnorm(dbl) > .99) %>% - collect(), - tbl - ), - 'Filter expression not implemented in Arrow: pnorm(dbl) > 0.99; pulling data into R', - fixed = TRUE - ) -}) - -test_that("filter() with .data pronoun", { - expect_dplyr_equal( - input %>% - filter(.data$dbl > 4) %>% - select(.data$chr, .data$int, .data$lgl) %>% - collect(), - tbl - ) - - expect_dplyr_equal( - input %>% - filter(is.na(.data$lgl)) %>% - select(.data$chr, .data$int, .data$lgl) %>% - collect(), - tbl - ) - - # and the .env pronoun too! - chr <- 4 - expect_dplyr_equal( - input %>% - filter(.data$dbl > .env$chr) %>% - select(.data$chr, .data$int, .data$lgl) %>% - collect(), - tbl - ) - - # but there is an error if we don't override the masking with `.env` - expect_dplyr_error( - tbl %>% - filter(.data$dbl > chr) %>% - select(.data$chr, .data$int, .data$lgl) %>% - collect() - ) -}) - test_that("summarize", { expect_dplyr_equal( input %>% @@ -411,29 +85,6 @@ test_that("summarize", { ) }) -test_that("mutate", { - expect_dplyr_equal( - input %>% - select(int, chr) %>% - filter(int > 5) %>% - mutate(int = int + 6L) %>% - summarize(min_int = min(int)), - tbl - ) -}) - -test_that("transmute", { - skip("TODO: reimplement transmute (with dplyr 1.0, it no longer just works via mutate)") - expect_dplyr_equal( - input %>% - select(int, chr) %>% - filter(int > 5) %>% - transmute(int = int + 6L) %>% - summarize(min_int = min(int)), - tbl - ) -}) - test_that("group_by groupings are recorded", { expect_dplyr_equal( input %>% From 7c136f7847c2c9794c15da1b88ad7d4c309b2f6b Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Thu, 18 Feb 2021 14:25:58 -0800 Subject: [PATCH 05/15] Proof of concept of mutate(); add transmute while we're at it --- r/R/arrow-package.R | 2 +- r/R/dplyr.R | 98 +++++++++++++++++++--------- r/tests/testthat/test-dplyr-mutate.R | 14 ++-- 3 files changed, 78 insertions(+), 36 deletions(-) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 66694a97867..818d85c8580 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -30,7 +30,7 @@ "dplyr::", c( "select", "filter", "collect", "summarise", "group_by", "groups", - "group_vars", "ungroup", "mutate", "arrange", "rename", "pull" + "group_vars", "ungroup", "mutate", "transmute", "arrange", "rename", "pull" ) ) for (cl in c("Dataset", "ArrowTabular", "arrow_dplyr_query")) { diff --git a/r/R/dplyr.R b/r/R/dplyr.R index a4ddf56f106..7b3a87293eb 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -200,29 +200,8 @@ filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) { } .data <- arrow_dplyr_query(.data) - # The filter() method works by evaluating the filters to generate Expressions - # with references to Arrays (if .data is Table/RecordBatch) or Fields (if - # .data is a Dataset). - dm <- filter_mask(.data) - filters <- lapply(filts, function (f) { - # This should yield an Expression as long as the filter function(s) are - # implemented in Arrow. - tryCatch(eval_tidy(f, dm), error = function(e) { - # Look for the cases where bad input was given, i.e. this would fail - # in regular dplyr anyway, and let those raise those as errors; - # else, for things not supported by Arrow return a "try-error", - # which we'll handle differently - msg <- conditionMessage(e) - # TODO: internationalization? - if (grepl("object '.*'.not.found", msg)) { - stop(e) - } - if (grepl('could not find function ".*"', msg)) { - stop(e) - } - invisible(structure(msg, class = "try-error", condition = e)) - }) - }) + # tidy-eval the filter expressions inside an Arrow data_mask + filters <- lapply(filts, arrow_eval, arrow_mask(.data)) bad_filters <- map_lgl(filters, ~inherits(., "try-error")) if (any(bad_filters)) { bads <- oxford_paste(map_chr(filts, as_label)[bad_filters], quote = FALSE) @@ -251,6 +230,30 @@ filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) { } filter.Dataset <- filter.ArrowTabular <- filter.arrow_dplyr_query +arrow_eval <- function (expr, mask) { + # filter(), mutate(), etc. work by evaluating the quoted `exprs` to generate Expressions + # with references to Arrays (if .data is Table/RecordBatch) or Fields (if + # .data is a Dataset). + + # This yields an Expression as long as the `exprs` are implemented in Arrow. + # Otherwise, it returns a try-error + tryCatch(eval_tidy(expr, mask), error = function(e) { + # Look for the cases where bad input was given, i.e. this would fail + # in regular dplyr anyway, and let those raise those as errors; + # else, for things not supported by Arrow return a "try-error", + # which we'll handle differently + msg <- conditionMessage(e) + # TODO: internationalization? + if (grepl("object '.*'.not.found", msg)) { + stop(e) + } + if (grepl('could not find function ".*"', msg)) { + stop(e) + } + invisible(structure(msg, class = "try-error", condition = e)) + }) +} + # Helper to assemble the functions that go in the NSE data mask # The only difference between the Dataset and the Table/RecordBatch versions # is that they use a different wrapping function (FUN) to hold the unevaluated @@ -284,8 +287,8 @@ build_function_list <- function(FUN) { dataset_function_list <- build_function_list(build_dataset_expression) array_function_list <- build_function_list(build_array_expression) -# Create a data mask for evaluating a filter expression -filter_mask <- function(.data) { +# Create a data mask for evaluating a dplyr expression +arrow_mask <- function(.data) { if (query_on_dataset(.data)) { f_env <- new_environment(dataset_function_list) } else { @@ -457,18 +460,51 @@ ungroup.arrow_dplyr_query <- function(x, ...) { } ungroup.Dataset <- ungroup.ArrowTabular <- force -mutate.arrow_dplyr_query <- function(.data, ...) { +mutate.arrow_dplyr_query <- function(.data, + ..., + .keep = c("all", "used", "unused", "none"), + .before = NULL, + .after = NULL) { + exprs <- quos(...) + if (length(exprs) == 0) { + # Nothing to do + return(.data) + } .data <- arrow_dplyr_query(.data) if (query_on_dataset(.data)) { not_implemented_for_dataset("mutate()") } - # TODO: see if we can defer evaluating the expressions and not collect here. - # It's different from filters (as currently implemented) because the basic - # vector transformation functions aren't yet implemented in Arrow C++. - dplyr::mutate(dplyr::collect(.data), ...) + .keep <- match.arg(.keep) + # Restrict the cases we support for now + stopifnot( + .keep %in% c("all", "none"), + is.null(.before), + is.null(.after), + length(group_vars(.data)) == 0 + ) + + mask <- arrow_mask(.data) + results <- list() + for (new_var in names(exprs)) { + results[[new_var]] <- mask[[new_var]] <- arrow_eval(exprs[[new_var]], mask) + # TODO also update .data pronoun + # TODO: check for try-error + } + + if (.keep == "all") { + for (new_var in names(results)) { + .data$selected_columns[[new_var]] <- results[[new_var]] + } + } else if (.keep == "none") { + .data$selected_columns <- results + .data <- ensure_group_vars(.data) + } + .data } mutate.Dataset <- mutate.ArrowTabular <- mutate.arrow_dplyr_query -# TODO: add transmute() that does what summarise() does (select only the vars we need) + +transmute.arrow_dplyr_query <- function(.data, ...) dplyr::mutate(.data, ..., .keep = "none") +transmute.Dataset <- transmute.ArrowTabular <- transmute.arrow_dplyr_query arrange.arrow_dplyr_query <- function(.data, ...) { .data <- arrow_dplyr_query(.data) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 4b8b8e57f68..80cd66332ce 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -25,25 +25,31 @@ tbl$verses <- verses[[1]] # nchar = 3 5 7 9 11 13 15 17 19 21 tbl$padded_strings <- stringr::str_pad(letters[1:10], width = 2*(1:10)+1, side = "both") -test_that("mutate", { +test_that("mutate() is lazy", { + expect_is( + tbl %>% record_batch() %>% mutate(int = int + 6L), + "arrow_dplyr_query" + ) +}) + +test_that("basic mutate", { expect_dplyr_equal( input %>% select(int, chr) %>% filter(int > 5) %>% mutate(int = int + 6L) %>% - summarize(min_int = min(int)), + collect(), tbl ) }) test_that("transmute", { - skip("TODO: reimplement transmute (with dplyr 1.0, it no longer just works via mutate)") expect_dplyr_equal( input %>% select(int, chr) %>% filter(int > 5) %>% transmute(int = int + 6L) %>% - summarize(min_int = min(int)), + collect(), tbl ) }) From ef3ce03e8db58083e101d7de24e6d2b65c5e3a9e Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Thu, 18 Feb 2021 15:06:55 -0800 Subject: [PATCH 06/15] Refactor data mask so we can mutate and use .data --- r/R/dplyr.R | 28 ++++++++++++++++------- r/tests/testthat/test-dplyr-mutate.R | 33 ++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 7b3a87293eb..35617e0751b 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -295,11 +295,16 @@ arrow_mask <- function(.data) { f_env <- new_environment(array_function_list) } - # Add the column references - env_bind(f_env, !!!.data$selected_columns) - # Then bind the data pronoun - env_bind(f_env, .data = .data$selected_columns) - new_data_mask(f_env) + # Add the column references and make the mask + out <- new_data_mask( + new_environment(.data$selected_columns, parent = f_env), + f_env + ) + # Then insert the data pronoun + # TODO: figure out what rlang::as_data_pronoun does/why we should use it + # (because if we do we get `Error: Can't modify the data pronoun` in mutate()) + out$.data <- .data$selected_columns + out } set_filters <- function(.data, expressions) { @@ -470,25 +475,31 @@ mutate.arrow_dplyr_query <- function(.data, # Nothing to do return(.data) } + .data <- arrow_dplyr_query(.data) if (query_on_dataset(.data)) { not_implemented_for_dataset("mutate()") } + .keep <- match.arg(.keep) # Restrict the cases we support for now stopifnot( .keep %in% c("all", "none"), is.null(.before), is.null(.after), + # mutate() on a grouped dataset does calculations within groups + # This doesn't matter on scalar ops (arithmetic etc.) but it does + # for things with aggregations (e.g. subtracting the mean) length(group_vars(.data)) == 0 ) - + mask <- arrow_mask(.data) results <- list() for (new_var in names(exprs)) { - results[[new_var]] <- mask[[new_var]] <- arrow_eval(exprs[[new_var]], mask) - # TODO also update .data pronoun + results[[new_var]] <- arrow_eval(exprs[[new_var]], mask) # TODO: check for try-error + # Put it in the data mask too + mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]] } if (.keep == "all") { @@ -497,6 +508,7 @@ mutate.arrow_dplyr_query <- function(.data, } } else if (.keep == "none") { .data$selected_columns <- results + # "none" (i.e. transmute) still keeps group vars .data <- ensure_group_vars(.data) } .data diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 80cd66332ce..36c549c5bc2 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -53,3 +53,36 @@ test_that("transmute", { tbl ) }) + +test_that("mutate and refer to previous mutants", { + expect_dplyr_equal( + input %>% + select(int, padded_strings) %>% + mutate( + line_lengths = nchar(padded_strings), + longer = line_lengths * 10 + ) %>% + filter(line_lengths > 15) %>% + collect(), + tbl + ) +}) + + +test_that("mutate with .data pronoun", { + expect_dplyr_equal( + input %>% + select(int, padded_strings) %>% + mutate( + line_lengths = nchar(padded_strings), + longer = .data$line_lengths * 10 + ) %>% + filter(line_lengths > 15) %>% + collect(), + tbl + ) +}) + +test_that("handle bad expressions", { + +}) \ No newline at end of file From fb5dabfb740a15d15f9220cca4c0d842dcfa294c Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Fri, 19 Feb 2021 13:50:09 -0800 Subject: [PATCH 07/15] Add dplyr::mutate examples as tests, support what we can and fall back to collecting data.frame where we can't --- r/R/dplyr.R | 124 +++++++++++++++------ r/tests/testthat/test-dplyr-filter.R | 2 +- r/tests/testthat/test-dplyr-mutate.R | 160 ++++++++++++++++++++++++++- 3 files changed, 252 insertions(+), 34 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 35617e0751b..35b8568a244 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -30,19 +30,14 @@ arrow_dplyr_query <- function(.data) { if (inherits(.data, "arrow_dplyr_query")) { return(.data) } - # selected_columns is a named list: - # * contents are references/expressions pointing to the data - # * names are the names they should be in the end (i.e. this - # records any renaming) - if (inherits(.data, "Dataset")) { - selected_columns <- lapply(names(.data), Expression$field_ref) - } else { - selected_columns <- lapply(names(.data), function(x) array_expression("array_ref", field_name = x)) - } structure( list( .data = .data$clone(), - selected_columns = set_names(selected_columns, names(.data)), + # selected_columns is a named list: + # * contents are references/expressions pointing to the data + # * names are the names they should be in the end (i.e. this + # records any renaming) + selected_columns = make_field_refs(names(.data), dataset = inherits(.data, "Dataset")), # filtered_rows will be an Expression filtered_rows = TRUE, # group_by_vars is a character vector of columns (as renamed) @@ -86,6 +81,15 @@ get_field_names <- function(selected_cols) { map_chr(selected_cols, ~.$field_name %||% .$args$field_name %||% "") } +make_field_refs <- function(field_names, dataset = TRUE) { + if (dataset) { + out <- lapply(field_names, Expression$field_ref) + } else { + out <- lapply(field_names, function(x) array_expression("array_ref", field_name = x)) + } + set_names(out, field_names) +} + # These are the names reflecting all select/rename, not what is in Arrow #' @export names.arrow_dplyr_query <- function(x) names(x$selected_columns) @@ -243,7 +247,7 @@ arrow_eval <- function (expr, mask) { # else, for things not supported by Arrow return a "try-error", # which we'll handle differently msg <- conditionMessage(e) - # TODO: internationalization? + # TODO(ARROW-11700): internationalization if (grepl("object '.*'.not.found", msg)) { stop(e) } @@ -295,6 +299,14 @@ arrow_mask <- function(.data) { f_env <- new_environment(array_function_list) } + # Add functions that need to error hard and clear. + # Some R functions will still try to evaluate on an Expression + # and return NA with a warning + fail <- function(...) stop("Not implemented") + for (f in c("mean")) { + f_env[[f]] <- fail + } + # Add the column references and make the mask out <- new_data_mask( new_environment(.data$selected_columns, parent = f_env), @@ -363,14 +375,12 @@ ensure_group_vars <- function(x) { # Before pulling data from Arrow, make sure all group vars are in the projection gv <- set_names(setdiff(dplyr::group_vars(x), names(x))) if (length(gv)) { - # selected_columns is no longer a character vector, so assemble refs - if (query_on_dataset(x)) { - gv <- set_names(lapply(gv, Expression$field_ref), gv) - } else { - gv <- set_names(lapply(gv, function(x) array_expression("array_ref", field_name = x)), gv) - } + # Add them back + x$selected_columns <- c( + x$selected_columns, + make_field_refs(gv, dataset = query_on_dataset(.data)) + ) } - x$selected_columns <- c(x$selected_columns, gv) } x } @@ -470,6 +480,7 @@ mutate.arrow_dplyr_query <- function(.data, .keep = c("all", "used", "unused", "none"), .before = NULL, .after = NULL) { + call <- match.call() exprs <- quos(...) if (length(exprs) == 0) { # Nothing to do @@ -482,36 +493,85 @@ mutate.arrow_dplyr_query <- function(.data, } .keep <- match.arg(.keep) + .before <- enquo(.before) + .after <- enquo(.after) # Restrict the cases we support for now - stopifnot( - .keep %in% c("all", "none"), - is.null(.before), - is.null(.after), + fall_back_to_r <- FALSE + if (!quo_is_null(.before) || !quo_is_null(.after)) { + warning( + '.before and .after arguments are not supported in Arrow; pulling data into R', + call. = FALSE + ) + fall_back_to_r <- TRUE + } else if (length(group_vars(.data)) > 0) { # mutate() on a grouped dataset does calculations within groups # This doesn't matter on scalar ops (arithmetic etc.) but it does # for things with aggregations (e.g. subtracting the mean) - length(group_vars(.data)) == 0 - ) + warning( + 'mutate() on grouped data not supported in Arrow; pulling data into R', + call. = FALSE + ) + fall_back_to_r <- TRUE + } else if (!all(nzchar(names(exprs)))) { + # This is either user error or a function that returns a data.frame + # e.g. across() that dplyr::mutate() will autosplice + warning( + 'all ... expressions must be named: ', + 'autosplicing multi-column results not supported in Arrow; ', + 'pulling data into R', + call. = FALSE + ) + fall_back_to_r <- TRUE + } + if (fall_back_to_r) { + # collect() and call mutate() on the data.frame + call$.data <- dplyr::collect(.data) + call[[1]] <- get("mutate", envir = asNamespace("dplyr")) + return(eval.parent(call)) + } mask <- arrow_mask(.data) results <- list() for (new_var in names(exprs)) { results[[new_var]] <- arrow_eval(exprs[[new_var]], mask) - # TODO: check for try-error + if (inherits(results[[new_var]], "try-error")) { + warning( + 'Expression ', as_label(exprs[[new_var]]), + ' not supported in Arrow; pulling data into R', + call. = FALSE + ) + # collect() and call mutate() on the data.frame + call$.data <- dplyr::collect(.data) + call[[1]] <- get("mutate", envir = asNamespace("dplyr")) + return(eval.parent(call)) + } # Put it in the data mask too mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]] } - if (.keep == "all") { - for (new_var in names(results)) { + # Assign the new columns into the .data$selected_columns, respecting the .keep param + if (.keep == "none") { + .data$selected_columns <- results + } else { + if (.keep != "all") { + # "used" or "unused" + used_vars <- unlist(lapply(exprs, all.vars), use.names = FALSE) + old_vars <- names(.data$selected_columns) + if (.keep == "used") { + .data$selected_columns <- .data$selected_columns[intersect(old_vars, used_vars)] + } else { + # "unused" + .data$selected_columns <- .data$selected_columns[setdiff(old_vars, used_vars)] + } + } + # Note that this is names(exprs) not names(results): + # if results$new_var is NULL, that means we are supposed to remove it + for (new_var in names(exprs)) { .data$selected_columns[[new_var]] <- results[[new_var]] } - } else if (.keep == "none") { - .data$selected_columns <- results - # "none" (i.e. transmute) still keeps group vars - .data <- ensure_group_vars(.data) } - .data + # Even if "none", we still keep group vars + ensure_group_vars(.data) } mutate.Dataset <- mutate.ArrowTabular <- mutate.arrow_dplyr_query diff --git a/r/tests/testthat/test-dplyr-filter.R b/r/tests/testthat/test-dplyr-filter.R index 6e37d41cc2f..f73589496be 100644 --- a/r/tests/testthat/test-dplyr-filter.R +++ b/r/tests/testthat/test-dplyr-filter.R @@ -230,7 +230,7 @@ test_that("filter environment scope", { }) test_that("Filtering on a column that doesn't exist errors correctly", { - skip("Error handling in filter() needs to be internationalized") + skip("Error handling in arrow_eval() needs to be internationalized (ARROW-11700)") expect_error( batch %>% filter(not_a_col == 42) %>% collect(), "object 'not_a_col' not found" diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 36c549c5bc2..96bcbfe4a26 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -68,7 +68,6 @@ test_that("mutate and refer to previous mutants", { ) }) - test_that("mutate with .data pronoun", { expect_dplyr_equal( input %>% @@ -83,6 +82,165 @@ test_that("mutate with .data pronoun", { ) }) +test_that("mutate with single value for recycling", { + +}) + +test_that("dplyr::mutate's examples", { + # Newly created variables are available immediately + expect_dplyr_equal( + input %>% + select(name, mass) %>% + mutate( + mass2 = mass * 2, + mass2_squared = mass2 * mass2 + ) %>% + collect(), + starwars # this is a test dataset that ships with dplyr + ) + + # As well as adding new variables, you can use mutate() to + # remove variables and modify existing variables. + expect_dplyr_equal( + input %>% + select(name, height, mass, homeworld) %>% + mutate( + mass = NULL, + height = height * 0.0328084 # convert to feet + ) %>% + collect(), + starwars + ) + + # Examples we don't support should succeed + # but warn that they're pulling data into R to do so + + # across + autosplicing: ARROW-11699 + expect_warning( + expect_dplyr_equal( + input %>% + select(name, homeworld, species) %>% + mutate(across(!name, as.factor)) %>% + collect(), + starwars + ), + "not supported in Arrow" + ) + + # group_by then mutate + expect_warning( + expect_dplyr_equal( + input %>% + select(name, mass, homeworld) %>% + group_by(homeworld) %>% + mutate(rank = min_rank(desc(mass))) %>% + collect(), + starwars + ), + "not supported in Arrow" + ) + + # `.before` and `.after` experimental args: ARROW-11701 + df <- tibble(x = 1, y = 2) + expect_dplyr_equal( + input %>% mutate(z = x + y) %>% collect(), + df + ) + #> # A tibble: 1 x 3 + #> x y z + #> + #> 1 1 2 3 + expect_warning( + expect_dplyr_equal( + input %>% mutate(z = x + y, .before = 1) %>% collect(), + df + ), + "not supported in Arrow" + ) + #> # A tibble: 1 x 3 + #> z x y + #> + #> 1 3 1 2 + expect_warning( + expect_dplyr_equal( + input %>% mutate(z = x + y, .after = x) %>% collect(), + df + ), + "not supported in Arrow" + ) + #> # A tibble: 1 x 3 + #> x z y + #> + #> 1 1 3 2 + + # By default, mutate() keeps all columns from the input data. + # Experimental: You can override with `.keep` + df <- tibble(x = 1, y = 2, a = "a", b = "b") + expect_dplyr_equal( + input %>% mutate(z = x + y, .keep = "all") %>% collect(), # the default + df + ) + #> # A tibble: 1 x 5 + #> x y a b z + #> + #> 1 1 2 a b 3 + expect_dplyr_equal( + input %>% mutate(z = x + y, .keep = "used") %>% collect(), + df + ) + #> # A tibble: 1 x 3 + #> x y z + #> + #> 1 1 2 3 + expect_dplyr_equal( + input %>% mutate(z = x + y, .keep = "unused") %>% collect(), + df + ) + #> # A tibble: 1 x 3 + #> a b z + #> + #> 1 a b 3 + expect_dplyr_equal( + input %>% mutate(z = x + y, .keep = "none") %>% collect(), # same as transmute() + df + ) + #> # A tibble: 1 x 1 + #> z + #> + #> 1 3 + + # Grouping ---------------------------------------- + # The mutate operation may yield different results on grouped + # tibbles because the expressions are computed within groups. + # The following normalises `mass` by the global average: + # TODO(ARROW-11702) + expect_warning( + expect_dplyr_equal( + input %>% + select(name, mass, species) %>% + mutate(mass_norm = mass / mean(mass, na.rm = TRUE)) %>% + collect(), + starwars + ), + "not supported in Arrow" + ) +}) + test_that("handle bad expressions", { + # TODO: search for functions other than mean() (see above test) + # that need to be forced to fail because they error ambiguously + + skip("Error handling in arrow_eval() needs to be internationalized (ARROW-11700)") + expect_error( + Table$create(tbl) %>% mutate(newvar = NOTAVAR + 2), + "object 'NOTAVAR' not found" + ) +}) + +test_that("print a mutated dataset", { + +}) + +test_that("mutate and write_dataset", { }) \ No newline at end of file From e1cc3141bd28b8a41f32280dd46031779bfcb27d Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Fri, 19 Feb 2021 14:43:44 -0800 Subject: [PATCH 08/15] Factor out abandon_ship() so that mutate() may work for Datasets too --- r/R/dplyr.R | 69 ++++++++++++++++++++++++++--------------------------- 1 file changed, 34 insertions(+), 35 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 35b8568a244..5deb23ae085 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -496,38 +496,20 @@ mutate.arrow_dplyr_query <- function(.data, .before <- enquo(.before) .after <- enquo(.after) # Restrict the cases we support for now - fall_back_to_r <- FALSE if (!quo_is_null(.before) || !quo_is_null(.after)) { - warning( - '.before and .after arguments are not supported in Arrow; pulling data into R', - call. = FALSE - ) - fall_back_to_r <- TRUE + # TODO(ARROW-11701) + return(abandon_ship(call, .data, '.before and .after arguments are not supported in Arrow')) } else if (length(group_vars(.data)) > 0) { # mutate() on a grouped dataset does calculations within groups # This doesn't matter on scalar ops (arithmetic etc.) but it does # for things with aggregations (e.g. subtracting the mean) - warning( - 'mutate() on grouped data not supported in Arrow; pulling data into R', - call. = FALSE - ) - fall_back_to_r <- TRUE + return(abandon_ship(call, .data, 'mutate() on grouped data not supported in Arrow')) } else if (!all(nzchar(names(exprs)))) { # This is either user error or a function that returns a data.frame # e.g. across() that dplyr::mutate() will autosplice - warning( - 'all ... expressions must be named: ', - 'autosplicing multi-column results not supported in Arrow; ', - 'pulling data into R', - call. = FALSE - ) - fall_back_to_r <- TRUE - } - if (fall_back_to_r) { - # collect() and call mutate() on the data.frame - call$.data <- dplyr::collect(.data) - call[[1]] <- get("mutate", envir = asNamespace("dplyr")) - return(eval.parent(call)) + # TODO(ARROW-16999) + msg <- 'all ... expressions must be named: autosplicing multi-column results not supported in Arrow' + return(abandon_ship(call, .data, msg)) } mask <- arrow_mask(.data) @@ -535,15 +517,8 @@ mutate.arrow_dplyr_query <- function(.data, for (new_var in names(exprs)) { results[[new_var]] <- arrow_eval(exprs[[new_var]], mask) if (inherits(results[[new_var]], "try-error")) { - warning( - 'Expression ', as_label(exprs[[new_var]]), - ' not supported in Arrow; pulling data into R', - call. = FALSE - ) - # collect() and call mutate() on the data.frame - call$.data <- dplyr::collect(.data) - call[[1]] <- get("mutate", envir = asNamespace("dplyr")) - return(eval.parent(call)) + msg <- paste('Expression', as_label(exprs[[new_var]]), 'not supported in Arrow') + return(abandon_ship(call, .data, msg)) } # Put it in the data mask too mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]] @@ -578,13 +553,37 @@ mutate.Dataset <- mutate.ArrowTabular <- mutate.arrow_dplyr_query transmute.arrow_dplyr_query <- function(.data, ...) dplyr::mutate(.data, ..., .keep = "none") transmute.Dataset <- transmute.ArrowTabular <- transmute.arrow_dplyr_query +# Helper to handle unsupported dplyr features +# * For Table/RecordBatch, we collect() and then call the dplyr method in R +# * For Dataset, we just error +abandon_ship <- function(call, .data, msg = NULL) { + dplyr_fun_name <- sub("^(.*?)\\..*", "\\1", as.character(call[[1]])) + if (query_on_dataset(.data)) { + if (is.null(msg)) { + # Default message: function not implemented + not_implemented_for_dataset(paste0(dplyr_fun_name, "()")) + } else { + stop(msg, call. = FALSE) + } + } + + # else, collect and call dplyr method + if (!is.null(msg)) { + warning(msg, "; pulling data into R", immediate. = TRUE, call. = FALSE) + } + call$.data <- dplyr::collect(.data) + call[[1]] <- get(dplyr_fun_name, envir = asNamespace("dplyr")) + eval.parent(call, 2) +} + arrange.arrow_dplyr_query <- function(.data, ...) { .data <- arrow_dplyr_query(.data) if (query_on_dataset(.data)) { not_implemented_for_dataset("arrange()") } - - dplyr::arrange(dplyr::collect(.data), ...) + # TODO(ARROW-11703) move this to Arrow + call <- match.call() + abandon_ship(call, .data) } arrange.Dataset <- arrange.ArrowTabular <- arrange.arrow_dplyr_query From c43f36823035f88bb8b45c276bbe12e625b0e975 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Fri, 19 Feb 2021 15:35:32 -0800 Subject: [PATCH 09/15] Update print method and add writing test --- r/R/dataset-write.R | 4 ++ r/R/dplyr.R | 11 +++++- r/tests/testthat/test-dplyr-mutate.R | 56 ++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/r/R/dataset-write.R b/r/R/dataset-write.R index 61c20b31f2d..5078bc3e371 100644 --- a/r/R/dataset-write.R +++ b/r/R/dataset-write.R @@ -62,6 +62,10 @@ write_dataset <- function(dataset, hive_style = TRUE, ...) { if (inherits(dataset, "arrow_dplyr_query")) { + if (inherits(dataset$.data, "ArrowTabular")) { + # collect() to materialize any mutate/rename + dataset <- dplyr::collect(dataset, as_data_frame = FALSE) + } # We can select a subset of columns but we can't rename them if (!all(get_field_names(dataset) == names(dataset$selected_columns))) { stop("Renaming columns when writing a dataset is not yet supported", call. = FALSE) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 5deb23ae085..4f4ed5dea5e 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -52,8 +52,14 @@ arrow_dplyr_query <- function(.data) { print.arrow_dplyr_query <- function(x, ...) { schm <- x$.data$schema cols <- get_field_names(x) - # TODO: if cols are expressions, they won't be in the schema - fields <- map_chr(cols, ~schm$GetFieldByName(.)$ToString()) + # If cols are expressions, they won't be in the schema and will be "" in cols + fields <- map_chr(cols, function(name) { + if (nzchar(name)) { + schm$GetFieldByName(name)$ToString() + } else { + "expr" + } + }) # Strip off the field names as they are in the dataset and add the renamed ones fields <- paste(names(cols), sub("^.*?: ", "", fields), sep = ": ", collapse = "\n") cat(class(x$.data)[1], " (query)\n", sep = "") @@ -346,6 +352,7 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { } else { filter <- eval_array_expression(x$filtered_rows, x$.data) } + # TODO: shortcut if identical(names(x$.data), find_array_refs(x$selected_columns))? tab <- x$.data[filter, find_array_refs(x$selected_columns), keep_na = FALSE] # Now evaluate those expressions on the filtered table cols <- lapply(x$selected_columns, eval_array_expression, data = tab) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 96bcbfe4a26..b998125359b 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -238,9 +238,65 @@ test_that("handle bad expressions", { }) test_that("print a mutated dataset", { + expect_output( + Table$create(tbl) %>% + select(int) %>% + mutate(twice = int * 2) %>% + print(), +'Table (query) +int: int32 +twice: expr +See $.data for the source Arrow object', + fixed = TRUE) }) test_that("mutate and write_dataset", { + # See related test in test-dataset.R + skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-9651 + + first_date <- lubridate::ymd_hms("2015-04-29 03:12:39") + df1 <- tibble( + int = 1:10, + dbl = as.numeric(1:10), + lgl = rep(c(TRUE, FALSE, NA, TRUE, FALSE), 2), + chr = letters[1:10], + fct = factor(LETTERS[1:10]), + ts = first_date + lubridate::days(1:10) + ) + + second_date <- lubridate::ymd_hms("2017-03-09 07:01:02") + df2 <- tibble( + int = 101:110, + dbl = c(as.numeric(51:59), NaN), + lgl = rep(c(TRUE, FALSE, NA, TRUE, FALSE), 2), + chr = letters[10:1], + fct = factor(LETTERS[10:1]), + ts = second_date + lubridate::days(10:1) + ) + + dst_dir <- tempfile() + stacked <- record_batch(rbind(df1, df2)) + stacked %>% + mutate(twice = int * 2) %>% + group_by(int) %>% + write_dataset(dst_dir, format = "feather") + expect_true(dir.exists(dst_dir)) + expect_identical(dir(dst_dir), sort(paste("int", c(1:10, 101:110), sep = "="))) + + new_ds <- open_dataset(dst_dir, format = "feather") + + expect_equivalent( + new_ds %>% + select(string = chr, integer = int, twice) %>% + filter(integer > 6 & integer < 11) %>% + collect() %>% + summarize(mean = mean(integer)), + df1 %>% + select(string = chr, integer = int) %>% + mutate(twice = integer * 2) %>% + filter(integer > 6) %>% + summarize(mean = mean(integer)) + ) }) \ No newline at end of file From b7571ce3ca88e2ba8e6eab289379cc436d68869e Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Fri, 19 Feb 2021 16:29:00 -0800 Subject: [PATCH 10/15] One more skipped test --- r/R/expression.R | 2 +- r/tests/testthat/test-RecordBatch.R | 8 ++++++++ r/tests/testthat/test-dplyr-mutate.R | 11 ++++++++++- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/r/R/expression.R b/r/R/expression.R index a926007c1b1..74c1aefcae1 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -147,7 +147,7 @@ eval_array_expression <- function(x, data = NULL) { if (!is.null(data)) { x <- bind_array_refs(x, data) } - if (inherits(x, "ArrowDatum")) { + if (!inherits(x, "array_expression")) { # Nothing to evaluate return(x) } diff --git a/r/tests/testthat/test-RecordBatch.R b/r/tests/testthat/test-RecordBatch.R index aeee66d8710..708fd901880 100644 --- a/r/tests/testthat/test-RecordBatch.R +++ b/r/tests/testthat/test-RecordBatch.R @@ -416,6 +416,14 @@ test_that("record_batch() handles null type (ARROW-7064)", { expect_equivalent(batch$schema, schema(a = int32(), n = null())) }) +test_that("record_batch() scalar recycling", { + skip("Not implemented (ARROW-11705") + expect_data_frame( + record_batch(a = 1:10, b = 5), + tibble::tibble(a = 1:10, b = 5) + ) +}) + test_that("RecordBatch$Equals", { df <- tibble::tibble(x = 1:10, y = letters[1:10]) a <- record_batch(df) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index b998125359b..4678bd53621 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -83,7 +83,16 @@ test_that("mutate with .data pronoun", { }) test_that("mutate with single value for recycling", { - + skip("Not implemented (ARROW-11705") + expect_dplyr_equal( + input %>% + select(int, padded_strings) %>% + mutate( + dr_bronner = 1 # ALL ONE! + ) %>% + collect(), + tbl + ) }) test_that("dplyr::mutate's examples", { From 99b999961370d5dcef20c09aa3e51c256e3ab337 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 23 Feb 2021 14:13:24 -0800 Subject: [PATCH 11/15] ) Co-authored-by: Jonathan Keane --- r/tests/testthat/test-RecordBatch.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/tests/testthat/test-RecordBatch.R b/r/tests/testthat/test-RecordBatch.R index 708fd901880..a017823ce34 100644 --- a/r/tests/testthat/test-RecordBatch.R +++ b/r/tests/testthat/test-RecordBatch.R @@ -417,7 +417,7 @@ test_that("record_batch() handles null type (ARROW-7064)", { }) test_that("record_batch() scalar recycling", { - skip("Not implemented (ARROW-11705") + skip("Not implemented (ARROW-11705)") expect_data_frame( record_batch(a = 1:10, b = 5), tibble::tibble(a = 1:10, b = 5) From 8eedbafc81a2dad3f2a723f7ffd77c3d1968f00b Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 24 Feb 2021 11:13:10 -0800 Subject: [PATCH 12/15] Allow unnamed expressions in mutate --- r/R/dplyr.R | 10 ++++------ r/tests/testthat/test-dplyr-mutate.R | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 4f4ed5dea5e..577ba8f802b 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -511,14 +511,12 @@ mutate.arrow_dplyr_query <- function(.data, # This doesn't matter on scalar ops (arithmetic etc.) but it does # for things with aggregations (e.g. subtracting the mean) return(abandon_ship(call, .data, 'mutate() on grouped data not supported in Arrow')) - } else if (!all(nzchar(names(exprs)))) { - # This is either user error or a function that returns a data.frame - # e.g. across() that dplyr::mutate() will autosplice - # TODO(ARROW-16999) - msg <- 'all ... expressions must be named: autosplicing multi-column results not supported in Arrow' - return(abandon_ship(call, .data, msg)) } + unnamed <- !nzchar(names(exprs)) + # Deparse and take the first element in case they're long expressions + names(exprs)[unnamed] <- map_chr(exprs[unnamed], ~deparse(.)[1]) + mask <- arrow_mask(.data) results <- list() for (new_var in names(exprs)) { diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 4678bd53621..c797ebefaf5 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -82,6 +82,19 @@ test_that("mutate with .data pronoun", { ) }) +test_that("mutate with unnamed expressions", { + expect_dplyr_equal( + input %>% + select(int, padded_strings) %>% + mutate( + nchar(padded_strings) + ) %>% + filter(int > 5) %>% + collect(), + tbl + ) +}) + test_that("mutate with single value for recycling", { skip("Not implemented (ARROW-11705") expect_dplyr_equal( @@ -133,7 +146,7 @@ test_that("dplyr::mutate's examples", { collect(), starwars ), - "not supported in Arrow" + "Expression across.*not supported in Arrow" ) # group_by then mutate From 08f78fb41554b2a0095e73a7ee216d6192f77119 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 24 Feb 2021 11:24:29 -0800 Subject: [PATCH 13/15] Make get_field_names more robust and intelligible --- r/R/dplyr.R | 14 +++++++++++++- r/tests/testthat/test-dplyr-mutate.R | 13 +++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 577ba8f802b..b7529b17357 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -84,7 +84,18 @@ get_field_names <- function(selected_cols) { if (inherits(selected_cols, "arrow_dplyr_query")) { selected_cols <- selected_cols$selected_columns } - map_chr(selected_cols, ~.$field_name %||% .$args$field_name %||% "") + map_chr(selected_cols, function(x) { + if (inherits(x, "Expression")) { + out <- x$field_name + } else if (inherits(x, "array_expression")) { + out <- x$args$field_name + } else { + out <- NULL + } + # If x isn't some kind of field reference, out is NULL, + # but we always need to return a string + out %||% "" + }) } make_field_refs <- function(field_names, dataset = TRUE) { @@ -513,6 +524,7 @@ mutate.arrow_dplyr_query <- function(.data, return(abandon_ship(call, .data, 'mutate() on grouped data not supported in Arrow')) } + # Check for unnamed expressions and fix if any unnamed <- !nzchar(names(exprs)) # Deparse and take the first element in case they're long expressions names(exprs)[unnamed] <- map_chr(exprs[unnamed], ~deparse(.)[1]) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index c797ebefaf5..6a2f9e44164 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -269,6 +269,19 @@ test_that("print a mutated dataset", { int: int32 twice: expr +See $.data for the source Arrow object', + fixed = TRUE) + + # Handling non-expressions/edge cases + expect_output( + Table$create(tbl) %>% + select(int) %>% + mutate(again = 1:10) %>% + print(), +'Table (query) +int: int32 +again: expr + See $.data for the source Arrow object', fixed = TRUE) }) From 020dad00d5e878cdef193d07577f02ddf3c2a8cf Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 24 Feb 2021 13:40:10 -0800 Subject: [PATCH 14/15] More edge case handling --- r/R/dplyr.R | 11 +++++++---- r/tests/testthat/test-dplyr-mutate.R | 15 ++++++++++++++- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/r/R/dplyr.R b/r/R/dplyr.R index b7529b17357..2bd8170a1cb 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -527,14 +527,17 @@ mutate.arrow_dplyr_query <- function(.data, # Check for unnamed expressions and fix if any unnamed <- !nzchar(names(exprs)) # Deparse and take the first element in case they're long expressions - names(exprs)[unnamed] <- map_chr(exprs[unnamed], ~deparse(.)[1]) + names(exprs)[unnamed] <- map_chr(exprs[unnamed], as_label) mask <- arrow_mask(.data) results <- list() - for (new_var in names(exprs)) { - results[[new_var]] <- arrow_eval(exprs[[new_var]], mask) + for (i in seq_along(exprs)) { + # Iterate over the indices and not the names because names may be repeated + # (which overwrites the previous name) + new_var <- names(exprs)[i] + results[[new_var]] <- arrow_eval(exprs[[i]], mask) if (inherits(results[[new_var]], "try-error")) { - msg <- paste('Expression', as_label(exprs[[new_var]]), 'not supported in Arrow') + msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow') return(abandon_ship(call, .data, msg)) } # Put it in the data mask too diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 6a2f9e44164..56d7e368520 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -87,7 +87,8 @@ test_that("mutate with unnamed expressions", { input %>% select(int, padded_strings) %>% mutate( - nchar(padded_strings) + int, # bare column name + nchar(padded_strings) # expression ) %>% filter(int > 5) %>% collect(), @@ -95,6 +96,18 @@ test_that("mutate with unnamed expressions", { ) }) +test_that("mutate with reassigning same name", { + expect_dplyr_equal( + input %>% + transmute( + new = lgl, + new = chr + ) %>% + collect(), + tbl + ) +}) + test_that("mutate with single value for recycling", { skip("Not implemented (ARROW-11705") expect_dplyr_equal( From 4e0df35adce888bee47ec32dffc7c0f1e92a3b91 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Thu, 25 Feb 2021 10:30:12 -0800 Subject: [PATCH 15/15] Catch up on the news --- r/NEWS.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/r/NEWS.md b/r/NEWS.md index 65c4e2205cc..a008088ff82 100644 --- a/r/NEWS.md +++ b/r/NEWS.md @@ -19,6 +19,13 @@ # arrow 3.0.0.9000 +## dplyr methods + +* `dplyr::mutate()` on Arrow `Table` and `RecordBatch` is now supported in Arrow for many applications. Where not yet supported, the implementation falls back to pulling data into an R `data.frame` first. +* String functions `nchar()`, `tolower()`, and `toupper()`, along with their `stringr` spellings `str_length()`, `str_to_lower()`, and `str_to_upper()`, are supported in Arrow `dplyr` calls. `str_trim()` is also supported. + +## Other improvements + * `value_counts()` to tabulate values in an `Array` or `ChunkedArray`, similar to `base::table()`. * `StructArray` objects gain data.frame-like methods, including `names()`, `$`, `[[`, and `dim()`. * RecordBatch columns can now be added, replaced, or removed by assigning (`<-`) with either `$` or `[[`