diff --git a/r/NAMESPACE b/r/NAMESPACE index 0a120dc97a6..f8f174579a2 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -67,6 +67,7 @@ S3method(as_record_batch,arrow_dplyr_query) S3method(as_record_batch,data.frame) S3method(as_record_batch,pyarrow.lib.RecordBatch) S3method(as_record_batch,pyarrow.lib.Table) +S3method(as_record_batch_reader,"function") S3method(as_record_batch_reader,Dataset) S3method(as_record_batch_reader,RecordBatch) S3method(as_record_batch_reader,RecordBatchReader) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index dfe0db614ad..711a0abf2ac 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -1736,6 +1736,10 @@ RecordBatchReader__from_batches <- function(batches, schema_sxp) { .Call(`_arrow_RecordBatchReader__from_batches`, batches, schema_sxp) } +RecordBatchReader__from_function <- function(fun_sexp, schema) { + .Call(`_arrow_RecordBatchReader__from_function`, fun_sexp, schema) +} + RecordBatchReader__from_Table <- function(table) { .Call(`_arrow_RecordBatchReader__from_Table`, table) } diff --git a/r/R/dataset-scan.R b/r/R/dataset-scan.R index cca92b676fe..53fe7078c23 100644 --- a/r/R/dataset-scan.R +++ b/r/R/dataset-scan.R @@ -183,11 +183,16 @@ tail_from_batches <- function(batches, n) { #' @param FUN A function or `purrr`-style lambda expression to apply to each #' batch. It must return a RecordBatch or something coercible to one via #' `as_record_batch()'. +#' @param .schema An optional [schema()]. If NULL, the schema will be inferred +#' from the first batch. +#' @param .lazy Use `TRUE` to evaluate `FUN` lazily as batches are read from +#' the result; use `FALSE` to evaluate `FUN` on all batches before returning +#' the reader. #' @param ... Additional arguments passed to `FUN` #' @param .data.frame Deprecated argument, ignored #' @return An `arrow_dplyr_query`. #' @export -map_batches <- function(X, FUN, ..., .data.frame = NULL) { +map_batches <- function(X, FUN, ..., .schema = NULL, .lazy = FALSE, .data.frame = NULL) { if (!is.null(.data.frame)) { warning( "The .data.frame argument is deprecated. ", @@ -197,25 +202,60 @@ map_batches <- function(X, FUN, ..., .data.frame = NULL) { } FUN <- as_mapper(FUN) reader <- as_record_batch_reader(X) + dots <- rlang::list2(...) - # TODO: for future consideration - # * Move eval to C++ and make it a generator so it can stream, not block - # * Accept an output schema argument: with that, we could make this lazy (via collapse) - batch <- reader$read_next_batch() - res <- vector("list", 1024) - i <- 0L - while (!is.null(batch)) { - i <- i + 1L - res[[i]] <- as_record_batch(FUN(batch, ...)) + # If no schema is supplied, we have to evaluate the first batch here + if (is.null(.schema)) { batch <- reader$read_next_batch() + if (is.null(batch)) { + abort("Can't infer schema from a RecordBatchReader with zero batches") + } + + first_result <- as_record_batch(do.call(FUN, c(list(batch), dots))) + .schema <- first_result$schema + fun <- function() { + if (!is.null(first_result)) { + result <- first_result + first_result <<- NULL + result + } else { + batch <- reader$read_next_batch() + if (is.null(batch)) { + NULL + } else { + as_record_batch( + do.call(FUN, c(list(batch), dots)), + schema = .schema + ) + } + } + } + } else { + fun <- function() { + batch <- reader$read_next_batch() + if (is.null(batch)) { + return(NULL) + } + + as_record_batch( + do.call(FUN, c(list(batch), dots)), + schema = .schema + ) + } } - # Trim list back - if (i < length(res)) { - res <- res[seq_len(i)] + reader_out <- as_record_batch_reader(fun, schema = .schema) + + # TODO(ARROW-17178) because there are some restrictions on evaluating + # reader_out in some ExecPlans, the default .lazy is FALSE for now. + if (!.lazy) { + reader_out <- RecordBatchReader$create( + batches = reader_out$batches(), + schema = .schema + ) } - RecordBatchReader$create(batches = res) + reader_out } #' @usage NULL diff --git a/r/R/record-batch-reader.R b/r/R/record-batch-reader.R index 8f6a600dfb1..3a985d8abce 100644 --- a/r/R/record-batch-reader.R +++ b/r/R/record-batch-reader.R @@ -191,6 +191,8 @@ RecordBatchFileReader$create <- function(file) { #' Convert an object to an Arrow RecordBatchReader #' #' @param x An object to convert to a [RecordBatchReader] +#' @param schema The [schema()] that must match the schema returned by each +#' call to `x` when `x` is a function. #' @param ... Passed to S3 methods #' #' @return A [RecordBatchReader] @@ -234,6 +236,13 @@ as_record_batch_reader.Dataset <- function(x, ...) { Scanner$create(x)$ToRecordBatchReader() } +#' @rdname as_record_batch_reader +#' @export +as_record_batch_reader.function <- function(x, ..., schema) { + assert_that(inherits(schema, "Schema")) + RecordBatchReader__from_function(x, schema) +} + #' @rdname as_record_batch_reader #' @export as_record_batch_reader.arrow_dplyr_query <- function(x, ...) { diff --git a/r/man/as_record_batch_reader.Rd b/r/man/as_record_batch_reader.Rd index e635c0b98bd..2ed54354760 100644 --- a/r/man/as_record_batch_reader.Rd +++ b/r/man/as_record_batch_reader.Rd @@ -7,6 +7,7 @@ \alias{as_record_batch_reader.RecordBatch} \alias{as_record_batch_reader.data.frame} \alias{as_record_batch_reader.Dataset} +\alias{as_record_batch_reader.function} \alias{as_record_batch_reader.arrow_dplyr_query} \alias{as_record_batch_reader.Scanner} \title{Convert an object to an Arrow RecordBatchReader} @@ -23,6 +24,8 @@ as_record_batch_reader(x, ...) \method{as_record_batch_reader}{Dataset}(x, ...) +\method{as_record_batch_reader}{`function`}(x, ..., schema) + \method{as_record_batch_reader}{arrow_dplyr_query}(x, ...) \method{as_record_batch_reader}{Scanner}(x, ...) @@ -31,6 +34,9 @@ as_record_batch_reader(x, ...) \item{x}{An object to convert to a \link{RecordBatchReader}} \item{...}{Passed to S3 methods} + +\item{schema}{The \code{\link[=schema]{schema()}} that must match the schema returned by each +call to \code{x} when \code{x} is a function.} } \value{ A \link{RecordBatchReader} diff --git a/r/man/map_batches.Rd b/r/man/map_batches.Rd index eaeab6013a6..0e4d48e024d 100644 --- a/r/man/map_batches.Rd +++ b/r/man/map_batches.Rd @@ -4,7 +4,7 @@ \alias{map_batches} \title{Apply a function to a stream of RecordBatches} \usage{ -map_batches(X, FUN, ..., .data.frame = NULL) +map_batches(X, FUN, ..., .schema = NULL, .lazy = FALSE, .data.frame = NULL) } \arguments{ \item{X}{A \code{Dataset} or \code{arrow_dplyr_query} object, as returned by the @@ -16,6 +16,13 @@ batch. It must return a RecordBatch or something coercible to one via \item{...}{Additional arguments passed to \code{FUN}} +\item{.schema}{An optional \code{\link[=schema]{schema()}}. If NULL, the schema will be inferred +from the first batch.} + +\item{.lazy}{Use \code{TRUE} to evaluate \code{FUN} lazily as batches are read from +the result; use \code{FALSE} to evaluate \code{FUN} on all batches before returning +the reader.} + \item{.data.frame}{Deprecated argument, ignored} } \value{ diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index fd9f92e5d1a..a5a75fc983c 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -4480,6 +4480,15 @@ BEGIN_CPP11 END_CPP11 } // recordbatchreader.cpp +std::shared_ptr RecordBatchReader__from_function(cpp11::sexp fun_sexp, const std::shared_ptr& schema); +extern "C" SEXP _arrow_RecordBatchReader__from_function(SEXP fun_sexp_sexp, SEXP schema_sexp){ +BEGIN_CPP11 + arrow::r::Input::type fun_sexp(fun_sexp_sexp); + arrow::r::Input&>::type schema(schema_sexp); + return cpp11::as_sexp(RecordBatchReader__from_function(fun_sexp, schema)); +END_CPP11 +} +// recordbatchreader.cpp std::shared_ptr RecordBatchReader__from_Table(const std::shared_ptr& table); extern "C" SEXP _arrow_RecordBatchReader__from_Table(SEXP table_sexp){ BEGIN_CPP11 @@ -5599,6 +5608,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_RecordBatchReader__ReadNext", (DL_FUNC) &_arrow_RecordBatchReader__ReadNext, 1}, { "_arrow_RecordBatchReader__batches", (DL_FUNC) &_arrow_RecordBatchReader__batches, 1}, { "_arrow_RecordBatchReader__from_batches", (DL_FUNC) &_arrow_RecordBatchReader__from_batches, 2}, + { "_arrow_RecordBatchReader__from_function", (DL_FUNC) &_arrow_RecordBatchReader__from_function, 2}, { "_arrow_RecordBatchReader__from_Table", (DL_FUNC) &_arrow_RecordBatchReader__from_Table, 1}, { "_arrow_Table__from_RecordBatchReader", (DL_FUNC) &_arrow_Table__from_RecordBatchReader, 1}, { "_arrow_RecordBatchReader__Head", (DL_FUNC) &_arrow_RecordBatchReader__Head, 2}, diff --git a/r/src/recordbatchreader.cpp b/r/src/recordbatchreader.cpp index fb173825f3b..c571d282da1 100644 --- a/r/src/recordbatchreader.cpp +++ b/r/src/recordbatchreader.cpp @@ -16,6 +16,7 @@ // under the License. #include "./arrow_types.h" +#include "./safe-call-into-r.h" #include #include @@ -54,6 +55,50 @@ std::shared_ptr RecordBatchReader__from_batches( } } +class RFunctionRecordBatchReader : public arrow::RecordBatchReader { + public: + RFunctionRecordBatchReader(cpp11::sexp fun, + const std::shared_ptr& schema) + : fun_(fun), schema_(schema) {} + + std::shared_ptr schema() const { return schema_; } + + arrow::Status ReadNext(std::shared_ptr* batch_out) { + auto batch = SafeCallIntoR>([&]() { + cpp11::sexp result_sexp = fun_(); + if (result_sexp == R_NilValue) { + return std::shared_ptr(nullptr); + } else if (!Rf_inherits(result_sexp, "RecordBatch")) { + cpp11::stop("Expected fun() to return an arrow::RecordBatch"); + } + + return cpp11::as_cpp>(result_sexp); + }); + + RETURN_NOT_OK(batch); + + if (batch.ValueUnsafe().get() != nullptr && + !batch.ValueUnsafe()->schema()->Equals(schema_)) { + return arrow::Status::Invalid("Expected fun() to return batch with schema '", + schema_->ToString(), "' but got batch with schema '", + batch.ValueUnsafe()->schema()->ToString(), "'"); + } + + *batch_out = batch.ValueUnsafe(); + return arrow::Status::OK(); + } + + private: + cpp11::function fun_; + std::shared_ptr schema_; +}; + +// [[arrow::export]] +std::shared_ptr RecordBatchReader__from_function( + cpp11::sexp fun_sexp, const std::shared_ptr& schema) { + return std::make_shared(fun_sexp, schema); +} + // [[arrow::export]] std::shared_ptr RecordBatchReader__from_Table( const std::shared_ptr& table) { diff --git a/r/src/safe-call-into-r-impl.cpp b/r/src/safe-call-into-r-impl.cpp index 7318c81bb55..4eec3a85df8 100644 --- a/r/src/safe-call-into-r-impl.cpp +++ b/r/src/safe-call-into-r-impl.cpp @@ -38,7 +38,7 @@ bool CanRunWithCapturedR() { on_old_windows = on_old_windows_fun(); } - return !on_old_windows; + return !on_old_windows && GetMainRThread().Executor() == nullptr; #else return false; #endif diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h index 937163a05df..08e8a8c11b6 100644 --- a/r/src/safe-call-into-r.h +++ b/r/src/safe-call-into-r.h @@ -31,7 +31,9 @@ // and crash R in older versions (ARROW-16201). Crashes also occur // on 32-bit R builds on R 3.6 and lower. Implementation provided // in safe-call-into-r-impl.cpp so that we can skip some tests -// when this feature is not provided. +// when this feature is not provided. This also checks that there +// is not already an event loop registered (via MainRThread::Executor()), +// because only one of these can exist at any given time. bool CanRunWithCapturedR(); // The MainRThread class keeps track of the thread on which it is safe diff --git a/r/tests/testthat/test-dataset-dplyr.R b/r/tests/testthat/test-dataset-dplyr.R index fecda56c6c2..229c3e7c603 100644 --- a/r/tests/testthat/test-dataset-dplyr.R +++ b/r/tests/testthat/test-dataset-dplyr.R @@ -70,6 +70,8 @@ test_that("filter() with %in%", { }) test_that("filter() on timestamp columns", { + skip_if_not_available("re2") + ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) expect_equal( ds %>% @@ -116,6 +118,8 @@ test_that("filter() on date32 columns", { 1L ) + skip_if_not_available("re2") + # Also with timestamp scalar expect_equal( open_dataset(tmp) %>% diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 5c826d6dffc..3bcdd8bcde4 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -689,6 +689,92 @@ test_that("map_batches", { ) }) +test_that("map_batches with explicit schema", { + fun_with_dots <- function(batch, first_col, first_col_val) { + record_batch( + !! first_col := first_col_val, + b = batch$a$cast(float64()) + ) + } + + empty_reader <- RecordBatchReader$create( + batches = list(), + schema = schema(a = int32()) + ) + expect_equal( + map_batches( + empty_reader, + fun_with_dots, + "first_col_name", + "first_col_value", + .schema = schema(first_col_name = string(), b = float64()) + )$read_table(), + arrow_table(first_col_name = character(), b = double()) + ) + + reader <- RecordBatchReader$create( + batches = list( + record_batch(a = 1, b = "two"), + record_batch(a = 2, b = "three") + ) + ) + expect_equal( + map_batches( + reader, + fun_with_dots, + "first_col_name", + "first_col_value", + .schema = schema(first_col_name = string(), b = float64()) + )$read_table(), + arrow_table( + first_col_name = c("first_col_value", "first_col_value"), + b = as.numeric(1:2) + ) + ) +}) + +test_that("map_batches without explicit schema", { + fun_with_dots <- function(batch, first_col, first_col_val) { + record_batch( + !! first_col := first_col_val, + b = batch$a$cast(float64()) + ) + } + + empty_reader <- RecordBatchReader$create( + batches = list(), + schema = schema(a = int32()) + ) + expect_error( + map_batches( + empty_reader, + fun_with_dots, + "first_col_name", + "first_col_value" + )$read_table(), + "Can't infer schema" + ) + + reader <- RecordBatchReader$create( + batches = list( + record_batch(a = 1, b = "two"), + record_batch(a = 2, b = "three") + ) + ) + expect_equal( + map_batches( + reader, + fun_with_dots, + "first_col_name", + "first_col_value" + )$read_table(), + arrow_table( + first_col_name = c("first_col_value", "first_col_value"), + b = as.numeric(1:2) + ) + ) +}) + test_that("head/tail", { # head/tail with no query are still deterministic order ds <- open_dataset(dataset_dir) diff --git a/r/tests/testthat/test-record-batch-reader.R b/r/tests/testthat/test-record-batch-reader.R index 597187da459..3cd856de667 100644 --- a/r/tests/testthat/test-record-batch-reader.R +++ b/r/tests/testthat/test-record-batch-reader.R @@ -236,3 +236,36 @@ test_that("as_record_batch_reader() works for data.frame", { reader <- as_record_batch_reader(df) expect_equal(reader$read_next_batch(), record_batch(a = 1, b = "two")) }) + +test_that("as_record_batch_reader() works for function", { + batches <- list( + record_batch(a = 1, b = "two"), + record_batch(a = 2, b = "three") + ) + + i <- 0 + fun <- function() { + i <<- i + 1 + if (i > length(batches)) NULL else batches[[i]] + } + + reader <- as_record_batch_reader(fun, schema = batches[[1]]$schema) + expect_equal(reader$read_next_batch(), batches[[1]]) + expect_equal(reader$read_next_batch(), batches[[2]]) + expect_null(reader$read_next_batch()) + + # check invalid returns + fun_bad_type <- function() "not a record batch" + reader <- as_record_batch_reader(fun_bad_type, schema = schema()) + expect_error( + reader$read_next_batch(), + "Expected fun\\(\\) to return an arrow::RecordBatch" + ) + + fun_bad_schema <- function() record_batch(a = 1) + reader <- as_record_batch_reader(fun_bad_schema, schema = schema(a = string())) + expect_error( + reader$read_next_batch(), + "Expected fun\\(\\) to return batch with schema 'a: string'" + ) +})