Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions r/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ S3method(as_record_batch,arrow_dplyr_query)
S3method(as_record_batch,data.frame)
S3method(as_record_batch,pyarrow.lib.RecordBatch)
S3method(as_record_batch,pyarrow.lib.Table)
S3method(as_record_batch_reader,"function")
S3method(as_record_batch_reader,Dataset)
S3method(as_record_batch_reader,RecordBatch)
S3method(as_record_batch_reader,RecordBatchReader)
Expand Down
4 changes: 4 additions & 0 deletions r/R/arrowExports.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

68 changes: 54 additions & 14 deletions r/R/dataset-scan.R
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,16 @@ tail_from_batches <- function(batches, n) {
#' @param FUN A function or `purrr`-style lambda expression to apply to each
#' batch. It must return a RecordBatch or something coercible to one via
#' `as_record_batch()'.
#' @param .schema An optional [schema()]. If NULL, the schema will be inferred
#' from the first batch.
#' @param .lazy Use `TRUE` to evaluate `FUN` lazily as batches are read from
#' the result; use `FALSE` to evaluate `FUN` on all batches before returning
#' the reader.
#' @param ... Additional arguments passed to `FUN`
#' @param .data.frame Deprecated argument, ignored
#' @return An `arrow_dplyr_query`.
#' @export
map_batches <- function(X, FUN, ..., .data.frame = NULL) {
map_batches <- function(X, FUN, ..., .schema = NULL, .lazy = FALSE, .data.frame = NULL) {
if (!is.null(.data.frame)) {
warning(
"The .data.frame argument is deprecated. ",
Expand All @@ -197,25 +202,60 @@ map_batches <- function(X, FUN, ..., .data.frame = NULL) {
}
FUN <- as_mapper(FUN)
reader <- as_record_batch_reader(X)
dots <- rlang::list2(...)

# TODO: for future consideration
# * Move eval to C++ and make it a generator so it can stream, not block
# * Accept an output schema argument: with that, we could make this lazy (via collapse)
batch <- reader$read_next_batch()
res <- vector("list", 1024)
i <- 0L
while (!is.null(batch)) {
i <- i + 1L
res[[i]] <- as_record_batch(FUN(batch, ...))
# If no schema is supplied, we have to evaluate the first batch here
if (is.null(.schema)) {
batch <- reader$read_next_batch()
if (is.null(batch)) {
abort("Can't infer schema from a RecordBatchReader with zero batches")
}

first_result <- as_record_batch(do.call(FUN, c(list(batch), dots)))
.schema <- first_result$schema
fun <- function() {
if (!is.null(first_result)) {
result <- first_result
first_result <<- NULL
result
} else {
batch <- reader$read_next_batch()
if (is.null(batch)) {
NULL
} else {
as_record_batch(
do.call(FUN, c(list(batch), dots)),
schema = .schema
)
}
}
}
} else {
fun <- function() {
batch <- reader$read_next_batch()
if (is.null(batch)) {
return(NULL)
}

as_record_batch(
do.call(FUN, c(list(batch), dots)),
schema = .schema
)
}
}

# Trim list back
if (i < length(res)) {
res <- res[seq_len(i)]
reader_out <- as_record_batch_reader(fun, schema = .schema)

# TODO(ARROW-17178) because there are some restrictions on evaluating
# reader_out in some ExecPlans, the default .lazy is FALSE for now.
if (!.lazy) {
reader_out <- RecordBatchReader$create(
batches = reader_out$batches(),
schema = .schema
)
}

RecordBatchReader$create(batches = res)
reader_out
}

#' @usage NULL
Expand Down
9 changes: 9 additions & 0 deletions r/R/record-batch-reader.R
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ RecordBatchFileReader$create <- function(file) {
#' Convert an object to an Arrow RecordBatchReader
#'
#' @param x An object to convert to a [RecordBatchReader]
#' @param schema The [schema()] that must match the schema returned by each
#' call to `x` when `x` is a function.
#' @param ... Passed to S3 methods
#'
#' @return A [RecordBatchReader]
Expand Down Expand Up @@ -234,6 +236,13 @@ as_record_batch_reader.Dataset <- function(x, ...) {
Scanner$create(x)$ToRecordBatchReader()
}

#' @rdname as_record_batch_reader
#' @export
as_record_batch_reader.function <- function(x, ..., schema) {
assert_that(inherits(schema, "Schema"))
RecordBatchReader__from_function(x, schema)
}

#' @rdname as_record_batch_reader
#' @export
as_record_batch_reader.arrow_dplyr_query <- function(x, ...) {
Expand Down
6 changes: 6 additions & 0 deletions r/man/as_record_batch_reader.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 8 additions & 1 deletion r/man/map_batches.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions r/src/arrowExports.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 45 additions & 0 deletions r/src/recordbatchreader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

#include "./arrow_types.h"
#include "./safe-call-into-r.h"

#include <arrow/ipc/reader.h>
#include <arrow/table.h>
Expand Down Expand Up @@ -54,6 +55,50 @@ std::shared_ptr<arrow::RecordBatchReader> RecordBatchReader__from_batches(
}
}

class RFunctionRecordBatchReader : public arrow::RecordBatchReader {
public:
RFunctionRecordBatchReader(cpp11::sexp fun,
const std::shared_ptr<arrow::Schema>& schema)
: fun_(fun), schema_(schema) {}

std::shared_ptr<arrow::Schema> schema() const { return schema_; }

arrow::Status ReadNext(std::shared_ptr<arrow::RecordBatch>* batch_out) {
auto batch = SafeCallIntoR<std::shared_ptr<arrow::RecordBatch>>([&]() {
cpp11::sexp result_sexp = fun_();
if (result_sexp == R_NilValue) {
return std::shared_ptr<arrow::RecordBatch>(nullptr);
} else if (!Rf_inherits(result_sexp, "RecordBatch")) {
cpp11::stop("Expected fun() to return an arrow::RecordBatch");
}

return cpp11::as_cpp<std::shared_ptr<arrow::RecordBatch>>(result_sexp);
});

RETURN_NOT_OK(batch);

if (batch.ValueUnsafe().get() != nullptr &&
!batch.ValueUnsafe()->schema()->Equals(schema_)) {
return arrow::Status::Invalid("Expected fun() to return batch with schema '",
schema_->ToString(), "' but got batch with schema '",
batch.ValueUnsafe()->schema()->ToString(), "'");
}

*batch_out = batch.ValueUnsafe();
return arrow::Status::OK();
}

private:
cpp11::function fun_;
std::shared_ptr<arrow::Schema> schema_;
};

// [[arrow::export]]
std::shared_ptr<arrow::RecordBatchReader> RecordBatchReader__from_function(
cpp11::sexp fun_sexp, const std::shared_ptr<arrow::Schema>& schema) {
return std::make_shared<RFunctionRecordBatchReader>(fun_sexp, schema);
}

// [[arrow::export]]
std::shared_ptr<arrow::RecordBatchReader> RecordBatchReader__from_Table(
const std::shared_ptr<arrow::Table>& table) {
Expand Down
2 changes: 1 addition & 1 deletion r/src/safe-call-into-r-impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ bool CanRunWithCapturedR() {
on_old_windows = on_old_windows_fun();
}

return !on_old_windows;
return !on_old_windows && GetMainRThread().Executor() == nullptr;
#else
return false;
#endif
Expand Down
4 changes: 3 additions & 1 deletion r/src/safe-call-into-r.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
// and crash R in older versions (ARROW-16201). Crashes also occur
// on 32-bit R builds on R 3.6 and lower. Implementation provided
// in safe-call-into-r-impl.cpp so that we can skip some tests
// when this feature is not provided.
// when this feature is not provided. This also checks that there
// is not already an event loop registered (via MainRThread::Executor()),
// because only one of these can exist at any given time.
bool CanRunWithCapturedR();

// The MainRThread class keeps track of the thread on which it is safe
Expand Down
4 changes: 4 additions & 0 deletions r/tests/testthat/test-dataset-dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ test_that("filter() with %in%", {
})

test_that("filter() on timestamp columns", {
skip_if_not_available("re2")

ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8()))
expect_equal(
ds %>%
Expand Down Expand Up @@ -116,6 +118,8 @@ test_that("filter() on date32 columns", {
1L
)

skip_if_not_available("re2")

# Also with timestamp scalar
expect_equal(
open_dataset(tmp) %>%
Expand Down
86 changes: 86 additions & 0 deletions r/tests/testthat/test-dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,92 @@ test_that("map_batches", {
)
})

test_that("map_batches with explicit schema", {
fun_with_dots <- function(batch, first_col, first_col_val) {
record_batch(
!! first_col := first_col_val,
b = batch$a$cast(float64())
)
}

empty_reader <- RecordBatchReader$create(
batches = list(),
schema = schema(a = int32())
)
expect_equal(
map_batches(
empty_reader,
fun_with_dots,
"first_col_name",
"first_col_value",
.schema = schema(first_col_name = string(), b = float64())
)$read_table(),
arrow_table(first_col_name = character(), b = double())
)

reader <- RecordBatchReader$create(
batches = list(
record_batch(a = 1, b = "two"),
record_batch(a = 2, b = "three")
)
)
expect_equal(
map_batches(
reader,
fun_with_dots,
"first_col_name",
"first_col_value",
.schema = schema(first_col_name = string(), b = float64())
)$read_table(),
arrow_table(
first_col_name = c("first_col_value", "first_col_value"),
b = as.numeric(1:2)
)
)
})

test_that("map_batches without explicit schema", {
fun_with_dots <- function(batch, first_col, first_col_val) {
record_batch(
!! first_col := first_col_val,
b = batch$a$cast(float64())
)
}

empty_reader <- RecordBatchReader$create(
batches = list(),
schema = schema(a = int32())
)
expect_error(
map_batches(
empty_reader,
fun_with_dots,
"first_col_name",
"first_col_value"
)$read_table(),
"Can't infer schema"
)

reader <- RecordBatchReader$create(
batches = list(
record_batch(a = 1, b = "two"),
record_batch(a = 2, b = "three")
)
)
expect_equal(
map_batches(
reader,
fun_with_dots,
"first_col_name",
"first_col_value"
)$read_table(),
arrow_table(
first_col_name = c("first_col_value", "first_col_value"),
b = as.numeric(1:2)
)
)
})

test_that("head/tail", {
# head/tail with no query are still deterministic order
ds <- open_dataset(dataset_dir)
Expand Down
Loading