Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion r/R/compute.R
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,17 @@ register_scalar_function <- function(name, fun, in_type, out_type,
RegisterScalarUDF(name, scalar_function)

# register with dplyr binding (enables its use in mutate(), filter(), etc.)
binding_fun <- function(...) build_expr(name, ...)

# inject the value of `name` into the expression to avoid saving this
# execution environment in the binding, which eliminates a warning when the
# same binding is registered twice
body(binding_fun) <- expr_substitute(body(binding_fun), sym("name"), name)
environment(binding_fun) <- asNamespace("arrow")

register_binding(
name,
function(...) build_expr(name, ...),
binding_fun,
update_cache = TRUE
)

Expand Down
2 changes: 1 addition & 1 deletion r/R/dplyr-funcs.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ register_binding <- function(fun_name,
previous_fun <- registry[[unqualified_name]]

# if the unqualified name exists in the registry, warn
if (!is.null(previous_fun)) {
if (!is.null(previous_fun) && !identical(fun, previous_fun)) {
warn(
paste0(
"A \"",
Expand Down
10 changes: 6 additions & 4 deletions r/src/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,8 +611,8 @@ class RScalarUDFKernelState : public arrow::compute::KernelState {
RScalarUDFKernelState(cpp11::sexp exec_func, cpp11::sexp resolver)
: exec_func_(exec_func), resolver_(resolver) {}

cpp11::function exec_func_;
cpp11::function resolver_;
cpp11::sexp exec_func_;
cpp11::sexp resolver_;
};

arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
Expand All @@ -630,7 +630,8 @@ arrow::Result<arrow::TypeHolder> ResolveScalarUDFOutputType(
cpp11::to_r6<arrow::DataType>(input_types[i].GetSharedPtr());
}

cpp11::sexp output_type_sexp = state->resolver_(input_types_sexp);
cpp11::sexp output_type_sexp =
cpp11::function(state->resolver_)(input_types_sexp);
if (!Rf_inherits(output_type_sexp, "DataType")) {
cpp11::stop(
"Function specified as arrow_scalar_function() out_type argument must "
Expand Down Expand Up @@ -674,7 +675,8 @@ arrow::Status CallRScalarUDF(arrow::compute::KernelContext* context,
cpp11::writable::list udf_context = {batch_length_sexp, output_type_sexp};
udf_context.names() = {"batch_length", "output_type"};

cpp11::sexp func_result_sexp = state->exec_func_(udf_context, args_sexp);
cpp11::sexp func_result_sexp =
cpp11::function(state->exec_func_)(udf_context, args_sexp);

if (Rf_inherits(func_result_sexp, "Array")) {
auto array = cpp11::as_cpp<std::shared_ptr<arrow::Array>>(func_result_sexp);
Expand Down
4 changes: 2 additions & 2 deletions r/src/recordbatchreader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class RFunctionRecordBatchReader : public arrow::RecordBatchReader {

arrow::Status ReadNext(std::shared_ptr<arrow::RecordBatch>* batch_out) {
auto batch = SafeCallIntoR<std::shared_ptr<arrow::RecordBatch>>([&]() {
cpp11::sexp result_sexp = fun_();
cpp11::sexp result_sexp = cpp11::function(fun_)();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it potentially defer any type check that would have occurred in the constructor before?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point...I looked that up, and its current form, cpp11::function doesn't do any type checking ( https://github.com/r-lib/cpp11/blob/main/inst/include/cpp11/function.hpp#L17 )

if (result_sexp == R_NilValue) {
return std::shared_ptr<arrow::RecordBatch>(nullptr);
} else if (!Rf_inherits(result_sexp, "RecordBatch")) {
Expand All @@ -94,7 +94,7 @@ class RFunctionRecordBatchReader : public arrow::RecordBatchReader {
}

private:
cpp11::function fun_;
cpp11::sexp fun_;
std::shared_ptr<arrow::Schema> schema_;
};

Expand Down
3 changes: 3 additions & 0 deletions r/tests/testthat/test-dplyr-funcs.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ test_that("register_binding()/unregister_binding() works", {
register_binding("some.pkg2::some_fun", fun2, fake_registry),
"A \"some_fun\" binding already exists in the registry and will be overwritten."
)

# No warning when an identical function is re-registered
expect_silent(register_binding("some.pkg2::some_fun", fun2, fake_registry))
})

test_that("register_binding_agg() works", {
Expand Down