diff --git a/r/NEWS.md b/r/NEWS.md index a008088ff82..1f1acb89805 100644 --- a/r/NEWS.md +++ b/r/NEWS.md @@ -21,7 +21,7 @@ ## dplyr methods -* `dplyr::mutate()` on Arrow `Table` and `RecordBatch` is now supported in Arrow for many applications. Where not yet supported, the implementation falls back to pulling data into an R `data.frame` first. +* `dplyr::mutate()` is now supported in Arrow for many applications. For queries on `Table` and `RecordBatch` that are not yet supported in Arrow, the implementation falls back to pulling data into an R `data.frame` first, as in the previous release. For queries on `Dataset`, it raises an error if the feature is not implemented. * String functions `nchar()`, `tolower()`, and `toupper()`, along with their `stringr` spellings `str_length()`, `str_to_lower()`, and `str_to_upper()`, are supported in Arrow `dplyr` calls. `str_trim()` is also supported. ## Other improvements diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 249341982e6..31b47b044e8 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -448,8 +448,12 @@ dataset___HivePartitioning__MakeFactory <- function(null_fallback){ .Call(`_arrow_dataset___HivePartitioning__MakeFactory`, null_fallback) } -dataset___ScannerBuilder__Project <- function(sb, cols){ - invisible(.Call(`_arrow_dataset___ScannerBuilder__Project`, sb, cols)) +dataset___ScannerBuilder__ProjectNames <- function(sb, cols){ + invisible(.Call(`_arrow_dataset___ScannerBuilder__ProjectNames`, sb, cols)) +} + +dataset___ScannerBuilder__ProjectExprs <- function(sb, exprs, names){ + invisible(.Call(`_arrow_dataset___ScannerBuilder__ProjectExprs`, sb, exprs, names)) } dataset___ScannerBuilder__Filter <- function(sb, expr){ diff --git a/r/R/dataset-scan.R b/r/R/dataset-scan.R index ec6f85c4bab..1c71bf481b5 100644 --- a/r/R/dataset-scan.R +++ b/r/R/dataset-scan.R @@ -157,13 +157,19 @@ ScannerBuilder <- R6Class("ScannerBuilder", inherit = ArrowObject, public = list( Project = function(cols) { # cols is either a character vector or a named list of Expressions - if (!is.character(cols)) { - # We don't yet support mutate() on datasets, so this is just a list - # of FieldRefs, and we need to back out the field names - cols <- get_field_names(cols) + if (is.character(cols)) { + dataset___ScannerBuilder__ProjectNames(self, cols) + } else { + # If we have expressions, but they all turn out to be field_refs, + # we can still call the simple method + field_names <- get_field_names(cols) + if (all(nzchar(field_names))) { + dataset___ScannerBuilder__ProjectNames(self, field_names) + } else { + # Else, we are projecting/mutating + dataset___ScannerBuilder__ProjectExprs(self, cols, names(cols)) + } } - assert_is(cols, "character") - dataset___ScannerBuilder__Project(self, cols) self }, Filter = function(expr) { diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 2bd8170a1cb..3a4d6faba00 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -506,9 +506,6 @@ mutate.arrow_dplyr_query <- function(.data, } .data <- arrow_dplyr_query(.data) - if (query_on_dataset(.data)) { - not_implemented_for_dataset("mutate()") - } .keep <- match.arg(.keep) .before <- enquo(.before) @@ -529,6 +526,7 @@ mutate.arrow_dplyr_query <- function(.data, # Deparse and take the first element in case they're long expressions names(exprs)[unnamed] <- map_chr(exprs[unnamed], as_label) + is_dataset <- query_on_dataset(.data) mask <- arrow_mask(.data) results <- list() for (i in seq_along(exprs)) { @@ -539,6 +537,15 @@ mutate.arrow_dplyr_query <- function(.data, if (inherits(results[[new_var]], "try-error")) { msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow') return(abandon_ship(call, .data, msg)) + } else if (is_dataset && + !inherits(results[[new_var]], "Expression") && + !is.null(results[[new_var]])) { + # We need some wrapping to handle literal values + if (length(results[[new_var]]) != 1) { + msg <- paste0('In ', new_var, " = ", as_label(exprs[[i]]), ", only values of size one are recycled") + return(abandon_ship(call, .data, msg)) + } + results[[new_var]] <- Expression$scalar(results[[new_var]]) } # Put it in the data mask too mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]] @@ -583,7 +590,7 @@ abandon_ship <- function(call, .data, msg = NULL) { # Default message: function not implemented not_implemented_for_dataset(paste0(dplyr_fun_name, "()")) } else { - stop(msg, call. = FALSE) + stop(msg, "\nCall collect() first to pull data into R.", call. = FALSE) } } diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 100c8087959..5cbebdf4562 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -970,12 +970,23 @@ BEGIN_CPP11 END_CPP11 } // dataset.cpp -void dataset___ScannerBuilder__Project(const std::shared_ptr& sb, const std::vector& cols); -extern "C" SEXP _arrow_dataset___ScannerBuilder__Project(SEXP sb_sexp, SEXP cols_sexp){ +void dataset___ScannerBuilder__ProjectNames(const std::shared_ptr& sb, const std::vector& cols); +extern "C" SEXP _arrow_dataset___ScannerBuilder__ProjectNames(SEXP sb_sexp, SEXP cols_sexp){ BEGIN_CPP11 arrow::r::Input&>::type sb(sb_sexp); arrow::r::Input&>::type cols(cols_sexp); - dataset___ScannerBuilder__Project(sb, cols); + dataset___ScannerBuilder__ProjectNames(sb, cols); + return R_NilValue; +END_CPP11 +} +// dataset.cpp +void dataset___ScannerBuilder__ProjectExprs(const std::shared_ptr& sb, const std::vector>& exprs, const std::vector& names); +extern "C" SEXP _arrow_dataset___ScannerBuilder__ProjectExprs(SEXP sb_sexp, SEXP exprs_sexp, SEXP names_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type sb(sb_sexp); + arrow::r::Input>&>::type exprs(exprs_sexp); + arrow::r::Input&>::type names(names_sexp); + dataset___ScannerBuilder__ProjectExprs(sb, exprs, names); return R_NilValue; END_CPP11 } @@ -3638,7 +3649,8 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_dataset___DirectoryPartitioning__MakeFactory", (DL_FUNC) &_arrow_dataset___DirectoryPartitioning__MakeFactory, 1}, { "_arrow_dataset___HivePartitioning", (DL_FUNC) &_arrow_dataset___HivePartitioning, 2}, { "_arrow_dataset___HivePartitioning__MakeFactory", (DL_FUNC) &_arrow_dataset___HivePartitioning__MakeFactory, 1}, - { "_arrow_dataset___ScannerBuilder__Project", (DL_FUNC) &_arrow_dataset___ScannerBuilder__Project, 2}, + { "_arrow_dataset___ScannerBuilder__ProjectNames", (DL_FUNC) &_arrow_dataset___ScannerBuilder__ProjectNames, 2}, + { "_arrow_dataset___ScannerBuilder__ProjectExprs", (DL_FUNC) &_arrow_dataset___ScannerBuilder__ProjectExprs, 3}, { "_arrow_dataset___ScannerBuilder__Filter", (DL_FUNC) &_arrow_dataset___ScannerBuilder__Filter, 2}, { "_arrow_dataset___ScannerBuilder__UseThreads", (DL_FUNC) &_arrow_dataset___ScannerBuilder__UseThreads, 2}, { "_arrow_dataset___ScannerBuilder__BatchSize", (DL_FUNC) &_arrow_dataset___ScannerBuilder__BatchSize, 2}, diff --git a/r/src/dataset.cpp b/r/src/dataset.cpp index 001a498b44a..f8c24217ce3 100644 --- a/r/src/dataset.cpp +++ b/r/src/dataset.cpp @@ -308,11 +308,24 @@ std::shared_ptr dataset___HivePartitioning__MakeFactory // ScannerBuilder, Scanner // [[arrow::export]] -void dataset___ScannerBuilder__Project(const std::shared_ptr& sb, - const std::vector& cols) { +void dataset___ScannerBuilder__ProjectNames(const std::shared_ptr& sb, + const std::vector& cols) { StopIfNotOk(sb->Project(cols)); } +// [[arrow::export]] +void dataset___ScannerBuilder__ProjectExprs( + const std::shared_ptr& sb, + const std::vector>& exprs, + const std::vector& names) { + // We have shared_ptrs of expressions but need the Expressions + std::vector expressions; + for (auto expr : exprs) { + expressions.push_back(*expr); + } + StopIfNotOk(sb->Project(expressions, names)); +} + // [[arrow::export]] void dataset___ScannerBuilder__Filter(const std::shared_ptr& sb, const std::shared_ptr& expr) { diff --git a/r/src/expression.cpp b/r/src/expression.cpp index 76d8222967b..e7c6dd4e0c0 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -50,8 +50,11 @@ std::shared_ptr dataset___expr__field_ref(std::string name) { // [[arrow::export]] std::string dataset___expr__get_field_ref_name( const std::shared_ptr& ref) { - auto refname = ref->field_ref()->name(); - return *refname; + auto field_ref = ref->field_ref(); + if (field_ref == nullptr) { + return ""; + } + return *field_ref->name(); } // [[arrow::export]] diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 092913373e8..502282c4e29 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -717,6 +717,96 @@ test_that("filter() with expressions", { ) }) +test_that("mutate()", { + ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) + mutated <- ds %>% + select(chr, dbl, int) %>% + filter(dbl * 2 > 14 & dbl - 50 < 3L) %>% + mutate(twice = int * 2) + expect_output( + print(mutated), +"FileSystemDataset (query) +chr: string +dbl: double +int: int32 +twice: expr + +* Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3)) +See $.data for the source Arrow object", + fixed = TRUE + ) + expect_equivalent( + mutated %>% + collect() %>% + arrange(dbl), + rbind( + df1[8:10, c("chr", "dbl", "int")], + df2[1:2, c("chr", "dbl", "int")] + ) %>% + mutate( + twice = int * 2 + ) + ) +}) + +test_that("transmute()", { + ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) + mutated <- + expect_equivalent( + ds %>% + select(chr, dbl, int) %>% + filter(dbl * 2 > 14 & dbl - 50 < 3L) %>% + transmute(twice = int * 2) %>% + collect() %>% + arrange(twice), + rbind( + df1[8:10, "int", drop = FALSE], + df2[1:2, "int", drop = FALSE] + ) %>% + transmute( + twice = int * 2 + ) + ) +}) + +test_that("mutate() features not yet implemented", { + expect_error( + ds %>% + group_by(int) %>% + mutate(avg = mean(int)), + "mutate() on grouped data not supported in Arrow\nCall collect() first to pull data into R.", + fixed = TRUE + ) +}) + + +test_that("mutate() with scalar (length 1) literal inputs", { + expect_equal( + ds %>% + mutate(the_answer = 42) %>% + collect() %>% + pull(the_answer), + rep(42, nrow(ds)) + ) + + expect_error( + ds %>% mutate(the_answer = c(42, 42)), + "In the_answer = c(42, 42), only values of size one are recycled\nCall collect() first to pull data into R.", + fixed = TRUE + ) +}) + +test_that("mutate() with NULL inputs", { + expect_equal( + ds %>% + mutate(int = NULL) %>% + collect(), + ds %>% + select(-int) %>% + collect() + ) +}) + test_that("filter scalar validation doesn't crash (ARROW-7772)", { expect_error( ds %>% @@ -832,7 +922,6 @@ test_that("dplyr method not implemented messages", { expect_error(x, "is not currently implemented for Arrow Datasets") } expect_not_implemented(ds %>% arrange(int)) - expect_not_implemented(ds %>% mutate(int = int + 2)) expect_not_implemented(ds %>% filter(int == 1) %>% summarize(n())) }) @@ -1137,7 +1226,7 @@ test_that("Dataset writing: no partitioning", { test_that("Dataset writing: partition on null", { skip_on_os("windows") # https://issues.apache.org/jira/browse/ARROW-9651 ds <- open_dataset(hive_dir) - + dst_dir <- tempfile() partitioning = hive_partition(lgl = boolean()) write_dataset(ds, dst_dir, partitioning = partitioning)