From 2eb3c9b6578d1162b70a1a4aa04ae6db92ff5e2c Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 17 Jun 2022 10:23:37 -0400 Subject: [PATCH 01/92] add test file for general compute --- r/tests/testthat/test-compute.R | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 r/tests/testthat/test-compute.R diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R new file mode 100644 index 00000000000..7e377f2d315 --- /dev/null +++ b/r/tests/testthat/test-compute.R @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +test_that("list_compute_functions() works", { + expect_type(list_compute_functions(), "character") + expect_true(all(!grepl("^hash_", list_compute_functions()))) +}) From 80293f7d075c8a95bc2351c61c48082d68772377 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 17 Jun 2022 11:52:36 -0400 Subject: [PATCH 02/92] scalar function creator --- r/R/compute.R | 40 ++++++++++++++++++++++++++++++ r/tests/testthat/_snaps/compute.md | 8 ++++++ r/tests/testthat/test-compute.R | 37 +++++++++++++++++++++++++++ 3 files changed, 85 insertions(+) create mode 100644 r/tests/testthat/_snaps/compute.md diff --git a/r/R/compute.R b/r/R/compute.R index 1cd12f2e29d..ea29e506abb 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -306,3 +306,43 @@ cast_options <- function(safe = TRUE, ...) { ) modifyList(opts, list(...)) } + +arrow_scalar_function <- function(in_type, out_type, fun) { + if (is.list(in_type)) { + in_type <- lapply(in_type, as_in_types) + } else { + in_type <- list(as_in_types(in_type)) + } + + if (is.list(out_type)) { + out_type <- lapply(out_type, as_data_type) + } else if (is.function(out_type)) { + out_type <- lapply(in_type, out_type) + } else { + out_type <- list(as_data_type(out_type)) + } + + out_type <- rep_len(out_type, length(in_type)) + + fun <- rlang::as_function(fun) + if (length(formals(fun)) != 2) { + abort("`fun` must accept exactly two arguments (`kernel_context`, `batch`)") + } + + structure( + fun, + in_type = in_type, + out_type = out_type, + class = "arrow_scalar_function" + ) +} + +as_in_types <- function(x) { + if (inherits(x, "Field")) { + schema(x) + } else if (inherits(x, "DataType")) { + schema(".x" = x) + } else { + as_schema(x) + } +} diff --git a/r/tests/testthat/_snaps/compute.md b/r/tests/testthat/_snaps/compute.md new file mode 100644 index 00000000000..ca74a5b2765 --- /dev/null +++ b/r/tests/testthat/_snaps/compute.md @@ -0,0 +1,8 @@ +# arrow_scalar_function() works + + `fun` must accept exactly two arguments (`kernel_context`, `batch`) + +--- + + Can't convert `fun`, NULL, to a function. + diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 7e377f2d315..3381b50e4ca 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -19,3 +19,40 @@ test_that("list_compute_functions() works", { expect_type(list_compute_functions(), "character") expect_true(all(!grepl("^hash_", list_compute_functions()))) }) + + +test_that("arrow_scalar_function() works", { + # check in/out type as schema/data type + fun <- arrow_scalar_function(schema(.y = int32()), int64(), function(x, y) y[[1]]) + expect_equal(attr(fun, "in_type")[[1]], schema(.y = int32())) + expect_equal(attr(fun, "out_type")[[1]], int64()) + + # check in/out type as data type/data type + fun <- arrow_scalar_function(int32(), int64(), function(x, y) y[[1]]) + expect_equal(attr(fun, "in_type")[[1]], schema(.x = int32())) + expect_equal(attr(fun, "out_type")[[1]], int64()) + + # check in/out type as field/data type + fun <- arrow_scalar_function( + field("a_name", int32()), + int64(), + function(x, y) y[[1]] + ) + expect_equal(attr(fun, "in_type")[[1]], schema(a_name = int32())) + expect_equal(attr(fun, "out_type")[[1]], int64()) + + # check in/out type as lists + fun <- arrow_scalar_function( + list(int32(), int64()), + list(int64(), int32()), + function(x, y) y[[1]] + ) + + expect_equal(attr(fun, "in_type")[[1]], schema(.x = int32())) + expect_equal(attr(fun, "in_type")[[2]], schema(.x = int64())) + expect_equal(attr(fun, "out_type")[[1]], int64()) + expect_equal(attr(fun, "out_type")[[1]], int64()) + + expect_snapshot_error(arrow_scalar_function(int32(), int32(), identity)) + expect_snapshot_error(arrow_scalar_function(int32(), int32(), NULL)) +}) From 29d02d80d9c564bc8066e62089e1f5fa2be06ddd Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 17 Jun 2022 12:07:58 -0400 Subject: [PATCH 03/92] implement registration in R --- r/R/compute.R | 26 ++++++++++++++++++++++++++ r/tests/testthat/test-compute.R | 8 ++++++++ 2 files changed, 34 insertions(+) diff --git a/r/R/compute.R b/r/R/compute.R index ea29e506abb..82a2a3f95c4 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -307,6 +307,32 @@ cast_options <- function(safe = TRUE, ...) { modifyList(opts, list(...)) } +register_scalar_function <- function(name, scalar_function) { + assert_that( + is.string(name), + inherits(scalar_function, "arrow_scalar_function") + ) + + # use something obfuscated for the arrow name to avoid + # collisions with functions registered in the same .so (e.g., Python) + compute_registry_name <- paste0( + "r_scalar_", name, "_", + rlang::hash(scalar_function) + ) + + # register with Arrow C++ + # (not yet) + + # register with dplyr bindings + register_binding( + name, + function(...) Expression$create(compute_registry_name, ...) + ) + + # invalidate function cache + create_binding_cache() +} + arrow_scalar_function <- function(in_type, out_type, fun) { if (is.list(in_type)) { in_type <- lapply(in_type, as_in_types) diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 3381b50e4ca..ae0ccb36782 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -56,3 +56,11 @@ test_that("arrow_scalar_function() works", { expect_snapshot_error(arrow_scalar_function(int32(), int32(), identity)) expect_snapshot_error(arrow_scalar_function(int32(), int32(), NULL)) }) + +test_that("register_scalar_function() creates a dplyr binding", { + fun <- arrow_scalar_function(int32(), int64(), function(x, y) y[[1]]) + register_scalar_function("my_test_scalar_function", fun) + expect_true("my_test_scalar_function" %in% names(arrow:::.cache$functions)) + + +}) From 190d059750ded2e6701257884e8677d48e27e455 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 17 Jun 2022 13:27:39 -0400 Subject: [PATCH 04/92] sketch C++ UDF behaviour --- r/R/arrowExports.R | 4 ++++ r/R/compute.R | 2 +- r/src/arrowExports.cpp | 11 +++++++++ r/src/compute.cpp | 53 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 1 deletion(-) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 84f6ee54fc7..963ffeb9673 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -480,6 +480,10 @@ compute__GetFunctionNames <- function() { .Call(`_arrow_compute__GetFunctionNames`) } +RegisterScalarUDF <- function(name, fun) { + invisible(.Call(`_arrow_RegisterScalarUDF`, name, fun)) +} + build_info <- function() { .Call(`_arrow_build_info`) } diff --git a/r/R/compute.R b/r/R/compute.R index 82a2a3f95c4..4b28a9b8bdf 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -321,7 +321,7 @@ register_scalar_function <- function(name, scalar_function) { ) # register with Arrow C++ - # (not yet) + RegisterScalarUDF(compute_registry_name, scalar_function) # register with dplyr bindings register_binding( diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index e89718144ab..d94ad1aef45 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -1099,6 +1099,16 @@ BEGIN_CPP11 return cpp11::as_sexp(compute__GetFunctionNames()); END_CPP11 } +// compute.cpp +void RegisterScalarUDF(std::string name, cpp11::sexp fun); +extern "C" SEXP _arrow_RegisterScalarUDF(SEXP name_sexp, SEXP fun_sexp){ +BEGIN_CPP11 + arrow::r::Input::type name(name_sexp); + arrow::r::Input::type fun(fun_sexp); + RegisterScalarUDF(name, fun); + return R_NilValue; +END_CPP11 +} // config.cpp std::vector build_info(); extern "C" SEXP _arrow_build_info(){ @@ -5258,6 +5268,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_Table__cast", (DL_FUNC) &_arrow_Table__cast, 3}, { "_arrow_compute__CallFunction", (DL_FUNC) &_arrow_compute__CallFunction, 3}, { "_arrow_compute__GetFunctionNames", (DL_FUNC) &_arrow_compute__GetFunctionNames, 0}, + { "_arrow_RegisterScalarUDF", (DL_FUNC) &_arrow_RegisterScalarUDF, 2}, { "_arrow_build_info", (DL_FUNC) &_arrow_build_info, 0}, { "_arrow_runtime_info", (DL_FUNC) &_arrow_runtime_info, 0}, { "_arrow_set_timezone_database", (DL_FUNC) &_arrow_set_timezone_database, 1}, diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 0db558972e8..cd8efdafc55 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -574,3 +574,56 @@ SEXP compute__CallFunction(std::string func_name, cpp11::list args, cpp11::list std::vector compute__GetFunctionNames() { return arrow::compute::GetFunctionRegistry()->GetFunctionNames(); } + +class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { + public: + RScalarUDFCallable(const std::shared_ptr& input_types, + const std::shared_ptr& output_type, cpp11::sexp fun) + : input_types_(input_types), output_type_(output_type), fun_(fun) {} + + arrow::Status operator()(arrow::compute::KernelContext* context, + const arrow::compute::ExecSpan& span, + arrow::compute::ExecResult* result) { + return arrow::Status::NotImplemented("did we get this far?"); + } + + private: + std::shared_ptr input_types_; + std::shared_ptr output_type_; + cpp11::sexp fun_; +}; + +// [[arrow::export]] +void RegisterScalarUDF(std::string name, cpp11::sexp fun) { + const arrow::compute::FunctionDoc dummy_function_doc{ + "A user-defined R function", "returns something", {"..."}}; + + auto func = std::make_shared( + name, arrow::compute::Arity::VarArgs(), dummy_function_doc); + + cpp11::list in_type_r(fun.attr("in_type")); + cpp11::list out_type_r(fun.attr("out_type")); + R_xlen_t n_kernels = in_type_r.size(); + + for (R_xlen_t i = 0; i < n_kernels; i++) { + auto in_types = cpp11::as_cpp>(in_type_r[i]); + auto out_type = cpp11::as_cpp>(out_type_r[i]); + + std::vector compute_in_types; + for (int64_t i = 0; i < in_types->num_fields(); i++) { + compute_in_types.push_back(arrow::compute::InputType(in_types->field(i)->type())); + } + + auto signature = std::make_shared(compute_in_types, + out_type, true); + arrow::compute::ScalarKernel kernel(signature, + RScalarUDFCallable(in_types, out_type, fun)); + kernel.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE; + kernel.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; + + StopIfNotOk(func->AddKernel(std::move(kernel))); + } + + auto registry = arrow::compute::GetFunctionRegistry(); + StopIfNotOk(registry->AddFunction(std::move(func))); +} From 5e3f6825726d75e22eb9cbb01472f53b5e7d01cd Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 17 Jun 2022 16:04:37 -0400 Subject: [PATCH 05/92] working R execution --- r/R/compute.R | 12 +++------- r/src/compute.cpp | 40 +++++++++++++++++++++++++++++++-- r/tests/testthat/test-compute.R | 14 +++++++++++- 3 files changed, 54 insertions(+), 12 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 4b28a9b8bdf..139378adf68 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -307,21 +307,15 @@ cast_options <- function(safe = TRUE, ...) { modifyList(opts, list(...)) } -register_scalar_function <- function(name, scalar_function) { +register_scalar_function <- function(name, scalar_function, registry_name = name) { assert_that( is.string(name), + is.string(registry_name), inherits(scalar_function, "arrow_scalar_function") ) - # use something obfuscated for the arrow name to avoid - # collisions with functions registered in the same .so (e.g., Python) - compute_registry_name <- paste0( - "r_scalar_", name, "_", - rlang::hash(scalar_function) - ) - # register with Arrow C++ - RegisterScalarUDF(compute_registry_name, scalar_function) + RegisterScalarUDF(registry_name, scalar_function) # register with dplyr bindings register_binding( diff --git a/r/src/compute.cpp b/r/src/compute.cpp index cd8efdafc55..2f35f8c13eb 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -16,7 +16,9 @@ // under the License. #include "./arrow_types.h" +#include "./safe-call-into-r.h" +#include #include #include #include @@ -584,13 +586,47 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { arrow::Status operator()(arrow::compute::KernelContext* context, const arrow::compute::ExecSpan& span, arrow::compute::ExecResult* result) { - return arrow::Status::NotImplemented("did we get this far?"); + std::vector> array_args; + for (int64_t i = 0; i < span.num_values(); i++) { + const arrow::compute::ExecValue& v = span[i]; + if (v.is_array()) { + array_args.push_back(v.array.ToArray()); + } else if (v.is_scalar()) { + auto array = ValueOrStop(arrow::MakeArrayFromScalar(*v.scalar, span.length)); + array_args.push_back(array); + } + } + + auto batch = arrow::RecordBatch::Make(input_types_, span.length, array_args); + + auto fun_result = SafeCallIntoR>([&]() { + cpp11::sexp batch_sexp = cpp11::to_r6(batch); + cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length); + + cpp11::writable::list udf_context = {batch_length_sexp}; + udf_context.names() = {"batch_length"}; + + cpp11::sexp fun_result_sexp = fun_(udf_context, batch_sexp); + if (!Rf_inherits(fun_result_sexp, "Array")) { + cpp11::stop("arrow_scalar_function must return an Array"); + } + + return cpp11::as_cpp>(fun_result_sexp); + }); + + if (!fun_result.ok()) { + return fun_result.status(); + } + + result->value.emplace>( + fun_result.ValueUnsafe()->data()); + return arrow::Status::OK(); } private: std::shared_ptr input_types_; std::shared_ptr output_type_; - cpp11::sexp fun_; + cpp11::function fun_; }; // [[arrow::export]] diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index ae0ccb36782..6e398f52283 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -58,9 +58,21 @@ test_that("arrow_scalar_function() works", { }) test_that("register_scalar_function() creates a dplyr binding", { - fun <- arrow_scalar_function(int32(), int64(), function(x, y) y[[1]]) + fun <- arrow_scalar_function( + int32(), + int64(), + function(x, y) { + y[[1]]$cast(int64()) + } + ) + register_scalar_function("my_test_scalar_function", fun) expect_true("my_test_scalar_function" %in% names(arrow:::.cache$functions)) + expect_true("my_test_scalar_function" %in% list_compute_functions()) + # segfaults but after the R execution step + # call_function("my_test_scalar_function", Scalar$create(1L)) + # fails because there's no event loop registered + # record_batch(a = 1L) |> dplyr::mutate(b = my_test_scalar_function(a)) |> dplyr::collect() }) From e1294714eeacb1461e81353d4ca1bf0536f5a108 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 20 Jun 2022 12:03:15 -0300 Subject: [PATCH 06/92] Update r/src/compute.cpp Co-authored-by: Vibhatha Lakmal Abeykoon --- r/src/compute.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 2f35f8c13eb..a894f479923 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -618,8 +618,7 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { return fun_result.status(); } - result->value.emplace>( - fun_result.ValueUnsafe()->data()); + result->value = std::move(fun_result->data()); return arrow::Status::OK(); } From 83ad7ad356fee4f44a293548fdbaf8f594569906 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 20 Jun 2022 14:23:49 -0300 Subject: [PATCH 07/92] Update r/src/compute.cpp Co-authored-by: Vibhatha Lakmal Abeykoon --- r/src/compute.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index a894f479923..8dad4359817 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -618,7 +618,7 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { return fun_result.status(); } - result->value = std::move(fun_result->data()); + result->value = std::move(ValueOrStop(fun_result)->data()); return arrow::Status::OK(); } From 94c0b2f3ecff21805df89f92f2f3b4566ab63170 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 20 Jun 2022 14:42:10 -0300 Subject: [PATCH 08/92] check Array argument --- r/src/compute.cpp | 4 ++-- r/tests/testthat/test-compute.R | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 8dad4359817..be9e3c98000 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -618,7 +618,7 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { return fun_result.status(); } - result->value = std::move(ValueOrStop(fun_result)->data()); + result->value = std::move(ValueOrStop(fun_result)->data()); return arrow::Status::OK(); } @@ -660,5 +660,5 @@ void RegisterScalarUDF(std::string name, cpp11::sexp fun) { } auto registry = arrow::compute::GetFunctionRegistry(); - StopIfNotOk(registry->AddFunction(std::move(func))); + StopIfNotOk(registry->AddFunction(std::move(func), true)); } diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 6e398f52283..674beb92a07 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -70,8 +70,10 @@ test_that("register_scalar_function() creates a dplyr binding", { expect_true("my_test_scalar_function" %in% names(arrow:::.cache$functions)) expect_true("my_test_scalar_function" %in% list_compute_functions()) - # segfaults but after the R execution step - # call_function("my_test_scalar_function", Scalar$create(1L)) + expect_equal( + call_function("my_test_scalar_function", Array$create(1L, int32())), + Array$create(1L, int64()) + ) # fails because there's no event loop registered # record_batch(a = 1L) |> dplyr::mutate(b = my_test_scalar_function(a)) |> dplyr::collect() From 9e1a3622faa09f8cb33a860d3543f5ababf3bf20 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 20 Jun 2022 16:15:07 -0300 Subject: [PATCH 09/92] don't force arguments to Array --- r/src/compute.cpp | 17 +++++++++++++++-- r/tests/testthat/test-compute.R | 8 +++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index be9e3c98000..48b0f0f1c69 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -600,13 +600,26 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { auto batch = arrow::RecordBatch::Make(input_types_, span.length, array_args); auto fun_result = SafeCallIntoR>([&]() { - cpp11::sexp batch_sexp = cpp11::to_r6(batch); + cpp11::writable::list args_sexp; + args_sexp.reserve(span.num_values()); + + for (int64_t i = 0; i < span.num_values(); i++) { + const arrow::compute::ExecValue& v = span[i]; + if (v.is_array()) { + std::shared_ptr array = v.array.ToArray(); + args_sexp.push_back(cpp11::to_r6(array)); + } else if (v.is_scalar()) { + std::shared_ptr scalar = v.scalar->shared_from_this(); + args_sexp.push_back(cpp11::to_r6(scalar)); + } + } + cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length); cpp11::writable::list udf_context = {batch_length_sexp}; udf_context.names() = {"batch_length"}; - cpp11::sexp fun_result_sexp = fun_(udf_context, batch_sexp); + cpp11::sexp fun_result_sexp = fun_(udf_context, args_sexp); if (!Rf_inherits(fun_result_sexp, "Array")) { cpp11::stop("arrow_scalar_function must return an Array"); } diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 674beb92a07..ef51abffa82 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -62,7 +62,7 @@ test_that("register_scalar_function() creates a dplyr binding", { int32(), int64(), function(x, y) { - y[[1]]$cast(int64()) + as_arrow_array(y[[1]])$cast(int64()) } ) @@ -75,6 +75,12 @@ test_that("register_scalar_function() creates a dplyr binding", { Array$create(1L, int64()) ) + # segfaults + # expect_equal( + # call_function("my_test_scalar_function", Scalar$create(1L, int32())), + # Array$create(1L, int64()) + # ) + # fails because there's no event loop registered # record_batch(a = 1L) |> dplyr::mutate(b = my_test_scalar_function(a)) |> dplyr::collect() }) From ddc0d46dc2193300568e3178a50d2d1bf7fa1ebf Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 20 Jun 2022 16:16:55 -0300 Subject: [PATCH 10/92] remove unused code --- r/src/compute.cpp | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 48b0f0f1c69..f7eabfb7a3c 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -586,19 +586,6 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { arrow::Status operator()(arrow::compute::KernelContext* context, const arrow::compute::ExecSpan& span, arrow::compute::ExecResult* result) { - std::vector> array_args; - for (int64_t i = 0; i < span.num_values(); i++) { - const arrow::compute::ExecValue& v = span[i]; - if (v.is_array()) { - array_args.push_back(v.array.ToArray()); - } else if (v.is_scalar()) { - auto array = ValueOrStop(arrow::MakeArrayFromScalar(*v.scalar, span.length)); - array_args.push_back(array); - } - } - - auto batch = arrow::RecordBatch::Make(input_types_, span.length, array_args); - auto fun_result = SafeCallIntoR>([&]() { cpp11::writable::list args_sexp; args_sexp.reserve(span.num_values()); From fadf258cc7c3f6c986f5de407a5febba3672a154 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 20 Jun 2022 21:43:22 -0300 Subject: [PATCH 11/92] Update r/src/compute.cpp Co-authored-by: Vibhatha Lakmal Abeykoon --- r/src/compute.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index f7eabfb7a3c..271031aa76f 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -596,8 +596,8 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { std::shared_ptr array = v.array.ToArray(); args_sexp.push_back(cpp11::to_r6(array)); } else if (v.is_scalar()) { - std::shared_ptr scalar = v.scalar->shared_from_this(); - args_sexp.push_back(cpp11::to_r6(scalar)); + std::shared_ptr scalar = v.scalar->Copy(); + args_sexp.push_back(cpp11::to_r6(scalar)); } } From 2eb48ae8c6cbb19757a7d16a77168f3fb9bb5425 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 20 Jun 2022 21:48:59 -0300 Subject: [PATCH 12/92] better names for variables --- r/src/compute.cpp | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 271031aa76f..d12900ace36 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -580,23 +580,24 @@ std::vector compute__GetFunctionNames() { class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { public: RScalarUDFCallable(const std::shared_ptr& input_types, - const std::shared_ptr& output_type, cpp11::sexp fun) - : input_types_(input_types), output_type_(output_type), fun_(fun) {} + const std::shared_ptr& output_type, + cpp11::sexp func) + : input_types_(input_types), output_type_(output_type), func_(func) {} arrow::Status operator()(arrow::compute::KernelContext* context, const arrow::compute::ExecSpan& span, arrow::compute::ExecResult* result) { - auto fun_result = SafeCallIntoR>([&]() { + auto func_result = SafeCallIntoR>([&]() { cpp11::writable::list args_sexp; args_sexp.reserve(span.num_values()); for (int64_t i = 0; i < span.num_values(); i++) { - const arrow::compute::ExecValue& v = span[i]; - if (v.is_array()) { - std::shared_ptr array = v.array.ToArray(); + const arrow::compute::ExecValue& exec_val = span[i]; + if (exec_val.is_array()) { + std::shared_ptr array = exec_val.array.ToArray(); args_sexp.push_back(cpp11::to_r6(array)); - } else if (v.is_scalar()) { - std::shared_ptr scalar = v.scalar->Copy(); + } else if (exec_val.is_scalar()) { + std::shared_ptr scalar = exec_val.scalar->Copy(); args_sexp.push_back(cpp11::to_r6(scalar)); } } @@ -606,7 +607,7 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { cpp11::writable::list udf_context = {batch_length_sexp}; udf_context.names() = {"batch_length"}; - cpp11::sexp fun_result_sexp = fun_(udf_context, args_sexp); + cpp11::sexp fun_result_sexp = func_(udf_context, args_sexp); if (!Rf_inherits(fun_result_sexp, "Array")) { cpp11::stop("arrow_scalar_function must return an Array"); } @@ -614,18 +615,18 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { return cpp11::as_cpp>(fun_result_sexp); }); - if (!fun_result.ok()) { - return fun_result.status(); + if (!func_result.ok()) { + return func_result.status(); } - result->value = std::move(ValueOrStop(fun_result)->data()); + result->value = std::move(ValueOrStop(func_result)->data()); return arrow::Status::OK(); } private: std::shared_ptr input_types_; std::shared_ptr output_type_; - cpp11::function fun_; + cpp11::function func_; }; // [[arrow::export]] From c29fc006a5e0732d75642b99bcc541bcfd38b40b Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 20 Jun 2022 22:50:56 -0300 Subject: [PATCH 13/92] handle more cases on execution --- r/src/compute.cpp | 52 +++++++++++++++++++++++++-------- r/tests/testthat/test-compute.R | 16 +++++----- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index d12900ace36..dc9e49a1eb9 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -587,7 +587,7 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { arrow::Status operator()(arrow::compute::KernelContext* context, const arrow::compute::ExecSpan& span, arrow::compute::ExecResult* result) { - auto func_result = SafeCallIntoR>([&]() { + return SafeCallIntoRVoid([&]() { cpp11::writable::list args_sexp; args_sexp.reserve(span.num_values()); @@ -607,20 +607,48 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { cpp11::writable::list udf_context = {batch_length_sexp}; udf_context.names() = {"batch_length"}; - cpp11::sexp fun_result_sexp = func_(udf_context, args_sexp); - if (!Rf_inherits(fun_result_sexp, "Array")) { - cpp11::stop("arrow_scalar_function must return an Array"); - } + cpp11::sexp func_result_sexp = func_(udf_context, args_sexp); - return cpp11::as_cpp>(fun_result_sexp); - }); + if (Rf_inherits(func_result_sexp, "Array")) { + auto array = cpp11::as_cpp>(func_result_sexp); - if (!func_result.ok()) { - return func_result.status(); - } + // handle an Array result of the wrong type + if (!array->type()->Equals(output_type_)) { + arrow::Datum out = ValueOrStop(arrow::compute::Cast(array, output_type_)); + std::shared_ptr out_array = out.make_array(); + array.swap(out_array); + } + + // make sure we assign the type that the result is expecting + if (result->is_array_data()) { + result->value = std::move(array->data()); + } else if (array->length() == 1) { + result->value = ValueOrStop(array->GetScalar(0)); + } else { + cpp11::stop("expected Scalar return value but got Array with length != 1"); + } + } else if (Rf_inherits(func_result_sexp, "Scalar")) { + auto scalar = cpp11::as_cpp>(func_result_sexp); + + // handle a Scalar result of the wrong type + if (!scalar->type->Equals(output_type_)) { + arrow::Datum out = ValueOrStop(arrow::compute::Cast(scalar, output_type_)); + std::shared_ptr out_scalar = out.scalar(); + scalar.swap(out_scalar); + } - result->value = std::move(ValueOrStop(func_result)->data()); - return arrow::Status::OK(); + // make sure we assign the type that the result is expecting + if (result->is_scalar()) { + result->value = std::move(scalar); + } else { + auto array = ValueOrStop( + arrow::MakeArrayFromScalar(*scalar, span.length, context->memory_pool())); + result->value = std::move(array->data()); + } + } else { + cpp11::stop("arrow_scalar_function must return an Array or Scalar"); + } + }); } private: diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index ef51abffa82..0b8a3661b19 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -61,12 +61,13 @@ test_that("register_scalar_function() creates a dplyr binding", { fun <- arrow_scalar_function( int32(), int64(), - function(x, y) { - as_arrow_array(y[[1]])$cast(int64()) + function(context, args) { + args[[1]] } ) register_scalar_function("my_test_scalar_function", fun) + expect_true("my_test_scalar_function" %in% names(arrow:::.cache$functions)) expect_true("my_test_scalar_function" %in% list_compute_functions()) @@ -75,12 +76,11 @@ test_that("register_scalar_function() creates a dplyr binding", { Array$create(1L, int64()) ) - # segfaults - # expect_equal( - # call_function("my_test_scalar_function", Scalar$create(1L, int32())), - # Array$create(1L, int64()) - # ) + expect_equal( + call_function("my_test_scalar_function", Scalar$create(1L, int32())), + Scalar$create(1L, int64()) + ) # fails because there's no event loop registered - # record_batch(a = 1L) |> dplyr::mutate(b = my_test_scalar_function(a)) |> dplyr::collect() + # record_batch(a = 1L) |> dplyr::mutate(b = arrow_my_test_scalar_function(a)) |> dplyr::collect() }) From b1c8cbf3efc338feb109977da498d622df00288d Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 21 Jun 2022 10:27:29 -0300 Subject: [PATCH 14/92] use Resolver as an R function --- r/R/arrowExports.R | 4 +-- r/R/compute.R | 15 +++++--- r/src/arrowExports.cpp | 8 ++--- r/src/compute.cpp | 64 +++++++++++++++++++++++---------- r/tests/testthat/test-compute.R | 17 ++++----- 5 files changed, 69 insertions(+), 39 deletions(-) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 963ffeb9673..83f5f2ef99b 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -480,8 +480,8 @@ compute__GetFunctionNames <- function() { .Call(`_arrow_compute__GetFunctionNames`) } -RegisterScalarUDF <- function(name, fun) { - invisible(.Call(`_arrow_RegisterScalarUDF`, name, fun)) +RegisterScalarUDF <- function(name, func_sexp) { + invisible(.Call(`_arrow_RegisterScalarUDF`, name, func_sexp)) } build_info <- function() { diff --git a/r/R/compute.R b/r/R/compute.R index 139378adf68..bca98397cef 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -335,11 +335,9 @@ arrow_scalar_function <- function(in_type, out_type, fun) { } if (is.list(out_type)) { - out_type <- lapply(out_type, as_data_type) - } else if (is.function(out_type)) { - out_type <- lapply(in_type, out_type) + out_type <- lapply(out_type, as_out_type) } else { - out_type <- list(as_data_type(out_type)) + out_type <- list(as_out_type(out_type)) } out_type <- rep_len(out_type, length(in_type)) @@ -366,3 +364,12 @@ as_in_types <- function(x) { as_schema(x) } } + +as_out_type <- function(x) { + if (is.function(x)) { + x + } else { + x <- as_data_type(x) + function(types) x + } +} diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index d94ad1aef45..fb5a46e2909 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -1100,12 +1100,12 @@ BEGIN_CPP11 END_CPP11 } // compute.cpp -void RegisterScalarUDF(std::string name, cpp11::sexp fun); -extern "C" SEXP _arrow_RegisterScalarUDF(SEXP name_sexp, SEXP fun_sexp){ +void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp); +extern "C" SEXP _arrow_RegisterScalarUDF(SEXP name_sexp, SEXP func_sexp_sexp){ BEGIN_CPP11 arrow::r::Input::type name(name_sexp); - arrow::r::Input::type fun(fun_sexp); - RegisterScalarUDF(name, fun); + arrow::r::Input::type func_sexp(func_sexp_sexp); + RegisterScalarUDF(name, func_sexp); return R_NilValue; END_CPP11 } diff --git a/r/src/compute.cpp b/r/src/compute.cpp index dc9e49a1eb9..d8b8294d650 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -577,12 +577,37 @@ std::vector compute__GetFunctionNames() { return arrow::compute::GetFunctionRegistry()->GetFunctionNames(); } +class RScalarUDFOutputTypeResolver : public arrow::compute::OutputType::Resolver { + public: + RScalarUDFOutputTypeResolver(cpp11::sexp func) : func_(func) {} + + arrow::Result operator()( + arrow::compute::KernelContext* context, + const std::vector& descr) { + return SafeCallIntoR([&]() -> arrow::ValueDescr { + cpp11::writable::list input_types_sexp; + input_types_sexp.reserve(descr.size()); + for (const auto& item : descr) { + input_types_sexp.push_back(cpp11::to_r6(item.type)); + } + + cpp11::sexp output_type_sexp = func_(input_types_sexp); + if (!Rf_inherits(output_type_sexp, "DataType")) { + cpp11::stop("arrow_scalar_function resolver must return a DataType"); + } + + return arrow::ValueDescr( + cpp11::as_cpp>(output_type_sexp)); + }); + } + + private: + cpp11::function func_; +}; + class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { public: - RScalarUDFCallable(const std::shared_ptr& input_types, - const std::shared_ptr& output_type, - cpp11::sexp func) - : input_types_(input_types), output_type_(output_type), func_(func) {} + RScalarUDFCallable(cpp11::sexp func) : func_(func) {} arrow::Status operator()(arrow::compute::KernelContext* context, const arrow::compute::ExecSpan& span, @@ -613,8 +638,9 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { auto array = cpp11::as_cpp>(func_result_sexp); // handle an Array result of the wrong type - if (!array->type()->Equals(output_type_)) { - arrow::Datum out = ValueOrStop(arrow::compute::Cast(array, output_type_)); + if (!result->type()->Equals(array->type())) { + arrow::Datum out = + ValueOrStop(arrow::compute::Cast(array, result->type()->Copy())); std::shared_ptr out_array = out.make_array(); array.swap(out_array); } @@ -631,8 +657,9 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { auto scalar = cpp11::as_cpp>(func_result_sexp); // handle a Scalar result of the wrong type - if (!scalar->type->Equals(output_type_)) { - arrow::Datum out = ValueOrStop(arrow::compute::Cast(scalar, output_type_)); + if (!result->type()->Equals(scalar->type)) { + arrow::Datum out = + ValueOrStop(arrow::compute::Cast(scalar, result->type()->Copy())); std::shared_ptr out_scalar = out.scalar(); scalar.swap(out_scalar); } @@ -652,36 +679,35 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { } private: - std::shared_ptr input_types_; - std::shared_ptr output_type_; cpp11::function func_; }; // [[arrow::export]] -void RegisterScalarUDF(std::string name, cpp11::sexp fun) { +void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { const arrow::compute::FunctionDoc dummy_function_doc{ "A user-defined R function", "returns something", {"..."}}; auto func = std::make_shared( name, arrow::compute::Arity::VarArgs(), dummy_function_doc); - cpp11::list in_type_r(fun.attr("in_type")); - cpp11::list out_type_r(fun.attr("out_type")); + cpp11::list in_type_r(func_sexp.attr("in_type")); + cpp11::list out_type_r(func_sexp.attr("out_type")); R_xlen_t n_kernels = in_type_r.size(); for (R_xlen_t i = 0; i < n_kernels; i++) { auto in_types = cpp11::as_cpp>(in_type_r[i]); - auto out_type = cpp11::as_cpp>(out_type_r[i]); + cpp11::sexp out_type_func = out_type_r[i]; std::vector compute_in_types; - for (int64_t i = 0; i < in_types->num_fields(); i++) { + for (int64_t j = 0; j < in_types->num_fields(); j++) { compute_in_types.push_back(arrow::compute::InputType(in_types->field(i)->type())); } - auto signature = std::make_shared(compute_in_types, - out_type, true); - arrow::compute::ScalarKernel kernel(signature, - RScalarUDFCallable(in_types, out_type, fun)); + arrow::compute::OutputType out_type((RScalarUDFOutputTypeResolver(out_type_func))); + + auto signature = std::make_shared( + compute_in_types, std::move(out_type), true); + arrow::compute::ScalarKernel kernel(signature, RScalarUDFCallable(func_sexp)); kernel.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE; kernel.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 0b8a3661b19..f8fb439b1c6 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -25,12 +25,12 @@ test_that("arrow_scalar_function() works", { # check in/out type as schema/data type fun <- arrow_scalar_function(schema(.y = int32()), int64(), function(x, y) y[[1]]) expect_equal(attr(fun, "in_type")[[1]], schema(.y = int32())) - expect_equal(attr(fun, "out_type")[[1]], int64()) + expect_equal(attr(fun, "out_type")[[1]](), int64()) # check in/out type as data type/data type fun <- arrow_scalar_function(int32(), int64(), function(x, y) y[[1]]) expect_equal(attr(fun, "in_type")[[1]], schema(.x = int32())) - expect_equal(attr(fun, "out_type")[[1]], int64()) + expect_equal(attr(fun, "out_type")[[1]](), int64()) # check in/out type as field/data type fun <- arrow_scalar_function( @@ -39,7 +39,7 @@ test_that("arrow_scalar_function() works", { function(x, y) y[[1]] ) expect_equal(attr(fun, "in_type")[[1]], schema(a_name = int32())) - expect_equal(attr(fun, "out_type")[[1]], int64()) + expect_equal(attr(fun, "out_type")[[1]](), int64()) # check in/out type as lists fun <- arrow_scalar_function( @@ -50,8 +50,8 @@ test_that("arrow_scalar_function() works", { expect_equal(attr(fun, "in_type")[[1]], schema(.x = int32())) expect_equal(attr(fun, "in_type")[[2]], schema(.x = int64())) - expect_equal(attr(fun, "out_type")[[1]], int64()) - expect_equal(attr(fun, "out_type")[[1]], int64()) + expect_equal(attr(fun, "out_type")[[1]](), int64()) + expect_equal(attr(fun, "out_type")[[2]](), int32()) expect_snapshot_error(arrow_scalar_function(int32(), int32(), identity)) expect_snapshot_error(arrow_scalar_function(int32(), int32(), NULL)) @@ -59,11 +59,8 @@ test_that("arrow_scalar_function() works", { test_that("register_scalar_function() creates a dplyr binding", { fun <- arrow_scalar_function( - int32(), - int64(), - function(context, args) { - args[[1]] - } + int32(), int64(), + function(context, args) args[[1]] ) register_scalar_function("my_test_scalar_function", fun) From cf9863557828b82c91888b6907e66d15c742caf4 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 21 Jun 2022 10:58:51 -0300 Subject: [PATCH 15/92] add a more user-friendly scalar function wrapper --- r/R/compute.R | 23 +++++++++++++++++------ r/src/compute.cpp | 11 +++++++++-- r/tests/testthat/_snaps/compute.md | 6 +++--- r/tests/testthat/test-compute.R | 30 ++++++++++++++++++++++-------- 4 files changed, 51 insertions(+), 19 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index bca98397cef..8a7707d735b 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -311,7 +311,7 @@ register_scalar_function <- function(name, scalar_function, registry_name = name assert_that( is.string(name), is.string(registry_name), - inherits(scalar_function, "arrow_scalar_function") + inherits(scalar_function, "arrow_base_scalar_function") ) # register with Arrow C++ @@ -328,6 +328,17 @@ register_scalar_function <- function(name, scalar_function, registry_name = name } arrow_scalar_function <- function(in_type, out_type, fun) { + force(fun) + base_fun <- function(kernel_context, args) { + args <- lapply(args, as.vector) + result <- do.call(fun, args) + as_arrow_array(result) + } + + arrow_base_scalar_function(in_type, out_type, base_fun) +} + +arrow_base_scalar_function <- function(in_type, out_type, base_fun) { if (is.list(in_type)) { in_type <- lapply(in_type, as_in_types) } else { @@ -342,16 +353,16 @@ arrow_scalar_function <- function(in_type, out_type, fun) { out_type <- rep_len(out_type, length(in_type)) - fun <- rlang::as_function(fun) - if (length(formals(fun)) != 2) { - abort("`fun` must accept exactly two arguments (`kernel_context`, `batch`)") + base_fun <- rlang::as_function(base_fun) + if (length(formals(base_fun)) != 2) { + abort("`base_fun` must accept exactly two arguments") } structure( - fun, + base_fun, in_type = in_type, out_type = out_type, - class = "arrow_scalar_function" + class = "arrow_base_scalar_function" ) } diff --git a/r/src/compute.cpp b/r/src/compute.cpp index d8b8294d650..fd002d9ee92 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -607,7 +607,8 @@ class RScalarUDFOutputTypeResolver : public arrow::compute::OutputType::Resolver class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { public: - RScalarUDFCallable(cpp11::sexp func) : func_(func) {} + RScalarUDFCallable(cpp11::sexp func, const std::vector& input_names) + : func_(func), input_names_(input_names) {} arrow::Status operator()(arrow::compute::KernelContext* context, const arrow::compute::ExecSpan& span, @@ -627,6 +628,8 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { } } + args_sexp.names() = input_names_; + cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length); cpp11::writable::list udf_context = {batch_length_sexp}; @@ -680,6 +683,7 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { private: cpp11::function func_; + cpp11::strings input_names_; }; // [[arrow::export]] @@ -699,15 +703,18 @@ void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { cpp11::sexp out_type_func = out_type_r[i]; std::vector compute_in_types; + std::vector compute_in_names; for (int64_t j = 0; j < in_types->num_fields(); j++) { compute_in_types.push_back(arrow::compute::InputType(in_types->field(i)->type())); + compute_in_names.push_back(in_types->field(i)->name()); } arrow::compute::OutputType out_type((RScalarUDFOutputTypeResolver(out_type_func))); auto signature = std::make_shared( compute_in_types, std::move(out_type), true); - arrow::compute::ScalarKernel kernel(signature, RScalarUDFCallable(func_sexp)); + arrow::compute::ScalarKernel kernel( + signature, RScalarUDFCallable(func_sexp, std::move(compute_in_names))); kernel.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE; kernel.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; diff --git a/r/tests/testthat/_snaps/compute.md b/r/tests/testthat/_snaps/compute.md index ca74a5b2765..a7ba33e24e4 100644 --- a/r/tests/testthat/_snaps/compute.md +++ b/r/tests/testthat/_snaps/compute.md @@ -1,8 +1,8 @@ -# arrow_scalar_function() works +# arrow_base_scalar_function() works - `fun` must accept exactly two arguments (`kernel_context`, `batch`) + `base_fun` must accept exactly two arguments --- - Can't convert `fun`, NULL, to a function. + Can't convert `base_fun`, NULL, to a function. diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index f8fb439b1c6..3ea003c0e89 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -21,19 +21,19 @@ test_that("list_compute_functions() works", { }) -test_that("arrow_scalar_function() works", { +test_that("arrow_base_scalar_function() works", { # check in/out type as schema/data type - fun <- arrow_scalar_function(schema(.y = int32()), int64(), function(x, y) y[[1]]) + fun <- arrow_base_scalar_function(schema(.y = int32()), int64(), function(x, y) y[[1]]) expect_equal(attr(fun, "in_type")[[1]], schema(.y = int32())) expect_equal(attr(fun, "out_type")[[1]](), int64()) # check in/out type as data type/data type - fun <- arrow_scalar_function(int32(), int64(), function(x, y) y[[1]]) + fun <- arrow_base_scalar_function(int32(), int64(), function(x, y) y[[1]]) expect_equal(attr(fun, "in_type")[[1]], schema(.x = int32())) expect_equal(attr(fun, "out_type")[[1]](), int64()) # check in/out type as field/data type - fun <- arrow_scalar_function( + fun <- arrow_base_scalar_function( field("a_name", int32()), int64(), function(x, y) y[[1]] @@ -42,7 +42,7 @@ test_that("arrow_scalar_function() works", { expect_equal(attr(fun, "out_type")[[1]](), int64()) # check in/out type as lists - fun <- arrow_scalar_function( + fun <- arrow_base_scalar_function( list(int32(), int64()), list(int64(), int32()), function(x, y) y[[1]] @@ -53,12 +53,26 @@ test_that("arrow_scalar_function() works", { expect_equal(attr(fun, "out_type")[[1]](), int64()) expect_equal(attr(fun, "out_type")[[2]](), int32()) - expect_snapshot_error(arrow_scalar_function(int32(), int32(), identity)) - expect_snapshot_error(arrow_scalar_function(int32(), int32(), NULL)) + expect_snapshot_error(arrow_base_scalar_function(int32(), int32(), identity)) + expect_snapshot_error(arrow_base_scalar_function(int32(), int32(), NULL)) +}) + +test_that("arrow_scalar_function() returns a base scalar function", { + base_fun <- arrow_scalar_function( + list(float64(), float64()), + float64(), + function(x, y) { x + y } + ) + + expect_s3_class(base_fun, "arrow_base_scalar_function") + expect_equal( + base_fun(list(), list(Scalar$create(2), Array$create(3))), + Array$create(5) + ) }) test_that("register_scalar_function() creates a dplyr binding", { - fun <- arrow_scalar_function( + fun <- arrow_base_scalar_function( int32(), int64(), function(context, args) args[[1]] ) From c171da6bf6afe0146c3cd416462598e9cf88666a Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 21 Jun 2022 11:00:36 -0300 Subject: [PATCH 16/92] better fun resolution --- r/R/compute.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/R/compute.R b/r/R/compute.R index 8a7707d735b..fb01dd28713 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -328,7 +328,7 @@ register_scalar_function <- function(name, scalar_function, registry_name = name } arrow_scalar_function <- function(in_type, out_type, fun) { - force(fun) + fun <- rlang::as_function(fun) base_fun <- function(kernel_context, args) { args <- lapply(args, as.vector) result <- do.call(fun, args) From e1290286ddae5fcb1e840f19dfd264850fa52794 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 21 Jun 2022 12:12:31 -0300 Subject: [PATCH 17/92] add kernel state class --- r/R/compute.R | 2 +- r/src/compute.cpp | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/r/R/compute.R b/r/R/compute.R index fb01dd28713..df0553de661 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -370,7 +370,7 @@ as_in_types <- function(x) { if (inherits(x, "Field")) { schema(x) } else if (inherits(x, "DataType")) { - schema(".x" = x) + schema(field("", x)) } else { as_schema(x) } diff --git a/r/src/compute.cpp b/r/src/compute.cpp index fd002d9ee92..2f46213677d 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -577,6 +577,17 @@ std::vector compute__GetFunctionNames() { return arrow::compute::GetFunctionRegistry()->GetFunctionNames(); } +class RScalarUDFKernelState : public arrow::compute::KernelState { + public: + RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver, + const std::vector& input_names) + : exec_func_(exec_func), resolver_(resolver), input_names_(input_names) {} + + cpp11::function exec_func_; + cpp11::function resolver_; + cpp11::strings input_names_; +}; + class RScalarUDFOutputTypeResolver : public arrow::compute::OutputType::Resolver { public: RScalarUDFOutputTypeResolver(cpp11::sexp func) : func_(func) {} From 4631cb99df30c73434bdcc86a3792c31c1f14d7a Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 21 Jun 2022 13:18:00 -0300 Subject: [PATCH 18/92] use ScalarKernel.data to keep function references --- r/src/compute.cpp | 34 ++++++++++++++++----------------- r/tests/testthat/test-compute.R | 6 +++--- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 2f46213677d..e07b6b69c7e 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -590,19 +590,22 @@ class RScalarUDFKernelState : public arrow::compute::KernelState { class RScalarUDFOutputTypeResolver : public arrow::compute::OutputType::Resolver { public: - RScalarUDFOutputTypeResolver(cpp11::sexp func) : func_(func) {} - arrow::Result operator()( arrow::compute::KernelContext* context, const std::vector& descr) { return SafeCallIntoR([&]() -> arrow::ValueDescr { + auto kernel = + reinterpret_cast(context->kernel()); + auto state = std::dynamic_pointer_cast(kernel->data); + cpp11::writable::list input_types_sexp; input_types_sexp.reserve(descr.size()); for (const auto& item : descr) { input_types_sexp.push_back(cpp11::to_r6(item.type)); } + input_types_sexp.names() = state->input_names_; - cpp11::sexp output_type_sexp = func_(input_types_sexp); + cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp); if (!Rf_inherits(output_type_sexp, "DataType")) { cpp11::stop("arrow_scalar_function resolver must return a DataType"); } @@ -611,20 +614,18 @@ class RScalarUDFOutputTypeResolver : public arrow::compute::OutputType::Resolver cpp11::as_cpp>(output_type_sexp)); }); } - - private: - cpp11::function func_; }; class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { public: - RScalarUDFCallable(cpp11::sexp func, const std::vector& input_names) - : func_(func), input_names_(input_names) {} - arrow::Status operator()(arrow::compute::KernelContext* context, const arrow::compute::ExecSpan& span, arrow::compute::ExecResult* result) { return SafeCallIntoRVoid([&]() { + auto kernel = + reinterpret_cast(context->kernel()); + auto state = std::dynamic_pointer_cast(kernel->data); + cpp11::writable::list args_sexp; args_sexp.reserve(span.num_values()); @@ -639,14 +640,14 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { } } - args_sexp.names() = input_names_; + args_sexp.names() = state->input_names_; cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length); cpp11::writable::list udf_context = {batch_length_sexp}; udf_context.names() = {"batch_length"}; - cpp11::sexp func_result_sexp = func_(udf_context, args_sexp); + cpp11::sexp func_result_sexp = state->exec_func_(udf_context, args_sexp); if (Rf_inherits(func_result_sexp, "Array")) { auto array = cpp11::as_cpp>(func_result_sexp); @@ -691,10 +692,6 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { } }); } - - private: - cpp11::function func_; - cpp11::strings input_names_; }; // [[arrow::export]] @@ -720,14 +717,15 @@ void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { compute_in_names.push_back(in_types->field(i)->name()); } - arrow::compute::OutputType out_type((RScalarUDFOutputTypeResolver(out_type_func))); + arrow::compute::OutputType out_type((RScalarUDFOutputTypeResolver())); auto signature = std::make_shared( compute_in_types, std::move(out_type), true); - arrow::compute::ScalarKernel kernel( - signature, RScalarUDFCallable(func_sexp, std::move(compute_in_names))); + arrow::compute::ScalarKernel kernel(signature, RScalarUDFCallable()); kernel.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE; kernel.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.data = std::make_shared(func_sexp, out_type_func, + compute_in_names); StopIfNotOk(func->AddKernel(std::move(kernel))); } diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 3ea003c0e89..5213d5ca345 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -29,7 +29,7 @@ test_that("arrow_base_scalar_function() works", { # check in/out type as data type/data type fun <- arrow_base_scalar_function(int32(), int64(), function(x, y) y[[1]]) - expect_equal(attr(fun, "in_type")[[1]], schema(.x = int32())) + expect_equal(attr(fun, "in_type")[[1]][[1]], field("", int32())) expect_equal(attr(fun, "out_type")[[1]](), int64()) # check in/out type as field/data type @@ -48,8 +48,8 @@ test_that("arrow_base_scalar_function() works", { function(x, y) y[[1]] ) - expect_equal(attr(fun, "in_type")[[1]], schema(.x = int32())) - expect_equal(attr(fun, "in_type")[[2]], schema(.x = int64())) + expect_equal(attr(fun, "in_type")[[1]][[1]], field("", int32())) + expect_equal(attr(fun, "in_type")[[2]][[1]], field("", int64())) expect_equal(attr(fun, "out_type")[[1]](), int64()) expect_equal(attr(fun, "out_type")[[2]](), int32()) From 5b82d79e173b17792d369764e906fa7d0835b014 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 21 Jun 2022 15:58:40 -0300 Subject: [PATCH 19/92] documentation for functions --- r/NAMESPACE | 3 + r/R/compute.R | 93 ++++++++++++++++++++++++++++--- r/man/register_scalar_function.Rd | 93 +++++++++++++++++++++++++++++++ r/src/compute.cpp | 14 +---- r/tests/testthat/test-compute.R | 6 +- 5 files changed, 188 insertions(+), 21 deletions(-) create mode 100644 r/man/register_scalar_function.Rd diff --git a/r/NAMESPACE b/r/NAMESPACE index c7d2657baed..ac9858e01b5 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -249,7 +249,9 @@ export(Type) export(UnionDataset) export(all_of) export(arrow_available) +export(arrow_base_scalar_function) export(arrow_info) +export(arrow_scalar_function) export(arrow_table) export(arrow_with_dataset) export(arrow_with_gcs) @@ -343,6 +345,7 @@ export(read_schema) export(read_tsv_arrow) export(record_batch) export(register_extension_type) +export(register_scalar_function) export(reregister_extension_type) export(s3_bucket) export(schema) diff --git a/r/R/compute.R b/r/R/compute.R index df0553de661..81f947aa4ef 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -307,6 +307,77 @@ cast_options <- function(safe = TRUE, ...) { modifyList(opts, list(...)) } +#' Register user-defined functions +#' +#' These functions support calling R code from query engine execution +#' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]). +#' Use [arrow_scalar_function()] to define an R function that accepts and +#' returns R objects; use [arrow_base_scalar_function()] to define a +#' lower-level function that operates directly on Arrow objects. +#' +#' @param name The function name to be used in the dplyr bindings +#' @param scalar_function An object created with [arrow_scalar_function()] +#' or [arrow_base_scalar_function()]. +#' @param registry_name The function name to be used in the Arrow C++ +#' compute function registry. This may be different from `name`. +#' @param in_type A [DataType] of the input type or a [schema()] +#' for functions with more than one argument. This signature will be used +#' to determine if this function is appropriate for a given set of arguments. +#' If this function is appropriate for more than one signature, pass a +#' `list()` of the above. +#' @param out_type A [DataType] of the output type or a function accepting +#' a single argument (`types`), which is a `list()` of [DataType]s. If a +#' function it must return a [DataType]. +#' @param fun An R function or rlang-style lambda expression. This function +#' will be called with R objects as arguments and must return an object +#' that can be converted to an [Array] using [as_arrow_array()]. Function +#' authors must take care to return an array castable to the output data +#' type specified by `out_type`. +#' @param base_fun An R function or rlang-style lambda expression. This +#' function will be called with exactly two arguments: `kernel_context`, +#' which is a `list()` of objects giving information about the +#' execution context and `args`, which is a list of [Array] or [Scalar] +#' objects corresponding to the input arguments. +#' +#' @return +#' - `register_scalar_function()`: `NULL`, invisibly +#' - `arrow_scalar_function()`: returns an object of class +#' "arrow_base_scalar_function" that can be passed to +#' `register_scalar_function()`. +#' @export +#' +#' @examples +#' fun_wrapper <- arrow_scalar_function( +#' schema(x = float64(), y = float64(), z = float64()), +#' float64(), +#' function(x, y, z) x + y + z +#' ) +#' register_scalar_function("example_add3", fun_wrapper) +#' +#' call_function( +#' "example_add3", +#' Scalar$create(1), +#' Scalar$create(2), +#' Array$create(3) +#' ) +#' +#' # use arrow_base_scalar_function() for a lower-level interface +#' base_fun_wrapper <- arrow_base_scalar_function( +#' schema(x = float64(), y = float64(), z = float64()), +#' float64(), +#' function(kernel_context, args) { +#' args[[1]] + args[[2]] + args[[3]] +#' } +#' ) +#' register_scalar_function("example_add3", base_fun_wrapper) +#' +#' call_function( +#' "example_add3", +#' Scalar$create(1), +#' Scalar$create(2), +#' Array$create(3) +#' ) +#' register_scalar_function <- function(name, scalar_function, registry_name = name) { assert_that( is.string(name), @@ -320,13 +391,17 @@ register_scalar_function <- function(name, scalar_function, registry_name = name # register with dplyr bindings register_binding( name, - function(...) Expression$create(compute_registry_name, ...) + function(...) build_expr(registry_name, ...) ) - # invalidate function cache + # recreate dplyr binding cache create_binding_cache() + + invisible(NULL) } +#' @rdname register_scalar_function +#' @export arrow_scalar_function <- function(in_type, out_type, fun) { fun <- rlang::as_function(fun) base_fun <- function(kernel_context, args) { @@ -338,17 +413,19 @@ arrow_scalar_function <- function(in_type, out_type, fun) { arrow_base_scalar_function(in_type, out_type, base_fun) } +#' @rdname register_scalar_function +#' @export arrow_base_scalar_function <- function(in_type, out_type, base_fun) { if (is.list(in_type)) { - in_type <- lapply(in_type, as_in_types) + in_type <- lapply(in_type, as_scalar_function_in_type) } else { - in_type <- list(as_in_types(in_type)) + in_type <- list(as_scalar_function_in_type(in_type)) } if (is.list(out_type)) { - out_type <- lapply(out_type, as_out_type) + out_type <- lapply(out_type, as_scalar_function_out_type) } else { - out_type <- list(as_out_type(out_type)) + out_type <- list(as_scalar_function_out_type(out_type)) } out_type <- rep_len(out_type, length(in_type)) @@ -366,7 +443,7 @@ arrow_base_scalar_function <- function(in_type, out_type, base_fun) { ) } -as_in_types <- function(x) { +as_scalar_function_in_type <- function(x) { if (inherits(x, "Field")) { schema(x) } else if (inherits(x, "DataType")) { @@ -376,7 +453,7 @@ as_in_types <- function(x) { } } -as_out_type <- function(x) { +as_scalar_function_out_type <- function(x) { if (is.function(x)) { x } else { diff --git a/r/man/register_scalar_function.Rd b/r/man/register_scalar_function.Rd new file mode 100644 index 00000000000..570d750451c --- /dev/null +++ b/r/man/register_scalar_function.Rd @@ -0,0 +1,93 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/compute.R +\name{register_scalar_function} +\alias{register_scalar_function} +\alias{arrow_scalar_function} +\alias{arrow_base_scalar_function} +\title{Register user-defined functions} +\usage{ +register_scalar_function(name, scalar_function, registry_name = name) + +arrow_scalar_function(in_type, out_type, fun) + +arrow_base_scalar_function(in_type, out_type, base_fun) +} +\arguments{ +\item{name}{The function name to be used in the dplyr bindings} + +\item{scalar_function}{An object created with \code{\link[=arrow_scalar_function]{arrow_scalar_function()}} +or \code{\link[=arrow_base_scalar_function]{arrow_base_scalar_function()}}.} + +\item{registry_name}{The function name to be used in the Arrow C++ +compute function registry. This may be different from \code{name}.} + +\item{in_type}{A \link{DataType} of the input type or a \code{\link[=schema]{schema()}} +for functions with more than one argument. This signature will be used +to determine if this function is appropriate for a given set of arguments. +If this function is appropriate for more than one signature, pass a +\code{list()} of the above.} + +\item{out_type}{A \link{DataType} of the output type or a function accepting +a single argument (\code{types}), which is a \code{list()} of \link{DataType}s. If a +function it must return a \link{DataType}.} + +\item{fun}{An R function or rlang-style lambda expression. This function +will be called with R objects as arguments and must return an object +that can be converted to an \link{Array} using \code{\link[=as_arrow_array]{as_arrow_array()}}. Function +authors must take care to return an array castable to the output data +type specified by \code{out_type}.} + +\item{base_fun}{An R function or rlang-style lambda expression. This +function will be called with exactly two arguments: \code{kernel_context}, +which is a \code{list()} of objects giving information about the +execution context and \code{args}, which is a list of \link{Array} or \link{Scalar} +objects corresponding to the input arguments.} +} +\value{ +\itemize{ +\item \code{register_scalar_function()}: \code{NULL}, invisibly +\item \code{arrow_scalar_function()}: returns an object of class +"arrow_base_scalar_function" that can be passed to +\code{register_scalar_function()}. +} +} +\description{ +These functions support calling R code from query engine execution +(i.e., a \code{\link[dplyr:mutate]{dplyr::mutate()}} or \code{\link[dplyr:filter]{dplyr::filter()}} on a \link{Table} or \link{Dataset}). +Use \code{\link[=arrow_scalar_function]{arrow_scalar_function()}} to define an R function that accepts and +returns R objects; use \code{\link[=arrow_base_scalar_function]{arrow_base_scalar_function()}} to define a +lower-level function that operates directly on Arrow objects. +} +\examples{ +fun_wrapper <- arrow_scalar_function( + schema(x = float64(), y = float64(), z = float64()), + float64(), + function(x, y, z) x + y + z +) +register_scalar_function("example_add3", fun_wrapper) + +call_function( + "example_add3", + Scalar$create(1), + Scalar$create(2), + Array$create(3) +) + +# use arrow_base_scalar_function() for a lower-level interface +base_fun_wrapper <- arrow_base_scalar_function( + schema(x = float64(), y = float64(), z = float64()), + float64(), + function(kernel_context, args) { + args[[1]] + args[[2]] + args[[3]] + } +) +register_scalar_function("example_add3", base_fun_wrapper) + +call_function( + "example_add3", + Scalar$create(1), + Scalar$create(2), + Array$create(3) +) + +} diff --git a/r/src/compute.cpp b/r/src/compute.cpp index e07b6b69c7e..510facbdd77 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -579,13 +579,11 @@ std::vector compute__GetFunctionNames() { class RScalarUDFKernelState : public arrow::compute::KernelState { public: - RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver, - const std::vector& input_names) - : exec_func_(exec_func), resolver_(resolver), input_names_(input_names) {} + RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver) + : exec_func_(exec_func), resolver_(resolver) {} cpp11::function exec_func_; cpp11::function resolver_; - cpp11::strings input_names_; }; class RScalarUDFOutputTypeResolver : public arrow::compute::OutputType::Resolver { @@ -603,7 +601,6 @@ class RScalarUDFOutputTypeResolver : public arrow::compute::OutputType::Resolver for (const auto& item : descr) { input_types_sexp.push_back(cpp11::to_r6(item.type)); } - input_types_sexp.names() = state->input_names_; cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp); if (!Rf_inherits(output_type_sexp, "DataType")) { @@ -640,8 +637,6 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { } } - args_sexp.names() = state->input_names_; - cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length); cpp11::writable::list udf_context = {batch_length_sexp}; @@ -711,10 +706,8 @@ void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { cpp11::sexp out_type_func = out_type_r[i]; std::vector compute_in_types; - std::vector compute_in_names; for (int64_t j = 0; j < in_types->num_fields(); j++) { compute_in_types.push_back(arrow::compute::InputType(in_types->field(i)->type())); - compute_in_names.push_back(in_types->field(i)->name()); } arrow::compute::OutputType out_type((RScalarUDFOutputTypeResolver())); @@ -724,8 +717,7 @@ void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { arrow::compute::ScalarKernel kernel(signature, RScalarUDFCallable()); kernel.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE; kernel.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; - kernel.data = std::make_shared(func_sexp, out_type_func, - compute_in_names); + kernel.data = std::make_shared(func_sexp, out_type_func); StopIfNotOk(func->AddKernel(std::move(kernel))); } diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 5213d5ca345..0050788e14d 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -71,7 +71,7 @@ test_that("arrow_scalar_function() returns a base scalar function", { ) }) -test_that("register_scalar_function() creates a dplyr binding", { +test_that("register_scalar_function() adds a compute function to the registry", { fun <- arrow_base_scalar_function( int32(), int64(), function(context, args) args[[1]] @@ -93,5 +93,7 @@ test_that("register_scalar_function() creates a dplyr binding", { ) # fails because there's no event loop registered - # record_batch(a = 1L) |> dplyr::mutate(b = arrow_my_test_scalar_function(a)) |> dplyr::collect() + record_batch(a = 1L) %>% + dplyr::mutate(b = arrow_my_test_scalar_function(a)) %>% + dplyr::collect() }) From 99f72258f5bbf16f139a319f146a987db95f97b2 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 21 Jun 2022 16:02:52 -0300 Subject: [PATCH 20/92] add the output type to the kernel context --- r/src/compute.cpp | 6 ++++-- r/tests/testthat/test-compute.R | 5 ++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 510facbdd77..9710e2d22b7 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -639,8 +639,10 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length); - cpp11::writable::list udf_context = {batch_length_sexp}; - udf_context.names() = {"batch_length"}; + std::shared_ptr output_type = result->type()->Copy(); + cpp11::sexp output_type_sexp = cpp11::to_r6(output_type); + cpp11::writable::list udf_context = {batch_length_sexp, output_type_sexp}; + udf_context.names() = {"batch_length", "output_type"}; cpp11::sexp func_result_sexp = state->exec_func_(udf_context, args_sexp); diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 0050788e14d..3dd81de58aa 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -74,7 +74,10 @@ test_that("arrow_scalar_function() returns a base scalar function", { test_that("register_scalar_function() adds a compute function to the registry", { fun <- arrow_base_scalar_function( int32(), int64(), - function(context, args) args[[1]] + function(context, args) { + browser() + args[[1]] + } ) register_scalar_function("my_test_scalar_function", fun) From 80e56834636d540253ee66b29f940b6ccd8fe61d Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 21 Jun 2022 16:09:52 -0300 Subject: [PATCH 21/92] touch up --- r/R/compute.R | 6 +++--- r/man/register_scalar_function.Rd | 2 +- r/tests/testthat/test-compute.R | 7 +++---- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 81f947aa4ef..1d8000ed19c 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -365,7 +365,7 @@ cast_options <- function(safe = TRUE, ...) { #' base_fun_wrapper <- arrow_base_scalar_function( #' schema(x = float64(), y = float64(), z = float64()), #' float64(), -#' function(kernel_context, args) { +#' function(context, args) { #' args[[1]] + args[[2]] + args[[3]] #' } #' ) @@ -404,10 +404,10 @@ register_scalar_function <- function(name, scalar_function, registry_name = name #' @export arrow_scalar_function <- function(in_type, out_type, fun) { fun <- rlang::as_function(fun) - base_fun <- function(kernel_context, args) { + base_fun <- function(context, args) { args <- lapply(args, as.vector) result <- do.call(fun, args) - as_arrow_array(result) + as_arrow_array(result, type = context$output_type) } arrow_base_scalar_function(in_type, out_type, base_fun) diff --git a/r/man/register_scalar_function.Rd b/r/man/register_scalar_function.Rd index 570d750451c..f1ff7d73ee1 100644 --- a/r/man/register_scalar_function.Rd +++ b/r/man/register_scalar_function.Rd @@ -77,7 +77,7 @@ call_function( base_fun_wrapper <- arrow_base_scalar_function( schema(x = float64(), y = float64(), z = float64()), float64(), - function(kernel_context, args) { + function(context, args) { args[[1]] + args[[2]] + args[[3]] } ) diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 3dd81de58aa..dcc15a451b0 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -75,7 +75,6 @@ test_that("register_scalar_function() adds a compute function to the registry", fun <- arrow_base_scalar_function( int32(), int64(), function(context, args) { - browser() args[[1]] } ) @@ -96,7 +95,7 @@ test_that("register_scalar_function() adds a compute function to the registry", ) # fails because there's no event loop registered - record_batch(a = 1L) %>% - dplyr::mutate(b = arrow_my_test_scalar_function(a)) %>% - dplyr::collect() + # record_batch(a = 1L) %>% + # dplyr::mutate(b = my_test_scalar_function(a)) %>% + # dplyr::collect() }) From 1e343a29e18a62b6b2e6c3fa139d312fee67499f Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 21 Jun 2022 16:12:04 -0300 Subject: [PATCH 22/92] add reference page --- r/_pkgdown.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/r/_pkgdown.yml b/r/_pkgdown.yml index c0f599fb8a5..b04cab8195e 100644 --- a/r/_pkgdown.yml +++ b/r/_pkgdown.yml @@ -219,6 +219,7 @@ reference: - match_arrow - value_counts - list_compute_functions + - register_scalar_function - title: Connections to other systems contents: - to_arrow From df6ea0c166fd740984c97772c18b82f910479fa2 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 22 Jun 2022 09:26:00 -0300 Subject: [PATCH 23/92] don't create lists in growables --- r/src/compute.cpp | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 9710e2d22b7..b30dccc49c8 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -596,10 +596,9 @@ class RScalarUDFOutputTypeResolver : public arrow::compute::OutputType::Resolver reinterpret_cast(context->kernel()); auto state = std::dynamic_pointer_cast(kernel->data); - cpp11::writable::list input_types_sexp; - input_types_sexp.reserve(descr.size()); - for (const auto& item : descr) { - input_types_sexp.push_back(cpp11::to_r6(item.type)); + cpp11::writable::list input_types_sexp(descr.size()); + for (int i = 0; i < descr.size(); i++) { + input_types_sexp[i] = cpp11::to_r6(descr[i].type); } cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp); @@ -623,17 +622,16 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { reinterpret_cast(context->kernel()); auto state = std::dynamic_pointer_cast(kernel->data); - cpp11::writable::list args_sexp; - args_sexp.reserve(span.num_values()); + cpp11::writable::list args_sexp(span.num_values()); - for (int64_t i = 0; i < span.num_values(); i++) { + for (int i = 0; i < span.num_values(); i++) { const arrow::compute::ExecValue& exec_val = span[i]; if (exec_val.is_array()) { std::shared_ptr array = exec_val.array.ToArray(); - args_sexp.push_back(cpp11::to_r6(array)); + args_sexp[i] = cpp11::to_r6(array); } else if (exec_val.is_scalar()) { std::shared_ptr scalar = exec_val.scalar->Copy(); - args_sexp.push_back(cpp11::to_r6(scalar)); + args_sexp[i] = cpp11::to_r6(scalar); } } From 36feaacc3ed7613dba5e117f97367a448543393e Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 22 Jun 2022 09:44:41 -0300 Subject: [PATCH 24/92] separate ExecPlan_prepare and ExecPlan_run --- r/R/arrowExports.R | 4 ++++ r/src/arrowExports.cpp | 13 ++++++++++++ r/src/compute-exec.cpp | 48 ++++++++++++++++++++++++++++++++++++------ 3 files changed, 58 insertions(+), 7 deletions(-) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 83f5f2ef99b..321e0c7fb4a 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -408,6 +408,10 @@ ExecPlan_run <- function(plan, final_node, sort_options, metadata, head) { .Call(`_arrow_ExecPlan_run`, plan, final_node, sort_options, metadata, head) } +ExecPlan_read_table <- function(plan, final_node, sort_options, metadata, head) { + .Call(`_arrow_ExecPlan_read_table`, plan, final_node, sort_options, metadata, head) +} + ExecPlan_StopProducing <- function(plan) { invisible(.Call(`_arrow_ExecPlan_StopProducing`, plan)) } diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index fb5a46e2909..6a5456784b2 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -881,6 +881,18 @@ BEGIN_CPP11 END_CPP11 } // compute-exec.cpp +std::shared_ptr ExecPlan_read_table(const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::list sort_options, cpp11::strings metadata, int64_t head); +extern "C" SEXP _arrow_ExecPlan_read_table(SEXP plan_sexp, SEXP final_node_sexp, SEXP sort_options_sexp, SEXP metadata_sexp, SEXP head_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type plan(plan_sexp); + arrow::r::Input&>::type final_node(final_node_sexp); + arrow::r::Input::type sort_options(sort_options_sexp); + arrow::r::Input::type metadata(metadata_sexp); + arrow::r::Input::type head(head_sexp); + return cpp11::as_sexp(ExecPlan_read_table(plan, final_node, sort_options, metadata, head)); +END_CPP11 +} +// compute-exec.cpp void ExecPlan_StopProducing(const std::shared_ptr& plan); extern "C" SEXP _arrow_ExecPlan_StopProducing(SEXP plan_sexp){ BEGIN_CPP11 @@ -5250,6 +5262,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_io___CompressedInputStream__Make", (DL_FUNC) &_arrow_io___CompressedInputStream__Make, 2}, { "_arrow_ExecPlan_create", (DL_FUNC) &_arrow_ExecPlan_create, 1}, { "_arrow_ExecPlan_run", (DL_FUNC) &_arrow_ExecPlan_run, 5}, + { "_arrow_ExecPlan_read_table", (DL_FUNC) &_arrow_ExecPlan_read_table, 5}, { "_arrow_ExecPlan_StopProducing", (DL_FUNC) &_arrow_ExecPlan_StopProducing, 1}, { "_arrow_ExecNode_output_schema", (DL_FUNC) &_arrow_ExecNode_output_schema, 1}, { "_arrow_ExecNode_Scan", (DL_FUNC) &_arrow_ExecNode_Scan, 4}, diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index 76112b4cefd..703847775ca 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -55,11 +55,10 @@ std::shared_ptr MakeExecNodeOrStop( }); } -// [[arrow::export]] -std::shared_ptr ExecPlan_run( - const std::shared_ptr& plan, - const std::shared_ptr& final_node, cpp11::list sort_options, - cpp11::strings metadata, int64_t head = -1) { +std::pair, std::shared_ptr> +ExecPlan_prepare(const std::shared_ptr& plan, + const std::shared_ptr& final_node, + cpp11::list sort_options, cpp11::strings metadata, int64_t head = -1) { // For now, don't require R to construct SinkNodes. // Instead, just pass the node we should collect as an argument. arrow::AsyncGenerator> sink_gen; @@ -89,7 +88,6 @@ std::shared_ptr ExecPlan_run( } StopIfNotOk(plan->Validate()); - StopIfNotOk(plan->StartProducing()); // If the generator is destroyed before being completely drained, inform plan std::shared_ptr stop_producing{nullptr, [plan](...) { @@ -109,9 +107,45 @@ std::shared_ptr ExecPlan_run( auto kv = strings_to_kvm(metadata); out_schema = out_schema->WithMetadata(kv); } - return compute::MakeGeneratorReader( + + std::pair, std::shared_ptr> + out; + out.first = plan; + out.second = compute::MakeGeneratorReader( out_schema, [stop_producing, plan, sink_gen] { return sink_gen(); }, gc_memory_pool()); + return out; +} + +// [[arrow::export]] +std::shared_ptr ExecPlan_run( + const std::shared_ptr& plan, + const std::shared_ptr& final_node, cpp11::list sort_options, + cpp11::strings metadata, int64_t head = -1) { + auto prepared_plan = ExecPlan_prepare(plan, final_node, sort_options, metadata, head); + StopIfNotOk(prepared_plan.first->StartProducing()); + return prepared_plan.second; +} + +// [[arrow::export]] +std::shared_ptr ExecPlan_read_table( + const std::shared_ptr& plan, + const std::shared_ptr& final_node, cpp11::list sort_options, + cpp11::strings metadata, int64_t head = -1) { + auto prepared_plan = ExecPlan_prepare(plan, final_node, sort_options, metadata, head); +#if !defined(HAS_SAFE_CALL_INTO_R) + StopIfNotOk(prepared_plan.first->StartProducing()); + return ValueOrStop(prepared_plan.second->ToTable()); +#else + const auto& io_context = arrow::io::default_io_context(); + auto result = RunWithCapturedR>([&]() { + return DeferNotOk(io_context.executor()->Submit([&]() { + StopIfNotOk(prepared_plan.first->StartProducing()); + return prepared_plan.second->ToTable(); + })); + }); + return ValueOrStop(result); +#endif } // [[arrow::export]] From 32e8d83a5d3c7cb804c8b10281f6b406212574ef Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 22 Jun 2022 10:13:50 -0300 Subject: [PATCH 25/92] push as much exec plan execution into C++ as is possibe --- r/NAMESPACE | 2 ++ r/R/dplyr-collect.R | 2 +- r/R/query-engine.R | 40 ++++++++++++++++++++++++++++------------ r/R/table.R | 15 +++++++++++++++ r/man/as_arrow_table.Rd | 6 ++++++ r/src/compute-exec.cpp | 2 ++ 6 files changed, 54 insertions(+), 13 deletions(-) diff --git a/r/NAMESPACE b/r/NAMESPACE index ac9858e01b5..4f96c5752bd 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -45,7 +45,9 @@ S3method(as_arrow_array,data.frame) S3method(as_arrow_array,default) S3method(as_arrow_array,pyarrow.lib.Array) S3method(as_arrow_table,RecordBatch) +S3method(as_arrow_table,RecordBatchReader) S3method(as_arrow_table,Table) +S3method(as_arrow_table,arrow_dplyr_query) S3method(as_arrow_table,data.frame) S3method(as_arrow_table,default) S3method(as_arrow_table,pyarrow.lib.RecordBatch) diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R index 7f10ed307e8..3e83475a8c8 100644 --- a/r/R/dplyr-collect.R +++ b/r/R/dplyr-collect.R @@ -20,7 +20,7 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { tryCatch( - out <- as_record_batch_reader(x)$read_table(), + out <- as_arrow_table(x), # n = 4 because we want the error to show up as being from collect() # and not handle_csv_read_error() error = function(e, call = caller_env(n = 4)) { diff --git a/r/R/query-engine.R b/r/R/query-engine.R index 511bf3dbc27..f32d746d69f 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -191,7 +191,7 @@ ExecPlan <- R6Class("ExecPlan", } node }, - Run = function(node) { + Run = function(node, as_table = FALSE) { assert_is(node, "ExecNode") # Sorting and head/tail (if sorted) are handled in the SinkNode, @@ -209,13 +209,23 @@ ExecPlan <- R6Class("ExecPlan", sorting$orders <- as.integer(sorting$orders) } - out <- ExecPlan_run( - self, - node, - sorting, - prepare_key_value_metadata(node$final_metadata()), - select_k - ) + if (as_table) { + out <- ExecPlan_read_table( + self, + node, + sorting, + prepare_key_value_metadata(node$final_metadata()), + select_k + ) + } else { + out <- ExecPlan_run( + self, + node, + sorting, + prepare_key_value_metadata(node$final_metadata()), + select_k + ) + } if (!has_sorting) { # Since ExecPlans don't scan in deterministic order, head/tail are both @@ -232,10 +242,12 @@ ExecPlan <- R6Class("ExecPlan", } else if (!is.null(node$extras$tail)) { # TODO(ARROW-16630): proper BottomK support # Reverse the row order to get back what we expect - out <- out$read_table() + out <- as_arrow_table(out) out <- out[rev(seq_len(nrow(out))), , drop = FALSE] # Put back into RBR - out <- as_record_batch_reader(out) + if (!as_table) { + out <- as_record_batch_reader(out) + } } # If arrange() created $temp_columns, make sure to omit them from the result @@ -243,9 +255,13 @@ ExecPlan <- R6Class("ExecPlan", # happens in the end (SinkNode) so nothing comes after it. # TODO(ARROW-16631): move into ExecPlan if (length(node$extras$sort$temp_columns) > 0) { - tab <- out$read_table() + tab <- as_arrow_table(out) tab <- tab[, setdiff(names(tab), node$extras$sort$temp_columns), drop = FALSE] - out <- as_record_batch_reader(tab) + if (!as_table) { + out <- as_record_batch_reader(tab) + } else { + out <- tab + } } out diff --git a/r/R/table.R b/r/R/table.R index 305f305129e..5579c676d51 100644 --- a/r/R/table.R +++ b/r/R/table.R @@ -318,3 +318,18 @@ as_arrow_table.RecordBatch <- function(x, ..., schema = NULL) { as_arrow_table.data.frame <- function(x, ..., schema = NULL) { Table$create(x, schema = schema) } + +#' @rdname as_arrow_table +#' @export +as_arrow_table.RecordBatchReader <- function(x, ...) { + x$read_table() +} + +#' @rdname as_arrow_table +#' @export +as_arrow_table.arrow_dplyr_query <- function(x, ...) { + # See query-engine.R for ExecPlan/Nodes + plan <- ExecPlan$create() + final_node <- plan$Build(x) + plan$Run(final_node, as_table = TRUE) +} diff --git a/r/man/as_arrow_table.Rd b/r/man/as_arrow_table.Rd index 0ba563f581b..aac4495e7c6 100644 --- a/r/man/as_arrow_table.Rd +++ b/r/man/as_arrow_table.Rd @@ -6,6 +6,8 @@ \alias{as_arrow_table.Table} \alias{as_arrow_table.RecordBatch} \alias{as_arrow_table.data.frame} +\alias{as_arrow_table.RecordBatchReader} +\alias{as_arrow_table.arrow_dplyr_query} \title{Convert an object to an Arrow Table} \usage{ as_arrow_table(x, ..., schema = NULL) @@ -17,6 +19,10 @@ as_arrow_table(x, ..., schema = NULL) \method{as_arrow_table}{RecordBatch}(x, ..., schema = NULL) \method{as_arrow_table}{data.frame}(x, ..., schema = NULL) + +\method{as_arrow_table}{RecordBatchReader}(x, ...) + +\method{as_arrow_table}{arrow_dplyr_query}(x, ...) } \arguments{ \item{x}{An object to convert to an Arrow Table} diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index 703847775ca..9b0213f79a7 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -16,12 +16,14 @@ // under the License. #include "./arrow_types.h" +#include "./safe-call-into-r.h" #include #include #include #include #include +#include #include #include #include From 8ff947b6f0dd70c95cd2a4884e0e8377405bbc83 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 22 Jun 2022 10:35:34 -0300 Subject: [PATCH 26/92] test UDF in dplyr query --- r/tests/testthat/test-compute.R | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index dcc15a451b0..62ca1108f05 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -75,7 +75,7 @@ test_that("register_scalar_function() adds a compute function to the registry", fun <- arrow_base_scalar_function( int32(), int64(), function(context, args) { - args[[1]] + args[[1]] + 1L } ) @@ -86,16 +86,18 @@ test_that("register_scalar_function() adds a compute function to the registry", expect_equal( call_function("my_test_scalar_function", Array$create(1L, int32())), - Array$create(1L, int64()) + Array$create(2L, int64()) ) expect_equal( call_function("my_test_scalar_function", Scalar$create(1L, int32())), - Scalar$create(1L, int64()) + Scalar$create(2L, int64()) ) - # fails because there's no event loop registered - # record_batch(a = 1L) %>% - # dplyr::mutate(b = my_test_scalar_function(a)) %>% - # dplyr::collect() + expect_identical( + record_batch(a = 1L) %>% + dplyr::mutate(b = my_test_scalar_function(a)) %>% + dplyr::collect(), + tibble::tibble(a = 1L, b = 2L) + ) }) From 87402fc26e3eda37b815bea9766df1c5a737847e Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 22 Jun 2022 10:44:20 -0300 Subject: [PATCH 27/92] clang-format --- r/src/compute-exec.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index 9b0213f79a7..0d4509a13d9 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -22,8 +22,8 @@ #include #include #include -#include #include +#include #include #include #include From 6d609211d13ea391ebab7b3d0201db8274a2340b Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 22 Jun 2022 13:19:24 -0300 Subject: [PATCH 28/92] maybe fixed sign compare error --- r/src/compute.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index b30dccc49c8..59326e707cc 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -597,7 +597,7 @@ class RScalarUDFOutputTypeResolver : public arrow::compute::OutputType::Resolver auto state = std::dynamic_pointer_cast(kernel->data); cpp11::writable::list input_types_sexp(descr.size()); - for (int i = 0; i < descr.size(); i++) { + for (size_t i = 0; i < descr.size(); i++) { input_types_sexp[i] = cpp11::to_r6(descr[i].type); } From e338f7d2cfd9414003aa58fbaadc8e400f8c126f Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 22 Jun 2022 14:10:17 -0300 Subject: [PATCH 29/92] limit scope on test --- r/tests/testthat/test-compute.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 62ca1108f05..eac242b8c22 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -72,6 +72,8 @@ test_that("arrow_scalar_function() returns a base scalar function", { }) test_that("register_scalar_function() adds a compute function to the registry", { + skip_if_not_available("dataset") + fun <- arrow_base_scalar_function( int32(), int64(), function(context, args) { From e0dd5c06dbfb6deb0517bf8107424c26575fcc7c Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 22 Jun 2022 14:13:48 -0300 Subject: [PATCH 30/92] try to fix lintr errors --- r/R/query-engine.R | 25 ++++++++----------------- r/tests/testthat/test-compute.R | 4 +++- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/r/R/query-engine.R b/r/R/query-engine.R index f32d746d69f..1b59534a5a3 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -209,23 +209,14 @@ ExecPlan <- R6Class("ExecPlan", sorting$orders <- as.integer(sorting$orders) } - if (as_table) { - out <- ExecPlan_read_table( - self, - node, - sorting, - prepare_key_value_metadata(node$final_metadata()), - select_k - ) - } else { - out <- ExecPlan_run( - self, - node, - sorting, - prepare_key_value_metadata(node$final_metadata()), - select_k - ) - } + read_func <- if (as_table) ExecPlan_read_table else ExecPlan_run + out <- read_func( + self, + node, + sorting, + prepare_key_value_metadata(node$final_metadata()), + select_k + ) if (!has_sorting) { # Since ExecPlans don't scan in deterministic order, head/tail are both diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index eac242b8c22..cd3395d72f7 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -61,7 +61,9 @@ test_that("arrow_scalar_function() returns a base scalar function", { base_fun <- arrow_scalar_function( list(float64(), float64()), float64(), - function(x, y) { x + y } + function(x, y) { + x + y + } ) expect_s3_class(base_fun, "arrow_base_scalar_function") From cdecb55fed9f641e02be1b53394eb9791fea8638 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 22 Jun 2022 14:57:39 -0300 Subject: [PATCH 31/92] see if this example is the problem on 32-bit windows --- r/R/compute.R | 2 +- r/man/register_scalar_function.Rd | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 1d8000ed19c..e5f12b1d74f 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -346,7 +346,7 @@ cast_options <- function(safe = TRUE, ...) { #' `register_scalar_function()`. #' @export #' -#' @examples +#' @examplesIf .Machine$sizeof.pointer >= 8 #' fun_wrapper <- arrow_scalar_function( #' schema(x = float64(), y = float64(), z = float64()), #' float64(), diff --git a/r/man/register_scalar_function.Rd b/r/man/register_scalar_function.Rd index f1ff7d73ee1..6d239adba30 100644 --- a/r/man/register_scalar_function.Rd +++ b/r/man/register_scalar_function.Rd @@ -59,6 +59,7 @@ returns R objects; use \code{\link[=arrow_base_scalar_function]{arrow_base_scala lower-level function that operates directly on Arrow objects. } \examples{ +\dontshow{if (.Machine$sizeof.pointer >= 8) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} fun_wrapper <- arrow_scalar_function( schema(x = float64(), y = float64(), z = float64()), float64(), @@ -89,5 +90,5 @@ call_function( Scalar$create(2), Array$create(3) ) - +\dontshow{\}) # examplesIf} } From 65a5dc0b11792f35200b1caad292887aba3d02e3 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 22 Jun 2022 15:54:06 -0300 Subject: [PATCH 32/92] maybe fix on old windows --- r/R/arrowExports.R | 4 ++-- r/R/query-engine.R | 26 ++++++++++++++++++-------- r/src/arrowExports.cpp | 9 +++++---- r/src/compute-exec.cpp | 7 ++++++- 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 321e0c7fb4a..65ca3020f0b 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -408,8 +408,8 @@ ExecPlan_run <- function(plan, final_node, sort_options, metadata, head) { .Call(`_arrow_ExecPlan_run`, plan, final_node, sort_options, metadata, head) } -ExecPlan_read_table <- function(plan, final_node, sort_options, metadata, head) { - .Call(`_arrow_ExecPlan_read_table`, plan, final_node, sort_options, metadata, head) +ExecPlan_read_table <- function(plan, final_node, sort_options, metadata, head, on_old_windows) { + .Call(`_arrow_ExecPlan_read_table`, plan, final_node, sort_options, metadata, head, on_old_windows) } ExecPlan_StopProducing <- function(plan) { diff --git a/r/R/query-engine.R b/r/R/query-engine.R index 1b59534a5a3..a98e87f78e8 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -209,14 +209,24 @@ ExecPlan <- R6Class("ExecPlan", sorting$orders <- as.integer(sorting$orders) } - read_func <- if (as_table) ExecPlan_read_table else ExecPlan_run - out <- read_func( - self, - node, - sorting, - prepare_key_value_metadata(node$final_metadata()), - select_k - ) + if (as_table) { + out <- ExecPlan_read_table( + self, + node, + sorting, + prepare_key_value_metadata(node$final_metadata()), + select_k, + on_old_windows() + ) + } else { + out <- ExecPlan_run( + self, + node, + sorting, + prepare_key_value_metadata(node$final_metadata()), + select_k + ) + } if (!has_sorting) { # Since ExecPlans don't scan in deterministic order, head/tail are both diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 6a5456784b2..cf022e26857 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -881,15 +881,16 @@ BEGIN_CPP11 END_CPP11 } // compute-exec.cpp -std::shared_ptr ExecPlan_read_table(const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::list sort_options, cpp11::strings metadata, int64_t head); -extern "C" SEXP _arrow_ExecPlan_read_table(SEXP plan_sexp, SEXP final_node_sexp, SEXP sort_options_sexp, SEXP metadata_sexp, SEXP head_sexp){ +std::shared_ptr ExecPlan_read_table(const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::list sort_options, cpp11::strings metadata, int64_t head, bool on_old_windows); +extern "C" SEXP _arrow_ExecPlan_read_table(SEXP plan_sexp, SEXP final_node_sexp, SEXP sort_options_sexp, SEXP metadata_sexp, SEXP head_sexp, SEXP on_old_windows_sexp){ BEGIN_CPP11 arrow::r::Input&>::type plan(plan_sexp); arrow::r::Input&>::type final_node(final_node_sexp); arrow::r::Input::type sort_options(sort_options_sexp); arrow::r::Input::type metadata(metadata_sexp); arrow::r::Input::type head(head_sexp); - return cpp11::as_sexp(ExecPlan_read_table(plan, final_node, sort_options, metadata, head)); + arrow::r::Input::type on_old_windows(on_old_windows_sexp); + return cpp11::as_sexp(ExecPlan_read_table(plan, final_node, sort_options, metadata, head, on_old_windows)); END_CPP11 } // compute-exec.cpp @@ -5262,7 +5263,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_io___CompressedInputStream__Make", (DL_FUNC) &_arrow_io___CompressedInputStream__Make, 2}, { "_arrow_ExecPlan_create", (DL_FUNC) &_arrow_ExecPlan_create, 1}, { "_arrow_ExecPlan_run", (DL_FUNC) &_arrow_ExecPlan_run, 5}, - { "_arrow_ExecPlan_read_table", (DL_FUNC) &_arrow_ExecPlan_read_table, 5}, + { "_arrow_ExecPlan_read_table", (DL_FUNC) &_arrow_ExecPlan_read_table, 6}, { "_arrow_ExecPlan_StopProducing", (DL_FUNC) &_arrow_ExecPlan_StopProducing, 1}, { "_arrow_ExecNode_output_schema", (DL_FUNC) &_arrow_ExecNode_output_schema, 1}, { "_arrow_ExecNode_Scan", (DL_FUNC) &_arrow_ExecNode_Scan, 4}, diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index 0d4509a13d9..1400c979d6a 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -133,12 +133,17 @@ std::shared_ptr ExecPlan_run( std::shared_ptr ExecPlan_read_table( const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::list sort_options, - cpp11::strings metadata, int64_t head = -1) { + cpp11::strings metadata, int64_t head = -1, bool on_old_windows = false) { auto prepared_plan = ExecPlan_prepare(plan, final_node, sort_options, metadata, head); #if !defined(HAS_SAFE_CALL_INTO_R) StopIfNotOk(prepared_plan.first->StartProducing()); return ValueOrStop(prepared_plan.second->ToTable()); #else + if (on_old_windows) { + StopIfNotOk(prepared_plan.first->StartProducing()); + return ValueOrStop(prepared_plan.second->ToTable()); + } + const auto& io_context = arrow::io::default_io_context(); auto result = RunWithCapturedR>([&]() { return DeferNotOk(io_context.executor()->Submit([&]() { From f5ec713ab47bc6dc5351eee1421025fbb921da3d Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 6 Jul 2022 13:13:01 -0300 Subject: [PATCH 33/92] add larger-scale dataset test whilst executing a user-defined function --- r/tests/testthat/test-compute.R | 35 +++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index cd3395d72f7..39ffbd8c72d 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -105,3 +105,38 @@ test_that("register_scalar_function() adds a compute function to the registry", tibble::tibble(a = 1L, b = 2L) ) }) + +test_that("user-defined functions work during multi-threaded execution", { + skip_if_not_available("dataset") + + n_rows <- 10000 + n_partitions <- 10 + example_df <- expand.grid( + part = letters[seq_len(n_partitions)], + value = seq_len(n_rows), + stringsAsFactors = FALSE + ) + + # make sure values are different for each partition + example_df$row_num <- seq_len(nrow(example_df)) + example_df$value <- example_df$value + match(example_df$part, letters) + + tf <- tempfile() + on.exit(unlink(tf)) + write_dataset(example_df, tf, partitioning = "part") + + times_32 <- arrow_scalar_function( + int32(), float64(), + function(x) x * 32 + ) + + register_scalar_function("times_32", times_32) + + result <- open_dataset(tf) %>% + dplyr::mutate(fun_result = times_32(value)) %>% + dplyr::collect() %>% + dplyr::arrange(row_num) %>% + tibble::as_tibble() + + expect_identical(result$fun_result, example_df$value * 32) +}) From bb382746a35e2a1d2cce024a63b193856b351b35 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 6 Jul 2022 13:37:44 -0300 Subject: [PATCH 34/92] better variable names in tests --- r/tests/testthat/test-compute.R | 44 ++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 39ffbd8c72d..bc014d8eb2c 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -23,12 +23,18 @@ test_that("list_compute_functions() works", { test_that("arrow_base_scalar_function() works", { # check in/out type as schema/data type - fun <- arrow_base_scalar_function(schema(.y = int32()), int64(), function(x, y) y[[1]]) + fun <- arrow_base_scalar_function( + schema(.y = int32()), int64(), + function(kernel_context, args) args[[1]] + ) expect_equal(attr(fun, "in_type")[[1]], schema(.y = int32())) expect_equal(attr(fun, "out_type")[[1]](), int64()) # check in/out type as data type/data type - fun <- arrow_base_scalar_function(int32(), int64(), function(x, y) y[[1]]) + fun <- arrow_base_scalar_function( + int32(), int64(), + function(kernel_context, args) args[[1]] + ) expect_equal(attr(fun, "in_type")[[1]][[1]], field("", int32())) expect_equal(attr(fun, "out_type")[[1]](), int64()) @@ -36,7 +42,7 @@ test_that("arrow_base_scalar_function() works", { fun <- arrow_base_scalar_function( field("a_name", int32()), int64(), - function(x, y) y[[1]] + function(kernel_context, args) args[[1]] ) expect_equal(attr(fun, "in_type")[[1]], schema(a_name = int32())) expect_equal(attr(fun, "out_type")[[1]](), int64()) @@ -45,7 +51,7 @@ test_that("arrow_base_scalar_function() works", { fun <- arrow_base_scalar_function( list(int32(), int64()), list(int64(), int32()), - function(x, y) y[[1]] + function(kernel_context, args) args[[1]] ) expect_equal(attr(fun, "in_type")[[1]][[1]], field("", int32())) @@ -76,33 +82,31 @@ test_that("arrow_scalar_function() returns a base scalar function", { test_that("register_scalar_function() adds a compute function to the registry", { skip_if_not_available("dataset") - fun <- arrow_base_scalar_function( - int32(), int64(), - function(context, args) { - args[[1]] + 1L - } + times_32 <- arrow_scalar_function( + int32(), float64(), + function(x) x * 32.0 ) - register_scalar_function("my_test_scalar_function", fun) + register_scalar_function("times_32", times_32) - expect_true("my_test_scalar_function" %in% names(arrow:::.cache$functions)) - expect_true("my_test_scalar_function" %in% list_compute_functions()) + expect_true("times_32" %in% names(arrow:::.cache$functions)) + expect_true("times_32" %in% list_compute_functions()) expect_equal( - call_function("my_test_scalar_function", Array$create(1L, int32())), - Array$create(2L, int64()) + call_function("times_32", Array$create(1L, int32())), + Array$create(32L, float64()) ) expect_equal( - call_function("my_test_scalar_function", Scalar$create(1L, int32())), - Scalar$create(2L, int64()) + call_function("times_32", Scalar$create(1L, int32())), + Scalar$create(32L, float64()) ) expect_identical( record_batch(a = 1L) %>% - dplyr::mutate(b = my_test_scalar_function(a)) %>% + dplyr::mutate(b = times_32(a)) %>% dplyr::collect(), - tibble::tibble(a = 1L, b = 2L) + tibble::tibble(a = 1L, b = 32.0) ) }) @@ -117,7 +121,7 @@ test_that("user-defined functions work during multi-threaded execution", { stringsAsFactors = FALSE ) - # make sure values are different for each partition + # make sure values are different for each partition and example_df$row_num <- seq_len(nrow(example_df)) example_df$value <- example_df$value + match(example_df$part, letters) @@ -127,7 +131,7 @@ test_that("user-defined functions work during multi-threaded execution", { times_32 <- arrow_scalar_function( int32(), float64(), - function(x) x * 32 + function(x) x * 32.0 ) register_scalar_function("times_32", times_32) From 565c5b51681e1e14731479f93b0ffaa3e66bea45 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 6 Jul 2022 14:45:28 -0300 Subject: [PATCH 35/92] base_scalar_function -> advanced_scalar_function --- r/NAMESPACE | 2 +- r/R/compute.R | 20 ++++++++++---------- r/man/register_scalar_function.Rd | 16 ++++++++-------- r/tests/testthat/_snaps/compute.md | 2 +- r/tests/testthat/test-compute.R | 16 ++++++++-------- 5 files changed, 28 insertions(+), 28 deletions(-) diff --git a/r/NAMESPACE b/r/NAMESPACE index 4f96c5752bd..e09ee33d459 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -250,8 +250,8 @@ export(TimestampParser) export(Type) export(UnionDataset) export(all_of) +export(arrow_advanced_scalar_function) export(arrow_available) -export(arrow_base_scalar_function) export(arrow_info) export(arrow_scalar_function) export(arrow_table) diff --git a/r/R/compute.R b/r/R/compute.R index e5f12b1d74f..9a25fcd45f2 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -312,12 +312,12 @@ cast_options <- function(safe = TRUE, ...) { #' These functions support calling R code from query engine execution #' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]). #' Use [arrow_scalar_function()] to define an R function that accepts and -#' returns R objects; use [arrow_base_scalar_function()] to define a +#' returns R objects; use [arrow_advanced_scalar_function()] to define a #' lower-level function that operates directly on Arrow objects. #' #' @param name The function name to be used in the dplyr bindings #' @param scalar_function An object created with [arrow_scalar_function()] -#' or [arrow_base_scalar_function()]. +#' or [arrow_advanced_scalar_function()]. #' @param registry_name The function name to be used in the Arrow C++ #' compute function registry. This may be different from `name`. #' @param in_type A [DataType] of the input type or a [schema()] @@ -342,7 +342,7 @@ cast_options <- function(safe = TRUE, ...) { #' @return #' - `register_scalar_function()`: `NULL`, invisibly #' - `arrow_scalar_function()`: returns an object of class -#' "arrow_base_scalar_function" that can be passed to +#' "arrow_advanced_scalar_function" that can be passed to #' `register_scalar_function()`. #' @export #' @@ -361,15 +361,15 @@ cast_options <- function(safe = TRUE, ...) { #' Array$create(3) #' ) #' -#' # use arrow_base_scalar_function() for a lower-level interface -#' base_fun_wrapper <- arrow_base_scalar_function( +#' # use arrow_advanced_scalar_function() for a lower-level interface +#' advanced_fun_wrapper <- arrow_advanced_scalar_function( #' schema(x = float64(), y = float64(), z = float64()), #' float64(), #' function(context, args) { #' args[[1]] + args[[2]] + args[[3]] #' } #' ) -#' register_scalar_function("example_add3", base_fun_wrapper) +#' register_scalar_function("example_add3", advanced_fun_wrapper) #' #' call_function( #' "example_add3", @@ -382,7 +382,7 @@ register_scalar_function <- function(name, scalar_function, registry_name = name assert_that( is.string(name), is.string(registry_name), - inherits(scalar_function, "arrow_base_scalar_function") + inherits(scalar_function, "arrow_advanced_scalar_function") ) # register with Arrow C++ @@ -410,12 +410,12 @@ arrow_scalar_function <- function(in_type, out_type, fun) { as_arrow_array(result, type = context$output_type) } - arrow_base_scalar_function(in_type, out_type, base_fun) + arrow_advanced_scalar_function(in_type, out_type, base_fun) } #' @rdname register_scalar_function #' @export -arrow_base_scalar_function <- function(in_type, out_type, base_fun) { +arrow_advanced_scalar_function <- function(in_type, out_type, base_fun) { if (is.list(in_type)) { in_type <- lapply(in_type, as_scalar_function_in_type) } else { @@ -439,7 +439,7 @@ arrow_base_scalar_function <- function(in_type, out_type, base_fun) { base_fun, in_type = in_type, out_type = out_type, - class = "arrow_base_scalar_function" + class = "arrow_advanced_scalar_function" ) } diff --git a/r/man/register_scalar_function.Rd b/r/man/register_scalar_function.Rd index 6d239adba30..f9bb7654a25 100644 --- a/r/man/register_scalar_function.Rd +++ b/r/man/register_scalar_function.Rd @@ -3,20 +3,20 @@ \name{register_scalar_function} \alias{register_scalar_function} \alias{arrow_scalar_function} -\alias{arrow_base_scalar_function} +\alias{arrow_advanced_scalar_function} \title{Register user-defined functions} \usage{ register_scalar_function(name, scalar_function, registry_name = name) arrow_scalar_function(in_type, out_type, fun) -arrow_base_scalar_function(in_type, out_type, base_fun) +arrow_advanced_scalar_function(in_type, out_type, base_fun) } \arguments{ \item{name}{The function name to be used in the dplyr bindings} \item{scalar_function}{An object created with \code{\link[=arrow_scalar_function]{arrow_scalar_function()}} -or \code{\link[=arrow_base_scalar_function]{arrow_base_scalar_function()}}.} +or \code{\link[=arrow_advanced_scalar_function]{arrow_advanced_scalar_function()}}.} \item{registry_name}{The function name to be used in the Arrow C++ compute function registry. This may be different from \code{name}.} @@ -47,7 +47,7 @@ objects corresponding to the input arguments.} \itemize{ \item \code{register_scalar_function()}: \code{NULL}, invisibly \item \code{arrow_scalar_function()}: returns an object of class -"arrow_base_scalar_function" that can be passed to +"arrow_advanced_scalar_function" that can be passed to \code{register_scalar_function()}. } } @@ -55,7 +55,7 @@ objects corresponding to the input arguments.} These functions support calling R code from query engine execution (i.e., a \code{\link[dplyr:mutate]{dplyr::mutate()}} or \code{\link[dplyr:filter]{dplyr::filter()}} on a \link{Table} or \link{Dataset}). Use \code{\link[=arrow_scalar_function]{arrow_scalar_function()}} to define an R function that accepts and -returns R objects; use \code{\link[=arrow_base_scalar_function]{arrow_base_scalar_function()}} to define a +returns R objects; use \code{\link[=arrow_advanced_scalar_function]{arrow_advanced_scalar_function()}} to define a lower-level function that operates directly on Arrow objects. } \examples{ @@ -74,15 +74,15 @@ call_function( Array$create(3) ) -# use arrow_base_scalar_function() for a lower-level interface -base_fun_wrapper <- arrow_base_scalar_function( +# use arrow_advanced_scalar_function() for a lower-level interface +advanced_fun_wrapper <- arrow_advanced_scalar_function( schema(x = float64(), y = float64(), z = float64()), float64(), function(context, args) { args[[1]] + args[[2]] + args[[3]] } ) -register_scalar_function("example_add3", base_fun_wrapper) +register_scalar_function("example_add3", advanced_fun_wrapper) call_function( "example_add3", diff --git a/r/tests/testthat/_snaps/compute.md b/r/tests/testthat/_snaps/compute.md index a7ba33e24e4..36f9238ec2d 100644 --- a/r/tests/testthat/_snaps/compute.md +++ b/r/tests/testthat/_snaps/compute.md @@ -1,4 +1,4 @@ -# arrow_base_scalar_function() works +# arrow_advanced_scalar_function() works `base_fun` must accept exactly two arguments diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index bc014d8eb2c..387b1d825a9 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -21,9 +21,9 @@ test_that("list_compute_functions() works", { }) -test_that("arrow_base_scalar_function() works", { +test_that("arrow_advanced_scalar_function() works", { # check in/out type as schema/data type - fun <- arrow_base_scalar_function( + fun <- arrow_advanced_scalar_function( schema(.y = int32()), int64(), function(kernel_context, args) args[[1]] ) @@ -31,7 +31,7 @@ test_that("arrow_base_scalar_function() works", { expect_equal(attr(fun, "out_type")[[1]](), int64()) # check in/out type as data type/data type - fun <- arrow_base_scalar_function( + fun <- arrow_advanced_scalar_function( int32(), int64(), function(kernel_context, args) args[[1]] ) @@ -39,7 +39,7 @@ test_that("arrow_base_scalar_function() works", { expect_equal(attr(fun, "out_type")[[1]](), int64()) # check in/out type as field/data type - fun <- arrow_base_scalar_function( + fun <- arrow_advanced_scalar_function( field("a_name", int32()), int64(), function(kernel_context, args) args[[1]] @@ -48,7 +48,7 @@ test_that("arrow_base_scalar_function() works", { expect_equal(attr(fun, "out_type")[[1]](), int64()) # check in/out type as lists - fun <- arrow_base_scalar_function( + fun <- arrow_advanced_scalar_function( list(int32(), int64()), list(int64(), int32()), function(kernel_context, args) args[[1]] @@ -59,8 +59,8 @@ test_that("arrow_base_scalar_function() works", { expect_equal(attr(fun, "out_type")[[1]](), int64()) expect_equal(attr(fun, "out_type")[[2]](), int32()) - expect_snapshot_error(arrow_base_scalar_function(int32(), int32(), identity)) - expect_snapshot_error(arrow_base_scalar_function(int32(), int32(), NULL)) + expect_snapshot_error(arrow_advanced_scalar_function(int32(), int32(), identity)) + expect_snapshot_error(arrow_advanced_scalar_function(int32(), int32(), NULL)) }) test_that("arrow_scalar_function() returns a base scalar function", { @@ -72,7 +72,7 @@ test_that("arrow_scalar_function() returns a base scalar function", { } ) - expect_s3_class(base_fun, "arrow_base_scalar_function") + expect_s3_class(base_fun, "arrow_advanced_scalar_function") expect_equal( base_fun(list(), list(Scalar$create(2), Array$create(3))), Array$create(5) From b96469b5afe4be00cb2898cdf3224946c4172291 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 6 Jul 2022 16:25:55 -0300 Subject: [PATCH 36/92] get write_dataset() to work with user-defined function --- r/R/arrowExports.R | 4 ++-- r/R/query-engine.R | 3 ++- r/src/arrowExports.cpp | 11 ++++++----- r/src/compute-exec.cpp | 21 ++++++++++++++++++++- r/tests/testthat/test-compute.R | 24 ++++++++++++++++++------ 5 files changed, 48 insertions(+), 15 deletions(-) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 65ca3020f0b..9e2adb50ca2 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -424,8 +424,8 @@ ExecNode_Scan <- function(plan, dataset, filter, materialized_field_names) { .Call(`_arrow_ExecNode_Scan`, plan, dataset, filter, materialized_field_names) } -ExecPlan_Write <- function(plan, final_node, metadata, file_write_options, filesystem, base_dir, partitioning, basename_template, existing_data_behavior, max_partitions, max_open_files, max_rows_per_file, min_rows_per_group, max_rows_per_group) { - invisible(.Call(`_arrow_ExecPlan_Write`, plan, final_node, metadata, file_write_options, filesystem, base_dir, partitioning, basename_template, existing_data_behavior, max_partitions, max_open_files, max_rows_per_file, min_rows_per_group, max_rows_per_group)) +ExecPlan_Write <- function(plan, final_node, metadata, file_write_options, filesystem, base_dir, partitioning, basename_template, existing_data_behavior, max_partitions, max_open_files, max_rows_per_file, min_rows_per_group, max_rows_per_group, on_old_windows) { + invisible(.Call(`_arrow_ExecPlan_Write`, plan, final_node, metadata, file_write_options, filesystem, base_dir, partitioning, basename_template, existing_data_behavior, max_partitions, max_open_files, max_rows_per_file, min_rows_per_group, max_rows_per_group, on_old_windows)) } ExecNode_Filter <- function(input, filter) { diff --git a/r/R/query-engine.R b/r/R/query-engine.R index a98e87f78e8..ece212181f7 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -273,7 +273,8 @@ ExecPlan <- R6Class("ExecPlan", self, node, prepare_key_value_metadata(node$final_metadata()), - ... + ..., + on_old_windows = on_old_windows() ) }, Stop = function() ExecPlan_StopProducing(self) diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index cf022e26857..9109d817738 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -930,8 +930,8 @@ extern "C" SEXP _arrow_ExecNode_Scan(SEXP plan_sexp, SEXP dataset_sexp, SEXP fil // compute-exec.cpp #if defined(ARROW_R_WITH_DATASET) -void ExecPlan_Write(const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::strings metadata, const std::shared_ptr& file_write_options, const std::shared_ptr& filesystem, std::string base_dir, const std::shared_ptr& partitioning, std::string basename_template, arrow::dataset::ExistingDataBehavior existing_data_behavior, int max_partitions, uint32_t max_open_files, uint64_t max_rows_per_file, uint64_t min_rows_per_group, uint64_t max_rows_per_group); -extern "C" SEXP _arrow_ExecPlan_Write(SEXP plan_sexp, SEXP final_node_sexp, SEXP metadata_sexp, SEXP file_write_options_sexp, SEXP filesystem_sexp, SEXP base_dir_sexp, SEXP partitioning_sexp, SEXP basename_template_sexp, SEXP existing_data_behavior_sexp, SEXP max_partitions_sexp, SEXP max_open_files_sexp, SEXP max_rows_per_file_sexp, SEXP min_rows_per_group_sexp, SEXP max_rows_per_group_sexp){ +void ExecPlan_Write(const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::strings metadata, const std::shared_ptr& file_write_options, const std::shared_ptr& filesystem, std::string base_dir, const std::shared_ptr& partitioning, std::string basename_template, arrow::dataset::ExistingDataBehavior existing_data_behavior, int max_partitions, uint32_t max_open_files, uint64_t max_rows_per_file, uint64_t min_rows_per_group, uint64_t max_rows_per_group, bool on_old_windows); +extern "C" SEXP _arrow_ExecPlan_Write(SEXP plan_sexp, SEXP final_node_sexp, SEXP metadata_sexp, SEXP file_write_options_sexp, SEXP filesystem_sexp, SEXP base_dir_sexp, SEXP partitioning_sexp, SEXP basename_template_sexp, SEXP existing_data_behavior_sexp, SEXP max_partitions_sexp, SEXP max_open_files_sexp, SEXP max_rows_per_file_sexp, SEXP min_rows_per_group_sexp, SEXP max_rows_per_group_sexp, SEXP on_old_windows_sexp){ BEGIN_CPP11 arrow::r::Input&>::type plan(plan_sexp); arrow::r::Input&>::type final_node(final_node_sexp); @@ -947,12 +947,13 @@ BEGIN_CPP11 arrow::r::Input::type max_rows_per_file(max_rows_per_file_sexp); arrow::r::Input::type min_rows_per_group(min_rows_per_group_sexp); arrow::r::Input::type max_rows_per_group(max_rows_per_group_sexp); - ExecPlan_Write(plan, final_node, metadata, file_write_options, filesystem, base_dir, partitioning, basename_template, existing_data_behavior, max_partitions, max_open_files, max_rows_per_file, min_rows_per_group, max_rows_per_group); + arrow::r::Input::type on_old_windows(on_old_windows_sexp); + ExecPlan_Write(plan, final_node, metadata, file_write_options, filesystem, base_dir, partitioning, basename_template, existing_data_behavior, max_partitions, max_open_files, max_rows_per_file, min_rows_per_group, max_rows_per_group, on_old_windows); return R_NilValue; END_CPP11 } #else -extern "C" SEXP _arrow_ExecPlan_Write(SEXP plan_sexp, SEXP final_node_sexp, SEXP metadata_sexp, SEXP file_write_options_sexp, SEXP filesystem_sexp, SEXP base_dir_sexp, SEXP partitioning_sexp, SEXP basename_template_sexp, SEXP existing_data_behavior_sexp, SEXP max_partitions_sexp, SEXP max_open_files_sexp, SEXP max_rows_per_file_sexp, SEXP min_rows_per_group_sexp, SEXP max_rows_per_group_sexp){ +extern "C" SEXP _arrow_ExecPlan_Write(SEXP plan_sexp, SEXP final_node_sexp, SEXP metadata_sexp, SEXP file_write_options_sexp, SEXP filesystem_sexp, SEXP base_dir_sexp, SEXP partitioning_sexp, SEXP basename_template_sexp, SEXP existing_data_behavior_sexp, SEXP max_partitions_sexp, SEXP max_open_files_sexp, SEXP max_rows_per_file_sexp, SEXP min_rows_per_group_sexp, SEXP max_rows_per_group_sexp, SEXP on_old_windows_sexp){ Rf_error("Cannot call ExecPlan_Write(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); } #endif @@ -5267,7 +5268,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_ExecPlan_StopProducing", (DL_FUNC) &_arrow_ExecPlan_StopProducing, 1}, { "_arrow_ExecNode_output_schema", (DL_FUNC) &_arrow_ExecNode_output_schema, 1}, { "_arrow_ExecNode_Scan", (DL_FUNC) &_arrow_ExecNode_Scan, 4}, - { "_arrow_ExecPlan_Write", (DL_FUNC) &_arrow_ExecPlan_Write, 14}, + { "_arrow_ExecPlan_Write", (DL_FUNC) &_arrow_ExecPlan_Write, 15}, { "_arrow_ExecNode_Filter", (DL_FUNC) &_arrow_ExecNode_Filter, 2}, { "_arrow_ExecNode_Project", (DL_FUNC) &_arrow_ExecNode_Project, 3}, { "_arrow_ExecNode_Aggregate", (DL_FUNC) &_arrow_ExecNode_Aggregate, 3}, diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index 1400c979d6a..3bddedb63ee 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -214,7 +214,7 @@ void ExecPlan_Write( const std::shared_ptr& partitioning, std::string basename_template, arrow::dataset::ExistingDataBehavior existing_data_behavior, int max_partitions, uint32_t max_open_files, uint64_t max_rows_per_file, uint64_t min_rows_per_group, - uint64_t max_rows_per_group) { + uint64_t max_rows_per_group, bool on_old_windows) { arrow::dataset::internal::Initialize(); // TODO(ARROW-16200): expose FileSystemDatasetWriteOptions in R @@ -237,8 +237,27 @@ void ExecPlan_Write( ds::WriteNodeOptions{std::move(opts), std::move(kv)}); StopIfNotOk(plan->Validate()); + +#if !defined(HAS_SAFE_CALL_INTO_R) StopIfNotOk(plan->StartProducing()); StopIfNotOk(plan->finished().status()); +#else + if (on_old_windows) { + StopIfNotOk(plan->StartProducing()); + StopIfNotOk(plan->finished().status()); + } else { + const auto& io_context = arrow::io::default_io_context(); + auto result = RunWithCapturedR([&]() { + return DeferNotOk(io_context.executor()->Submit([&]() -> arrow::Result { + RETURN_NOT_OK(plan->StartProducing()); + RETURN_NOT_OK(plan->finished().status()); + return true; + })); + }); + + StopIfNotOk(result.status()); + } +#endif } #endif diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 387b1d825a9..4fa65399440 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -125,9 +125,10 @@ test_that("user-defined functions work during multi-threaded execution", { example_df$row_num <- seq_len(nrow(example_df)) example_df$value <- example_df$value + match(example_df$part, letters) - tf <- tempfile() - on.exit(unlink(tf)) - write_dataset(example_df, tf, partitioning = "part") + tf_dataset <- tempfile() + tf_dest <- tempfile() + on.exit(unlink(c(tf_dataset, tf_dest))) + write_dataset(example_df, tf_dataset, partitioning = "part") times_32 <- arrow_scalar_function( int32(), float64(), @@ -136,11 +137,22 @@ test_that("user-defined functions work during multi-threaded execution", { register_scalar_function("times_32", times_32) - result <- open_dataset(tf) %>% + # check a regular collect() + result <- open_dataset(tf_dataset) %>% dplyr::mutate(fun_result = times_32(value)) %>% dplyr::collect() %>% - dplyr::arrange(row_num) %>% - tibble::as_tibble() + dplyr::arrange(row_num) expect_identical(result$fun_result, example_df$value * 32) + + # check a write_dataset() + open_dataset(tf_dataset) %>% + dplyr::mutate(fun_result = times_32(value)) %>% + write_dataset(tf_dest) + + result2 <- dplyr::collect(open_dataset(tf_dest)) %>% + dplyr::arrange(row_num) %>% + dplyr::collect() + + expect_identical(result2$fun_result, example_df$value * 32) }) From 0c139d17e026ba1d86cd7f3bb44bca20bdec7164 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 6 Jul 2022 20:51:54 -0300 Subject: [PATCH 37/92] better names for arrow_scalar_function() test --- r/tests/testthat/test-compute.R | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 4fa65399440..fb51e8dd006 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -63,19 +63,21 @@ test_that("arrow_advanced_scalar_function() works", { expect_snapshot_error(arrow_advanced_scalar_function(int32(), int32(), NULL)) }) -test_that("arrow_scalar_function() returns a base scalar function", { - base_fun <- arrow_scalar_function( - list(float64(), float64()), +test_that("arrow_scalar_function() returns an advanced scalar function", { + times_32_wrapper <- arrow_scalar_function( float64(), - function(x, y) { - x + y + float64(), + function(x) { + x * 32 } ) - expect_s3_class(base_fun, "arrow_advanced_scalar_function") + dummy_kernel_context <- list() + + expect_s3_class(times_32_wrapper, "arrow_advanced_scalar_function") expect_equal( - base_fun(list(), list(Scalar$create(2), Array$create(3))), - Array$create(5) + times_32_wrapper(dummy_kernel_context, list(Scalar$create(2))), + Array$create(2 * 32) ) }) From 52880b16bfb76be0cfe64ab0420ea3b79b9c2259 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 6 Jul 2022 21:04:04 -0300 Subject: [PATCH 38/92] change argument order for scalar function constructor --- r/R/compute.R | 18 ++++++++-------- r/man/register_scalar_function.Rd | 20 +++++++++--------- r/tests/testthat/_snaps/compute.md | 4 ++-- r/tests/testthat/test-compute.R | 34 ++++++++++++++---------------- 4 files changed, 37 insertions(+), 39 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 9a25fcd45f2..f54a9ec4b2b 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -348,9 +348,9 @@ cast_options <- function(safe = TRUE, ...) { #' #' @examplesIf .Machine$sizeof.pointer >= 8 #' fun_wrapper <- arrow_scalar_function( -#' schema(x = float64(), y = float64(), z = float64()), -#' float64(), #' function(x, y, z) x + y + z +#' schema(x = float64(), y = float64(), z = float64()), +#' float64() #' ) #' register_scalar_function("example_add3", fun_wrapper) #' @@ -402,7 +402,7 @@ register_scalar_function <- function(name, scalar_function, registry_name = name #' @rdname register_scalar_function #' @export -arrow_scalar_function <- function(in_type, out_type, fun) { +arrow_scalar_function <- function(fun, in_type, out_type) { fun <- rlang::as_function(fun) base_fun <- function(context, args) { args <- lapply(args, as.vector) @@ -410,12 +410,12 @@ arrow_scalar_function <- function(in_type, out_type, fun) { as_arrow_array(result, type = context$output_type) } - arrow_advanced_scalar_function(in_type, out_type, base_fun) + arrow_advanced_scalar_function(base_fun, in_type, out_type) } #' @rdname register_scalar_function #' @export -arrow_advanced_scalar_function <- function(in_type, out_type, base_fun) { +arrow_advanced_scalar_function <- function(fun, in_type, out_type) { if (is.list(in_type)) { in_type <- lapply(in_type, as_scalar_function_in_type) } else { @@ -430,13 +430,13 @@ arrow_advanced_scalar_function <- function(in_type, out_type, base_fun) { out_type <- rep_len(out_type, length(in_type)) - base_fun <- rlang::as_function(base_fun) - if (length(formals(base_fun)) != 2) { - abort("`base_fun` must accept exactly two arguments") + fun <- rlang::as_function(fun) + if (length(formals(fun)) != 2) { + abort("`fun` must accept exactly two arguments") } structure( - base_fun, + fun, in_type = in_type, out_type = out_type, class = "arrow_advanced_scalar_function" diff --git a/r/man/register_scalar_function.Rd b/r/man/register_scalar_function.Rd index f9bb7654a25..8e0e74471ed 100644 --- a/r/man/register_scalar_function.Rd +++ b/r/man/register_scalar_function.Rd @@ -8,9 +8,9 @@ \usage{ register_scalar_function(name, scalar_function, registry_name = name) -arrow_scalar_function(in_type, out_type, fun) +arrow_scalar_function(fun, in_type, out_type) -arrow_advanced_scalar_function(in_type, out_type, base_fun) +arrow_advanced_scalar_function(fun, in_type, out_type) } \arguments{ \item{name}{The function name to be used in the dplyr bindings} @@ -21,6 +21,12 @@ or \code{\link[=arrow_advanced_scalar_function]{arrow_advanced_scalar_function() \item{registry_name}{The function name to be used in the Arrow C++ compute function registry. This may be different from \code{name}.} +\item{fun}{An R function or rlang-style lambda expression. This function +will be called with R objects as arguments and must return an object +that can be converted to an \link{Array} using \code{\link[=as_arrow_array]{as_arrow_array()}}. Function +authors must take care to return an array castable to the output data +type specified by \code{out_type}.} + \item{in_type}{A \link{DataType} of the input type or a \code{\link[=schema]{schema()}} for functions with more than one argument. This signature will be used to determine if this function is appropriate for a given set of arguments. @@ -31,12 +37,6 @@ If this function is appropriate for more than one signature, pass a a single argument (\code{types}), which is a \code{list()} of \link{DataType}s. If a function it must return a \link{DataType}.} -\item{fun}{An R function or rlang-style lambda expression. This function -will be called with R objects as arguments and must return an object -that can be converted to an \link{Array} using \code{\link[=as_arrow_array]{as_arrow_array()}}. Function -authors must take care to return an array castable to the output data -type specified by \code{out_type}.} - \item{base_fun}{An R function or rlang-style lambda expression. This function will be called with exactly two arguments: \code{kernel_context}, which is a \code{list()} of objects giving information about the @@ -61,9 +61,9 @@ lower-level function that operates directly on Arrow objects. \examples{ \dontshow{if (.Machine$sizeof.pointer >= 8) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} fun_wrapper <- arrow_scalar_function( - schema(x = float64(), y = float64(), z = float64()), - float64(), function(x, y, z) x + y + z + schema(x = float64(), y = float64(), z = float64()), + float64() ) register_scalar_function("example_add3", fun_wrapper) diff --git a/r/tests/testthat/_snaps/compute.md b/r/tests/testthat/_snaps/compute.md index 36f9238ec2d..792d29ea0ef 100644 --- a/r/tests/testthat/_snaps/compute.md +++ b/r/tests/testthat/_snaps/compute.md @@ -1,8 +1,8 @@ # arrow_advanced_scalar_function() works - `base_fun` must accept exactly two arguments + `fun` must accept exactly two arguments --- - Can't convert `base_fun`, NULL, to a function. + Can't convert `fun`, NULL, to a function. diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index fb51e8dd006..9fb0fb6f883 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -24,34 +24,34 @@ test_that("list_compute_functions() works", { test_that("arrow_advanced_scalar_function() works", { # check in/out type as schema/data type fun <- arrow_advanced_scalar_function( - schema(.y = int32()), int64(), - function(kernel_context, args) args[[1]] + function(kernel_context, args) args[[1]], + schema(.y = int32()), int64() ) expect_equal(attr(fun, "in_type")[[1]], schema(.y = int32())) expect_equal(attr(fun, "out_type")[[1]](), int64()) # check in/out type as data type/data type fun <- arrow_advanced_scalar_function( - int32(), int64(), - function(kernel_context, args) args[[1]] + function(kernel_context, args) args[[1]], + int32(), int64() ) expect_equal(attr(fun, "in_type")[[1]][[1]], field("", int32())) expect_equal(attr(fun, "out_type")[[1]](), int64()) # check in/out type as field/data type fun <- arrow_advanced_scalar_function( + function(kernel_context, args) args[[1]], field("a_name", int32()), - int64(), - function(kernel_context, args) args[[1]] + int64() ) expect_equal(attr(fun, "in_type")[[1]], schema(a_name = int32())) expect_equal(attr(fun, "out_type")[[1]](), int64()) # check in/out type as lists fun <- arrow_advanced_scalar_function( + function(kernel_context, args) args[[1]], list(int32(), int64()), - list(int64(), int32()), - function(kernel_context, args) args[[1]] + list(int64(), int32()) ) expect_equal(attr(fun, "in_type")[[1]][[1]], field("", int32())) @@ -59,17 +59,15 @@ test_that("arrow_advanced_scalar_function() works", { expect_equal(attr(fun, "out_type")[[1]](), int64()) expect_equal(attr(fun, "out_type")[[2]](), int32()) - expect_snapshot_error(arrow_advanced_scalar_function(int32(), int32(), identity)) - expect_snapshot_error(arrow_advanced_scalar_function(int32(), int32(), NULL)) + expect_snapshot_error(arrow_advanced_scalar_function(identity, int32(), int32())) + expect_snapshot_error(arrow_advanced_scalar_function(NULL, int32(), int32())) }) test_that("arrow_scalar_function() returns an advanced scalar function", { times_32_wrapper <- arrow_scalar_function( + function(x) x * 32, float64(), - float64(), - function(x) { - x * 32 - } + float64() ) dummy_kernel_context <- list() @@ -85,8 +83,8 @@ test_that("register_scalar_function() adds a compute function to the registry", skip_if_not_available("dataset") times_32 <- arrow_scalar_function( - int32(), float64(), - function(x) x * 32.0 + function(x) x * 32.0, + int32(), float64() ) register_scalar_function("times_32", times_32) @@ -133,8 +131,8 @@ test_that("user-defined functions work during multi-threaded execution", { write_dataset(example_df, tf_dataset, partitioning = "part") times_32 <- arrow_scalar_function( - int32(), float64(), - function(x) x * 32.0 + function(x) x * 32.0, + int32(), float64() ) register_scalar_function("times_32", times_32) From e8856b71b624e90f7847530cbd617ceb88f8921a Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 6 Jul 2022 21:07:21 -0300 Subject: [PATCH 39/92] register_scalar_function -> register_user_defined_function --- r/NAMESPACE | 2 +- r/R/compute.R | 10 +++++----- ...r_function.Rd => register_user_defined_function.Rd} | 10 +++++----- r/tests/testthat/test-compute.R | 6 +++--- 4 files changed, 14 insertions(+), 14 deletions(-) rename r/man/{register_scalar_function.Rd => register_user_defined_function.Rd} (92%) diff --git a/r/NAMESPACE b/r/NAMESPACE index e09ee33d459..ced43d01463 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -347,7 +347,7 @@ export(read_schema) export(read_tsv_arrow) export(record_batch) export(register_extension_type) -export(register_scalar_function) +export(register_user_defined_function) export(reregister_extension_type) export(s3_bucket) export(schema) diff --git a/r/R/compute.R b/r/R/compute.R index f54a9ec4b2b..aee7adc8295 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -352,7 +352,7 @@ cast_options <- function(safe = TRUE, ...) { #' schema(x = float64(), y = float64(), z = float64()), #' float64() #' ) -#' register_scalar_function("example_add3", fun_wrapper) +#' register_user_defined_function("example_add3", fun_wrapper) #' #' call_function( #' "example_add3", @@ -369,7 +369,7 @@ cast_options <- function(safe = TRUE, ...) { #' args[[1]] + args[[2]] + args[[3]] #' } #' ) -#' register_scalar_function("example_add3", advanced_fun_wrapper) +#' register_user_defined_function("example_add3", advanced_fun_wrapper) #' #' call_function( #' "example_add3", @@ -378,7 +378,7 @@ cast_options <- function(safe = TRUE, ...) { #' Array$create(3) #' ) #' -register_scalar_function <- function(name, scalar_function, registry_name = name) { +register_user_defined_function <- function(name, scalar_function, registry_name = name) { assert_that( is.string(name), is.string(registry_name), @@ -400,7 +400,7 @@ register_scalar_function <- function(name, scalar_function, registry_name = name invisible(NULL) } -#' @rdname register_scalar_function +#' @rdname register_user_defined_function #' @export arrow_scalar_function <- function(fun, in_type, out_type) { fun <- rlang::as_function(fun) @@ -413,7 +413,7 @@ arrow_scalar_function <- function(fun, in_type, out_type) { arrow_advanced_scalar_function(base_fun, in_type, out_type) } -#' @rdname register_scalar_function +#' @rdname register_user_defined_function #' @export arrow_advanced_scalar_function <- function(fun, in_type, out_type) { if (is.list(in_type)) { diff --git a/r/man/register_scalar_function.Rd b/r/man/register_user_defined_function.Rd similarity index 92% rename from r/man/register_scalar_function.Rd rename to r/man/register_user_defined_function.Rd index 8e0e74471ed..cc671808e80 100644 --- a/r/man/register_scalar_function.Rd +++ b/r/man/register_user_defined_function.Rd @@ -1,12 +1,12 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/compute.R -\name{register_scalar_function} -\alias{register_scalar_function} +\name{register_user_defined_function} +\alias{register_user_defined_function} \alias{arrow_scalar_function} \alias{arrow_advanced_scalar_function} \title{Register user-defined functions} \usage{ -register_scalar_function(name, scalar_function, registry_name = name) +register_user_defined_function(name, scalar_function, registry_name = name) arrow_scalar_function(fun, in_type, out_type) @@ -65,7 +65,7 @@ fun_wrapper <- arrow_scalar_function( schema(x = float64(), y = float64(), z = float64()), float64() ) -register_scalar_function("example_add3", fun_wrapper) +register_user_defined_function("example_add3", fun_wrapper) call_function( "example_add3", @@ -82,7 +82,7 @@ advanced_fun_wrapper <- arrow_advanced_scalar_function( args[[1]] + args[[2]] + args[[3]] } ) -register_scalar_function("example_add3", advanced_fun_wrapper) +register_user_defined_function("example_add3", advanced_fun_wrapper) call_function( "example_add3", diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 9fb0fb6f883..34ef029edb9 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -79,7 +79,7 @@ test_that("arrow_scalar_function() returns an advanced scalar function", { ) }) -test_that("register_scalar_function() adds a compute function to the registry", { +test_that("register_user_defined_function() adds a compute function to the registry", { skip_if_not_available("dataset") times_32 <- arrow_scalar_function( @@ -87,7 +87,7 @@ test_that("register_scalar_function() adds a compute function to the registry", int32(), float64() ) - register_scalar_function("times_32", times_32) + register_user_defined_function("times_32", times_32) expect_true("times_32" %in% names(arrow:::.cache$functions)) expect_true("times_32" %in% list_compute_functions()) @@ -135,7 +135,7 @@ test_that("user-defined functions work during multi-threaded execution", { int32(), float64() ) - register_scalar_function("times_32", times_32) + register_user_defined_function("times_32", times_32) # check a regular collect() result <- open_dataset(tf_dataset) %>% From 2665fdfe6109c84e9750289bf62c1a1cd3795a8f Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 6 Jul 2022 21:31:26 -0300 Subject: [PATCH 40/92] better argument names and inline comments --- r/R/compute.R | 42 +++++++++++++------------ r/man/register_user_defined_function.Rd | 30 +++++++++--------- r/tests/testthat/_snaps/compute.md | 4 +-- r/tests/testthat/test-compute.R | 8 ++--- 4 files changed, 43 insertions(+), 41 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index aee7adc8295..4ba5ef6eb2a 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -333,26 +333,26 @@ cast_options <- function(safe = TRUE, ...) { #' that can be converted to an [Array] using [as_arrow_array()]. Function #' authors must take care to return an array castable to the output data #' type specified by `out_type`. -#' @param base_fun An R function or rlang-style lambda expression. This +#' @param advanced_fun An R function or rlang-style lambda expression. This #' function will be called with exactly two arguments: `kernel_context`, #' which is a `list()` of objects giving information about the #' execution context and `args`, which is a list of [Array] or [Scalar] #' objects corresponding to the input arguments. #' #' @return -#' - `register_scalar_function()`: `NULL`, invisibly +#' - `register_user_defined_function()`: `NULL`, invisibly #' - `arrow_scalar_function()`: returns an object of class #' "arrow_advanced_scalar_function" that can be passed to -#' `register_scalar_function()`. +#' `register_user_defined_function()`. #' @export #' #' @examplesIf .Machine$sizeof.pointer >= 8 #' fun_wrapper <- arrow_scalar_function( -#' function(x, y, z) x + y + z +#' function(x, y, z) x + y + z, #' schema(x = float64(), y = float64(), z = float64()), #' float64() #' ) -#' register_user_defined_function("example_add3", fun_wrapper) +#' register_user_defined_function(fun_wrapper, "example_add3") #' #' call_function( #' "example_add3", @@ -363,13 +363,13 @@ cast_options <- function(safe = TRUE, ...) { #' #' # use arrow_advanced_scalar_function() for a lower-level interface #' advanced_fun_wrapper <- arrow_advanced_scalar_function( -#' schema(x = float64(), y = float64(), z = float64()), -#' float64(), #' function(context, args) { #' args[[1]] + args[[2]] + args[[3]] -#' } +#' }, +#' schema(x = float64(), y = float64(), z = float64()), +#' float64() #' ) -#' register_user_defined_function("example_add3", advanced_fun_wrapper) +#' register_user_defined_function(advanced_fun_wrapper, "example_add3") #' #' call_function( #' "example_add3", @@ -378,20 +378,19 @@ cast_options <- function(safe = TRUE, ...) { #' Array$create(3) #' ) #' -register_user_defined_function <- function(name, scalar_function, registry_name = name) { +register_user_defined_function <- function(scalar_function, name) { assert_that( is.string(name), - is.string(registry_name), inherits(scalar_function, "arrow_advanced_scalar_function") ) # register with Arrow C++ - RegisterScalarUDF(registry_name, scalar_function) + RegisterScalarUDF(name, scalar_function) # register with dplyr bindings register_binding( name, - function(...) build_expr(registry_name, ...) + function(...) build_expr(name, ...) ) # recreate dplyr binding cache @@ -404,18 +403,21 @@ register_user_defined_function <- function(name, scalar_function, registry_name #' @export arrow_scalar_function <- function(fun, in_type, out_type) { fun <- rlang::as_function(fun) - base_fun <- function(context, args) { + + # create a small wrapper that converts Scalar/Array arguments to R vectors + # and converts the result back to an Array + advanced_fun <- function(context, args) { args <- lapply(args, as.vector) result <- do.call(fun, args) as_arrow_array(result, type = context$output_type) } - arrow_advanced_scalar_function(base_fun, in_type, out_type) + arrow_advanced_scalar_function(advanced_fun, in_type, out_type) } #' @rdname register_user_defined_function #' @export -arrow_advanced_scalar_function <- function(fun, in_type, out_type) { +arrow_advanced_scalar_function <- function(advanced_fun, in_type, out_type) { if (is.list(in_type)) { in_type <- lapply(in_type, as_scalar_function_in_type) } else { @@ -430,13 +432,13 @@ arrow_advanced_scalar_function <- function(fun, in_type, out_type) { out_type <- rep_len(out_type, length(in_type)) - fun <- rlang::as_function(fun) - if (length(formals(fun)) != 2) { - abort("`fun` must accept exactly two arguments") + advanced_fun <- rlang::as_function(advanced_fun) + if (length(formals(advanced_fun)) != 2) { + abort("`advanced_fun` must accept exactly two arguments") } structure( - fun, + advanced_fun, in_type = in_type, out_type = out_type, class = "arrow_advanced_scalar_function" diff --git a/r/man/register_user_defined_function.Rd b/r/man/register_user_defined_function.Rd index cc671808e80..d6ebb848f93 100644 --- a/r/man/register_user_defined_function.Rd +++ b/r/man/register_user_defined_function.Rd @@ -6,20 +6,17 @@ \alias{arrow_advanced_scalar_function} \title{Register user-defined functions} \usage{ -register_user_defined_function(name, scalar_function, registry_name = name) +register_user_defined_function(scalar_function, name) arrow_scalar_function(fun, in_type, out_type) -arrow_advanced_scalar_function(fun, in_type, out_type) +arrow_advanced_scalar_function(advanced_fun, in_type, out_type) } \arguments{ -\item{name}{The function name to be used in the dplyr bindings} - \item{scalar_function}{An object created with \code{\link[=arrow_scalar_function]{arrow_scalar_function()}} or \code{\link[=arrow_advanced_scalar_function]{arrow_advanced_scalar_function()}}.} -\item{registry_name}{The function name to be used in the Arrow C++ -compute function registry. This may be different from \code{name}.} +\item{name}{The function name to be used in the dplyr bindings} \item{fun}{An R function or rlang-style lambda expression. This function will be called with R objects as arguments and must return an object @@ -37,18 +34,21 @@ If this function is appropriate for more than one signature, pass a a single argument (\code{types}), which is a \code{list()} of \link{DataType}s. If a function it must return a \link{DataType}.} -\item{base_fun}{An R function or rlang-style lambda expression. This +\item{advanced_fun}{An R function or rlang-style lambda expression. This function will be called with exactly two arguments: \code{kernel_context}, which is a \code{list()} of objects giving information about the execution context and \code{args}, which is a list of \link{Array} or \link{Scalar} objects corresponding to the input arguments.} + +\item{registry_name}{The function name to be used in the Arrow C++ +compute function registry. This may be different from \code{name}.} } \value{ \itemize{ -\item \code{register_scalar_function()}: \code{NULL}, invisibly +\item \code{register_user_defined_function()}: \code{NULL}, invisibly \item \code{arrow_scalar_function()}: returns an object of class "arrow_advanced_scalar_function" that can be passed to -\code{register_scalar_function()}. +\code{register_user_defined_function()}. } } \description{ @@ -61,11 +61,11 @@ lower-level function that operates directly on Arrow objects. \examples{ \dontshow{if (.Machine$sizeof.pointer >= 8) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} fun_wrapper <- arrow_scalar_function( - function(x, y, z) x + y + z + function(x, y, z) x + y + z, schema(x = float64(), y = float64(), z = float64()), float64() ) -register_user_defined_function("example_add3", fun_wrapper) +register_user_defined_function(fun_wrapper, "example_add3") call_function( "example_add3", @@ -76,13 +76,13 @@ call_function( # use arrow_advanced_scalar_function() for a lower-level interface advanced_fun_wrapper <- arrow_advanced_scalar_function( - schema(x = float64(), y = float64(), z = float64()), - float64(), function(context, args) { args[[1]] + args[[2]] + args[[3]] - } + }, + schema(x = float64(), y = float64(), z = float64()), + float64() ) -register_user_defined_function("example_add3", advanced_fun_wrapper) +register_user_defined_function(advanced_fun_wrapper, "example_add3") call_function( "example_add3", diff --git a/r/tests/testthat/_snaps/compute.md b/r/tests/testthat/_snaps/compute.md index 792d29ea0ef..1885067a873 100644 --- a/r/tests/testthat/_snaps/compute.md +++ b/r/tests/testthat/_snaps/compute.md @@ -1,8 +1,8 @@ # arrow_advanced_scalar_function() works - `fun` must accept exactly two arguments + `advanced_fun` must accept exactly two arguments --- - Can't convert `fun`, NULL, to a function. + Can't convert `advanced_fun`, NULL, to a function. diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 34ef029edb9..353e73f611d 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -82,12 +82,12 @@ test_that("arrow_scalar_function() returns an advanced scalar function", { test_that("register_user_defined_function() adds a compute function to the registry", { skip_if_not_available("dataset") - times_32 <- arrow_scalar_function( + times_32_wrapper <- arrow_scalar_function( function(x) x * 32.0, int32(), float64() ) - register_user_defined_function("times_32", times_32) + register_user_defined_function(times_32_wrapper, "times_32") expect_true("times_32" %in% names(arrow:::.cache$functions)) expect_true("times_32" %in% list_compute_functions()) @@ -130,12 +130,12 @@ test_that("user-defined functions work during multi-threaded execution", { on.exit(unlink(c(tf_dataset, tf_dest))) write_dataset(example_df, tf_dataset, partitioning = "part") - times_32 <- arrow_scalar_function( + times_32_wrapper <- arrow_scalar_function( function(x) x * 32.0, int32(), float64() ) - register_user_defined_function("times_32", times_32) + register_user_defined_function(times_32_wrapper, "times_32") # check a regular collect() result <- open_dataset(tf_dataset) %>% From 010ccf629c4d3fc9a1e89e59947e3837f7e75802 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 6 Jul 2022 21:50:09 -0300 Subject: [PATCH 41/92] fix pkgdown reference --- r/_pkgdown.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/_pkgdown.yml b/r/_pkgdown.yml index b04cab8195e..5241c93671b 100644 --- a/r/_pkgdown.yml +++ b/r/_pkgdown.yml @@ -219,7 +219,7 @@ reference: - match_arrow - value_counts - list_compute_functions - - register_scalar_function + - register_user_defined_function - title: Connections to other systems contents: - to_arrow From f732505a94f10b7e94a568dfd3b800382441ccd8 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 7 Jul 2022 09:00:21 -0300 Subject: [PATCH 42/92] remove unused doc entry --- r/R/compute.R | 2 -- r/man/register_user_defined_function.Rd | 3 --- 2 files changed, 5 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 4ba5ef6eb2a..a58ef0a73ff 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -318,8 +318,6 @@ cast_options <- function(safe = TRUE, ...) { #' @param name The function name to be used in the dplyr bindings #' @param scalar_function An object created with [arrow_scalar_function()] #' or [arrow_advanced_scalar_function()]. -#' @param registry_name The function name to be used in the Arrow C++ -#' compute function registry. This may be different from `name`. #' @param in_type A [DataType] of the input type or a [schema()] #' for functions with more than one argument. This signature will be used #' to determine if this function is appropriate for a given set of arguments. diff --git a/r/man/register_user_defined_function.Rd b/r/man/register_user_defined_function.Rd index d6ebb848f93..e462381d5e4 100644 --- a/r/man/register_user_defined_function.Rd +++ b/r/man/register_user_defined_function.Rd @@ -39,9 +39,6 @@ function will be called with exactly two arguments: \code{kernel_context}, which is a \code{list()} of objects giving information about the execution context and \code{args}, which is a list of \link{Array} or \link{Scalar} objects corresponding to the input arguments.} - -\item{registry_name}{The function name to be used in the Arrow C++ -compute function registry. This may be different from \code{name}.} } \value{ \itemize{ From 4c01654d1c4416e87234059ead5d270863459411 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 7 Jul 2022 09:54:45 -0300 Subject: [PATCH 43/92] simplify detection of when we can and can't use SafeCallIntoR() --- r/R/arrowExports.R | 16 ++++++++-------- r/R/feather.R | 4 ++-- r/R/query-engine.R | 29 +++++++++-------------------- r/src/arrowExports.cpp | 38 +++++++++++++++++--------------------- r/src/compute-exec.cpp | 37 ++++++++++++++----------------------- r/src/csv.cpp | 8 ++++---- r/src/feather.cpp | 25 ++++++++----------------- r/src/safe-call-into-r.h | 21 ++++++++++++++------- 8 files changed, 76 insertions(+), 102 deletions(-) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 9e2adb50ca2..4b4d90cf34a 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -408,8 +408,8 @@ ExecPlan_run <- function(plan, final_node, sort_options, metadata, head) { .Call(`_arrow_ExecPlan_run`, plan, final_node, sort_options, metadata, head) } -ExecPlan_read_table <- function(plan, final_node, sort_options, metadata, head, on_old_windows) { - .Call(`_arrow_ExecPlan_read_table`, plan, final_node, sort_options, metadata, head, on_old_windows) +ExecPlan_read_table <- function(plan, final_node, sort_options, metadata, head) { + .Call(`_arrow_ExecPlan_read_table`, plan, final_node, sort_options, metadata, head) } ExecPlan_StopProducing <- function(plan) { @@ -424,8 +424,8 @@ ExecNode_Scan <- function(plan, dataset, filter, materialized_field_names) { .Call(`_arrow_ExecNode_Scan`, plan, dataset, filter, materialized_field_names) } -ExecPlan_Write <- function(plan, final_node, metadata, file_write_options, filesystem, base_dir, partitioning, basename_template, existing_data_behavior, max_partitions, max_open_files, max_rows_per_file, min_rows_per_group, max_rows_per_group, on_old_windows) { - invisible(.Call(`_arrow_ExecPlan_Write`, plan, final_node, metadata, file_write_options, filesystem, base_dir, partitioning, basename_template, existing_data_behavior, max_partitions, max_open_files, max_rows_per_file, min_rows_per_group, max_rows_per_group, on_old_windows)) +ExecPlan_Write <- function(plan, final_node, metadata, file_write_options, filesystem, base_dir, partitioning, basename_template, existing_data_behavior, max_partitions, max_open_files, max_rows_per_file, min_rows_per_group, max_rows_per_group) { + invisible(.Call(`_arrow_ExecPlan_Write`, plan, final_node, metadata, file_write_options, filesystem, base_dir, partitioning, basename_template, existing_data_behavior, max_partitions, max_open_files, max_rows_per_file, min_rows_per_group, max_rows_per_group)) } ExecNode_Filter <- function(input, filter) { @@ -1116,12 +1116,12 @@ ipc___feather___Reader__version <- function(reader) { .Call(`_arrow_ipc___feather___Reader__version`, reader) } -ipc___feather___Reader__Read <- function(reader, columns, on_old_windows) { - .Call(`_arrow_ipc___feather___Reader__Read`, reader, columns, on_old_windows) +ipc___feather___Reader__Read <- function(reader, columns) { + .Call(`_arrow_ipc___feather___Reader__Read`, reader, columns) } -ipc___feather___Reader__Open <- function(stream, on_old_windows) { - .Call(`_arrow_ipc___feather___Reader__Open`, stream, on_old_windows) +ipc___feather___Reader__Open <- function(stream) { + .Call(`_arrow_ipc___feather___Reader__Open`, stream) } ipc___feather___Reader__schema <- function(reader) { diff --git a/r/R/feather.R b/r/R/feather.R index 02871396fa6..73eb5d8b6fd 100644 --- a/r/R/feather.R +++ b/r/R/feather.R @@ -190,7 +190,7 @@ FeatherReader <- R6Class("FeatherReader", inherit = ArrowObject, public = list( Read = function(columns) { - ipc___feather___Reader__Read(self, columns, on_old_windows()) + ipc___feather___Reader__Read(self, columns) }, print = function(...) { cat("FeatherReader:\n") @@ -211,5 +211,5 @@ names.FeatherReader <- function(x) x$column_names FeatherReader$create <- function(file) { assert_is(file, "RandomAccessFile") - ipc___feather___Reader__Open(file, on_old_windows()) + ipc___feather___Reader__Open(file) } diff --git a/r/R/query-engine.R b/r/R/query-engine.R index ece212181f7..16dbb1c7772 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -209,24 +209,14 @@ ExecPlan <- R6Class("ExecPlan", sorting$orders <- as.integer(sorting$orders) } - if (as_table) { - out <- ExecPlan_read_table( - self, - node, - sorting, - prepare_key_value_metadata(node$final_metadata()), - select_k, - on_old_windows() - ) - } else { - out <- ExecPlan_run( - self, - node, - sorting, - prepare_key_value_metadata(node$final_metadata()), - select_k - ) - } + exec_fun <- if (as_table) ExecPlan_read_table else ExecPlan_run + out <- exec_fun( + self, + node, + sorting, + prepare_key_value_metadata(node$final_metadata()), + select_k + ) if (!has_sorting) { # Since ExecPlans don't scan in deterministic order, head/tail are both @@ -273,8 +263,7 @@ ExecPlan <- R6Class("ExecPlan", self, node, prepare_key_value_metadata(node$final_metadata()), - ..., - on_old_windows = on_old_windows() + ... ) }, Stop = function() ExecPlan_StopProducing(self) diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 9109d817738..e82af9be48b 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -881,16 +881,15 @@ BEGIN_CPP11 END_CPP11 } // compute-exec.cpp -std::shared_ptr ExecPlan_read_table(const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::list sort_options, cpp11::strings metadata, int64_t head, bool on_old_windows); -extern "C" SEXP _arrow_ExecPlan_read_table(SEXP plan_sexp, SEXP final_node_sexp, SEXP sort_options_sexp, SEXP metadata_sexp, SEXP head_sexp, SEXP on_old_windows_sexp){ +std::shared_ptr ExecPlan_read_table(const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::list sort_options, cpp11::strings metadata, int64_t head); +extern "C" SEXP _arrow_ExecPlan_read_table(SEXP plan_sexp, SEXP final_node_sexp, SEXP sort_options_sexp, SEXP metadata_sexp, SEXP head_sexp){ BEGIN_CPP11 arrow::r::Input&>::type plan(plan_sexp); arrow::r::Input&>::type final_node(final_node_sexp); arrow::r::Input::type sort_options(sort_options_sexp); arrow::r::Input::type metadata(metadata_sexp); arrow::r::Input::type head(head_sexp); - arrow::r::Input::type on_old_windows(on_old_windows_sexp); - return cpp11::as_sexp(ExecPlan_read_table(plan, final_node, sort_options, metadata, head, on_old_windows)); + return cpp11::as_sexp(ExecPlan_read_table(plan, final_node, sort_options, metadata, head)); END_CPP11 } // compute-exec.cpp @@ -930,8 +929,8 @@ extern "C" SEXP _arrow_ExecNode_Scan(SEXP plan_sexp, SEXP dataset_sexp, SEXP fil // compute-exec.cpp #if defined(ARROW_R_WITH_DATASET) -void ExecPlan_Write(const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::strings metadata, const std::shared_ptr& file_write_options, const std::shared_ptr& filesystem, std::string base_dir, const std::shared_ptr& partitioning, std::string basename_template, arrow::dataset::ExistingDataBehavior existing_data_behavior, int max_partitions, uint32_t max_open_files, uint64_t max_rows_per_file, uint64_t min_rows_per_group, uint64_t max_rows_per_group, bool on_old_windows); -extern "C" SEXP _arrow_ExecPlan_Write(SEXP plan_sexp, SEXP final_node_sexp, SEXP metadata_sexp, SEXP file_write_options_sexp, SEXP filesystem_sexp, SEXP base_dir_sexp, SEXP partitioning_sexp, SEXP basename_template_sexp, SEXP existing_data_behavior_sexp, SEXP max_partitions_sexp, SEXP max_open_files_sexp, SEXP max_rows_per_file_sexp, SEXP min_rows_per_group_sexp, SEXP max_rows_per_group_sexp, SEXP on_old_windows_sexp){ +void ExecPlan_Write(const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::strings metadata, const std::shared_ptr& file_write_options, const std::shared_ptr& filesystem, std::string base_dir, const std::shared_ptr& partitioning, std::string basename_template, arrow::dataset::ExistingDataBehavior existing_data_behavior, int max_partitions, uint32_t max_open_files, uint64_t max_rows_per_file, uint64_t min_rows_per_group, uint64_t max_rows_per_group); +extern "C" SEXP _arrow_ExecPlan_Write(SEXP plan_sexp, SEXP final_node_sexp, SEXP metadata_sexp, SEXP file_write_options_sexp, SEXP filesystem_sexp, SEXP base_dir_sexp, SEXP partitioning_sexp, SEXP basename_template_sexp, SEXP existing_data_behavior_sexp, SEXP max_partitions_sexp, SEXP max_open_files_sexp, SEXP max_rows_per_file_sexp, SEXP min_rows_per_group_sexp, SEXP max_rows_per_group_sexp){ BEGIN_CPP11 arrow::r::Input&>::type plan(plan_sexp); arrow::r::Input&>::type final_node(final_node_sexp); @@ -947,13 +946,12 @@ BEGIN_CPP11 arrow::r::Input::type max_rows_per_file(max_rows_per_file_sexp); arrow::r::Input::type min_rows_per_group(min_rows_per_group_sexp); arrow::r::Input::type max_rows_per_group(max_rows_per_group_sexp); - arrow::r::Input::type on_old_windows(on_old_windows_sexp); - ExecPlan_Write(plan, final_node, metadata, file_write_options, filesystem, base_dir, partitioning, basename_template, existing_data_behavior, max_partitions, max_open_files, max_rows_per_file, min_rows_per_group, max_rows_per_group, on_old_windows); + ExecPlan_Write(plan, final_node, metadata, file_write_options, filesystem, base_dir, partitioning, basename_template, existing_data_behavior, max_partitions, max_open_files, max_rows_per_file, min_rows_per_group, max_rows_per_group); return R_NilValue; END_CPP11 } #else -extern "C" SEXP _arrow_ExecPlan_Write(SEXP plan_sexp, SEXP final_node_sexp, SEXP metadata_sexp, SEXP file_write_options_sexp, SEXP filesystem_sexp, SEXP base_dir_sexp, SEXP partitioning_sexp, SEXP basename_template_sexp, SEXP existing_data_behavior_sexp, SEXP max_partitions_sexp, SEXP max_open_files_sexp, SEXP max_rows_per_file_sexp, SEXP min_rows_per_group_sexp, SEXP max_rows_per_group_sexp, SEXP on_old_windows_sexp){ +extern "C" SEXP _arrow_ExecPlan_Write(SEXP plan_sexp, SEXP final_node_sexp, SEXP metadata_sexp, SEXP file_write_options_sexp, SEXP filesystem_sexp, SEXP base_dir_sexp, SEXP partitioning_sexp, SEXP basename_template_sexp, SEXP existing_data_behavior_sexp, SEXP max_partitions_sexp, SEXP max_open_files_sexp, SEXP max_rows_per_file_sexp, SEXP min_rows_per_group_sexp, SEXP max_rows_per_group_sexp){ Rf_error("Cannot call ExecPlan_Write(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); } #endif @@ -2812,22 +2810,20 @@ BEGIN_CPP11 END_CPP11 } // feather.cpp -std::shared_ptr ipc___feather___Reader__Read(const std::shared_ptr& reader, cpp11::sexp columns, bool on_old_windows); -extern "C" SEXP _arrow_ipc___feather___Reader__Read(SEXP reader_sexp, SEXP columns_sexp, SEXP on_old_windows_sexp){ +std::shared_ptr ipc___feather___Reader__Read(const std::shared_ptr& reader, cpp11::sexp columns); +extern "C" SEXP _arrow_ipc___feather___Reader__Read(SEXP reader_sexp, SEXP columns_sexp){ BEGIN_CPP11 arrow::r::Input&>::type reader(reader_sexp); arrow::r::Input::type columns(columns_sexp); - arrow::r::Input::type on_old_windows(on_old_windows_sexp); - return cpp11::as_sexp(ipc___feather___Reader__Read(reader, columns, on_old_windows)); + return cpp11::as_sexp(ipc___feather___Reader__Read(reader, columns)); END_CPP11 } // feather.cpp -std::shared_ptr ipc___feather___Reader__Open(const std::shared_ptr& stream, bool on_old_windows); -extern "C" SEXP _arrow_ipc___feather___Reader__Open(SEXP stream_sexp, SEXP on_old_windows_sexp){ +std::shared_ptr ipc___feather___Reader__Open(const std::shared_ptr& stream); +extern "C" SEXP _arrow_ipc___feather___Reader__Open(SEXP stream_sexp){ BEGIN_CPP11 arrow::r::Input&>::type stream(stream_sexp); - arrow::r::Input::type on_old_windows(on_old_windows_sexp); - return cpp11::as_sexp(ipc___feather___Reader__Open(stream, on_old_windows)); + return cpp11::as_sexp(ipc___feather___Reader__Open(stream)); END_CPP11 } // feather.cpp @@ -5264,11 +5260,11 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_io___CompressedInputStream__Make", (DL_FUNC) &_arrow_io___CompressedInputStream__Make, 2}, { "_arrow_ExecPlan_create", (DL_FUNC) &_arrow_ExecPlan_create, 1}, { "_arrow_ExecPlan_run", (DL_FUNC) &_arrow_ExecPlan_run, 5}, - { "_arrow_ExecPlan_read_table", (DL_FUNC) &_arrow_ExecPlan_read_table, 6}, + { "_arrow_ExecPlan_read_table", (DL_FUNC) &_arrow_ExecPlan_read_table, 5}, { "_arrow_ExecPlan_StopProducing", (DL_FUNC) &_arrow_ExecPlan_StopProducing, 1}, { "_arrow_ExecNode_output_schema", (DL_FUNC) &_arrow_ExecNode_output_schema, 1}, { "_arrow_ExecNode_Scan", (DL_FUNC) &_arrow_ExecNode_Scan, 4}, - { "_arrow_ExecPlan_Write", (DL_FUNC) &_arrow_ExecPlan_Write, 15}, + { "_arrow_ExecPlan_Write", (DL_FUNC) &_arrow_ExecPlan_Write, 14}, { "_arrow_ExecNode_Filter", (DL_FUNC) &_arrow_ExecNode_Filter, 2}, { "_arrow_ExecNode_Project", (DL_FUNC) &_arrow_ExecNode_Project, 3}, { "_arrow_ExecNode_Aggregate", (DL_FUNC) &_arrow_ExecNode_Aggregate, 3}, @@ -5441,8 +5437,8 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_arrow__UnregisterRExtensionType", (DL_FUNC) &_arrow_arrow__UnregisterRExtensionType, 1}, { "_arrow_ipc___WriteFeather__Table", (DL_FUNC) &_arrow_ipc___WriteFeather__Table, 6}, { "_arrow_ipc___feather___Reader__version", (DL_FUNC) &_arrow_ipc___feather___Reader__version, 1}, - { "_arrow_ipc___feather___Reader__Read", (DL_FUNC) &_arrow_ipc___feather___Reader__Read, 3}, - { "_arrow_ipc___feather___Reader__Open", (DL_FUNC) &_arrow_ipc___feather___Reader__Open, 2}, + { "_arrow_ipc___feather___Reader__Read", (DL_FUNC) &_arrow_ipc___feather___Reader__Read, 2}, + { "_arrow_ipc___feather___Reader__Open", (DL_FUNC) &_arrow_ipc___feather___Reader__Open, 1}, { "_arrow_ipc___feather___Reader__schema", (DL_FUNC) &_arrow_ipc___feather___Reader__schema, 1}, { "_arrow_Field__initialize", (DL_FUNC) &_arrow_Field__initialize, 3}, { "_arrow_Field__ToString", (DL_FUNC) &_arrow_Field__ToString, 1}, diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index 3bddedb63ee..15782c06876 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -133,26 +133,22 @@ std::shared_ptr ExecPlan_run( std::shared_ptr ExecPlan_read_table( const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::list sort_options, - cpp11::strings metadata, int64_t head = -1, bool on_old_windows = false) { + cpp11::strings metadata, int64_t head = -1) { auto prepared_plan = ExecPlan_prepare(plan, final_node, sort_options, metadata, head); -#if !defined(HAS_SAFE_CALL_INTO_R) - StopIfNotOk(prepared_plan.first->StartProducing()); - return ValueOrStop(prepared_plan.second->ToTable()); -#else - if (on_old_windows) { + if (!CanSafeCallIntoR()) { StopIfNotOk(prepared_plan.first->StartProducing()); return ValueOrStop(prepared_plan.second->ToTable()); - } + } else { + const auto& io_context = arrow::io::default_io_context(); + auto result = RunWithCapturedR>([&]() { + return DeferNotOk(io_context.executor()->Submit([&]() { + StopIfNotOk(prepared_plan.first->StartProducing()); + return prepared_plan.second->ToTable(); + })); + }); - const auto& io_context = arrow::io::default_io_context(); - auto result = RunWithCapturedR>([&]() { - return DeferNotOk(io_context.executor()->Submit([&]() { - StopIfNotOk(prepared_plan.first->StartProducing()); - return prepared_plan.second->ToTable(); - })); - }); - return ValueOrStop(result); -#endif + return ValueOrStop(result); + } } // [[arrow::export]] @@ -214,7 +210,7 @@ void ExecPlan_Write( const std::shared_ptr& partitioning, std::string basename_template, arrow::dataset::ExistingDataBehavior existing_data_behavior, int max_partitions, uint32_t max_open_files, uint64_t max_rows_per_file, uint64_t min_rows_per_group, - uint64_t max_rows_per_group, bool on_old_windows) { + uint64_t max_rows_per_group) { arrow::dataset::internal::Initialize(); // TODO(ARROW-16200): expose FileSystemDatasetWriteOptions in R @@ -238,11 +234,7 @@ void ExecPlan_Write( StopIfNotOk(plan->Validate()); -#if !defined(HAS_SAFE_CALL_INTO_R) - StopIfNotOk(plan->StartProducing()); - StopIfNotOk(plan->finished().status()); -#else - if (on_old_windows) { + if (!CanSafeCallIntoR()) { StopIfNotOk(plan->StartProducing()); StopIfNotOk(plan->finished().status()); } else { @@ -257,7 +249,6 @@ void ExecPlan_Write( StopIfNotOk(result.status()); } -#endif } #endif diff --git a/r/src/csv.cpp b/r/src/csv.cpp index d031cc87cac..0a9cbb72b46 100644 --- a/r/src/csv.cpp +++ b/r/src/csv.cpp @@ -162,16 +162,16 @@ std::shared_ptr csv___TableReader__Make( // [[arrow::export]] std::shared_ptr csv___TableReader__Read( const std::shared_ptr& table_reader) { -#if !defined(HAS_SAFE_CALL_INTO_R) - return ValueOrStop(table_reader->Read()); -#else + if (!CanSafeCallIntoR()) { + return ValueOrStop(table_reader->Read()); + } + const auto& io_context = arrow::io::default_io_context(); auto result = RunWithCapturedR>([&]() { return DeferNotOk( io_context.executor()->Submit([&]() { return table_reader->Read(); })); }); return ValueOrStop(result); -#endif } // [[arrow::export]] diff --git a/r/src/feather.cpp b/r/src/feather.cpp index debabe49689..4d264beb07b 100644 --- a/r/src/feather.cpp +++ b/r/src/feather.cpp @@ -49,8 +49,7 @@ int ipc___feather___Reader__version( // [[arrow::export]] std::shared_ptr ipc___feather___Reader__Read( - const std::shared_ptr& reader, cpp11::sexp columns, - bool on_old_windows) { + const std::shared_ptr& reader, cpp11::sexp columns) { bool use_names = columns != R_NilValue; std::vector names; if (use_names) { @@ -77,37 +76,29 @@ std::shared_ptr ipc___feather___Reader__Read( } }; -#if !defined(HAS_SAFE_CALL_INTO_R) - return ValueOrStop(read_table()); -#else - if (!on_old_windows) { + if (!CanSafeCallIntoR()) { + return ValueOrStop(read_table()); + } else { const auto& io_context = arrow::io::default_io_context(); auto result = RunWithCapturedR>( [&]() { return DeferNotOk(io_context.executor()->Submit(read_table)); }); return ValueOrStop(result); - } else { - return ValueOrStop(read_table()); } -#endif } // [[arrow::export]] std::shared_ptr ipc___feather___Reader__Open( - const std::shared_ptr& stream, bool on_old_windows) { -#if !defined(HAS_SAFE_CALL_INTO_R) - return ValueOrStop(arrow::ipc::feather::Reader::Open(stream)); -#else - if (!on_old_windows) { + const std::shared_ptr& stream) { + if (!CanSafeCallIntoR()) { + return ValueOrStop(arrow::ipc::feather::Reader::Open(stream)); + } else { const auto& io_context = arrow::io::default_io_context(); auto result = RunWithCapturedR>([&]() { return DeferNotOk(io_context.executor()->Submit( [&]() { return arrow::ipc::feather::Reader::Open(stream); })); }); return ValueOrStop(result); - } else { - return ValueOrStop(arrow::ipc::feather::Reader::Open(stream)); } -#endif } // [[arrow::export]] diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h index 0555628d7d5..7e1f2c83599 100644 --- a/r/src/safe-call-into-r.h +++ b/r/src/safe-call-into-r.h @@ -27,11 +27,17 @@ #include // Unwind protection was added in R 3.5 and some calls here use it -// and crash R in older versions (ARROW-16201). We use this define -// to make sure we don't crash on R 3.4 and lower. +// and crash R in older versions (ARROW-16201). Crashes also occur +// on 32-bit R builds on R 3.6 and lower. +static inline bool CanSafeCallIntoR() { #if defined(HAS_UNWIND_PROTECT) -#define HAS_SAFE_CALL_INTO_R + cpp11::function on_old_windows_fun = cpp11::package("arrow")["on_old_windows"]; + bool on_old_windows = on_old_windows_fun(); + return !on_old_windows; +#else + return false; #endif +} // The MainRThread class keeps track of the thread on which it is safe // to call the R API to facilitate its safe use (or erroring @@ -139,9 +145,11 @@ static inline arrow::Status SafeCallIntoRVoid(std::function fun) { template arrow::Result RunWithCapturedR(std::function()> make_arrow_call) { -#if !defined(HAS_SAFE_CALL_INTO_R) - return arrow::Status::NotImplemented("RunWithCapturedR() without UnwindProtect"); -#else + if (!CanSafeCallIntoR()) { + return arrow::Status::NotImplemented( + "RunWithCapturedR() without UnwindProtect or on 32-bit Windows + R <= 3.6"); + } + if (GetMainRThread().Executor() != nullptr) { return arrow::Status::AlreadyExists("Attempt to use more than one R Executor()"); } @@ -158,7 +166,6 @@ arrow::Result RunWithCapturedR(std::function()> make_arrow_c GetMainRThread().ClearError(); return result; -#endif } #endif From 21a932a5bf1fd651ecee10b1712d5b84a533c277 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 7 Jul 2022 10:30:21 -0300 Subject: [PATCH 44/92] abstract and document RunWithCapturedR usage --- r/src/compute-exec.cpp | 40 ++++++++++++---------------------------- r/src/csv.cpp | 11 ++--------- r/src/feather.cpp | 26 ++++++-------------------- r/src/safe-call-into-r.h | 25 +++++++++++++++++++++++++ 4 files changed, 45 insertions(+), 57 deletions(-) diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index 15782c06876..4821361357d 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -22,7 +22,6 @@ #include #include #include -#include #include #include #include @@ -135,20 +134,14 @@ std::shared_ptr ExecPlan_read_table( const std::shared_ptr& final_node, cpp11::list sort_options, cpp11::strings metadata, int64_t head = -1) { auto prepared_plan = ExecPlan_prepare(plan, final_node, sort_options, metadata, head); - if (!CanSafeCallIntoR()) { - StopIfNotOk(prepared_plan.first->StartProducing()); - return ValueOrStop(prepared_plan.second->ToTable()); - } else { - const auto& io_context = arrow::io::default_io_context(); - auto result = RunWithCapturedR>([&]() { - return DeferNotOk(io_context.executor()->Submit([&]() { - StopIfNotOk(prepared_plan.first->StartProducing()); + + auto result = RunWithCapturedRIfPossible>( + [&]() -> arrow::Result> { + ARROW_RETURN_NOT_OK(prepared_plan.first->StartProducing()); return prepared_plan.second->ToTable(); - })); - }); + }); - return ValueOrStop(result); - } + return ValueOrStop(result); } // [[arrow::export]] @@ -233,22 +226,13 @@ void ExecPlan_Write( ds::WriteNodeOptions{std::move(opts), std::move(kv)}); StopIfNotOk(plan->Validate()); + auto result = RunWithCapturedRIfPossible([&]() -> arrow::Result { + RETURN_NOT_OK(plan->StartProducing()); + RETURN_NOT_OK(plan->finished().status()); + return true; + }); - if (!CanSafeCallIntoR()) { - StopIfNotOk(plan->StartProducing()); - StopIfNotOk(plan->finished().status()); - } else { - const auto& io_context = arrow::io::default_io_context(); - auto result = RunWithCapturedR([&]() { - return DeferNotOk(io_context.executor()->Submit([&]() -> arrow::Result { - RETURN_NOT_OK(plan->StartProducing()); - RETURN_NOT_OK(plan->finished().status()); - return true; - })); - }); - - StopIfNotOk(result.status()); - } + StopIfNotOk(result.status()); } #endif diff --git a/r/src/csv.cpp b/r/src/csv.cpp index 0a9cbb72b46..7ce55feb5fe 100644 --- a/r/src/csv.cpp +++ b/r/src/csv.cpp @@ -162,15 +162,8 @@ std::shared_ptr csv___TableReader__Make( // [[arrow::export]] std::shared_ptr csv___TableReader__Read( const std::shared_ptr& table_reader) { - if (!CanSafeCallIntoR()) { - return ValueOrStop(table_reader->Read()); - } - - const auto& io_context = arrow::io::default_io_context(); - auto result = RunWithCapturedR>([&]() { - return DeferNotOk( - io_context.executor()->Submit([&]() { return table_reader->Read(); })); - }); + auto result = RunWithCapturedRIfPossible>( + [&]() { return table_reader->Read(); }); return ValueOrStop(result); } diff --git a/r/src/feather.cpp b/r/src/feather.cpp index 4d264beb07b..cf68faef1b5 100644 --- a/r/src/feather.cpp +++ b/r/src/feather.cpp @@ -60,7 +60,7 @@ std::shared_ptr ipc___feather___Reader__Read( } } - auto read_table = [&]() { + auto result = RunWithCapturedRIfPossible>([&]() { std::shared_ptr table; arrow::Status read_result; if (use_names) { @@ -74,31 +74,17 @@ std::shared_ptr ipc___feather___Reader__Read( } else { return arrow::Result>(read_result); } - }; + }); - if (!CanSafeCallIntoR()) { - return ValueOrStop(read_table()); - } else { - const auto& io_context = arrow::io::default_io_context(); - auto result = RunWithCapturedR>( - [&]() { return DeferNotOk(io_context.executor()->Submit(read_table)); }); - return ValueOrStop(result); - } + return ValueOrStop(result); } // [[arrow::export]] std::shared_ptr ipc___feather___Reader__Open( const std::shared_ptr& stream) { - if (!CanSafeCallIntoR()) { - return ValueOrStop(arrow::ipc::feather::Reader::Open(stream)); - } else { - const auto& io_context = arrow::io::default_io_context(); - auto result = RunWithCapturedR>([&]() { - return DeferNotOk(io_context.executor()->Submit( - [&]() { return arrow::ipc::feather::Reader::Open(stream); })); - }); - return ValueOrStop(result); - } + auto result = RunWithCapturedRIfPossible>( + [&]() { return arrow::ipc::feather::Reader::Open(stream); }); + return ValueOrStop(result); } // [[arrow::export]] diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h index 7e1f2c83599..775250ef573 100644 --- a/r/src/safe-call-into-r.h +++ b/r/src/safe-call-into-r.h @@ -20,6 +20,7 @@ #include "./arrow_types.h" +#include #include #include @@ -143,6 +144,9 @@ static inline arrow::Status SafeCallIntoRVoid(std::function fun) { return future.status(); } +// Performs an Arrow call (e.g., run an exec plan) in such a way that background threads +// can use SafeCallIntoR(). This version is useful for Arrow calls that already +// return a Future<>. template arrow::Result RunWithCapturedR(std::function()> make_arrow_call) { if (!CanSafeCallIntoR()) { @@ -168,4 +172,25 @@ arrow::Result RunWithCapturedR(std::function()> make_arrow_c return result; } +// Performs an Arrow call (e.g., run an exec plan) in such a way that background threads +// can use SafeCallIntoR(). This version is useful for Arrow calls that do not already +// return a Future<>(). If it is not possible to use RunWithCapturedR() (i.e., +// CanSafeCallIntoR() returns false), this will run make_arrow_call on the main +// R thread (which will cause background threads that try to SafeCallIntoR() to +// error). +template +arrow::Result RunWithCapturedRIfPossible( + std::function()> make_arrow_call) { + if (CanSafeCallIntoR()) { + // Note that the use of the io_context here is arbitrary (i.e. we could use + // any construct that launches a background thread). + const auto& io_context = arrow::io::default_io_context(); + return RunWithCapturedR([&]() { + return DeferNotOk(io_context.executor()->Submit(std::move(make_arrow_call))); + }); + } else { + return make_arrow_call(); + } +} + #endif From 0c1b8cf6fef4d9fad7a73b51bef0fe1e38d503cd Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 7 Jul 2022 10:50:32 -0300 Subject: [PATCH 45/92] better failure mode for calling user-defined functions when we can't execute them --- r/src/compute.cpp | 178 ++++++++++++----------- r/src/extension-impl.cpp | 25 ++-- r/src/io.cpp | 68 +++++---- r/src/safe-call-into-r.h | 12 +- r/tests/testthat/test-compute.R | 23 +++ r/tests/testthat/test-safe-call-into-r.R | 4 +- 6 files changed, 174 insertions(+), 136 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 59326e707cc..f29421b0b2a 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -591,24 +591,26 @@ class RScalarUDFOutputTypeResolver : public arrow::compute::OutputType::Resolver arrow::Result operator()( arrow::compute::KernelContext* context, const std::vector& descr) { - return SafeCallIntoR([&]() -> arrow::ValueDescr { - auto kernel = - reinterpret_cast(context->kernel()); - auto state = std::dynamic_pointer_cast(kernel->data); - - cpp11::writable::list input_types_sexp(descr.size()); - for (size_t i = 0; i < descr.size(); i++) { - input_types_sexp[i] = cpp11::to_r6(descr[i].type); - } - - cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp); - if (!Rf_inherits(output_type_sexp, "DataType")) { - cpp11::stop("arrow_scalar_function resolver must return a DataType"); - } - - return arrow::ValueDescr( - cpp11::as_cpp>(output_type_sexp)); - }); + return SafeCallIntoR( + [&]() -> arrow::ValueDescr { + auto kernel = + reinterpret_cast(context->kernel()); + auto state = std::dynamic_pointer_cast(kernel->data); + + cpp11::writable::list input_types_sexp(descr.size()); + for (size_t i = 0; i < descr.size(); i++) { + input_types_sexp[i] = cpp11::to_r6(descr[i].type); + } + + cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp); + if (!Rf_inherits(output_type_sexp, "DataType")) { + cpp11::stop("arrow_scalar_function resolver must return a DataType"); + } + + return arrow::ValueDescr( + cpp11::as_cpp>(output_type_sexp)); + }, + "resolve scalar user-defined function output data type"); } }; @@ -617,75 +619,77 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { arrow::Status operator()(arrow::compute::KernelContext* context, const arrow::compute::ExecSpan& span, arrow::compute::ExecResult* result) { - return SafeCallIntoRVoid([&]() { - auto kernel = - reinterpret_cast(context->kernel()); - auto state = std::dynamic_pointer_cast(kernel->data); - - cpp11::writable::list args_sexp(span.num_values()); - - for (int i = 0; i < span.num_values(); i++) { - const arrow::compute::ExecValue& exec_val = span[i]; - if (exec_val.is_array()) { - std::shared_ptr array = exec_val.array.ToArray(); - args_sexp[i] = cpp11::to_r6(array); - } else if (exec_val.is_scalar()) { - std::shared_ptr scalar = exec_val.scalar->Copy(); - args_sexp[i] = cpp11::to_r6(scalar); - } - } - - cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length); - - std::shared_ptr output_type = result->type()->Copy(); - cpp11::sexp output_type_sexp = cpp11::to_r6(output_type); - cpp11::writable::list udf_context = {batch_length_sexp, output_type_sexp}; - udf_context.names() = {"batch_length", "output_type"}; - - cpp11::sexp func_result_sexp = state->exec_func_(udf_context, args_sexp); - - if (Rf_inherits(func_result_sexp, "Array")) { - auto array = cpp11::as_cpp>(func_result_sexp); - - // handle an Array result of the wrong type - if (!result->type()->Equals(array->type())) { - arrow::Datum out = - ValueOrStop(arrow::compute::Cast(array, result->type()->Copy())); - std::shared_ptr out_array = out.make_array(); - array.swap(out_array); - } - - // make sure we assign the type that the result is expecting - if (result->is_array_data()) { - result->value = std::move(array->data()); - } else if (array->length() == 1) { - result->value = ValueOrStop(array->GetScalar(0)); - } else { - cpp11::stop("expected Scalar return value but got Array with length != 1"); - } - } else if (Rf_inherits(func_result_sexp, "Scalar")) { - auto scalar = cpp11::as_cpp>(func_result_sexp); - - // handle a Scalar result of the wrong type - if (!result->type()->Equals(scalar->type)) { - arrow::Datum out = - ValueOrStop(arrow::compute::Cast(scalar, result->type()->Copy())); - std::shared_ptr out_scalar = out.scalar(); - scalar.swap(out_scalar); - } - - // make sure we assign the type that the result is expecting - if (result->is_scalar()) { - result->value = std::move(scalar); - } else { - auto array = ValueOrStop( - arrow::MakeArrayFromScalar(*scalar, span.length, context->memory_pool())); - result->value = std::move(array->data()); - } - } else { - cpp11::stop("arrow_scalar_function must return an Array or Scalar"); - } - }); + return SafeCallIntoRVoid( + [&]() { + auto kernel = + reinterpret_cast(context->kernel()); + auto state = std::dynamic_pointer_cast(kernel->data); + + cpp11::writable::list args_sexp(span.num_values()); + + for (int i = 0; i < span.num_values(); i++) { + const arrow::compute::ExecValue& exec_val = span[i]; + if (exec_val.is_array()) { + std::shared_ptr array = exec_val.array.ToArray(); + args_sexp[i] = cpp11::to_r6(array); + } else if (exec_val.is_scalar()) { + std::shared_ptr scalar = exec_val.scalar->Copy(); + args_sexp[i] = cpp11::to_r6(scalar); + } + } + + cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length); + + std::shared_ptr output_type = result->type()->Copy(); + cpp11::sexp output_type_sexp = cpp11::to_r6(output_type); + cpp11::writable::list udf_context = {batch_length_sexp, output_type_sexp}; + udf_context.names() = {"batch_length", "output_type"}; + + cpp11::sexp func_result_sexp = state->exec_func_(udf_context, args_sexp); + + if (Rf_inherits(func_result_sexp, "Array")) { + auto array = cpp11::as_cpp>(func_result_sexp); + + // handle an Array result of the wrong type + if (!result->type()->Equals(array->type())) { + arrow::Datum out = + ValueOrStop(arrow::compute::Cast(array, result->type()->Copy())); + std::shared_ptr out_array = out.make_array(); + array.swap(out_array); + } + + // make sure we assign the type that the result is expecting + if (result->is_array_data()) { + result->value = std::move(array->data()); + } else if (array->length() == 1) { + result->value = ValueOrStop(array->GetScalar(0)); + } else { + cpp11::stop("expected Scalar return value but got Array with length != 1"); + } + } else if (Rf_inherits(func_result_sexp, "Scalar")) { + auto scalar = cpp11::as_cpp>(func_result_sexp); + + // handle a Scalar result of the wrong type + if (!result->type()->Equals(scalar->type)) { + arrow::Datum out = + ValueOrStop(arrow::compute::Cast(scalar, result->type()->Copy())); + std::shared_ptr out_scalar = out.scalar(); + scalar.swap(out_scalar); + } + + // make sure we assign the type that the result is expecting + if (result->is_scalar()) { + result->value = std::move(scalar); + } else { + auto array = ValueOrStop(arrow::MakeArrayFromScalar( + *scalar, span.length, context->memory_pool())); + result->value = std::move(array->data()); + } + } else { + cpp11::stop("arrow_scalar_function must return an Array or Scalar"); + } + }, + "execute scalar user-defined function"); } }; diff --git a/r/src/extension-impl.cpp b/r/src/extension-impl.cpp index efb9f0f4675..e6efcf36479 100644 --- a/r/src/extension-impl.cpp +++ b/r/src/extension-impl.cpp @@ -38,18 +38,19 @@ bool RExtensionType::ExtensionEquals(const arrow::ExtensionType& other) const { // With any ambiguity, we need to materialize the R6 instance and call its // ExtensionEquals method. We can't do this on the non-R thread. - // After ARROW-15841, we can use SafeCallIntoR. - arrow::Result result = SafeCallIntoR([&]() { - cpp11::environment instance = r6_instance(); - cpp11::function instance_ExtensionEquals(instance["ExtensionEquals"]); - - std::shared_ptr other_shared = - ValueOrStop(other.Deserialize(other.storage_type(), other.Serialize())); - cpp11::sexp other_r6 = cpp11::to_r6(other_shared, "ExtensionType"); - - cpp11::logicals result(instance_ExtensionEquals(other_r6)); - return cpp11::as_cpp(result); - }); + arrow::Result result = SafeCallIntoR( + [&]() { + cpp11::environment instance = r6_instance(); + cpp11::function instance_ExtensionEquals(instance["ExtensionEquals"]); + + std::shared_ptr other_shared = + ValueOrStop(other.Deserialize(other.storage_type(), other.Serialize())); + cpp11::sexp other_r6 = cpp11::to_r6(other_shared, "ExtensionType"); + + cpp11::logicals result(instance_ExtensionEquals(other_r6)); + return cpp11::as_cpp(result); + }, + "RExtensionType$ExtensionEquals()"); if (!result.ok()) { throw std::runtime_error(result.status().message()); diff --git a/r/src/io.cpp b/r/src/io.cpp index 42766ddd2f5..321b1b17feb 100644 --- a/r/src/io.cpp +++ b/r/src/io.cpp @@ -223,8 +223,8 @@ class RConnectionFileInterface : public virtual arrow::io::FileInterface { closed_ = true; - return SafeCallIntoRVoid( - [&]() { cpp11::package("base")["close"](connection_sexp_); }); + return SafeCallIntoRVoid([&]() { cpp11::package("base")["close"](connection_sexp_); }, + "close() on R connection"); } arrow::Result Tell() const { @@ -232,10 +232,12 @@ class RConnectionFileInterface : public virtual arrow::io::FileInterface { return arrow::Status::IOError("R connection is closed"); } - return SafeCallIntoR([&]() { - cpp11::sexp result = cpp11::package("base")["seek"](connection_sexp_); - return cpp11::as_cpp(result); - }); + return SafeCallIntoR( + [&]() { + cpp11::sexp result = cpp11::package("base")["seek"](connection_sexp_); + return cpp11::as_cpp(result); + }, + "tell() on R connection"); } bool closed() const { return closed_; } @@ -251,17 +253,19 @@ class RConnectionFileInterface : public virtual arrow::io::FileInterface { return arrow::Status::IOError("R connection is closed"); } - return SafeCallIntoR([&] { - cpp11::function read_bin = cpp11::package("base")["readBin"]; - cpp11::writable::raws ptype((R_xlen_t)0); - cpp11::integers n = cpp11::as_sexp(nbytes); + return SafeCallIntoR( + [&] { + cpp11::function read_bin = cpp11::package("base")["readBin"]; + cpp11::writable::raws ptype((R_xlen_t)0); + cpp11::integers n = cpp11::as_sexp(nbytes); - cpp11::sexp result = read_bin(connection_sexp_, ptype, n); + cpp11::sexp result = read_bin(connection_sexp_, ptype, n); - int64_t result_size = cpp11::safe[Rf_xlength](result); - memcpy(out, cpp11::safe[RAW](result), result_size); - return result_size; - }); + int64_t result_size = cpp11::safe[Rf_xlength](result); + memcpy(out, cpp11::safe[RAW](result), result_size); + return result_size; + }, + "readBin() on R connection"); } arrow::Result> ReadBase(int64_t nbytes) { @@ -278,13 +282,15 @@ class RConnectionFileInterface : public virtual arrow::io::FileInterface { return arrow::Status::IOError("R connection is closed"); } - return SafeCallIntoRVoid([&]() { - cpp11::writable::raws data_raw(nbytes); - memcpy(cpp11::safe[RAW](data_raw), data, nbytes); - - cpp11::function write_bin = cpp11::package("base")["writeBin"]; - write_bin(data_raw, connection_sexp_); - }); + return SafeCallIntoRVoid( + [&]() { + cpp11::writable::raws data_raw(nbytes); + memcpy(cpp11::safe[RAW](data_raw), data, nbytes); + + cpp11::function write_bin = cpp11::package("base")["writeBin"]; + write_bin(data_raw, connection_sexp_); + }, + "writeBin() on R connection"); } arrow::Status SeekBase(int64_t pos) { @@ -292,9 +298,11 @@ class RConnectionFileInterface : public virtual arrow::io::FileInterface { return arrow::Status::IOError("R connection is closed"); } - return SafeCallIntoRVoid([&]() { - cpp11::package("base")["seek"](connection_sexp_, cpp11::as_sexp(pos)); - }); + return SafeCallIntoRVoid( + [&]() { + cpp11::package("base")["seek"](connection_sexp_, cpp11::as_sexp(pos)); + }, + "seek() on R connection"); } private: @@ -305,10 +313,12 @@ class RConnectionFileInterface : public virtual arrow::io::FileInterface { return true; } - auto is_open_result = SafeCallIntoR([&]() { - cpp11::sexp result = cpp11::package("base")["isOpen"](connection_sexp_); - return cpp11::as_cpp(result); - }); + auto is_open_result = SafeCallIntoR( + [&]() { + cpp11::sexp result = cpp11::package("base")["isOpen"](connection_sexp_); + return cpp11::as_cpp(result); + }, + "isOpen() on R connection"); if (!is_open_result.ok()) { closed_ = true; diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h index 775250ef573..d98abcf77aa 100644 --- a/r/src/safe-call-into-r.h +++ b/r/src/safe-call-into-r.h @@ -100,7 +100,7 @@ MainRThread& GetMainRThread(); // a SEXP (use cpp11::as_cpp to convert it to a C++ type inside // `fun`). template -arrow::Future SafeCallIntoRAsync(std::function(void)> fun) { +arrow::Future SafeCallIntoRAsync(std::function(void)> fun, std::string reason = "unspecified") { MainRThread& main_r_thread = GetMainRThread(); if (main_r_thread.IsMainThread()) { // If we're on the main thread, run the task immediately and let @@ -126,21 +126,21 @@ arrow::Future SafeCallIntoRAsync(std::function(void)> fun) { })); } else { return arrow::Status::NotImplemented( - "Call to R from a non-R thread without calling RunWithCapturedR"); + "Call to R (", reason, ") from a non-R thread from an unsupported context"); } } template -arrow::Result SafeCallIntoR(std::function fun) { - arrow::Future future = SafeCallIntoRAsync(std::move(fun)); +arrow::Result SafeCallIntoR(std::function fun, std::string reason = "unspecified") { + arrow::Future future = SafeCallIntoRAsync(std::move(fun), reason); return future.result(); } -static inline arrow::Status SafeCallIntoRVoid(std::function fun) { +static inline arrow::Status SafeCallIntoRVoid(std::function fun, std::string reason = "unspecified") { arrow::Future future = SafeCallIntoRAsync([&fun]() { fun(); return true; - }); + }, reason); return future.status(); } diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 353e73f611d..b82aee56a4f 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -156,3 +156,26 @@ test_that("user-defined functions work during multi-threaded execution", { expect_identical(result2$fun_result, example_df$value * 32) }) + +test_that("user-defined error when called from an unsupported context", { + skip_if_not_available("dataset") + + times_32_wrapper <- arrow_scalar_function( + function(x) x * 32.0, + int32(), float64() + ) + + register_user_defined_function(times_32_wrapper, "times_32") + + stream_plan_with_udf <- function() { + rbr <- record_batch(a = 1:1000) %>% + dplyr::mutate(b = times_32(a)) %>% + as_record_batch_reader() + rbr$read_table() + } + + expect_error( + stream_plan_with_udf(), + "Call to R \\(.*?\\) from a non-R thread from an unsupported context" + ) +}) diff --git a/r/tests/testthat/test-safe-call-into-r.R b/r/tests/testthat/test-safe-call-into-r.R index a8027ac4237..d3d1d341010 100644 --- a/r/tests/testthat/test-safe-call-into-r.R +++ b/r/tests/testthat/test-safe-call-into-r.R @@ -52,11 +52,11 @@ test_that("SafeCallIntoR errors from the non-R thread", { expect_error( TestSafeCallIntoR(function() "string one!", opt = "async_without_executor"), - "Call to R from a non-R thread" + "Call to R \\(unspecified\\) from a non-R thread" ) expect_error( TestSafeCallIntoR(function() stop("an error!"), opt = "async_without_executor"), - "Call to R from a non-R thread" + "Call to R \\(unspecified\\) from a non-R thread" ) }) From 017f681bb678895797508fc2734d19b1bf464b9f Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 7 Jul 2022 11:35:20 -0300 Subject: [PATCH 46/92] fix + clarify registration --- r/R/compute.R | 11 +++++------ r/R/dplyr-funcs.R | 11 ++++++++++- r/man/register_binding.Rd | 6 +++++- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index a58ef0a73ff..1b6d0ec30a3 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -382,18 +382,17 @@ register_user_defined_function <- function(scalar_function, name) { inherits(scalar_function, "arrow_advanced_scalar_function") ) - # register with Arrow C++ + # register with Arrow C++ function registry (enables its use in + # call_function() and Expression$create()) RegisterScalarUDF(name, scalar_function) - # register with dplyr bindings + # register with dplyr binding (enables its use in mutate(), filter(), etc.) register_binding( name, - function(...) build_expr(name, ...) + function(...) build_expr(name, ...), + update_cache = TRUE ) - # recreate dplyr binding cache - create_binding_cache() - invisible(NULL) } diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R index 7c4ed99e2ed..1df83746744 100644 --- a/r/R/dplyr-funcs.R +++ b/r/R/dplyr-funcs.R @@ -50,6 +50,9 @@ NULL #' - `fun`: string function name #' - `data`: `Expression` (these are all currently a single field) #' - `options`: list of function options, as passed to call_function +#' @param update_cache Update .cache$functions at the time of registration. +#' the default is FALSE because the majority of usage is to register +#' bindings at package load, after which we create the cache once. #' @param registry An environment in which the functions should be #' assigned. #' @@ -57,7 +60,7 @@ NULL #' registered function existed. #' @keywords internal #' -register_binding <- function(fun_name, fun, registry = nse_funcs) { +register_binding <- function(fun_name, fun, registry = nse_funcs, update_cache = FALSE) { unqualified_name <- sub("^.*?:{+}", "", fun_name) previous_fun <- registry[[unqualified_name]] @@ -80,6 +83,12 @@ register_binding <- function(fun_name, fun, registry = nse_funcs) { registry[[unqualified_name]] <- fun } + if (update_cache) { + fun_cache <- .cache$functions + fun_cache[[name]] <- fun + .cache$functions <- fun_cache + } + invisible(previous_fun) } diff --git a/r/man/register_binding.Rd b/r/man/register_binding.Rd index e776e7b3f5b..77049c24c3b 100644 --- a/r/man/register_binding.Rd +++ b/r/man/register_binding.Rd @@ -4,7 +4,7 @@ \alias{register_binding} \title{Register compute bindings} \usage{ -register_binding(fun_name, fun, registry = nse_funcs) +register_binding(fun_name, fun, registry = nse_funcs, update_cache = FALSE) } \arguments{ \item{fun_name}{A string containing a function name in the form \code{"function"} or @@ -18,6 +18,10 @@ This function must accept \code{Expression} objects as arguments and return \item{registry}{An environment in which the functions should be assigned.} +\item{update_cache}{Update .cache$functions at the time of registration. +the default is FALSE because the majority of usage is to register +bindings at package load, after which we create the cache once.} + \item{agg_fun}{An aggregate function or \code{NULL} to un-register a previous aggregate function. This function must accept \code{Expression} objects as arguments and return a \code{list()} with components: From 85519a20e1f9eac829f47928b757e525c9fc0155 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 7 Jul 2022 11:38:08 -0300 Subject: [PATCH 47/92] don't namespace rlang:: --- r/R/compute.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 1b6d0ec30a3..ccd07cb8dec 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -399,7 +399,7 @@ register_user_defined_function <- function(scalar_function, name) { #' @rdname register_user_defined_function #' @export arrow_scalar_function <- function(fun, in_type, out_type) { - fun <- rlang::as_function(fun) + fun <- as_function(fun) # create a small wrapper that converts Scalar/Array arguments to R vectors # and converts the result back to an Array @@ -429,7 +429,7 @@ arrow_advanced_scalar_function <- function(advanced_fun, in_type, out_type) { out_type <- rep_len(out_type, length(in_type)) - advanced_fun <- rlang::as_function(advanced_fun) + advanced_fun <- as_function(advanced_fun) if (length(formals(advanced_fun)) != 2) { abort("`advanced_fun` must accept exactly two arguments") } From 9f251fcde9209648e2c9e8e8de117422d0130dcd Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 7 Jul 2022 11:56:50 -0300 Subject: [PATCH 48/92] clang-format --- r/src/safe-call-into-r.h | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h index d98abcf77aa..94c8e5dec31 100644 --- a/r/src/safe-call-into-r.h +++ b/r/src/safe-call-into-r.h @@ -100,7 +100,8 @@ MainRThread& GetMainRThread(); // a SEXP (use cpp11::as_cpp to convert it to a C++ type inside // `fun`). template -arrow::Future SafeCallIntoRAsync(std::function(void)> fun, std::string reason = "unspecified") { +arrow::Future SafeCallIntoRAsync(std::function(void)> fun, + std::string reason = "unspecified") { MainRThread& main_r_thread = GetMainRThread(); if (main_r_thread.IsMainThread()) { // If we're on the main thread, run the task immediately and let @@ -131,16 +132,20 @@ arrow::Future SafeCallIntoRAsync(std::function(void)> fun, s } template -arrow::Result SafeCallIntoR(std::function fun, std::string reason = "unspecified") { +arrow::Result SafeCallIntoR(std::function fun, + std::string reason = "unspecified") { arrow::Future future = SafeCallIntoRAsync(std::move(fun), reason); return future.result(); } -static inline arrow::Status SafeCallIntoRVoid(std::function fun, std::string reason = "unspecified") { - arrow::Future future = SafeCallIntoRAsync([&fun]() { - fun(); - return true; - }, reason); +static inline arrow::Status SafeCallIntoRVoid(std::function fun, + std::string reason = "unspecified") { + arrow::Future future = SafeCallIntoRAsync( + [&fun]() { + fun(); + return true; + }, + reason); return future.status(); } From ebc0b84a900260b500b9706c1e58996d3ce6e658 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 7 Jul 2022 14:44:10 -0300 Subject: [PATCH 49/92] constrain Arity specification to a fixed number of arguments (per function) for now --- r/src/compute.cpp | 42 ++++++++++++++++++++++++++++----- r/tests/testthat/test-compute.R | 22 +++++++++++++++++ 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index f29421b0b2a..7f717d17967 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -695,16 +695,46 @@ class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { // [[arrow::export]] void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { - const arrow::compute::FunctionDoc dummy_function_doc{ - "A user-defined R function", "returns something", {"..."}}; - - auto func = std::make_shared( - name, arrow::compute::Arity::VarArgs(), dummy_function_doc); - cpp11::list in_type_r(func_sexp.attr("in_type")); cpp11::list out_type_r(func_sexp.attr("out_type")); R_xlen_t n_kernels = in_type_r.size(); + if (n_kernels == 0) { + cpp11::stop("Can't register user-defined function with zero kernels"); + } + + // compute the Arity from the list of input kernels + std::vector n_args(n_kernels); + for (R_xlen_t i = 0; i < n_kernels; i++) { + auto in_types = cpp11::as_cpp>(in_type_r[i]); + n_args[i] = in_types->num_fields(); + } + + int64_t min_args = *std::min_element(n_args.begin(), n_args.end()); + int64_t max_args = *std::max_element(n_args.begin(), n_args.end()); + + // We can't currently handle variable numbers of arguments in a user-defined + // function and we don't have a mechanism for the user to specify a variable + // number of arguments at the end of a signature. + if (min_args != max_args) { + cpp11::stop( + "User-defined function with a variable number of arguments is not supported"); + } + + arrow::compute::Arity arity(min_args, false); + + // The function documentation isn't currently accessible from R but is required + // for the C++ function constructor. + std::vector dummy_argument_names(min_args); + for (int64_t i = 0; i < min_args; i++) { + dummy_argument_names[i] = "arg"; + } + const arrow::compute::FunctionDoc dummy_function_doc{ + "A user-defined R function", "returns something", std::move(dummy_argument_names)}; + + auto func = + std::make_shared(name, arity, dummy_function_doc); + for (R_xlen_t i = 0; i < n_kernels; i++) { auto in_types = cpp11::as_cpp>(in_type_r[i]); cpp11::sexp out_type_func = out_type_r[i]; diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index b82aee56a4f..f1fd274e821 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -110,6 +110,28 @@ test_that("register_user_defined_function() adds a compute function to the regis ) }) +test_that("register_user_defined_function() errors for unsupported specifications", { + no_kernel_wrapper <- arrow_scalar_function( + function(...) NULL, + list(), + list() + ) + expect_error( + register_user_defined_function(no_kernel_wrapper, "no_kernels"), + "Can't register user-defined function with zero kernels" + ) + + varargs_kernel_wrapper <- arrow_scalar_function( + function(...) NULL, + list(float64(), schema(x = float64(), y = float64())), + list(float64()) + ) + expect_error( + register_user_defined_function(varargs_kernel_wrapper, "var_kernels"), + "User-defined function with a variable number of arguments is not supported" + ) +}) + test_that("user-defined functions work during multi-threaded execution", { skip_if_not_available("dataset") From ed735e14ff9e4ee1eb3f7b62c28b414703260e4b Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 8 Jul 2022 08:59:10 -0300 Subject: [PATCH 50/92] Update r/src/compute.cpp Co-authored-by: Vibhatha Lakmal Abeykoon --- r/src/compute.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 7f717d17967..ac2e9df02fd 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -710,8 +710,8 @@ void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { n_args[i] = in_types->num_fields(); } - int64_t min_args = *std::min_element(n_args.begin(), n_args.end()); - int64_t max_args = *std::max_element(n_args.begin(), n_args.end()); + const int64_t min_args = *std::min_element(n_args.begin(), n_args.end()); + const int64_t max_args = *std::max_element(n_args.begin(), n_args.end()); // We can't currently handle variable numbers of arguments in a user-defined // function and we don't have a mechanism for the user to specify a variable From 8877c1216015016f8581ee792abe190359f8cbbd Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 8 Jul 2022 08:59:21 -0300 Subject: [PATCH 51/92] Update r/src/compute.cpp Co-authored-by: Vibhatha Lakmal Abeykoon --- r/src/compute.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index ac2e9df02fd..4dd4b4658ef 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -739,9 +739,9 @@ void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { auto in_types = cpp11::as_cpp>(in_type_r[i]); cpp11::sexp out_type_func = out_type_r[i]; - std::vector compute_in_types; + std::vector compute_in_types(in_types->num_fields()); for (int64_t j = 0; j < in_types->num_fields(); j++) { - compute_in_types.push_back(arrow::compute::InputType(in_types->field(i)->type())); + compute_in_types.emplace_back(arrow::compute::InputType(in_types->field(j)->type())); } arrow::compute::OutputType out_type((RScalarUDFOutputTypeResolver())); From a89ce07a845e4416e28bdef12c833678bf9ba635 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 8 Jul 2022 08:59:39 -0300 Subject: [PATCH 52/92] Update r/src/compute.cpp Co-authored-by: Vibhatha Lakmal Abeykoon --- r/src/compute.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 4dd4b4658ef..e68f05ab38d 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -748,7 +748,7 @@ void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { auto signature = std::make_shared( compute_in_types, std::move(out_type), true); - arrow::compute::ScalarKernel kernel(signature, RScalarUDFCallable()); + arrow::compute::ScalarKernel kernel(std::move(signature), RScalarUDFCallable()); kernel.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE; kernel.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; kernel.data = std::make_shared(func_sexp, out_type_func); From f6874519acc2cdd4332bfd66aabe3fbbfe62fa8f Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 8 Jul 2022 09:04:09 -0300 Subject: [PATCH 53/92] clang-format --- r/src/compute.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index e68f05ab38d..33cbf49dc53 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -741,7 +741,8 @@ void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { std::vector compute_in_types(in_types->num_fields()); for (int64_t j = 0; j < in_types->num_fields(); j++) { - compute_in_types.emplace_back(arrow::compute::InputType(in_types->field(j)->type())); + compute_in_types.emplace_back( + arrow::compute::InputType(in_types->field(j)->type())); } arrow::compute::OutputType out_type((RScalarUDFOutputTypeResolver())); From 2e9e26168c39da4e76ea8ffe2946fbc24bcbd51a Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 8 Jul 2022 09:12:47 -0300 Subject: [PATCH 54/92] more readable path when RunWithCapturedR does not return a Result --- r/src/compute-exec.cpp | 7 ++++--- r/src/safe-call-into-r.h | 12 ++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index 4821361357d..afddf8dfa46 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -226,13 +226,14 @@ void ExecPlan_Write( ds::WriteNodeOptions{std::move(opts), std::move(kv)}); StopIfNotOk(plan->Validate()); - auto result = RunWithCapturedRIfPossible([&]() -> arrow::Result { + + arrow::Status result = RunWithCapturedRIfPossibleVoid([&]() { RETURN_NOT_OK(plan->StartProducing()); RETURN_NOT_OK(plan->finished().status()); - return true; + arrow::Status::OK(); }); - StopIfNotOk(result.status()); + StopIfNotOk(result); } #endif diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h index 94c8e5dec31..d1791b93ad1 100644 --- a/r/src/safe-call-into-r.h +++ b/r/src/safe-call-into-r.h @@ -198,4 +198,16 @@ arrow::Result RunWithCapturedRIfPossible( } } +// Like RunWithCapturedRIfPossible<>() but for arrow calls that don't return +// a Result. +arrow::Status RunWithCapturedRIfPossibleVoid( + std::function make_arrow_call) { + auto result = RunWithCapturedRIfPossible([&]() -> arrow::Result { + ARROW_RETURN_NOT_OK(make_arrow_call()); + return true; + }); + ARROW_RETURN_NOT_OK(result); + return arrow::Status::OK(); +} + #endif From 514d91ef0ae940c7689907f91bf500e1c1860063 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 8 Jul 2022 09:18:44 -0300 Subject: [PATCH 55/92] fix the void version of RunWithCapturedR --- r/src/compute-exec.cpp | 2 +- r/src/safe-call-into-r.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index afddf8dfa46..e348675fc17 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -230,7 +230,7 @@ void ExecPlan_Write( arrow::Status result = RunWithCapturedRIfPossibleVoid([&]() { RETURN_NOT_OK(plan->StartProducing()); RETURN_NOT_OK(plan->finished().status()); - arrow::Status::OK(); + return arrow::Status::OK(); }); StopIfNotOk(result); diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h index d1791b93ad1..95edd7468e7 100644 --- a/r/src/safe-call-into-r.h +++ b/r/src/safe-call-into-r.h @@ -200,7 +200,7 @@ arrow::Result RunWithCapturedRIfPossible( // Like RunWithCapturedRIfPossible<>() but for arrow calls that don't return // a Result. -arrow::Status RunWithCapturedRIfPossibleVoid( +static inline arrow::Status RunWithCapturedRIfPossibleVoid( std::function make_arrow_call) { auto result = RunWithCapturedRIfPossible([&]() -> arrow::Result { ARROW_RETURN_NOT_OK(make_arrow_call()); From 58c857380eb930b3228612e5d0c429b09e544993 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 8 Jul 2022 10:25:29 -0300 Subject: [PATCH 56/92] fix kernel signature assignment --- r/src/compute.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 33cbf49dc53..34ed658985e 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -741,15 +741,14 @@ void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { std::vector compute_in_types(in_types->num_fields()); for (int64_t j = 0; j < in_types->num_fields(); j++) { - compute_in_types.emplace_back( - arrow::compute::InputType(in_types->field(j)->type())); + compute_in_types[i] = arrow::compute::InputType(in_types->field(j)->type()); } arrow::compute::OutputType out_type((RScalarUDFOutputTypeResolver())); auto signature = std::make_shared( - compute_in_types, std::move(out_type), true); - arrow::compute::ScalarKernel kernel(std::move(signature), RScalarUDFCallable()); + std::move(compute_in_types), std::move(out_type), true); + arrow::compute::ScalarKernel kernel(signature, RScalarUDFCallable()); kernel.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE; kernel.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; kernel.data = std::make_shared(func_sexp, out_type_func); From b4154afef18ae0cbdec9220a15ea0cc3eedd3b7f Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 8 Jul 2022 11:50:00 -0300 Subject: [PATCH 57/92] fix for updated master --- r/src/compute.cpp | 187 +++++++++++++++++++++------------------------- 1 file changed, 86 insertions(+), 101 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 34ed658985e..6b0e9f28779 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -586,112 +586,97 @@ class RScalarUDFKernelState : public arrow::compute::KernelState { cpp11::function resolver_; }; -class RScalarUDFOutputTypeResolver : public arrow::compute::OutputType::Resolver { - public: - arrow::Result operator()( - arrow::compute::KernelContext* context, - const std::vector& descr) { - return SafeCallIntoR( - [&]() -> arrow::ValueDescr { - auto kernel = - reinterpret_cast(context->kernel()); - auto state = std::dynamic_pointer_cast(kernel->data); - - cpp11::writable::list input_types_sexp(descr.size()); - for (size_t i = 0; i < descr.size(); i++) { - input_types_sexp[i] = cpp11::to_r6(descr[i].type); - } +arrow::Result ResolveScalarUDFOutputType( + arrow::compute::KernelContext* context, + const std::vector& input_types) { + return SafeCallIntoR( + [&]() -> arrow::TypeHolder { + auto kernel = + reinterpret_cast(context->kernel()); + auto state = std::dynamic_pointer_cast(kernel->data); + + cpp11::writable::list input_types_sexp(input_types.size()); + for (size_t i = 0; i < input_types.size(); i++) { + input_types_sexp[i] = + cpp11::to_r6(input_types[i].GetSharedPtr()); + } + + cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp); + if (!Rf_inherits(output_type_sexp, "DataType")) { + cpp11::stop("arrow_scalar_function resolver must return a DataType"); + } + + return arrow::TypeHolder( + cpp11::as_cpp>(output_type_sexp)); + }, + "resolve scalar user-defined function output data type"); +} - cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp); - if (!Rf_inherits(output_type_sexp, "DataType")) { - cpp11::stop("arrow_scalar_function resolver must return a DataType"); +arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context, + const arrow::compute::ExecSpan& span, + arrow::compute::ExecResult* result) { + if (result->is_array_span()) { + return arrow::Status::NotImplemented("ArraySpan result from R scalar UDF"); + } + + return SafeCallIntoRVoid( + [&]() { + auto kernel = + reinterpret_cast(context->kernel()); + auto state = std::dynamic_pointer_cast(kernel->data); + + cpp11::writable::list args_sexp(span.num_values()); + + for (int i = 0; i < span.num_values(); i++) { + const arrow::compute::ExecValue& exec_val = span[i]; + if (exec_val.is_array()) { + std::shared_ptr array = exec_val.array.ToArray(); + args_sexp[i] = cpp11::to_r6(array); + } else if (exec_val.is_scalar()) { + std::shared_ptr scalar = exec_val.scalar->GetSharedPtr(); + args_sexp[i] = cpp11::to_r6(scalar); } + } - return arrow::ValueDescr( - cpp11::as_cpp>(output_type_sexp)); - }, - "resolve scalar user-defined function output data type"); - } -}; + cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length); -class RScalarUDFCallable : public arrow::compute::ArrayKernelExec { - public: - arrow::Status operator()(arrow::compute::KernelContext* context, - const arrow::compute::ExecSpan& span, - arrow::compute::ExecResult* result) { - return SafeCallIntoRVoid( - [&]() { - auto kernel = - reinterpret_cast(context->kernel()); - auto state = std::dynamic_pointer_cast(kernel->data); - - cpp11::writable::list args_sexp(span.num_values()); - - for (int i = 0; i < span.num_values(); i++) { - const arrow::compute::ExecValue& exec_val = span[i]; - if (exec_val.is_array()) { - std::shared_ptr array = exec_val.array.ToArray(); - args_sexp[i] = cpp11::to_r6(array); - } else if (exec_val.is_scalar()) { - std::shared_ptr scalar = exec_val.scalar->Copy(); - args_sexp[i] = cpp11::to_r6(scalar); - } + std::shared_ptr output_type = result->type()->GetSharedPtr(); + cpp11::sexp output_type_sexp = cpp11::to_r6(output_type); + cpp11::writable::list udf_context = {batch_length_sexp, output_type_sexp}; + udf_context.names() = {"batch_length", "output_type"}; + + cpp11::sexp func_result_sexp = state->exec_func_(udf_context, args_sexp); + + if (Rf_inherits(func_result_sexp, "Array")) { + auto array = cpp11::as_cpp>(func_result_sexp); + + // handle an Array result of the wrong type + if (!result->type()->Equals(array->type())) { + arrow::Datum out = ValueOrStop(arrow::compute::Cast(array, result->type())); + std::shared_ptr out_array = out.make_array(); + array.swap(out_array); } - cpp11::sexp batch_length_sexp = cpp11::as_sexp(span.length); - - std::shared_ptr output_type = result->type()->Copy(); - cpp11::sexp output_type_sexp = cpp11::to_r6(output_type); - cpp11::writable::list udf_context = {batch_length_sexp, output_type_sexp}; - udf_context.names() = {"batch_length", "output_type"}; - - cpp11::sexp func_result_sexp = state->exec_func_(udf_context, args_sexp); - - if (Rf_inherits(func_result_sexp, "Array")) { - auto array = cpp11::as_cpp>(func_result_sexp); - - // handle an Array result of the wrong type - if (!result->type()->Equals(array->type())) { - arrow::Datum out = - ValueOrStop(arrow::compute::Cast(array, result->type()->Copy())); - std::shared_ptr out_array = out.make_array(); - array.swap(out_array); - } - - // make sure we assign the type that the result is expecting - if (result->is_array_data()) { - result->value = std::move(array->data()); - } else if (array->length() == 1) { - result->value = ValueOrStop(array->GetScalar(0)); - } else { - cpp11::stop("expected Scalar return value but got Array with length != 1"); - } - } else if (Rf_inherits(func_result_sexp, "Scalar")) { - auto scalar = cpp11::as_cpp>(func_result_sexp); - - // handle a Scalar result of the wrong type - if (!result->type()->Equals(scalar->type)) { - arrow::Datum out = - ValueOrStop(arrow::compute::Cast(scalar, result->type()->Copy())); - std::shared_ptr out_scalar = out.scalar(); - scalar.swap(out_scalar); - } - - // make sure we assign the type that the result is expecting - if (result->is_scalar()) { - result->value = std::move(scalar); - } else { - auto array = ValueOrStop(arrow::MakeArrayFromScalar( - *scalar, span.length, context->memory_pool())); - result->value = std::move(array->data()); - } - } else { - cpp11::stop("arrow_scalar_function must return an Array or Scalar"); + result->value = std::move(array->data()); + } else if (Rf_inherits(func_result_sexp, "Scalar")) { + auto scalar = cpp11::as_cpp>(func_result_sexp); + + // handle a Scalar result of the wrong type + if (!result->type()->Equals(scalar->type)) { + arrow::Datum out = ValueOrStop(arrow::compute::Cast(scalar, result->type())); + std::shared_ptr out_scalar = out.scalar(); + scalar.swap(out_scalar); } - }, - "execute scalar user-defined function"); - } -}; + + auto array = ValueOrStop( + arrow::MakeArrayFromScalar(*scalar, span.length, context->memory_pool())); + result->value = std::move(array->data()); + } else { + cpp11::stop("arrow_scalar_function must return an Array or Scalar"); + } + }, + "execute scalar user-defined function"); +} // [[arrow::export]] void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { @@ -744,11 +729,11 @@ void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { compute_in_types[i] = arrow::compute::InputType(in_types->field(j)->type()); } - arrow::compute::OutputType out_type((RScalarUDFOutputTypeResolver())); + arrow::compute::OutputType out_type((&ResolveScalarUDFOutputType)); auto signature = std::make_shared( std::move(compute_in_types), std::move(out_type), true); - arrow::compute::ScalarKernel kernel(signature, RScalarUDFCallable()); + arrow::compute::ScalarKernel kernel(signature, &CallRScalarUDF); kernel.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE; kernel.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; kernel.data = std::make_shared(func_sexp, out_type_func); From 1ed6d25f3c6fcd0eb2958dbc8d6311b51c89ab1c Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 11 Jul 2022 12:00:15 -0300 Subject: [PATCH 58/92] inline some short variable definitions --- r/src/compute.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 6b0e9f28779..b2b4c1b3083 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -630,11 +630,9 @@ arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context, for (int i = 0; i < span.num_values(); i++) { const arrow::compute::ExecValue& exec_val = span[i]; if (exec_val.is_array()) { - std::shared_ptr array = exec_val.array.ToArray(); - args_sexp[i] = cpp11::to_r6(array); + args_sexp[i] = cpp11::to_r6(exec_val.array.ToArray()); } else if (exec_val.is_scalar()) { - std::shared_ptr scalar = exec_val.scalar->GetSharedPtr(); - args_sexp[i] = cpp11::to_r6(scalar); + args_sexp[i] = cpp11::to_r6(exec_val.scalar->GetSharedPtr()); } } From 88bf4d2f0d6ec1df37e09d210da6bb947ec96ba1 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 11 Jul 2022 12:07:28 -0300 Subject: [PATCH 59/92] documentation updates --- r/R/compute.R | 2 +- r/R/dplyr-funcs.R | 6 +++++- r/man/register_binding.Rd | 6 +++++- r/man/register_user_defined_function.Rd | 2 +- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index ccd07cb8dec..7304740f0d1 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -344,7 +344,7 @@ cast_options <- function(safe = TRUE, ...) { #' `register_user_defined_function()`. #' @export #' -#' @examplesIf .Machine$sizeof.pointer >= 8 +#' @examplesIf arrow_with_dataset() #' fun_wrapper <- arrow_scalar_function( #' function(x, y, z) x + y + z, #' schema(x = float64(), y = float64(), z = float64()), diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R index 1df83746744..0bc278f0241 100644 --- a/r/R/dplyr-funcs.R +++ b/r/R/dplyr-funcs.R @@ -52,7 +52,11 @@ NULL #' - `options`: list of function options, as passed to call_function #' @param update_cache Update .cache$functions at the time of registration. #' the default is FALSE because the majority of usage is to register -#' bindings at package load, after which we create the cache once. +#' bindings at package load, after which we create the cache once. The +#' reason why .cache$functions is needed in addition to nse_funcs for +#' non-aggregate functions could be revisited...it is currently used +#' as the data mask in mutate, filter, and aggregate (but not +#' summarise) because the data mask has to be a list. #' @param registry An environment in which the functions should be #' assigned. #' diff --git a/r/man/register_binding.Rd b/r/man/register_binding.Rd index 77049c24c3b..c53df707516 100644 --- a/r/man/register_binding.Rd +++ b/r/man/register_binding.Rd @@ -20,7 +20,11 @@ assigned.} \item{update_cache}{Update .cache$functions at the time of registration. the default is FALSE because the majority of usage is to register -bindings at package load, after which we create the cache once.} +bindings at package load, after which we create the cache once. The +reason why .cache$functions is needed in addition to nse_funcs for +non-aggregate functions could be revisited...it is currently used +as the data mask in mutate, filter, and aggregate (but not +summarise) because the data mask has to be a list.} \item{agg_fun}{An aggregate function or \code{NULL} to un-register a previous aggregate function. This function must accept \code{Expression} objects as diff --git a/r/man/register_user_defined_function.Rd b/r/man/register_user_defined_function.Rd index e462381d5e4..44cc28c5ea5 100644 --- a/r/man/register_user_defined_function.Rd +++ b/r/man/register_user_defined_function.Rd @@ -56,7 +56,7 @@ returns R objects; use \code{\link[=arrow_advanced_scalar_function]{arrow_advanc lower-level function that operates directly on Arrow objects. } \examples{ -\dontshow{if (.Machine$sizeof.pointer >= 8) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +\dontshow{if (arrow_with_dataset()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} fun_wrapper <- arrow_scalar_function( function(x, y, z) x + y + z, schema(x = float64(), y = float64(), z = float64()), From 6f3d6012c04f70630a961af3e9bab2d0996aecd5 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 11 Jul 2022 12:27:25 -0300 Subject: [PATCH 60/92] see if the lack of error on Windows is because it actually works --- r/tests/testthat/test-compute.R | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index f1fd274e821..46c92b9992a 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -196,8 +196,17 @@ test_that("user-defined error when called from an unsupported context", { rbr$read_table() } - expect_error( - stream_plan_with_udf(), - "Call to R \\(.*?\\) from a non-R thread from an unsupported context" - ) + if (identical(tolower(Sys.info()[["sysname"]]), "windows")) { + expect_equal( + stream_plan_with_udf(), + record_batch(a = 1:1000) %>% + dplyr::mutate(b = times_32(a)) %>% + dplyr::collect(as_data_frame = FALSE) + ) + } else { + expect_error( + stream_plan_with_udf(), + "Call to R \\(.*?\\) from a non-R thread from an unsupported context" + ) + } }) From 031ec64400ff9cebbc07abaebe102e69f346ea39 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 11 Jul 2022 12:38:49 -0300 Subject: [PATCH 61/92] test adding multiple kernels at once --- r/src/compute.cpp | 2 +- r/tests/testthat/test-compute.R | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index b2b4c1b3083..29f349d1334 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -724,7 +724,7 @@ void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { std::vector compute_in_types(in_types->num_fields()); for (int64_t j = 0; j < in_types->num_fields(); j++) { - compute_in_types[i] = arrow::compute::InputType(in_types->field(j)->type()); + compute_in_types[j] = arrow::compute::InputType(in_types->field(j)->type()); } arrow::compute::OutputType out_type((&ResolveScalarUDFOutputType)); diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 46c92b9992a..8bb10be1990 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -110,6 +110,33 @@ test_that("register_user_defined_function() adds a compute function to the regis ) }) +test_that("register_user_defined_function() can register multiple kernels", { + skip_if_not_available("dataset") + + times_32_wrapper <- arrow_scalar_function( + function(x) x * 32L, + in_type = list(int32(), int64(), float64()), + out_type = function(in_types) in_types[[1]] + ) + + register_user_defined_function(times_32_wrapper, "times_32") + + expect_equal( + call_function("times_32", Scalar$create(1L, int32())), + Scalar$create(32L, int32()) + ) + + expect_equal( + call_function("times_32", Scalar$create(1L, int64())), + Scalar$create(32L, int64()) + ) + + expect_equal( + call_function("times_32", Scalar$create(1L, float64())), + Scalar$create(32L, float64()) + ) +}) + test_that("register_user_defined_function() errors for unsupported specifications", { no_kernel_wrapper <- arrow_scalar_function( function(...) NULL, From 0652ae0f04c6abb830a87cc9ccabb4236c2c18ec Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 11 Jul 2022 12:56:20 -0300 Subject: [PATCH 62/92] cleaner handling of number of arguments in user-provided kernels --- r/src/compute.cpp | 30 ++++++++++++------------------ r/tests/testthat/test-compute.R | 2 +- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 29f349d1334..dad0b26a5fa 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -686,30 +686,24 @@ void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { cpp11::stop("Can't register user-defined function with zero kernels"); } - // compute the Arity from the list of input kernels - std::vector n_args(n_kernels); - for (R_xlen_t i = 0; i < n_kernels; i++) { + // Compute the Arity from the list of input kernels. We don't currently handle + // variable numbers of arguments in a user-defined function. + int64_t n_args = + cpp11::as_cpp>(in_type_r[0])->num_fields(); + for (R_xlen_t i = 1; i < n_kernels; i++) { auto in_types = cpp11::as_cpp>(in_type_r[i]); - n_args[i] = in_types->num_fields(); - } - - const int64_t min_args = *std::min_element(n_args.begin(), n_args.end()); - const int64_t max_args = *std::max_element(n_args.begin(), n_args.end()); - - // We can't currently handle variable numbers of arguments in a user-defined - // function and we don't have a mechanism for the user to specify a variable - // number of arguments at the end of a signature. - if (min_args != max_args) { - cpp11::stop( - "User-defined function with a variable number of arguments is not supported"); + if (in_types->num_fields() != n_args) { + cpp11::stop( + "Kernels for user-defined function must accept the same number of arguments"); + } } - arrow::compute::Arity arity(min_args, false); + arrow::compute::Arity arity(n_args, false); // The function documentation isn't currently accessible from R but is required // for the C++ function constructor. - std::vector dummy_argument_names(min_args); - for (int64_t i = 0; i < min_args; i++) { + std::vector dummy_argument_names(n_args); + for (int64_t i = 0; i < n_args; i++) { dummy_argument_names[i] = "arg"; } const arrow::compute::FunctionDoc dummy_function_doc{ diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 8bb10be1990..5992fc050b7 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -155,7 +155,7 @@ test_that("register_user_defined_function() errors for unsupported specification ) expect_error( register_user_defined_function(varargs_kernel_wrapper, "var_kernels"), - "User-defined function with a variable number of arguments is not supported" + "Kernels for user-defined function must accept the same number of arguments" ) }) From 49261d6acd89d7da9753e66d73896a9d6c466f9e Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 11 Jul 2022 14:42:38 -0300 Subject: [PATCH 63/92] improvements for readability and performance in safe-call-into-r.h --- r/src/safe-call-into-r.h | 61 +++++++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h index 95edd7468e7..11ca1ce1c2a 100644 --- a/r/src/safe-call-into-r.h +++ b/r/src/safe-call-into-r.h @@ -32,8 +32,13 @@ // on 32-bit R builds on R 3.6 and lower. static inline bool CanSafeCallIntoR() { #if defined(HAS_UNWIND_PROTECT) - cpp11::function on_old_windows_fun = cpp11::package("arrow")["on_old_windows"]; - bool on_old_windows = on_old_windows_fun(); + // Cache value to avoid calling into R more than once to check + static int on_old_windows = -1; + if (on_old_windows == -1) { + cpp11::function on_old_windows_fun = cpp11::package("arrow")["on_old_windows"]; + on_old_windows = on_old_windows_fun(); + } + return !on_old_windows; #else return false; @@ -55,7 +60,7 @@ class MainRThread { void Initialize() { thread_id_ = std::this_thread::get_id(); initialized_ = true; - SetError(R_NilValue); + ResetError(); } bool IsInitialized() { return initialized_; } @@ -63,33 +68,34 @@ class MainRThread { // Check if the current thread is the main R thread bool IsMainThread() { return initialized_ && std::this_thread::get_id() == thread_id_; } + // Check if a SafeCallIntoR call is able to execute + bool CanExecuteSafeCallIntoR() { return IsMainThread() || executor_ != nullptr; } + // The Executor that is running on the main R thread, if it exists arrow::internal::Executor*& Executor() { return executor_; } - // Save an error token generated from a cpp11::unwind_exception - // so that it can be properly handled after some cleanup code - // has run (e.g., cancelling some futures or waiting for them - // to finish). - void SetError(cpp11::sexp token) { error_token_ = token; } + // Save an error (possibly with an error token generated from + // a cpp11::unwind_exception) so that it can be properly handled + // after some cleanup code has run (e.g., cancelling some futures + // or waiting for them to finish). + void SetError(arrow::Status status) { status_ = status; } - void ResetError() { error_token_ = R_NilValue; } + void ResetError() { status_ = arrow::Status::OK(); } // Check if there is a saved error - bool HasError() { return error_token_ != R_NilValue; } + bool HasError() { return !status_.ok(); } - // Throw a cpp11::unwind_exception() with the saved token if it exists + // Throw a cpp11::unwind_exception() if void ClearError() { - if (HasError()) { - cpp11::unwind_exception e(error_token_); - ResetError(); - throw e; - } + arrow::Status maybe_error_status = status_; + ResetError(); + arrow::StopIfNotOk(maybe_error_status); } private: bool initialized_; std::thread::id thread_id_; - cpp11::sexp error_token_; + arrow::Status status_; arrow::internal::Executor* executor_; }; @@ -108,21 +114,32 @@ arrow::Future SafeCallIntoRAsync(std::function(void)> fun, // the cpp11::unwind_exception be thrown since it will be caught // at the top level. return fun(); - } else if (main_r_thread.Executor() != nullptr) { + } else if (main_r_thread.CanExecuteSafeCallIntoR()) { // If we are not on the main thread and have an Executor, // use it to run the task on the main R thread. We can't throw // a cpp11::unwind_exception here, so we need to propagate it back // to RunWithCapturedR through the MainRThread singleton. - return DeferNotOk(main_r_thread.Executor()->Submit([fun]() { + return DeferNotOk(main_r_thread.Executor()->Submit([fun, reason]() { + // This occurs when some other R code that was previously scheduled to run + // has errored, in which case we skip execution and let the original + // error surface. if (GetMainRThread().HasError()) { - return arrow::Result(arrow::Status::UnknownError("R code execution error")); + return arrow::Result( + arrow::Status::Cancelled("Previous R code execution error (", reason, ")")); } try { return fun(); } catch (cpp11::unwind_exception& e) { - GetMainRThread().SetError(e.token); - return arrow::Result(arrow::Status::UnknownError("R code execution error")); + // Here we save the token and set the main R thread to an error state + GetMainRThread().SetError(arrow::StatusUnwindProtect(e.token)); + + // We also return an error although this should not surface because + // main_r_thread.ClearError() will get called before this value can be + // returned and will StopIfNotOk(). We don't save the error token here + // to ensure that it will only get thrown once. + return arrow::Result( + arrow::Status::UnknownError("R code execution error (", reason, ")")); } })); } else { From c1207eb84a3ec78b48fd68aec64794070c774882 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 11 Jul 2022 16:18:29 -0300 Subject: [PATCH 64/92] revert change to executor checking --- r/src/safe-call-into-r.h | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h index 11ca1ce1c2a..1e62941abbe 100644 --- a/r/src/safe-call-into-r.h +++ b/r/src/safe-call-into-r.h @@ -68,9 +68,6 @@ class MainRThread { // Check if the current thread is the main R thread bool IsMainThread() { return initialized_ && std::this_thread::get_id() == thread_id_; } - // Check if a SafeCallIntoR call is able to execute - bool CanExecuteSafeCallIntoR() { return IsMainThread() || executor_ != nullptr; } - // The Executor that is running on the main R thread, if it exists arrow::internal::Executor*& Executor() { return executor_; } @@ -114,7 +111,7 @@ arrow::Future SafeCallIntoRAsync(std::function(void)> fun, // the cpp11::unwind_exception be thrown since it will be caught // at the top level. return fun(); - } else if (main_r_thread.CanExecuteSafeCallIntoR()) { + } else if (main_r_thread.Executor() != nullptr) { // If we are not on the main thread and have an Executor, // use it to run the task on the main R thread. We can't throw // a cpp11::unwind_exception here, so we need to propagate it back From 72d650d8bcbd957cc0afa78c819e2e1c1012a0ff Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 11 Jul 2022 16:48:13 -0300 Subject: [PATCH 65/92] don't automatically cast output types --- r/R/compute.R | 4 +++- r/man/register_user_defined_function.Rd | 4 +++- r/src/compute.cpp | 20 +++++++++++-------- r/tests/testthat/test-compute.R | 26 +++++++++++++++++++++++++ 4 files changed, 44 insertions(+), 10 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 7304740f0d1..fc43c3f48f2 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -317,7 +317,9 @@ cast_options <- function(safe = TRUE, ...) { #' #' @param name The function name to be used in the dplyr bindings #' @param scalar_function An object created with [arrow_scalar_function()] -#' or [arrow_advanced_scalar_function()]. +#' or [arrow_advanced_scalar_function()]. Scalar functions must be +#' stateless and return output with the same shape (i.e., the same +#' number of rows) as the input. #' @param in_type A [DataType] of the input type or a [schema()] #' for functions with more than one argument. This signature will be used #' to determine if this function is appropriate for a given set of arguments. diff --git a/r/man/register_user_defined_function.Rd b/r/man/register_user_defined_function.Rd index 44cc28c5ea5..ac8db0137c7 100644 --- a/r/man/register_user_defined_function.Rd +++ b/r/man/register_user_defined_function.Rd @@ -14,7 +14,9 @@ arrow_advanced_scalar_function(advanced_fun, in_type, out_type) } \arguments{ \item{scalar_function}{An object created with \code{\link[=arrow_scalar_function]{arrow_scalar_function()}} -or \code{\link[=arrow_advanced_scalar_function]{arrow_advanced_scalar_function()}}.} +or \code{\link[=arrow_advanced_scalar_function]{arrow_advanced_scalar_function()}}. Scalar functions must be +stateless and return output with the same shape (i.e., the same +number of rows) as the input.} \item{name}{The function name to be used in the dplyr bindings} diff --git a/r/src/compute.cpp b/r/src/compute.cpp index dad0b26a5fa..e0d69cb2f9f 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -603,7 +603,9 @@ arrow::Result ResolveScalarUDFOutputType( cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp); if (!Rf_inherits(output_type_sexp, "DataType")) { - cpp11::stop("arrow_scalar_function resolver must return a DataType"); + cpp11::stop( + "Function specified as arrow_scalar_function() out_type argument must " + "return a DataType"); } return arrow::TypeHolder( @@ -648,11 +650,12 @@ arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context, if (Rf_inherits(func_result_sexp, "Array")) { auto array = cpp11::as_cpp>(func_result_sexp); - // handle an Array result of the wrong type + // Error for an Array result of the wrong type if (!result->type()->Equals(array->type())) { - arrow::Datum out = ValueOrStop(arrow::compute::Cast(array, result->type())); - std::shared_ptr out_array = out.make_array(); - array.swap(out_array); + return cpp11::stop( + "Expected return Array or Scalar with type '%s' from user-defined " + "function but got Array with type '%s'", + result->type()->ToString().c_str(), array->type()->ToString().c_str()); } result->value = std::move(array->data()); @@ -661,9 +664,10 @@ arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context, // handle a Scalar result of the wrong type if (!result->type()->Equals(scalar->type)) { - arrow::Datum out = ValueOrStop(arrow::compute::Cast(scalar, result->type())); - std::shared_ptr out_scalar = out.scalar(); - scalar.swap(out_scalar); + return cpp11::stop( + "Expected return Array or Scalar with type '%s' from user-defined " + "function but got Scalar with type '%s'", + result->type()->ToString().c_str(), scalar->type->ToString().c_str()); } auto array = ValueOrStop( diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 5992fc050b7..dc87cb0b445 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -110,6 +110,32 @@ test_that("register_user_defined_function() adds a compute function to the regis ) }) +test_that("arrow_scalar_function() with bad return type errors", { + skip_if_not_available("dataset") + + times_32_wrapper <- arrow_advanced_scalar_function( + function(context, args) Array$create(args[[1]], int32()), + int32(), float64() + ) + + register_user_defined_function(times_32_wrapper, "times_32_bad_return_type") + expect_error( + call_function("times_32_bad_return_type", Array$create(1L)), + "Expected return Array or Scalar with type 'double'" + ) + + times_32_wrapper <- arrow_advanced_scalar_function( + function(context, args) Scalar$create(args[[1]], int32()), + int32(), float64() + ) + + register_user_defined_function(times_32_wrapper, "times_32_bad_return_type") + expect_error( + call_function("times_32_bad_return_type", Array$create(1L)), + "Expected return Array or Scalar with type 'double'" + ) +}) + test_that("register_user_defined_function() can register multiple kernels", { skip_if_not_available("dataset") From a1f8b53a6868e61e8d34a654613475441f890ca0 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 11 Jul 2022 16:56:23 -0300 Subject: [PATCH 66/92] document return value of advanced_fun --- r/R/compute.R | 5 ++++- r/man/register_user_defined_function.Rd | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index fc43c3f48f2..5a6a6d3b461 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -337,7 +337,10 @@ cast_options <- function(safe = TRUE, ...) { #' function will be called with exactly two arguments: `kernel_context`, #' which is a `list()` of objects giving information about the #' execution context and `args`, which is a list of [Array] or [Scalar] -#' objects corresponding to the input arguments. +#' objects corresponding to the input arguments. The function must return +#' an Array or Scalar with the type equal to +#' `kernel_context$output_type` and length equal to +#' `kernel_context$batch_length`. #' #' @return #' - `register_user_defined_function()`: `NULL`, invisibly diff --git a/r/man/register_user_defined_function.Rd b/r/man/register_user_defined_function.Rd index ac8db0137c7..4aa9c0b23d2 100644 --- a/r/man/register_user_defined_function.Rd +++ b/r/man/register_user_defined_function.Rd @@ -40,7 +40,10 @@ function it must return a \link{DataType}.} function will be called with exactly two arguments: \code{kernel_context}, which is a \code{list()} of objects giving information about the execution context and \code{args}, which is a list of \link{Array} or \link{Scalar} -objects corresponding to the input arguments.} +objects corresponding to the input arguments. The function must return +an Array or Scalar with the type equal to +\code{kernel_context$output_type} and length equal to +\code{kernel_context$batch_length}.} } \value{ \itemize{ From 4acaa619137b46eaa3d5fa265f94c50d75ad715d Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 11 Jul 2022 17:04:46 -0300 Subject: [PATCH 67/92] revert static variable change --- r/src/safe-call-into-r.h | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h index 1e62941abbe..e13238e9619 100644 --- a/r/src/safe-call-into-r.h +++ b/r/src/safe-call-into-r.h @@ -32,13 +32,8 @@ // on 32-bit R builds on R 3.6 and lower. static inline bool CanSafeCallIntoR() { #if defined(HAS_UNWIND_PROTECT) - // Cache value to avoid calling into R more than once to check - static int on_old_windows = -1; - if (on_old_windows == -1) { - cpp11::function on_old_windows_fun = cpp11::package("arrow")["on_old_windows"]; - on_old_windows = on_old_windows_fun(); - } - + cpp11::function on_old_windows_fun = cpp11::package("arrow")["on_old_windows"]; + bool on_old_windows = on_old_windows_fun(); return !on_old_windows; #else return false; From 7ccb23b5fba0aff789bca17c7db647f5dba82c18 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 11 Jul 2022 21:40:14 -0300 Subject: [PATCH 68/92] keep old error behaviour --- r/src/safe-call-into-r.h | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h index e13238e9619..64a70cc43d1 100644 --- a/r/src/safe-call-into-r.h +++ b/r/src/safe-call-into-r.h @@ -32,8 +32,12 @@ // on 32-bit R builds on R 3.6 and lower. static inline bool CanSafeCallIntoR() { #if defined(HAS_UNWIND_PROTECT) - cpp11::function on_old_windows_fun = cpp11::package("arrow")["on_old_windows"]; - bool on_old_windows = on_old_windows_fun(); + static int on_old_windows = -1; + if (on_old_windows == -1) { + cpp11::function on_old_windows_fun = cpp11::package("arrow")["on_old_windows"]; + on_old_windows = on_old_windows_fun(); + } + return !on_old_windows; #else return false; @@ -66,28 +70,30 @@ class MainRThread { // The Executor that is running on the main R thread, if it exists arrow::internal::Executor*& Executor() { return executor_; } - // Save an error (possibly with an error token generated from - // a cpp11::unwind_exception) so that it can be properly handled - // after some cleanup code has run (e.g., cancelling some futures - // or waiting for them to finish). - void SetError(arrow::Status status) { status_ = status; } + // Save an error token generated from a cpp11::unwind_exception + // so that it can be properly handled after some cleanup code + // has run (e.g., cancelling some futures or waiting for them + // to finish). + void SetError(cpp11::sexp token) { error_token_ = token; } - void ResetError() { status_ = arrow::Status::OK(); } + void ResetError() { error_token_ = R_NilValue; } // Check if there is a saved error - bool HasError() { return !status_.ok(); } + bool HasError() { return error_token_ != R_NilValue; } - // Throw a cpp11::unwind_exception() if + // Throw a cpp11::unwind_exception() with the saved token if it exists void ClearError() { - arrow::Status maybe_error_status = status_; - ResetError(); - arrow::StopIfNotOk(maybe_error_status); + if (HasError()) { + cpp11::unwind_exception e(error_token_); + ResetError(); + throw e; + } } private: bool initialized_; std::thread::id thread_id_; - arrow::Status status_; + cpp11::sexp error_token_; arrow::internal::Executor* executor_; }; From 6ff4fb2f49c09ab1ab7768a653179e072f58cca8 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Mon, 11 Jul 2022 21:49:46 -0300 Subject: [PATCH 69/92] fix one more safe call into R change --- r/src/safe-call-into-r.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h index 64a70cc43d1..bd5d7b20ed6 100644 --- a/r/src/safe-call-into-r.h +++ b/r/src/safe-call-into-r.h @@ -130,7 +130,7 @@ arrow::Future SafeCallIntoRAsync(std::function(void)> fun, return fun(); } catch (cpp11::unwind_exception& e) { // Here we save the token and set the main R thread to an error state - GetMainRThread().SetError(arrow::StatusUnwindProtect(e.token)); + GetMainRThread().SetError(e.token); // We also return an error although this should not surface because // main_r_thread.ClearError() will get called before this value can be From 83aa148f0555420e12f91426326980608ae06a27 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 12 Jul 2022 11:12:33 -0300 Subject: [PATCH 70/92] unify test skipping based on whether or not we can runwithcapturedr and reinstate improvements to safe-call-into-r.h --- r/R/arrowExports.R | 4 +++ r/src/arrowExports.cpp | 8 +++++ r/src/safe-call-into-r-impl.cpp | 15 ++++++++ r/src/safe-call-into-r.h | 57 +++++++++++++------------------ r/tests/testthat/test-compute.R | 8 +++-- r/tests/testthat/test-csv.R | 4 +-- r/tests/testthat/test-extension.R | 1 + r/tests/testthat/test-feather.R | 6 +--- 8 files changed, 59 insertions(+), 44 deletions(-) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 4b4d90cf34a..dfe0db614ad 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -1800,6 +1800,10 @@ InitializeMainRThread <- function() { invisible(.Call(`_arrow_InitializeMainRThread`)) } +CanRunWithCapturedR <- function() { + .Call(`_arrow_CanRunWithCapturedR`) +} + TestSafeCallIntoR <- function(r_fun_that_returns_a_string, opt) { .Call(`_arrow_TestSafeCallIntoR`, r_fun_that_returns_a_string, opt) } diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index e82af9be48b..83ed4338446 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -4621,6 +4621,13 @@ BEGIN_CPP11 END_CPP11 } // safe-call-into-r-impl.cpp +bool CanRunWithCapturedR(); +extern "C" SEXP _arrow_CanRunWithCapturedR(){ +BEGIN_CPP11 + return cpp11::as_sexp(CanRunWithCapturedR()); +END_CPP11 +} +// safe-call-into-r-impl.cpp std::string TestSafeCallIntoR(cpp11::function r_fun_that_returns_a_string, std::string opt); extern "C" SEXP _arrow_TestSafeCallIntoR(SEXP r_fun_that_returns_a_string_sexp, SEXP opt_sexp){ BEGIN_CPP11 @@ -5608,6 +5615,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_ipc___RecordBatchFileWriter__Open", (DL_FUNC) &_arrow_ipc___RecordBatchFileWriter__Open, 4}, { "_arrow_ipc___RecordBatchStreamWriter__Open", (DL_FUNC) &_arrow_ipc___RecordBatchStreamWriter__Open, 4}, { "_arrow_InitializeMainRThread", (DL_FUNC) &_arrow_InitializeMainRThread, 0}, + { "_arrow_CanRunWithCapturedR", (DL_FUNC) &_arrow_CanRunWithCapturedR, 0}, { "_arrow_TestSafeCallIntoR", (DL_FUNC) &_arrow_TestSafeCallIntoR, 2}, { "_arrow_Array__GetScalar", (DL_FUNC) &_arrow_Array__GetScalar, 2}, { "_arrow_Scalar__ToString", (DL_FUNC) &_arrow_Scalar__ToString, 1}, diff --git a/r/src/safe-call-into-r-impl.cpp b/r/src/safe-call-into-r-impl.cpp index 7c5e75b788e..7318c81bb55 100644 --- a/r/src/safe-call-into-r-impl.cpp +++ b/r/src/safe-call-into-r-impl.cpp @@ -29,6 +29,21 @@ MainRThread& GetMainRThread() { // [[arrow::export]] void InitializeMainRThread() { GetMainRThread().Initialize(); } +// [[arrow::export]] +bool CanRunWithCapturedR() { +#if defined(HAS_UNWIND_PROTECT) + static int on_old_windows = -1; + if (on_old_windows == -1) { + cpp11::function on_old_windows_fun = cpp11::package("arrow")["on_old_windows"]; + on_old_windows = on_old_windows_fun(); + } + + return !on_old_windows; +#else + return false; +#endif +} + // [[arrow::export]] std::string TestSafeCallIntoR(cpp11::function r_fun_that_returns_a_string, std::string opt) { diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h index bd5d7b20ed6..937163a05df 100644 --- a/r/src/safe-call-into-r.h +++ b/r/src/safe-call-into-r.h @@ -29,20 +29,10 @@ // Unwind protection was added in R 3.5 and some calls here use it // and crash R in older versions (ARROW-16201). Crashes also occur -// on 32-bit R builds on R 3.6 and lower. -static inline bool CanSafeCallIntoR() { -#if defined(HAS_UNWIND_PROTECT) - static int on_old_windows = -1; - if (on_old_windows == -1) { - cpp11::function on_old_windows_fun = cpp11::package("arrow")["on_old_windows"]; - on_old_windows = on_old_windows_fun(); - } - - return !on_old_windows; -#else - return false; -#endif -} +// on 32-bit R builds on R 3.6 and lower. Implementation provided +// in safe-call-into-r-impl.cpp so that we can skip some tests +// when this feature is not provided. +bool CanRunWithCapturedR(); // The MainRThread class keeps track of the thread on which it is safe // to call the R API to facilitate its safe use (or erroring @@ -67,33 +57,34 @@ class MainRThread { // Check if the current thread is the main R thread bool IsMainThread() { return initialized_ && std::this_thread::get_id() == thread_id_; } + // Check if a SafeCallIntoR call is able to execute + bool CanExecuteSafeCallIntoR() { return IsMainThread() || executor_ != nullptr; } + // The Executor that is running on the main R thread, if it exists arrow::internal::Executor*& Executor() { return executor_; } - // Save an error token generated from a cpp11::unwind_exception - // so that it can be properly handled after some cleanup code - // has run (e.g., cancelling some futures or waiting for them - // to finish). - void SetError(cpp11::sexp token) { error_token_ = token; } + // Save an error (possibly with an error token generated from + // a cpp11::unwind_exception) so that it can be properly handled + // after some cleanup code has run (e.g., cancelling some futures + // or waiting for them to finish). + void SetError(arrow::Status status) { status_ = status; } - void ResetError() { error_token_ = R_NilValue; } + void ResetError() { status_ = arrow::Status::OK(); } // Check if there is a saved error - bool HasError() { return error_token_ != R_NilValue; } + bool HasError() { return !status_.ok(); } - // Throw a cpp11::unwind_exception() with the saved token if it exists + // Throw a cpp11::unwind_exception() if void ClearError() { - if (HasError()) { - cpp11::unwind_exception e(error_token_); - ResetError(); - throw e; - } + arrow::Status maybe_error_status = status_; + ResetError(); + arrow::StopIfNotOk(maybe_error_status); } private: bool initialized_; std::thread::id thread_id_; - cpp11::sexp error_token_; + arrow::Status status_; arrow::internal::Executor* executor_; }; @@ -112,7 +103,7 @@ arrow::Future SafeCallIntoRAsync(std::function(void)> fun, // the cpp11::unwind_exception be thrown since it will be caught // at the top level. return fun(); - } else if (main_r_thread.Executor() != nullptr) { + } else if (main_r_thread.CanExecuteSafeCallIntoR()) { // If we are not on the main thread and have an Executor, // use it to run the task on the main R thread. We can't throw // a cpp11::unwind_exception here, so we need to propagate it back @@ -130,7 +121,7 @@ arrow::Future SafeCallIntoRAsync(std::function(void)> fun, return fun(); } catch (cpp11::unwind_exception& e) { // Here we save the token and set the main R thread to an error state - GetMainRThread().SetError(e.token); + GetMainRThread().SetError(arrow::StatusUnwindProtect(e.token)); // We also return an error although this should not surface because // main_r_thread.ClearError() will get called before this value can be @@ -169,7 +160,7 @@ static inline arrow::Status SafeCallIntoRVoid(std::function fun, // return a Future<>. template arrow::Result RunWithCapturedR(std::function()> make_arrow_call) { - if (!CanSafeCallIntoR()) { + if (!CanRunWithCapturedR()) { return arrow::Status::NotImplemented( "RunWithCapturedR() without UnwindProtect or on 32-bit Windows + R <= 3.6"); } @@ -195,13 +186,13 @@ arrow::Result RunWithCapturedR(std::function()> make_arrow_c // Performs an Arrow call (e.g., run an exec plan) in such a way that background threads // can use SafeCallIntoR(). This version is useful for Arrow calls that do not already // return a Future<>(). If it is not possible to use RunWithCapturedR() (i.e., -// CanSafeCallIntoR() returns false), this will run make_arrow_call on the main +// CanRunWithCapturedR() returns false), this will run make_arrow_call on the main // R thread (which will cause background threads that try to SafeCallIntoR() to // error). template arrow::Result RunWithCapturedRIfPossible( std::function()> make_arrow_call) { - if (CanSafeCallIntoR()) { + if (CanRunWithCapturedR()) { // Note that the use of the io_context here is arbitrary (i.e. we could use // any construct that launches a background thread). const auto& io_context = arrow::io::default_io_context(); diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index dc87cb0b445..3ea87d4520b 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -80,7 +80,7 @@ test_that("arrow_scalar_function() returns an advanced scalar function", { }) test_that("register_user_defined_function() adds a compute function to the registry", { - skip_if_not_available("dataset") + skip_if_not(CanRunWithCapturedR()) times_32_wrapper <- arrow_scalar_function( function(x) x * 32.0, @@ -111,7 +111,7 @@ test_that("register_user_defined_function() adds a compute function to the regis }) test_that("arrow_scalar_function() with bad return type errors", { - skip_if_not_available("dataset") + skip_if_not(CanRunWithCapturedR()) times_32_wrapper <- arrow_advanced_scalar_function( function(context, args) Array$create(args[[1]], int32()), @@ -137,7 +137,7 @@ test_that("arrow_scalar_function() with bad return type errors", { }) test_that("register_user_defined_function() can register multiple kernels", { - skip_if_not_available("dataset") + skip_if_not(CanRunWithCapturedR()) times_32_wrapper <- arrow_scalar_function( function(x) x * 32L, @@ -186,6 +186,7 @@ test_that("register_user_defined_function() errors for unsupported specification }) test_that("user-defined functions work during multi-threaded execution", { + skip_if_not(CanRunWithCapturedR()) skip_if_not_available("dataset") n_rows <- 10000 @@ -234,6 +235,7 @@ test_that("user-defined functions work during multi-threaded execution", { test_that("user-defined error when called from an unsupported context", { skip_if_not_available("dataset") + skip_if_not(CanRunWithCapturedR()) times_32_wrapper <- arrow_scalar_function( function(x) x * 32.0, diff --git a/r/tests/testthat/test-csv.R b/r/tests/testthat/test-csv.R index fca717cc051..d4878e6d670 100644 --- a/r/tests/testthat/test-csv.R +++ b/r/tests/testthat/test-csv.R @@ -293,9 +293,7 @@ test_that("more informative error when reading a CSV with headers and schema", { }) test_that("read_csv_arrow() and write_csv_arrow() accept connection objects", { - # connections with csv need RunWithCapturedR, which is not available - # in R <= 3.4.4 - skip_on_r_older_than("3.5") + skip_if_not(CanRunWithCapturedR()) tf <- tempfile() on.exit(unlink(tf)) diff --git a/r/tests/testthat/test-extension.R b/r/tests/testthat/test-extension.R index 638869dc8c3..55a1f8d21ee 100644 --- a/r/tests/testthat/test-extension.R +++ b/r/tests/testthat/test-extension.R @@ -312,6 +312,7 @@ test_that("Table can roundtrip extension types", { test_that("Dataset/arrow_dplyr_query can roundtrip extension types", { skip_if_not_available("dataset") + skip_if_not(CanRunWithCapturedR()) tf <- tempfile() on.exit(unlink(tf, recursive = TRUE)) diff --git a/r/tests/testthat/test-feather.R b/r/tests/testthat/test-feather.R index bed097762a2..99dc8ab9c90 100644 --- a/r/tests/testthat/test-feather.R +++ b/r/tests/testthat/test-feather.R @@ -179,11 +179,7 @@ test_that("read_feather requires RandomAccessFile and errors nicely otherwise (A }) test_that("read_feather() and write_feather() accept connection objects", { - # connection object don't work on Windows i386 before R 4.0 - skip_if(on_old_windows()) - # connections with feather need RunWithCapturedR, which is not available - # in R <= 3.4.4 - skip_on_r_older_than("3.5") + skip_if_not(CanRunWithCapturedR()) tf <- tempfile() on.exit(unlink(tf)) From 8a8955fda98d35c909a20ce0b6d2c4251bfab092 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 12 Jul 2022 11:55:51 -0300 Subject: [PATCH 71/92] more skips aligned with run with captured R usage --- r/tests/testthat/test-safe-call-into-r.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/r/tests/testthat/test-safe-call-into-r.R b/r/tests/testthat/test-safe-call-into-r.R index d3d1d341010..c07d90433fd 100644 --- a/r/tests/testthat/test-safe-call-into-r.R +++ b/r/tests/testthat/test-safe-call-into-r.R @@ -32,7 +32,7 @@ test_that("SafeCallIntoR works from the main R thread", { }) test_that("SafeCallIntoR works within RunWithCapturedR", { - skip_on_r_older_than("3.5") + skip_if_not(CanRunWithCapturedR()) skip_on_cran() expect_identical( @@ -47,7 +47,7 @@ test_that("SafeCallIntoR works within RunWithCapturedR", { }) test_that("SafeCallIntoR errors from the non-R thread", { - skip_on_r_older_than("3.5") + skip_if_not(CanRunWithCapturedR()) skip_on_cran() expect_error( From 0d3520a1be4f291aecd319a01ec514970e2e85b3 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 13 Jul 2022 22:58:29 -0300 Subject: [PATCH 72/92] nix the advanced interface --- r/NAMESPACE | 1 - r/R/compute.R | 90 ++++++++++--------------- r/man/register_user_defined_function.Rd | 51 ++++---------- r/src/arrowExports.cpp | 4 +- r/src/compute.cpp | 9 +-- r/tests/testthat/_snaps/compute.md | 8 +-- r/tests/testthat/test-compute.R | 86 ++++++++++++----------- 7 files changed, 101 insertions(+), 148 deletions(-) diff --git a/r/NAMESPACE b/r/NAMESPACE index ced43d01463..4040dc3ac2f 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -250,7 +250,6 @@ export(TimestampParser) export(Type) export(UnionDataset) export(all_of) -export(arrow_advanced_scalar_function) export(arrow_available) export(arrow_info) export(arrow_scalar_function) diff --git a/r/R/compute.R b/r/R/compute.R index 5a6a6d3b461..e3d36e2c0ac 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -328,32 +328,28 @@ cast_options <- function(safe = TRUE, ...) { #' @param out_type A [DataType] of the output type or a function accepting #' a single argument (`types`), which is a `list()` of [DataType]s. If a #' function it must return a [DataType]. -#' @param fun An R function or rlang-style lambda expression. This function -#' will be called with R objects as arguments and must return an object -#' that can be converted to an [Array] using [as_arrow_array()]. Function -#' authors must take care to return an array castable to the output data -#' type specified by `out_type`. -#' @param advanced_fun An R function or rlang-style lambda expression. This -#' function will be called with exactly two arguments: `kernel_context`, -#' which is a `list()` of objects giving information about the -#' execution context and `args`, which is a list of [Array] or [Scalar] -#' objects corresponding to the input arguments. The function must return -#' an Array or Scalar with the type equal to -#' `kernel_context$output_type` and length equal to -#' `kernel_context$batch_length`. +#' @param fun An R function or rlang-style lambda expression. The function +#' will be called with a first argument `context` which is a `list()` +#' with elements `batch_size` (the expected length of the output) and +#' `output_type` (the required [DataType] of the output). Subsequent +#' arguments are passed by position as specified by `in_types`. If +#' `auto_convert` is `TRUE`, subsequent arguments are converted to +#' R vectors before being passed to `fun` and the output is automatically +#' constructed with the expected output type via [as_arrow_array()]. #' #' @return #' - `register_user_defined_function()`: `NULL`, invisibly #' - `arrow_scalar_function()`: returns an object of class -#' "arrow_advanced_scalar_function" that can be passed to +#' "arrow_scalar_function" that can be passed to #' `register_user_defined_function()`. #' @export #' #' @examplesIf arrow_with_dataset() #' fun_wrapper <- arrow_scalar_function( -#' function(x, y, z) x + y + z, +#' function(context, x, y, z) x + y + z, #' schema(x = float64(), y = float64(), z = float64()), -#' float64() +#' float64(), +#' auto_convert = TRUE #' ) #' register_user_defined_function(fun_wrapper, "example_add3") #' @@ -364,27 +360,10 @@ cast_options <- function(safe = TRUE, ...) { #' Array$create(3) #' ) #' -#' # use arrow_advanced_scalar_function() for a lower-level interface -#' advanced_fun_wrapper <- arrow_advanced_scalar_function( -#' function(context, args) { -#' args[[1]] + args[[2]] + args[[3]] -#' }, -#' schema(x = float64(), y = float64(), z = float64()), -#' float64() -#' ) -#' register_user_defined_function(advanced_fun_wrapper, "example_add3") -#' -#' call_function( -#' "example_add3", -#' Scalar$create(1), -#' Scalar$create(2), -#' Array$create(3) -#' ) -#' register_user_defined_function <- function(scalar_function, name) { assert_that( is.string(name), - inherits(scalar_function, "arrow_advanced_scalar_function") + inherits(scalar_function, "arrow_scalar_function") ) # register with Arrow C++ function registry (enables its use in @@ -403,23 +382,25 @@ register_user_defined_function <- function(scalar_function, name) { #' @rdname register_user_defined_function #' @export -arrow_scalar_function <- function(fun, in_type, out_type) { +arrow_scalar_function <- function(fun, in_type, out_type, auto_convert = FALSE) { fun <- as_function(fun) - # create a small wrapper that converts Scalar/Array arguments to R vectors - # and converts the result back to an Array - advanced_fun <- function(context, args) { - args <- lapply(args, as.vector) - result <- do.call(fun, args) - as_arrow_array(result, type = context$output_type) + # Create a small wrapper function that is easier to call from C++. + # This wrapper could be implemented in C/C++ to reduce evaluation + # overhead and generate prettier backtraces when errors occur + # (probably using a similar approach to purrr). + if (auto_convert) { + wrapper_fun <- function(context, args) { + args <- lapply(args, as.vector) + result <- do.call(fun, c(list(context), args)) + as_arrow_array(result, type = context$output_type) + } + } else { + wrapper_fun <- function(context, args) { + do.call(fun, c(list(context), args)) + } } - arrow_advanced_scalar_function(advanced_fun, in_type, out_type) -} - -#' @rdname register_user_defined_function -#' @export -arrow_advanced_scalar_function <- function(advanced_fun, in_type, out_type) { if (is.list(in_type)) { in_type <- lapply(in_type, as_scalar_function_in_type) } else { @@ -434,16 +415,13 @@ arrow_advanced_scalar_function <- function(advanced_fun, in_type, out_type) { out_type <- rep_len(out_type, length(in_type)) - advanced_fun <- as_function(advanced_fun) - if (length(formals(advanced_fun)) != 2) { - abort("`advanced_fun` must accept exactly two arguments") - } - structure( - advanced_fun, - in_type = in_type, - out_type = out_type, - class = "arrow_advanced_scalar_function" + list( + wrapper_fun = wrapper_fun, + in_type = in_type, + out_type = out_type + ), + class = "arrow_scalar_function" ) } diff --git a/r/man/register_user_defined_function.Rd b/r/man/register_user_defined_function.Rd index 4aa9c0b23d2..3fb2c2efcd3 100644 --- a/r/man/register_user_defined_function.Rd +++ b/r/man/register_user_defined_function.Rd @@ -3,14 +3,11 @@ \name{register_user_defined_function} \alias{register_user_defined_function} \alias{arrow_scalar_function} -\alias{arrow_advanced_scalar_function} \title{Register user-defined functions} \usage{ register_user_defined_function(scalar_function, name) -arrow_scalar_function(fun, in_type, out_type) - -arrow_advanced_scalar_function(advanced_fun, in_type, out_type) +arrow_scalar_function(fun, in_type, out_type, auto_convert = FALSE) } \arguments{ \item{scalar_function}{An object created with \code{\link[=arrow_scalar_function]{arrow_scalar_function()}} @@ -20,11 +17,14 @@ number of rows) as the input.} \item{name}{The function name to be used in the dplyr bindings} -\item{fun}{An R function or rlang-style lambda expression. This function -will be called with R objects as arguments and must return an object -that can be converted to an \link{Array} using \code{\link[=as_arrow_array]{as_arrow_array()}}. Function -authors must take care to return an array castable to the output data -type specified by \code{out_type}.} +\item{fun}{An R function or rlang-style lambda expression. The function +will be called with a first argument \code{context} which is a \code{list()} +with elements \code{batch_size} (the expected length of the output) and +\code{output_type} (the required \link{DataType} of the output). Subsequent +arguments are passed by position as specified by \code{in_types}. If +\code{auto_convert} is \code{TRUE}, subsequent arguments are converted to +R vectors before being passed to \code{fun} and the output is automatically +constructed with the expected output type via \code{\link[=as_arrow_array]{as_arrow_array()}}.} \item{in_type}{A \link{DataType} of the input type or a \code{\link[=schema]{schema()}} for functions with more than one argument. This signature will be used @@ -35,21 +35,12 @@ If this function is appropriate for more than one signature, pass a \item{out_type}{A \link{DataType} of the output type or a function accepting a single argument (\code{types}), which is a \code{list()} of \link{DataType}s. If a function it must return a \link{DataType}.} - -\item{advanced_fun}{An R function or rlang-style lambda expression. This -function will be called with exactly two arguments: \code{kernel_context}, -which is a \code{list()} of objects giving information about the -execution context and \code{args}, which is a list of \link{Array} or \link{Scalar} -objects corresponding to the input arguments. The function must return -an Array or Scalar with the type equal to -\code{kernel_context$output_type} and length equal to -\code{kernel_context$batch_length}.} } \value{ \itemize{ \item \code{register_user_defined_function()}: \code{NULL}, invisibly \item \code{arrow_scalar_function()}: returns an object of class -"arrow_advanced_scalar_function" that can be passed to +"arrow_scalar_function" that can be passed to \code{register_user_defined_function()}. } } @@ -63,29 +54,13 @@ lower-level function that operates directly on Arrow objects. \examples{ \dontshow{if (arrow_with_dataset()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} fun_wrapper <- arrow_scalar_function( - function(x, y, z) x + y + z, + function(context, x, y, z) x + y + z, schema(x = float64(), y = float64(), z = float64()), - float64() + float64(), + auto_convert = TRUE ) register_user_defined_function(fun_wrapper, "example_add3") -call_function( - "example_add3", - Scalar$create(1), - Scalar$create(2), - Array$create(3) -) - -# use arrow_advanced_scalar_function() for a lower-level interface -advanced_fun_wrapper <- arrow_advanced_scalar_function( - function(context, args) { - args[[1]] + args[[2]] + args[[3]] - }, - schema(x = float64(), y = float64(), z = float64()), - float64() -) -register_user_defined_function(advanced_fun_wrapper, "example_add3") - call_function( "example_add3", Scalar$create(1), diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index 83ed4338446..fd9f92e5d1a 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -1112,11 +1112,11 @@ BEGIN_CPP11 END_CPP11 } // compute.cpp -void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp); +void RegisterScalarUDF(std::string name, cpp11::list func_sexp); extern "C" SEXP _arrow_RegisterScalarUDF(SEXP name_sexp, SEXP func_sexp_sexp){ BEGIN_CPP11 arrow::r::Input::type name(name_sexp); - arrow::r::Input::type func_sexp(func_sexp_sexp); + arrow::r::Input::type func_sexp(func_sexp_sexp); RegisterScalarUDF(name, func_sexp); return R_NilValue; END_CPP11 diff --git a/r/src/compute.cpp b/r/src/compute.cpp index e0d69cb2f9f..f15117f7e48 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -681,9 +681,9 @@ arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context, } // [[arrow::export]] -void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { - cpp11::list in_type_r(func_sexp.attr("in_type")); - cpp11::list out_type_r(func_sexp.attr("out_type")); +void RegisterScalarUDF(std::string name, cpp11::list func_sexp) { + cpp11::list in_type_r(func_sexp["in_type"]); + cpp11::list out_type_r(func_sexp["out_type"]); R_xlen_t n_kernels = in_type_r.size(); if (n_kernels == 0) { @@ -732,7 +732,8 @@ void RegisterScalarUDF(std::string name, cpp11::sexp func_sexp) { arrow::compute::ScalarKernel kernel(signature, &CallRScalarUDF); kernel.mem_allocation = arrow::compute::MemAllocation::NO_PREALLOCATE; kernel.null_handling = arrow::compute::NullHandling::COMPUTED_NO_PREALLOCATE; - kernel.data = std::make_shared(func_sexp, out_type_func); + kernel.data = + std::make_shared(func_sexp["wrapper_fun"], out_type_func); StopIfNotOk(func->AddKernel(std::move(kernel))); } diff --git a/r/tests/testthat/_snaps/compute.md b/r/tests/testthat/_snaps/compute.md index 1885067a873..626608ad2e7 100644 --- a/r/tests/testthat/_snaps/compute.md +++ b/r/tests/testthat/_snaps/compute.md @@ -1,8 +1,4 @@ -# arrow_advanced_scalar_function() works +# arrow_scalar_function() works - `advanced_fun` must accept exactly two arguments - ---- - - Can't convert `advanced_fun`, NULL, to a function. + Can't convert `fun`, NULL, to a function. diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 3ea87d4520b..d1b541bf46a 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -21,60 +21,60 @@ test_that("list_compute_functions() works", { }) -test_that("arrow_advanced_scalar_function() works", { +test_that("arrow_scalar_function() works", { # check in/out type as schema/data type - fun <- arrow_advanced_scalar_function( - function(kernel_context, args) args[[1]], - schema(.y = int32()), int64() + fun <- arrow_scalar_function( + function(context, x) x$cast(int64()), + schema(x = int32()), int64() ) - expect_equal(attr(fun, "in_type")[[1]], schema(.y = int32())) - expect_equal(attr(fun, "out_type")[[1]](), int64()) + expect_equal(fun$in_type[[1]], schema(x = int32())) + expect_equal(fun$out_type[[1]](), int64()) # check in/out type as data type/data type - fun <- arrow_advanced_scalar_function( - function(kernel_context, args) args[[1]], + fun <- arrow_scalar_function( + function(context, x) x$cast(int64()), int32(), int64() ) - expect_equal(attr(fun, "in_type")[[1]][[1]], field("", int32())) - expect_equal(attr(fun, "out_type")[[1]](), int64()) + expect_equal(fun$in_type[[1]][[1]], field("", int32())) + expect_equal(fun$out_type[[1]](), int64()) # check in/out type as field/data type - fun <- arrow_advanced_scalar_function( - function(kernel_context, args) args[[1]], + fun <- arrow_scalar_function( + function(context, a_name) x$cast(int64()), field("a_name", int32()), int64() ) - expect_equal(attr(fun, "in_type")[[1]], schema(a_name = int32())) - expect_equal(attr(fun, "out_type")[[1]](), int64()) + expect_equal(fun$in_type[[1]], schema(a_name = int32())) + expect_equal(fun$out_type[[1]](), int64()) # check in/out type as lists - fun <- arrow_advanced_scalar_function( - function(kernel_context, args) args[[1]], + fun <- arrow_scalar_function( + function(context, x) x, list(int32(), int64()), - list(int64(), int32()) + list(int64(), int32()), + auto_convert = TRUE ) - expect_equal(attr(fun, "in_type")[[1]][[1]], field("", int32())) - expect_equal(attr(fun, "in_type")[[2]][[1]], field("", int64())) - expect_equal(attr(fun, "out_type")[[1]](), int64()) - expect_equal(attr(fun, "out_type")[[2]](), int32()) + expect_equal(fun$in_type[[1]][[1]], field("", int32())) + expect_equal(fun$in_type[[2]][[1]], field("", int64())) + expect_equal(fun$out_type[[1]](), int64()) + expect_equal(fun$out_type[[2]](), int32()) - expect_snapshot_error(arrow_advanced_scalar_function(identity, int32(), int32())) - expect_snapshot_error(arrow_advanced_scalar_function(NULL, int32(), int32())) + expect_snapshot_error(arrow_scalar_function(NULL, int32(), int32())) }) -test_that("arrow_scalar_function() returns an advanced scalar function", { +test_that("arrow_scalar_function() works with auto_convert = TRUE", { times_32_wrapper <- arrow_scalar_function( - function(x) x * 32, + function(context, x) x * 32, float64(), - float64() + float64(), + auto_convert = TRUE ) dummy_kernel_context <- list() - expect_s3_class(times_32_wrapper, "arrow_advanced_scalar_function") expect_equal( - times_32_wrapper(dummy_kernel_context, list(Scalar$create(2))), + times_32_wrapper$wrapper_fun(dummy_kernel_context, list(Scalar$create(2))), Array$create(2 * 32) ) }) @@ -83,13 +83,14 @@ test_that("register_user_defined_function() adds a compute function to the regis skip_if_not(CanRunWithCapturedR()) times_32_wrapper <- arrow_scalar_function( - function(x) x * 32.0, - int32(), float64() + function(context, x) x * 32.0, + int32(), float64(), + auto_convert = TRUE ) register_user_defined_function(times_32_wrapper, "times_32") - expect_true("times_32" %in% names(arrow:::.cache$functions)) + expect_true("times_32" %in% names(asNamespace("arrow")$.cache$functions)) expect_true("times_32" %in% list_compute_functions()) expect_equal( @@ -113,8 +114,8 @@ test_that("register_user_defined_function() adds a compute function to the regis test_that("arrow_scalar_function() with bad return type errors", { skip_if_not(CanRunWithCapturedR()) - times_32_wrapper <- arrow_advanced_scalar_function( - function(context, args) Array$create(args[[1]], int32()), + times_32_wrapper <- arrow_scalar_function( + function(context, x) Array$create(x, int32()), int32(), float64() ) @@ -124,8 +125,8 @@ test_that("arrow_scalar_function() with bad return type errors", { "Expected return Array or Scalar with type 'double'" ) - times_32_wrapper <- arrow_advanced_scalar_function( - function(context, args) Scalar$create(args[[1]], int32()), + times_32_wrapper <- arrow_scalar_function( + function(context, x) Scalar$create(x, int32()), int32(), float64() ) @@ -140,9 +141,10 @@ test_that("register_user_defined_function() can register multiple kernels", { skip_if_not(CanRunWithCapturedR()) times_32_wrapper <- arrow_scalar_function( - function(x) x * 32L, + function(context, x) x * 32L, in_type = list(int32(), int64(), float64()), - out_type = function(in_types) in_types[[1]] + out_type = function(in_types) in_types[[1]], + auto_convert = TRUE ) register_user_defined_function(times_32_wrapper, "times_32") @@ -207,8 +209,9 @@ test_that("user-defined functions work during multi-threaded execution", { write_dataset(example_df, tf_dataset, partitioning = "part") times_32_wrapper <- arrow_scalar_function( - function(x) x * 32.0, - int32(), float64() + function(context, x) x * 32.0, + int32(), float64(), + auto_convert = TRUE ) register_user_defined_function(times_32_wrapper, "times_32") @@ -238,8 +241,9 @@ test_that("user-defined error when called from an unsupported context", { skip_if_not(CanRunWithCapturedR()) times_32_wrapper <- arrow_scalar_function( - function(x) x * 32.0, - int32(), float64() + function(context, x) x * 32.0, + int32(), float64(), + auto_convert = TRUE ) register_user_defined_function(times_32_wrapper, "times_32") From 4ac9ec5eab27892271fd4bdda308f596b55dbd1b Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 14 Jul 2022 07:36:15 -0300 Subject: [PATCH 73/92] document auto_convert --- r/R/compute.R | 4 ++++ r/man/register_user_defined_function.Rd | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/r/R/compute.R b/r/R/compute.R index e3d36e2c0ac..54e9b537ff5 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -336,6 +336,10 @@ cast_options <- function(safe = TRUE, ...) { #' `auto_convert` is `TRUE`, subsequent arguments are converted to #' R vectors before being passed to `fun` and the output is automatically #' constructed with the expected output type via [as_arrow_array()]. +#' @param auto_convert Use `TRUE` to convert inputs before passing to `fun` +#' and construct an Array of the correct type from the output. Use this +#' option to write functions of R objects as opposed to functions of +#' Arrow R6 objects. #' #' @return #' - `register_user_defined_function()`: `NULL`, invisibly diff --git a/r/man/register_user_defined_function.Rd b/r/man/register_user_defined_function.Rd index 3fb2c2efcd3..f91f9a620b8 100644 --- a/r/man/register_user_defined_function.Rd +++ b/r/man/register_user_defined_function.Rd @@ -35,6 +35,11 @@ If this function is appropriate for more than one signature, pass a \item{out_type}{A \link{DataType} of the output type or a function accepting a single argument (\code{types}), which is a \code{list()} of \link{DataType}s. If a function it must return a \link{DataType}.} + +\item{auto_convert}{Use \code{TRUE} to convert inputs before passing to \code{fun} +and construct an Array of the correct type from the output. Use this +option to write functions of R objects as opposed to functions of +Arrow R6 objects.} } \value{ \itemize{ From ba34d1ad142b80999268cdfbe514efd196aa5f50 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 14 Jul 2022 09:02:32 -0300 Subject: [PATCH 74/92] fix one more doc link --- r/R/compute.R | 5 ++--- r/man/register_user_defined_function.Rd | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 54e9b537ff5..8213da1b76e 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -317,9 +317,8 @@ cast_options <- function(safe = TRUE, ...) { #' #' @param name The function name to be used in the dplyr bindings #' @param scalar_function An object created with [arrow_scalar_function()] -#' or [arrow_advanced_scalar_function()]. Scalar functions must be -#' stateless and return output with the same shape (i.e., the same -#' number of rows) as the input. +#' Scalar functions must be stateless and return output with the same +#' shape (i.e., the same number of rows) as the input. #' @param in_type A [DataType] of the input type or a [schema()] #' for functions with more than one argument. This signature will be used #' to determine if this function is appropriate for a given set of arguments. diff --git a/r/man/register_user_defined_function.Rd b/r/man/register_user_defined_function.Rd index f91f9a620b8..c1511ba7321 100644 --- a/r/man/register_user_defined_function.Rd +++ b/r/man/register_user_defined_function.Rd @@ -11,9 +11,8 @@ arrow_scalar_function(fun, in_type, out_type, auto_convert = FALSE) } \arguments{ \item{scalar_function}{An object created with \code{\link[=arrow_scalar_function]{arrow_scalar_function()}} -or \code{\link[=arrow_advanced_scalar_function]{arrow_advanced_scalar_function()}}. Scalar functions must be -stateless and return output with the same shape (i.e., the same -number of rows) as the input.} +Scalar functions must be stateless and return output with the same +shape (i.e., the same number of rows) as the input.} \item{name}{The function name to be used in the dplyr bindings} From dfbdbc2ee743f1beae85d1da46f6cf7148108b91 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 14 Jul 2022 10:00:20 -0300 Subject: [PATCH 75/92] fix link in documentation --- r/R/compute.R | 6 +++--- r/man/register_user_defined_function.Rd | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 8213da1b76e..67812d721b9 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -311,9 +311,9 @@ cast_options <- function(safe = TRUE, ...) { #' #' These functions support calling R code from query engine execution #' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]). -#' Use [arrow_scalar_function()] to define an R function that accepts and -#' returns R objects; use [arrow_advanced_scalar_function()] to define a -#' lower-level function that operates directly on Arrow objects. +#' Use [arrow_scalar_function()] attach input and output types to a vectorized +#' R function; use [register_user_defined_function()] to make it available +#' for use in the dplyr interface and/or [call_function()]. #' #' @param name The function name to be used in the dplyr bindings #' @param scalar_function An object created with [arrow_scalar_function()] diff --git a/r/man/register_user_defined_function.Rd b/r/man/register_user_defined_function.Rd index c1511ba7321..82b97b35756 100644 --- a/r/man/register_user_defined_function.Rd +++ b/r/man/register_user_defined_function.Rd @@ -51,9 +51,9 @@ Arrow R6 objects.} \description{ These functions support calling R code from query engine execution (i.e., a \code{\link[dplyr:mutate]{dplyr::mutate()}} or \code{\link[dplyr:filter]{dplyr::filter()}} on a \link{Table} or \link{Dataset}). -Use \code{\link[=arrow_scalar_function]{arrow_scalar_function()}} to define an R function that accepts and -returns R objects; use \code{\link[=arrow_advanced_scalar_function]{arrow_advanced_scalar_function()}} to define a -lower-level function that operates directly on Arrow objects. +Use \code{\link[=arrow_scalar_function]{arrow_scalar_function()}} attach input and output types to a vectorized +R function; use \code{\link[=register_user_defined_function]{register_user_defined_function()}} to make it available +for use in the dplyr interface and/or \code{\link[=call_function]{call_function()}}. } \examples{ \dontshow{if (arrow_with_dataset()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} From 7fd6a77e4a1ec3fe1c7243426abd1afbada4aefd Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 15 Jul 2022 14:34:27 -0300 Subject: [PATCH 76/92] back to the Python interface --- r/NAMESPACE | 3 +- r/R/compute.R | 25 ++++--- ...unction.Rd => register_scalar_function.Rd} | 28 +++----- r/tests/testthat/test-compute.R | 66 +++++++++---------- 4 files changed, 56 insertions(+), 66 deletions(-) rename r/man/{register_user_defined_function.Rd => register_scalar_function.Rd} (83%) diff --git a/r/NAMESPACE b/r/NAMESPACE index 4040dc3ac2f..60f53524c14 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -252,7 +252,6 @@ export(UnionDataset) export(all_of) export(arrow_available) export(arrow_info) -export(arrow_scalar_function) export(arrow_table) export(arrow_with_dataset) export(arrow_with_gcs) @@ -346,7 +345,7 @@ export(read_schema) export(read_tsv_arrow) export(record_batch) export(register_extension_type) -export(register_user_defined_function) +export(register_scalar_function) export(reregister_extension_type) export(s3_bucket) export(schema) diff --git a/r/R/compute.R b/r/R/compute.R index 67812d721b9..a79cc0e8df9 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -340,21 +340,17 @@ cast_options <- function(safe = TRUE, ...) { #' option to write functions of R objects as opposed to functions of #' Arrow R6 objects. #' -#' @return -#' - `register_user_defined_function()`: `NULL`, invisibly -#' - `arrow_scalar_function()`: returns an object of class -#' "arrow_scalar_function" that can be passed to -#' `register_user_defined_function()`. +#' @return `NULL`, invisibly #' @export #' #' @examplesIf arrow_with_dataset() -#' fun_wrapper <- arrow_scalar_function( +#' register_scalar_function( +#' "example_add3", #' function(context, x, y, z) x + y + z, #' schema(x = float64(), y = float64(), z = float64()), #' float64(), #' auto_convert = TRUE #' ) -#' register_user_defined_function(fun_wrapper, "example_add3") #' #' call_function( #' "example_add3", @@ -363,10 +359,15 @@ cast_options <- function(safe = TRUE, ...) { #' Array$create(3) #' ) #' -register_user_defined_function <- function(scalar_function, name) { - assert_that( - is.string(name), - inherits(scalar_function, "arrow_scalar_function") +register_scalar_function <- function(name, fun, in_type, out_type, + auto_convert = FALSE) { + assert_that(is.string(name)) + + scalar_function <- arrow_scalar_function( + fun, + in_type, + out_type, + auto_convert = auto_convert ) # register with Arrow C++ function registry (enables its use in @@ -383,8 +384,6 @@ register_user_defined_function <- function(scalar_function, name) { invisible(NULL) } -#' @rdname register_user_defined_function -#' @export arrow_scalar_function <- function(fun, in_type, out_type, auto_convert = FALSE) { fun <- as_function(fun) diff --git a/r/man/register_user_defined_function.Rd b/r/man/register_scalar_function.Rd similarity index 83% rename from r/man/register_user_defined_function.Rd rename to r/man/register_scalar_function.Rd index 82b97b35756..200c051e51d 100644 --- a/r/man/register_user_defined_function.Rd +++ b/r/man/register_scalar_function.Rd @@ -1,19 +1,12 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/compute.R -\name{register_user_defined_function} -\alias{register_user_defined_function} -\alias{arrow_scalar_function} +\name{register_scalar_function} +\alias{register_scalar_function} \title{Register user-defined functions} \usage{ -register_user_defined_function(scalar_function, name) - -arrow_scalar_function(fun, in_type, out_type, auto_convert = FALSE) +register_scalar_function(name, fun, in_type, out_type, auto_convert = FALSE) } \arguments{ -\item{scalar_function}{An object created with \code{\link[=arrow_scalar_function]{arrow_scalar_function()}} -Scalar functions must be stateless and return output with the same -shape (i.e., the same number of rows) as the input.} - \item{name}{The function name to be used in the dplyr bindings} \item{fun}{An R function or rlang-style lambda expression. The function @@ -39,14 +32,13 @@ function it must return a \link{DataType}.} and construct an Array of the correct type from the output. Use this option to write functions of R objects as opposed to functions of Arrow R6 objects.} + +\item{scalar_function}{An object created with \code{\link[=arrow_scalar_function]{arrow_scalar_function()}} +Scalar functions must be stateless and return output with the same +shape (i.e., the same number of rows) as the input.} } \value{ -\itemize{ -\item \code{register_user_defined_function()}: \code{NULL}, invisibly -\item \code{arrow_scalar_function()}: returns an object of class -"arrow_scalar_function" that can be passed to -\code{register_user_defined_function()}. -} +\code{NULL}, invisibly } \description{ These functions support calling R code from query engine execution @@ -57,13 +49,13 @@ for use in the dplyr interface and/or \code{\link[=call_function]{call_function( } \examples{ \dontshow{if (arrow_with_dataset()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} -fun_wrapper <- arrow_scalar_function( +register_scalar_function( + "example_add3", function(context, x, y, z) x + y + z, schema(x = float64(), y = float64(), z = float64()), float64(), auto_convert = TRUE ) -register_user_defined_function(fun_wrapper, "example_add3") call_function( "example_add3", diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index d1b541bf46a..3268ce3c9ca 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -79,17 +79,16 @@ test_that("arrow_scalar_function() works with auto_convert = TRUE", { ) }) -test_that("register_user_defined_function() adds a compute function to the registry", { +test_that("register_scalar_function() adds a compute function to the registry", { skip_if_not(CanRunWithCapturedR()) - times_32_wrapper <- arrow_scalar_function( + register_scalar_function( + "times_32", function(context, x) x * 32.0, int32(), float64(), auto_convert = TRUE ) - register_user_defined_function(times_32_wrapper, "times_32") - expect_true("times_32" %in% names(asNamespace("arrow")$.cache$functions)) expect_true("times_32" %in% list_compute_functions()) @@ -114,23 +113,25 @@ test_that("register_user_defined_function() adds a compute function to the regis test_that("arrow_scalar_function() with bad return type errors", { skip_if_not(CanRunWithCapturedR()) - times_32_wrapper <- arrow_scalar_function( + register_scalar_function( + "times_32_bad_return_type", function(context, x) Array$create(x, int32()), - int32(), float64() + int32(), + float64() ) - register_user_defined_function(times_32_wrapper, "times_32_bad_return_type") expect_error( call_function("times_32_bad_return_type", Array$create(1L)), "Expected return Array or Scalar with type 'double'" ) - times_32_wrapper <- arrow_scalar_function( + register_scalar_function( + "times_32_bad_return_type", function(context, x) Scalar$create(x, int32()), - int32(), float64() + int32(), + float64() ) - register_user_defined_function(times_32_wrapper, "times_32_bad_return_type") expect_error( call_function("times_32_bad_return_type", Array$create(1L)), "Expected return Array or Scalar with type 'double'" @@ -140,15 +141,14 @@ test_that("arrow_scalar_function() with bad return type errors", { test_that("register_user_defined_function() can register multiple kernels", { skip_if_not(CanRunWithCapturedR()) - times_32_wrapper <- arrow_scalar_function( + register_scalar_function( + "times_32", function(context, x) x * 32L, in_type = list(int32(), int64(), float64()), out_type = function(in_types) in_types[[1]], auto_convert = TRUE ) - register_user_defined_function(times_32_wrapper, "times_32") - expect_equal( call_function("times_32", Scalar$create(1L, int32())), Scalar$create(32L, int32()) @@ -166,23 +166,23 @@ test_that("register_user_defined_function() can register multiple kernels", { }) test_that("register_user_defined_function() errors for unsupported specifications", { - no_kernel_wrapper <- arrow_scalar_function( - function(...) NULL, - list(), - list() - ) expect_error( - register_user_defined_function(no_kernel_wrapper, "no_kernels"), + register_scalar_function( + "no_kernels", + function(...) NULL, + list(), + list() + ), "Can't register user-defined function with zero kernels" ) - varargs_kernel_wrapper <- arrow_scalar_function( - function(...) NULL, - list(float64(), schema(x = float64(), y = float64())), - list(float64()) - ) expect_error( - register_user_defined_function(varargs_kernel_wrapper, "var_kernels"), + register_scalar_function( + "var_kernels", + function(...) NULL, + list(float64(), schema(x = float64(), y = float64())), + float64() + ), "Kernels for user-defined function must accept the same number of arguments" ) }) @@ -208,14 +208,14 @@ test_that("user-defined functions work during multi-threaded execution", { on.exit(unlink(c(tf_dataset, tf_dest))) write_dataset(example_df, tf_dataset, partitioning = "part") - times_32_wrapper <- arrow_scalar_function( + register_scalar_function( + "times_32", function(context, x) x * 32.0, - int32(), float64(), + int32(), + float64(), auto_convert = TRUE ) - register_user_defined_function(times_32_wrapper, "times_32") - # check a regular collect() result <- open_dataset(tf_dataset) %>% dplyr::mutate(fun_result = times_32(value)) %>% @@ -240,14 +240,14 @@ test_that("user-defined error when called from an unsupported context", { skip_if_not_available("dataset") skip_if_not(CanRunWithCapturedR()) - times_32_wrapper <- arrow_scalar_function( + register_scalar_function( + "times_32", function(context, x) x * 32.0, - int32(), float64(), + int32(), + float64(), auto_convert = TRUE ) - register_user_defined_function(times_32_wrapper, "times_32") - stream_plan_with_udf <- function() { rbr <- record_batch(a = 1:1000) %>% dplyr::mutate(b = times_32(a)) %>% From b07e736a4b50cc8e50a1d21555715719cdf24cb6 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 15 Jul 2022 14:53:02 -0300 Subject: [PATCH 77/92] better example, fix pkgdown reference --- r/R/compute.R | 23 +++++++++++++---------- r/_pkgdown.yml | 2 +- r/man/register_scalar_function.Rd | 23 +++++++++++++---------- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index a79cc0e8df9..2f88ac4f49e 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -344,20 +344,23 @@ cast_options <- function(safe = TRUE, ...) { #' @export #' #' @examplesIf arrow_with_dataset() +#' library(dplyr, warn.conflicts = FALSE) +#' +#' some_model <- lm(mpg ~ disp + cyl, data = mtcars) #' register_scalar_function( -#' "example_add3", -#' function(context, x, y, z) x + y + z, -#' schema(x = float64(), y = float64(), z = float64()), -#' float64(), +#' "mtcars_predict_mpg", +#' function(context, disp, cyl) { +#' predict(some_model, newdata = data.frame(disp, cyl)) +#' }, +#' in_type = schema(disp = float64(), cyl = float64()), +#' out_type = float64(), #' auto_convert = TRUE #' ) #' -#' call_function( -#' "example_add3", -#' Scalar$create(1), -#' Scalar$create(2), -#' Array$create(3) -#' ) +#' as_arrow_table(mtcars) %>% +#' transmute(mpg, mpg_predicted = mtcars_predict_mpg(disp, cyl)) %>% +#' collect() %>% +#' head() #' register_scalar_function <- function(name, fun, in_type, out_type, auto_convert = FALSE) { diff --git a/r/_pkgdown.yml b/r/_pkgdown.yml index 5241c93671b..b04cab8195e 100644 --- a/r/_pkgdown.yml +++ b/r/_pkgdown.yml @@ -219,7 +219,7 @@ reference: - match_arrow - value_counts - list_compute_functions - - register_user_defined_function + - register_scalar_function - title: Connections to other systems contents: - to_arrow diff --git a/r/man/register_scalar_function.Rd b/r/man/register_scalar_function.Rd index 200c051e51d..2eae361680c 100644 --- a/r/man/register_scalar_function.Rd +++ b/r/man/register_scalar_function.Rd @@ -49,19 +49,22 @@ for use in the dplyr interface and/or \code{\link[=call_function]{call_function( } \examples{ \dontshow{if (arrow_with_dataset()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} +library(dplyr, warn.conflicts = FALSE) + +some_model <- lm(mpg ~ disp + cyl, data = mtcars) register_scalar_function( - "example_add3", - function(context, x, y, z) x + y + z, - schema(x = float64(), y = float64(), z = float64()), - float64(), + "mtcars_predict_mpg", + function(context, disp, cyl) { + predict(some_model, newdata = data.frame(disp, cyl)) + }, + in_type = schema(disp = float64(), cyl = float64()), + out_type = float64(), auto_convert = TRUE ) -call_function( - "example_add3", - Scalar$create(1), - Scalar$create(2), - Array$create(3) -) +as_arrow_table(mtcars) \%>\% + transmute(mpg, mpg_predicted = mtcars_predict_mpg(disp, cyl)) \%>\% + collect() \%>\% + head() \dontshow{\}) # examplesIf} } From c79aca546161c8da2f005217a18886807be37e3b Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 15 Jul 2022 15:32:42 -0300 Subject: [PATCH 78/92] fix doc --- r/R/compute.R | 8 +++++--- r/man/register_scalar_function.Rd | 9 +++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 2f88ac4f49e..5534be644e9 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -313,12 +313,14 @@ cast_options <- function(safe = TRUE, ...) { #' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]). #' Use [arrow_scalar_function()] attach input and output types to a vectorized #' R function; use [register_user_defined_function()] to make it available -#' for use in the dplyr interface and/or [call_function()]. +#' for use in the dplyr interface and/or [call_function()]. Scalar functions +#' are currently the only type of user-defined function supported. +#' In Arrow, scalar functions must be stateless and return output with the +#' same shape (i.e., the same `number of rows) as the input. #' #' @param name The function name to be used in the dplyr bindings #' @param scalar_function An object created with [arrow_scalar_function()] -#' Scalar functions must be stateless and return output with the same -#' shape (i.e., the same number of rows) as the input. +#' #' @param in_type A [DataType] of the input type or a [schema()] #' for functions with more than one argument. This signature will be used #' to determine if this function is appropriate for a given set of arguments. diff --git a/r/man/register_scalar_function.Rd b/r/man/register_scalar_function.Rd index 2eae361680c..76f52f9cb74 100644 --- a/r/man/register_scalar_function.Rd +++ b/r/man/register_scalar_function.Rd @@ -33,9 +33,7 @@ and construct an Array of the correct type from the output. Use this option to write functions of R objects as opposed to functions of Arrow R6 objects.} -\item{scalar_function}{An object created with \code{\link[=arrow_scalar_function]{arrow_scalar_function()}} -Scalar functions must be stateless and return output with the same -shape (i.e., the same number of rows) as the input.} +\item{scalar_function}{An object created with \code{\link[=arrow_scalar_function]{arrow_scalar_function()}}} } \value{ \code{NULL}, invisibly @@ -45,7 +43,10 @@ These functions support calling R code from query engine execution (i.e., a \code{\link[dplyr:mutate]{dplyr::mutate()}} or \code{\link[dplyr:filter]{dplyr::filter()}} on a \link{Table} or \link{Dataset}). Use \code{\link[=arrow_scalar_function]{arrow_scalar_function()}} attach input and output types to a vectorized R function; use \code{\link[=register_user_defined_function]{register_user_defined_function()}} to make it available -for use in the dplyr interface and/or \code{\link[=call_function]{call_function()}}. +for use in the dplyr interface and/or \code{\link[=call_function]{call_function()}}. Scalar functions +are currently the only type of user-defined function supported. +In Arrow, scalar functions must be stateless and return output with the +same shape (i.e., the same `number of rows) as the input. } \examples{ \dontshow{if (arrow_with_dataset()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} From 8d187547928e63a2aeb08c8761fe9c8a1455cfa4 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 15 Jul 2022 16:52:59 -0300 Subject: [PATCH 79/92] maybe fix doc again --- r/R/compute.R | 2 -- r/man/register_scalar_function.Rd | 2 -- 2 files changed, 4 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 5534be644e9..10e893ecb45 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -319,8 +319,6 @@ cast_options <- function(safe = TRUE, ...) { #' same shape (i.e., the same `number of rows) as the input. #' #' @param name The function name to be used in the dplyr bindings -#' @param scalar_function An object created with [arrow_scalar_function()] -#' #' @param in_type A [DataType] of the input type or a [schema()] #' for functions with more than one argument. This signature will be used #' to determine if this function is appropriate for a given set of arguments. diff --git a/r/man/register_scalar_function.Rd b/r/man/register_scalar_function.Rd index 76f52f9cb74..5c165e570e6 100644 --- a/r/man/register_scalar_function.Rd +++ b/r/man/register_scalar_function.Rd @@ -32,8 +32,6 @@ function it must return a \link{DataType}.} and construct an Array of the correct type from the output. Use this option to write functions of R objects as opposed to functions of Arrow R6 objects.} - -\item{scalar_function}{An object created with \code{\link[=arrow_scalar_function]{arrow_scalar_function()}}} } \value{ \code{NULL}, invisibly From 8d33b0700c79a38eb226bcf4cae4535172ef34de Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 15 Jul 2022 16:56:03 -0300 Subject: [PATCH 80/92] more doc fixes --- r/R/compute.R | 12 ++++++------ r/man/register_scalar_function.Rd | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 10e893ecb45..c665de4e833 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -311,12 +311,12 @@ cast_options <- function(safe = TRUE, ...) { #' #' These functions support calling R code from query engine execution #' (i.e., a [dplyr::mutate()] or [dplyr::filter()] on a [Table] or [Dataset]). -#' Use [arrow_scalar_function()] attach input and output types to a vectorized -#' R function; use [register_user_defined_function()] to make it available -#' for use in the dplyr interface and/or [call_function()]. Scalar functions -#' are currently the only type of user-defined function supported. -#' In Arrow, scalar functions must be stateless and return output with the -#' same shape (i.e., the same `number of rows) as the input. +#' Use [register_scalar_function()] attach Arrow input and output types to an +#' R function and make it available for use in the dplyr interface and/or +#' [call_function()]. Scalar functions are currently the only type of +#' user-defined function supported. In Arrow, scalar functions must be +#' stateless and return output with the same shape (i.e., the same number +#' of rows) as the input. #' #' @param name The function name to be used in the dplyr bindings #' @param in_type A [DataType] of the input type or a [schema()] diff --git a/r/man/register_scalar_function.Rd b/r/man/register_scalar_function.Rd index 5c165e570e6..c445ef6f4b9 100644 --- a/r/man/register_scalar_function.Rd +++ b/r/man/register_scalar_function.Rd @@ -39,12 +39,12 @@ Arrow R6 objects.} \description{ These functions support calling R code from query engine execution (i.e., a \code{\link[dplyr:mutate]{dplyr::mutate()}} or \code{\link[dplyr:filter]{dplyr::filter()}} on a \link{Table} or \link{Dataset}). -Use \code{\link[=arrow_scalar_function]{arrow_scalar_function()}} attach input and output types to a vectorized -R function; use \code{\link[=register_user_defined_function]{register_user_defined_function()}} to make it available -for use in the dplyr interface and/or \code{\link[=call_function]{call_function()}}. Scalar functions -are currently the only type of user-defined function supported. -In Arrow, scalar functions must be stateless and return output with the -same shape (i.e., the same `number of rows) as the input. +Use \code{\link[=register_scalar_function]{register_scalar_function()}} attach Arrow input and output types to an +R function and make it available for use in the dplyr interface and/or +\code{\link[=call_function]{call_function()}}. Scalar functions are currently the only type of +user-defined function supported. In Arrow, scalar functions must be +stateless and return output with the same shape (i.e., the same number +of rows) as the input. } \examples{ \dontshow{if (arrow_with_dataset()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} From abd938a87b47b9d5337926898e028ae8a6c98a8c Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 15 Jul 2022 22:54:35 -0300 Subject: [PATCH 81/92] adapt for updated register_binding() --- r/R/dplyr-funcs.R | 36 ++++++++++++++++++++++------- r/tests/testthat/test-compute.R | 18 +++++++++++---- r/tests/testthat/test-dplyr-funcs.R | 7 +++--- 3 files changed, 46 insertions(+), 15 deletions(-) diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R index 0bc278f0241..c1dcdd17744 100644 --- a/r/R/dplyr-funcs.R +++ b/r/R/dplyr-funcs.R @@ -64,13 +64,14 @@ NULL #' registered function existed. #' @keywords internal #' -register_binding <- function(fun_name, fun, registry = nse_funcs, update_cache = FALSE) { +register_binding <- function(fun_name, fun, registry = nse_funcs, + update_cache = FALSE) { unqualified_name <- sub("^.*?:{+}", "", fun_name) previous_fun <- registry[[unqualified_name]] # if the unqualified name exists in the registry, warn - if (!is.null(fun) && !is.null(previous_fun)) { + if (!is.null(previous_fun)) { warn( paste0( "A \"", @@ -80,16 +81,35 @@ register_binding <- function(fun_name, fun, registry = nse_funcs, update_cache = } # register both as `pkg::fun` and as `fun` if `qualified_name` is prefixed - if (grepl("::", fun_name)) { - registry[[unqualified_name]] <- fun - registry[[fun_name]] <- fun - } else { - registry[[unqualified_name]] <- fun + # unqualified_name and fun_name will be the same if not prefixed + registry[[unqualified_name]] <- fun + registry[[fun_name]] <- fun + + if (update_cache) { + fun_cache <- .cache$functions + fun_cache[[unqualified_name]] <- fun + fun_cache[[fun_name]] <- fun + .cache$functions <- fun_cache } + invisible(previous_fun) +} + +unregister_binding <- function(fun_name, registry = nse_funcs, + update_cache = FALSE) { + unqualified_name <- sub("^.*?:{+}", "", fun_name) + previous_fun <- registry[[unqualified_name]] + + rm( + list = unique(c(fun_name, unqualified_name)), + envir = registry, + inherits = FALSE + ) + if (update_cache) { fun_cache <- .cache$functions - fun_cache[[name]] <- fun + fun_cache[[unqualified_name]] <- NULL + fun_cache[[fun_name]] <- NULL .cache$functions <- fun_cache } diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 3268ce3c9ca..2662aa33efe 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -88,6 +88,7 @@ test_that("register_scalar_function() adds a compute function to the registry", int32(), float64(), auto_convert = TRUE ) + on.exit(unregister_binding("times_32", update_cache = TRUE)) expect_true("times_32" %in% names(asNamespace("arrow")$.cache$functions)) expect_true("times_32" %in% list_compute_functions()) @@ -114,26 +115,32 @@ test_that("arrow_scalar_function() with bad return type errors", { skip_if_not(CanRunWithCapturedR()) register_scalar_function( - "times_32_bad_return_type", + "times_32_bad_return_type_array", function(context, x) Array$create(x, int32()), int32(), float64() ) + on.exit( + unregister_binding("times_32_bad_return_type_array", update_cache = TRUE) + ) expect_error( - call_function("times_32_bad_return_type", Array$create(1L)), + call_function("times_32_bad_return_type_array", Array$create(1L)), "Expected return Array or Scalar with type 'double'" ) register_scalar_function( - "times_32_bad_return_type", + "times_32_bad_return_type_scalar", function(context, x) Scalar$create(x, int32()), int32(), float64() ) + on.exit( + unregister_binding("times_32_bad_return_type_scalar", update_cache = TRUE) + ) expect_error( - call_function("times_32_bad_return_type", Array$create(1L)), + call_function("times_32_bad_return_type_scalar", Array$create(1L)), "Expected return Array or Scalar with type 'double'" ) }) @@ -148,6 +155,7 @@ test_that("register_user_defined_function() can register multiple kernels", { out_type = function(in_types) in_types[[1]], auto_convert = TRUE ) + on.exit(unregister_binding("times_32", update_cache = TRUE)) expect_equal( call_function("times_32", Scalar$create(1L, int32())), @@ -215,6 +223,7 @@ test_that("user-defined functions work during multi-threaded execution", { float64(), auto_convert = TRUE ) + on.exit(unregister_binding("times_32", update_cache = TRUE)) # check a regular collect() result <- open_dataset(tf_dataset) %>% @@ -247,6 +256,7 @@ test_that("user-defined error when called from an unsupported context", { float64(), auto_convert = TRUE ) + on.exit(unregister_binding("times_32", update_cache = TRUE)) stream_plan_with_udf <- function() { rbr <- record_batch(a = 1:1000) %>% diff --git a/r/tests/testthat/test-dplyr-funcs.R b/r/tests/testthat/test-dplyr-funcs.R index 2156ad9af06..86f984dd32c 100644 --- a/r/tests/testthat/test-dplyr-funcs.R +++ b/r/tests/testthat/test-dplyr-funcs.R @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -test_that("register_binding() works", { +test_that("register_binding()/unregister_binding() works", { fake_registry <- new.env(parent = emptyenv()) fun1 <- function() NULL fun2 <- function() "Hello" @@ -24,8 +24,9 @@ test_that("register_binding() works", { expect_identical(fake_registry$some_fun, fun1) expect_identical(fake_registry$`some.pkg::some_fun`, fun1) - expect_identical(register_binding("some.pkg::some_fun", NULL, fake_registry), fun1) - expect_silent(expect_null(register_binding("some.pkg::some_fun", NULL, fake_registry))) + expect_identical(unregister_binding("some.pkg::some_fun", fake_registry), fun1) + expect_false("some.pkg::some_fun" %in% names(fake_registry)) + expect_false("some_fun" %in% names(fake_registry)) expect_null(register_binding("somePkg::some_fun", fun1, fake_registry)) expect_identical(fake_registry$some_fun, fun1) From 86c1c7e94033393885393fe0c80c2bc2d7d81faa Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 20 Jul 2022 15:41:58 -0300 Subject: [PATCH 82/92] improve comments in compute.R --- r/R/compute.R | 17 ++++++++++++++--- r/R/query-engine.R | 6 ++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index c665de4e833..36ecb65cafe 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -391,9 +391,9 @@ arrow_scalar_function <- function(fun, in_type, out_type, auto_convert = FALSE) fun <- as_function(fun) # Create a small wrapper function that is easier to call from C++. - # This wrapper could be implemented in C/C++ to reduce evaluation - # overhead and generate prettier backtraces when errors occur - # (probably using a similar approach to purrr). + # TODO(ARROW-17148): This wrapper could be implemented in C/C++ to + # reduce evaluation overhead and generate prettier backtraces when + # errors occur (probably using a similar approach to purrr). if (auto_convert) { wrapper_fun <- function(context, args) { args <- lapply(args, as.vector) @@ -406,18 +406,22 @@ arrow_scalar_function <- function(fun, in_type, out_type, auto_convert = FALSE) } } + # in_type can be a list() if registering multiple kernels at once if (is.list(in_type)) { in_type <- lapply(in_type, as_scalar_function_in_type) } else { in_type <- list(as_scalar_function_in_type(in_type)) } + # out_type can be a list() if registering multiple kernels at once if (is.list(out_type)) { out_type <- lapply(out_type, as_scalar_function_out_type) } else { out_type <- list(as_scalar_function_out_type(out_type)) } + # recycle out_type (which is frequently length 1 even if multiple kernels + # are being registered at once) out_type <- rep_len(out_type, length(in_type)) structure( @@ -430,6 +434,10 @@ arrow_scalar_function <- function(fun, in_type, out_type, auto_convert = FALSE) ) } +# This function sanitizes the in_type argument for arrow_scalar_function(), +# which can be a data type (e.g., int32()), a field for a unary function +# or a schema() for functions accepting more than one argument. C++ expects +# a schema(). as_scalar_function_in_type <- function(x) { if (inherits(x, "Field")) { schema(x) @@ -440,6 +448,9 @@ as_scalar_function_in_type <- function(x) { } } +# This function sanitizes the out_type argument for arrow_scalar_function(), +# which can be a data type (e.g., int32()) or a function of the input types. +# C++ currently expects a function. as_scalar_function_out_type <- function(x) { if (is.function(x)) { x diff --git a/r/R/query-engine.R b/r/R/query-engine.R index 16dbb1c7772..d7370e1b957 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -209,6 +209,12 @@ ExecPlan <- R6Class("ExecPlan", sorting$orders <- as.integer(sorting$orders) } + # If we are going to return a Table anyway, we do this in one step and + # entirely in one C++ call to ensure that we can execute user-defined + # functions from the worker threads spawned by the ExecPlan. If not, we + # use ExecPlan_run which returns a RecordBatchReader that can be + # manipulated in R code (but that right now won't work with + # user-defined functions). exec_fun <- if (as_table) ExecPlan_read_table else ExecPlan_run out <- exec_fun( self, From 652175f554790c9995cb8be562aff234062bb47f Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 20 Jul 2022 15:49:49 -0300 Subject: [PATCH 83/92] maybe fix linter error --- r/R/query-engine.R | 3 +++ 1 file changed, 3 insertions(+) diff --git a/r/R/query-engine.R b/r/R/query-engine.R index d7370e1b957..e63fa75ebf1 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# nolint start: cyclocomp_linter, ExecPlan <- R6Class("ExecPlan", inherit = ArrowObject, public = list( @@ -275,6 +276,8 @@ ExecPlan <- R6Class("ExecPlan", Stop = function() ExecPlan_StopProducing(self) ) ) +# nolint end. + ExecPlan$create <- function(use_threads = option_use_threads()) { ExecPlan_create(use_threads) } From 12b972158a1f825a6a5fa9eed23f91a06f91be49 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 20 Jul 2022 15:57:48 -0300 Subject: [PATCH 84/92] better names for in_type/out_type sanitizers --- r/R/compute.R | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 36ecb65cafe..57c54a43a98 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -408,16 +408,16 @@ arrow_scalar_function <- function(fun, in_type, out_type, auto_convert = FALSE) # in_type can be a list() if registering multiple kernels at once if (is.list(in_type)) { - in_type <- lapply(in_type, as_scalar_function_in_type) + in_type <- lapply(in_type, in_type_as_schema) } else { - in_type <- list(as_scalar_function_in_type(in_type)) + in_type <- list(in_type_as_schema(in_type)) } # out_type can be a list() if registering multiple kernels at once if (is.list(out_type)) { - out_type <- lapply(out_type, as_scalar_function_out_type) + out_type <- lapply(out_type, out_type_as_function) } else { - out_type <- list(as_scalar_function_out_type(out_type)) + out_type <- list(out_type_as_function(out_type)) } # recycle out_type (which is frequently length 1 even if multiple kernels @@ -438,7 +438,7 @@ arrow_scalar_function <- function(fun, in_type, out_type, auto_convert = FALSE) # which can be a data type (e.g., int32()), a field for a unary function # or a schema() for functions accepting more than one argument. C++ expects # a schema(). -as_scalar_function_in_type <- function(x) { +in_type_as_schema <- function(x) { if (inherits(x, "Field")) { schema(x) } else if (inherits(x, "DataType")) { @@ -451,7 +451,7 @@ as_scalar_function_in_type <- function(x) { # This function sanitizes the out_type argument for arrow_scalar_function(), # which can be a data type (e.g., int32()) or a function of the input types. # C++ currently expects a function. -as_scalar_function_out_type <- function(x) { +out_type_as_function <- function(x) { if (is.function(x)) { x } else { From 259eed9b0700d6fa7b6405611964316affa7733e Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 20 Jul 2022 16:23:13 -0300 Subject: [PATCH 85/92] make sure an exec plan with head() works --- r/R/query-engine.R | 4 ++-- r/tests/testthat/test-compute.R | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/r/R/query-engine.R b/r/R/query-engine.R index e63fa75ebf1..06a70164894 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -82,8 +82,8 @@ ExecPlan <- R6Class("ExecPlan", # head and tail are not ExecNodes; at best we can handle them via # SinkNode, so if there are any steps done after head/tail, we need to # evaluate the query up to then and then do a new query for the rest. - # as_record_batch_reader() will build and run an ExecPlan - node <- self$SourceNode(as_record_batch_reader(.data$.data)) + # as_arrow_table() will build and run an ExecPlan + node <- self$SourceNode(as_arrow_table(.data$.data)) } else { # Recurse node <- self$Build(.data$.data) diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 2662aa33efe..1a45229b65b 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -233,6 +233,13 @@ test_that("user-defined functions work during multi-threaded execution", { expect_identical(result$fun_result, example_df$value * 32) + # check an exec plan with head() + result <- open_dataset(tf_dataset) %>% + dplyr::mutate(fun_result = times_32(value)) %>% + head(11) %>% + dplyr::collect() + expect_equal(nrow(result), 11) + # check a write_dataset() open_dataset(tf_dataset) %>% dplyr::mutate(fun_result = times_32(value)) %>% From 1f8b24832f4f92d1dc8898ec254c90879120345e Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 21 Jul 2022 10:17:50 -0300 Subject: [PATCH 86/92] clarify the `context` argument --- r/R/compute.R | 3 ++- r/man/register_scalar_function.Rd | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 57c54a43a98..aa0ce626bd6 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -330,7 +330,8 @@ cast_options <- function(safe = TRUE, ...) { #' @param fun An R function or rlang-style lambda expression. The function #' will be called with a first argument `context` which is a `list()` #' with elements `batch_size` (the expected length of the output) and -#' `output_type` (the required [DataType] of the output). Subsequent +#' `output_type` (the required [DataType] of the output) that may be used +#' to ensure that the output has the correct type and length. Subsequent #' arguments are passed by position as specified by `in_types`. If #' `auto_convert` is `TRUE`, subsequent arguments are converted to #' R vectors before being passed to `fun` and the output is automatically diff --git a/r/man/register_scalar_function.Rd b/r/man/register_scalar_function.Rd index c445ef6f4b9..4da8f54f645 100644 --- a/r/man/register_scalar_function.Rd +++ b/r/man/register_scalar_function.Rd @@ -12,7 +12,8 @@ register_scalar_function(name, fun, in_type, out_type, auto_convert = FALSE) \item{fun}{An R function or rlang-style lambda expression. The function will be called with a first argument \code{context} which is a \code{list()} with elements \code{batch_size} (the expected length of the output) and -\code{output_type} (the required \link{DataType} of the output). Subsequent +\code{output_type} (the required \link{DataType} of the output) that may be used +to ensure that the output has the correct type and length. Subsequent arguments are passed by position as specified by \code{in_types}. If \code{auto_convert} is \code{TRUE}, subsequent arguments are converted to R vectors before being passed to \code{fun} and the output is automatically From 510221f733d1e68644f87cfab0569041e0401029 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 21 Jul 2022 10:46:38 -0300 Subject: [PATCH 87/92] don't allow rlang-style lambdas quite yet --- r/R/compute.R | 2 +- r/tests/testthat/_snaps/compute.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index aa0ce626bd6..f8d05134d49 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -389,7 +389,7 @@ register_scalar_function <- function(name, fun, in_type, out_type, } arrow_scalar_function <- function(fun, in_type, out_type, auto_convert = FALSE) { - fun <- as_function(fun) + assert_that(is.function(fun)) # Create a small wrapper function that is easier to call from C++. # TODO(ARROW-17148): This wrapper could be implemented in C/C++ to diff --git a/r/tests/testthat/_snaps/compute.md b/r/tests/testthat/_snaps/compute.md index 626608ad2e7..89506a7fbc2 100644 --- a/r/tests/testthat/_snaps/compute.md +++ b/r/tests/testthat/_snaps/compute.md @@ -1,4 +1,4 @@ # arrow_scalar_function() works - Can't convert `fun`, NULL, to a function. + fun is not a function From e906632827379dd31baddc51abb90568f2010b23 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 21 Jul 2022 11:28:05 -0300 Subject: [PATCH 88/92] check formals of fun --- r/R/compute.R | 20 ++++++++++++++++++++ r/tests/testthat/test-compute.R | 12 +++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/r/R/compute.R b/r/R/compute.R index f8d05134d49..8812df40b8a 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -425,6 +425,26 @@ arrow_scalar_function <- function(fun, in_type, out_type, auto_convert = FALSE) # are being registered at once) out_type <- rep_len(out_type, length(in_type)) + # check n_kernels and number of args in fun + n_kernels <- length(in_type) + if (n_kernels == 0) { + abort("Can't register user-defined scalar function with 0 kernels") + } + + expected_n_args <- in_type[[1]]$num_fields + 1L + fun_formals_have_dots <- any(names(formals(fun)) == "...") + if (!fun_formals_have_dots && length(formals(fun)) != expected_n_args) { + abort( + glue::glue( + paste0( + "Expected `fun` to accept {expected_n_args} argument(s)\n", + "but found a function that acccepts {length(formals(fun))} argument(s)\n", + "Did you forget to include `context` as the first argument?" + ) + ) + ) + } + structure( list( wrapper_fun = wrapper_fun, diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 1a45229b65b..fe016e18cce 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -181,7 +181,17 @@ test_that("register_user_defined_function() errors for unsupported specification list(), list() ), - "Can't register user-defined function with zero kernels" + "Can't register user-defined scalar function with 0 kernels" + ) + + expect_error( + register_scalar_function( + "wrong_n_args", + function(x) NULL, + int32(), + int32() + ), + "Expected `fun` to accept 2 argument\\(s\\)" ) expect_error( From 1805836eed93b17060d858a3067eb208dcf57b62 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 21 Jul 2022 16:20:39 -0300 Subject: [PATCH 89/92] don't use glue::glue --- r/R/compute.R | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/r/R/compute.R b/r/R/compute.R index 8812df40b8a..0985e73a5f2 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -435,12 +435,14 @@ arrow_scalar_function <- function(fun, in_type, out_type, auto_convert = FALSE) fun_formals_have_dots <- any(names(formals(fun)) == "...") if (!fun_formals_have_dots && length(formals(fun)) != expected_n_args) { abort( - glue::glue( + sprintf( paste0( - "Expected `fun` to accept {expected_n_args} argument(s)\n", - "but found a function that acccepts {length(formals(fun))} argument(s)\n", + "Expected `fun` to accept %d argument(s)\n", + "but found a function that acccepts %d argument(s)\n", "Did you forget to include `context` as the first argument?" - ) + ), + expected_n_args, + length(formals(fun)) ) ) } From aa9165f293ccd8ab15234ff1df6f6e012b99b04e Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 21 Jul 2022 21:55:27 -0300 Subject: [PATCH 90/92] Update r/tests/testthat/test-compute.R Co-authored-by: Neal Richardson --- r/tests/testthat/test-compute.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index fe016e18cce..dc94c7e8c95 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -276,10 +276,10 @@ test_that("user-defined error when called from an unsupported context", { on.exit(unregister_binding("times_32", update_cache = TRUE)) stream_plan_with_udf <- function() { - rbr <- record_batch(a = 1:1000) %>% + record_batch(a = 1:1000) %>% dplyr::mutate(b = times_32(a)) %>% - as_record_batch_reader() - rbr$read_table() + as_record_batch_reader() %>% + as_arrow_table() } if (identical(tolower(Sys.info()[["sysname"]]), "windows")) { From 7952710f2da628fb2e7ad38dd3c2e60b790cb677 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 21 Jul 2022 22:02:48 -0300 Subject: [PATCH 91/92] revert very bad eager evaluation! --- r/R/query-engine.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/r/R/query-engine.R b/r/R/query-engine.R index 06a70164894..e63fa75ebf1 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -82,8 +82,8 @@ ExecPlan <- R6Class("ExecPlan", # head and tail are not ExecNodes; at best we can handle them via # SinkNode, so if there are any steps done after head/tail, we need to # evaluate the query up to then and then do a new query for the rest. - # as_arrow_table() will build and run an ExecPlan - node <- self$SourceNode(as_arrow_table(.data$.data)) + # as_record_batch_reader() will build and run an ExecPlan + node <- self$SourceNode(as_record_batch_reader(.data$.data)) } else { # Recurse node <- self$Build(.data$.data) From e31f2b1e5932f52bf6c42eea4f8ae2ce467b1c05 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 21 Jul 2022 22:14:17 -0300 Subject: [PATCH 92/92] fix test for plan with head() --- r/tests/testthat/test-compute.R | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index dc94c7e8c95..946583ae004 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -243,13 +243,6 @@ test_that("user-defined functions work during multi-threaded execution", { expect_identical(result$fun_result, example_df$value * 32) - # check an exec plan with head() - result <- open_dataset(tf_dataset) %>% - dplyr::mutate(fun_result = times_32(value)) %>% - head(11) %>% - dplyr::collect() - expect_equal(nrow(result), 11) - # check a write_dataset() open_dataset(tf_dataset) %>% dplyr::mutate(fun_result = times_32(value)) %>% @@ -282,6 +275,13 @@ test_that("user-defined error when called from an unsupported context", { as_arrow_table() } + collect_plan_with_head <- function() { + record_batch(a = 1:1000) %>% + dplyr::mutate(fun_result = times_32(a)) %>% + head(11) %>% + dplyr::collect() + } + if (identical(tolower(Sys.info()[["sysname"]]), "windows")) { expect_equal( stream_plan_with_udf(), @@ -289,10 +289,17 @@ test_that("user-defined error when called from an unsupported context", { dplyr::mutate(b = times_32(a)) %>% dplyr::collect(as_data_frame = FALSE) ) + + result <- collect_plan_with_head() + expect_equal(nrow(result), 11) } else { expect_error( stream_plan_with_udf(), "Call to R \\(.*?\\) from a non-R thread from an unsupported context" ) + expect_error( + collect_plan_with_head(), + "Call to R \\(.*?\\) from a non-R thread from an unsupported context" + ) } })