diff --git a/r/R/dataset-scan.R b/r/R/dataset-scan.R index b7f58bfa4bd..21a5056f7e1 100644 --- a/r/R/dataset-scan.R +++ b/r/R/dataset-scan.R @@ -185,17 +185,36 @@ ScanTask <- R6Class("ScanTask", #' `data.frame`? Default `TRUE` #' @export map_batches <- function(X, FUN, ..., .data.frame = TRUE) { - if (.data.frame) { - lapply <- map_dfr - } - scanner <- Scanner$create(ensure_group_vars(X)) + # TODO: ARROW-15271 possibly refactor do_exec_plan to return a RecordBatchReader + plan <- ExecPlan$create() + final_node <- plan$Build(X) + reader <- plan$Run(final_node) FUN <- as_mapper(FUN) - lapply(scanner$ScanBatches(), function(batch) { - # TODO: wrap batch in arrow_dplyr_query with X$selected_columns, - # X$temp_columns, and X$group_by_vars - # if X is arrow_dplyr_query, if some other arg (.dplyr?) == TRUE - FUN(batch, ...) - }) + + # TODO: wrap batch in arrow_dplyr_query with X$selected_columns, + # X$temp_columns, and X$group_by_vars + # if X is arrow_dplyr_query, if some other arg (.dplyr?) == TRUE + batch <- reader$read_next_batch() + res <- vector("list", 1024) + i <- 0L + while (!is.null(batch)) { + i <- i + 1L + res[[i]] <- FUN(batch, ...) + batch <- reader$read_next_batch() + } + + # Trim list back + if (i < length(res)) { + res <- res[seq_len(i)] + } + + if (.data.frame & inherits(res[[1]], "arrow_dplyr_query")) { + res <- dplyr::bind_rows(map(res, collect)) + } else if (.data.frame) { + res <- dplyr::bind_rows(map(res, as.data.frame)) + } + + res } #' @usage NULL diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 8465e93bed5..58e7458098e 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -453,15 +453,38 @@ test_that("Creating UnionDataset", { }) test_that("map_batches", { - skip("map_batches() is broken (ARROW-14029)") ds <- open_dataset(dataset_dir, partitioning = "part") + + # summarize returns arrow_dplyr_query, which gets collected into a tibble expect_equal( ds %>% filter(int > 5) %>% select(int, lgl) %>% - map_batches(~ summarize(., min_int = min(int))), + map_batches(~ summarize(., min_int = min(int))) %>% + arrange(min_int), tibble(min_int = c(6L, 101L)) ) + + # $num_rows returns integer vector + expect_equal( + ds %>% + filter(int > 5) %>% + select(int, lgl) %>% + map_batches(~ .$num_rows, .data.frame = FALSE) %>% + unlist() %>% # Returns list because .data.frame is FALSE + sort(), + c(5, 10) + ) + + # $Take returns RecordBatch, which gets binded into a tibble + expect_equal( + ds %>% + filter(int > 5) %>% + select(int, lgl) %>% + map_batches(~ .$Take(0)) %>% + arrange(int), + tibble(int = c(6, 101), lgl = c(TRUE, TRUE)) + ) }) test_that("partitioning = NULL to ignore partition information (but why?)", { diff --git a/r/vignettes/dataset.Rmd b/r/vignettes/dataset.Rmd index a7e8b8050b4..f09185589e1 100644 --- a/r/vignettes/dataset.Rmd +++ b/r/vignettes/dataset.Rmd @@ -290,6 +290,66 @@ rows match the filter. Relatedly, since Parquet files contain row groups with statistics on the data within, there may be entire chunks of data you can avoid scanning because they have no rows where `total_amount > 100`. +### Processing data in batches + +Sometimes you want to run R code on the entire dataset, but that dataset is much +larger than memory. You can use `map_batches` on a dataset query to process +it batch-by-batch. + +**Note**: `map_batches` is experimental and not recommended for production use. + +As an example, to randomly sample a dataset, use `map_batches` to sample a +percentage of rows from each batch: + +```{r, eval = file.exists("nyc-taxi")} +sampled_data <- ds %>% + filter(year == 2015) %>% + select(tip_amount, total_amount, passenger_count) %>% + map_batches(~ sample_frac(as.data.frame(.), 1e-4)) %>% + mutate(tip_pct = tip_amount / total_amount) + +str(sampled_data) +``` + +```{r, echo = FALSE, eval = !file.exists("nyc-taxi")} +cat(" +'data.frame': 15603 obs. of 4 variables: + $ tip_amount : num 0 0 1.55 1.45 5.2 ... + $ total_amount : num 5.8 16.3 7.85 8.75 26 ... + $ passenger_count: int 1 1 1 1 1 6 5 1 2 1 ... + $ tip_pct : num 0 0 0.197 0.166 0.2 ... +") +``` + +This function can also be used to aggregate summary statistics over a dataset by +computing partial results for each batch and then aggregating those partial +results. Extending the example above, you could fit a model to the sample data +and then use `map_batches` to compute the MSE on the full dataset. + +```{r, eval = file.exists("nyc-taxi")} +model <- lm(tip_pct ~ total_amount + passenger_count, data = sampled_data) + +ds %>% + filter(year == 2015) %>% + select(tip_amount, total_amount, passenger_count) %>% + mutate(tip_pct = tip_amount / total_amount) %>% + map_batches(function(batch) { + batch %>% + as.data.frame() %>% + mutate(pred_tip_pct = predict(model, newdata = .)) %>% + filter(!is.nan(tip_pct)) %>% + summarize(sse_partial = sum((pred_tip_pct - tip_pct)^2), n_partial = n()) + }) %>% + summarize(mse = sum(sse_partial) / sum(n_partial)) %>% + pull(mse) +``` + +```{r, echo = FALSE, eval = !file.exists("nyc-taxi")} +cat(" +[1] 0.1304284 +") +``` + ## More dataset options There are a few ways you can control the Dataset creation to adapt to special use cases.