Skip to content
39 changes: 29 additions & 10 deletions r/R/dataset-scan.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,36 @@ ScanTask <- R6Class("ScanTask",
#' `data.frame`? Default `TRUE`
#' @export
map_batches <- function(X, FUN, ..., .data.frame = TRUE) {
if (.data.frame) {
lapply <- map_dfr
}
scanner <- Scanner$create(ensure_group_vars(X))
# TODO: ARROW-15271 possibly refactor do_exec_plan to return a RecordBatchReader
plan <- ExecPlan$create()
final_node <- plan$Build(X)
reader <- plan$Run(final_node)
FUN <- as_mapper(FUN)
lapply(scanner$ScanBatches(), function(batch) {
# TODO: wrap batch in arrow_dplyr_query with X$selected_columns,
# X$temp_columns, and X$group_by_vars
# if X is arrow_dplyr_query, if some other arg (.dplyr?) == TRUE
FUN(batch, ...)
})

# TODO: wrap batch in arrow_dplyr_query with X$selected_columns,
# X$temp_columns, and X$group_by_vars
# if X is arrow_dplyr_query, if some other arg (.dplyr?) == TRUE
batch <- reader$read_next_batch()
res <- vector("list", 1024)
i <- 0L
while (!is.null(batch)) {
i <- i + 1L
res[[i]] <- FUN(batch, ...)
batch <- reader$read_next_batch()
}

# Trim list back
if (i < length(res)) {
res <- res[seq_len(i)]
}

if (.data.frame & inherits(res[[1]], "arrow_dplyr_query")) {
res <- dplyr::bind_rows(map(res, collect))
} else if (.data.frame) {
res <- dplyr::bind_rows(map(res, as.data.frame))
}

res
}

#' @usage NULL
Expand Down
27 changes: 25 additions & 2 deletions r/tests/testthat/test-dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -453,15 +453,38 @@ test_that("Creating UnionDataset", {
})

test_that("map_batches", {
skip("map_batches() is broken (ARROW-14029)")
ds <- open_dataset(dataset_dir, partitioning = "part")

# summarize returns arrow_dplyr_query, which gets collected into a tibble
expect_equal(
ds %>%
filter(int > 5) %>%
select(int, lgl) %>%
map_batches(~ summarize(., min_int = min(int))),
map_batches(~ summarize(., min_int = min(int))) %>%
arrange(min_int),
tibble(min_int = c(6L, 101L))
)

# $num_rows returns integer vector
expect_equal(
ds %>%
filter(int > 5) %>%
select(int, lgl) %>%
map_batches(~ .$num_rows, .data.frame = FALSE) %>%
unlist() %>% # Returns list because .data.frame is FALSE
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this fantastic extra clarification!

sort(),
c(5, 10)
)

# $Take returns RecordBatch, which gets binded into a tibble
expect_equal(
ds %>%
filter(int > 5) %>%
select(int, lgl) %>%
map_batches(~ .$Take(0)) %>%
arrange(int),
tibble(int = c(6, 101), lgl = c(TRUE, TRUE))
)
})

test_that("partitioning = NULL to ignore partition information (but why?)", {
Expand Down
60 changes: 60 additions & 0 deletions r/vignettes/dataset.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,66 @@ rows match the filter. Relatedly, since Parquet files contain row groups with
statistics on the data within, there may be entire chunks of data you can
avoid scanning because they have no rows where `total_amount > 100`.

### Processing data in batches
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example was useful in testing, and hopefully gives some ideas for usage. Though perhaps it belongs more in the cookbook? LMK what you think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it a lot. And I think it totally belongs here in a vignette (especially in the tone you have here). But it wouldn't be bad to make an issue to add to the cookbook as well (though don't feel obligated to do that right now if you don't want to!).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Sometimes you want to run R code on the entire dataset, but that dataset is much
larger than memory. You can use `map_batches` on a dataset query to process
it batch-by-batch.

**Note**: `map_batches` is experimental and not recommended for production use.

As an example, to randomly sample a dataset, use `map_batches` to sample a
percentage of rows from each batch:

```{r, eval = file.exists("nyc-taxi")}
sampled_data <- ds %>%
filter(year == 2015) %>%
select(tip_amount, total_amount, passenger_count) %>%
map_batches(~ sample_frac(as.data.frame(.), 1e-4)) %>%
mutate(tip_pct = tip_amount / total_amount)

str(sampled_data)
```

```{r, echo = FALSE, eval = !file.exists("nyc-taxi")}
cat("
'data.frame': 15603 obs. of 4 variables:
$ tip_amount : num 0 0 1.55 1.45 5.2 ...
$ total_amount : num 5.8 16.3 7.85 8.75 26 ...
$ passenger_count: int 1 1 1 1 1 6 5 1 2 1 ...
$ tip_pct : num 0 0 0.197 0.166 0.2 ...
")
```

This function can also be used to aggregate summary statistics over a dataset by
computing partial results for each batch and then aggregating those partial
results. Extending the example above, you could fit a model to the sample data
and then use `map_batches` to compute the MSE on the full dataset.

```{r, eval = file.exists("nyc-taxi")}
model <- lm(tip_pct ~ total_amount + passenger_count, data = sampled_data)

ds %>%
filter(year == 2015) %>%
select(tip_amount, total_amount, passenger_count) %>%
mutate(tip_pct = tip_amount / total_amount) %>%
map_batches(function(batch) {
batch %>%
as.data.frame() %>%
mutate(pred_tip_pct = predict(model, newdata = .)) %>%
filter(!is.nan(tip_pct)) %>%
summarize(sse_partial = sum((pred_tip_pct - tip_pct)^2), n_partial = n())
}) %>%
summarize(mse = sum(sse_partial) / sum(n_partial)) %>%
pull(mse)
```

```{r, echo = FALSE, eval = !file.exists("nyc-taxi")}
cat("
[1] 0.1304284
")
```

## More dataset options

There are a few ways you can control the Dataset creation to adapt to special use cases.
Expand Down